Skip to content

Commit

Permalink
Improve Origin and Upstream (#2872)
Browse files Browse the repository at this point in the history
We move the `Upstream` trait closer to the `Origin`, and make it more
obvious than a hot `Origin` is:

- an `Upstream` referencing the upstream HTLCs
- an actor requesting the outgoing payment

We also improve the cold trampoline relay class to record the incoming
HTLC amount, which we previously didn't bother encoding but is useful to
compute the fees collected during relay. To ensure backwards-compat, it
is set to `0 msat` for pending HTLCs. It will only affect HTLCs that
were pending during the upgrade, which is acceptable.
  • Loading branch information
t-bast authored Jun 27, 2024
1 parent c53b32c commit 791edf7
Show file tree
Hide file tree
Showing 39 changed files with 481 additions and 403 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ trait CustomCommitmentsPlugin extends PluginParams {
* result upstream to preserve channels. If you have non-standard HTLCs that may be in this situation, they should be
* returned by this method.
*/
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc], nodeParams: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]]
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc], nodeParams: NodeParams, log: LoggingAdapter): Map[Origin.Cold, Set[(ByteVector32, Long)]]
}

// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object CheckBalance {
// and succeed if they were sent from this node
val htlcOut = localCommit.spec.htlcs.collect(outgoing)
.filterNot(htlc => htlcsOutOnChain.contains(htlc.id)) // we filter the htlc that already pay us on-chain
.filterNot(htlc => originChannels.get(htlc.id).exists(_.isInstanceOf[Origin.Local]))
.filterNot(htlc => originChannels.get(htlc.id).exists(_.upstream.isInstanceOf[Upstream.Local]))
.filterNot(htlc => remoteHasPreimage(changes, htlc.id))
.sumAmount
// all claim txs have possibly been published
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ import fr.acinq.eclair.channel.LocalFundingStatus.DualFundedUnconfirmedFundingTx
import fr.acinq.eclair.channel.fund.InteractiveTxBuilder._
import fr.acinq.eclair.channel.fund.{InteractiveTxBuilder, InteractiveTxSigningSession}
import fr.acinq.eclair.io.Peer
import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream
import fr.acinq.eclair.transactions.CommitmentSpec
import fr.acinq.eclair.transactions.Transactions._
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelReady, ChannelReestablish, ChannelUpdate, ClosingSigned, CommitSig, FailureMessage, FundingCreated, FundingSigned, Init, OnionRoutingPacket, OpenChannel, OpenDualFundedChannel, Shutdown, SpliceInit, Stfu, TxSignatures, UpdateAddHtlc, UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFulfillHtlc}
import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, Features, InitFeature, MilliSatoshi, MilliSatoshiLong, RealShortChannelId, UInt64}
import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, Features, InitFeature, MilliSatoshi, MilliSatoshiLong, RealShortChannelId, TimestampMilli, UInt64}
import scodec.bits.ByteVector

import java.util.UUID
Expand Down Expand Up @@ -134,51 +133,58 @@ case class INPUT_RESTORED(data: PersistentChannelData)
"Y8888P" "Y88888P" 888 888 888 888 d88P 888 888 Y888 8888888P" "Y8888P"
*/

/** Detailed upstream parent(s) of a payment in the HTLC chain. */
sealed trait Upstream { def amountIn: MilliSatoshi }
object Upstream {
/** We haven't restarted and have full information about the upstream parent(s). */
sealed trait Hot extends Upstream
object Hot {
/** Our node is forwarding a single incoming HTLC. */
case class Channel(add: UpdateAddHtlc, receivedAt: TimestampMilli) extends Hot {
override val amountIn: MilliSatoshi = add.amountMsat
val expiryIn: CltvExpiry = add.cltvExpiry
}
/** Our node is forwarding a payment based on a set of HTLCs from potentially multiple upstream channels. */
case class Trampoline(received: Seq[Channel]) extends Hot {
override val amountIn: MilliSatoshi = received.map(_.add.amountMsat).sum
// We must use the lowest expiry of the incoming HTLC set.
val expiryIn: CltvExpiry = received.map(_.add.cltvExpiry).min
val receivedAt: TimestampMilli = received.map(_.receivedAt).max
}
}

/** We have restarted and stored limited information about the upstream parent(s). */
sealed trait Cold extends Upstream
object Cold {
def apply(hot: Hot): Cold = hot match {
case Local(id) => Local(id)
case Hot.Channel(add, _) => Cold.Channel(add.channelId, add.id, add.amountMsat)
case Hot.Trampoline(received) => Cold.Trampoline(received.map(r => Cold.Channel(r.add.channelId, r.add.id, r.add.amountMsat)).toList)
}

/** Our node is forwarding a single incoming HTLC. */
case class Channel(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi) extends Cold
/** Our node is forwarding a payment based on a set of HTLCs from potentially multiple upstream channels. */
case class Trampoline(originHtlcs: List[Channel]) extends Cold { override val amountIn: MilliSatoshi = originHtlcs.map(_.amountIn).sum }
}

/** Our node is the origin of the payment: there are no matching upstream HTLCs. */
case class Local(id: UUID) extends Hot with Cold { override val amountIn: MilliSatoshi = 0 msat }
}

