| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- package ftp
- import (
- "crypto/tls"
- "fmt"
- "log"
- "net"
- "os"
- "strings"
- "sync"
- "time"
- "ftp-server/config"
- "ftp-server/database"
- ftpserver "github.com/fclairamb/ftpserverlib"
- "github.com/spf13/afero"
- )
- // Server FTP服务器
- type Server struct {
- config *config.Config
- db *database.DB
- ftpServer *ftpserver.FtpServer
- onlineMu sync.RWMutex
- onlineUsers map[string]*database.OnlineUser
- }
- // NewServer 创建FTP服务器
- func NewServer(cfg *config.Config, db *database.DB) *Server {
- return &Server{
- config: cfg,
- db: db,
- onlineUsers: make(map[string]*database.OnlineUser),
- }
- }
- // Start 启动FTP服务器
- func (s *Server) Start() error {
- ftpCfg := s.config.Get().FTP
- // 确保FTP根目录存在
- if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil {
- return fmt.Errorf("创建FTP根目录失败: %w", err)
- }
- server := ftpserver.NewFtpServer(s)
- s.ftpServer = server
- go func() {
- if err := server.ListenAndServe(); err != nil {
- log.Printf("FTP服务器错误: %v", err)
- }
- }()
- log.Printf("FTP服务器已启动: %s:%d", ftpCfg.Host, ftpCfg.Port)
- return nil
- }
- // Stop 停止FTP服务器
- func (s *Server) Stop() {
- if s.ftpServer != nil {
- s.ftpServer.Stop()
- log.Println("FTP服务器已停止")
- }
- }
- // GetOnlineUsers 获取在线用户列表
- func (s *Server) GetOnlineUsers() []database.OnlineUser {
- s.onlineMu.RLock()
- defer s.onlineMu.RUnlock()
- result := make([]database.OnlineUser, 0, len(s.onlineUsers))
- for _, u := range s.onlineUsers {
- result = append(result, *u)
- }
- return result
- }
- // --- 实现 ftpserverlib.MainDriver 接口 ---
- // GetSettings 返回FTP服务器设置
- func (s *Server) GetSettings() (*ftpserver.Settings, error) {
- ftpCfg := s.config.Get().FTP
- return &ftpserver.Settings{
- ListenAddr: fmt.Sprintf("%s:%d", ftpCfg.Host, ftpCfg.Port),
- PassiveTransferPortRange: ftpserver.PortRange{
- Start: ftpCfg.PassivePortMin,
- End: ftpCfg.PassivePortMax,
- },
- ConnectionTimeout: int(time.Duration(ftpCfg.IdleTimeout) * time.Second),
- }, nil
- }
- // ClientConnected 客户端连接(只检查全局规则)
- func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
- clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
- if err != nil {
- clientIP = cc.RemoteAddr().String()
- }
- if err := s.checkIPAccess(clientIP, ""); err != nil {
- log.Printf("IP %s 被拒绝连接(全局规则): %v", clientIP, err)
- return "", fmt.Errorf("连接被拒绝: %s", err)
- }
- return "220 Welcome to FTP Server\r\n", nil
- }
- // ClientDisconnected 客户端断开
- func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
- s.onlineMu.Lock()
- defer s.onlineMu.Unlock()
- for id, u := range s.onlineUsers {
- if u.IP == cc.RemoteAddr().String() {
- delete(s.onlineUsers, id)
- break
- }
- }
- }
- // AuthUser 认证用户
- func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) (ftpserver.ClientDriver, error) {
- ftpCfg := s.config.Get().FTP
- // 匿名登录
- if username == "anonymous" {
- if !ftpCfg.EnableAnonymous {
- return nil, fmt.Errorf("匿名访问未启用")
- }
- if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil {
- return nil, fmt.Errorf("创建根目录失败")
- }
- osFs := afero.NewOsFs()
- boundedFs := afero.NewBasePathFs(osFs, ftpCfg.RootDir)
- return newLoggingFs(boundedFs, s.db, "anonymous"), nil
- }
- // 数据库用户认证
- user, err := s.db.GetUser(username)
- if err != nil {
- return nil, fmt.Errorf("认证失败")
- }
- if user == nil || !user.Enabled {
- return nil, fmt.Errorf("用户不存在或已禁用")
- }
- if user.Password != password {
- s.db.AddLog(&database.FTPLog{
- Username: username,
- IP: cc.RemoteAddr().String(),
- Action: "login_failed",
- Status: "failed",
- })
- return nil, fmt.Errorf("密码错误")
- }
- // 检查用户级别IP规则
- clientIP, _, _ := net.SplitHostPort(cc.RemoteAddr().String())
- if clientIP == "" {
- clientIP = cc.RemoteAddr().String()
- }
- if err := s.checkIPAccess(clientIP, username); err != nil {
- log.Printf("用户 %s IP %s 被拒绝: %v", username, clientIP, err)
- s.db.AddLog(&database.FTPLog{
- Username: username,
- IP: cc.RemoteAddr().String(),
- Action: "login_blocked",
- Status: "blocked",
- })
- return nil, fmt.Errorf("登录被拒绝: %s", err)
- }
- // 记录登录日志
- s.db.AddLog(&database.FTPLog{
- Username: username,
- IP: cc.RemoteAddr().String(),
- Action: "login",
- Status: "success",
- })
- // 记录在线用户
- s.onlineMu.Lock()
- s.onlineUsers[username+"_"+cc.RemoteAddr().String()] = &database.OnlineUser{
- Username: username,
- IP: cc.RemoteAddr().String(),
- LoginTime: time.Now(),
- LastActivity: time.Now(),
- CurrentDir: user.HomeDir,
- }
- s.onlineMu.Unlock()
- // 确保用户目录存在(自动创建)
- if err := os.MkdirAll(user.HomeDir, 0755); err != nil {
- return nil, fmt.Errorf("创建用户目录失败: %v", err)
- }
- // 返回 afero.Fs 作为 ClientDriver(带日志包装)
- osFs := afero.NewOsFs()
- boundedFs := afero.NewBasePathFs(osFs, user.HomeDir)
- loggedFs := newLoggingFs(boundedFs, s.db, username)
- // 根据权限设置只读
- if user.Permissions == "read" {
- return afero.NewReadOnlyFs(loggedFs), nil
- }
- return loggedFs, nil
- }
- // GetTLSConfig 获取TLS配置
- func (s *Server) GetTLSConfig() (*tls.Config, error) {
- return nil, fmt.Errorf("TLS未配置")
- }
- // checkIPAccess 检查IP是否允许访问,username为空时只检查全局规则
- func (s *Server) checkIPAccess(clientIP, username string) error {
- if s.db == nil {
- return nil
- }
- rules, err := s.db.GetEnabledIPRules(username)
- if err != nil {
- return nil // 查询失败时允许连接
- }
- // 分离全局规则和用户规则
- var globalWhitelist, globalBlacklist []database.IPAccessRule
- var userWhitelist, userBlacklist []database.IPAccessRule
- for _, rule := range rules {
- if rule.Username == "" {
- if rule.Type == "whitelist" {
- globalWhitelist = append(globalWhitelist, rule)
- } else {
- globalBlacklist = append(globalBlacklist, rule)
- }
- } else {
- if rule.Type == "whitelist" {
- userWhitelist = append(userWhitelist, rule)
- } else {
- userBlacklist = append(userBlacklist, rule)
- }
- }
- }
- // 1. 先检查全局黑名单
- for _, rule := range globalBlacklist {
- if matchIP(clientIP, rule.IP) {
- return fmt.Errorf("IP已被全局黑名单拦截")
- }
- }
- // 2. 检查全局白名单(如果有全局白名单,必须在其中)
- if len(globalWhitelist) > 0 {
- matched := false
- for _, rule := range globalWhitelist {
- if matchIP(clientIP, rule.IP) {
- matched = true
- break
- }
- }
- if !matched {
- return fmt.Errorf("IP不在全局白名单中")
- }
- }
- // 3. 检查用户黑名单
- for _, rule := range userBlacklist {
- if matchIP(clientIP, rule.IP) {
- return fmt.Errorf("IP已被用户黑名单拦截")
- }
- }
- // 4. 检查用户白名单(如果有用户白名单,必须在其中)
- if len(userWhitelist) > 0 {
- matched := false
- for _, rule := range userWhitelist {
- if matchIP(clientIP, rule.IP) {
- matched = true
- break
- }
- }
- if !matched {
- return fmt.Errorf("IP不在用户白名单中")
- }
- }
- return nil
- }
- // matchIP 检查IP是否匹配规则
- func matchIP(clientIP, rule string) bool {
- // 单个IP
- if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") {
- return clientIP == rule
- }
- // CIDR 表示法 (192.168.1.0/24)
- if strings.Contains(rule, "/") {
- _, ipNet, err := net.ParseCIDR(rule)
- if err != nil {
- return clientIP == rule
- }
- ip := net.ParseIP(clientIP)
- if ip == nil {
- return false
- }
- return ipNet.Contains(ip)
- }
- // IP范围 (192.168.1.1-192.168.1.100)
- if strings.Contains(rule, "-") {
- parts := strings.SplitN(rule, "-", 2)
- startIP := net.ParseIP(strings.TrimSpace(parts[0]))
- endIP := net.ParseIP(strings.TrimSpace(parts[1]))
- ip := net.ParseIP(clientIP)
- if startIP == nil || endIP == nil || ip == nil {
- return false
- }
- return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0
- }
- return false
- }
- // bytesCompare 比较两个IP的字节
- func bytesCompare(a, b net.IP) int {
- a = a.To16()
- b = b.To16()
- for i := range a {
- if a[i] < b[i] {
- return -1
- }
- if a[i] > b[i] {
- return 1
- }
- }
- return 0
- }
|