From f124195ae9a62c4cc708a6d5d66f5429fbcca01d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 26 Jan 2024 18:48:19 -0800 Subject: [PATCH 1/4] msgmux: add new abstract message router In this commit, we add a new abstract message router. Over time, the goal is that this message router replaces the logic we currently have in the readHandler (the giant switch for each message). With this new abstraction, can reduce the responsibilities of the readHandler to *just* reading messages off the wire and handing them off to the msg router. The readHandler no longer needs to know *where* the messages should go, or how they should be dispatched. This will be used in tandem with the new `protofsm` module in an upcoming PR implementing the new rbf-coop close. --- go.mod | 2 +- go.sum | 4 +- msgmux/log.go | 32 +++++ msgmux/msg_router.go | 274 ++++++++++++++++++++++++++++++++++++++ msgmux/msg_router_test.go | 157 ++++++++++++++++++++++ 5 files changed, 466 insertions(+), 3 deletions(-) create mode 100644 msgmux/log.go create mode 100644 msgmux/msg_router.go create mode 100644 msgmux/msg_router_test.go diff --git a/go.mod b/go.mod index 0e628d18b0..5671a4c423 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn v1.2.0 + github.com/lightningnetwork/lnd/fn v1.2.1 github.com/lightningnetwork/lnd/healthcheck v1.2.5 github.com/lightningnetwork/lnd/kvdb v1.4.10 github.com/lightningnetwork/lnd/queue v1.1.1 diff --git a/go.sum b/go.sum index d556042dd3..0d49f76e37 100644 --- a/go.sum +++ b/go.sum @@ -450,8 +450,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn v1.2.0 h1:YTb2m8NN5ZiJAskHeBZAmR1AiPY8SXziIYPAX1VI/ZM= -github.com/lightningnetwork/lnd/fn v1.2.0/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= +github.com/lightningnetwork/lnd/fn v1.2.1 h1:pPsVGrwi9QBwdLJzaEGK33wmiVKOxs/zc8H7+MamFf0= +github.com/lightningnetwork/lnd/fn v1.2.1/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= github.com/lightningnetwork/lnd/healthcheck v1.2.5 h1:aTJy5xeBpcWgRtW/PGBDe+LMQEmNm/HQewlQx2jt7OA= github.com/lightningnetwork/lnd/healthcheck v1.2.5/go.mod h1:G7Tst2tVvWo7cx6mSBEToQC5L1XOGxzZTPB29g9Rv2I= github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kfmOrqHbY5zoY= diff --git a/msgmux/log.go b/msgmux/log.go new file mode 100644 index 0000000000..63da4791f9 --- /dev/null +++ b/msgmux/log.go @@ -0,0 +1,32 @@ +package msgmux + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the logging code for this subsystem. +const Subsystem = "MSGX" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/msgmux/msg_router.go b/msgmux/msg_router.go new file mode 100644 index 0000000000..cd69b14a61 --- /dev/null +++ b/msgmux/msg_router.go @@ -0,0 +1,274 @@ +package msgmux + +// For some reason golangci-lint has a false positive on the sort order of the +// imports for the new "maps" package... We need the nolint directive here to +// ignore that. +// +//nolint:gci +import ( + "fmt" + "maps" + "sync" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrDuplicateEndpoint is returned when an endpoint is registered with + // a name that already exists. + ErrDuplicateEndpoint = fmt.Errorf("endpoint already registered") + + // ErrUnableToRouteMsg is returned when a message is unable to be + // routed to any endpoints. + ErrUnableToRouteMsg = fmt.Errorf("unable to route message") +) + +// EndpointName is the name of a given endpoint. This MUST be unique across all +// registered endpoints. +type EndpointName = string + +// PeerMsg is a wire message that includes the public key of the peer that sent +// it. +type PeerMsg struct { + lnwire.Message + + // PeerPub is the public key of the peer that sent this message. + PeerPub btcec.PublicKey +} + +// Endpoint is an interface that represents a message endpoint, or the +// sub-system that will handle processing an incoming wire message. +type Endpoint interface { + // Name returns the name of this endpoint. This MUST be unique across + // all registered endpoints. + Name() EndpointName + + // CanHandle returns true if the target message can be routed to this + // endpoint. + CanHandle(msg PeerMsg) bool + + // SendMessage handles the target message, and returns true if the + // message was able being processed. + SendMessage(msg PeerMsg) bool +} + +// MsgRouter is an interface that represents a message router, which is generic +// sub-system capable of routing any incoming wire message to a set of +// registered endpoints. +type Router interface { + // RegisterEndpoint registers a new endpoint with the router. If a + // duplicate endpoint exists, an error is returned. + RegisterEndpoint(Endpoint) error + + // UnregisterEndpoint unregisters the target endpoint from the router. + UnregisterEndpoint(EndpointName) error + + // RouteMsg attempts to route the target message to a registered + // endpoint. If ANY endpoint could handle the message, then nil is + // returned. Otherwise, ErrUnableToRouteMsg is returned. + RouteMsg(PeerMsg) error + + // Start starts the peer message router. + Start() + + // Stop stops the peer message router. + Stop() +} + +// sendQuery sends a query to the main event loop, and returns the response. +func sendQuery[Q any, R any](sendChan chan fn.Req[Q, R], queryArg Q, + quit chan struct{}) fn.Result[R] { + + query, respChan := fn.NewReq[Q, R](queryArg) + + if !fn.SendOrQuit(sendChan, query, quit) { + return fn.Errf[R]("router shutting down") + } + + return fn.NewResult(fn.RecvResp(respChan, nil, quit)) +} + +// sendQueryErr is a helper function based on sendQuery that can be used when +// the query only needs an error response. +func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q, + quitChan chan struct{}) error { + + return fn.ElimEither( + fn.Iden, fn.Iden, + sendQuery(sendChan, queryArg, quitChan).Either, + ) +} + +// EndpointsMap is a map of all registered endpoints. +type EndpointsMap map[EndpointName]Endpoint + +// MultiMsgRouter is a type of message router that is capable of routing new +// incoming messages, permitting a message to be routed to multiple registered +// endpoints. +type MultiMsgRouter struct { + startOnce sync.Once + stopOnce sync.Once + + // registerChan is the channel that all new endpoints will be sent to. + registerChan chan fn.Req[Endpoint, error] + + // unregisterChan is the channel that all endpoints that are to be + // removed are sent to. + unregisterChan chan fn.Req[EndpointName, error] + + // msgChan is the channel that all messages will be sent to for + // processing. + msgChan chan fn.Req[PeerMsg, error] + + // endpointsQueries is a channel that all queries to the endpoints map + // will be sent to. + endpointQueries chan fn.Req[Endpoint, EndpointsMap] + + wg sync.WaitGroup + quit chan struct{} +} + +// NewMultiMsgRouter creates a new instance of a peer message router. +func NewMultiMsgRouter() *MultiMsgRouter { + return &MultiMsgRouter{ + registerChan: make(chan fn.Req[Endpoint, error]), + unregisterChan: make(chan fn.Req[EndpointName, error]), + msgChan: make(chan fn.Req[PeerMsg, error]), + endpointQueries: make(chan fn.Req[Endpoint, EndpointsMap]), + quit: make(chan struct{}), + } +} + +// Start starts the peer message router. +func (p *MultiMsgRouter) Start() { + log.Infof("Starting Router") + + p.startOnce.Do(func() { + p.wg.Add(1) + go p.msgRouter() + }) +} + +// Stop stops the peer message router. +func (p *MultiMsgRouter) Stop() { + log.Infof("Stopping Router") + + p.stopOnce.Do(func() { + close(p.quit) + p.wg.Wait() + }) +} + +// RegisterEndpoint registers a new endpoint with the router. If a duplicate +// endpoint exists, an error is returned. +func (p *MultiMsgRouter) RegisterEndpoint(endpoint Endpoint) error { + return sendQueryErr(p.registerChan, endpoint, p.quit) +} + +// UnregisterEndpoint unregisters the target endpoint from the router. +func (p *MultiMsgRouter) UnregisterEndpoint(name EndpointName) error { + return sendQueryErr(p.unregisterChan, name, p.quit) +} + +// RouteMsg attempts to route the target message to a registered endpoint. If +// ANY endpoint could handle the message, then nil is returned. +func (p *MultiMsgRouter) RouteMsg(msg PeerMsg) error { + return sendQueryErr(p.msgChan, msg, p.quit) +} + +// Endpoints returns a list of all registered endpoints. +func (p *MultiMsgRouter) endpoints() fn.Result[EndpointsMap] { + return sendQuery(p.endpointQueries, nil, p.quit) +} + +// msgRouter is the main goroutine that handles all incoming messages. +func (p *MultiMsgRouter) msgRouter() { + defer p.wg.Done() + + // endpoints is a map of all registered endpoints. + endpoints := make(map[EndpointName]Endpoint) + + for { + select { + // A new endpoint was just sent in, so we'll add it to our set + // of registered endpoints. + case newEndpointMsg := <-p.registerChan: + endpoint := newEndpointMsg.Request + + log.Infof("MsgRouter: registering new "+ + "Endpoint(%s)", endpoint.Name()) + + // If this endpoint already exists, then we'll return + // an error as we require unique names. + if _, ok := endpoints[endpoint.Name()]; ok { + log.Errorf("MsgRouter: rejecting "+ + "duplicate endpoint: %v", + endpoint.Name()) + + newEndpointMsg.Resolve(ErrDuplicateEndpoint) + + continue + } + + endpoints[endpoint.Name()] = endpoint + + newEndpointMsg.Resolve(nil) + + // A request to unregister an endpoint was just sent in, so + // we'll attempt to remove it. + case endpointName := <-p.unregisterChan: + delete(endpoints, endpointName.Request) + + log.Infof("MsgRouter: unregistering "+ + "Endpoint(%s)", endpointName.Request) + + endpointName.Resolve(nil) + + // A new message was just sent in. We'll attempt to route it to + // all the endpoints that can handle it. + case msgQuery := <-p.msgChan: + msg := msgQuery.Request + + // Loop through all the endpoints and send the message + // to those that can handle it the message. + var couldSend bool + for _, endpoint := range endpoints { + if endpoint.CanHandle(msg) { + log.Tracef("MsgRouter: sending "+ + "msg %T to endpoint %s", msg, + endpoint.Name()) + + sent := endpoint.SendMessage(msg) + couldSend = couldSend || sent + } + } + + var err error + if !couldSend { + log.Tracef("MsgRouter: unable to route "+ + "msg %T", msg) + + err = ErrUnableToRouteMsg + } + + msgQuery.Resolve(err) + + // A query for the endpoint state just came in, we'll send back + // a copy of our current state. + case endpointQuery := <-p.endpointQueries: + endpointsCopy := make(EndpointsMap, len(endpoints)) + maps.Copy(endpointsCopy, endpoints) + + endpointQuery.Resolve(endpointsCopy) + + case <-p.quit: + return + } + } +} + +// A compile time check to ensure MultiMsgRouter implements the MsgRouter +// interface. +var _ Router = (*MultiMsgRouter)(nil) diff --git a/msgmux/msg_router_test.go b/msgmux/msg_router_test.go new file mode 100644 index 0000000000..af8cdedef6 --- /dev/null +++ b/msgmux/msg_router_test.go @@ -0,0 +1,157 @@ +package msgmux + +import ( + "testing" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockEndpoint struct { + mock.Mock +} + +func (m *mockEndpoint) Name() string { + args := m.Called() + + return args.String(0) +} + +func (m *mockEndpoint) CanHandle(msg PeerMsg) bool { + args := m.Called(msg) + + return args.Bool(0) +} + +func (m *mockEndpoint) SendMessage(msg PeerMsg) bool { + args := m.Called(msg) + + return args.Bool(0) +} + +// TestMessageRouterOperation tests the basic operation of the message router: +// add new endpoints, route to them, remove, them, etc. +func TestMessageRouterOperation(t *testing.T) { + msgRouter := NewMultiMsgRouter() + msgRouter.Start() + defer msgRouter.Stop() + + openChanMsg := PeerMsg{ + Message: &lnwire.OpenChannel{}, + } + commitSigMsg := PeerMsg{ + Message: &lnwire.CommitSig{}, + } + + errorMsg := PeerMsg{ + Message: &lnwire.Error{}, + } + + // For this test, we'll have two endpoints, each with distinct names. + // One endpoint will only handle OpenChannel, while the other will + // handle the CommitSig message. + fundingEndpoint := &mockEndpoint{} + fundingEndpointName := "funding" + fundingEndpoint.On("Name").Return(fundingEndpointName) + fundingEndpoint.On("CanHandle", openChanMsg).Return(true) + fundingEndpoint.On("CanHandle", errorMsg).Return(false) + fundingEndpoint.On("CanHandle", commitSigMsg).Return(false) + fundingEndpoint.On("SendMessage", openChanMsg).Return(true) + + commitEndpoint := &mockEndpoint{} + commitEndpointName := "commit" + commitEndpoint.On("Name").Return(commitEndpointName) + commitEndpoint.On("CanHandle", commitSigMsg).Return(true) + commitEndpoint.On("CanHandle", openChanMsg).Return(false) + commitEndpoint.On("CanHandle", errorMsg).Return(false) + commitEndpoint.On("SendMessage", commitSigMsg).Return(true) + + t.Run("add endpoints", func(t *testing.T) { + // First, we'll add the funding endpoint to the router. + require.NoError(t, msgRouter.RegisterEndpoint(fundingEndpoint)) + + endpoints, err := msgRouter.endpoints().Unpack() + require.NoError(t, err) + + // There should be a single endpoint registered. + require.Len(t, endpoints, 1) + + // The name of the registered endpoint should be "funding". + require.Equal( + t, "funding", endpoints[fundingEndpointName].Name(), + ) + }) + + t.Run("duplicate endpoint reject", func(t *testing.T) { + // Next, we'll attempt to add the funding endpoint again. This + // should return an ErrDuplicateEndpoint error. + require.ErrorIs( + t, msgRouter.RegisterEndpoint(fundingEndpoint), + ErrDuplicateEndpoint, + ) + }) + + t.Run("route to endpoint", func(t *testing.T) { + // Next, we'll add our other endpoint, then attempt to route a + // message. + require.NoError(t, msgRouter.RegisterEndpoint(commitEndpoint)) + + // If we try to route a message none of the endpoints know of, + // we should get an error. + require.ErrorIs( + t, msgRouter.RouteMsg(errorMsg), ErrUnableToRouteMsg, + ) + + fundingEndpoint.AssertCalled(t, "CanHandle", errorMsg) + commitEndpoint.AssertCalled(t, "CanHandle", errorMsg) + + // Next, we'll route the open channel message. Only the + // fundingEndpoint should be used. + require.NoError(t, msgRouter.RouteMsg(openChanMsg)) + + fundingEndpoint.AssertCalled(t, "CanHandle", openChanMsg) + commitEndpoint.AssertCalled(t, "CanHandle", openChanMsg) + + fundingEndpoint.AssertCalled(t, "SendMessage", openChanMsg) + commitEndpoint.AssertNotCalled(t, "SendMessage", openChanMsg) + + // We'll do the same for the commit sig message. + require.NoError(t, msgRouter.RouteMsg(commitSigMsg)) + + fundingEndpoint.AssertCalled(t, "CanHandle", commitSigMsg) + commitEndpoint.AssertCalled(t, "CanHandle", commitSigMsg) + + commitEndpoint.AssertCalled(t, "SendMessage", commitSigMsg) + fundingEndpoint.AssertNotCalled(t, "SendMessage", commitSigMsg) + }) + + t.Run("remove endpoints", func(t *testing.T) { + // Finally, we'll remove both endpoints. + require.NoError( + t, msgRouter.UnregisterEndpoint(fundingEndpointName), + ) + require.NoError( + t, msgRouter.UnregisterEndpoint(commitEndpointName), + ) + + endpoints, err := msgRouter.endpoints().Unpack() + require.NoError(t, err) + + // There should be no endpoints registered. + require.Len(t, endpoints, 0) + + // Trying to route a message should fail. + require.ErrorIs( + t, msgRouter.RouteMsg(openChanMsg), + ErrUnableToRouteMsg, + ) + require.ErrorIs( + t, msgRouter.RouteMsg(commitSigMsg), + ErrUnableToRouteMsg, + ) + }) + + commitEndpoint.AssertExpectations(t) + fundingEndpoint.AssertExpectations(t) +} From 927aa84b5fcb716e34182f8aba6256bf2f633118 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 30 Jan 2024 18:00:11 -0800 Subject: [PATCH 2/4] peer: update readHandler to dispatch to msgRouter if set Over time with this, we should be able to significantly reduce the size of the peer.Brontide struct as we only need all those deps as the peer needs to recognize and handle each incoming wire message itself. --- peer/brontide.go | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/peer/brontide.go b/peer/brontide.go index 920a1bcca8..d561948502 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -42,6 +42,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/msgmux" "github.com/lightningnetwork/lnd/netann" "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/queue" @@ -522,6 +523,10 @@ type Brontide struct { // potentially holding lots of un-consumed events. channelEventClient *subscribe.Client + // msgRouter is an instance of the msgmux.Router which is used to send + // off new wire messages for handing. + msgRouter fn.Option[msgmux.Router] + startReady chan struct{} quit chan struct{} wg sync.WaitGroup @@ -559,6 +564,9 @@ func NewBrontide(cfg Config) *Brontide { startReady: make(chan struct{}), quit: make(chan struct{}), log: build.NewPrefixLog(logPrefix, peerLog), + msgRouter: fn.Some[msgmux.Router]( + msgmux.NewMultiMsgRouter(), + ), } if cfg.Conn != nil && cfg.Conn.RemoteAddr() != nil { @@ -738,6 +746,12 @@ func (p *Brontide) Start() error { return err } + // Register the message router now as we may need to register some + // endpoints while loading the channels below. + p.msgRouter.WhenSome(func(router msgmux.Router) { + router.Start() + }) + msgs, err := p.loadActiveChannels(activeChans) if err != nil { return fmt.Errorf("unable to load channels: %w", err) @@ -913,7 +927,8 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( p.cfg.Signer, dbChan, p.cfg.SigPool, ) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create channel "+ + "state machine: %w", err) } chanPoint := dbChan.FundingOutpoint @@ -1368,6 +1383,10 @@ func (p *Brontide) Disconnect(reason error) { p.cfg.Conn.Close() close(p.quit) + + p.msgRouter.WhenSome(func(router msgmux.Router) { + router.Stop() + }) } // String returns the string representation of this peer. @@ -1809,6 +1828,22 @@ out: } } + // If a message router is active, then we'll try to have it + // handle this message. If it can, then we're able to skip the + // rest of the message handling logic. + err = fn.MapOptionZ(p.msgRouter, func(r msgmux.Router) error { + return r.RouteMsg(msgmux.PeerMsg{ + PeerPub: *p.IdentityKey(), + Message: nextMsg, + }) + }) + + // No error occurred, and the message was handled by the + // router. + if err == nil { + continue + } + var ( targetChan lnwire.ChannelID isLinkUpdate bool From b028af1836b22a2550cb4ad47b1e1d52f500d5d0 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 4 Apr 2024 16:13:30 -0700 Subject: [PATCH 3/4] multi: make MsgRouter available in the ImplementationCfg With this commit, we allow the `MsgRouter` to be available in the `ImplementationCfg`. With this, programs outside of lnd itself are able to now hook into the message processing flow to direct handle custom messages, and even normal wire messages. --- config_builder.go | 14 ++++++++++++++ lnd.go | 1 + peer/brontide.go | 15 ++++++++++++--- server.go | 7 ++++++- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/config_builder.go b/config_builder.go index bf6274cdf5..bef59b9a04 100644 --- a/config_builder.go +++ b/config_builder.go @@ -33,6 +33,7 @@ import ( "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -42,6 +43,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/btcwallet" "github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/macaroons" + "github.com/lightningnetwork/lnd/msgmux" "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" @@ -118,6 +120,14 @@ type ChainControlBuilder interface { *btcwallet.Config) (*chainreg.ChainControl, func(), error) } +// AuxComponents is a set of auxiliary components that can be used by lnd for +// certain custom channel types. +type AuxComponents struct { + // MsgRouter is an optional message router that if set will be used in + // place of a new blank default message router. + MsgRouter fn.Option[msgmux.Router] +} + // ImplementationCfg is a struct that holds all configuration items for // components that can be implemented outside lnd itself. type ImplementationCfg struct { @@ -144,6 +154,10 @@ type ImplementationCfg struct { // ChainControlBuilder is a type that can provide a custom wallet // implementation. ChainControlBuilder + + // AuxComponents is a set of auxiliary components that can be used by + // lnd for certain custom channel types. + AuxComponents } // DefaultWalletImpl is the default implementation of our normal, btcwallet diff --git a/lnd.go b/lnd.go index 38f5c0d759..e483d5512f 100644 --- a/lnd.go +++ b/lnd.go @@ -600,6 +600,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, multiAcceptor, torController, tlsManager, leaderElector, + implCfg, ) if err != nil { return mkErr("unable to create server: %v", err) diff --git a/peer/brontide.go b/peer/brontide.go index d561948502..a6cb4da889 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -387,6 +387,11 @@ type Config struct { // This value will be passed to created links. MaxFeeExposure lnwire.MilliSatoshi + // MsgRouter is an optional instance of the main message router that + // the peer will use. If None, then a new default version will be used + // in place. + MsgRouter fn.Option[msgmux.Router] + // Quit is the server's quit channel. If this is closed, we halt operation. Quit chan struct{} } @@ -542,6 +547,12 @@ var _ lnpeer.Peer = (*Brontide)(nil) func NewBrontide(cfg Config) *Brontide { logPrefix := fmt.Sprintf("Peer(%x):", cfg.PubKeyBytes) + // We'll either use the msg router instance passed in, or create a new + // blank instance. + msgRouter := cfg.MsgRouter.Alt(fn.Some[msgmux.Router]( + msgmux.NewMultiMsgRouter(), + )) + p := &Brontide{ cfg: cfg, activeSignal: make(chan struct{}), @@ -564,9 +575,7 @@ func NewBrontide(cfg Config) *Brontide { startReady: make(chan struct{}), quit: make(chan struct{}), log: build.NewPrefixLog(logPrefix, peerLog), - msgRouter: fn.Some[msgmux.Router]( - msgmux.NewMultiMsgRouter(), - ), + msgRouter: msgRouter, } if cfg.Conn != nil && cfg.Conn.RemoteAddr() != nil { diff --git a/server.go b/server.go index 555ee26274..0cedb2f48b 100644 --- a/server.go +++ b/server.go @@ -160,6 +160,8 @@ type server struct { cfg *Config + implCfg *ImplementationCfg + // identityECDH is an ECDH capable wrapper for the private key used // to authenticate any incoming connections. identityECDH keychain.SingleKeyECDH @@ -486,7 +488,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, chansToRestore walletunlocker.ChannelsToRecover, chanPredicate chanacceptor.ChannelAcceptor, torController *tor.Controller, tlsManager *TLSManager, - leaderElector cluster.LeaderElector) (*server, error) { + leaderElector cluster.LeaderElector, + implCfg *ImplementationCfg) (*server, error) { var ( err error @@ -571,6 +574,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, + implCfg: implCfg, graphDB: dbs.GraphDB.ChannelGraph(), chanStateDB: dbs.ChanStateDB.ChannelStateDB(), addrSource: dbs.ChanStateDB, @@ -3986,6 +3990,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), MaxFeeExposure: thresholdMSats, Quit: s.quit, + MsgRouter: s.implCfg.MsgRouter, } copy(pCfg.PubKeyBytes[:], peerAddr.IdentityKey.SerializeCompressed()) From f09c517bc6efb5a015616e9afe331fb1d3809003 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 9 Jul 2024 14:05:33 -0700 Subject: [PATCH 4/4] peer: don't stop global msg router In this commit, we fix a bug that would cause a global message router to be stopped anytime a peer disconnected. The global msg router only allows `Start` to be called once, so afterwards, no messages would properly be routed. --- peer/brontide.go | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/peer/brontide.go b/peer/brontide.go index a6cb4da889..25c7cea6f2 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -532,6 +532,11 @@ type Brontide struct { // off new wire messages for handing. msgRouter fn.Option[msgmux.Router] + // globalMsgRouter is a flag that indicates whether we have a global + // msg router. If so, then we don't worry about stopping the msg router + // when a peer disconnects. + globalMsgRouter bool + startReady chan struct{} quit chan struct{} wg sync.WaitGroup @@ -547,6 +552,11 @@ var _ lnpeer.Peer = (*Brontide)(nil) func NewBrontide(cfg Config) *Brontide { logPrefix := fmt.Sprintf("Peer(%x):", cfg.PubKeyBytes) + // We have a global message router if one was passed in via the config. + // In this case, we don't need to attempt to tear it down when the peer + // is stopped. + globalMsgRouter := cfg.MsgRouter.IsSome() + // We'll either use the msg router instance passed in, or create a new // blank instance. msgRouter := cfg.MsgRouter.Alt(fn.Some[msgmux.Router]( @@ -576,6 +586,7 @@ func NewBrontide(cfg Config) *Brontide { quit: make(chan struct{}), log: build.NewPrefixLog(logPrefix, peerLog), msgRouter: msgRouter, + globalMsgRouter: globalMsgRouter, } if cfg.Conn != nil && cfg.Conn.RemoteAddr() != nil { @@ -1393,9 +1404,13 @@ func (p *Brontide) Disconnect(reason error) { close(p.quit) - p.msgRouter.WhenSome(func(router msgmux.Router) { - router.Stop() - }) + // If our msg router isn't global (local to this instance), then we'll + // stop it. Otherwise, we'll leave it running. + if !p.globalMsgRouter { + p.msgRouter.WhenSome(func(router msgmux.Router) { + router.Stop() + }) + } } // String returns the string representation of this peer.