Skip to content

Commit

Permalink
Add incoming peer to Hot.Channel (#2883)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq authored Jul 18, 2024
1 parent 86373b4 commit 83d790e
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ object Upstream {
sealed trait Hot extends Upstream
object Hot {
/** Our node is forwarding a single incoming HTLC. */
case class Channel(add: UpdateAddHtlc, receivedAt: TimestampMilli) extends Hot {
case class Channel(add: UpdateAddHtlc, receivedAt: TimestampMilli, receivedFrom: PublicKey) extends Hot {
override val amountIn: MilliSatoshi = add.amountMsat
val expiryIn: CltvExpiry = add.cltvExpiry
}
Expand All @@ -158,7 +158,7 @@ object Upstream {
object Cold {
def apply(hot: Hot): Cold = hot match {
case Local(id) => Local(id)
case Hot.Channel(add, _) => Cold.Channel(add.channelId, add.id, add.amountMsat)
case Hot.Channel(add, _, _) => Cold.Channel(add.channelId, add.id, add.amountMsat)
case Hot.Trampoline(received) => Cold.Trampoline(received.map(r => Cold.Channel(r.add.channelId, r.add.id, r.add.amountMsat)).toList)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder with
actions.foreach {
case PostRevocationAction.RelayHtlc(add) =>
log.debug("forwarding incoming htlc {} to relayer", add)
relayer ! Relayer.RelayForward(add)
relayer ! Relayer.RelayForward(add, remoteNodeId)
case PostRevocationAction.RejectHtlc(add) =>
log.debug("rejecting incoming htlc {}", add)
// NB: we don't set commit = true, we will sign all updates at once afterwards.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel._
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.db.PendingCommandsDb
Expand Down Expand Up @@ -53,15 +54,16 @@ object ChannelRelay {
case class RelaySuccess(selectedChannelId: ByteVector32, cmdAdd: CMD_ADD_HTLC) extends RelayResult
// @formatter:on

def apply(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], relayId: UUID, r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] =
def apply(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], originNode: PublicKey, relayId: UUID, r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(
category_opt = Some(Logs.LogCategory.PAYMENT),
parentPaymentId_opt = Some(relayId), // for a channel relay, parent payment id = relay id
paymentHash_opt = Some(r.add.paymentHash),
nodeAlias_opt = Some(nodeParams.alias))) {
val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode)
context.self ! DoRelay
new ChannelRelay(nodeParams, register, channels, r, context).relay(Seq.empty)
new ChannelRelay(nodeParams, register, channels, r, upstream, context).relay(Seq.empty)
}
}

Expand Down Expand Up @@ -104,16 +106,14 @@ class ChannelRelay private(nodeParams: NodeParams,
register: ActorRef,
channels: Map[ByteVector32, Relayer.OutgoingChannel],
r: IncomingPaymentPacket.ChannelRelayPacket,
context: ActorContext[ChannelRelay.Command],
startedAt: TimestampMilli = TimestampMilli.now()) {
upstream: Upstream.Hot.Channel,
context: ActorContext[ChannelRelay.Command]) {

import ChannelRelay._

private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure)
private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse)

private val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), startedAt)

private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException])

def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = {
Expand Down Expand Up @@ -159,7 +159,7 @@ class ChannelRelay private(nodeParams: NodeParams,
case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) =>
context.log.debug("relaying fulfill to upstream")
val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true)
context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, startedAt, TimestampMilli.now()))
context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now()))
recordRelayDuration(isSuccess = true)
safeSendAndStop(upstream.add.channelId, cmd)

