feat: 添加IP白名单/黑名单功能 - 支持单IP、CIDR、IP范围

This commit is contained in:
Your Name
2026-05-07 10:05:03 +08:00
parent 0f9663045c
commit d92cf341b8
5 changed files with 467 additions and 0 deletions
+107
View File
@@ -4,7 +4,9 @@ import (
"crypto/tls"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"time"
@@ -92,6 +94,17 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
// ClientConnected 客户端连接
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
// IP白名单/黑名单检查
clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
if err != nil {
clientIP = cc.RemoteAddr().String()
}
if err := s.checkIPAccess(clientIP); err != nil {
log.Printf("IP %s 被拒绝连接: %v", clientIP, err)
return "", fmt.Errorf("连接被拒绝: %s", err)
}
return "220 Welcome to FTP Server\r\n", nil
}
@@ -183,3 +196,97 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
func (s *Server) GetTLSConfig() (*tls.Config, error) {
return nil, fmt.Errorf("TLS未配置")
}
// checkIPAccess 检查IP是否允许访问
func (s *Server) checkIPAccess(clientIP string) error {
if s.db == nil {
return nil
}
rules, err := s.db.GetEnabledIPRules()
if err != nil {
return nil // 查询失败时允许连接
}
var whitelists, blacklists []database.IPAccessRule
for _, rule := range rules {
if rule.Type == "whitelist" {
whitelists = append(whitelists, rule)
} else if rule.Type == "blacklist" {
blacklists = append(blacklists, rule)
}
}
// 如果有白名单规则,只允许白名单中的IP
if len(whitelists) > 0 {
matched := false
for _, rule := range whitelists {
if matchIP(clientIP, rule.IP) {
matched = true
break
}
}
if !matched {
return fmt.Errorf("IP不在白名单中")
}
}
// 检查黑名单
for _, rule := range blacklists {
if matchIP(clientIP, rule.IP) {
return fmt.Errorf("IP已被列入黑名单")
}
}
return nil
}
// matchIP 检查IP是否匹配规则
func matchIP(clientIP, rule string) bool {
// 单个IP
if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") {
return clientIP == rule
}
// CIDR 表示法 (192.168.1.0/24)
if strings.Contains(rule, "/") {
_, ipNet, err := net.ParseCIDR(rule)
if err != nil {
return clientIP == rule
}
ip := net.ParseIP(clientIP)
if ip == nil {
return false
}
return ipNet.Contains(ip)
}
// IP范围 (192.168.1.1-192.168.1.100)
if strings.Contains(rule, "-") {
parts := strings.SplitN(rule, "-", 2)
startIP := net.ParseIP(strings.TrimSpace(parts[0]))
endIP := net.ParseIP(strings.TrimSpace(parts[1]))
ip := net.ParseIP(clientIP)
if startIP == nil || endIP == nil || ip == nil {
return false
}
return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0
}
return false
}
// bytesCompare 比较两个IP的字节
func bytesCompare(a, b net.IP) int {
a = a.To16()
b = b.To16()
for i := range a {
if a[i] < b[i] {
return -1
}
if a[i] > b[i] {
return 1
}
}
return 0
}