-
Notifications
You must be signed in to change notification settings - Fork 416
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package gochannel | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"sync" | ||
|
||
"github.com/ThreeDotsLabs/watermill" | ||
"github.com/ThreeDotsLabs/watermill/message" | ||
) | ||
|
||
// FanOut is a component that receives messages from the subscriber and passes them | ||
// to all publishers. In effect, messages are "multiplied". | ||
// | ||
// A typical use case for using FanOut is having one external subscription and multiple workers | ||
// inside the process. | ||
// | ||
// You need to call AddSubscription method for all topics that you want to listen to. | ||
// This needs to be done *before* starting the FanOut. | ||
// | ||
// FanOut exposes the standard Subscriber interface. | ||
type FanOut struct { | ||
internalPubSub *GoChannel | ||
internalRouter *message.Router | ||
|
||
subscriber message.Subscriber | ||
|
||
logger watermill.LoggerAdapter | ||
|
||
subscribedTopics map[string]struct{} | ||
subscribedLock sync.Mutex | ||
} | ||
|
||
// NewFanOut creates a new FanOut. | ||
func NewFanOut( | ||
subscriber message.Subscriber, | ||
logger watermill.LoggerAdapter, | ||
) (*FanOut, error) { | ||
if subscriber == nil { | ||
return nil, errors.New("missing subscriber") | ||
} | ||
if logger == nil { | ||
logger = watermill.NopLogger{} | ||
} | ||
|
||
router, err := message.NewRouter(message.RouterConfig{}, logger) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &FanOut{ | ||
internalPubSub: NewGoChannel(Config{}, logger), | ||
internalRouter: router, | ||
|
||
subscriber: subscriber, | ||
|
||
logger: logger, | ||
|
||
subscribedTopics: map[string]struct{}{}, | ||
}, nil | ||
} | ||
|
||
// AddSubscription add an internal subscription for the given topic. | ||
// You need to call this method with all topics that you want to listen to, before the FanOut is started. | ||
// AddSubscription is idempotent. | ||
func (f *FanOut) AddSubscription(topic string) { | ||
f.subscribedLock.Lock() | ||
defer f.subscribedLock.Unlock() | ||
|
||
_, ok := f.subscribedTopics[topic] | ||
if ok { | ||
// Subscription already exists | ||
return | ||
} | ||
|
||
f.logger.Trace("Adding fan-out subscription for topic", watermill.LogFields{ | ||
"topic": topic, | ||
}) | ||
|
||
f.internalRouter.AddHandler( | ||
fmt.Sprintf("fanout-%s", topic), | ||
topic, | ||
f.subscriber, | ||
topic, | ||
f.internalPubSub, | ||
message.PassthroughHandler, | ||
) | ||
|
||
f.subscribedTopics[topic] = struct{}{} | ||
} | ||
|
||
// Run runs the FanOut. | ||
func (f *FanOut) Run(ctx context.Context) error { | ||
return f.internalRouter.Run(ctx) | ||
} | ||
|
||
// Running is closed when FanOut is running. | ||
func (f *FanOut) Running() chan struct{} { | ||
return f.internalRouter.Running() | ||
} | ||
|
||
func (f *FanOut) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) { | ||
return f.internalPubSub.Subscribe(ctx, topic) | ||
} | ||
|
||
func (f *FanOut) Close() error { | ||
return f.internalPubSub.Close() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package gochannel_test | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/ThreeDotsLabs/watermill" | ||
"github.com/ThreeDotsLabs/watermill/message" | ||
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel" | ||
) | ||
|
||
func TestFanOut(t *testing.T) { | ||
logger := watermill.NopLogger{} | ||
|
||
upstreamPubSub := gochannel.NewGoChannel(gochannel.Config{}, logger) | ||
upstreamTopic := "upstream-topic" | ||
|
||
router, err := message.NewRouter(message.RouterConfig{}, logger) | ||
require.NoError(t, err) | ||
|
||
fanout, err := gochannel.NewFanOut(upstreamPubSub, logger) | ||
require.NoError(t, err) | ||
|
||
fanout.AddSubscription(upstreamTopic) | ||
|
||
workersCount := 10 | ||
messagesCount := 100 | ||
|
||
receivedMessages := make(chan struct{}, workersCount*messagesCount*2) | ||
|
||
for i := 0; i < workersCount; i++ { | ||
router.AddNoPublisherHandler( | ||
fmt.Sprintf("worker-%v", i), | ||
upstreamTopic, | ||
fanout, | ||
func(msg *message.Message) error { | ||
receivedMessages <- struct{}{} | ||
return nil | ||
}, | ||
) | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) | ||
defer cancel() | ||
|
||
go func() { | ||
err := router.Run(ctx) | ||
require.NoError(t, err) | ||
}() | ||
|
||
go func() { | ||
err := fanout.Run(ctx) | ||
require.NoError(t, err) | ||
}() | ||
|
||
<-router.Running() | ||
<-fanout.Running() | ||
|
||
go func() { | ||
for i := 0; i < messagesCount; i++ { | ||
msg := message.NewMessage(watermill.NewUUID(), nil) | ||
err := upstreamPubSub.Publish(upstreamTopic, msg) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
}() | ||
|
||
<-ctx.Done() | ||
|
||
counter := 0 | ||
|
||
loop: | ||
for { | ||
select { | ||
case <-receivedMessages: | ||
counter += 1 | ||
case <-time.After(time.Second): | ||
close(receivedMessages) | ||
break loop | ||
} | ||
} | ||
|
||
require.Equal(t, workersCount*messagesCount, counter) | ||
} | ||
|
||
func TestFanOut_RouterClosed(t *testing.T) { | ||
logger := watermill.NopLogger{} | ||
pubSub := gochannel.NewGoChannel(gochannel.Config{}, logger) | ||
|
||
fanout, err := gochannel.NewFanOut(pubSub, logger) | ||
require.NoError(t, err) | ||
|
||
fanout.AddSubscription("some-topic") | ||
|
||
go func() { | ||
err := fanout.Run(context.Background()) | ||
require.NoError(t, err) | ||
}() | ||
|
||
<-fanout.Running() | ||
|
||
err = fanout.Close() | ||
require.NoError(t, err) | ||
} |