339 lines
8.0 KiB
Go
339 lines
8.0 KiB
Go
package ftp
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"ftp-server/config"
|
|
"ftp-server/database"
|
|
|
|
ftpserver "github.com/fclairamb/ftpserverlib"
|
|
"github.com/spf13/afero"
|
|
)
|
|
|
|
// Server FTP服务器
|
|
type Server struct {
|
|
config *config.Config
|
|
db *database.DB
|
|
ftpServer *ftpserver.FtpServer
|
|
onlineMu sync.RWMutex
|
|
onlineUsers map[string]*database.OnlineUser
|
|
}
|
|
|
|
// NewServer 创建FTP服务器
|
|
func NewServer(cfg *config.Config, db *database.DB) *Server {
|
|
return &Server{
|
|
config: cfg,
|
|
db: db,
|
|
onlineUsers: make(map[string]*database.OnlineUser),
|
|
}
|
|
}
|
|
|
|
// Start 启动FTP服务器
|
|
func (s *Server) Start() error {
|
|
ftpCfg := s.config.Get().FTP
|
|
|
|
// 确保FTP根目录存在
|
|
if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil {
|
|
return fmt.Errorf("创建FTP根目录失败: %w", err)
|
|
}
|
|
|
|
server := ftpserver.NewFtpServer(s)
|
|
s.ftpServer = server
|
|
|
|
go func() {
|
|
if err := server.ListenAndServe(); err != nil {
|
|
log.Printf("FTP服务器错误: %v", err)
|
|
}
|
|
}()
|
|
|
|
log.Printf("FTP服务器已启动: %s:%d", ftpCfg.Host, ftpCfg.Port)
|
|
return nil
|
|
}
|
|
|
|
// Stop 停止FTP服务器
|
|
func (s *Server) Stop() {
|
|
if s.ftpServer != nil {
|
|
s.ftpServer.Stop()
|
|
log.Println("FTP服务器已停止")
|
|
}
|
|
}
|
|
|
|
// GetOnlineUsers 获取在线用户列表
|
|
func (s *Server) GetOnlineUsers() []database.OnlineUser {
|
|
s.onlineMu.RLock()
|
|
defer s.onlineMu.RUnlock()
|
|
|
|
result := make([]database.OnlineUser, 0, len(s.onlineUsers))
|
|
for _, u := range s.onlineUsers {
|
|
result = append(result, *u)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// --- 实现 ftpserverlib.MainDriver 接口 ---
|
|
|
|
// GetSettings 返回FTP服务器设置
|
|
func (s *Server) GetSettings() (*ftpserver.Settings, error) {
|
|
ftpCfg := s.config.Get().FTP
|
|
return &ftpserver.Settings{
|
|
ListenAddr: fmt.Sprintf("%s:%d", ftpCfg.Host, ftpCfg.Port),
|
|
PassiveTransferPortRange: ftpserver.PortRange{
|
|
Start: ftpCfg.PassivePortMin,
|
|
End: ftpCfg.PassivePortMax,
|
|
},
|
|
ConnectionTimeout: int(time.Duration(ftpCfg.IdleTimeout) * time.Second),
|
|
}, nil
|
|
}
|
|
|
|
// ClientConnected 客户端连接(只检查全局规则)
|
|
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
|
clientIP, _, err := net.SplitHostPort(cc.RemoteAddr().String())
|
|
if err != nil {
|
|
clientIP = cc.RemoteAddr().String()
|
|
}
|
|
|
|
if err := s.checkIPAccess(clientIP, ""); err != nil {
|
|
log.Printf("IP %s 被拒绝连接(全局规则): %v", clientIP, err)
|
|
return "", fmt.Errorf("连接被拒绝: %s", err)
|
|
}
|
|
|
|
return "220 Welcome to FTP Server\r\n", nil
|
|
}
|
|
|
|
// ClientDisconnected 客户端断开
|
|
func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
|
|
s.onlineMu.Lock()
|
|
defer s.onlineMu.Unlock()
|
|
|
|
for id, u := range s.onlineUsers {
|
|
if u.IP == cc.RemoteAddr().String() {
|
|
delete(s.onlineUsers, id)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// AuthUser 认证用户
|
|
func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) (ftpserver.ClientDriver, error) {
|
|
ftpCfg := s.config.Get().FTP
|
|
|
|
// 匿名登录
|
|
if username == "anonymous" {
|
|
if !ftpCfg.EnableAnonymous {
|
|
return nil, fmt.Errorf("匿名访问未启用")
|
|
}
|
|
if err := os.MkdirAll(ftpCfg.RootDir, 0755); err != nil {
|
|
return nil, fmt.Errorf("创建根目录失败")
|
|
}
|
|
osFs := afero.NewOsFs()
|
|
boundedFs := afero.NewBasePathFs(osFs, ftpCfg.RootDir)
|
|
return boundedFs, nil
|
|
}
|
|
|
|
// 数据库用户认证
|
|
user, err := s.db.GetUser(username)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("认证失败")
|
|
}
|
|
if user == nil || !user.Enabled {
|
|
return nil, fmt.Errorf("用户不存在或已禁用")
|
|
}
|
|
if user.Password != password {
|
|
s.db.AddLog(&database.FTPLog{
|
|
Username: username,
|
|
IP: cc.RemoteAddr().String(),
|
|
Action: "login_failed",
|
|
Status: "failed",
|
|
})
|
|
return nil, fmt.Errorf("密码错误")
|
|
}
|
|
|
|
// 检查用户级别IP规则
|
|
clientIP, _, _ := net.SplitHostPort(cc.RemoteAddr().String())
|
|
if clientIP == "" {
|
|
clientIP = cc.RemoteAddr().String()
|
|
}
|
|
if err := s.checkIPAccess(clientIP, username); err != nil {
|
|
log.Printf("用户 %s IP %s 被拒绝: %v", username, clientIP, err)
|
|
s.db.AddLog(&database.FTPLog{
|
|
Username: username,
|
|
IP: cc.RemoteAddr().String(),
|
|
Action: "login_blocked",
|
|
Status: "blocked",
|
|
})
|
|
return nil, fmt.Errorf("登录被拒绝: %s", err)
|
|
}
|
|
|
|
// 记录登录日志
|
|
s.db.AddLog(&database.FTPLog{
|
|
Username: username,
|
|
IP: cc.RemoteAddr().String(),
|
|
Action: "login",
|
|
Status: "success",
|
|
})
|
|
|
|
// 记录在线用户
|
|
s.onlineMu.Lock()
|
|
s.onlineUsers[username+"_"+cc.RemoteAddr().String()] = &database.OnlineUser{
|
|
Username: username,
|
|
IP: cc.RemoteAddr().String(),
|
|
LoginTime: time.Now(),
|
|
LastActivity: time.Now(),
|
|
CurrentDir: user.HomeDir,
|
|
}
|
|
s.onlineMu.Unlock()
|
|
|
|
// 确保用户目录存在(自动创建)
|
|
if err := os.MkdirAll(user.HomeDir, 0755); err != nil {
|
|
return nil, fmt.Errorf("创建用户目录失败: %v", err)
|
|
}
|
|
|
|
// 返回 afero.Fs 作为 ClientDriver
|
|
osFs := afero.NewOsFs()
|
|
boundedFs := afero.NewBasePathFs(osFs, user.HomeDir)
|
|
|
|
// 根据权限设置只读
|
|
if user.Permissions == "read" {
|
|
return afero.NewReadOnlyFs(boundedFs), nil
|
|
}
|
|
|
|
return boundedFs, nil
|
|
}
|
|
|
|
// GetTLSConfig 获取TLS配置
|
|
func (s *Server) GetTLSConfig() (*tls.Config, error) {
|
|
return nil, fmt.Errorf("TLS未配置")
|
|
}
|
|
|
|
// checkIPAccess 检查IP是否允许访问,username为空时只检查全局规则
|
|
func (s *Server) checkIPAccess(clientIP, username string) error {
|
|
if s.db == nil {
|
|
return nil
|
|
}
|
|
|
|
rules, err := s.db.GetEnabledIPRules(username)
|
|
if err != nil {
|
|
return nil // 查询失败时允许连接
|
|
}
|
|
|
|
// 分离全局规则和用户规则
|
|
var globalWhitelist, globalBlacklist []database.IPAccessRule
|
|
var userWhitelist, userBlacklist []database.IPAccessRule
|
|
for _, rule := range rules {
|
|
if rule.Username == "" {
|
|
if rule.Type == "whitelist" {
|
|
globalWhitelist = append(globalWhitelist, rule)
|
|
} else {
|
|
globalBlacklist = append(globalBlacklist, rule)
|
|
}
|
|
} else {
|
|
if rule.Type == "whitelist" {
|
|
userWhitelist = append(userWhitelist, rule)
|
|
} else {
|
|
userBlacklist = append(userBlacklist, rule)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 1. 先检查全局黑名单
|
|
for _, rule := range globalBlacklist {
|
|
if matchIP(clientIP, rule.IP) {
|
|
return fmt.Errorf("IP已被全局黑名单拦截")
|
|
}
|
|
}
|
|
|
|
// 2. 检查全局白名单(如果有全局白名单,必须在其中)
|
|
if len(globalWhitelist) > 0 {
|
|
matched := false
|
|
for _, rule := range globalWhitelist {
|
|
if matchIP(clientIP, rule.IP) {
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
if !matched {
|
|
return fmt.Errorf("IP不在全局白名单中")
|
|
}
|
|
}
|
|
|
|
// 3. 检查用户黑名单
|
|
for _, rule := range userBlacklist {
|
|
if matchIP(clientIP, rule.IP) {
|
|
return fmt.Errorf("IP已被用户黑名单拦截")
|
|
}
|
|
}
|
|
|
|
// 4. 检查用户白名单(如果有用户白名单,必须在其中)
|
|
if len(userWhitelist) > 0 {
|
|
matched := false
|
|
for _, rule := range userWhitelist {
|
|
if matchIP(clientIP, rule.IP) {
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
if !matched {
|
|
return fmt.Errorf("IP不在用户白名单中")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// matchIP 检查IP是否匹配规则
|
|
func matchIP(clientIP, rule string) bool {
|
|
// 单个IP
|
|
if !strings.Contains(rule, "/") && !strings.Contains(rule, "-") {
|
|
return clientIP == rule
|
|
}
|
|
|
|
// CIDR 表示法 (192.168.1.0/24)
|
|
if strings.Contains(rule, "/") {
|
|
_, ipNet, err := net.ParseCIDR(rule)
|
|
if err != nil {
|
|
return clientIP == rule
|
|
}
|
|
ip := net.ParseIP(clientIP)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
return ipNet.Contains(ip)
|
|
}
|
|
|
|
// IP范围 (192.168.1.1-192.168.1.100)
|
|
if strings.Contains(rule, "-") {
|
|
parts := strings.SplitN(rule, "-", 2)
|
|
startIP := net.ParseIP(strings.TrimSpace(parts[0]))
|
|
endIP := net.ParseIP(strings.TrimSpace(parts[1]))
|
|
ip := net.ParseIP(clientIP)
|
|
if startIP == nil || endIP == nil || ip == nil {
|
|
return false
|
|
}
|
|
return bytesCompare(ip, startIP) >= 0 && bytesCompare(ip, endIP) <= 0
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// bytesCompare 比较两个IP的字节
|
|
func bytesCompare(a, b net.IP) int {
|
|
a = a.To16()
|
|
b = b.To16()
|
|
for i := range a {
|
|
if a[i] < b[i] {
|
|
return -1
|
|
}
|
|
if a[i] > b[i] {
|
|
return 1
|
|
}
|
|
}
|
|
return 0
|
|
}
|