package ftp import ( "crypto/tls" "fmt" "log" "net" "os" "strings" "sync" "time" "ftp-server/config" "ftp-server/database" ftpserver "github.com/fclairamb/ftpserverlib" "github.com/spf13/afero" ) // Server FTP服务器 type Server struct { config *config.Config db *database.DB ftpServer *ftpserver.FtpServer onlineMu sync.RWMutex onlineUsers map[string]*database.OnlineUser } // NewServer 创建FTP服务器 func NewServer(cfg *config.Config, db *database.DB) *Server { return &Server{ config: cfg, db: db, onlineUsers: make(map[string]*database.OnlineUser), } } // Start 启动FTP服务器 func (s *Server) Start() error { ftpCfg := s.config.Get().FTP // 确保FTP根目录存在 if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil { return fmt.Errorf("创建FTP根目录失败: %w", err) } server := ftpserver.NewFtpServer(s) s.ftpServer = server go func() { if err := server.ListenAndServe(); err != nil { log.Printf("FTP服务器错误: %v", err) } }() log.Printf("FTP服务器已启动: %s:%d", ftpCfg.Host, ftpCfg.Port) return nil } // Stop 停止FTP服务器 func (s *Server) Stop() { if s.ftpServer != nil { s.ftpServer.Stop() log.Println("FTP服务器已停止") } } // GetOnlineUsers 获取在线用户列表 func (s *Server) GetOnlineUsers() []database.OnlineUser { s.onlineMu.RLock() defer s.onlineMu.RUnlock() result := make([]database.OnlineUser, 0, len(s.onlineUsers)) for _, u := range s.onlineUsers { result = append(result, *u) } return result } // --- 实现 ftpserverlib.MainDriver 接口 --- // GetSettings 返回FTP服务器设置 func (s *Server) GetSettings() (*ftpserver.Settings, error) { ftpCfg := s.config.Get().FTP return &ftpserver.Settings{ ListenAddr: fmt.Sprintf("%s:%d", ftpCfg.Host, ftpCfg.Port), PassiveTransferPortRange: ftpserver.PortRange{ Start: ftpCfg.PassivePortMin, End: ftpCfg.PassivePortMax, }, ConnectionTimeout: int(time.Duration(ftpCfg.IdleTimeout) * time.Second), }, nil } // ClientConnected 客户端连接(只检查全局规则) func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String()) if err != nil { clientIP = cc.RemoteAddr().String() } if err := s.checkIPAccess(clientIP, ""); err != nil { log.Printf("IP %s 被拒绝连接(全局规则): %v", clientIP, err) return "", fmt.Errorf("连接被拒绝: %s", err) } return "220 Welcome to FTP Server\r\n", nil } // ClientDisconnected 客户端断开 func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) { s.onlineMu.Lock() defer s.onlineMu.Unlock() for id, u := range s.onlineUsers { if u.IP == cc.RemoteAddr().String() { delete(s.onlineUsers, id) break } } } // AuthUser 认证用户 func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) (ftpserver.ClientDriver, error) { ftpCfg := s.config.Get().FTP // 匿名登录 if username == "anonymous" { if !ftpCfg.EnableAnonymous { return nil, fmt.Errorf("匿名访问未启用") } if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil { return nil, fmt.Errorf("创建根目录失败") } osFs := afero.NewOsFs() boundedFs := afero.NewBasePathFs(osFs, ftpCfg.RootDir) return newLoggingFs(boundedFs, s.db, "anonymous"), nil } // 数据库用户认证 user, err := s.db.GetUser(username) if err != nil { return nil, fmt.Errorf("认证失败") } if user == nil || !user.Enabled { return nil, fmt.Errorf("用户不存在或已禁用") } if user.Password != password { s.db.AddLog(&database.FTPLog{ Username: username, IP: cc.RemoteAddr().String(), Action: "login_failed", Status: "failed", }) return nil, fmt.Errorf("密码错误") } // 检查用户级别IP规则 clientIP, _, _ := net.SplitHostPort(cc.RemoteAddr().String()) if clientIP == "" { clientIP = cc.RemoteAddr().String() } if err := s.checkIPAccess(clientIP, username); err != nil { log.Printf("用户 %s IP %s 被拒绝: %v", username, clientIP, err) s.db.AddLog(&database.FTPLog{ Username: username, IP: cc.RemoteAddr().String(), Action: "login_blocked", Status: "blocked", }) return nil, fmt.Errorf("登录被拒绝: %s", err) } // 记录登录日志 s.db.AddLog(&database.FTPLog{ Username: username, IP: cc.RemoteAddr().String(), Action: "login", Status: "success", }) // 记录在线用户 s.onlineMu.Lock() s.onlineUsers[username+"_"+cc.RemoteAddr().String()] = &database.OnlineUser{ Username: username, IP: cc.RemoteAddr().String(), LoginTime: time.Now(), LastActivity: time.Now(), CurrentDir: user.HomeDir, } s.onlineMu.Unlock() // 确保用户目录存在(自动创建) if err := os.MkdirAll(user.HomeDir, 0755); err != nil { return nil, fmt.Errorf("创建用户目录失败: %v", err) } // 返回 afero.Fs 作为 ClientDriver(带日志包装) osFs := afero.NewOsFs() boundedFs := afero.NewBasePathFs(osFs, user.HomeDir) loggedFs := newLoggingFs(boundedFs, s.db, username) // 根据权限设置只读 if user.Permissions == "read" { return afero.NewReadOnlyFs(loggedFs), nil } return loggedFs, nil } // GetTLSConfig 获取TLS配置 func (s *Server) GetTLSConfig() (*tls.Config, error) { return nil, fmt.Errorf("TLS未配置") } // checkIPAccess 检查IP是否允许访问,username为空时只检查全局规则 func (s *Server) checkIPAccess(clientIP, username string) error { if s.db == nil { return nil } rules, err := s.db.GetEnabledIPRules(username) if err != nil { return nil // 查询失败时允许连接 } // 分离全局规则和用户规则 var globalWhitelist, globalBlacklist []database.IPAccessRule var userWhitelist, userBlacklist []database.IPAccessRule for _, rule := range rules { if rule.Username == "" { if rule.Type == "whitelist" { globalWhitelist = append(globalWhitelist, rule) } else { globalBlacklist = append(globalBlacklist, rule) } } else { if rule.Type == "whitelist" { userWhitelist = append(userWhitelist, rule) } else { userBlacklist = append(userBlacklist, rule) } } } // 1. 先检查全局黑名单 for _, rule := range globalBlacklist { if matchIP(clientIP, rule.IP) { return fmt.Errorf("IP已被全局黑名单拦截") } } // 2. 检查全局白名单(如果有全局白名单,必须在其中) if len(globalWhitelist) > 0 { matched := false for _, rule := range globalWhitelist { if matchIP(clientIP, rule.IP) { matched = true break } } if !matched { return fmt.Errorf("IP不在全局白名单中") } } // 3. 检查用户黑名单 for _, rule := range userBlacklist { if matchIP(clientIP, rule.IP) { return fmt.Errorf("IP已被用户黑名单拦截") } } // 4. 检查用户白名单(如果有用户白名单,必须在其中) if len(userWhitelist) > 0 { matched := false for _, rule := range userWhitelist { if matchIP(clientIP, rule.IP) { matched = true break } } if !matched { return fmt.Errorf("IP不在用户白名单中") } } return nil } // matchIP 检查IP是否匹配规则 func matchIP(clientIP, rule string) bool { // 单个IP if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") { return clientIP == rule } // CIDR 表示法 (192.168.1.0/24) if strings.Contains(rule, "/") { _, ipNet, err := net.ParseCIDR(rule) if err != nil { return clientIP == rule } ip := net.ParseIP(clientIP) if ip == nil { return false } return ipNet.Contains(ip) } // IP范围 (192.168.1.1-192.168.1.100) if strings.Contains(rule, "-") { parts := strings.SplitN(rule, "-", 2) startIP := net.ParseIP(strings.TrimSpace(parts[0])) endIP := net.ParseIP(strings.TrimSpace(parts[1])) ip := net.ParseIP(clientIP) if startIP == nil || endIP == nil || ip == nil { return false } return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0 } return false } // bytesCompare 比较两个IP的字节 func bytesCompare(a, b net.IP) int { a = a.To16() b = b.To16() for i := range a { if a[i] < b[i] { return -1 } if a[i] > b[i] { return 1 } } return 0 }