From c05a78c7717db9b97b4fa4b035206ea03ba6ae4f Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 4 Aug 2023 08:55:10 +0800 Subject: [PATCH] refactor msg channel using singleton pattern and write better unit testing Signed-off-by: Future Outlier --- pkg/async/notifications/factory.go | 16 +++++++++++-- .../implementations/sandbox_processor.go | 10 ++++---- .../implementations/sandbox_processor_test.go | 4 ++-- .../implementations/sandbox_publisher.go | 14 ++++++----- .../implementations/sandbox_publisher_test.go | 24 +++---------------- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/pkg/async/notifications/factory.go b/pkg/async/notifications/factory.go index aba360eb5..a2fd4c357 100644 --- a/pkg/async/notifications/factory.go +++ b/pkg/async/notifications/factory.go @@ -3,6 +3,7 @@ package notifications import ( "context" "fmt" + "sync" "time" "github.com/flyteorg/flyteadmin/pkg/async" @@ -27,6 +28,9 @@ const maxRetries = 3 var enable64decoding = false +var msgChan chan []byte +var once sync.Once + type PublisherConfig struct { TopicName string } @@ -41,6 +45,13 @@ type EmailerConfig struct { BaseURL string } +// For sandbox only +func CreateMsgChan() { + once.Do(func() { + msgChan = make(chan []byte) + }) +} + func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Emailer { // If an external email service is specified use that instead. // TODO: Handling of this is messy, see https://github.com/flyteorg/flyte/issues/1063 @@ -122,7 +133,7 @@ func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, sco return implementations.NewGcpProcessor(sub, emailer, scope) case common.Sandbox: emailer = GetEmailer(config, scope) - return implementations.NewSandboxProcessor(emailer) + return implementations.NewSandboxProcessor(msgChan, emailer) case common.Local: fallthrough default: @@ -175,7 +186,8 @@ func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, sco } return implementations.NewPublisher(publisher, scope) case common.Sandbox: - return implementations.NewSandboxPublisher() + CreateMsgChan() + return implementations.NewSandboxPublisher(msgChan) case common.Local: fallthrough default: diff --git a/pkg/async/notifications/implementations/sandbox_processor.go b/pkg/async/notifications/implementations/sandbox_processor.go index c654c6c3b..112c27901 100644 --- a/pkg/async/notifications/implementations/sandbox_processor.go +++ b/pkg/async/notifications/implementations/sandbox_processor.go @@ -12,7 +12,8 @@ import ( ) type SandboxProcessor struct { - email interfaces.Emailer + email interfaces.Emailer + msgChan chan []byte } func (p *SandboxProcessor) StartProcessing() { @@ -29,7 +30,7 @@ func (p *SandboxProcessor) run() error { for { select { - case msg := <-msgChan: + case msg := <-p.msgChan: err := proto.Unmarshal(msg, &emailMessage) if err != nil { logger.Errorf(context.Background(), "error with unmarshalling message [%v]", err) @@ -53,8 +54,9 @@ func (p *SandboxProcessor) StopProcessing() error { return nil } -func NewSandboxProcessor(emailer interfaces.Emailer) interfaces.Processor { +func NewSandboxProcessor(msgChan chan []byte, emailer interfaces.Emailer) interfaces.Processor { return &SandboxProcessor{ - email: emailer, + msgChan: msgChan, + email: emailer, } } diff --git a/pkg/async/notifications/implementations/sandbox_processor_test.go b/pkg/async/notifications/implementations/sandbox_processor_test.go index 9ce752097..45f96f226 100644 --- a/pkg/async/notifications/implementations/sandbox_processor_test.go +++ b/pkg/async/notifications/implementations/sandbox_processor_test.go @@ -24,8 +24,8 @@ func TestSandboxProcessor_UnmarshalMessage(t *testing.T) { } func TestSandboxProcessor_StartProcessing(t *testing.T) { - - testSandboxProcessor := NewSandboxProcessor(&mockSandboxEmailer) + msgChan := make(chan []byte, 1) + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) sendEmailValidationFunc := func(ctx context.Context, email admin.EmailMessage) error { assert.Equal(t, testEmail.Body, email.Body) diff --git a/pkg/async/notifications/implementations/sandbox_publisher.go b/pkg/async/notifications/implementations/sandbox_publisher.go index 735dc054b..a50e5fb36 100644 --- a/pkg/async/notifications/implementations/sandbox_publisher.go +++ b/pkg/async/notifications/implementations/sandbox_publisher.go @@ -7,9 +7,9 @@ import ( "github.com/golang/protobuf/proto" ) -type SandboxPublisher struct{} - -var msgChan = make(chan []byte) +type SandboxPublisher struct { + msgChan chan []byte +} func (p *SandboxPublisher) Publish(ctx context.Context, notificationType string, msg proto.Message) error { logger.Debugf(ctx, "Publishing the following message [%s]", msg.String()) @@ -21,11 +21,13 @@ func (p *SandboxPublisher) Publish(ctx context.Context, notificationType string, return err } - msgChan <- data + p.msgChan <- data return nil } -func NewSandboxPublisher() *SandboxPublisher { - return &SandboxPublisher{} +func NewSandboxPublisher(msgChan chan []byte) *SandboxPublisher { + return &SandboxPublisher{ + msgChan: msgChan, + } } diff --git a/pkg/async/notifications/implementations/sandbox_publisher_test.go b/pkg/async/notifications/implementations/sandbox_publisher_test.go index 97c103666..fed562a0a 100644 --- a/pkg/async/notifications/implementations/sandbox_publisher_test.go +++ b/pkg/async/notifications/implementations/sandbox_publisher_test.go @@ -3,34 +3,16 @@ package implementations import ( "context" "testing" - "time" "github.com/stretchr/testify/assert" ) func TestSandboxPublisher_Publish(t *testing.T) { - publisher := NewSandboxPublisher() - - errChan := make(chan string) - - go func() { - select { - case <-msgChan: - // if message received, no need to send an error - case <-time.After(time.Second * 5): - errChan <- "No data was received in the channel within the expected time frame" - } - }() + msgChan := make(chan []byte, 1) + publisher := NewSandboxPublisher(msgChan) err := publisher.Publish(context.Background(), "NOTIFICATION_TYPE", &testEmail) - // Check if there was an error in the goroutine - select { - case errMsg := <-errChan: - t.Fatal(errMsg) - default: - // no error from the goroutine - } - + assert.NotZero(t, len(msgChan)) assert.Nil(t, err) }