Expand Down Expand Up @@ -320,5 +320,5 @@ class ChannelRelay private(nodeParams: NodeParams,
Metrics.RelayedPaymentDuration
.withTag(Tags.Relay, Tags.RelayType.Channel)
.withTag(Tags.Success, isSuccess)
.record((TimestampMilli.now() - startedAt).toMillis, TimeUnit.MILLISECONDS)
.record((TimestampMilli.now() - upstream.receivedAt).toMillis, TimeUnit.MILLISECONDS)
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ object ChannelRelayer {
// @formatter:off
sealed trait Command
case class GetOutgoingChannels(replyTo: ActorRef, getOutgoingChannels: Relayer.GetOutgoingChannels) extends Command
case class Relay(channelRelayPacket: IncomingPaymentPacket.ChannelRelayPacket) extends Command
case class Relay(channelRelayPacket: IncomingPaymentPacket.ChannelRelayPacket, originNode: PublicKey) extends Command
private[payment] case class WrappedLocalChannelUpdate(localChannelUpdate: LocalChannelUpdate) extends Command
private[payment] case class WrappedLocalChannelDown(localChannelDown: LocalChannelDown) extends Command
private[payment] case class WrappedAvailableBalanceChanged(availableBalanceChanged: AvailableBalanceChanged) extends Command
Expand All @@ -66,10 +66,9 @@ object ChannelRelayer {
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[LocalChannelDown](WrappedLocalChannelDown))
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[AvailableBalanceChanged](WrappedAvailableBalanceChanged))
context.system.eventStream ! EventStream.Publish(SubscriptionsComplete(this.getClass))
context.messageAdapter[IncomingPaymentPacket.ChannelRelayPacket](Relay)
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT), nodeAlias_opt = Some(nodeParams.alias)), mdc) {
Behaviors.receiveMessage {
case Relay(channelRelayPacket) =>
case Relay(channelRelayPacket, originNode) =>
val relayId = UUID.randomUUID()
val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match {
case Some(channelId) => channels.get(channelId).map(_.nextNodeId)
Expand All @@ -80,7 +79,7 @@ object ChannelRelayer {
case None => Map.empty
}
context.log.debug(s"spawning a new handler with relayId=$relayId to nextNodeId={} with channels={}", nextNodeId_opt.getOrElse(""), nextChannels.keys.mkString(","))
context.spawn(ChannelRelay.apply(nodeParams, register, nextChannels, relayId, channelRelayPacket), name = relayId.toString)
context.spawn(ChannelRelay.apply(nodeParams, register, nextChannels, originNode, relayId, channelRelayPacket), name = relayId.toString)
Behaviors.same

case GetOutgoingChannels(replyTo, Relayer.GetOutgoingChannels(enabledOnly)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.{ActorRef, typed}
import com.softwaremill.quicklens.ModifyPimp
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream}
import fr.acinq.eclair.db.PendingCommandsDb
import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket
Expand Down Expand Up @@ -53,7 +54,7 @@ object NodeRelay {

// @formatter:off
sealed trait Command
case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket) extends Command
case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket, originNode: PublicKey) extends Command
case object Stop extends Command
private case class WrappedMultiPartExtraPaymentReceived(mppExtraReceived: MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]) extends Command
private case class WrappedMultiPartPaymentFailed(mppFailed: MultiPartPaymentFSM.MultiPartPaymentFailed) extends Command
Expand Down Expand Up @@ -203,11 +204,11 @@ class NodeRelay private(nodeParams: NodeParams,
*/
private def receiving(htlcs: Queue[Upstream.Hot.Channel], nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket], handler: ActorRef): Behavior[Command] =
Behaviors.receiveMessagePartial {
case Relay(packet: IncomingPaymentPacket.NodeRelayPacket) =>
case Relay(packet: IncomingPaymentPacket.NodeRelayPacket, originNode) =>
require(packet.outerPayload.paymentSecret == paymentSecret, "payment secret mismatch")
context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", packet.add.id, packet.add.channelId)
handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add)
receiving(htlcs :+ Upstream.Hot.Channel(packet.add.removeUnknownTlvs(), TimestampMilli.now()), nextPayload, nextPacket_opt, handler)
receiving(htlcs :+ Upstream.Hot.Channel(packet.add.removeUnknownTlvs(), TimestampMilli.now(), originNode), nextPayload, nextPacket_opt, handler)
case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) =>
context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure)
Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline)
Expand Down Expand Up @@ -384,7 +385,7 @@ class NodeRelay private(nodeParams: NodeParams,
}

private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = {
case Relay(nodeRelayPacket) =>
case Relay(nodeRelayPacket, _) =>
rejectExtraHtlc(nodeRelayPacket.add)
Behaviors.same
// NB: this message would be sent from the payment FSM which we stopped before going to this state, but all this is asynchronous.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import akka.actor.typed
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.payment._
import fr.acinq.eclair.{Logs, NodeParams}

Expand All @@ -38,7 +39,7 @@ object NodeRelayer {

// @formatter:off
sealed trait Command
case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket) extends Command
case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket, originNode: PublicKey) extends Command
case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32, paymentSecret: ByteVector32) extends Command
private[relay] case class GetPendingPayments(replyTo: akka.actor.ActorRef) extends Command
// @formatter:on
Expand All @@ -61,20 +62,20 @@ object NodeRelayer {
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) {
Behaviors.receiveMessage {
case Relay(nodeRelayPacket) =>
case Relay(nodeRelayPacket, originNode) =>
val htlcIn = nodeRelayPacket.add
val childKey = PaymentKey(htlcIn.paymentHash, nodeRelayPacket.outerPayload.paymentSecret)
children.get(childKey) match {
case Some(handler) =>
context.log.debug("forwarding incoming htlc #{} from channel {} to existing handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket)
handler ! NodeRelay.Relay(nodeRelayPacket, originNode)
Behaviors.same
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, triggerer, router), relayId.toString)
context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket)
handler ! NodeRelay.Relay(nodeRelayPacket, originNode)
apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children + (childKey -> handler))
}
case RelayComplete(childHandler, paymentHash, paymentSecret) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym

