Skip to content

Commit

Permalink
Register can forward messages to nodes (#2863)
Browse files Browse the repository at this point in the history
We add a `ForwardNodeId` command to the `Register` to forward messages
to a `Peer` actor based on its `node_id`.
  • Loading branch information
t-bast authored Jun 12, 2024
1 parent f0e3985 commit 741ac49
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 48 deletions.
6 changes: 3 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {

override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GET_CHANNEL_INFO]] = {
val futureResponse = toRemoteNode_opt match {
case Some(pk) => (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys)
case None => (appKit.register ? Symbol("channels")).mapTo[Map[ByteVector32, ActorRef]].map(_.keys)
case Some(pk) => (appKit.register ? Register.GetChannelsTo).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys)
case None => (appKit.register ? Register.GetChannels).mapTo[Map[ByteVector32, ActorRef]].map(_.keys)
}

for {
Expand Down Expand Up @@ -594,7 +594,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
/** Send a request to multiple channels using node ids */
private def sendToNodes[C <: Command, R <: CommandResponse[C]](nodeids: List[PublicKey], request: C)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, R]]] = {
for {
channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(kv => nodeids.contains(kv._2)).keys)
channelIds <- (appKit.register ? Register.GetChannelsTo).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(kv => nodeids.contains(kv._2)).keys)
res <- sendToChannels[C, R](channelIds.map(Left(_)).toList, request)
} yield res
}
Expand Down
63 changes: 44 additions & 19 deletions eclair-core/src/main/scala/fr/acinq/eclair/channel/Register.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,45 @@
package fr.acinq.eclair.channel

