-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
htlcswitch: use fn.GoroutineManager #9140
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ package htlcswitch | |
|
||
import ( | ||
"bytes" | ||
"context" | ||
"errors" | ||
"fmt" | ||
"math/rand" | ||
|
@@ -85,6 +86,9 @@ var ( | |
// fail payments if they increase our fee exposure. This is currently | ||
// set to 500m msats. | ||
DefaultMaxFeeExposure = lnwire.MilliSatoshi(500_000_000) | ||
|
||
// background is a shortcut for context.Background. | ||
background = context.Background() | ||
) | ||
|
||
// plexPacket encapsulates switch packet and adds error channel to receive | ||
|
@@ -245,8 +249,8 @@ type Switch struct { | |
// This will be retrieved by the registered links atomically. | ||
bestHeight uint32 | ||
|
||
wg sync.WaitGroup | ||
quit chan struct{} | ||
// gm starts and stops tasks in goroutines and waits for them. | ||
gm *fn.GoroutineManager | ||
|
||
// cfg is a copy of the configuration struct that the htlc switch | ||
// service was initialized with. | ||
|
@@ -368,8 +372,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { | |
return nil, err | ||
} | ||
|
||
gm := fn.NewGoroutineManager() | ||
|
||
s := &Switch{ | ||
bestHeight: currentHeight, | ||
gm: gm, | ||
cfg: &cfg, | ||
circuits: circuitMap, | ||
linkIndex: make(map[lnwire.ChannelID]ChannelLink), | ||
|
@@ -382,7 +389,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { | |
chanCloseRequests: make(chan *ChanClose), | ||
resolutionMsgs: make(chan *resolutionMsg), | ||
resMsgStore: resStore, | ||
quit: make(chan struct{}), | ||
} | ||
|
||
s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) | ||
|
@@ -420,14 +426,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro | |
ResolutionMsg: msg, | ||
errChan: errChan, | ||
}: | ||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
return ErrSwitchExiting | ||
} | ||
|
||
select { | ||
case err := <-errChan: | ||
return err | ||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
return ErrSwitchExiting | ||
} | ||
} | ||
|
@@ -493,14 +499,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, | |
// Since the attempt was known, we can start a goroutine that can | ||
// extract the result when it is available, and pass it on to the | ||
// caller. | ||
s.wg.Add(1) | ||
go func() { | ||
defer s.wg.Done() | ||
|
||
ok := s.gm.Go(background, func(ctx context.Context) { | ||
var n *networkResult | ||
select { | ||
case n = <-nChan: | ||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it is not great to refer to |
||
// We close the result channel to signal a shutdown. We | ||
// don't send any result in this case since the HTLC is | ||
// still in flight. | ||
|
@@ -524,7 +527,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, | |
return | ||
} | ||
resultChan <- result | ||
}() | ||
}) | ||
// The switch shutting down is signaled by closing the channel. | ||
if !ok { | ||
close(resultChan) | ||
} | ||
|
||
return resultChan, nil | ||
} | ||
|
@@ -704,12 +711,19 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, | |
select { | ||
case <-linkQuit: | ||
return nil | ||
case <-s.quit: | ||
|
||
case <-s.gm.Done(): | ||
return nil | ||
|
||
default: | ||
// Spawn a goroutine to log the errors returned from failed packets. | ||
s.wg.Add(1) | ||
go s.logFwdErrs(&numSent, &wg, fwdChan) | ||
// Spawn a goroutine to log the errors returned from failed | ||
// packets. | ||
ok := s.gm.Go(background, func(ctx context.Context) { | ||
s.logFwdErrs(ctx, &numSent, &wg, fwdChan) | ||
}) | ||
if !ok { | ||
return nil | ||
ellemouton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
// Make a first pass over the packets, forwarding any settles or fails. | ||
|
@@ -820,8 +834,8 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, | |
} | ||
|
||
// logFwdErrs logs any errors received on `fwdChan`. | ||
func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { | ||
defer s.wg.Done() | ||
func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup, | ||
fwdChan chan error) { | ||
|
||
// Wait here until the outer function has finished persisting | ||
// and routing the packets. This guarantees we don't read from num until | ||
|
@@ -836,7 +850,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { | |
log.Errorf("Unhandled error while reforwarding htlc "+ | ||
"settle/fail over htlcswitch: %v", err) | ||
} | ||
case <-s.quit: | ||
|
||
case <-s.gm.Done(): | ||
log.Errorf("unable to forward htlc packet " + | ||
"htlc switch was stopped") | ||
return | ||
|
@@ -862,7 +877,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, | |
return nil | ||
case <-linkQuit: | ||
return ErrLinkShuttingDown | ||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
return errors.New("htlc switch was stopped") | ||
} | ||
} | ||
|
@@ -940,8 +955,6 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( | |
// | ||
// NOTE: This method MUST be spawned as a goroutine. | ||
func (s *Switch) handleLocalResponse(pkt *htlcPacket) { | ||
defer s.wg.Done() | ||
|
||
attemptID := pkt.incomingHTLCID | ||
|
||
// The error reason will be unencypted in case this a local | ||
|
@@ -1436,7 +1449,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, | |
case s.chanCloseRequests <- command: | ||
return updateChan, errChan | ||
|
||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
errChan <- ErrSwitchExiting | ||
close(updateChan) | ||
return updateChan, errChan | ||
|
@@ -1454,8 +1467,6 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, | |
// | ||
// NOTE: This MUST be run as a goroutine. | ||
func (s *Switch) htlcForwarder() { | ||
defer s.wg.Done() | ||
|
||
defer func() { | ||
s.blockEpochStream.Cancel() | ||
|
||
|
@@ -1489,6 +1500,8 @@ func (s *Switch) htlcForwarder() { | |
var wg sync.WaitGroup | ||
for _, link := range linksToStop { | ||
wg.Add(1) | ||
// Here it is ok to start a goroutine directly bypassing | ||
// s.gm, because we want for them to complete here. | ||
go func(l ChannelLink) { | ||
defer wg.Done() | ||
|
||
|
@@ -1628,15 +1641,16 @@ out: | |
// collect all the forwarding events since the last internal, | ||
// and write them out to our log. | ||
case <-s.cfg.FwdEventTicker.Ticks(): | ||
s.wg.Add(1) | ||
go func() { | ||
defer s.wg.Done() | ||
|
||
if err := s.FlushForwardingEvents(); err != nil { | ||
// The error of Go is ignored: if it is shutting down, | ||
// the loop will terminate on the next iteration, in | ||
// s.gm.Done case. | ||
ellemouton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_ = s.gm.Go(background, func(ctx context.Context) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let |
||
err := s.FlushForwardingEvents() | ||
if err != nil { | ||
log.Errorf("Unable to flush "+ | ||
"forwarding events: %v", err) | ||
} | ||
}() | ||
}) | ||
|
||
// The log ticker has fired, so we'll calculate some forwarding | ||
// stats for the last 10 seconds to display within the logs to | ||
|
@@ -1739,7 +1753,7 @@ out: | |
// memory. | ||
s.pendingSettleFails = s.pendingSettleFails[:0] | ||
|
||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
return | ||
} | ||
} | ||
|
@@ -1749,6 +1763,7 @@ out: | |
func (s *Switch) Start() error { | ||
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { | ||
log.Warn("Htlc Switch already started") | ||
|
||
return errors.New("htlc switch already started") | ||
} | ||
|
||
|
@@ -1760,19 +1775,32 @@ func (s *Switch) Start() error { | |
} | ||
s.blockEpochStream = blockEpochStream | ||
|
||
s.wg.Add(1) | ||
go s.htlcForwarder() | ||
ok := s.gm.Go(background, func(ctx context.Context) { | ||
s.htlcForwarder() | ||
}) | ||
if !ok { | ||
// We are already stopping so we can ignore the error. | ||
_ = s.Stop() | ||
err = fmt.Errorf("unable to start htlc forwarder: %w", | ||
ErrSwitchExiting) | ||
log.Errorf("%v", err) | ||
|
||
return err | ||
} | ||
|
||
if err := s.reforwardResponses(); err != nil { | ||
s.Stop() | ||
// We are already stopping so we can ignore the error. | ||
_ = s.Stop() | ||
log.Errorf("unable to reforward responses: %v", err) | ||
|
||
return err | ||
} | ||
|
||
if err := s.reforwardResolutions(); err != nil { | ||
// We are already stopping so we can ignore the error. | ||
_ = s.Stop() | ||
log.Errorf("unable to reforward resolutions: %v", err) | ||
|
||
return err | ||
} | ||
|
||
|
@@ -1991,9 +2019,8 @@ func (s *Switch) Stop() error { | |
log.Info("HTLC Switch shutting down...") | ||
defer log.Debug("HTLC Switch shutdown complete") | ||
|
||
close(s.quit) | ||
|
||
s.wg.Wait() | ||
// Ask running goroutines to stop and wait for them. | ||
s.gm.Stop() | ||
|
||
// Wait until all active goroutines have finished exiting before | ||
// stopping the mailboxes, otherwise the mailbox map could still be | ||
|
@@ -2349,7 +2376,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) { | |
select { | ||
case <-stopChan: | ||
return | ||
case <-s.quit: | ||
case <-s.gm.Done(): | ||
return | ||
} | ||
} | ||
|
@@ -3020,8 +3047,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { | |
// NOTE: `closeCircuit` modifies the state of `packet`. | ||
if localHTLC { | ||
// TODO(yy): remove the goroutine and send back the error here. | ||
s.wg.Add(1) | ||
go s.handleLocalResponse(packet) | ||
ok := s.gm.Go(background, func(ctx context.Context) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rather pass in a context to the calling func. Same for all the others |
||
s.handleLocalResponse(packet) | ||
}) | ||
if !ok { | ||
return ErrSwitchExiting | ||
} | ||
|
||
// If this is a locally initiated HTLC, there's no need to | ||
// forward it so we exit. | ||
|
@@ -3076,8 +3107,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, | |
// NOTE: `closeCircuit` modifies the state of `packet`. | ||
if packet.incomingChanID == hop.Source { | ||
// TODO(yy): remove the goroutine and send back the error here. | ||
s.wg.Add(1) | ||
go s.handleLocalResponse(packet) | ||
ok := s.gm.Go(background, func(ctx context.Context) { | ||
s.handleLocalResponse(packet) | ||
}) | ||
if !ok { | ||
return ErrSwitchExiting | ||
} | ||
|
||
// If this is a locally initiated HTLC, there's no need to | ||
// forward it so we exit. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont think we should do this. Rather use a
context.TODO()
where neededThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you rebase on top of #9344, then we can also add a context guard here and then we only need a single context.TODO() in
Start()