feat: 添加IP白名单/黑名单功能 - 支持单IP、CIDR、IP范围
This commit is contained in:
+107
@@ -4,7 +4,9 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -92,6 +94,17 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
|
||||
|
||||
// ClientConnected 客户端连接
|
||||
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
||||
// IP白名单/黑名单检查
|
||||
clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
|
||||
if err != nil {
|
||||
clientIP = cc.RemoteAddr().String()
|
||||
}
|
||||
|
||||
if err := s.checkIPAccess(clientIP); err != nil {
|
||||
log.Printf("IP %s 被拒绝连接: %v", clientIP, err)
|
||||
return "", fmt.Errorf("连接被拒绝: %s", err)
|
||||
}
|
||||
|
||||
return "220 Welcome to FTP Server\r\n", nil
|
||||
}
|
||||
|
||||
@@ -183,3 +196,97 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
|
||||
func (s *Server) GetTLSConfig() (*tls.Config, error) {
|
||||
return nil, fmt.Errorf("TLS未配置")
|
||||
}
|
||||
|
||||
// checkIPAccess 检查IP是否允许访问
|
||||
func (s *Server) checkIPAccess(clientIP string) error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := s.db.GetEnabledIPRules()
|
||||
if err != nil {
|
||||
return nil // 查询失败时允许连接
|
||||
}
|
||||
|
||||
var whitelists, blacklists []database.IPAccessRule
|
||||
for _, rule := range rules {
|
||||
if rule.Type == "whitelist" {
|
||||
whitelists = append(whitelists, rule)
|
||||
} else if rule.Type == "blacklist" {
|
||||
blacklists = append(blacklists, rule)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有白名单规则,只允许白名单中的IP
|
||||
if len(whitelists) > 0 {
|
||||
matched := false
|
||||
for _, rule := range whitelists {
|
||||
if matchIP(clientIP, rule.IP) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return fmt.Errorf("IP不在白名单中")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查黑名单
|
||||
for _, rule := range blacklists {
|
||||
if matchIP(clientIP, rule.IP) {
|
||||
return fmt.Errorf("IP已被列入黑名单")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchIP 检查IP是否匹配规则
|
||||
func matchIP(clientIP, rule string) bool {
|
||||
// 单个IP
|
||||
if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") {
|
||||
return clientIP == rule
|
||||
}
|
||||
|
||||
// CIDR 表示法 (192.168.1.0/24)
|
||||
if strings.Contains(rule, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(rule)
|
||||
if err != nil {
|
||||
return clientIP == rule
|
||||
}
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ipNet.Contains(ip)
|
||||
}
|
||||
|
||||
// IP范围 (192.168.1.1-192.168.1.100)
|
||||
if strings.Contains(rule, "-") {
|
||||
parts := strings.SplitN(rule, "-", 2)
|
||||
startIP := net.ParseIP(strings.TrimSpace(parts[0]))
|
||||
endIP := net.ParseIP(strings.TrimSpace(parts[1]))
|
||||
ip := net.ParseIP(clientIP)
|
||||
if startIP == nil || endIP == nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// bytesCompare 比较两个IP的字节
|
||||
func bytesCompare(a, b net.IP) int {
|
||||
a = a.To16()
|
||||
b = b.To16()
|
||||
for i := range a {
|
||||
if a[i] < b[i] {
|
||||
return -1
|
||||
}
|
||||
if a[i] > b[i] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user