Explorar o código

feat: 添加IP白名单/黑名单功能 - 支持单IP、CIDR、IP范围

Your Name hai 2 semanas
pai
achega
d92cf341b8
Modificáronse 5 ficheiros con 467 adicións e 0 borrados
  1. 102 0
      database/db.go
  2. 107 0
      ftp/server.go
  3. 71 0
      static/index.html
  4. 105 0
      static/js/app.js
  5. 82 0
      web/server.go

+ 102 - 0
database/db.go

@@ -53,6 +53,16 @@ type OnlineUser struct {
 	CurrentDir   string    `json:"current_dir"`
 }
 
+// IPAccessRule IP访问规则
+type IPAccessRule struct {
+	ID        int64  `json:"id"`
+	IP        string `json:"ip"`   // 支持单IP (192.168.1.1)、CIDR (192.168.1.0/24)、IP范围 (192.168.1.1-192.168.1.100)
+	Type      string `json:"type"` // "whitelist" 或 "blacklist"
+	Note      string `json:"note"` // 备注说明
+	Enabled   bool   `json:"enabled"`
+	CreatedAt string `json:"created_at"`
+}
+
 // Open 打开数据库
 func Open(dbPath string) (*DB, error) {
 	dir := filepath.Dir(dbPath)
@@ -118,6 +128,17 @@ func (db *DB) initTables() error {
 		total_download_bytes INTEGER DEFAULT 0,
 		updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
 	);
+	CREATE TABLE IF NOT EXISTS ip_access_rules (
+		id INTEGER PRIMARY KEY AUTOINCREMENT,
+		ip TEXT NOT NULL,
+		type TEXT NOT NULL DEFAULT 'blacklist',
+		note TEXT DEFAULT '',
+		enabled INTEGER DEFAULT 1,
+		created_at DATETIME DEFAULT CURRENT_TIMESTAMP
+	);
+
+	CREATE INDEX IF NOT EXISTS idx_ip_rules_type ON ip_access_rules(type);
+	CREATE INDEX IF NOT EXISTS idx_ip_rules_enabled ON ip_access_rules(enabled);
 	`
 
 	_, err := db.Exec(schema)
@@ -328,3 +349,84 @@ func (db *DB) CleanOldLogs(days int) (int64, error) {
 	}
 	return result.RowsAffected()
 }
+
+// --- IP访问规则 CRUD ---
+
+// CreateIPRule 创建IP规则
+func (db *DB) CreateIPRule(rule *IPAccessRule) error {
+	now := time.Now().Format("2006-01-02 15:04:05")
+	result, err := db.Exec(`
+		INSERT INTO ip_access_rules (ip, type, note, enabled, created_at)
+		VALUES (?, ?, ?, ?, ?)`,
+		rule.IP, rule.Type, rule.Note, rule.Enabled, now)
+	if err != nil {
+		return fmt.Errorf("创建IP规则失败: %w", err)
+	}
+	rule.ID, _ = result.LastInsertId()
+	rule.CreatedAt = now
+	return nil
+}
+
+// ListIPRules 列出所有IP规则
+func (db *DB) ListIPRules(ruleType string) ([]IPAccessRule, error) {
+	var rows *sql.Rows
+	var err error
+	if ruleType != "" {
+		rows, err = db.Query(`SELECT id, ip, type, note, enabled, created_at
+			FROM ip_access_rules WHERE type=? ORDER BY id`, ruleType)
+	} else {
+		rows, err = db.Query(`SELECT id, ip, type, note, enabled, created_at
+			FROM ip_access_rules ORDER BY id`)
+	}
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+
+	var rules []IPAccessRule
+	for rows.Next() {
+		var rule IPAccessRule
+		var enabled int
+		if err := rows.Scan(&rule.ID, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil {
+			return nil, err
+		}
+		rule.Enabled = enabled == 1
+		rules = append(rules, rule)
+	}
+	return rules, nil
+}
+
+// DeleteIPRule 删除IP规则
+func (db *DB) DeleteIPRule(id int64) error {
+	_, err := db.Exec(`DELETE FROM ip_access_rules WHERE id=?`, id)
+	return err
+}
+
+// UpdateIPRule 更新IP规则
+func (db *DB) UpdateIPRule(rule *IPAccessRule) error {
+	_, err := db.Exec(`UPDATE ip_access_rules SET ip=?, type=?, note=?, enabled=? WHERE id=?`,
+		rule.IP, rule.Type, rule.Note, rule.Enabled, rule.ID)
+	return err
+}
+
+// GetEnabledIPRules 获取所有启用的IP规则
+func (db *DB) GetEnabledIPRules() ([]IPAccessRule, error) {
+	rows, err := db.Query(`SELECT id, ip, type, note, enabled, created_at
+		FROM ip_access_rules WHERE enabled=1 ORDER BY id`)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+
+	var rules []IPAccessRule
+	for rows.Next() {
+		var rule IPAccessRule
+		var enabled int
+		if err := rows.Scan(&rule.ID, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil {
+			return nil, err
+		}
+		rule.Enabled = enabled == 1
+		rules = append(rules, rule)
+	}
+	return rules, nil
+}

+ 107 - 0
ftp/server.go

@@ -4,7 +4,9 @@ import (
 	"crypto/tls"
 	"fmt"
 	"log"
+	"net"
 	"os"
+	"strings"
 	"sync"
 	"time"
 
@@ -92,6 +94,17 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
 
 // ClientConnected 客户端连接
 func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
+	// IP白名单/黑名单检查
+	clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
+	if err != nil {
+		clientIP = cc.RemoteAddr().String()
+	}
+
+	if err := s.checkIPAccess(clientIP); err != nil {
+		log.Printf("IP %s 被拒绝连接: %v", clientIP, err)
+		return "", fmt.Errorf("连接被拒绝: %s", err)
+	}
+
 	return "220 Welcome to FTP Server\r\n", nil
 }
 
@@ -183,3 +196,97 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
 func (s *Server) GetTLSConfig() (*tls.Config, error) {
 	return nil, fmt.Errorf("TLS未配置")
 }
+
+// checkIPAccess 检查IP是否允许访问
+func (s *Server) checkIPAccess(clientIP string) error {
+	if s.db == nil {
+		return nil
+	}
+
+	rules, err := s.db.GetEnabledIPRules()
+	if err != nil {
+		return nil // 查询失败时允许连接
+	}
+
+	var whitelists, blacklists []database.IPAccessRule
+	for _, rule := range rules {
+		if rule.Type == "whitelist" {
+			whitelists = append(whitelists, rule)
+		} else if rule.Type == "blacklist" {
+			blacklists = append(blacklists, rule)
+		}
+	}
+
+	// 如果有白名单规则,只允许白名单中的IP
+	if len(whitelists) > 0 {
+		matched := false
+		for _, rule := range whitelists {
+			if matchIP(clientIP, rule.IP) {
+				matched = true
+				break
+			}
+		}
+		if !matched {
+			return fmt.Errorf("IP不在白名单中")
+		}
+	}
+
+	// 检查黑名单
+	for _, rule := range blacklists {
+		if matchIP(clientIP, rule.IP) {
+			return fmt.Errorf("IP已被列入黑名单")
+		}
+	}
+
+	return nil
+}
+
+// matchIP 检查IP是否匹配规则
+func matchIP(clientIP, rule string) bool {
+	// 单个IP
+	if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") {
+		return clientIP == rule
+	}
+
+	// CIDR 表示法 (192.168.1.0/24)
+	if strings.Contains(rule, "/") {
+		_, ipNet, err := net.ParseCIDR(rule)
+		if err != nil {
+			return clientIP == rule
+		}
+		ip := net.ParseIP(clientIP)
+		if ip == nil {
+			return false
+		}
+		return ipNet.Contains(ip)
+	}
+
+	// IP范围 (192.168.1.1-192.168.1.100)
+	if strings.Contains(rule, "-") {
+		parts := strings.SplitN(rule, "-", 2)
+		startIP := net.ParseIP(strings.TrimSpace(parts[0]))
+		endIP := net.ParseIP(strings.TrimSpace(parts[1]))
+		ip := net.ParseIP(clientIP)
+		if startIP == nil || endIP == nil || ip == nil {
+			return false
+		}
+		return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0
+	}
+
+	return false
+}
+
+// bytesCompare 比较两个IP的字节
+func bytesCompare(a, b net.IP) int {
+	a = a.To16()
+	b = b.To16()
+	for i := range a {
+		if a[i] < b[i] {
+			return -1
+		}
+		if a[i] > b[i] {
+			return 1
+		}
+	}
+	return 0
+}

+ 71 - 0
static/index.html

@@ -49,6 +49,9 @@
                 <li data-page="online">
                     <span class="icon">&#128279;</span> 在线用户
                 </li>
+                <li data-page="ip-rules">
+                    <span class="icon">&#128737;</span> IP白/黑名单
+                </li>
                 <li data-page="settings">
                     <span class="icon">&#9881;</span> 系统设置
                 </li>
@@ -200,6 +203,37 @@
                 </table>
             </div>
 
+            <!-- IP白/黑名单 -->
+            <div id="page-ip-rules" class="page">
+                <div class="page-header">
+                    <h2>IP 白名单/黑名单</h2>
+                    <button class="btn btn-primary" onclick="showAddIPRule()">添加规则</button>
+                </div>
+                <div class="ip-rules-info" style="margin-bottom:16px;padding:12px;background:#fff;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,0.08)">
+                    <p style="color:#666;font-size:13px;line-height:1.8">
+                        <strong>规则说明:</strong><br>
+                        - 支持<strong>单IP</strong>(如 192.168.1.1)、<strong>CIDR</strong>(如 192.168.1.0/24)、<strong>IP范围</strong>(如 192.168.1.1-192.168.1.100)<br>
+                        - <strong>白名单</strong>:启用后只有白名单中的IP才能连接,黑名单中的IP会被拒绝<br>
+                        - <strong>黑名单</strong>:黑名单中的IP将被禁止连接<br>
+                        - 如果没有白名单规则,则所有IP默认允许(除非在黑名单中)
+                    </p>
+                </div>
+                <table class="data-table">
+                    <thead>
+                        <tr>
+                            <th>ID</th>
+                            <th>IP地址/网段</th>
+                            <th>类型</th>
+                            <th>备注</th>
+                            <th>状态</th>
+                            <th>创建时间</th>
+                            <th>操作</th>
+                        </tr>
+                    </thead>
+                    <tbody id="ip-rules-tbody"></tbody>
+                </table>
+            </div>
+
             <!-- 系统设置 -->
             <div id="page-settings" class="page">
                 <h2>系统设置</h2>
@@ -333,6 +367,43 @@
         </div>
     </div>
 
+    <!-- IP规则弹窗 -->
+    <div id="ip-rule-modal" class="modal" style="display:none">
+        <div class="modal-content">
+            <div class="modal-header">
+                <h3 id="ip-rule-modal-title">添加IP规则</h3>
+                <span class="modal-close" onclick="closeIPRuleModal()">&times;</span>
+            </div>
+            <form id="ip-rule-form">
+                <input type="hidden" id="ip-rule-edit-id" value="">
+                <div class="form-group">
+                    <label>IP地址/网段</label>
+                    <input type="text" id="ip-rule-ip" placeholder="如: 192.168.1.1 或 192.168.1.0/24 或 10.0.0.1-10.0.0.255" required>
+                </div>
+                <div class="form-group">
+                    <label>规则类型</label>
+                    <select id="ip-rule-type">
+                        <option value="blacklist">黑名单(禁止连接)</option>
+                        <option value="whitelist">白名单(允许连接)</option>
+                    </select>
+                </div>
+                <div class="form-group">
+                    <label>备注说明</label>
+                    <input type="text" id="ip-rule-note" placeholder="可选,填写备注说明">
+                </div>
+                <div class="form-group">
+                    <label>
+                        <input type="checkbox" id="ip-rule-enabled" checked> 启用此规则
+                    </label>
+                </div>
+                <div class="modal-footer">
+                    <button type="button" class="btn" onclick="closeIPRuleModal()">取消</button>
+                    <button type="submit" class="btn btn-primary">保存</button>
+                </div>
+            </form>
+        </div>
+    </div>
+
     <!-- 提示消息 -->
     <div id="toast" class="toast"></div>
 

+ 105 - 0
static/js/app.js

@@ -100,6 +100,7 @@ function loadPage(page) {
         case 'files': loadFiles(currentPath); break;
         case 'logs': loadLogs(); break;
         case 'online': loadOnline(); break;
+        case 'ip-rules': loadIPRules(); break;
         case 'settings': loadConfig(); break;
     }
 }
@@ -458,3 +459,107 @@ if (token) {
 } else {
     showLogin();
 }
+
+// --- IP规则管理 ---
+async function loadIPRules() {
+    try {
+        const rules = await api('GET', '/api/ip-rules');
+        const tbody = document.getElementById('ip-rules-tbody');
+        if (!rules || !rules.length) {
+            tbody.innerHTML = '<tr><td colspan="7" style="text-align:center;color:#999;padding:40px">暂无IP规则,所有IP默认允许连接</td></tr>';
+            return;
+        }
+        tbody.innerHTML = rules.map(r => {
+            const typeLabel = r.type === 'whitelist'
+                ? '<span style="color:#667eea;font-weight:600">白名单</span>'
+                : '<span style="color:#ff4d4f;font-weight:600">黑名单</span>';
+            const statusLabel = r.enabled
+                ? '<span class="status-enabled">启用</span>'
+                : '<span class="status-disabled">禁用</span>';
+            return `<tr>
+                <td>${r.id}</td>
+                <td><code style="background:#f5f5f5;padding:2px 6px;border-radius:3px">${r.ip}</code></td>
+                <td>${typeLabel}</td>
+                <td>${r.note || '-'}</td>
+                <td>${statusLabel}</td>
+                <td>${formatTime(r.created_at)}</td>
+                <td class="action-btns">
+                    <button class="btn btn-sm" onclick="editIPRule(${r.id}, '${r.ip}', '${r.type}', '${(r.note||'').replace(/'/g, "\\'")}', ${r.enabled})">编辑</button>
+                    <button class="btn btn-sm" onclick="toggleIPRule(${r.id}, '${r.ip}', '${r.type}', '${(r.note||'').replace(/'/g, "\\'")}', ${r.enabled})">${r.enabled ? '禁用' : '启用'}</button>
+                    <button class="btn btn-sm btn-danger" onclick="deleteIPRule(${r.id})">删除</button>
+                </td>
+            </tr>`;
+        }).join('');
+    } catch (err) {
+        showToast(err.message, 'error');
+    }
+}
+
+function showAddIPRule() {
+    document.getElementById('ip-rule-modal-title').textContent = '添加IP规则';
+    document.getElementById('ip-rule-edit-id').value = '';
+    document.getElementById('ip-rule-form').reset();
+    document.getElementById('ip-rule-enabled').checked = true;
+    document.getElementById('ip-rule-modal').style.display = 'flex';
+}
+
+function editIPRule(id, ip, type, note, enabled) {
+    document.getElementById('ip-rule-modal-title').textContent = '编辑IP规则';
+    document.getElementById('ip-rule-edit-id').value = id;
+    document.getElementById('ip-rule-ip').value = ip;
+    document.getElementById('ip-rule-type').value = type;
+    document.getElementById('ip-rule-note').value = note;
+    document.getElementById('ip-rule-enabled').checked = enabled;
+    document.getElementById('ip-rule-modal').style.display = 'flex';
+}
+
+function closeIPRuleModal() {
+    document.getElementById('ip-rule-modal').style.display = 'none';
+}
+
+document.getElementById('ip-rule-form').addEventListener('submit', async (e) => {
+    e.preventDefault();
+    const editId = document.getElementById('ip-rule-edit-id').value;
+    const data = {
+        ip: document.getElementById('ip-rule-ip').value,
+        type: document.getElementById('ip-rule-type').value,
+        note: document.getElementById('ip-rule-note').value,
+        enabled: document.getElementById('ip-rule-enabled').checked
+    };
+    try {
+        if (editId) {
+            await api('PUT', '/api/ip-rules/' + editId, data);
+            showToast('规则已更新');
+        } else {
+            await api('POST', '/api/ip-rules', data);
+            showToast('规则添加成功');
+        }
+        closeIPRuleModal();
+        loadIPRules();
+    } catch (err) {
+        showToast(err.message, 'error');
+    }
+});
+
+async function toggleIPRule(id, ip, type, note, enabled) {
+    try {
+        await api('PUT', '/api/ip-rules/' + id, {
+            ip, type, note, enabled: !enabled
+        });
+        showToast(!enabled ? '规则已启用' : '规则已禁用');
+        loadIPRules();
+    } catch (err) {
+        showToast(err.message, 'error');
+    }
+}
+
+async function deleteIPRule(id) {
+    if (!confirm('确定删除此IP规则吗?')) return;
+    try {
+        await api('DELETE', '/api/ip-rules/' + id);
+        showToast('规则已删除');
+        loadIPRules();
+    } catch (err) {
+        showToast(err.message, 'error');
+    }
+}

+ 82 - 0
web/server.go

@@ -62,6 +62,8 @@ func (s *Server) Start() error {
 	mux.HandleFunc("/api/online", s.authMiddleware(s.handleOnline))
 	mux.HandleFunc("/api/config", s.authMiddleware(s.handleConfig))
 	mux.HandleFunc("/api/upload", s.authMiddleware(s.handleUpload))
+	mux.HandleFunc("/api/ip-rules", s.authMiddleware(s.handleIPRules))
+	mux.HandleFunc("/api/ip-rules/", s.authMiddleware(s.handleIPRuleOperation))
 
 	addr := fmt.Sprintf("%s:%d", webCfg.Host, webCfg.Port)
 	log.Printf("Web管理界面已启动: http://localhost:%d", webCfg.Port)
@@ -583,3 +585,83 @@ func GenerateToken() string {
 	rand.Read(b)
 	return hex.EncodeToString(b)
 }
+
+// handleIPRules IP规则列表和创建
+func (s *Server) handleIPRules(w http.ResponseWriter, r *http.Request) {
+	switch r.Method {
+	case http.MethodGet:
+		ruleType := r.URL.Query().Get("type")
+		rules, err := s.db.ListIPRules(ruleType)
+		if err != nil {
+			s.jsonError(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+		if rules == nil {
+			rules = []database.IPAccessRule{}
+		}
+		s.jsonResponse(w, http.StatusOK, rules)
+
+	case http.MethodPost:
+		var rule database.IPAccessRule
+		if err := json.NewDecoder(r.Body).Decode(&rule); err != nil {
+			s.jsonError(w, "请求格式错误", http.StatusBadRequest)
+			return
+		}
+		if rule.IP == "" {
+			s.jsonError(w, "IP不能为空", http.StatusBadRequest)
+			return
+		}
+		if rule.Type != "whitelist" && rule.Type != "blacklist" {
+			rule.Type = "blacklist"
+		}
+		rule.Enabled = true
+		if err := s.db.CreateIPRule(&rule); err != nil {
+			s.jsonError(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+		s.jsonResponse(w, http.StatusOK, rule)
+
+	default:
+		s.jsonError(w, "方法不允许", http.StatusMethodNotAllowed)
+	}
+}
+
+// handleIPRuleOperation 单条IP规则操作
+func (s *Server) handleIPRuleOperation(w http.ResponseWriter, r *http.Request) {
+	pathParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/api/ip-rules/"), "/")
+	idStr := pathParts[0]
+	if idStr == "" {
+		s.jsonError(w, "ID不能为空", http.StatusBadRequest)
+		return
+	}
+	id, err := strconv.ParseInt(idStr, 10, 64)
+	if err != nil {
+		s.jsonError(w, "无效的ID", http.StatusBadRequest)
+		return
+	}
+
+	switch r.Method {
+	case http.MethodPut:
+		var rule database.IPAccessRule
+		if err := json.NewDecoder(r.Body).Decode(&rule); err != nil {
+			s.jsonError(w, "请求格式错误", http.StatusBadRequest)
+			return
+		}
+		rule.ID = id
+		if err := s.db.UpdateIPRule(&rule); err != nil {
+			s.jsonError(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+		s.jsonResponse(w, http.StatusOK, map[string]string{"message": "规则已更新"})
+
+	case http.MethodDelete:
+		if err := s.db.DeleteIPRule(id); err != nil {
+			s.jsonError(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+		s.jsonResponse(w, http.StatusOK, map[string]string{"message": "规则已删除"})
+
+	default:
+		s.jsonError(w, "方法不允许", http.StatusMethodNotAllowed)
+	}
+}