import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import akka.actor.{Actor, ActorLogging, ActorRef, Props, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.channel.Register._
import fr.acinq.eclair.{SubscriptionsComplete, ShortChannelId}
import fr.acinq.eclair.io.PeerCreated
import fr.acinq.eclair.{ShortChannelId, SubscriptionsComplete}

/**
* Created by PM on 26/01/2016.
*/

class Register() extends Actor with ActorLogging {
class Register extends Actor with ActorLogging {

context.system.eventStream.subscribe(self, classOf[PeerCreated])
context.system.eventStream.subscribe(self, classOf[ChannelCreated])
context.system.eventStream.subscribe(self, classOf[AbstractChannelRestored])
context.system.eventStream.subscribe(self, classOf[ChannelIdAssigned])
context.system.eventStream.subscribe(self, classOf[ShortChannelIdAssigned])
context.system.eventStream.publish(SubscriptionsComplete(this.getClass))

// @formatter:off
private case class ChannelTerminated(channel: ActorRef, channelId: ByteVector32)
// @formatter:on
override def receive: Receive = main(Map.empty, Map.empty, Map.empty, Map.empty)

override def receive: Receive = main(Map.empty, Map.empty, Map.empty)
def main(channels: Map[ByteVector32, ActorRef], shortIds: Map[ShortChannelId, ByteVector32], channelsTo: Map[ByteVector32, PublicKey], nodeIdToPeer: Map[PublicKey, ActorRef]): Receive = {
case PeerCreated(peer, remoteNodeId) =>
context.watchWith(peer, PeerTerminated(peer, remoteNodeId))
context become main(channels, shortIds, channelsTo, nodeIdToPeer + (remoteNodeId -> peer))

def main(channels: Map[ByteVector32, ActorRef], shortIds: Map[ShortChannelId, ByteVector32], channelsTo: Map[ByteVector32, PublicKey]): Receive = {
case ChannelCreated(channel, _, remoteNodeId, _, temporaryChannelId, _, _) =>
context.watchWith(channel, ChannelTerminated(channel, temporaryChannelId))
context become main(channels + (temporaryChannelId -> channel), shortIds, channelsTo + (temporaryChannelId -> remoteNodeId))
context become main(channels + (temporaryChannelId -> channel), shortIds, channelsTo + (temporaryChannelId -> remoteNodeId), nodeIdToPeer)

case event: AbstractChannelRestored =>
context.watchWith(event.channel, ChannelTerminated(event.channel, event.channelId))
context become main(channels + (event.channelId -> event.channel), shortIds, channelsTo + (event.channelId -> event.remoteNodeId))
context become main(channels + (event.channelId -> event.channel), shortIds, channelsTo + (event.channelId -> event.remoteNodeId), nodeIdToPeer)

case ChannelIdAssigned(channel, remoteNodeId, temporaryChannelId, channelId) =>
context.unwatch(channel)
context.watchWith(channel, ChannelTerminated(channel, channelId))
context become main(channels + (channelId -> channel) - temporaryChannelId, shortIds, channelsTo + (channelId -> remoteNodeId) - temporaryChannelId)
context become main(channels + (channelId -> channel) - temporaryChannelId, shortIds, channelsTo + (channelId -> remoteNodeId) - temporaryChannelId, nodeIdToPeer)

case scidAssigned: ShortChannelIdAssigned =>
// We map all known scids (real or alias) to the channel_id. The relayer is in charge of deciding whether a real
Expand All @@ -66,17 +67,24 @@ class Register() extends Actor with ActorLogging {
log.error("duplicate alias={} for channelIds={},{} this should never happen!", scidAssigned.shortIds.localAlias, channelId, scidAssigned.channelId)
case _ => ()
}
context become main(channels, shortIds ++ m, channelsTo)
context become main(channels, shortIds ++ m, channelsTo, nodeIdToPeer)

case ChannelTerminated(_, channelId) =>
val shortChannelIds = shortIds.collect { case (key, value) if value == channelId => key }
context become main(channels - channelId, shortIds -- shortChannelIds, channelsTo - channelId)

case Symbol("channels") => sender() ! channels
context become main(channels - channelId, shortIds -- shortChannelIds, channelsTo - channelId, nodeIdToPeer)

case PeerTerminated(peer, remoteNodeId) =>
// Note that peer actors can be stopped and recreated, which may lead to race conditions between PeerCreated and
// PeerTerminated messages: we only remove that nodeId from the map if the actor matches.
if (nodeIdToPeer.get(remoteNodeId).contains(peer)) {
context become main(channels, shortIds, channelsTo, nodeIdToPeer - remoteNodeId)
} else {
log.debug("ignoring obsolete PeerTerminated event for remoteNodeId={}", remoteNodeId)
}

case Symbol("shortIds") => sender() ! shortIds
case GetChannels => sender() ! channels

case Symbol("channelsTo") => sender() ! channelsTo
case GetChannelsTo => sender() ! channelsTo

case GetNextNodeId(replyTo, shortChannelId) =>
replyTo ! shortIds.get(shortChannelId).flatMap(cid => channelsTo.get(cid))
Expand All @@ -96,20 +104,37 @@ class Register() extends Actor with ActorLogging {
case Some(channel) => channel.tell(msg, compatReplyTo)
case None => compatReplyTo ! ForwardShortIdFailure(fwd)
}

case fwd@ForwardNodeId(replyTo, nodeId, msg) =>
nodeIdToPeer.get(nodeId) match {
case Some(peer) => peer.tell(msg, replyTo.toClassic)
case None => replyTo ! ForwardNodeIdFailure(fwd)
}
}
}

object Register {

def props(): Props = Props(new Register())

// @formatter:off
private[channel] case class PeerTerminated(peer: ActorRef, nodeId: PublicKey)
private case class ChannelTerminated(channel: ActorRef, channelId: ByteVector32)
// @formatter:on

// @formatter:off
case class Forward[T](replyTo: akka.actor.typed.ActorRef[ForwardFailure[T]], channelId: ByteVector32, message: T)
case class ForwardShortId[T](replyTo: akka.actor.typed.ActorRef[ForwardShortIdFailure[T]], shortChannelId: ShortChannelId, message: T)
case class ForwardNodeId[T](replyTo: akka.actor.typed.ActorRef[ForwardNodeIdFailure[T]], nodeId: PublicKey, message: T)

case class ForwardFailure[T](fwd: Forward[T])
case class ForwardShortIdFailure[T](fwd: ForwardShortId[T])
// @formatter:on
case class ForwardNodeIdFailure[T](fwd: ForwardNodeId[T])

case class GetNextNodeId(replyTo: typed.ActorRef[Option[PublicKey]], shortChannelId: ShortChannelId)

case object GetChannels
case object GetChannelsTo
// @formatter:on

}
11 changes: 6 additions & 5 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Peer(val nodeParams: NodeParams,
channel ! INPUT_RESTORED(state)
FinalChannelId(state.channelId) -> channel
}.toMap
context.system.eventStream.publish(PeerCreated(self, remoteNodeId))
goto(DISCONNECTED) using DisconnectedData(channels) // when we restart, we will attempt to reconnect right away, but then we'll wait
}

Expand Down Expand Up @@ -374,7 +375,7 @@ class Peer(val nodeParams: NodeParams,
context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId))
}

