diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala index c39bb1b226..a416d73c8d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala @@ -140,7 +140,7 @@ object Upstream { sealed trait Hot extends Upstream object Hot { /** Our node is forwarding a single incoming HTLC. */ - case class Channel(add: UpdateAddHtlc, receivedAt: TimestampMilli) extends Hot { + case class Channel(add: UpdateAddHtlc, receivedAt: TimestampMilli, receivedFrom: PublicKey) extends Hot { override val amountIn: MilliSatoshi = add.amountMsat val expiryIn: CltvExpiry = add.cltvExpiry } @@ -158,7 +158,7 @@ object Upstream { object Cold { def apply(hot: Hot): Cold = hot match { case Local(id) => Local(id) - case Hot.Channel(add, _) => Cold.Channel(add.channelId, add.id, add.amountMsat) + case Hot.Channel(add, _, _) => Cold.Channel(add.channelId, add.id, add.amountMsat) case Hot.Trampoline(received) => Cold.Trampoline(received.map(r => Cold.Channel(r.add.channelId, r.add.id, r.add.amountMsat)).toList) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala index b9c39913c5..e55d86d612 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala @@ -637,7 +637,7 @@ class Channel(val nodeParams: NodeParams, val wallet: OnChainChannelFunder with actions.foreach { case PostRevocationAction.RelayHtlc(add) => log.debug("forwarding incoming htlc {} to relayer", add) - relayer ! Relayer.RelayForward(add) + relayer ! Relayer.RelayForward(add, remoteNodeId) case PostRevocationAction.RejectHtlc(add) => log.debug("rejecting incoming htlc {}", add) // NB: we don't set commit = true, we will sign all updates at once afterwards. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 75891e2179..bb6b4f3b58 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -22,6 +22,7 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb @@ -53,15 +54,16 @@ object ChannelRelay { case class RelaySuccess(selectedChannelId: ByteVector32, cmdAdd: CMD_ADD_HTLC) extends RelayResult // @formatter:on - def apply(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], relayId: UUID, r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] = + def apply(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], originNode: PublicKey, relayId: UUID, r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] = Behaviors.setup { context => Behaviors.withMdc(Logs.mdc( category_opt = Some(Logs.LogCategory.PAYMENT), parentPaymentId_opt = Some(relayId), // for a channel relay, parent payment id = relay id paymentHash_opt = Some(r.add.paymentHash), nodeAlias_opt = Some(nodeParams.alias))) { + val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode) context.self ! DoRelay - new ChannelRelay(nodeParams, register, channels, r, context).relay(Seq.empty) + new ChannelRelay(nodeParams, register, channels, r, upstream, context).relay(Seq.empty) } } @@ -104,16 +106,14 @@ class ChannelRelay private(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], r: IncomingPaymentPacket.ChannelRelayPacket, - context: ActorContext[ChannelRelay.Command], - startedAt: TimestampMilli = TimestampMilli.now()) { + upstream: Upstream.Hot.Channel, + context: ActorContext[ChannelRelay.Command]) { import ChannelRelay._ private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) - private val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), startedAt) - private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { @@ -159,7 +159,7 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) => context.log.debug("relaying fulfill to upstream") val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true) - context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, startedAt, TimestampMilli.now())) + context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now())) recordRelayDuration(isSuccess = true) safeSendAndStop(upstream.add.channelId, cmd) @@ -320,5 +320,5 @@ class ChannelRelay private(nodeParams: NodeParams, Metrics.RelayedPaymentDuration .withTag(Tags.Relay, Tags.RelayType.Channel) .withTag(Tags.Success, isSuccess) - .record((TimestampMilli.now() - startedAt).toMillis, TimeUnit.MILLISECONDS) + .record((TimestampMilli.now() - upstream.receivedAt).toMillis, TimeUnit.MILLISECONDS) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index 59eb7b58b0..39d61a22c4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -42,7 +42,7 @@ object ChannelRelayer { // @formatter:off sealed trait Command case class GetOutgoingChannels(replyTo: ActorRef, getOutgoingChannels: Relayer.GetOutgoingChannels) extends Command - case class Relay(channelRelayPacket: IncomingPaymentPacket.ChannelRelayPacket) extends Command + case class Relay(channelRelayPacket: IncomingPaymentPacket.ChannelRelayPacket, originNode: PublicKey) extends Command private[payment] case class WrappedLocalChannelUpdate(localChannelUpdate: LocalChannelUpdate) extends Command private[payment] case class WrappedLocalChannelDown(localChannelDown: LocalChannelDown) extends Command private[payment] case class WrappedAvailableBalanceChanged(availableBalanceChanged: AvailableBalanceChanged) extends Command @@ -66,10 +66,9 @@ object ChannelRelayer { context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[LocalChannelDown](WrappedLocalChannelDown)) context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[AvailableBalanceChanged](WrappedAvailableBalanceChanged)) context.system.eventStream ! EventStream.Publish(SubscriptionsComplete(this.getClass)) - context.messageAdapter[IncomingPaymentPacket.ChannelRelayPacket](Relay) Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT), nodeAlias_opt = Some(nodeParams.alias)), mdc) { Behaviors.receiveMessage { - case Relay(channelRelayPacket) => + case Relay(channelRelayPacket, originNode) => val relayId = UUID.randomUUID() val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match { case Some(channelId) => channels.get(channelId).map(_.nextNodeId) @@ -80,7 +79,7 @@ object ChannelRelayer { case None => Map.empty } context.log.debug(s"spawning a new handler with relayId=$relayId to nextNodeId={} with channels={}", nextNodeId_opt.getOrElse(""), nextChannels.keys.mkString(",")) - context.spawn(ChannelRelay.apply(nodeParams, register, nextChannels, relayId, channelRelayPacket), name = relayId.toString) + context.spawn(ChannelRelay.apply(nodeParams, register, nextChannels, originNode, relayId, channelRelayPacket), name = relayId.toString) Behaviors.same case GetOutgoingChannels(replyTo, Relayer.GetOutgoingChannels(enabledOnly)) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index da9f1a17de..4f625e9fe7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -23,6 +23,7 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.{ActorRef, typed} import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket @@ -53,7 +54,7 @@ object NodeRelay { // @formatter:off sealed trait Command - case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket) extends Command + case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket, originNode: PublicKey) extends Command case object Stop extends Command private case class WrappedMultiPartExtraPaymentReceived(mppExtraReceived: MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]) extends Command private case class WrappedMultiPartPaymentFailed(mppFailed: MultiPartPaymentFSM.MultiPartPaymentFailed) extends Command @@ -203,11 +204,11 @@ class NodeRelay private(nodeParams: NodeParams, */ private def receiving(htlcs: Queue[Upstream.Hot.Channel], nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket], handler: ActorRef): Behavior[Command] = Behaviors.receiveMessagePartial { - case Relay(packet: IncomingPaymentPacket.NodeRelayPacket) => + case Relay(packet: IncomingPaymentPacket.NodeRelayPacket, originNode) => require(packet.outerPayload.paymentSecret == paymentSecret, "payment secret mismatch") context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", packet.add.id, packet.add.channelId) handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add) - receiving(htlcs :+ Upstream.Hot.Channel(packet.add.removeUnknownTlvs(), TimestampMilli.now()), nextPayload, nextPacket_opt, handler) + receiving(htlcs :+ Upstream.Hot.Channel(packet.add.removeUnknownTlvs(), TimestampMilli.now(), originNode), nextPayload, nextPacket_opt, handler) case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) => context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure) Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline) @@ -384,7 +385,7 @@ class NodeRelay private(nodeParams: NodeParams, } private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = { - case Relay(nodeRelayPacket) => + case Relay(nodeRelayPacket, _) => rejectExtraHtlc(nodeRelayPacket.add) Behaviors.same // NB: this message would be sent from the payment FSM which we stopped before going to this state, but all this is asynchronous. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index d74ef9e0c5..20d65b1991 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -20,6 +20,7 @@ import akka.actor.typed import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment._ import fr.acinq.eclair.{Logs, NodeParams} @@ -38,7 +39,7 @@ object NodeRelayer { // @formatter:off sealed trait Command - case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket) extends Command + case class Relay(nodeRelayPacket: IncomingPaymentPacket.NodeRelayPacket, originNode: PublicKey) extends Command case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32, paymentSecret: ByteVector32) extends Command private[relay] case class GetPendingPayments(replyTo: akka.actor.ActorRef) extends Command // @formatter:on @@ -61,20 +62,20 @@ object NodeRelayer { Behaviors.setup { context => Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) { Behaviors.receiveMessage { - case Relay(nodeRelayPacket) => + case Relay(nodeRelayPacket, originNode) => val htlcIn = nodeRelayPacket.add val childKey = PaymentKey(htlcIn.paymentHash, nodeRelayPacket.outerPayload.paymentSecret) children.get(childKey) match { case Some(handler) => context.log.debug("forwarding incoming htlc #{} from channel {} to existing handler", htlcIn.id, htlcIn.channelId) - handler ! NodeRelay.Relay(nodeRelayPacket) + handler ! NodeRelay.Relay(nodeRelayPacket, originNode) Behaviors.same case None => val relayId = UUID.randomUUID() context.log.debug(s"spawning a new handler with relayId=$relayId") val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, triggerer, router), relayId.toString) context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId) - handler ! NodeRelay.Relay(nodeRelayPacket) + handler ! NodeRelay.Relay(nodeRelayPacket, originNode) apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children + (childKey -> handler)) } case RelayComplete(childHandler, paymentHash, paymentSecret) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index dd20d82397..f9f5c0039b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -62,20 +62,20 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym def receive: Receive = { case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init - case RelayForward(add) => + case RelayForward(add, originNode) => log.debug(s"received forwarding request for htlc #${add.id} from channelId=${add.channelId}") IncomingPaymentPacket.decrypt(add, nodeParams.privateKey, nodeParams.features) match { case Right(p: IncomingPaymentPacket.FinalPacket) => log.debug(s"forwarding htlc #${add.id} to payment-handler") paymentHandler forward p case Right(r: IncomingPaymentPacket.ChannelRelayPacket) => - channelRelayer ! ChannelRelayer.Relay(r) + channelRelayer ! ChannelRelayer.Relay(r, originNode) case Right(r: IncomingPaymentPacket.NodeRelayPacket) => if (!nodeParams.enableTrampolinePayment) { log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} reason=trampoline disabled") PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, CMD_FAIL_HTLC(add.id, Right(RequiredNodeFeatureMissing()), commit = true)) } else { - nodeRelayer ! NodeRelayer.Relay(r) + nodeRelayer ! NodeRelayer.Relay(r, originNode) } case Left(badOnion: BadOnion) => log.warning(s"couldn't parse onion: reason=${badOnion.message}") @@ -108,7 +108,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym override def mdc(currentMessage: Any): MDC = { val paymentHash_opt = currentMessage match { - case RelayForward(add) => Some(add.paymentHash) + case RelayForward(add, _) => Some(add.paymentHash) case addFailed: RES_ADD_FAILED[_] => Some(addFailed.c.paymentHash) case addCompleted: RES_ADD_SETTLED[_, _] => Some(addCompleted.htlc.paymentHash) case _ => None @@ -145,7 +145,7 @@ object Relayer extends Logging { } } - case class RelayForward(add: UpdateAddHtlc) + case class RelayForward(add: UpdateAddHtlc, originNode: PublicKey) case class ChannelBalance(remoteNodeId: PublicKey, shortIds: ShortIds, canSend: MilliSatoshi, canReceive: MilliSatoshi, isPublic: Boolean, isEnabled: Boolean) sealed trait OutgoingChannelParams { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/balance/CheckBalanceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/balance/CheckBalanceSpec.scala index f102dca530..a5e4d68c73 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/balance/CheckBalanceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/balance/CheckBalanceSpec.scala @@ -172,7 +172,7 @@ class CheckBalanceSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val (ra2, htlca2) = addHtlc(100000000 msat, alice, bob, alice2bob, bob2alice) val (_, htlca3) = addHtlc(10000 msat, alice, bob, alice2bob, bob2alice) // for this one we set a non-local upstream to simulate a relayed payment - val (_, htlca4) = addHtlc(30000000 msat, CltvExpiryDelta(144), alice, bob, alice2bob, bob2alice, upstream = Upstream.Hot.Trampoline(Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 42, 30003000 msat, randomBytes32(), CltvExpiry(144), TestConstants.emptyOnionPacket, TlvStream.empty[UpdateAddHtlcTlv]), TimestampMilli(1687345927000L)) :: Nil), replyTo = TestProbe().ref) + val (_, htlca4) = addHtlc(30000000 msat, CltvExpiryDelta(144), alice, bob, alice2bob, bob2alice, upstream = Upstream.Hot.Trampoline(Upstream.Hot.Channel(UpdateAddHtlc(randomBytes32(), 42, 30003000 msat, randomBytes32(), CltvExpiry(144), TestConstants.emptyOnionPacket, TlvStream.empty[UpdateAddHtlcTlv]), TimestampMilli(1687345927000L), TestConstants.Alice.nodeParams.nodeId) :: Nil), replyTo = TestProbe().ref) val (rb1, htlcb1) = addHtlc(50000000 msat, bob, alice, bob2alice, alice2bob) val (_, _) = addHtlc(55000000 msat, bob, alice, bob2alice, alice2bob) crossSign(alice, bob, alice2bob, bob2alice) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalQuiescentStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalQuiescentStateSpec.scala index d39afd2ecb..fdf09f68ae 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalQuiescentStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalQuiescentStateSpec.scala @@ -311,7 +311,7 @@ class NormalQuiescentStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteL import f._ val (preimage, add) = addHtlc(50_000_000 msat, bob, alice, bob2alice, alice2bob) crossSign(bob, alice, bob2alice, alice2bob) - alice2relayer.expectMsg(RelayForward(add)) + alice2relayer.expectMsg(RelayForward(add, TestConstants.Bob.nodeParams.nodeId)) initiateQuiescence(f, sendInitialStfu = true) val forbiddenMsg = UpdateFulfillHtlc(channelId(bob), add.id, preimage) // both parties will respond to a forbidden msg while quiescent with a warning (and disconnect) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index 4e6e6ad446..5c91d3fbef 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -116,7 +116,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val sender = TestProbe() val h = randomBytes32() val originHtlc = UpdateAddHtlc(channelId = randomBytes32(), id = 5656, amountMsat = 50000000 msat, cltvExpiry = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), paymentHash = h, onionRoutingPacket = TestConstants.emptyOnionPacket, blinding_opt = None) - val origin = Origin.Hot(sender.ref, Upstream.Hot.Channel(originHtlc, TimestampMilli.now())) + val origin = Origin.Hot(sender.ref, Upstream.Hot.Channel(originHtlc, TimestampMilli.now(), randomKey().publicKey)) val cmd = CMD_ADD_HTLC(sender.ref, originHtlc.amountMsat - 10_000.msat, h, originHtlc.cltvExpiry - CltvExpiryDelta(7), TestConstants.emptyOnionPacket, None, origin) alice ! cmd sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] @@ -135,7 +135,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val h = randomBytes32() val originHtlc1 = UpdateAddHtlc(randomBytes32(), 47, 30000000 msat, h, CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None) val originHtlc2 = UpdateAddHtlc(randomBytes32(), 32, 20000000 msat, h, CltvExpiryDelta(160).toCltvExpiry(currentBlockHeight), TestConstants.emptyOnionPacket, None) - val origin = Origin.Hot(sender.ref, Upstream.Hot.Trampoline(Seq(originHtlc1, originHtlc2).map(htlc => Upstream.Hot.Channel(htlc, TimestampMilli.now())))) + val origin = Origin.Hot(sender.ref, Upstream.Hot.Trampoline(Seq(originHtlc1, originHtlc2).map(htlc => Upstream.Hot.Channel(htlc, TimestampMilli.now(), randomKey().publicKey)))) val cmd = CMD_ADD_HTLC(sender.ref, originHtlc1.amountMsat + originHtlc2.amountMsat - 10000.msat, h, originHtlc2.cltvExpiry - CltvExpiryDelta(7), TestConstants.emptyOnionPacket, None, origin) alice ! cmd sender.expectMsgType[RES_SUCCESS[CMD_ADD_HTLC]] @@ -1407,9 +1407,9 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with // Alice forwards HTLCs that fit in the dust exposure. alice2relayer.expectMsgAllOf( - RelayForward(nonDust), - RelayForward(almostTrimmed), - RelayForward(trimmed2), + RelayForward(nonDust, TestConstants.Bob.nodeParams.nodeId), + RelayForward(almostTrimmed, TestConstants.Bob.nodeParams.nodeId), + RelayForward(trimmed2, TestConstants.Bob.nodeParams.nodeId), ) alice2relayer.expectNoMessage(100 millis) // And instantly fails the others. @@ -1454,7 +1454,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with bob2alice.forward(alice) // Alice forwards HTLCs that fit in the dust exposure and instantly fails the others. - alice2relayer.expectMsg(RelayForward(acceptedHtlc)) + alice2relayer.expectMsg(RelayForward(acceptedHtlc, TestConstants.Bob.nodeParams.nodeId)) alice2relayer.expectNoMessage(100 millis) assert(alice2bob.expectMsgType[UpdateFailHtlc].id == rejectedHtlc.id) alice2bob.expectMsgType[CommitSig] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index a010fc1026..06d43507a0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -315,7 +315,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L))))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -367,7 +367,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L))))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) assert(payment_e.cmd.amount == amount_cd) assert(payment_e.cmd.cltvExpiry == expiry_cd) @@ -408,7 +408,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards the trampoline payment to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, inner_c.paymentSecret.get, invoice.extraEdges, inner_c.paymentMetadata) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L))))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(add_d2, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -467,7 +467,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid trampoline onion to e through d. val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e.copy(payload = trampolinePacket_e.payload.reverse))) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L))))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) assert(payment_e.outgoingChannel == channelUpdate_cd.shortChannelId) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, _, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -617,7 +617,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer total amount doesn't match the inner amount). val invalidTotalAmount = inner_c.amountToForward - 1.msat val recipient_e = ClearRecipient(e, Features.empty, invalidTotalAmount, inner_c.outgoingCltv, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L))))), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(invalidTotalAmount, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) @@ -633,7 +633,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // c forwards an invalid amount to e through (the outer expiry doesn't match the inner expiry). val invalidExpiry = inner_c.outgoingCltv - CltvExpiryDelta(12) val recipient_e = ClearRecipient(e, Features.empty, inner_c.amountToForward, invalidExpiry, randomBytes32(), nextTrampolineOnion_opt = Some(trampolinePacket_e)) - val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L))))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) + val Right(payment_e) = buildOutgoingPayment(Origin.Hot(ActorRef.noSender, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_c, TimestampMilli(1687345927000L), b)))), paymentHash, Route(inner_c.amountToForward, afterTrampolineChannelHops, None), recipient_e) val add_d = UpdateAddHtlc(randomBytes32(), 3, payment_e.cmd.amount, paymentHash, payment_e.cmd.cltvExpiry, payment_e.cmd.onion, None) val Right(ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features.empty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index ea290454a1..a2f5d9a901 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -337,9 +337,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit val htlc_upstream_1 = Seq(buildHtlcIn(0, channelId_ab_1, paymentHash1), buildHtlcIn(5, channelId_ab_1, paymentHash2)) val htlc_upstream_2 = Seq(buildHtlcIn(7, channelId_ab_2, paymentHash1), buildHtlcIn(9, channelId_ab_2, paymentHash2)) val htlc_upstream_3 = Seq(buildHtlcIn(11, randomBytes32(), paymentHash3)) - val upstream_1 = Upstream.Hot.Trampoline(Upstream.Hot.Channel(htlc_upstream_1.head.add, TimestampMilli(1687345927000L)) :: Upstream.Hot.Channel(htlc_upstream_2.head.add, TimestampMilli(1687345967000L)) :: Nil) - val upstream_2 = Upstream.Hot.Trampoline(Upstream.Hot.Channel(htlc_upstream_1(1).add, TimestampMilli(1687345902000L)) :: Upstream.Hot.Channel(htlc_upstream_2(1).add, TimestampMilli(1687345999000L)) :: Nil) - val upstream_3 = Upstream.Hot.Trampoline(Upstream.Hot.Channel(htlc_upstream_3.head.add, TimestampMilli(1687345980000L)) :: Nil) + val upstream_1 = Upstream.Hot.Trampoline(Upstream.Hot.Channel(htlc_upstream_1.head.add, TimestampMilli(1687345927000L), a) :: Upstream.Hot.Channel(htlc_upstream_2.head.add, TimestampMilli(1687345967000L), a) :: Nil) + val upstream_2 = Upstream.Hot.Trampoline(Upstream.Hot.Channel(htlc_upstream_1(1).add, TimestampMilli(1687345902000L), a) :: Upstream.Hot.Channel(htlc_upstream_2(1).add, TimestampMilli(1687345999000L), a) :: Nil) + val upstream_3 = Upstream.Hot.Trampoline(Upstream.Hot.Channel(htlc_upstream_3.head.add, TimestampMilli(1687345980000L), a) :: Nil) val data_upstream_1 = ChannelCodecsSpec.makeChannelDataNormal(htlc_upstream_1, Map.empty) val data_upstream_2 = ChannelCodecsSpec.makeChannelDataNormal(htlc_upstream_2, Map.empty) val data_upstream_3 = ChannelCodecsSpec.makeChannelDataNormal(htlc_upstream_3, Map.empty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index 78693075cc..539f394ba0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -85,7 +85,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val r = createValidIncomingPacket(payload) channelRelayer ! WrappedLocalChannelUpdate(lcu) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) if (success) { expectFwdAdd(register, lcu.channelId, outgoingAmount, outgoingExpiry) @@ -134,7 +134,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val payload1 = ChannelRelay.Standard(realScid1, outgoingAmount, outgoingExpiry) val r1 = createValidIncomingPacket(payload1) channelRelayer ! WrappedLocalChannelUpdate(lcu1) - channelRelayer ! Relay(r1) + channelRelayer ! Relay(r1, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, lcu1.channelId, outgoingAmount, outgoingExpiry) // reorg happens @@ -145,10 +145,10 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a channelRelayer ! WrappedLocalChannelUpdate(lcu2) // both old and new real scids work - channelRelayer ! Relay(r1) + channelRelayer ! Relay(r1, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, lcu1.channelId, outgoingAmount, outgoingExpiry) // new real scid works - channelRelayer ! Relay(r2) + channelRelayer ! Relay(r2, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, lcu2.channelId, outgoingAmount, outgoingExpiry) } @@ -160,7 +160,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) } @@ -179,7 +179,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u2 = createLocalUpdate(channelId2, balance = 80_000_000 msat) channelRelayer ! WrappedLocalChannelUpdate(u2) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // first try val fwd1 = expectFwdAdd(register, channelIds(realScid2), outgoingAmount, outgoingExpiry) @@ -201,7 +201,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val payload = ChannelRelay.Standard(realScid1, outgoingAmount, outgoingExpiry) val r = createValidIncomingPacket(payload) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) } @@ -214,7 +214,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u = createLocalUpdate(channelId1) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd = expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) fwd.replyTo ! Register.ForwardFailure(fwd) @@ -232,7 +232,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! WrappedLocalChannelDown(d) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) } @@ -245,7 +245,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u = createLocalUpdate(channelId1, enabled = false) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(u.channelUpdate.messageFlags, u.channelUpdate.channelFlags, Some(u.channelUpdate))), commit = true)) } @@ -258,7 +258,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u = createLocalUpdate(channelId1, htlcMinimum = outgoingAmount + 1.msat) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(outgoingAmount, Some(u.channelUpdate))), commit = true)) } @@ -272,7 +272,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val cmd = register.expectMessageType[Register.Forward[channel.Command]] assert(cmd.channelId == r.add.channelId) @@ -300,7 +300,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val r = createValidIncomingPacket(payload, expiryIn = outgoingExpiry + u.channelUpdate.cltvExpiryDelta + CltvExpiryDelta(1)) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, channelIds(realScid1), r.amountToForward, r.outgoingCltv).message } @@ -313,7 +313,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val r = createValidIncomingPacket(payload, expiryIn = outgoingExpiry + u.channelUpdate.cltvExpiryDelta - CltvExpiryDelta(1)) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(u.channelUpdate))), commit = true)) } @@ -326,7 +326,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u = createLocalUpdate(channelId1) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(u.channelUpdate))), commit = true)) } @@ -339,7 +339,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u1 = createLocalUpdate(channelId1, timestamp = TimestampSecond.now(), feeBaseMsat = 1 msat, feeProportionalMillionths = 0) channelRelayer ! WrappedLocalChannelUpdate(u1) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // relay succeeds with current channel update (u1) with lower fees expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) @@ -347,7 +347,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a val u2 = createLocalUpdate(channelId1, timestamp = TimestampSecond.now() - 530) channelRelayer ! WrappedLocalChannelUpdate(u2) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // relay succeeds because the current update (u2) with higher fees occurred less than 10 minutes ago expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) @@ -356,7 +356,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a channelRelayer ! WrappedLocalChannelUpdate(u1) channelRelayer ! WrappedLocalChannelUpdate(u3) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // relay fails because the current update (u3) with higher fees occurred more than 10 minutes ago expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(u3.channelUpdate))), commit = true)) @@ -385,7 +385,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a testCases.foreach { testCase => channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd = expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) fwd.message.replyTo ! RES_ADD_FAILED(fwd.message, testCase.exc, Some(testCase.update)) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(testCase.failure), commit = true)) @@ -420,7 +420,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a { val payload = ChannelRelay.Standard(ShortChannelId(12345), 998900 msat, CltvExpiry(60)) val r = createValidIncomingPacket(payload, 1000000 msat, CltvExpiry(70)) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // select the channel to the same node, with the lowest capacity and balance but still high enough to handle the payment val cmd1 = expectFwdAdd(register, channelUpdates(ShortChannelId(22223)).channelId, r.amountToForward, r.outgoingCltv).message cmd1.replyTo ! RES_ADD_FAILED(cmd1, ChannelUnavailable(randomBytes32()), None) @@ -440,35 +440,35 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a // higher amount payment (have to increased incoming htlc amount for fees to be sufficient) val payload = ChannelRelay.Standard(ShortChannelId(12345), 50000000 msat, CltvExpiry(60)) val r = createValidIncomingPacket(payload, 60000000 msat, CltvExpiry(70)) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, channelUpdates(ShortChannelId(11111)).channelId, r.amountToForward, r.outgoingCltv).message } { // lower amount payment val payload = ChannelRelay.Standard(ShortChannelId(12345), 1000 msat, CltvExpiry(60)) val r = createValidIncomingPacket(payload, 60000000 msat, CltvExpiry(70)) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, channelUpdates(ShortChannelId(33333)).channelId, r.amountToForward, r.outgoingCltv).message } { // payment too high, no suitable channel found, we keep the requested one val payload = ChannelRelay.Standard(ShortChannelId(12345), 1000000000 msat, CltvExpiry(60)) val r = createValidIncomingPacket(payload, 1010000000 msat, CltvExpiry(70)) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, channelUpdates(ShortChannelId(12345)).channelId, r.amountToForward, r.outgoingCltv).message } { // cltv expiry larger than our requirements val payload = ChannelRelay.Standard(ShortChannelId(12345), 998900 msat, CltvExpiry(50)) val r = createValidIncomingPacket(payload, 1000000 msat, CltvExpiry(70)) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdAdd(register, channelUpdates(ShortChannelId(22223)).channelId, r.amountToForward, r.outgoingCltv).message } { // cltv expiry too small, no suitable channel found val payload = ChannelRelay.Standard(ShortChannelId(12345), 998900 msat, CltvExpiry(61)) val r = createValidIncomingPacket(payload, 1000000 msat, CltvExpiry(70)) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) expectFwdFail(register, r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(CltvExpiry(61), Some(channelUpdates(ShortChannelId(12345)).channelUpdate))), commit = true)) } } @@ -494,7 +494,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a testCases.foreach { testCase => channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd = expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) fwd.message.replyTo ! RES_SUCCESS(fwd.message, channelId1) fwd.message.origin.replyTo ! RES_ADD_SETTLED(fwd.message.origin, downstream_htlc, testCase.result) @@ -520,7 +520,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a testCases.foreach { htlcResult => val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry) fwd.message.replyTo ! RES_SUCCESS(fwd.message, channelId1) fwd.message.origin.replyTo ! RES_ADD_SETTLED(fwd.message.origin, downstream, htlcResult) @@ -563,7 +563,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a testCases.foreach { testCase => channelRelayer ! WrappedLocalChannelUpdate(u) - channelRelayer ! Relay(r) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd1 = expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) fwd1.message.replyTo ! RES_SUCCESS(fwd1.message, channelId1) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 79ace35328..ea6672d9fd 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -112,26 +112,26 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (paymentHash1, paymentSecret1) = (randomBytes32(), randomBytes32()) val payment1 = createPartialIncomingPacket(paymentHash1, paymentSecret1) - parentRelayer ! NodeRelayer.Relay(payment1) + parentRelayer ! NodeRelayer.Relay(payment1, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending1.keySet == Set(PaymentKey(paymentHash1, paymentSecret1))) val (paymentHash2, paymentSecret2) = (randomBytes32(), randomBytes32()) val payment2 = createPartialIncomingPacket(paymentHash2, paymentSecret2) - parentRelayer ! NodeRelayer.Relay(payment2) + parentRelayer ! NodeRelayer.Relay(payment2, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending2 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending2.keySet == Set(PaymentKey(paymentHash1, paymentSecret1), PaymentKey(paymentHash2, paymentSecret2))) val payment3a = createPartialIncomingPacket(paymentHash1, paymentSecret2) - parentRelayer ! NodeRelayer.Relay(payment3a) + parentRelayer ! NodeRelayer.Relay(payment3a, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending3 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending3.keySet == Set(PaymentKey(paymentHash1, paymentSecret1), PaymentKey(paymentHash2, paymentSecret2), PaymentKey(paymentHash1, paymentSecret2))) val payment3b = createPartialIncomingPacket(paymentHash1, paymentSecret2) - parentRelayer ! NodeRelayer.Relay(payment3b) + parentRelayer ! NodeRelayer.Relay(payment3b, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending4 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending4.keySet == Set(PaymentKey(paymentHash1, paymentSecret1), PaymentKey(paymentHash2, paymentSecret2), PaymentKey(paymentHash1, paymentSecret2))) @@ -180,7 +180,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } { val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) - parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head) + parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending1.size == 1) @@ -190,7 +190,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) - parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head) + parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending2 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending2.size == 1) @@ -204,7 +204,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (nodeRelayer, parent) = f.createNodeRelay(incomingMultiPart.head) // Receive a partial upstream multi-part payment. - incomingMultiPart.dropRight(1).foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingMultiPart.dropRight(1).foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) // after a while the payment times out incomingMultiPart.dropRight(1).foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]](30 seconds) @@ -222,14 +222,14 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head) // We send all the parts of a mpp - incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) // and then one extra val extra = IncomingPaymentPacket.RelayToTrampolinePacket( UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) - nodeRelayer ! NodeRelay.Relay(extra) + nodeRelayer ! NodeRelay.Relay(extra, randomKey().publicKey) // the extra payment will be rejected val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -245,10 +245,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head) // Receive a complete upstream multi-part payment, which we relay out. - incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now())))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) @@ -258,7 +258,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) - nodeRelayer ! NodeRelay.Relay(i1) + nodeRelayer ! NodeRelay.Relay(i1, randomKey().publicKey) val fwd1 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd1.channelId == i1.add.channelId) @@ -271,7 +271,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) - nodeRelayer ! NodeRelay.Relay(i2) + nodeRelayer ! NodeRelay.Relay(i2, randomKey().publicKey) val fwd2 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd1.channelId == i1.add.channelId) @@ -288,7 +288,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val expiryOut = CltvExpiry(499900) val p = createValidIncomingPacket(2000000 msat, 2000000 msat, expiryIn, 1000000 msat, expiryOut) val (nodeRelayer, _) = f.createNodeRelay(p) - nodeRelayer ! NodeRelay.Relay(p) + nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) @@ -304,7 +304,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val expiryOut = CltvExpiry(300000) // not ok (chain height = 400000) val p = createValidIncomingPacket(2000000 msat, 2000000 msat, expiryIn, 1000000 msat, expiryOut) val (nodeRelayer, _) = f.createNodeRelay(p) - nodeRelayer ! NodeRelay.Relay(p) + nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) @@ -324,7 +324,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(1000000 msat, 3000000 msat, expiryIn2, 2100000 msat, expiryOut) ) val (nodeRelayer, _) = f.createNodeRelay(p.head) - p.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + p.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) p.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -339,7 +339,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) // wait until the NodeRelay is waiting for the trigger eventListener.expectMessageType[WaitingToRelayPayment] @@ -363,7 +363,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) // wait until the NodeRelay is waiting for the trigger eventListener.expectMessageType[WaitingToRelayPayment] @@ -377,7 +377,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // upstream payment relayed val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now())))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) // those are adapters for pay-fsm messages @@ -405,7 +405,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment // publish notification that peer is unavailable at the cancel-safety-before-timeout-block threshold before the current incoming payment expires (and before the timeout height) @@ -427,7 +427,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) // wait until the NodeRelay is waiting for the trigger eventListener.expectMessageType[WaitingToRelayPayment] @@ -450,11 +450,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl assert(!nodeParams.features.hasFeature(AsyncPaymentPrototype)) val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) // upstream payment relayed val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now())))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) // those are adapters for pay-fsm messages @@ -483,7 +483,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val p = createValidIncomingPacket(2000000 msat, 2000000 msat, CltvExpiry(500000), 1999000 msat, CltvExpiry(490000)) val (nodeRelayer, _) = f.createNodeRelay(p) - nodeRelayer ! NodeRelay.Relay(p) + nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) @@ -500,7 +500,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(1000000 msat, 3000000 msat, CltvExpiry(500000), 2999000 msat, CltvExpiry(400000)) ) val (nodeRelayer, _) = f.createNodeRelay(p.head) - p.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + p.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) p.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -516,7 +516,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val p = createValidIncomingPacket(5000000 msat, 5000000 msat, CltvExpiry(500000), 0 msat, CltvExpiry(490000)) val (nodeRelayer, _) = f.createNodeRelay(p) - nodeRelayer ! NodeRelay.Relay(p) + nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) @@ -533,7 +533,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(1000000 msat, 5000000 msat, CltvExpiry(500000), 0 msat, CltvExpiry(490000)) ) val (nodeRelayer, _) = f.createNodeRelay(p.head) - p.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + p.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) p.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -549,7 +549,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream multi-part payment. val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head) - incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) mockPayFSM.expectMessageType[SendPaymentConfig] // those are adapters for pay-fsm messages @@ -576,7 +576,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(outgoingAmount, outgoingAmount * 2, CltvExpiry(500000), outgoingAmount, outgoingExpiry), ) val (nodeRelayer, _) = f.createNodeRelay(incoming.head, useRealPaymentFactory = true) - incoming.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incoming.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef] router.expectMessageType[RouteRequest] @@ -596,7 +596,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream multi-part payment. val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head, useRealPaymentFactory = true) - incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef] router.expectMessageType[RouteRequest] @@ -619,7 +619,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream multi-part payment. val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head, useRealPaymentFactory = true) - incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef] router.expectMessageType[RouteRequest] @@ -641,7 +641,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream payment. val (nodeRelayer, _) = f.createNodeRelay(incomingSinglePart, useRealPaymentFactory = true) - nodeRelayer ! NodeRelay.Relay(incomingSinglePart) + nodeRelayer ! NodeRelay.Relay(incomingSinglePart, randomKey().publicKey) val routeRequest = router.expectMessageType[RouteRequest] val routeParams = routeRequest.routeParams @@ -656,13 +656,13 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream multi-part payment. val (nodeRelayer, parent) = f.createNodeRelay(incomingMultiPart.head) - incomingMultiPart.dropRight(1).foreach(p => nodeRelayer ! NodeRelay.Relay(p)) + incomingMultiPart.dropRight(1).foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment - nodeRelayer ! NodeRelay.Relay(incomingMultiPart.last) + nodeRelayer ! NodeRelay.Relay(incomingMultiPart.last, randomKey().publicKey) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now())))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) // those are adapters for pay-fsm messages @@ -695,10 +695,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive an upstream single-part payment. val (nodeRelayer, parent) = f.createNodeRelay(incomingSinglePart) - nodeRelayer ! NodeRelay.Relay(incomingSinglePart) + nodeRelayer ! NodeRelay.Relay(incomingSinglePart, randomKey().publicKey) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(Upstream.Hot.Channel(incomingSinglePart.add, TimestampMilli.now()) :: Nil)) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(Upstream.Hot.Channel(incomingSinglePart.add, TimestampMilli.now(), randomKey().publicKey) :: Nil)) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] validateOutgoingPayment(outgoingPayment) // those are adapters for pay-fsm messages @@ -730,10 +730,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.innerPayload.amountToForward, outgoingAmount * 3, outgoingExpiry, outgoingNodeId, invoice ))) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now())))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.nodeId == outgoingNodeId) assert(outgoingPayment.recipient.totalAmount == outgoingAmount) @@ -774,10 +774,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.innerPayload.amountToForward, incoming.innerPayload.amountToForward, outgoingExpiry, outgoingNodeId, invoice ))) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now())))) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey)))) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.recipient.nodeId == outgoingNodeId) assert(outgoingPayment.amount == outgoingAmount) @@ -818,7 +818,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.copy(innerPayload = invalidPayload) }) val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) incomingMultiPart.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -843,10 +843,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.innerPayload.amountToForward, outgoingExpiry, invoice ))) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now()))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -882,10 +882,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.innerPayload.amountToForward, outgoingExpiry, invoice ))) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now()))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -924,7 +924,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.innerPayload.amountToForward, outgoingExpiry, invoice ))) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val getNodeId = router.expectMessageType[Router.GetNodeId] assert(getNodeId.isNode1 == scidDir.isNode1) @@ -932,7 +932,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl getNodeId.replyTo ! Some(outgoingNodeId) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now()))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -971,7 +971,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incoming.innerPayload.amountToForward, outgoingExpiry, invoice ))) val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) - incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val getNodeId = router.expectMessageType[Router.GetNodeId] assert(getNodeId.isNode1 == scidDir.isNode1) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index c024a29e50..ea3dba19aa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -93,7 +93,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val Right(payment) = buildOutgoingPayment(TestConstants.emptyOrigin, paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc val add_ab = UpdateAddHtlc(randomBytes32(), 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) register.expectMessageType[Register.Forward[CMD_ADD_HTLC]] } @@ -102,7 +102,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val Right(payment) = buildOutgoingPayment(TestConstants.emptyOrigin, paymentHash, Route(finalAmount, hops.take(1), None), ClearRecipient(b, Features.empty, finalAmount, finalExpiry, paymentSecret)) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) val fp = paymentHandler.expectMessageType[FinalPacket] assert(fp.add == add_ab) @@ -123,7 +123,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(payment.cmd.amount == finalAmount) assert(payment.cmd.cltvExpiry == finalExpiry) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) val fp = paymentHandler.expectMessageType[FinalPacket] assert(fp.add == add_ab) @@ -143,7 +143,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val Right(payment) = buildOutgoingPayment(TestConstants.emptyOrigin, paymentHash, Route(finalAmount, hops, None), ClearRecipient(e, Features.empty, finalAmount, finalExpiry, paymentSecret)) // and then manually build an htlc with an invalid onion (hmac) val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion.copy(hmac = payment.cmd.onion.hmac.reverse), None) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message assert(fail.id == add_ab.id) @@ -162,7 +162,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val route = Route(finalAmount, Seq(channelHopFromUpdate(priv_a.publicKey, b, channelUpdate_ab)), Some(blindedHop)) val Right(payment) = buildOutgoingPayment(TestConstants.emptyOrigin, paymentHash, route, recipient) val add_ab = UpdateAddHtlc(channelId_ab, 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]].message assert(fail.id == add_ab.id) @@ -178,7 +178,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat // we use an expired blinded route. val Right(payment) = buildOutgoingBlindedPaymentAB(paymentHash, routeExpiry = CltvExpiry(nodeParams.currentBlockHeight - 1)) val add_ab = UpdateAddHtlc(channelId_ab, 0, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]].message assert(fail.id == add_ab.id) @@ -200,7 +200,7 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat // and then manually build an htlc val add_ab = UpdateAddHtlc(channelId_ab, 123456, payment.cmd.amount, payment.cmd.paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None) - relayer ! RelayForward(add_ab) + relayer ! RelayForward(add_ab, priv_a.publicKey) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]].message assert(fail.id == add_ab.id) @@ -215,8 +215,8 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val replyTo = TestProbe[Any]() val add_ab = UpdateAddHtlc(channelId = channelId_ab, id = 42, amountMsat = 11000000 msat, paymentHash = ByteVector32.Zeroes, CltvExpiry(4200), TestConstants.emptyOnionPacket, None) val add_bc = UpdateAddHtlc(channelId_bc, 72, 1000 msat, paymentHash, CltvExpiry(1), TestConstants.emptyOnionPacket, None) - val channelOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Channel(add_ab, TimestampMilli.now())) - val trampolineOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_ab, TimestampMilli.now())))) + val channelOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Channel(add_ab, TimestampMilli.now(), priv_a.publicKey)) + val trampolineOrigin = Origin.Hot(replyTo.ref.toClassic, Upstream.Hot.Trampoline(Seq(Upstream.Hot.Channel(add_ab, TimestampMilli.now(), priv_a.publicKey)))) val addSettled = Seq( RES_ADD_SETTLED(channelOrigin, add_bc, HtlcResult.OnChainFulfill(randomBytes32())), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala index 858ce59b32..7c72c64128 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1Spec.scala @@ -11,7 +11,7 @@ import fr.acinq.eclair.transactions.{CommitmentSpec, DirectedHtlc, IncomingHtlc, import fr.acinq.eclair.wire.internal.channel.version0.ChannelTypes0.ChannelVersion import fr.acinq.eclair.wire.internal.channel.version1.ChannelCodecs1.Codecs._ import fr.acinq.eclair.wire.protocol.UpdateAddHtlc -import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, MilliSatoshiLong, TestConstants, TimestampMilli, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, MilliSatoshiLong, TestConstants, TimestampMilli, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits._ import scodec.{Attempt, DecodeResult} @@ -123,7 +123,7 @@ class ChannelCodecs1Spec extends AnyFunSuite { assert(originCodec.decodeValue(originCodec.encode(localCold).require).require == localCold) val add = UpdateAddHtlc(randomBytes32(), 4324, 11000000 msat, randomBytes32(), CltvExpiry(400000), TestConstants.emptyOnionPacket, None) - val relayedHot = Origin.Hot(replyTo, Upstream.Hot.Channel(add, TimestampMilli(0))) + val relayedHot = Origin.Hot(replyTo, Upstream.Hot.Channel(add, TimestampMilli(0), randomKey().publicKey)) val relayedCold = Origin.Cold(Upstream.Cold.Channel(add.channelId, add.id, add.amountMsat)) assert(originCodec.decodeValue(originCodec.encode(relayedHot).require).require == relayedCold) assert(originCodec.decodeValue(originCodec.encode(relayedCold).require).require == relayedCold) @@ -133,7 +133,7 @@ class ChannelCodecs1Spec extends AnyFunSuite { UpdateAddHtlc(randomBytes32(), 1L, 2000 msat, randomBytes32(), CltvExpiry(400000), TestConstants.emptyOnionPacket, None), UpdateAddHtlc(randomBytes32(), 2L, 3000 msat, randomBytes32(), CltvExpiry(400000), TestConstants.emptyOnionPacket, None), ) - val trampolineRelayedHot = Origin.Hot(replyTo, Upstream.Hot.Trampoline(adds.map(add => Upstream.Hot.Channel(add, TimestampMilli(0))))) + val trampolineRelayedHot = Origin.Hot(replyTo, Upstream.Hot.Trampoline(adds.map(add => Upstream.Hot.Channel(add, TimestampMilli(0), randomKey().publicKey)))) // We didn't encode the incoming HTLC amount. val trampolineRelayed = Origin.Cold(Upstream.Cold.Trampoline(adds.map(add => Upstream.Cold.Channel(add.channelId, add.id, 0 msat)).toList)) assert(originCodec.decodeValue(originCodec.encode(trampolineRelayedHot).require).require == trampolineRelayed)