feat: 支持用户级别IP黑白名单 - 可为单个用户设置独立的IP访问规则

This commit is contained in:
Your Name
2026-05-07 11:17:20 +08:00
parent d92cf341b8
commit dec263312a
6 changed files with 141 additions and 53 deletions
+65 -19
View File
@@ -92,16 +92,15 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
}, nil
}
// ClientConnected 客户端连接
// ClientConnected 客户端连接(只检查全局规则)
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
// IP白名单/黑名单检查
clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
if err != nil {
clientIP = cc.RemoteAddr().String()
}
if err := s.checkIPAccess(clientIP); err != nil {
log.Printf("IP %s 被拒绝连接: %v", clientIP, err)
if err := s.checkIPAccess(clientIP, ""); err != nil {
log.Printf("IP %s 被拒绝连接(全局规则): %v", clientIP, err)
return "", fmt.Errorf("连接被拒绝: %s", err)
}
@@ -156,6 +155,22 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
return nil, fmt.Errorf("密码错误")
}
// 检查用户级别IP规则
clientIP, _, _ := net.SplitHostPort(cc.RemoteAddr().String())
if clientIP == "" {
clientIP = cc.RemoteAddr().String()
}
if err := s.checkIPAccess(clientIP, username); err != nil {
log.Printf("用户 %s IP %s 被拒绝: %v", username, clientIP, err)
s.db.AddLog(&database.FTPLog{
Username: username,
IP: cc.RemoteAddr().String(),
Action: "login_blocked",
Status: "blocked",
})
return nil, fmt.Errorf("登录被拒绝: %s", err)
}
// 记录登录日志
s.db.AddLog(&database.FTPLog{
Username: username,
@@ -197,44 +212,75 @@ func (s *Server) GetTLSConfig() (*tls.Config, error) {
return nil, fmt.Errorf("TLS未配置")
}
// checkIPAccess 检查IP是否允许访问
func (s *Server) checkIPAccess(clientIP string) error {
// checkIPAccess 检查IP是否允许访问username为空时只检查全局规则
func (s *Server) checkIPAccess(clientIP, username string) error {
if s.db == nil {
return nil
}
rules, err := s.db.GetEnabledIPRules()
rules, err := s.db.GetEnabledIPRules(username)
if err != nil {
return nil // 查询失败时允许连接
}
var whitelists, blacklists []database.IPAccessRule
// 分离全局规则和用户规则
var globalWhitelist, globalBlacklist []database.IPAccessRule
var userWhitelist, userBlacklist []database.IPAccessRule
for _, rule := range rules {
if rule.Type == "whitelist" {
whitelists = append(whitelists, rule)
} else if rule.Type == "blacklist" {
blacklists = append(blacklists, rule)
if rule.Username == "" {
if rule.Type == "whitelist" {
globalWhitelist = append(globalWhitelist, rule)
} else {
globalBlacklist = append(globalBlacklist, rule)
}
} else {
if rule.Type == "whitelist" {
userWhitelist = append(userWhitelist, rule)
} else {
userBlacklist = append(userBlacklist, rule)
}
}
}
// 如果有白名单规则,只允许白名单中的IP
if len(whitelists) > 0 {
// 1. 先检查全局黑名单
for _, rule := range globalBlacklist {
if matchIP(clientIP, rule.IP) {
return fmt.Errorf("IP已被全局黑名单拦截")
}
}
// 2. 检查全局白名单(如果有全局白名单,必须在其中)
if len(globalWhitelist) > 0 {
matched := false
for _, rule := range whitelists {
for _, rule := range globalWhitelist {
if matchIP(clientIP, rule.IP) {
matched = true
break
}
}
if !matched {
return fmt.Errorf("IP不在白名单中")
return fmt.Errorf("IP不在全局白名单中")
}
}
// 检查黑名单
for _, rule := range blacklists {
// 3. 检查用户黑名单
for _, rule := range userBlacklist {
if matchIP(clientIP, rule.IP) {
return fmt.Errorf("IP已被列入黑名单")
return fmt.Errorf("IP已被用户黑名单拦截")
}
}
// 4. 检查用户白名单(如果有用户白名单,必须在其中)
if len(userWhitelist) > 0 {
matched := false
for _, rule := range userWhitelist {
if matchIP(clientIP, rule.IP) {
matched = true
break
}
}
if !matched {
return fmt.Errorf("IP不在用户白名单中")
}
}