/**
* Origin of a payment, answering both questions:
* - what actor in the app sent that htlc? (Origin.replyTo)
* - what are the upstream parent(s) of this payment in the htlc chain?
* - what actor in the app sent that htlc and is waiting for its result?
* - what are the upstream parent(s) of this payment in the htlc chain?
*/
sealed trait Origin
sealed trait Origin { def upstream: Upstream }
object Origin {
/** We haven't restarted since we sent the payment downstream: the origin actor is known. */
sealed trait Hot extends Origin { def replyTo: ActorRef }
case class Hot(replyTo: ActorRef, upstream: Upstream.Hot) extends Origin
/** We have restarted after the payment was sent, we have limited info and the origin actor doesn't exist anymore. */
sealed trait Cold extends Origin

/** Our node is the origin of the payment. */
sealed trait Local extends Origin { def id: UUID }
case class LocalHot(replyTo: ActorRef, id: UUID) extends Local with Hot
case class LocalCold(id: UUID) extends Local with Cold

/** Our node forwarded a single incoming HTLC to an outgoing channel. */
sealed trait ChannelRelayed extends Origin {
def originChannelId: ByteVector32
def originHtlcId: Long
def amountIn: MilliSatoshi
def amountOut: MilliSatoshi
}
case class ChannelRelayedHot(replyTo: ActorRef, add: UpdateAddHtlc, override val amountOut: MilliSatoshi) extends ChannelRelayed with Hot {
override def originChannelId: ByteVector32 = add.channelId
override def originHtlcId: Long = add.id
override def amountIn: MilliSatoshi = add.amountMsat
}
case class ChannelRelayedCold(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi, amountOut: MilliSatoshi) extends ChannelRelayed with Cold

/** Our node forwarded an incoming HTLC set to a remote outgoing node (potentially producing multiple downstream HTLCs).*/
sealed trait TrampolineRelayed extends Origin { def htlcs: List[(ByteVector32, Long)] }
case class TrampolineRelayedHot(replyTo: ActorRef, adds: Seq[UpdateAddHtlc]) extends TrampolineRelayed with Hot {
override def htlcs: List[(ByteVector32, Long)] = adds.map(u => (u.channelId, u.id)).toList
val amountIn: MilliSatoshi = adds.map(_.amountMsat).sum
val expiryIn: CltvExpiry = adds.map(_.cltvExpiry).min
}
case class TrampolineRelayedCold(override val htlcs: List[(ByteVector32, Long)]) extends TrampolineRelayed with Cold

object Hot {
def apply(replyTo: ActorRef, upstream: Upstream): Hot = upstream match {
case u: Upstream.Local => Origin.LocalHot(replyTo, u.id)
case u: Upstream.Trampoline => Origin.TrampolineRelayedHot(replyTo, u.adds.map(_.add))
}
case class Cold(upstream: Upstream.Cold) extends Origin
object Cold {
def apply(hot: Hot): Cold = Cold(Upstream.Cold(hot.upstream))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,7 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder with
// for our outgoing payments, let's send events if we know that they will settle on chain
Closing
.onChainOutgoingHtlcs(d.commitments.latest.localCommit, d.commitments.latest.remoteCommit, d.commitments.latest.nextRemoteCommit_opt.map(_.commit), tx)
.map(add => (add, d.commitments.originChannels.get(add.id).collect { case o: Origin.Local => o.id })) // we resolve the payment id if this was a local payment
.map(add => (add, d.commitments.originChannels.get(add.id).map(_.upstream).collect { case Upstream.Local(id) => id })) // we resolve the payment id if this was a local payment
.collect { case (add, Some(id)) => context.system.eventStream.publish(PaymentSettlingOnChain(id, amount = add.amountMsat, add.paymentHash)) }
// then let's see if any of the possible close scenarios can be considered done
val closingType_opt = Closing.isClosed(d1, Some(tx))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,17 +514,37 @@ object ChannelEventSerializer extends MinimalSerializer({
})

object OriginSerializer extends MinimalSerializer({
case o: Origin.Local => JObject(JField("paymentId", JString(o.id.toString)))
case o: Origin.ChannelRelayed => JObject(
JField("channelId", JString(o.originChannelId.toHex)),
JField("htlcId", JLong(o.originHtlcId)),
)
case o: Origin.TrampolineRelayed => JArray(o.htlcs.map {
case (channelId, htlcId) => JObject(
JField("channelId", JString(channelId.toHex)),
JField("htlcId", JLong(htlcId)),
case o: Origin => o.upstream match {
case u: Upstream.Local => JObject(JField("paymentId", JString(u.id.toString)))
case u: Upstream.Hot.Channel => JObject(
JField("channelId", JString(u.add.channelId.toHex)),
JField("htlcId", JLong(u.add.id)),
JField("amount", JLong(u.add.amountMsat.toLong)),
JField("expiry", JLong(u.add.cltvExpiry.toLong)),
JField("receivedAt", JLong(u.receivedAt.toLong)),
)
})
case u: Upstream.Hot.Trampoline => JArray(u.received.map { htlc =>
JObject(
JField("channelId", JString(htlc.add.channelId.toHex)),
JField("htlcId", JLong(htlc.add.id)),
JField("amount", JLong(htlc.add.amountMsat.toLong)),
JField("expiry", JLong(htlc.add.cltvExpiry.toLong)),
JField("receivedAt", JLong(htlc.receivedAt.toLong)),
)
}.toList)
case o: Upstream.Cold.Channel => JObject(
JField("channelId", JString(o.originChannelId.toHex)),
JField("htlcId", JLong(o.originHtlcId)),
JField("amount", JLong(o.amountIn.toLong)),
)
case o: Upstream.Cold.Trampoline => JArray(o.originHtlcs.map { htlc =>
JObject(
JField("channelId", JString(htlc.originChannelId.toHex)),
JField("htlcId", JLong(htlc.originHtlcId)),
JField("amount", JLong(htlc.amountIn.toLong)),
)
}.toList)
}
})

// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package fr.acinq.eclair.payment

import akka.actor.ActorRef
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CannotExtractSharedSecret, Origin}
Expand All @@ -26,11 +25,10 @@ import fr.acinq.eclair.router.Router.Route
import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.OutgoingBlindedPaths
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, PerHopPayload}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, TimestampMilli, UInt64, randomKey}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey}
import scodec.bits.ByteVector
import scodec.{Attempt, DecodeResult}

