diff --git a/database/db.go b/database/db.go index d04e26c..9572bb1 100644 --- a/database/db.go +++ b/database/db.go @@ -53,6 +53,16 @@ type OnlineUser struct { CurrentDir string `json:"current_dir"` } +// IPAccessRule IP访问规则 +type IPAccessRule struct { + ID int64 `json:"id"` + IP string `json:"ip"` // 支持单IP (192.168.1.1)、CIDR (192.168.1.0/24)、IP范围 (192.168.1.1-192.168.1.100) + Type string `json:"type"` // "whitelist" 或 "blacklist" + Note string `json:"note"` // 备注说明 + Enabled bool `json:"enabled"` + CreatedAt string `json:"created_at"` +} + // Open 打开数据库 func Open(dbPath string) (*DB, error) { dir := filepath.Dir(dbPath) @@ -118,6 +128,17 @@ func (db *DB) initTables() error { total_download_bytes INTEGER DEFAULT 0, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); + CREATE TABLE IF NOT EXISTS ip_access_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ip TEXT NOT NULL, + type TEXT NOT NULL DEFAULT 'blacklist', + note TEXT DEFAULT '', + enabled INTEGER DEFAULT 1, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_ip_rules_type ON ip_access_rules(type); + CREATE INDEX IF NOT EXISTS idx_ip_rules_enabled ON ip_access_rules(enabled); ` _, err := db.Exec(schema) @@ -328,3 +349,84 @@ func (db *DB) CleanOldLogs(days int) (int64, error) { } return result.RowsAffected() } + +// --- IP访问规则 CRUD --- + +// CreateIPRule 创建IP规则 +func (db *DB) CreateIPRule(rule *IPAccessRule) error { + now := time.Now().Format("2006-01-02 15:04:05") + result, err := db.Exec(` + INSERT INTO ip_access_rules (ip, type, note, enabled, created_at) + VALUES (?, ?, ?, ?, ?)`, + rule.IP, rule.Type, rule.Note, rule.Enabled, now) + if err != nil { + return fmt.Errorf("创建IP规则失败: %w", err) + } + rule.ID, _ = result.LastInsertId() + rule.CreatedAt = now + return nil +} + +// ListIPRules 列出所有IP规则 +func (db *DB) ListIPRules(ruleType string) ([]IPAccessRule, error) { + var rows *sql.Rows + var err error + if ruleType != "" { + rows, err = db.Query(`SELECT id, ip, type, note, enabled, created_at + FROM ip_access_rules WHERE type=? ORDER BY id`, ruleType) + } else { + rows, err = db.Query(`SELECT id, ip, type, note, enabled, created_at + FROM ip_access_rules ORDER BY id`) + } + if err != nil { + return nil, err + } + defer rows.Close() + + var rules []IPAccessRule + for rows.Next() { + var rule IPAccessRule + var enabled int + if err := rows.Scan(&rule.ID, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil { + return nil, err + } + rule.Enabled = enabled == 1 + rules = append(rules, rule) + } + return rules, nil +} + +// DeleteIPRule 删除IP规则 +func (db *DB) DeleteIPRule(id int64) error { + _, err := db.Exec(`DELETE FROM ip_access_rules WHERE id=?`, id) + return err +} + +// UpdateIPRule 更新IP规则 +func (db *DB) UpdateIPRule(rule *IPAccessRule) error { + _, err := db.Exec(`UPDATE ip_access_rules SET ip=?, type=?, note=?, enabled=? WHERE id=?`, + rule.IP, rule.Type, rule.Note, rule.Enabled, rule.ID) + return err +} + +// GetEnabledIPRules 获取所有启用的IP规则 +func (db *DB) GetEnabledIPRules() ([]IPAccessRule, error) { + rows, err := db.Query(`SELECT id, ip, type, note, enabled, created_at + FROM ip_access_rules WHERE enabled=1 ORDER BY id`) + if err != nil { + return nil, err + } + defer rows.Close() + + var rules []IPAccessRule + for rows.Next() { + var rule IPAccessRule + var enabled int + if err := rows.Scan(&rule.ID, &rule.IP, &rule.Type, &rule.Note, &enabled, &rule.CreatedAt); err != nil { + return nil, err + } + rule.Enabled = enabled == 1 + rules = append(rules, rule) + } + return rules, nil +} diff --git a/ftp/server.go b/ftp/server.go index 45fb447..4e8f08a 100644 --- a/ftp/server.go +++ b/ftp/server.go @@ -4,7 +4,9 @@ import ( "crypto/tls" "fmt" "log" + "net" "os" + "strings" "sync" "time" @@ -92,6 +94,17 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) { // ClientConnected 客户端连接 func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { + // IP白名单/黑名单检查 + clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String()) + if err != nil { + clientIP = cc.RemoteAddr().String() + } + + if err := s.checkIPAccess(clientIP); err != nil { + log.Printf("IP %s 被拒绝连接: %v", clientIP, err) + return "", fmt.Errorf("连接被拒绝: %s", err) + } + return "220 Welcome to FTP Server\r\n", nil } @@ -183,3 +196,97 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) func (s *Server) GetTLSConfig() (*tls.Config, error) { return nil, fmt.Errorf("TLS未配置") } + +// checkIPAccess 检查IP是否允许访问 +func (s *Server) checkIPAccess(clientIP string) error { + if s.db == nil { + return nil + } + + rules, err := s.db.GetEnabledIPRules() + if err != nil { + return nil // 查询失败时允许连接 + } + + var whitelists, blacklists []database.IPAccessRule + for _, rule := range rules { + if rule.Type == "whitelist" { + whitelists = append(whitelists, rule) + } else if rule.Type == "blacklist" { + blacklists = append(blacklists, rule) + } + } + + // 如果有白名单规则,只允许白名单中的IP + if len(whitelists) > 0 { + matched := false + for _, rule := range whitelists { + if matchIP(clientIP, rule.IP) { + matched = true + break + } + } + if !matched { + return fmt.Errorf("IP不在白名单中") + } + } + + // 检查黑名单 + for _, rule := range blacklists { + if matchIP(clientIP, rule.IP) { + return fmt.Errorf("IP已被列入黑名单") + } + } + + return nil +} + +// matchIP 检查IP是否匹配规则 +func matchIP(clientIP, rule string) bool { + // 单个IP + if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") { + return clientIP == rule + } + + // CIDR 表示法 (192.168.1.0/24) + if strings.Contains(rule, "/") { + _, ipNet, err := net.ParseCIDR(rule) + if err != nil { + return clientIP == rule + } + ip := net.ParseIP(clientIP) + if ip == nil { + return false + } + return ipNet.Contains(ip) + } + + // IP范围 (192.168.1.1-192.168.1.100) + if strings.Contains(rule, "-") { + parts := strings.SplitN(rule, "-", 2) + startIP := net.ParseIP(strings.TrimSpace(parts[0])) + endIP := net.ParseIP(strings.TrimSpace(parts[1])) + ip := net.ParseIP(clientIP) + if startIP == nil || endIP == nil || ip == nil { + return false + } + return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0 + } + + return false +} + +// bytesCompare 比较两个IP的字节 +func bytesCompare(a, b net.IP) int { + a = a.To16() + b = b.To16() + for i := range a { + if a[i] < b[i] { + return -1 + } + if a[i] > b[i] { + return 1 + } + } + return 0 +} diff --git a/static/index.html b/static/index.html index 7e36a4e..daab2e6 100644 --- a/static/index.html +++ b/static/index.html @@ -49,6 +49,9 @@
+ 规则说明:
+ - 支持单IP(如 192.168.1.1)、CIDR(如 192.168.1.0/24)、IP范围(如 192.168.1.1-192.168.1.100)
+ - 白名单:启用后只有白名单中的IP才能连接,黑名单中的IP会被拒绝
+ - 黑名单:黑名单中的IP将被禁止连接
+ - 如果没有白名单规则,则所有IP默认允许(除非在黑名单中)
+
| ID | +IP地址/网段 | +类型 | +备注 | +状态 | +创建时间 | +操作 | +
|---|
${r.ip}