Files
FTP-Server/ftp/server.go
T

339 lines
8.0 KiB
Go

package ftp
import (
"crypto/tls"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"time"
"ftp-server/config"
"ftp-server/database"
ftpserver "github.com/fclairamb/ftpserverlib"
"github.com/spf13/afero"
)
// Server FTP服务器
type Server struct {
config *config.Config
db *database.DB
ftpServer *ftpserver.FtpServer
onlineMu sync.RWMutex
onlineUsers map[string]*database.OnlineUser
}
// NewServer 创建FTP服务器
func NewServer(cfg *config.Config, db *database.DB) *Server {
return &Server{
config: cfg,
db: db,
onlineUsers: make(map[string]*database.OnlineUser),
}
}
// Start 启动FTP服务器
func (s *Server) Start() error {
ftpCfg := s.config.Get().FTP
// 确保FTP根目录存在
if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil {
return fmt.Errorf("创建FTP根目录失败: %w", err)
}
server := ftpserver.NewFtpServer(s)
s.ftpServer = server
go func() {
if err := server.ListenAndServe(); err != nil {
log.Printf("FTP服务器错误: %v", err)
}
}()
log.Printf("FTP服务器已启动: %s:%d", ftpCfg.Host, ftpCfg.Port)
return nil
}
// Stop 停止FTP服务器
func (s *Server) Stop() {
if s.ftpServer != nil {
s.ftpServer.Stop()
log.Println("FTP服务器已停止")
}
}
// GetOnlineUsers 获取在线用户列表
func (s *Server) GetOnlineUsers() []database.OnlineUser {
s.onlineMu.RLock()
defer s.onlineMu.RUnlock()
result := make([]database.OnlineUser, 0, len(s.onlineUsers))
for _, u := range s.onlineUsers {
result = append(result, *u)
}
return result
}
// --- 实现 ftpserverlib.MainDriver 接口 ---
// GetSettings 返回FTP服务器设置
func (s *Server) GetSettings() (*ftpserver.Settings, error) {
ftpCfg := s.config.Get().FTP
return &ftpserver.Settings{
ListenAddr: fmt.Sprintf("%s:%d", ftpCfg.Host, ftpCfg.Port),
PassiveTransferPortRange: ftpserver.PortRange{
Start: ftpCfg.PassivePortMin,
End: ftpCfg.PassivePortMax,
},
ConnectionTimeout: int(time.Duration(ftpCfg.IdleTimeout) * time.Second),
}, nil
}
// ClientConnected 客户端连接(只检查全局规则)
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
if err != nil {
clientIP = cc.RemoteAddr().String()
}
if err := s.checkIPAccess(clientIP, ""); err != nil {
log.Printf("IP %s 被拒绝连接(全局规则): %v", clientIP, err)
return "", fmt.Errorf("连接被拒绝: %s", err)
}
return "220 Welcome to FTP Server\r\n", nil
}
// ClientDisconnected 客户端断开
func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
s.onlineMu.Lock()
defer s.onlineMu.Unlock()
for id, u := range s.onlineUsers {
if u.IP == cc.RemoteAddr().String() {
delete(s.onlineUsers, id)
break
}
}
}
// AuthUser 认证用户
func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) (ftpserver.ClientDriver, error) {
ftpCfg := s.config.Get().FTP
// 匿名登录
if username == "anonymous" {
if !ftpCfg.EnableAnonymous {
return nil, fmt.Errorf("匿名访问未启用")
}
if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil {
return nil, fmt.Errorf("创建根目录失败")
}
osFs := afero.NewOsFs()
boundedFs := afero.NewBasePathFs(osFs, ftpCfg.RootDir)
return boundedFs, nil
}
// 数据库用户认证
user, err := s.db.GetUser(username)
if err != nil {
return nil, fmt.Errorf("认证失败")
}
if user == nil || !user.Enabled {
return nil, fmt.Errorf("用户不存在或已禁用")
}
if user.Password != password {
s.db.AddLog(&database.FTPLog{
Username: username,
IP: cc.RemoteAddr().String(),
Action: "login_failed",
Status: "failed",
})
return nil, fmt.Errorf("密码错误")
}
// 检查用户级别IP规则
clientIP, _, _ := net.SplitHostPort(cc.RemoteAddr().String())
if clientIP == "" {
clientIP = cc.RemoteAddr().String()
}
if err := s.checkIPAccess(clientIP, username); err != nil {
log.Printf("用户 %s IP %s 被拒绝: %v", username, clientIP, err)
s.db.AddLog(&database.FTPLog{
Username: username,
IP: cc.RemoteAddr().String(),
Action: "login_blocked",
Status: "blocked",
})
return nil, fmt.Errorf("登录被拒绝: %s", err)
}
// 记录登录日志
s.db.AddLog(&database.FTPLog{
Username: username,
IP: cc.RemoteAddr().String(),
Action: "login",
Status: "success",
})
// 记录在线用户
s.onlineMu.Lock()
s.onlineUsers[username+"_"+cc.RemoteAddr().String()] = &database.OnlineUser{
Username: username,
IP: cc.RemoteAddr().String(),
LoginTime: time.Now(),
LastActivity: time.Now(),
CurrentDir: user.HomeDir,
}
s.onlineMu.Unlock()
// 确保用户目录存在(自动创建)
if err := os.MkdirAll(user.HomeDir, 0755); err != nil {
return nil, fmt.Errorf("创建用户目录失败: %v", err)
}
// 返回 afero.Fs 作为 ClientDriver
osFs := afero.NewOsFs()
boundedFs := afero.NewBasePathFs(osFs, user.HomeDir)
// 根据权限设置只读
if user.Permissions == "read" {
return afero.NewReadOnlyFs(boundedFs), nil
}
return boundedFs, nil
}
// GetTLSConfig 获取TLS配置
func (s *Server) GetTLSConfig() (*tls.Config, error) {
return nil, fmt.Errorf("TLS未配置")
}
// checkIPAccess 检查IP是否允许访问,username为空时只检查全局规则
func (s *Server) checkIPAccess(clientIP, username string) error {
if s.db == nil {
return nil
}
rules, err := s.db.GetEnabledIPRules(username)
if err != nil {
return nil // 查询失败时允许连接
}
// 分离全局规则和用户规则
var globalWhitelist, globalBlacklist []database.IPAccessRule
var userWhitelist, userBlacklist []database.IPAccessRule
for _, rule := range rules {
if rule.Username == "" {
if rule.Type == "whitelist" {
globalWhitelist = append(globalWhitelist, rule)
} else {
globalBlacklist = append(globalBlacklist, rule)
}
} else {
if rule.Type == "whitelist" {
userWhitelist = append(userWhitelist, rule)
} else {
userBlacklist = append(userBlacklist, rule)
}
}
}
// 1. 先检查全局黑名单
for _, rule := range globalBlacklist {
if matchIP(clientIP, rule.IP) {
return fmt.Errorf("IP已被全局黑名单拦截")
}
}
// 2. 检查全局白名单(如果有全局白名单,必须在其中)
if len(globalWhitelist) > 0 {
matched := false
for _, rule := range globalWhitelist {
if matchIP(clientIP, rule.IP) {
matched = true
break
}
}
if !matched {
return fmt.Errorf("IP不在全局白名单中")
}
}
// 3. 检查用户黑名单
for _, rule := range userBlacklist {
if matchIP(clientIP, rule.IP) {
return fmt.Errorf("IP已被用户黑名单拦截")
}
}
// 4. 检查用户白名单(如果有用户白名单,必须在其中)
if len(userWhitelist) > 0 {
matched := false
for _, rule := range userWhitelist {
if matchIP(clientIP, rule.IP) {
matched = true
break
}
}
if !matched {
return fmt.Errorf("IP不在用户白名单中")
}
}
return nil
}
// matchIP 检查IP是否匹配规则
func matchIP(clientIP, rule string) bool {
// 单个IP
if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") {
return clientIP == rule
}
// CIDR 表示法 (192.168.1.0/24)
if strings.Contains(rule, "/") {
_, ipNet, err := net.ParseCIDR(rule)
if err != nil {
return clientIP == rule
}
ip := net.ParseIP(clientIP)
if ip == nil {
return false
}
return ipNet.Contains(ip)
}
// IP范围 (192.168.1.1-192.168.1.100)
if strings.Contains(rule, "-") {
parts := strings.SplitN(rule, "-", 2)
startIP := net.ParseIP(strings.TrimSpace(parts[0]))
endIP := net.ParseIP(strings.TrimSpace(parts[1]))
ip := net.ParseIP(clientIP)
if startIP == nil || endIP == nil || ip == nil {
return false
}
return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0
}
return false
}
// bytesCompare 比较两个IP的字节
func bytesCompare(a, b net.IP) int {
a = a.To16()
b = b.To16()
for i := range a {
if a[i] < b[i] {
return -1
}
if a[i] > b[i] {
return 1
}
}
return 0
}