import java.util.UUID
import scala.util.{Failure, Success}

/**
Expand Down Expand Up @@ -263,17 +261,6 @@ object OutgoingPaymentPacket {
case class MissingTrampolineHop(trampolineNodeId: PublicKey) extends OutgoingPaymentError { override def getMessage: String = s"expected route to trampoline node $trampolineNodeId" }
case class MissingBlindedHop(introductionNodeIds: Set[PublicKey]) extends OutgoingPaymentError { override def getMessage: String = s"expected blinded route using one of the following introduction nodes: ${introductionNodeIds.mkString(", ")}" }
case object EmptyRoute extends OutgoingPaymentError { override def getMessage: String = "route cannot be empty" }

sealed trait Upstream
object Upstream {
case class Local(id: UUID) extends Upstream
case class Trampoline(adds: Seq[ReceivedHtlc]) extends Upstream {
val amountIn: MilliSatoshi = adds.map(_.add.amountMsat).sum
val expiryIn: CltvExpiry = adds.map(_.add.cltvExpiry).min
}

case class ReceivedHtlc(add: UpdateAddHtlc, receivedAt: TimestampMilli)
}
// @formatter:on

/**
Expand All @@ -298,12 +285,12 @@ object OutgoingPaymentPacket {
}

/** Build the command to add an HTLC for the given recipient using the provided route. */
def buildOutgoingPayment(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = {
def buildOutgoingPayment(origin: Origin.Hot, paymentHash: ByteVector32, route: Route, recipient: Recipient): Either[OutgoingPaymentError, OutgoingPaymentPacket] = {
for {
payment <- recipient.buildPayloads(paymentHash, route)
onion <- buildOnion(payment.payloads, paymentHash, Some(PaymentOnionCodecs.paymentOnionPayloadLength)) // BOLT 2 requires that associatedData == paymentHash
} yield {
val cmd = CMD_ADD_HTLC(replyTo, payment.amount, paymentHash, payment.expiry, onion.packet, payment.outerBlinding_opt, Origin.Hot(replyTo, upstream), commit = true)
val cmd = CMD_ADD_HTLC(origin.replyTo, payment.amount, paymentHash, payment.expiry, onion.packet, payment.outerBlinding_opt, origin, commit = true)
OutgoingPaymentPacket(cmd, route.hops.head.shortChannelId, onion.sharedSecrets)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class ChannelRelay private(nodeParams: NodeParams,
private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure)
private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse)

private val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), startedAt)

private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException])