def receive: Receive = {
case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init
case RelayForward(add) =>
case RelayForward(add, originNode) =>
log.debug(s"received forwarding request for htlc #${add.id} from channelId=${add.channelId}")
IncomingPaymentPacket.decrypt(add, nodeParams.privateKey, nodeParams.features) match {
case Right(p: IncomingPaymentPacket.FinalPacket) =>
log.debug(s"forwarding htlc #${add.id} to payment-handler")
paymentHandler forward p
case Right(r: IncomingPaymentPacket.ChannelRelayPacket) =>
channelRelayer ! ChannelRelayer.Relay(r)
channelRelayer ! ChannelRelayer.Relay(r, originNode)
case Right(r: IncomingPaymentPacket.NodeRelayPacket) =>
if (!nodeParams.enableTrampolinePayment) {
log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} reason=trampoline disabled")
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, CMD_FAIL_HTLC(add.id, Right(RequiredNodeFeatureMissing()), commit = true))
} else {
nodeRelayer ! NodeRelayer.Relay(r)
nodeRelayer ! NodeRelayer.Relay(r, originNode)
}
case Left(badOnion: BadOnion) =>
log.warning(s"couldn't parse onion: reason=${badOnion.message}")
Expand Down Expand Up @@ -108,7 +108,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym

override def mdc(currentMessage: Any): MDC = {
val paymentHash_opt = currentMessage match {
case RelayForward(add) => Some(add.paymentHash)
case RelayForward(add, _) => Some(add.paymentHash)
case addFailed: RES_ADD_FAILED[_] => Some(addFailed.c.paymentHash)
case addCompleted: RES_ADD_SETTLED[_, _] => Some(addCompleted.htlc.paymentHash)
case _ => None
Expand Down Expand Up @@ -145,7 +145,7 @@ object Relayer extends Logging {
}
}

case class RelayForward(add: UpdateAddHtlc)
case class RelayForward(add: UpdateAddHtlc, originNode: PublicKey)
case class ChannelBalance(remoteNodeId: PublicKey, shortIds: ShortIds, canSend: MilliSatoshi, canReceive: MilliSatoshi, isPublic: Boolean, isEnabled: Boolean)

sealed trait OutgoingChannelParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class CheckBalanceSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
val (ra2, htlca2) = addHtlc(100000000 msat, alice, bob, alice2bob, bob2alice)
val (_, htlca3) = addHtlc(10000 msat, alice, bob, alice2bob, bob2alice)
// for this one we set a non-local upstream to simulate a relayed payment
val (_, htlca4) = addHtlc(30000000 msat, CltvExpiryDelta(144), alice, bob, alice2bob, bob2alice, upstream = Upstream.Hot.Trampoline(Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 42, 30003000 msat, randomBytes32(), CltvExpiry(144), TestConstants.emptyOnionPacket, TlvStream.empty[UpdateAddHtlcTlv]), TimestampMilli(1687345927000L)) :: Nil), replyTo = TestProbe().ref)
val (_, htlca4) = addHtlc(30000000 msat, CltvExpiryDelta(144), alice, bob, alice2bob, bob2alice, upstream = Upstream.Hot.Trampoline(Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 42, 30003000 msat, randomBytes32(), CltvExpiry(144), TestConstants.emptyOnionPacket, TlvStream.empty[UpdateAddHtlcTlv]), TimestampMilli(1687345927000L), TestConstants.Alice.nodeParams.nodeId) :: Nil), replyTo = TestProbe().ref)
val (rb1, htlcb1) = addHtlc(50000000 msat, bob, alice, bob2alice, alice2bob)
val (_, _) = addHtlc(55000000 msat, bob, alice, bob2alice, alice2bob)
crossSign(alice, bob, alice2bob, bob2alice)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class NormalQuiescentStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteL
import f._
val (preimage, add) = addHtlc(50_000_000 msat, bob, alice, bob2alice, alice2bob)
crossSign(bob, alice, bob2alice, alice2bob)
alice2relayer.expectMsg(RelayForward(add))
alice2relayer.expectMsg(RelayForward(add, TestConstants.Bob.nodeParams.nodeId))
initiateQuiescence(f, sendInitialStfu = true)
val forbiddenMsg = UpdateFulfillHtlc(channelId(bob), add.id, preimage)
// both parties will respond to a forbidden msg while quiescent with a warning (and disconnect)
Expand Down
Loading

0 comments on commit 83d790e

Please sign in to comment.