diff --git a/pkg/handler/sftp.go b/pkg/handler/sftp.go index a340dc63..6d973e43 100644 --- a/pkg/handler/sftp.go +++ b/pkg/handler/sftp.go @@ -77,13 +77,24 @@ func (s *SftpHandler) Filewrite(r *sftp.Request) (io.WriterAt, error) { if err != nil { return nil, err } + go func() { <-r.Context().Done() + + fileInfo, err2 := f.Stat() + if err2 != nil { + logger.Errorf("Get file %s stat err: %s", r.Filepath, err2) + return + } + + if err1 := s.recorder.ChunkedRecord(f.FTPLog, f, 0, fileInfo.Size()); err1 != nil { + logger.Errorf("Record file %s err: %s", r.Filepath, err1) + } + if err := f.Close(); err != nil { logger.Errorf("Remote sftp file %s close err: %s", r.Filepath, err) } logger.Infof("Sftp file write %s done", r.Filepath) - s.recorder.FinishFTPFile(f.FTPLog.ID) }() return NewWriterAt(f, s.recorder), err } @@ -100,20 +111,18 @@ func (s *SftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) { return nil, err } - if err1 := s.recorder.ChunkedRecord(f.FTPLog, f, 0, fileInfo.Size()); err1 != nil { - logger.Errorf("Record file %s err: %s", r.Filepath, err1) - } - - // 重置文件指针 - _, _ = f.Seek(0, io.SeekStart) go func() { <-r.Context().Done() + + if err1 := s.recorder.ChunkedRecord(f.FTPLog, f, 0, fileInfo.Size()); err1 != nil { + logger.Errorf("Record file %s err: %s", r.Filepath, err1) + } + if err2 := f.Close(); err2 != nil { logger.Errorf("Remote sftp file %s close err: %s", r.Filepath, err2) } - logger.Infof("Sftp File read %s done", r.Filepath) - s.recorder.FinishFTPFile(f.FTPLog.ID) + logger.Infof("Sftp File read %s done", r.Filepath) }() // 包裹一层,兼容 WinSCP 目录的批量下载 return NewReaderAt(f), err @@ -153,18 +162,10 @@ type clientReadWritAt struct { } func (c *clientReadWritAt) WriteAt(p []byte, off int64) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() - if err1 := c.recorder.RecordFtpChunk(c.f.FTPLog, p, off); err1 != nil { - logger.Errorf("Record write err: %s", err1) - } - _, _ = c.f.Seek(off, 0) - return c.f.Write(p) + return c.f.WriteAt(p, off) } func (c *clientReadWritAt) ReadAt(p []byte, off int64) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() return c.f.ReadAt(p, off) } diff --git a/pkg/proxy/recorder.go b/pkg/proxy/recorder.go index 7f56a933..e2e96d5e 100644 --- a/pkg/proxy/recorder.go +++ b/pkg/proxy/recorder.go @@ -351,24 +351,6 @@ func (r *FTPFileRecorder) CreateFTPFileInfo(logData *model.FTPLog) (info *FTPFil return info, nil } -func (r *FTPFileRecorder) RecordFtpChunk(ftpLog *model.FTPLog, p []byte, off int64) (err error) { - if r.isNullStorage() { - return - } - info := r.getFTPFile(ftpLog.ID) - if info == nil { - info, err = r.CreateFTPFileInfo(ftpLog) - } - if err != nil { - return - } - if info.isExceedWrittenSize() { - logger.Errorf("FTP file %s is exceeds the max limit and discard it", ftpLog.ID) - return nil - } - return info.WriteChunk(p, off) -} - func (r *FTPFileRecorder) FinishFTPFile(id string) { info := r.getFTPFile(id) if info == nil { @@ -409,7 +391,7 @@ func (r *FTPFileRecorder) ChunkedRecord(ftpLog *model.FTPLog, readerAt io.Reader return err } - if info.isExceedWrittenSize() { + if info.isExceedWrittenSize() || totalSize >= info.maxWrittenSize { logger.Errorf("FTP file %s is exceeds the max limit and discard it", ftpLog.ID) return nil } @@ -495,23 +477,6 @@ type FTPFileInfo struct { writtenBytes int64 } -func (f *FTPFileInfo) WriteChunk(p []byte, off int64) error { - var ( - nw int - err error - ) - _, err = f.fd.Seek(off, io.SeekStart) - if err != nil { - return err - } - nw, err = f.fd.Write(p) - if nw > 0 { - f.writtenBytes += int64(nw) - } - return err - -} - func (f *FTPFileInfo) WriteFromReader(r io.Reader) error { buf := make([]byte, 32*1024) var err error