def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef]): State = {
private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef]): State = {
require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeid: $remoteNodeId != ${connectionReady.remoteNodeId}")
log.debug("got authenticated connection to address {}", connectionReady.address)

Expand All @@ -394,29 +395,29 @@ class Peer(val nodeParams: NodeParams,
* We need to ignore [[LightningMessage]] not sent by the current [[PeerConnection]]. This may happen if we switch
* between connections.
*/
def dropStaleMessages(s: StateFunction): StateFunction = {
private def dropStaleMessages(s: StateFunction): StateFunction = {
case Event(msg: LightningMessage, d: ConnectedData) if sender() != d.peerConnection =>
log.warning("dropping message from stale connection: {}", msg)
stay()
case e if s.isDefinedAt(e) =>
s(e)
}

def spawnChannel(): ActorRef = {
private def spawnChannel(): ActorRef = {
val channel = channelFactory.spawn(context, remoteNodeId)
context watch channel
channel
}

def replyUnknownChannel(peerConnection: ActorRef, unknownChannelId: ByteVector32): Unit = {
private def replyUnknownChannel(peerConnection: ActorRef, unknownChannelId: ByteVector32): Unit = {
val msg = Error(unknownChannelId, "unknown channel")
self ! Peer.OutgoingMessage(msg, peerConnection)
}

// resume the openChannelInterceptor in case of failure, we always want the open channel request to succeed or fail
private val openChannelInterceptor = context.spawnAnonymous(Behaviors.supervise(OpenChannelInterceptor(context.self.toTyped, nodeParams, remoteNodeId, wallet, pendingChannelsRateLimiter)).onFailure(typed.SupervisorStrategy.resume))

def stopPeer(): State = {
private def stopPeer(): State = {
log.info("removing peer from db")
nodeParams.db.peers.removePeer(remoteNodeId)
stop(FSM.Normal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.concurrent.duration._

sealed trait PeerEvent

case class PeerCreated(peer: ActorRef, nodeId: PublicKey) extends PeerEvent

case class ConnectionInfo(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init)

case class PeerConnected(peer: ActorRef, nodeId: PublicKey, connectionInfo: ConnectionInfo) extends PeerEvent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I

eclair.channelsInfo(toRemoteNode_opt = None).pipeTo(sender.ref)

register.expectMsg(Symbol("channels"))
register.expectMsg(Register.GetChannels)
register.reply(map)

val c1 = register.expectMsgType[Register.Forward[CMD_GET_CHANNEL_INFO]]
Expand Down Expand Up @@ -544,7 +544,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I

eclair.channelsInfo(toRemoteNode_opt = Some(a)).pipeTo(sender.ref)

register.expectMsg(Symbol("channelsTo"))
register.expectMsg(Register.GetChannelsTo)
register.reply(channels2Nodes)

val c1 = register.expectMsgType[Register.Forward[CMD_GET_CHANNEL_INFO]]
Expand Down Expand Up @@ -676,7 +676,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I

eclair.updateRelayFee(List(a, b), 999 msat, 1234).pipeTo(sender.ref)

register.expectMsg(Symbol("channelsTo"))
register.expectMsg(Register.GetChannelsTo)
register.reply(map)

val u1 = register.expectMsgType[Register.Forward[CMD_UPDATE_RELAY_FEE]]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,74 @@
package fr.acinq.eclair.channel

import fr.acinq.eclair._

import akka.actor.{ActorRef, Props}
import akka.testkit.TestProbe
import akka.actor.typed.scaladsl.adapter._
import akka.actor.{ActorRef, PoisonPill}
import akka.testkit.{TestActorRef, TestProbe}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import org.scalatest.funsuite.AnyFunSuiteLike
import org.scalatest.ParallelTestExecution
import fr.acinq.eclair._
import fr.acinq.eclair.io.PeerCreated
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, ParallelTestExecution}

class RegisterSpec extends TestKitBaseClass with AnyFunSuiteLike with ParallelTestExecution {
class RegisterSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with ParallelTestExecution {

case class CustomChannelRestored(channel: ActorRef, channelId: ByteVector32, peer: ActorRef, remoteNodeId: PublicKey) extends AbstractChannelRestored

test("register processes custom restored events") {
val sender = TestProbe()
val registerRef = system.actorOf(Register.props())
case class FixtureParam(register: TestActorRef[Register], probe: TestProbe)

override def withFixture(test: OneArgTest): Outcome = {
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[SubscriptionsComplete])
val register = TestActorRef(new Register())
probe.expectMsg(SubscriptionsComplete(classOf[Register]))
try {
withFixture(test.toNoArgTest(FixtureParam(register, probe)))
} finally {
system.stop(register)
}
}

test("process custom restored events") { f =>
import f._

val customRestoredEvent = CustomChannelRestored(TestProbe().ref, randomBytes32(), TestProbe().ref, randomKey().publicKey)
registerRef ! customRestoredEvent
sender.send(registerRef, Symbol("channels"))
sender.expectMsgType[Map[ByteVector32, ActorRef]] == Map(customRestoredEvent.channelId -> customRestoredEvent.channel)
system.eventStream.publish(customRestoredEvent)
awaitAssert({
probe.send(register, Register.GetChannels)
probe.expectMsgType[Map[ByteVector32, ActorRef]] == Map(customRestoredEvent.channelId -> customRestoredEvent.channel)
})
}

test("forward messages to peers") { f =>
import f._

val nodeId = randomKey().publicKey
val peer1 = TestProbe()
system.eventStream.publish(PeerCreated(peer1.ref, nodeId))

awaitAssert({
register ! Register.ForwardNodeId(probe.ref.toTyped, nodeId, "hello")
peer1.expectMsg("hello")
})

// We simulate a race condition, where the peer is recreated but we receive events out of order.
val peer2 = TestProbe()
system.eventStream.publish(PeerCreated(peer2.ref, nodeId))
awaitAssert({
register ! Register.ForwardNodeId(probe.ref.toTyped, nodeId, "world")
peer2.expectMsg("world")
})
register ! Register.PeerTerminated(peer1.ref, nodeId)

register ! Register.ForwardNodeId(probe.ref.toTyped, nodeId, "hello again")
peer2.expectMsg("hello again")

peer2.ref ! PoisonPill
awaitAssert({
val fwd = Register.ForwardNodeId(probe.ref.toTyped, nodeId, "d34d")
register ! fwd
probe.expectMsg(Register.ForwardNodeIdFailure(fwd))
})
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class StandardChannelIntegrationSpec extends ChannelIntegrationSpec {
// mine the funding tx
generateBlocks(2)
// get the channelId
sender.send(fundee.register, Symbol("channels"))
sender.send(fundee.register, Register.GetChannels)
val Some((_, fundeeChannel)) = sender.expectMsgType[Map[ByteVector32, ActorRef]].find(_._1 == tempChannelId)

sender.send(fundeeChannel, CMD_GET_CHANNEL_DATA(ActorRef.noSender))
Expand Down Expand Up @@ -682,7 +682,7 @@ abstract class AnchorChannelIntegrationSpec extends ChannelIntegrationSpec {

// initially all the balance is on C side and F doesn't have an output
val sender = TestProbe()
sender.send(nodes("F").register, Symbol("channelsTo"))
sender.send(nodes("F").register, Register.GetChannelsTo)
// retrieve the channelId of C <--> F
val Some(channelId) = sender.expectMsgType[Map[ByteVector32, PublicKey]].find(_._2 == nodes("C").nodeParams.nodeId).map(_._1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import fr.acinq.eclair.TestUtils.waitEventStreamSynced
import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher
import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{Watch, WatchFundingConfirmed}
import fr.acinq.eclair.blockchain.bitcoind.rpc.BitcoinCoreClient
import fr.acinq.eclair.channel.{CMD_CLOSE, RES_SUCCESS}
import fr.acinq.eclair.channel.{CMD_CLOSE, RES_SUCCESS, Register}
import fr.acinq.eclair.io.Switchboard
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute}
Expand Down Expand Up @@ -328,10 +328,10 @@ class MessageIntegrationSpec extends IntegrationSpec {
// We close the channels A -> B -> C but we keep channels with D
// This ensures nodes still have an unrelated channel so we keep them in the network DB.
val probe = TestProbe()
probe.send(nodes("B").register, Symbol("channels"))
probe.send(nodes("B").register, Register.GetChannels)
val channelsB = probe.expectMsgType[Map[ByteVector32, ActorRef]]
assert(channelsB.size == 3)
probe.send(nodes("D").register, Symbol("channels"))
probe.send(nodes("D").register, Register.GetChannels)
val channelsD = probe.expectMsgType[Map[ByteVector32, ActorRef]]
assert(channelsD.size == 3)
channelsB.foreach {
Expand Down
2 changes: 2 additions & 0 deletions eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ class PeerSpec extends FixtureSpec {
test("restore existing channels") { f =>
import f._
val probe = TestProbe()
system.eventStream.subscribe(probe.ref, classOf[PeerCreated])
connect(remoteNodeId, peer, peerConnection, switchboard, channels = Set(ChannelCodecsSpec.normal))
probe.expectMsg(PeerCreated(peer.ref, remoteNodeId))
probe.send(peer, Peer.GetPeerInfo(None))
val peerInfo = probe.expectMsgType[PeerInfo]
assert(peerInfo.peer == peer)
Expand Down

0 comments on commit 741ac49

Please sign in to comment.