Browse Source

feat: 支持用户级别IP黑白名单 - 可为单个用户设置独立的IP访问规则

Your Name 2 weeks ago
parent
commit
dec263312a
6 changed files with 141 additions and 53 deletions
  1. 42 23
      database/db.go
  2. BIN
      ftp-server
  3. 65 19
      ftp/server.go
  4. 15 3
      static/index.html
  5. 17 7
      static/js/app.js
  6. 2 1
      web/server.go

+ 42 - 23
database/db.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"os"
 	"path/filepath"
+	"strings"
 	"time"
 
 	_ "modernc.org/sqlite"
@@ -56,9 +57,10 @@ type OnlineUser struct {
 // IPAccessRule IP访问规则
 type IPAccessRule struct {
 	ID        int64  `json:"id"`
-	IP        string `json:"ip"`   // 支持单IP (192.168.1.1)、CIDR (192.168.1.0/24)、IP范围 (192.168.1.1-192.168.1.100)
-	Type      string `json:"type"` // "whitelist" 或 "blacklist"
-	Note      string `json:"note"` // 备注说明
+	Username  string `json:"username"` // 为空表示全局规则,有值表示用户专属规则
+	IP        string `json:"ip"`       // 支持单IP、CIDR(192.168.1.0/24)、IP范围(192.168.1.1-192.168.1.100)
+	Type      string `json:"type"`     // "whitelist" 或 "blacklist"
+	Note      string `json:"note"`     // 备注说明
 	Enabled   bool   `json:"enabled"`
 	CreatedAt string `json:"created_at"`
 }
@@ -130,6 +132,7 @@ func (db *DB) initTables() error {
 	);
 	CREATE TABLE IF NOT EXISTS ip_access_rules (
 		id INTEGER PRIMARY KEY AUTOINCREMENT,
+		username TEXT DEFAULT '',
 		ip TEXT NOT NULL,
 		type TEXT NOT NULL DEFAULT 'blacklist',
 		note TEXT DEFAULT '',
@@ -139,6 +142,7 @@ func (db *DB) initTables() error {
 
 	CREATE INDEX IF NOT EXISTS idx_ip_rules_type ON ip_access_rules(type);
 	CREATE INDEX IF NOT EXISTS idx_ip_rules_enabled ON ip_access_rules(enabled);
+	CREATE INDEX IF NOT EXISTS idx_ip_rules_username ON ip_access_rules(username);
 	`
 
 	_, err := db.Exec(schema)
@@ -356,9 +360,9 @@ func (db *DB) CleanOldLogs(days int) (int64, error) {
 func (db *DB) CreateIPRule(rule *IPAccessRule) error {
 	now := time.Now().Format("2006-01-02 15:04:05")
 	result, err := db.Exec(`
-		INSERT INTO ip_access_rules (ip, type, note, enabled, created_at)
-		VALUES (?, ?, ?, ?, ?)`,
-		rule.IP, rule.Type, rule.Note, rule.Enabled, now)
+		INSERT INTO ip_access_rules (username, ip, type, note, enabled, created_at)
+		VALUES (?, ?, ?, ?, ?, ?)`,
+		rule.Username, rule.IP, rule.Type, rule.Note, rule.Enabled, now)
 	if err != nil {
 		return fmt.Errorf("创建IP规则失败: %w", err)
 	}
@@ -367,17 +371,32 @@ func (db *DB) CreateIPRule(rule *IPAccessRule) error {
 	return nil
 }
 
-// ListIPRules 列出所有IP规则
-func (db *DB) ListIPRules(ruleType string) ([]IPAccessRule, error) {
-	var rows *sql.Rows
-	var err error
+// ListIPRules 列出IP规则,ruleType和username为空时列出全部
+func (db *DB) ListIPRules(ruleType, username string) ([]IPAccessRule, error) {
+	query := `SELECT id, username, ip, type, note, enabled, created_at FROM ip_access_rules`
+	var conditions []string
+	var args []interface{}
+
 	if ruleType != "" {
-		rows, err = db.Query(`SELECT id, ip, type, note, enabled, created_at
-			FROM ip_access_rules WHERE type=? ORDER BY id`, ruleType)
-	} else {
-		rows, err = db.Query(`SELECT id, ip, type, note, enabled, created_at
-			FROM ip_access_rules ORDER BY id`)
+		conditions = append(conditions, "type=?")
+		args = append(args, ruleType)
 	}
+	if username != "" {
+		if username == "__empty__" {
+			conditions = append(conditions, "username='')")
+		} else if username == "__has__" {
+			conditions = append(conditions, "username!='')")
+		} else {
+			conditions = append(conditions, "username=?")
+			args = append(args, username)
+		}
+	}
+	if len(conditions) > 0 {
+		query += " WHERE " + strings.Join(conditions, " AND ")
+	}
+	query += " ORDER BY id"
+
+	rows, err := db.Query(query, args...)
 	if err != nil {
 		return nil, err
 	}
@@ -387,7 +406,7 @@ func (db *DB) ListIPRules(ruleType string) ([]IPAccessRule, error) {
 	for rows.Next() {
 		var rule IPAccessRule
 		var enabled int
-		if err := rows.Scan(&rule.ID, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil {
+		if err := rows.Scan(&rule.ID, &rule.Username, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil {
 			return nil, err
 		}
 		rule.Enabled = enabled == 1
@@ -404,15 +423,15 @@ func (db *DB) DeleteIPRule(id int64) error {
 
 // UpdateIPRule 更新IP规则
 func (db *DB) UpdateIPRule(rule *IPAccessRule) error {
-	_, err := db.Exec(`UPDATE ip_access_rules SET ip=?, type=?, note=?, enabled=? WHERE id=?`,
-		rule.IP, rule.Type, rule.Note, rule.Enabled, rule.ID)
+	_, err := db.Exec(`UPDATE ip_access_rules SET username=?, ip=?, type=?, note=?, enabled=? WHERE id=?`,
+		rule.Username, rule.IP, rule.Type, rule.Note, rule.Enabled, rule.ID)
 	return err
 }
 
-// GetEnabledIPRules 获取所有启用的IP规则
-func (db *DB) GetEnabledIPRules() ([]IPAccessRule, error) {
-	rows, err := db.Query(`SELECT id, ip, type, note, enabled, created_at
-		FROM ip_access_rules WHERE enabled=1 ORDER BY id`)
+// GetEnabledIPRules 获取所有启用的IP规则(全局+指定用户)
+func (db *DB) GetEnabledIPRules(username string) ([]IPAccessRule, error) {
+	rows, err := db.Query(`SELECT id, username, ip, type, note, enabled, created_at
+		FROM ip_access_rules WHERE enabled=1 AND (username='' OR username=?) ORDER BY id`, username)
 	if err != nil {
 		return nil, err
 	}
@@ -422,7 +441,7 @@ func (db *DB) GetEnabledIPRules() ([]IPAccessRule, error) {
 	for rows.Next() {
 		var rule IPAccessRule
 		var enabled int
-		if err := rows.Scan(&rule.ID, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil {
+		if err := rows.Scan(&rule.ID, &rule.Username, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil {
 			return nil, err
 		}
 		rule.Enabled = enabled == 1

BIN
ftp-server


+ 65 - 19
ftp/server.go

@@ -92,16 +92,15 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
 	}, nil
 }
 
-// ClientConnected 客户端连接
+// ClientConnected 客户端连接(只检查全局规则)
 func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
-	// IP白名单/黑名单检查
 	clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
 	if err != nil {
 		clientIP = cc.RemoteAddr().String()
 	}
 
-	if err := s.checkIPAccess(clientIP); err != nil {
-		log.Printf("IP %s 被拒绝连接: %v", clientIP, err)
+	if err := s.checkIPAccess(clientIP, ""); err != nil {
+		log.Printf("IP %s 被拒绝连接(全局规则): %v", clientIP, err)
 		return "", fmt.Errorf("连接被拒绝: %s", err)
 	}
 
@@ -156,6 +155,22 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
 		return nil, fmt.Errorf("密码错误")
 	}
 
+	// 检查用户级别IP规则
+	clientIP, _, _ := net.SplitHostPort(cc.RemoteAddr().String())
+	if clientIP == "" {
+		clientIP = cc.RemoteAddr().String()
+	}
+	if err := s.checkIPAccess(clientIP, username); err != nil {
+		log.Printf("用户 %s IP %s 被拒绝: %v", username, clientIP, err)
+		s.db.AddLog(&database.FTPLog{
+			Username: username,
+			IP:       cc.RemoteAddr().String(),
+			Action:   "login_blocked",
+			Status:   "blocked",
+		})
+		return nil, fmt.Errorf("登录被拒绝: %s", err)
+	}
+
 	// 记录登录日志
 	s.db.AddLog(&database.FTPLog{
 		Username: username,
@@ -197,44 +212,75 @@ func (s *Server) GetTLSConfig() (*tls.Config, error) {
 	return nil, fmt.Errorf("TLS未配置")
 }
 
-// checkIPAccess 检查IP是否允许访问
-func (s *Server) checkIPAccess(clientIP string) error {
+// checkIPAccess 检查IP是否允许访问,username为空时只检查全局规则
+func (s *Server) checkIPAccess(clientIP, username string) error {
 	if s.db == nil {
 		return nil
 	}
 
-	rules, err := s.db.GetEnabledIPRules()
+	rules, err := s.db.GetEnabledIPRules(username)
 	if err != nil {
 		return nil // 查询失败时允许连接
 	}
 
-	var whitelists, blacklists []database.IPAccessRule
+	// 分离全局规则和用户规则
+	var globalWhitelist, globalBlacklist []database.IPAccessRule
+	var userWhitelist, userBlacklist []database.IPAccessRule
 	for _, rule := range rules {
-		if rule.Type == "whitelist" {
-			whitelists = append(whitelists, rule)
-		} else if rule.Type == "blacklist" {
-			blacklists = append(blacklists, rule)
+		if rule.Username == "" {
+			if rule.Type == "whitelist" {
+				globalWhitelist = append(globalWhitelist, rule)
+			} else {
+				globalBlacklist = append(globalBlacklist, rule)
+			}
+		} else {
+			if rule.Type == "whitelist" {
+				userWhitelist = append(userWhitelist, rule)
+			} else {
+				userBlacklist = append(userBlacklist, rule)
+			}
+		}
+	}
+
+	// 1. 先检查全局黑名单
+	for _, rule := range globalBlacklist {
+		if matchIP(clientIP, rule.IP) {
+			return fmt.Errorf("IP已被全局黑名单拦截")
 		}
 	}
 
-	// 如果有白名单规则,只允许白名单中的IP
-	if len(whitelists) > 0 {
+	// 2. 检查全局白名单(如果有全局白名单,必须在其中)
+	if len(globalWhitelist) > 0 {
 		matched := false
-		for _, rule := range whitelists {
+		for _, rule := range globalWhitelist {
 			if matchIP(clientIP, rule.IP) {
 				matched = true
 				break
 			}
 		}
 		if !matched {
-			return fmt.Errorf("IP不在白名单中")
+			return fmt.Errorf("IP不在全局白名单中")
 		}
 	}
 
-	// 检查黑名单
-	for _, rule := range blacklists {
+	// 3. 检查用户黑名单
+	for _, rule := range userBlacklist {
 		if matchIP(clientIP, rule.IP) {
-			return fmt.Errorf("IP已被列入黑名单")
+			return fmt.Errorf("IP已被用户黑名单拦截")
+		}
+	}
+
+	// 4. 检查用户白名单(如果有用户白名单,必须在其中)
+	if len(userWhitelist) > 0 {
+		matched := false
+		for _, rule := range userWhitelist {
+			if matchIP(clientIP, rule.IP) {
+				matched = true
+				break
+			}
+		}
+		if !matched {
+			return fmt.Errorf("IP不在用户白名单中")
 		}
 	}
 

+ 15 - 3
static/index.html

@@ -213,15 +213,23 @@
                     <p style="color:#666;font-size:13px;line-height:1.8">
                         <strong>规则说明:</strong><br>
                         - 支持<strong>单IP</strong>(如 192.168.1.1)、<strong>CIDR</strong>(如 192.168.1.0/24)、<strong>IP范围</strong>(如 192.168.1.1-192.168.1.100)<br>
-                        - <strong>白名单</strong>:启用后只有白名单中的IP才能连接,黑名单中的IP会被拒绝<br>
-                        - <strong>黑名单</strong>:黑名单中的IP将被禁止连接<br>
-                        - 如果没有白名单规则,则所有IP默认允许(除非在黑名单中)
+                        - <strong>全局规则</strong>:对所有用户生效,在连接时即检查<br>
+                        - <strong>用户规则</strong>:仅对指定用户生效,在用户登录时检查<br>
+                        - <strong>优先级</strong>:全局黑名单 > 全局白名单 > 用户黑名单 > 用户白名单
                     </p>
                 </div>
+                <div class="filter-bar" style="margin-bottom:12px">
+                    <select id="ip-rule-filter" class="input-sm" onchange="loadIPRules()">
+                        <option value="">全部规则</option>
+                        <option value="global">仅全局规则</option>
+                        <option value="user">仅用户规则</option>
+                    </select>
+                </div>
                 <table class="data-table">
                     <thead>
                         <tr>
                             <th>ID</th>
+                            <th>作用范围</th>
                             <th>IP地址/网段</th>
                             <th>类型</th>
                             <th>备注</th>
@@ -376,6 +384,10 @@
             </div>
             <form id="ip-rule-form">
                 <input type="hidden" id="ip-rule-edit-id" value="">
+                <div class="form-group">
+                    <label>作用用户(留空为全局规则)</label>
+                    <input type="text" id="ip-rule-username" placeholder="留空表示全局规则,填入用户名表示仅对该用户生效">
+                </div>
                 <div class="form-group">
                     <label>IP地址/网段</label>
                     <input type="text" id="ip-rule-ip" placeholder="如: 192.168.1.1 或 192.168.1.0/24 或 10.0.0.1-10.0.0.255" required>

+ 17 - 7
static/js/app.js

@@ -463,13 +463,20 @@ if (token) {
 // --- IP规则管理 ---
 async function loadIPRules() {
     try {
-        const rules = await api('GET', '/api/ip-rules');
+        const filter = document.getElementById('ip-rule-filter').value;
+        let url = '/api/ip-rules';
+        if (filter === 'global') url += '?username=__empty__';
+        else if (filter === 'user') url += '?username=__has__';
+        const rules = await api('GET', url);
         const tbody = document.getElementById('ip-rules-tbody');
         if (!rules || !rules.length) {
-            tbody.innerHTML = '<tr><td colspan="7" style="text-align:center;color:#999;padding:40px">暂无IP规则,所有IP默认允许连接</td></tr>';
+            tbody.innerHTML = '<tr><td colspan="8" style="text-align:center;color:#999;padding:40px">暂无IP规则,所有IP默认允许连接</td></tr>';
             return;
         }
         tbody.innerHTML = rules.map(r => {
+            const scopeLabel = r.username
+                ? `<span style="color:#e67e22;font-weight:600">用户: ${r.username}</span>`
+                : '<span style="color:#667eea;font-weight:600">全局</span>';
             const typeLabel = r.type === 'whitelist'
                 ? '<span style="color:#667eea;font-weight:600">白名单</span>'
                 : '<span style="color:#ff4d4f;font-weight:600">黑名单</span>';
@@ -478,14 +485,15 @@ async function loadIPRules() {
                 : '<span class="status-disabled">禁用</span>';
             return `<tr>
                 <td>${r.id}</td>
+                <td>${scopeLabel}</td>
                 <td><code style="background:#f5f5f5;padding:2px 6px;border-radius:3px">${r.ip}</code></td>
                 <td>${typeLabel}</td>
                 <td>${r.note || '-'}</td>
                 <td>${statusLabel}</td>
                 <td>${formatTime(r.created_at)}</td>
                 <td class="action-btns">
-                    <button class="btn btn-sm" onclick="editIPRule(${r.id}, '${r.ip}', '${r.type}', '${(r.note||'').replace(/'/g, "\\'")}', ${r.enabled})">编辑</button>
-                    <button class="btn btn-sm" onclick="toggleIPRule(${r.id}, '${r.ip}', '${r.type}', '${(r.note||'').replace(/'/g, "\\'")}', ${r.enabled})">${r.enabled ? '禁用' : '启用'}</button>
+                    <button class="btn btn-sm" onclick="editIPRule(${r.id}, '${r.username||''}', '${r.ip}', '${r.type}', '${(r.note||'').replace(/'/g, "\\'")}', ${r.enabled})">编辑</button>
+                    <button class="btn btn-sm" onclick="toggleIPRule(${r.id}, '${r.username||''}', '${r.ip}', '${r.type}', '${(r.note||'').replace(/'/g, "\\'")}', ${r.enabled})">${r.enabled ? '禁用' : '启用'}</button>
                     <button class="btn btn-sm btn-danger" onclick="deleteIPRule(${r.id})">删除</button>
                 </td>
             </tr>`;
@@ -503,9 +511,10 @@ function showAddIPRule() {
     document.getElementById('ip-rule-modal').style.display = 'flex';
 }
 
-function editIPRule(id, ip, type, note, enabled) {
+function editIPRule(id, username, ip, type, note, enabled) {
     document.getElementById('ip-rule-modal-title').textContent = '编辑IP规则';
     document.getElementById('ip-rule-edit-id').value = id;
+    document.getElementById('ip-rule-username').value = username;
     document.getElementById('ip-rule-ip').value = ip;
     document.getElementById('ip-rule-type').value = type;
     document.getElementById('ip-rule-note').value = note;
@@ -521,6 +530,7 @@ document.getElementById('ip-rule-form').addEventListener('submit', async (e) =>
     e.preventDefault();
     const editId = document.getElementById('ip-rule-edit-id').value;
     const data = {
+        username: document.getElementById('ip-rule-username').value,
         ip: document.getElementById('ip-rule-ip').value,
         type: document.getElementById('ip-rule-type').value,
         note: document.getElementById('ip-rule-note').value,
@@ -541,10 +551,10 @@ document.getElementById('ip-rule-form').addEventListener('submit', async (e) =>
     }
 });
 
-async function toggleIPRule(id, ip, type, note, enabled) {
+async function toggleIPRule(id, username, ip, type, note, enabled) {
     try {
         await api('PUT', '/api/ip-rules/' + id, {
-            ip, type, note, enabled: !enabled
+            username, ip, type, note, enabled: !enabled
         });
         showToast(!enabled ? '规则已启用' : '规则已禁用');
         loadIPRules();

+ 2 - 1
web/server.go

@@ -591,7 +591,8 @@ func (s *Server) handleIPRules(w http.ResponseWriter, r *http.Request) {
 	switch r.Method {
 	case http.MethodGet:
 		ruleType := r.URL.Query().Get("type")
-		rules, err := s.db.ListIPRules(ruleType)
+		username := r.URL.Query().Get("username")
+		rules, err := s.db.ListIPRules(ruleType, username)
 		if err != nil {
 			s.jsonError(w, err.Error(), http.StatusInternalServerError)
 			return