def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = {
Expand All @@ -136,13 +138,13 @@ class ChannelRelay private(nodeParams: NodeParams,

def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, CMD_ADD_HTLC(_, _, _, _, _, _, o: Origin.ChannelRelayedHot, _)))) =>
context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${o.add.id}")
val cmdFail = CMD_FAIL_HTLC(o.add.id, Right(UnknownNextPeer()), commit = true)
case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) =>
context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}")
val cmdFail = CMD_FAIL_HTLC(upstream.add.id, Right(UnknownNextPeer()), commit = true)
Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel)
safeSendAndStop(o.add.channelId, cmdFail)
safeSendAndStop(upstream.add.channelId, cmdFail)

case WrappedAddResponse(addFailed@RES_ADD_FAILED(CMD_ADD_HTLC(_, _, _, _, _, _, _: Origin.ChannelRelayedHot, _), _, _)) =>
case WrappedAddResponse(addFailed: RES_ADD_FAILED[_]) =>
context.log.info("attempt failed with reason={}", addFailed.t.getClass.getSimpleName)
context.self ! DoRelay
relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed))
Expand All @@ -154,19 +156,19 @@ class ChannelRelay private(nodeParams: NodeParams,

def waitForAddSettled(): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedAddResponse(RES_ADD_SETTLED(o: Origin.ChannelRelayedHot, htlc, fulfill: HtlcResult.Fulfill)) =>
case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) =>
context.log.debug("relaying fulfill to upstream")
val cmd = CMD_FULFILL_HTLC(o.originHtlcId, fulfill.paymentPreimage, commit = true)
context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(o.amountIn, o.amountOut, htlc.paymentHash, o.originChannelId, htlc.channelId, startedAt, TimestampMilli.now()))
val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true)
context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, startedAt, TimestampMilli.now()))
recordRelayDuration(isSuccess = true)
safeSendAndStop(o.originChannelId, cmd)
safeSendAndStop(upstream.add.channelId, cmd)

case WrappedAddResponse(RES_ADD_SETTLED(o: Origin.ChannelRelayedHot, _, fail: HtlcResult.Fail)) =>
case WrappedAddResponse(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail)) =>
context.log.debug("relaying fail to upstream")
Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel)
val cmd = translateRelayFailure(o.originHtlcId, fail)
val cmd = translateRelayFailure(upstream.add.id, fail)
recordRelayDuration(isSuccess = false)
safeSendAndStop(o.originChannelId, cmd)
safeSendAndStop(upstream.add.channelId, cmd)
}

def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = {
Expand Down Expand Up @@ -305,7 +307,7 @@ class ChannelRelay private(nodeParams: NodeParams,
outgoingChannel_opt.flatMap(_.prevChannelUpdate).forall(c => r.relayFeeMsat < nodeFee(c.relayFees, r.amountToForward))) =>
RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(c.channelUpdate))), commit = true))
case Some(c: OutgoingChannel) =>
val origin = Origin.ChannelRelayedHot(addResponseAdapter.toClassic, r.add, r.amountToForward)
val origin = Origin.Hot(addResponseAdapter.toClassic, upstream)
val nextBlindingKey_opt = r.payload match {
case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding)
case _: IntermediatePayload.ChannelRelay.Standard => None
Expand Down
Loading

0 comments on commit 791edf7

Please sign in to comment.