Skip to content

Commit

Permalink
Merge pull request #8660 from GeorgeTsagk/interceptor-wire-records
Browse files Browse the repository at this point in the history
Enhance `update_add_htlc` with remote peer's custom records
  • Loading branch information
guggero committed May 22, 2024
2 parents b009db3 + 1b1969b commit 966f41f
Show file tree
Hide file tree
Showing 18 changed files with 1,228 additions and 778 deletions.
22 changes: 22 additions & 0 deletions channeldb/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,12 @@ type HTLC struct {
// HTLC. It is stored in the ExtraData field, which is used to store
// a TLV stream of additional information associated with the HTLC.
BlindingPoint lnwire.BlindingPointRecord

// CustomRecords is a set of custom TLV records that are associated with
// this HTLC. These records are used to store additional information
// about the HTLC that is not part of the standard HTLC fields. This
// field is encoded within the ExtraData field.
CustomRecords lnwire.CustomRecords
}

// serializeExtraData encodes a TLV stream of extra data to be stored with a
Expand All @@ -2568,6 +2574,11 @@ func (h *HTLC) serializeExtraData() error {
records = append(records, &b)
})

records, err := h.CustomRecords.ExtendRecordProducers(records)
if err != nil {
return err
}

return h.ExtraData.PackRecords(records...)
}

Expand All @@ -2589,7 +2600,18 @@ func (h *HTLC) deserializeExtraData() error {

if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil {
h.BlindingPoint = tlv.SomeRecordT(blindingPoint)

// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(tlvMap, h.BlindingPoint.TlvType())
}

// Set the custom records field to the remaining TLV records.
customRecords, err := lnwire.NewCustomRecordsFromTlvTypeMap(tlvMap)
if err != nil {
return err
}
h.CustomRecords = customRecords

return nil
}
Expand Down
84 changes: 47 additions & 37 deletions htlcswitch/interceptable_switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"sync"

"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb/models"
Expand Down Expand Up @@ -622,15 +623,16 @@ func (f *interceptedForward) Packet() InterceptedPacket {
ChanID: f.packet.incomingChanID,
HtlcID: f.packet.incomingHTLCID,
},
OutgoingChanID: f.packet.outgoingChanID,
Hash: f.htlc.PaymentHash,
OutgoingExpiry: f.htlc.Expiry,
OutgoingAmount: f.htlc.Amount,
IncomingAmount: f.packet.incomingAmount,
IncomingExpiry: f.packet.incomingTimeout,
CustomRecords: f.packet.customRecords,
OnionBlob: f.htlc.OnionBlob,
AutoFailHeight: f.autoFailHeight,
OutgoingChanID: f.packet.outgoingChanID,
Hash: f.htlc.PaymentHash,
OutgoingExpiry: f.htlc.Expiry,
OutgoingAmount: f.htlc.Amount,
IncomingAmount: f.packet.incomingAmount,
IncomingExpiry: f.packet.incomingTimeout,
CustomRecords: f.packet.customRecords,
OnionBlob: f.htlc.OnionBlob,
AutoFailHeight: f.autoFailHeight,
IncomingWireCustomRecords: f.packet.incomingCustomRecords,
}
}

Expand Down Expand Up @@ -659,50 +661,58 @@ func (f *interceptedForward) ResumeModified(
htlc.Amount = amount
})

//nolint:lll
err := fn.MapOptionZ(customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}
err := fn.MapOptionZ(
customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}

// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(records)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate custom "+
"records: %w", err)
}
// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(
records,
)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate "+
"custom records: %w", err)
}

return nil
})
return nil
},
)
if err != nil {
return fmt.Errorf("failed to encode custom records: %w",
err)
}

case *lnwire.UpdateFulfillHTLC:
//nolint:lll
err := fn.MapOptionZ(customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}
err := fn.MapOptionZ(
customRecords, func(records record.CustomSet) error {
if len(records) == 0 {
return nil
}

// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(records)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate custom "+
"records: %w", err)
}
// Type cast and validate custom records.
htlc.CustomRecords = lnwire.CustomRecords(
records,
)
err := htlc.CustomRecords.Validate()
if err != nil {
return fmt.Errorf("failed to validate "+
"custom records: %w", err)
}

return nil
})
return nil
},
)
if err != nil {
return fmt.Errorf("failed to encode custom records: %w",
err)
}
}

log.Tracef("Forwarding packet %v", spew.Sdump(f.packet))

// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
Expand Down
4 changes: 4 additions & 0 deletions htlcswitch/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ type InterceptedPacket struct {
// OnionBlob is the onion packet for the next hop
OnionBlob [lnwire.OnionPacketSize]byte

// IncomingWireCustomRecords are user-defined records that were defined
// by the peer that forwarded this htlc to us.
IncomingWireCustomRecords record.CustomSet

// AutoFailHeight is the block height at which this intercept will be
// failed back automatically.
AutoFailHeight int32
Expand Down
77 changes: 52 additions & 25 deletions htlcswitch/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tlv"
)

func init() {
Expand Down Expand Up @@ -3354,6 +3356,27 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
continue
}

var customRecords record.CustomSet
err = fn.MapOptionZ(
pd.CustomRecords, func(b tlv.Blob) error {
r, err := lnwire.ParseCustomRecords(b)
if err != nil {
return err
}

customRecords = record.CustomSet(r)

return nil
},
)
if err != nil {
l.fail(LinkFailureError{
code: ErrInternalError,
}, err.Error())

return
}

switch fwdPkg.State {
case channeldb.FwdStateProcessed:
// This add was not forwarded on the previous
Expand All @@ -3367,7 +3390,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
}

// Otherwise, it was already processed, we can
// can collect it and continue.
// collect it and continue.
addMsg := &lnwire.UpdateAddHTLC{
Expiry: fwdInfo.OutgoingCTLV,
Amount: fwdInfo.AmountToForward,
Expand All @@ -3387,19 +3410,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,

inboundFee := l.cfg.FwrdingPolicy.InboundFee

//nolint:lll
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingCustomRecords: customRecords,
}
switchPackets = append(
switchPackets, updatePacket,
Expand Down Expand Up @@ -3455,19 +3480,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
if fwdPkg.State == channeldb.FwdStateLockedIn {
inboundFee := l.cfg.FwrdingPolicy.InboundFee

//nolint:lll
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(),
inboundFee: inboundFee,
incomingCustomRecords: customRecords,
}

fwdPkg.FwdFilter.Set(idx)
Expand Down
4 changes: 4 additions & 0 deletions htlcswitch/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ type htlcPacket struct {
// were included in the payload.
customRecords record.CustomSet

// incomingCustomRecords are custom type range TLVs that are included
// in the incoming update_add_htlc.
incomingCustomRecords record.CustomSet

// originalOutgoingChanID is used when sending back failure messages.
// It is only used for forwarded Adds on option_scid_alias channels.
// This is to avoid possible confusion if a payer uses the public SCID
Expand Down
4 changes: 4 additions & 0 deletions itest/list_on_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ var allTestCases = []*lntest.TestCase{
Name: "forward interceptor modified htlc",
TestFunc: testForwardInterceptorModifiedHtlc,
},
{
Name: "forward interceptor wire records",
TestFunc: testForwardInterceptorWireRecords,
},
{
Name: "zero conf channel open",
TestFunc: testZeroConfChannelOpen,
Expand Down
Loading

0 comments on commit 966f41f

Please sign in to comment.