diff --git a/Makefile b/Makefile index 090f203db..bd04e28ae 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,7 @@ gen-mock: $(MOCKGEN) mockgen -typed -destination=./internal/api/auth/fake/mock_ticket_store.go -package=fake github.com/artefactual-sdps/enduro/internal/api/auth TicketStore mockgen -typed -destination=./internal/package_/fake/mock_package_.go -package=fake github.com/artefactual-sdps/enduro/internal/package_ Service mockgen -typed -destination=./internal/persistence/fake/mock_persistence.go -package=fake github.com/artefactual-sdps/enduro/internal/persistence Service - mockgen -typed -destination=./internal/sftp/fake/mock_sftp.go -package=fake github.com/artefactual-sdps/enduro/internal/sftp Client + mockgen -typed -destination=./internal/sftp/fake/mock_sftp.go -package=fake github.com/artefactual-sdps/enduro/internal/sftp Client,AsyncUpload mockgen -typed -destination=./internal/storage/fake/mock_storage.go -package=fake github.com/artefactual-sdps/enduro/internal/storage Service mockgen -typed -destination=./internal/storage/persistence/fake/mock_persistence.go -package=fake github.com/artefactual-sdps/enduro/internal/storage/persistence Storage mockgen -typed -destination=./internal/upload/fake/mock_upload.go -package=fake github.com/artefactual-sdps/enduro/internal/upload Service diff --git a/cmd/enduro-am-worker/main.go b/cmd/enduro-am-worker/main.go index 819393c9e..56200a88f 100644 --- a/cmd/enduro-am-worker/main.go +++ b/cmd/enduro-am-worker/main.go @@ -151,7 +151,7 @@ func main() { activities.NewZipActivity(logger).Execute, temporalsdk_activity.RegisterOptions{Name: activities.ZipActivityName}, ) w.RegisterActivityWithOptions( - am.NewUploadTransferActivity(logger, sftpClient).Execute, + am.NewUploadTransferActivity(logger, sftpClient, cfg.AM.PollInterval).Execute, temporalsdk_activity.RegisterOptions{Name: am.UploadTransferActivityName}, ) w.RegisterActivityWithOptions( diff --git a/internal/am/delete_transfer_test.go b/internal/am/delete_transfer_test.go index 82557f411..678cfe224 100644 --- a/internal/am/delete_transfer_test.go +++ b/internal/am/delete_transfer_test.go @@ -27,10 +27,10 @@ func TestDeleteTransferActivity(t *testing.T) { ) type test struct { - name string - params am.DeleteTransferActivityParams - recorder func(*sftp_fake.MockClientMockRecorder, am.DeleteTransferActivityParams) - errMsg string + name string + params am.DeleteTransferActivityParams + mock func(*gomock.Controller) *sftp_fake.MockClient + errMsg string } for _, tt := range []test{ { @@ -38,11 +38,13 @@ func TestDeleteTransferActivity(t *testing.T) { params: am.DeleteTransferActivityParams{ Destination: td.Path(), }, - recorder: func(m *sftp_fake.MockClientMockRecorder, params am.DeleteTransferActivityParams) { - m.Delete( - mockutil.Context(), - params.Destination, - ).Return(nil) + mock: func(ctrl *gomock.Controller) *sftp_fake.MockClient { + client := sftp_fake.NewMockClient(ctrl) + client.EXPECT(). + Delete(mockutil.Context(), td.Path()). + Return(nil) + + return client }, }, { @@ -50,13 +52,15 @@ func TestDeleteTransferActivity(t *testing.T) { params: am.DeleteTransferActivityParams{ Destination: td.Join("missing"), }, - recorder: func(m *sftp_fake.MockClientMockRecorder, params am.DeleteTransferActivityParams) { - m.Delete( - mockutil.Context(), - params.Destination, - ).Return( - errors.New("SFTP: unable to remove file \"test.txt\": file does not exist"), - ) + mock: func(ctrl *gomock.Controller) *sftp_fake.MockClient { + client := sftp_fake.NewMockClient(ctrl) + client.EXPECT(). + Delete(mockutil.Context(), td.Join("missing")). + Return( + errors.New("SFTP: unable to remove file \"test.txt\": file does not exist"), + ) + + return client }, errMsg: fmt.Sprintf("delete transfer: path: %q: %v", td.Join("missing"), errors.New("SFTP: unable to remove file \"test.txt\": file does not exist")), }, @@ -65,13 +69,15 @@ func TestDeleteTransferActivity(t *testing.T) { params: am.DeleteTransferActivityParams{ Destination: td.Join(filename), }, - recorder: func(m *sftp_fake.MockClientMockRecorder, params am.DeleteTransferActivityParams) { - m.Delete( - mockutil.Context(), - params.Destination, - ).Return( - errors.New("SSH: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused"), - ) + mock: func(ctrl *gomock.Controller) *sftp_fake.MockClient { + client := sftp_fake.NewMockClient(ctrl) + client.EXPECT(). + Delete(mockutil.Context(), td.Join(filename)). + Return( + errors.New("SSH: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused"), + ) + + return client }, errMsg: fmt.Sprintf("delete transfer: path: %q: %v", td.Join(filename), errors.New("SSH: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused")), }, @@ -82,14 +88,10 @@ func TestDeleteTransferActivity(t *testing.T) { ts := &temporalsdk_testsuite.WorkflowTestSuite{} env := ts.NewTestActivityEnvironment() - msvc := sftp_fake.NewMockClient(gomock.NewController(t)) - - if tt.recorder != nil { - tt.recorder(msvc.EXPECT(), tt.params) - } + ctrl := gomock.NewController(t) env.RegisterActivityWithOptions( - am.NewDeleteTransferActivity(logr.Discard(), msvc).Execute, + am.NewDeleteTransferActivity(logr.Discard(), tt.mock(ctrl)).Execute, temporalsdk_activity.RegisterOptions{ Name: am.DeleteTransferActivityName, }, diff --git a/internal/am/upload_transfer.go b/internal/am/upload_transfer.go index 5597aed64..d973aa318 100644 --- a/internal/am/upload_transfer.go +++ b/internal/am/upload_transfer.go @@ -5,9 +5,11 @@ import ( "fmt" "os" "path/filepath" + "time" "github.com/go-logr/logr" "go.artefactual.dev/tools/temporal" + temporalsdk_activity "go.temporal.io/sdk/activity" "github.com/artefactual-sdps/enduro/internal/sftp" ) @@ -15,26 +17,42 @@ import ( const UploadTransferActivityName = "UploadTransferActivity" type UploadTransferActivityParams struct { + // Local path of the source file. SourcePath string } type UploadTransferActivityResult struct { - BytesCopied int64 - // Full path including `remoteDir` config path. + // Bytes copied to the remote file over the SFTP connection. + BytesCopied uint64 + // Full path of the destination file including `remoteDir` config path. RemoteFullPath string - // Relative path to the `remoteDir` config path. + // Path of the destination file relative to the `remoteDir` config path. RemoteRelativePath string } +// UploadTransferActivity uploads a transfer via the SFTP client, and sends +// a periodic Temporal Heartbeat at the given heartRate. type UploadTransferActivity struct { - client sftp.Client - logger logr.Logger + client sftp.Client + logger logr.Logger + heartRate time.Duration } -func NewUploadTransferActivity(logger logr.Logger, client sftp.Client) *UploadTransferActivity { - return &UploadTransferActivity{client: client, logger: logger} +// NewUploadTransferActivity initializes and returns a new +// UploadTransferActivity. +func NewUploadTransferActivity( + logger logr.Logger, + client sftp.Client, + heartRate time.Duration, +) *UploadTransferActivity { + return &UploadTransferActivity{ + client: client, + logger: logger, + heartRate: heartRate, + } } +// Execute copies the source transfer to the destination via SFTP. func (a *UploadTransferActivity) Execute(ctx context.Context, params *UploadTransferActivityParams) (*UploadTransferActivityResult, error) { a.logger.V(1).Info("Execute UploadTransferActivity", "SourcePath", params.SourcePath) @@ -45,7 +63,7 @@ func (a *UploadTransferActivity) Execute(ctx context.Context, params *UploadTran defer src.Close() filename := filepath.Base(params.SourcePath) - bytes, path, err := a.client.Upload(ctx, src, filename) + path, upload, err := a.client.Upload(ctx, src, filename) if err != nil { e := fmt.Errorf("%s: %v", UploadTransferActivityName, err) @@ -57,9 +75,43 @@ func (a *UploadTransferActivity) Execute(ctx context.Context, params *UploadTran } } + fi, err := src.Stat() + if err != nil { + return nil, fmt.Errorf("%s: %v", UploadTransferActivityName, err) + } + + // Block (with a heartbeat) until ctx is cancelled, the upload is done, or + // it stops with an error. + err = a.Heartbeat(ctx, upload, fi.Size()) + if err != nil { + return nil, err + } + return &UploadTransferActivityResult{ - BytesCopied: bytes, + BytesCopied: upload.Bytes(), RemoteFullPath: path, RemoteRelativePath: filename, }, nil } + +// Heartbeat sends a periodic Temporal heartbeat, which includes the number of +// bytes uploaded, until the upload is complete, cancelled or returns an error. +func (a *UploadTransferActivity) Heartbeat(ctx context.Context, upload sftp.AsyncUpload, fileSize int64) error { + ticker := time.NewTicker(a.heartRate) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-upload.Err(): + return err + case <-upload.Done(): + return nil + case <-ticker.C: + temporalsdk_activity.RecordHeartbeat(ctx, + fmt.Sprintf("Uploaded %d bytes of %d.", upload.Bytes(), fileSize), + ) + } + } +} diff --git a/internal/am/upload_transfer_test.go b/internal/am/upload_transfer_test.go index 125fe967c..55b93cd7e 100644 --- a/internal/am/upload_transfer_test.go +++ b/internal/am/upload_transfer_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/go-logr/logr" "go.artefactual.dev/tools/mockutil" @@ -29,7 +30,7 @@ func TestUploadTransferActivity(t *testing.T) { type test struct { name string params am.UploadTransferActivityParams - recorder func(*sftp_fake.MockClientMockRecorder) + mock func(*gomock.Controller) (sftp.Client, sftp.AsyncUpload) want am.UploadTransferActivityResult wantErr string wantNonRetryErr bool @@ -40,16 +41,36 @@ func TestUploadTransferActivity(t *testing.T) { params: am.UploadTransferActivityParams{ SourcePath: td.Join(filename), }, - recorder: func(m *sftp_fake.MockClientMockRecorder) { - var t *os.File - m.Upload( - mockutil.Context(), - gomock.AssignableToTypeOf(t), - filename, - ).Return(int64(14), "/transfer_dir/"+filename, nil) + mock: func(ctrl *gomock.Controller) (sftp.Client, sftp.AsyncUpload) { + var fp *os.File + + client := sftp_fake.NewMockClient(ctrl) + upload := sftp_fake.NewMockAsyncUpload(ctrl) + + client.EXPECT(). + Upload( + mockutil.Context(), + gomock.AssignableToTypeOf(fp), + filename, + ). + Return("/transfer_dir/"+filename, upload, nil) + + doneCh := make(chan bool, 1) + upload.EXPECT().Done().Return(doneCh).Times(2) + + errCh := make(chan error, 1) + upload.EXPECT().Err().Return(errCh).Times(2) + + upload.EXPECT().Bytes().DoAndReturn(func() uint64 { + doneCh <- true + return uint64(7) + }) + upload.EXPECT().Bytes().Return(14) + + return client, upload }, want: am.UploadTransferActivityResult{ - BytesCopied: int64(14), + BytesCopied: uint64(14), RemoteFullPath: "/transfer_dir/" + filename, RemoteRelativePath: filename, }, @@ -66,17 +87,23 @@ func TestUploadTransferActivity(t *testing.T) { params: am.UploadTransferActivityParams{ SourcePath: td.Join(filename), }, - recorder: func(m *sftp_fake.MockClientMockRecorder) { - var t *os.File - m.Upload( - mockutil.Context(), - gomock.AssignableToTypeOf(t), - filename, - ).Return( - 0, - "", - errors.New("ssh: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused"), - ) + mock: func(ctrl *gomock.Controller) (sftp.Client, sftp.AsyncUpload) { + var fp *os.File + + client := sftp_fake.NewMockClient(ctrl) + client.EXPECT(). + Upload( + mockutil.Context(), + gomock.AssignableToTypeOf(fp), + filename, + ). + Return( + "", + nil, + errors.New("ssh: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused"), + ) + + return client, nil }, wantErr: "activity error (type: UploadTransferActivity, scheduledEventID: 0, startedEventID: 0, identity: ): UploadTransferActivity: ssh: failed to connect: dial tcp 127.0.0.1:2200: connect: connection refused", }, @@ -85,19 +112,25 @@ func TestUploadTransferActivity(t *testing.T) { params: am.UploadTransferActivityParams{ SourcePath: td.Join(filename), }, - recorder: func(m *sftp_fake.MockClientMockRecorder) { - var t *os.File - m.Upload( - mockutil.Context(), - gomock.AssignableToTypeOf(t), - filename, - ).Return( - 0, - "", - &sftp.AuthError{ - Message: "ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain", - }, - ) + mock: func(ctrl *gomock.Controller) (sftp.Client, sftp.AsyncUpload) { + var fp *os.File + + client := sftp_fake.NewMockClient(ctrl) + client.EXPECT(). + Upload( + mockutil.Context(), + gomock.AssignableToTypeOf(fp), + filename, + ). + Return( + "", + nil, + &sftp.AuthError{ + Message: "ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain", + }, + ) + + return client, nil }, wantErr: "activity error (type: UploadTransferActivity, scheduledEventID: 0, startedEventID: 0, identity: ): UploadTransferActivity: auth: ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain", wantNonRetryErr: true, @@ -109,14 +142,15 @@ func TestUploadTransferActivity(t *testing.T) { ts := &temporalsdk_testsuite.WorkflowTestSuite{} env := ts.NewTestActivityEnvironment() - msvc := sftp_fake.NewMockClient(gomock.NewController(t)) + ctrl := gomock.NewController(t) - if tt.recorder != nil { - tt.recorder(msvc.EXPECT()) + var client sftp.Client + if tt.mock != nil { + client, _ = tt.mock(ctrl) } env.RegisterActivityWithOptions( - am.NewUploadTransferActivity(logr.Discard(), msvc).Execute, + am.NewUploadTransferActivity(logr.Discard(), client, 2*time.Millisecond).Execute, temporalsdk_activity.RegisterOptions{ Name: am.UploadTransferActivityName, }, diff --git a/internal/sftp/async_upload.go b/internal/sftp/async_upload.go new file mode 100644 index 000000000..c80fc3dc7 --- /dev/null +++ b/internal/sftp/async_upload.go @@ -0,0 +1,57 @@ +package sftp + +import "sync/atomic" + +// AsyncUploadImpl provides an asynchronous upload implementation. +type AsyncUploadImpl struct { + conn *connection + done chan bool + err chan error + + bytes atomic.Uint64 +} + +var _ AsyncUpload = (*AsyncUploadImpl)(nil) + +// NewAsyncUpload returns an initialized AsyncUploadImpl struct that wraps the +// underlying SFTP connection. +func NewAsyncUpload(conn *connection) AsyncUploadImpl { + return AsyncUploadImpl{ + conn: conn, + done: make(chan bool, 1), + err: make(chan error, 1), + } +} + +// Bytes returns the current number of bytes uploaded. +func (u *AsyncUploadImpl) Bytes() uint64 { + return uint64(u.bytes.Load()) +} + +// Close closes SFTP connection used for the upload. Close must be called +// when the upload is complete to prevent memory leaks. +func (u *AsyncUploadImpl) Close() error { + return u.conn.Close() +} + +// Done returns a done channel that receives a true value when the upload is +// complete. +func (u *AsyncUploadImpl) Done() chan bool { + return u.done +} + +// Error returns an error channel that receives an error if the upload +// encounters an error. +func (u *AsyncUploadImpl) Err() chan error { + return u.err +} + +// Write adds the length of p to the total number of bytes written on the +// connection. +// +// Write implements the io.Writer interface. +func (u *AsyncUploadImpl) Write(p []byte) (int, error) { + n := len(p) + u.bytes.Add(uint64(n)) + return n, nil +} diff --git a/internal/sftp/client.go b/internal/sftp/client.go index 159700d41..128eb717a 100644 --- a/internal/sftp/client.go +++ b/internal/sftp/client.go @@ -6,14 +6,18 @@ import ( "io" ) +// AuthError represents an SFTP authentication error. type AuthError struct { Message string } +// Error implements the error interface. func (e *AuthError) Error() string { return fmt.Sprintf("auth: %s", e.Message) } +// NewAuthError returns a pointer to a new AuthError from the underlying |e| +// error message. func NewAuthError(e error) error { return &AuthError{Message: e.Error()} } @@ -24,9 +28,28 @@ func NewAuthError(e error) error { // authentication, and other intricacies associated with different SFTP // servers and protocols. type Client interface { - // Upload transfers data from the provided source reader to a specified - // destination on the SFTP server. - Upload(ctx context.Context, src io.Reader, dest string) (bytes int64, remotePath string, err error) - // Delete removes data from the provided dest on the SFTP server. - Delete(ctx context.Context, dest string) (err error) + // Delete removes dest from the SFTP server. + Delete(ctx context.Context, dest string) error + // Upload asynchronously copies data from the src reader to the specified + // dest on the SFTP server. + Upload(ctx context.Context, src io.Reader, dest string) (remotePath string, upload AsyncUpload, err error) +} + +// AsyncUpload provides information about an upload happening asynchronously in +// a separate goroutine. +type AsyncUpload interface { + // Bytes returns the number of bytes copied to the SFTP destination. + Bytes() uint64 + // Close closes SFTP connection used for the upload. Close must be called + // when the upload is complete to prevent memory leaks. + Close() error + // Done returns a channel that receives a true value when the upload is + // complete. A done signal should not be sent on error. + Done() chan bool + // Done returns a channel that receives an error if the upload encounters + // an error. + Err() chan error + // Write implements the io.Writer interface and adds len(p) to the count of + // bytes uploaded. + Write(p []byte) (int, error) } diff --git a/internal/sftp/config.go b/internal/sftp/config.go index e3cb84389..750c22413 100644 --- a/internal/sftp/config.go +++ b/internal/sftp/config.go @@ -5,6 +5,7 @@ import ( "path/filepath" ) +// Config represents the configuration needed to connect to an SFTP server. type Config struct { // Host address, e.g. 127.0.0.1 (default), sftp.example.org. Host string @@ -28,12 +29,13 @@ type Config struct { RemoteDir string } +// PrivateKey represents a SSH private key, with an optional passphrase. type PrivateKey struct { // Path to private key file used for authentication (default: // "$HOME/.ssh/id_rsa") Path string - // Passphrase (if any) used to decrypt private key. + // Passphrase (if any) used to decrypt the private key. Passphrase string } diff --git a/internal/sftp/connection.go b/internal/sftp/connection.go new file mode 100644 index 000000000..16befb02c --- /dev/null +++ b/internal/sftp/connection.go @@ -0,0 +1,33 @@ +package sftp + +import ( + "errors" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// connection represents an SFTP connection and the underlying SSH connection. +type connection struct { + *sftp.Client + sshClient *ssh.Client +} + +// Close closes the SFTP connection then the underlying SSH connection. +func (c *connection) Close() error { + var errs error + + if c.Client != nil { + if err := c.Client.Close(); err != nil { + errs = errors.Join(err, errs) + } + } + + if c.sshClient != nil { + if err := c.sshClient.Close(); err != nil { + errs = errors.Join(err, errs) + } + } + + return errs +} diff --git a/internal/sftp/fake/mock_sftp.go b/internal/sftp/fake/mock_sftp.go index ade6e9b01..f3eea63d7 100644 --- a/internal/sftp/fake/mock_sftp.go +++ b/internal/sftp/fake/mock_sftp.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/artefactual-sdps/enduro/internal/sftp (interfaces: Client) +// Source: github.com/artefactual-sdps/enduro/internal/sftp (interfaces: Client,AsyncUpload) // // Generated by this command: // -// mockgen -typed -destination=./internal/sftp/fake/mock_sftp.go -package=fake github.com/artefactual-sdps/enduro/internal/sftp Client +// mockgen -typed -destination=./internal/sftp/fake/mock_sftp.go -package=fake github.com/artefactual-sdps/enduro/internal/sftp Client,AsyncUpload // // Package fake is a generated GoMock package. package fake @@ -13,6 +13,7 @@ import ( io "io" reflect "reflect" + sftp "github.com/artefactual-sdps/enduro/internal/sftp" gomock "go.uber.org/mock/gomock" ) @@ -78,11 +79,11 @@ func (c *ClientDeleteCall) DoAndReturn(f func(context.Context, string) error) *C } // Upload mocks base method. -func (m *MockClient) Upload(arg0 context.Context, arg1 io.Reader, arg2 string) (int64, string, error) { +func (m *MockClient) Upload(arg0 context.Context, arg1 io.Reader, arg2 string) (string, sftp.AsyncUpload, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Upload", arg0, arg1, arg2) - ret0, _ := ret[0].(int64) - ret1, _ := ret[1].(string) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(sftp.AsyncUpload) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } @@ -100,19 +101,233 @@ type ClientUploadCall struct { } // Return rewrite *gomock.Call.Return -func (c *ClientUploadCall) Return(arg0 int64, arg1 string, arg2 error) *ClientUploadCall { +func (c *ClientUploadCall) Return(arg0 string, arg1 sftp.AsyncUpload, arg2 error) *ClientUploadCall { c.Call = c.Call.Return(arg0, arg1, arg2) return c } // Do rewrite *gomock.Call.Do -func (c *ClientUploadCall) Do(f func(context.Context, io.Reader, string) (int64, string, error)) *ClientUploadCall { +func (c *ClientUploadCall) Do(f func(context.Context, io.Reader, string) (string, sftp.AsyncUpload, error)) *ClientUploadCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *ClientUploadCall) DoAndReturn(f func(context.Context, io.Reader, string) (int64, string, error)) *ClientUploadCall { +func (c *ClientUploadCall) DoAndReturn(f func(context.Context, io.Reader, string) (string, sftp.AsyncUpload, error)) *ClientUploadCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockAsyncUpload is a mock of AsyncUpload interface. +type MockAsyncUpload struct { + ctrl *gomock.Controller + recorder *MockAsyncUploadMockRecorder +} + +// MockAsyncUploadMockRecorder is the mock recorder for MockAsyncUpload. +type MockAsyncUploadMockRecorder struct { + mock *MockAsyncUpload +} + +// NewMockAsyncUpload creates a new mock instance. +func NewMockAsyncUpload(ctrl *gomock.Controller) *MockAsyncUpload { + mock := &MockAsyncUpload{ctrl: ctrl} + mock.recorder = &MockAsyncUploadMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAsyncUpload) EXPECT() *MockAsyncUploadMockRecorder { + return m.recorder +} + +// Bytes mocks base method. +func (m *MockAsyncUpload) Bytes() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Bytes") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// Bytes indicates an expected call of Bytes. +func (mr *MockAsyncUploadMockRecorder) Bytes() *AsyncUploadBytesCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockAsyncUpload)(nil).Bytes)) + return &AsyncUploadBytesCall{Call: call} +} + +// AsyncUploadBytesCall wrap *gomock.Call +type AsyncUploadBytesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *AsyncUploadBytesCall) Return(arg0 uint64) *AsyncUploadBytesCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *AsyncUploadBytesCall) Do(f func() uint64) *AsyncUploadBytesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *AsyncUploadBytesCall) DoAndReturn(f func() uint64) *AsyncUploadBytesCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Close mocks base method. +func (m *MockAsyncUpload) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAsyncUploadMockRecorder) Close() *AsyncUploadCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAsyncUpload)(nil).Close)) + return &AsyncUploadCloseCall{Call: call} +} + +// AsyncUploadCloseCall wrap *gomock.Call +type AsyncUploadCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *AsyncUploadCloseCall) Return(arg0 error) *AsyncUploadCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *AsyncUploadCloseCall) Do(f func() error) *AsyncUploadCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *AsyncUploadCloseCall) DoAndReturn(f func() error) *AsyncUploadCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Done mocks base method. +func (m *MockAsyncUpload) Done() chan bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Done") + ret0, _ := ret[0].(chan bool) + return ret0 +} + +// Done indicates an expected call of Done. +func (mr *MockAsyncUploadMockRecorder) Done() *AsyncUploadDoneCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockAsyncUpload)(nil).Done)) + return &AsyncUploadDoneCall{Call: call} +} + +// AsyncUploadDoneCall wrap *gomock.Call +type AsyncUploadDoneCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *AsyncUploadDoneCall) Return(arg0 chan bool) *AsyncUploadDoneCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *AsyncUploadDoneCall) Do(f func() chan bool) *AsyncUploadDoneCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *AsyncUploadDoneCall) DoAndReturn(f func() chan bool) *AsyncUploadDoneCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Err mocks base method. +func (m *MockAsyncUpload) Err() chan error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(chan error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockAsyncUploadMockRecorder) Err() *AsyncUploadErrCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockAsyncUpload)(nil).Err)) + return &AsyncUploadErrCall{Call: call} +} + +// AsyncUploadErrCall wrap *gomock.Call +type AsyncUploadErrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *AsyncUploadErrCall) Return(arg0 chan error) *AsyncUploadErrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *AsyncUploadErrCall) Do(f func() chan error) *AsyncUploadErrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *AsyncUploadErrCall) DoAndReturn(f func() chan error) *AsyncUploadErrCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Write mocks base method. +func (m *MockAsyncUpload) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockAsyncUploadMockRecorder) Write(arg0 any) *AsyncUploadWriteCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockAsyncUpload)(nil).Write), arg0) + return &AsyncUploadWriteCall{Call: call} +} + +// AsyncUploadWriteCall wrap *gomock.Call +type AsyncUploadWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *AsyncUploadWriteCall) Return(arg0 int, arg1 error) *AsyncUploadWriteCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *AsyncUploadWriteCall) Do(f func([]byte) (int, error)) *AsyncUploadWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *AsyncUploadWriteCall) DoAndReturn(f func([]byte) (int, error)) *AsyncUploadWriteCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/sftp/goclient.go b/internal/sftp/goclient.go index 0a7d35b89..8b821b863 100644 --- a/internal/sftp/goclient.go +++ b/internal/sftp/goclient.go @@ -9,12 +9,10 @@ import ( "os" "regexp" "strconv" - "strings" "github.com/dolmen-go/contextio" "github.com/go-logr/logr" "github.com/pkg/sftp" - "golang.org/x/crypto/ssh" ) // GoClient implements the SFTP service using native Go SSH and SFTP packages. @@ -32,51 +30,18 @@ func NewGoClient(logger logr.Logger, cfg Config) *GoClient { return &GoClient{cfg: cfg, logger: logger} } -// Upload writes the data from src to the remote file at dest and returns the -// number of bytes written. A new SFTP connection is opened before writing, and -// closed when the upload is complete or cancelled. -func (c *GoClient) Upload(ctx context.Context, src io.Reader, dest string) (int64, string, error) { - conn, err := c.dial(ctx) - if err != nil { - return 0, "", err - } - defer conn.close() - - // SFTP assumes that "/" is used as the directory separator. See: - // https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-02#section-6.2 - remotePath := strings.TrimSuffix(c.cfg.RemoteDir, "/") + "/" + dest - - // Note: Some SFTP servers don't support O_RDWR mode. - w, err := conn.sftp.OpenFile(remotePath, (os.O_WRONLY | os.O_CREATE | os.O_TRUNC)) - if err != nil { - return 0, "", fmt.Errorf("SFTP: open remote file %q: %v", dest, err) - } - defer w.Close() - - // Use contextio to stop the upload if a context cancellation signal is - // received. - bytes, err := io.Copy(contextio.NewWriter(ctx, w), contextio.NewReader(ctx, src)) - if err != nil { - return 0, "", fmt.Errorf("SFTP: upload to %q: %v", dest, err) - } - - return bytes, remotePath, nil -} - // Delete removes the data from dest. A new SFTP connection is opened before // removing the file, and closed when the delete is complete. func (c *GoClient) Delete(ctx context.Context, dest string) error { + remotePath := sftp.Join(c.cfg.RemoteDir, dest) + conn, err := c.dial(ctx) if err != nil { - return fmt.Errorf("SFTP: unable to dial: %w", err) + return fmt.Errorf("sftp: dial: %v", err) } - defer conn.close() - - // SFTP assumes that "/" is used as the directory separator. See: - // https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-02#section-6.2 - remotePath := strings.TrimSuffix(c.cfg.RemoteDir, "/") + "/" + dest + defer conn.Close() - if err := conn.sftp.Remove(remotePath); err != nil { + if err := conn.Remove(remotePath); err != nil { head := fmt.Sprintf("SFTP: unable to remove file %q", dest) if errors.Is(err, fs.ErrNotExist) || errors.Is(err, fs.ErrPermission) { return fmt.Errorf("%s: %w", head, err) @@ -90,7 +55,33 @@ func (c *GoClient) Delete(ctx context.Context, dest string) error { return nil } -// Dial connects to an SSH host, creates an SFTP client on the connection, then +// Upload asynchronously copies the src data to dest over an SFTP connection. +// +// When Upload is called it starts the upload in an asynchronous goroutine, then +// immediately returns the full remote path, and an AsyncUpload struct that +// provides access to the upload status and progress. +// +// When the upload completes, the `AsyncUpload.Done()` channel is sent a `true` +// value. If an error occurs during the upload the error is sent to the +// `AsyncUpload.Error()` channel and the upload is terminated. If a ctx +// cancellation signal is received, the `ctx.Err()` error will be sent to the +// `AsyncUpload.Error()` channel, and the upload is terminated. +func (c *GoClient) Upload(ctx context.Context, src io.Reader, dest string) (string, AsyncUpload, error) { + remotePath := sftp.Join(c.cfg.RemoteDir, dest) + + conn, err := c.dial(ctx) + if err != nil { + return "", nil, err + } + + // Asynchronously upload file. + upload := NewAsyncUpload(conn) + go remoteCopy(ctx, &upload, src, remotePath) + + return remotePath, &upload, nil +} + +// dial connects to an SSH host, creates an SFTP client on the connection, then // returns conn. When conn is no longer needed, conn.close() must be called to // prevent leaks. func (c *GoClient) dial(ctx context.Context) (*connection, error) { @@ -99,20 +90,49 @@ func (c *GoClient) dial(ctx context.Context) (*connection, error) { err error ) - conn.ssh, err = sshConnect(ctx, c.logger, c.cfg) + conn.sshClient, err = sshConnect(ctx, c.logger, c.cfg) if err != nil { return nil, err } - conn.sftp, err = sftp.NewClient(conn.ssh) + conn.Client, err = sftp.NewClient(conn.sshClient) if err != nil { - _ = conn.ssh.Close() + _ = conn.sshClient.Close() return nil, fmt.Errorf("start SFTP subsystem: %v", err) } return &conn, nil } +// remoteCopy copies data from the src reader to a remote file at dest, and +// updates upload progress asynchronously. Upload status and progress will be +// sent to the upload struct via the `upload.Done()` and `upload.Error()` channels. +func remoteCopy(ctx context.Context, upload *AsyncUploadImpl, src io.Reader, dest string) { + defer upload.Close() + + // Note: Some SFTP servers don't support O_RDWR mode. + w, err := upload.conn.OpenFile(dest, (os.O_WRONLY | os.O_CREATE | os.O_TRUNC)) + if err != nil { + upload.Err() <- fmt.Errorf("sftp: open remote file %q: %v", dest, err) + return + } + defer w.Close() + + // Write the number of bytes copied to upload. + src = contextio.NewReader(ctx, src) + src = io.TeeReader(src, upload) + + // Use contextio to stop the upload if a context cancellation signal is + // received. + _, err = io.Copy(contextio.NewWriter(ctx, w), src) + if err != nil { + upload.Err() <- fmt.Errorf("remote copy: %v", err) + return + } + + upload.Done() <- true +} + var statusCodeRegex = regexp.MustCompile(`\(SSH_[A-Z_]+\)$`) // formatStatusError extracts/formats the SFTP status error code and message. @@ -132,27 +152,3 @@ func formatStatusError(err *sftp.StatusError) string { return fmt.Sprintf("%s (%s)", codeMsg, code) } - -type connection struct { - ssh *ssh.Client - sftp *sftp.Client -} - -// close closes the SFTP client first, then the SSH client. -func (conn *connection) close() error { - var errs error - - if conn.sftp != nil { - if err := conn.sftp.Close(); err != nil { - errs = errors.Join(err, errs) - } - } - - if conn.ssh != nil { - if err := conn.ssh.Close(); err != nil { - errs = errors.Join(err, errs) - } - } - - return errs -} diff --git a/internal/sftp/goclient_test.go b/internal/sftp/goclient_test.go index 1f559ba9c..6e6aebb58 100644 --- a/internal/sftp/goclient_test.go +++ b/internal/sftp/goclient_test.go @@ -147,7 +147,7 @@ func startSFTPServer(t *testing.T) (string, string) { return host, port } -func TestUpload(t *testing.T) { +func TestGoClient(t *testing.T) { t.Parallel() host, port := startSFTPServer(t) @@ -166,7 +166,7 @@ func TestUpload(t *testing.T) { dest string } type results struct { - Bytes int64 + Bytes uint64 Paths []tfs.PathOp } @@ -179,7 +179,7 @@ func TestUpload(t *testing.T) { } for _, tc := range []test{ { - name: "Uploads a file using a key with no passphrase", + name: "Uploads a file using private key auth", cfg: sftp.Config{ Host: host, Port: port, @@ -198,7 +198,7 @@ func TestUpload(t *testing.T) { }, }, { - name: "Uploads a file using a key with a passphrase", + name: "Uploads a file using private key + password auth", cfg: sftp.Config{ Host: host, Port: port, @@ -306,19 +306,25 @@ func TestUpload(t *testing.T) { remoteDir := tfs.NewDir(t, "sftp_test_remote") tc.cfg.RemoteDir = remoteDir.Path() - sftpc := sftp.NewGoClient(logr.Discard(), tc.cfg) - bytes, remotePath, err := sftpc.Upload(context.Background(), tc.params.src, tc.params.dest) - + client := sftp.NewGoClient(logr.Discard(), tc.cfg) + remotePath, upload, err := client.Upload(context.Background(), tc.params.src, tc.params.dest) if tc.wantErr != nil { assert.Error(t, err, tc.wantErr.Error()) assert.Assert(t, reflect.TypeOf(err) == reflect.TypeOf(tc.wantErr)) return } - assert.NilError(t, err) - assert.Equal(t, bytes, tc.want.Bytes) + assert.Equal(t, remotePath, tc.cfg.RemoteDir+"/"+tc.params.dest) - assert.Assert(t, tfs.Equal(remoteDir.Path(), tfs.Expected(t, tc.want.Paths...))) + assert.Equal(t, upload.Bytes(), uint64(0)) // Upload hasn't started yet. + + select { + case <-upload.Done(): + assert.Equal(t, upload.Bytes(), tc.want.Bytes) + assert.Assert(t, tfs.Equal(remoteDir.Path(), tfs.Expected(t, tc.want.Paths...))) + case err = <-upload.Err(): + t.Fatal(err) + } }) } } @@ -396,9 +402,8 @@ func TestDelete(t *testing.T) { assert.NilError(t, err) } - sftpc := sftp.NewGoClient(logr.Discard(), cfg) - err := sftpc.Delete(context.Background(), tc.params.file) - + client := sftp.NewGoClient(logr.Discard(), cfg) + err := client.Delete(context.Background(), tc.params.file) if tc.wantErr != "" { assert.Error(t, err, tc.wantErr) return diff --git a/internal/workflow/processing.go b/internal/workflow/processing.go index 1dd5e2ec8..ecd7e0373 100644 --- a/internal/workflow/processing.go +++ b/internal/workflow/processing.go @@ -668,6 +668,7 @@ func (w *ProcessingWorkflow) transferAM(sessCtx temporalsdk_workflow.Context, ti activityOpts = temporalsdk_workflow.WithActivityOptions(sessCtx, temporalsdk_workflow.ActivityOptions{ StartToCloseTimeout: time.Hour * 2, + HeartbeatTimeout: 2 * tinfo.req.PollInterval, RetryPolicy: &temporalsdk_temporal.RetryPolicy{ InitialInterval: time.Second * 5, BackoffCoefficient: 2, @@ -706,8 +707,8 @@ func (w *ProcessingWorkflow) transferAM(sessCtx temporalsdk_workflow.Context, ti pollOpts := temporalsdk_workflow.WithActivityOptions( sessCtx, temporalsdk_workflow.ActivityOptions{ - HeartbeatTimeout: 2 * tinfo.req.PollInterval, - ScheduleToCloseTimeout: tinfo.req.TransferDeadline, + HeartbeatTimeout: 2 * tinfo.req.PollInterval, + StartToCloseTimeout: tinfo.req.TransferDeadline, RetryPolicy: &temporalsdk_temporal.RetryPolicy{ InitialInterval: 5 * time.Second, BackoffCoefficient: 2, diff --git a/internal/workflow/processing_test.go b/internal/workflow/processing_test.go index 5e23f308e..5699cf35b 100644 --- a/internal/workflow/processing_test.go +++ b/internal/workflow/processing_test.go @@ -71,7 +71,7 @@ func (s *ProcessingWorkflowTestSuite) SetupWorkflowTest(taskQueue string) { temporalsdk_activity.RegisterOptions{Name: activities.ZipActivityName}, ) s.env.RegisterActivityWithOptions( - am.NewUploadTransferActivity(logger, sftpc).Execute, + am.NewUploadTransferActivity(logger, sftpc, 10*time.Second).Execute, temporalsdk_activity.RegisterOptions{Name: am.UploadTransferActivityName}, ) s.env.RegisterActivityWithOptions(