diff --git a/ftp/logging.go b/ftp/logging.go new file mode 100644 index 0000000..9bbc037 --- /dev/null +++ b/ftp/logging.go @@ -0,0 +1,116 @@ +package ftp + +import ( + "io" + "os" + + "ftp-server/database" + + "github.com/spf13/afero" +) + +// loggingFs 带日志功能的文件系统包装器 +type loggingFs struct { + afero.Fs + db *database.DB + username string +} + +// loggingFile 带日志功能的文件包装器 +type loggingFile struct { + afero.File + fs *loggingFs + name string + flags int + written int64 + read int64 + logged bool +} + +// newLoggingFs 创建带日志的文件系统 +func newLoggingFs(base afero.Fs, db *database.DB, username string) *loggingFs { + return &loggingFs{ + Fs: base, + db: db, + username: username, + } +} + +// GetHandle 实现 ftpserverlib.ClientDriverExtentionFileTransfer 接口 +// flags: os.O_RDONLY 表示下载, os.O_WRONLY/os.O_CREATE 表示上传 +func (fs *loggingFs) GetHandle(name string, flags int, offset int64) (interface { + io.Reader + io.Writer + io.Seeker + io.Closer +}, error) { + file, err := fs.Fs.OpenFile(name, flags, 0666) + if err != nil { + return nil, err + } + + // 如果有偏移量(断点续传),先 Seek + if offset > 0 { + if _, err := file.Seek(offset, 0); err != nil { + file.Close() + return nil, err + } + } + + return &loggingFile{ + File: file, + fs: fs, + name: name, + flags: flags, + logged: false, + }, nil +} + +func (f *loggingFile) Write(p []byte) (int, error) { + n, err := f.File.Write(p) + f.written += int64(n) + return n, err +} + +func (f *loggingFile) Read(p []byte) (int, error) { + n, err := f.File.Read(p) + f.read += int64(n) + return n, err +} + +func (f *loggingFile) Close() error { + err := f.File.Close() + + // 避免重复记录 + if f.logged { + return err + } + f.logged = true + + if f.fs.db == nil { + return err + } + + // 判断是上传还是下载 + isWrite := (f.flags & os.O_WRONLY) != 0 || (f.flags & os.O_RDWR) != 0 || (f.flags & os.O_CREATE) != 0 + + if isWrite && f.written > 0 { + f.fs.db.AddLog(&database.FTPLog{ + Username: f.fs.username, + Action: "upload", + FilePath: f.name, + FileSize: f.written, + Status: "success", + }) + } else if !isWrite && f.read > 0 { + f.fs.db.AddLog(&database.FTPLog{ + Username: f.fs.username, + Action: "download", + FilePath: f.name, + FileSize: f.read, + Status: "success", + }) + } + + return err +} diff --git a/ftp/server.go b/ftp/server.go index 0073125..36c9f19 100644 --- a/ftp/server.go +++ b/ftp/server.go @@ -134,7 +134,7 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) } osFs := afero.NewOsFs() boundedFs := afero.NewBasePathFs(osFs, ftpCfg.RootDir) - return boundedFs, nil + return newLoggingFs(boundedFs, s.db, "anonymous"), nil } // 数据库用户认证 @@ -195,16 +195,17 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) return nil, fmt.Errorf("创建用户目录失败: %v", err) } - // 返回 afero.Fs 作为 ClientDriver + // 返回 afero.Fs 作为 ClientDriver(带日志包装) osFs := afero.NewOsFs() boundedFs := afero.NewBasePathFs(osFs, user.HomeDir) + loggedFs := newLoggingFs(boundedFs, s.db, username) // 根据权限设置只读 if user.Permissions == "read" { - return afero.NewReadOnlyFs(boundedFs), nil + return afero.NewReadOnlyFs(loggedFs), nil } - return boundedFs, nil + return loggedFs, nil } // GetTLSConfig 获取TLS配置