diff --git a/gbus/saga/glue.go b/gbus/saga/glue.go index 1c2c732..6d68ff7 100644 --- a/gbus/saga/glue.go +++ b/gbus/saga/glue.go @@ -106,7 +106,6 @@ func (imsm *Glue) handler(invocation gbus.Invocation, message *gbus.BusMessage) imsm.lock.Lock() defer imsm.lock.Unlock() msgName := message.PayloadFQN - exchange, routingKey := invocation.Routing() defs := imsm.msgToDefMap[strings.ToLower(msgName)] @@ -125,7 +124,7 @@ func (imsm *Glue) handler(invocation gbus.Invocation, message *gbus.BusMessage) imsm.log(). WithFields(log.Fields{"saga_def": def.String(), "saga_id": newInstance.ID}). Info("created new saga") - if invkErr := newInstance.invoke(exchange, routingKey, invocation, message); invkErr != nil { + if invkErr := imsm.invokeSagaInstance(newInstance, invocation, message); invkErr != nil { imsm.log().WithError(invkErr).WithField("saga_id", newInstance.ID).Error("failed to invoke saga") return invkErr } @@ -156,7 +155,7 @@ func (imsm *Glue) handler(invocation gbus.Invocation, message *gbus.BusMessage) return e } - if invkErr := instance.invoke(exchange, routingKey, invocation, message); invkErr != nil { + if invkErr := imsm.invokeSagaInstance(instance, invocation, message); invkErr != nil { imsm.log().WithError(invkErr).WithField("saga_id", instance.ID).Error("failed to invoke saga") return invkErr } @@ -178,7 +177,7 @@ func (imsm *Glue) handler(invocation gbus.Invocation, message *gbus.BusMessage) for _, instance := range instances { - if invkErr := instance.invoke(exchange, routingKey, invocation, message); invkErr != nil { + if invkErr := imsm.invokeSagaInstance(instance, invocation, message); invkErr != nil { imsm.log().WithError(invkErr).WithField("saga_id", instance.ID).Error("failed to invoke saga") return invkErr } @@ -193,6 +192,19 @@ func (imsm *Glue) handler(invocation gbus.Invocation, message *gbus.BusMessage) return nil } +func (imsm *Glue) invokeSagaInstance(instance *Instance, invocation gbus.Invocation, message *gbus.BusMessage) error { + sginv := &sagaInvocation{ + decoratedBus: invocation.Bus(), + decoratedInvocation: invocation, + inboundMsg: message, + sagaID: instance.ID, + ctx: invocation.Ctx(), + invokingService: imsm.svcName} + + exchange, routingKey := invocation.Routing() + return instance.invoke(exchange, routingKey, sginv, message) +} + func (imsm *Glue) completeOrUpdateSaga(tx *sql.Tx, instance *Instance, lastMessage *gbus.BusMessage) error { _, timedOut := lastMessage.Payload.(gbus.SagaTimeoutMessage) diff --git a/gbus/saga/instance.go b/gbus/saga/instance.go index 50c8d1c..5cdf99d 100644 --- a/gbus/saga/instance.go +++ b/gbus/saga/instance.go @@ -30,18 +30,12 @@ func (si *Instance) invoke(exchange, routingKey string, invocation gbus.Invocati } valueOfMessage := reflect.ValueOf(message) - sginv := &sagaInvocation{ - decoratedBus: invocation.Bus(), - decoratedInvocation: invocation, - inboundMsg: message, - sagaID: si.ID, - ctx: invocation.Ctx(), - } + reflectedVal := reflect.ValueOf(si.UnderlyingInstance) for _, methodName := range methodsToInvoke { params := make([]reflect.Value, 0) - params = append(params, reflect.ValueOf(sginv), valueOfMessage) + params = append(params, reflect.ValueOf(invocation), valueOfMessage) method := reflectedVal.MethodByName(methodName) log.Printf(" invoking method %v on saga instance %v", methodName, si.ID) returns := method.Call(params) diff --git a/gbus/saga/invocation.go b/gbus/saga/invocation.go index ea59d5a..f2725af 100644 --- a/gbus/saga/invocation.go +++ b/gbus/saga/invocation.go @@ -14,19 +14,26 @@ type sagaInvocation struct { inboundMsg *gbus.BusMessage sagaID string ctx context.Context + invokingService string } func (si *sagaInvocation) setCorrelationIDs(message *gbus.BusMessage, isEvent bool) { message.CorrelationID = si.inboundMsg.ID + message.SagaID = si.sagaID if !isEvent { //support saga-to-saga communication if si.inboundMsg.SagaID != "" { + message.SagaCorrelationID = si.inboundMsg.SagaID + } + //if the saga is potentially invoking itself then set the SagaCorrelationID to reflect that + //https://github.com/wework/grabbit/issues/64 + _, targetService := si.decoratedInvocation.Routing() + if targetService == si.invokingService { message.SagaCorrelationID = message.SagaID } - message.SagaID = si.sagaID } } diff --git a/tests/saga_test.go b/tests/saga_test.go index e26ea73..3c86325 100644 --- a/tests/saga_test.go +++ b/tests/saga_test.go @@ -1,7 +1,9 @@ package tests import ( + "context" "log" + "reflect" "testing" "time" @@ -224,6 +226,37 @@ func TestSagaTimeout(t *testing.T) { <-proceed } +func TestSagaSelfMessaging(t *testing.T) { + proceed := make(chan bool) + b := createNamedBusForTest(testSvc1) + + handler := func(invocation gbus.Invocation, message *gbus.BusMessage) error { + + _, ok := message.Payload.(*Event1) + if !ok { + t.Errorf("handler invoced with wrong message type\r\nexpeted:%v\r\nactual:%v", reflect.TypeOf(Command1{}), reflect.TypeOf(message.Payload)) + } + proceed <- true + + return nil + } + + err := b.HandleEvent("test_exchange", "test_topic", Event1{}, handler) + if err != nil { + t.Errorf("Registering handler returned false, expected true with error: %s", err.Error()) + } + + b.RegisterSaga(&SelfSendingSaga{}) + + b.Start() + defer b.Shutdown() + + b.Send(context.TODO(), testSvc1, gbus.NewBusMessage(Command1{})) + + <-proceed + +} + /*Test Sagas*/ type SagaA struct { @@ -370,3 +403,40 @@ func (s *TimingOutSaga) Timeout(invocation gbus.Invocation, message *gbus.BusMes Data: "TimingOutSaga.Timeout", })) } + +type SelfSendingSaga struct { +} + +func (*SelfSendingSaga) StartedBy() []gbus.Message { + starters := make([]gbus.Message, 0) + return append(starters, Command1{}) +} + +func (s *SelfSendingSaga) IsComplete() bool { + return false +} + +func (s *SelfSendingSaga) New() gbus.Saga { + return &SelfSendingSaga{} +} + +func (s *SelfSendingSaga) RegisterAllHandlers(register gbus.HandlerRegister) { + register.HandleMessage(Command1{}, s.HandleCommand1) + register.HandleMessage(Command2{}, s.HandleCommand2) + register.HandleMessage(Reply2{}, s.HandleReply2) +} + +func (s *SelfSendingSaga) HandleCommand1(invocation gbus.Invocation, message *gbus.BusMessage) error { + cmd2 := gbus.NewBusMessage(Command2{}) + return invocation.Bus().Send(invocation.Ctx(), testSvc1, cmd2) +} + +func (s *SelfSendingSaga) HandleCommand2(invocation gbus.Invocation, message *gbus.BusMessage) error { + reply := gbus.NewBusMessage(Reply2{}) + return invocation.Reply(invocation.Ctx(), reply) +} + +func (s *SelfSendingSaga) HandleReply2(invocation gbus.Invocation, message *gbus.BusMessage) error { + evt1 := gbus.NewBusMessage(Event1{}) + return invocation.Bus().Publish(invocation.Ctx(), "test_exchange", "test_topic", evt1) +}