diff --git a/pkg/connector/errors.go b/pkg/connector/errors.go index 9587c40c2..9cf3df53d 100644 --- a/pkg/connector/errors.go +++ b/pkg/connector/errors.go @@ -22,4 +22,9 @@ var ( ErrInvalidConnectorStateType = cerrors.New("invalid connector state type") ErrProcessorIDNotFound = cerrors.New("processor ID not found") ErrConnectorRunning = cerrors.New("connector is running") + ErrInvalidCharacters = cerrors.New("connector ID contains invalid characters") + ErrIDOverLimit = cerrors.New("connector ID is over the character limit (64)") + ErrNameOverLimit = cerrors.New("connector name is over the character limit (64)") + ErrNameMissing = cerrors.New("must provide a connector name") + ErrIDMissing = cerrors.New("must provide a connector ID") ) diff --git a/pkg/connector/service.go b/pkg/connector/service.go index 08a0c7588..0e1433fec 100644 --- a/pkg/connector/service.go +++ b/pkg/connector/service.go @@ -16,6 +16,7 @@ package connector import ( "context" + "regexp" "strings" "time" @@ -23,8 +24,11 @@ import ( "github.com/conduitio/conduit/pkg/foundation/database" "github.com/conduitio/conduit/pkg/foundation/log" "github.com/conduitio/conduit/pkg/foundation/metrics/measure" + "github.com/conduitio/conduit/pkg/foundation/multierror" ) +var idRegex = regexp.MustCompile(`^[A-Za-z0-9-_:]*$`) + // Service manages connectors. type Service struct { logger log.CtxLogger @@ -111,6 +115,11 @@ func (s *Service) Create( cfg Config, p ProvisionType, ) (*Instance, error) { + err := s.validateConnector(cfg, id) + if err != nil { + return nil, cerrors.Errorf("connector is invalid: %w", err) + } + // determine the path of the Connector binary if plugin == "" { return nil, cerrors.New("must provide a plugin") @@ -142,7 +151,7 @@ func (s *Service) Create( } // persist instance - err := s.store.Set(ctx, id, conn) + err = s.store.Set(ctx, id, conn) if err != nil { return nil, err } @@ -275,3 +284,26 @@ func (s *Service) SetState(ctx context.Context, id string, state any) (*Instance return conn, err } +func (s *Service) validateConnector(cfg Config, id string) error { + // contains all the errors occurred while provisioning configuration files. + var multierr error + + if cfg.Name == "" { + multierr = multierror.Append(multierr, ErrNameMissing) + } + if len(cfg.Name) > 64 { + multierr = multierror.Append(multierr, ErrNameOverLimit) + } + if id == "" { + multierr = multierror.Append(multierr, ErrIDMissing) + } + matched := idRegex.MatchString(id) + if !matched { + multierr = multierror.Append(multierr, ErrInvalidCharacters) + } + if len(id) > 64 { + multierr = multierror.Append(multierr, ErrIDOverLimit) + } + + return multierr +} diff --git a/pkg/connector/service_test.go b/pkg/connector/service_test.go index 54bc8787b..fa13ef166 100644 --- a/pkg/connector/service_test.go +++ b/pkg/connector/service_test.go @@ -220,7 +220,10 @@ func TestService_CreateDLQ(t *testing.T) { TypeDestination, "test-plugin", uuid.NewString(), - Config{}, + Config{ + Name: "test-connector", + Settings: map[string]string{"foo": "bar"}, + }, ProvisionTypeDLQ, ) is.NoErr(err) @@ -271,7 +274,8 @@ func TestService_CreateError(t *testing.T) { Name: "test-connector", Settings: map[string]string{"foo": "bar"}, }, - }} + }, + } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { @@ -290,6 +294,123 @@ func TestService_CreateError(t *testing.T) { } } +func TestService_Create_ValidateSuccess(t *testing.T) { + is := is.New(t) + ctx := context.Background() + logger := log.Nop() + db := &inmemory.DB{} + + service := NewService(logger, db, nil) + + testCases := []struct { + name string + connID string + data Config + }{{ + name: "valid config name", + connID: uuid.NewString(), + data: Config{ + Name: "Name#@-/_0%$", + Settings: map[string]string{"foo": "bar"}, + }, + }, { + name: "valid connector ID", + connID: "Aa0-_", + data: Config{ + Name: "test-connector", + Settings: map[string]string{"foo": "bar"}, + }, + }} + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + got, err := service.Create( + ctx, + tt.connID, + TypeSource, + "test-plugin", + uuid.NewString(), + tt.data, + ProvisionTypeAPI, + ) + is.True(got != nil) + is.Equal(err, nil) + }) + } +} + +func TestService_Create_ValidateError(t *testing.T) { + is := is.New(t) + ctx := context.Background() + logger := log.Nop() + db := &inmemory.DB{} + + service := NewService(logger, db, nil) + + testCases := []struct { + name string + connID string + errType error + data Config + }{{ + name: "empty config name", + connID: uuid.NewString(), + errType: ErrNameMissing, + data: Config{ + Name: "", + Settings: map[string]string{"foo": "bar"}, + }, + }, { + name: "connector name over 64 characters", + connID: uuid.NewString(), + errType: ErrNameOverLimit, + data: Config{ + Name: "aaaaaaaaa1bbbbbbbbb2ccccccccc3ddddddddd4eeeeeeeee5fffffffff6ggggg", + Settings: map[string]string{"foo": "bar"}, + }, + }, { + name: "connector ID over 64 characters", + connID: "aaaaaaaaa1bbbbbbbbb2ccccccccc3ddddddddd4eeeeeeeee5fffffffff6ggggg", + errType: ErrIDOverLimit, + data: Config{ + Name: "test-connector", + Settings: map[string]string{"foo": "bar"}, + }, + }, { + name: "invalid characters in connector ID", + connID: "a%bc", + errType: ErrInvalidCharacters, + data: Config{ + Name: "test-connector", + Settings: map[string]string{"foo": "bar"}, + }, + }, { + name: "empty connector ID", + connID: "", + errType: ErrIDMissing, + data: Config{ + Name: "test-connector", + Settings: map[string]string{"foo": "bar"}, + }, + }} + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + got, err := service.Create( + ctx, + tt.connID, + TypeSource, + "test-plugin", + uuid.NewString(), + tt.data, + ProvisionTypeAPI, + ) + is.True(cerrors.Is(err, tt.errType)) + is.Equal(got, nil) + }) + } +} + func TestService_GetInstanceNotFound(t *testing.T) { is := is.New(t) ctx := context.Background() diff --git a/pkg/pipeline/errors.go b/pkg/pipeline/errors.go index f338f851f..e015fb5e5 100644 --- a/pkg/pipeline/errors.go +++ b/pkg/pipeline/errors.go @@ -17,13 +17,18 @@ package pipeline import "github.com/conduitio/conduit/pkg/foundation/cerrors" var ( - ErrTimeout = cerrors.New("operation timed out") - ErrGracefulShutdown = cerrors.New("graceful shutdown") - ErrPipelineRunning = cerrors.New("pipeline is running") - ErrPipelineNotRunning = cerrors.New("pipeline not running") - ErrInstanceNotFound = cerrors.New("pipeline instance not found") - ErrNameMissing = cerrors.New("must provide a pipeline name") - ErrNameAlreadyExists = cerrors.New("pipeline name already exists") - ErrConnectorIDNotFound = cerrors.New("connector ID not found") - ErrProcessorIDNotFound = cerrors.New("processor ID not found") + ErrTimeout = cerrors.New("operation timed out") + ErrGracefulShutdown = cerrors.New("graceful shutdown") + ErrPipelineRunning = cerrors.New("pipeline is running") + ErrPipelineNotRunning = cerrors.New("pipeline not running") + ErrInstanceNotFound = cerrors.New("pipeline instance not found") + ErrNameMissing = cerrors.New("must provide a pipeline name") + ErrIDMissing = cerrors.New("must provide a pipeline ID") + ErrNameAlreadyExists = cerrors.New("pipeline name already exists") + ErrInvalidCharacters = cerrors.New("pipeline ID contains invalid characters") + ErrNameOverLimit = cerrors.New("pipeline name is over the character limit (64)") + ErrIDOverLimit = cerrors.New("pipeline ID is over the character limit (64)") + ErrDescriptionOverLimit = cerrors.New("pipeline description is over the character limit (8192)") + ErrConnectorIDNotFound = cerrors.New("connector ID not found") + ErrProcessorIDNotFound = cerrors.New("processor ID not found") ) diff --git a/pkg/pipeline/service.go b/pkg/pipeline/service.go index 3e7f56a18..4628c3568 100644 --- a/pkg/pipeline/service.go +++ b/pkg/pipeline/service.go @@ -16,6 +16,7 @@ package pipeline import ( "context" + "regexp" "strings" "time" @@ -23,8 +24,11 @@ import ( "github.com/conduitio/conduit/pkg/foundation/database" "github.com/conduitio/conduit/pkg/foundation/log" "github.com/conduitio/conduit/pkg/foundation/metrics/measure" + "github.com/conduitio/conduit/pkg/foundation/multierror" ) +var idRegex = regexp.MustCompile(`^[A-Za-z0-9-_:]*$`) + type FailureEvent struct { // ID is the ID of the pipeline which failed. ID string @@ -111,11 +115,9 @@ func (s *Service) Get(_ context.Context, id string) (*Instance, error) { // Create will create a new pipeline instance with the given config and return // it if it was successfully saved to the database. func (s *Service) Create(ctx context.Context, id string, cfg Config, p ProvisionType) (*Instance, error) { - if cfg.Name == "" { - return nil, ErrNameMissing - } - if s.instanceNames[cfg.Name] { - return nil, ErrNameAlreadyExists + err := s.validatePipeline(cfg, id) + if err != nil { + return nil, cerrors.Errorf("pipeline is invalid: %w", err) } t := time.Now() @@ -129,7 +131,7 @@ func (s *Service) Create(ctx context.Context, id string, cfg Config, p Provision DLQ: DefaultDLQ, } - err := s.store.Set(ctx, pl.ID, pl) + err = s.store.Set(ctx, pl.ID, pl) if err != nil { return nil, cerrors.Errorf("failed to save pipeline with ID %q: %w", pl.ID, err) } @@ -326,3 +328,32 @@ func (s *Service) notify(pipelineID string, err error) { handler(e) } } +func (s *Service) validatePipeline(cfg Config, id string) error { + // contains all the errors occurred while provisioning configuration files. + var multierr error + + if cfg.Name == "" { + multierr = multierror.Append(multierr, ErrNameMissing) + } + if s.instanceNames[cfg.Name] { + multierr = multierror.Append(multierr, ErrNameAlreadyExists) + } + if len(cfg.Name) > 64 { + multierr = multierror.Append(multierr, ErrNameOverLimit) + } + if len(cfg.Description) > 8192 { + multierr = multierror.Append(multierr, ErrDescriptionOverLimit) + } + if id == "" { + multierr = multierror.Append(multierr, ErrIDMissing) + } + matched := idRegex.MatchString(id) + if !matched { + multierr = multierror.Append(multierr, ErrInvalidCharacters) + } + if len(id) > 64 { + multierr = multierror.Append(multierr, ErrIDOverLimit) + } + + return multierr +} diff --git a/pkg/pipeline/service_test.go b/pkg/pipeline/service_test.go index fad96b6d5..9f2ae91ba 100644 --- a/pkg/pipeline/service_test.go +++ b/pkg/pipeline/service_test.go @@ -146,6 +146,109 @@ func TestService_CreateSuccess(t *testing.T) { } } +func TestService_Create_ValidateSuccess(t *testing.T) { + is := is.New(t) + ctx := context.Background() + logger := log.Nop() + db := &inmemory.DB{} + + service := NewService(logger, db) + + testCases := []struct { + name string + connID string + data Config + }{{ + name: "valid config name", + connID: uuid.NewString(), + data: Config{ + Name: "Name#@-/_0%$", + Description: "", + }, + }, { + name: "valid connector ID", + connID: "Aa0-_:", + data: Config{ + Name: "test-connector", + Description: "", + }, + }} + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + got, err := service.Create( + ctx, + tt.connID, + tt.data, + ProvisionTypeAPI, + ) + is.True(got != nil) + is.Equal(err, nil) + }) + } +} + +func TestService_Create_ValidateError(t *testing.T) { + is := is.New(t) + ctx := context.Background() + logger := log.Nop() + db := &inmemory.DB{} + + service := NewService(logger, db) + + testCases := []struct { + name string + connID string + errType error + data Config + }{{ + name: "empty config name", + connID: uuid.NewString(), + errType: ErrNameMissing, + data: Config{ + Name: "", + Description: "", + }, + }, { + name: "pipeline name over 64 characters", + connID: uuid.NewString(), + errType: ErrNameOverLimit, + data: Config{ + Name: "aaaaaaaaa1bbbbbbbbb2ccccccccc3ddddddddd4eeeeeeeee5fffffffff6ggggg", + Description: "", + }, + }, { + name: "pipeline ID over 64 characters", + connID: "aaaaaaaaa1bbbbbbbbb2ccccccccc3ddddddddd4eeeeeeeee5fffffffff6ggggg", + errType: ErrIDOverLimit, + data: Config{ + Name: "test-connector", + Description: "", + }, + }, { + name: "invalid characters in connector ID", + connID: "a%bc", + errType: ErrInvalidCharacters, + data: Config{ + Name: "test-connector", + Description: "", + }, + }} + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + got, err := service.Create( + ctx, + tt.connID, + tt.data, + ProvisionTypeAPI, + ) + is.True(cerrors.Is(err, tt.errType)) + is.Equal(got, nil) + }) + } +} + func TestService_Create_PipelineNameExists(t *testing.T) { is := is.New(t) ctx := context.Background() diff --git a/pkg/web/api/connector_v1.go b/pkg/web/api/connector_v1.go index fcd77744e..077a8398f 100644 --- a/pkg/web/api/connector_v1.go +++ b/pkg/web/api/connector_v1.go @@ -128,7 +128,6 @@ func (c *ConnectorAPIv1) CreateConnector( req.PipelineId, fromproto.ConnectorConfig(req.Config), ) - if err != nil { return nil, status.ConnectorError(cerrors.Errorf("failed to create connector: %w", err)) }