Skip to content

Commit 198fd93

Browse files
Allow sending message to route that starts with us (#2585)
In case we are the introduction node of a message blinded route, we would not be able to send a message to that route. We now unwrap the first hop in case the route starts with us.
1 parent f901aea commit 198fd93

File tree

9 files changed

+141
-78
lines changed

9 files changed

+141
-78
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ class Setup(val datadir: File,
371371
_ = triggerer ! AsyncPaymentTriggerer.Start(switchboard.toTyped)
372372
balanceActor = system.spawn(BalanceActor(nodeParams.db, bitcoinClient, channelsListener, nodeParams.balanceCheckInterval), name = "balance-actor")
373373

374-
postman = system.spawn(Behaviors.supervise(Postman(switchboard.toTyped)).onFailure(typed.SupervisorStrategy.restart), name = "postman")
374+
postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped)).onFailure(typed.SupervisorStrategy.restart), name = "postman")
375375

376376
kit = Kit(
377377
nodeParams = nodeParams,

eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import scodec.{Attempt, DecodeResult}
2828

2929
import scala.annotation.tailrec
3030
import scala.concurrent.duration.FiniteDuration
31-
import scala.util.Try
3231

3332
object OnionMessages {
3433

@@ -42,35 +41,55 @@ object OnionMessages {
4241
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None) extends Destination
4342
// @formatter:on
4443

45-
def buildRoute(blindingSecret: PrivateKey,
46-
intermediateNodes: Seq[IntermediateNode],
47-
destination: Destination): Sphinx.RouteBlinding.BlindedRoute = {
48-
val last = destination match {
49-
case Recipient(nodeId, _, _) => OutgoingNodeId(nodeId) :: Nil
50-
case BlindedPath(Sphinx.RouteBlinding.BlindedRoute(nodeId, blindingKey, _)) => OutgoingNodeId(nodeId) :: NextBlinding(blindingKey) :: Nil
51-
}
52-
val intermediatePayloads = if (intermediateNodes.isEmpty) {
44+
private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], nextTlvs: Set[RouteBlindingEncryptedDataTlv]): Seq[ByteVector] = {
45+
if (intermediateNodes.isEmpty) {
5346
Nil
5447
} else {
55-
(intermediateNodes.tail.map(node => Set(OutgoingNodeId(node.nodeId))) :+ last)
48+
(intermediateNodes.tail.map(node => Set(OutgoingNodeId(node.nodeId))) :+ nextTlvs)
5649
.zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(Padding).toSet[RouteBlindingEncryptedDataTlv] ++ tlvs }
5750
.map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs)).require.bytes)
5851
}
52+
}
53+
54+
def buildRoute(blindingSecret: PrivateKey,
55+
intermediateNodes: Seq[IntermediateNode],
56+
recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = {
57+
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(recipient.nodeId)))
58+
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten
59+
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs)).require.bytes
60+
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route
61+
}
62+
63+
private def buildRouteFrom(originKey: PrivateKey,
64+
blindingSecret: PrivateKey,
65+
intermediateNodes: Seq[IntermediateNode],
66+
destination: Destination): Option[Sphinx.RouteBlinding.BlindedRoute] = {
5967
destination match {
60-
case Recipient(nodeId, pathId, padding) =>
61-
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(padding.map(Padding), pathId.map(PathId)).flatten
62-
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs)).require.bytes
63-
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ nodeId, intermediatePayloads :+ lastPayload).route
64-
case BlindedPath(route) =>
65-
if (intermediateNodes.isEmpty) {
66-
route
67-
} else {
68-
val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route
69-
Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes)
68+
case recipient: Recipient => Some(buildRoute(blindingSecret, intermediateNodes, recipient))
69+
case BlindedPath(route) if route.introductionNodeId == originKey.publicKey =>
70+
RouteBlindingEncryptedDataCodecs.decode(originKey, route.blindingKey, route.blindedNodes.head.encryptedPayload) match {
71+
case Left(_) => None
72+
case Right(decoded) =>
73+
decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match {
74+
case None => None
75+
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId)) =>
76+
Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail))
77+
}
7078
}
79+
case BlindedPath(route) if intermediateNodes.isEmpty => Some(route)
80+
case BlindedPath(route) =>
81+
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(route.introductionNodeId), NextBlinding(route.blindingKey)))
82+
val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route
83+
Some(Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes))
7184
}
7285
}
7386

87+
// @formatter:off
88+
sealed trait BuildMessageError
89+
case class MessageTooLarge(payloadSize: Long) extends BuildMessageError
90+
case class InvalidDestination(destination: Destination) extends BuildMessageError
91+
// @formatter:on
92+
7493
/**
7594
* Builds an encrypted onion containing a message that should be relayed to the destination.
7695
*
@@ -81,28 +100,32 @@ object OnionMessages {
81100
* @param content List of TLVs to send to the recipient of the message
82101
* @return The node id to send the onion to and the onion containing the message
83102
*/
84-
def buildMessage(sessionKey: PrivateKey,
103+
def buildMessage(nodeKey: PrivateKey,
104+
sessionKey: PrivateKey,
85105
blindingSecret: PrivateKey,
86106
intermediateNodes: Seq[IntermediateNode],
87107
destination: Destination,
88-
content: TlvStream[OnionMessagePayloadTlv]): Try[(PublicKey, OnionMessage)] = Try{
89-
val route = buildRoute(blindingSecret, intermediateNodes, destination)
90-
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
91-
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
92-
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
93-
val packetSize = if (payloadSize <= 1300) {
94-
1300
95-
} else if (payloadSize <= 32768) {
96-
32768
97-
} else if (payloadSize > 65432) {
98-
// A payload of size 65432 corresponds to a total lightning message size of 65535.
99-
throw new Exception(s"Message is too large: payloadSize=$payloadSize")
100-
} else {
101-
payloadSize.toInt
108+
content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, (PublicKey, OnionMessage)] = {
109+
buildRouteFrom(nodeKey, blindingSecret, intermediateNodes, destination) match {
110+
case None => Left(InvalidDestination(destination))
111+
case Some(route) =>
112+
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
113+
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
114+
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
115+
val packetSize = if (payloadSize <= 1300) {
116+
1300
117+
} else if (payloadSize <= 32768) {
118+
32768
119+
} else if (payloadSize > 65432) {
120+
// A payload of size 65432 corresponds to a total lightning message size of 65535.
121+
return Left(MessageTooLarge(payloadSize))
122+
} else {
123+
payloadSize.toInt
124+
}
125+
// Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`).
126+
val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get
127+
Right((route.introductionNodeId, OnionMessage(route.blindingKey, packet)))
102128
}
103-
// Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`).
104-
val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get
105-
(route.introductionNodeId, OnionMessage(route.blindingKey, packet))
106129
}
107130

108131
// @formatter:off

eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ import akka.actor.typed.{ActorRef, Behavior}
2222
import fr.acinq.bitcoin.scalacompat.ByteVector32
2323
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
2424
import fr.acinq.eclair.io.{MessageRelay, Switchboard}
25-
import fr.acinq.eclair.message.OnionMessages.{Destination, ReceiveMessage}
25+
import fr.acinq.eclair.message.OnionMessages.Destination
2626
import fr.acinq.eclair.wire.protocol.MessageOnion.FinalPayload
2727
import fr.acinq.eclair.wire.protocol.{OnionMessagePayloadTlv, TlvStream}
28-
import fr.acinq.eclair.{randomBytes32, randomKey}
28+
import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey}
2929

3030
import scala.collection.mutable
3131
import scala.concurrent.duration.FiniteDuration
32-
import scala.util.{Failure, Success}
3332

3433
object Postman {
3534
// @formatter:off
@@ -62,9 +61,9 @@ object Postman {
6261
case class MessageFailed(reason: String) extends MessageStatus
6362
// @formatter:on
6463

65-
def apply(switchboard: ActorRef[Switchboard.RelayMessage]): Behavior[Command] = {
64+
def apply(nodeParams: NodeParams, switchboard: ActorRef[Switchboard.RelayMessage]): Behavior[Command] = {
6665
Behaviors.setup(context => {
67-
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[ReceiveMessage](r => WrappedMessage(r.finalPayload)))
66+
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[OnionMessages.ReceiveMessage](r => WrappedMessage(r.finalPayload)))
6867

6968
val relayMessageStatusAdapter = context.messageAdapter[MessageRelay.Status](SendingStatus)
7069

@@ -94,14 +93,15 @@ object Postman {
9493
OnionMessages.buildRoute(randomKey(), intermediateHops, lastHop)
9594
})
9695
OnionMessages.buildMessage(
96+
nodeParams.privateKey,
9797
randomKey(),
9898
randomKey(),
9999
intermediateNodes.map(OnionMessages.IntermediateNode(_)),
100100
destination,
101101
TlvStream(replyRoute.map(OnionMessagePayloadTlv.ReplyPath).toSet ++ messageContent.records, messageContent.unknown)) match {
102-
case Failure(f) =>
103-
replyTo ! MessageFailed(f.getMessage)
104-
case Success((nextNodeId, message)) =>
102+
case Left(failure) =>
103+
replyTo ! MessageFailed(failure.toString)
104+
case Right((nextNodeId, message)) =>
105105
if (replyPath.isEmpty) { // not expecting reply
106106
sendStatusTo += (messageId -> replyTo)
107107
} else { // expecting reply

eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ import fr.acinq.eclair.blockchain.bitcoind.rpc.BitcoinCoreClient
3131
import fr.acinq.eclair.channel.{CMD_CLOSE, RES_SUCCESS}
3232
import fr.acinq.eclair.io.Switchboard
3333
import fr.acinq.eclair.message.OnionMessages
34+
import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute}
3435
import fr.acinq.eclair.router.Router
3536
import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv
3637
import fr.acinq.eclair.wire.protocol.{GenericTlv, NodeAnnouncement}
37-
import fr.acinq.eclair.{EclairImpl, Features, MilliSatoshi, SendOnionMessageResponse, UInt64, randomBytes}
38+
import fr.acinq.eclair.{EclairImpl, Features, MilliSatoshi, SendOnionMessageResponse, UInt64, randomBytes, randomKey}
3839
import scodec.bits.{ByteVector, HexStringSyntax}
3940

4041
import scala.concurrent.ExecutionContext.Implicits.global
@@ -73,6 +74,21 @@ class MessageIntegrationSpec extends IntegrationSpec {
7374
eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds)
7475
}
7576

77+
test("send to route that starts at ourselves") {
78+
val alice = new EclairImpl(nodes("A"))
79+
80+
val probe = TestProbe()
81+
val eventListener = TestProbe()
82+
nodes("B").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage])
83+
84+
val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(nodes("A").nodeParams.nodeId), IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("B").nodeParams.nodeId)), Recipient(nodes("B").nodeParams.nodeId, None))
85+
assert(blindedRoute.introductionNodeId == nodes("A").nodeParams.nodeId)
86+
87+
alice.sendOnionMessage(Nil, Right(blindedRoute), None, ByteVector.empty).pipeTo(probe.ref)
88+
assert(probe.expectMsgType[SendOnionMessageResponse].sent)
89+
eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds)
90+
}
91+
7692
test("expect reply") {
7793
val alice = new EclairImpl(nodes("A"))
7894
val bob = new EclairImpl(nodes("B"))
@@ -140,7 +156,7 @@ class MessageIntegrationSpec extends IntegrationSpec {
140156
assert(probe.expectMsgType[SendOnionMessageResponse].sent)
141157

142158
val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds)
143-
assert(r.finalPayload.records.unknown.toSet == Set(GenericTlv(UInt64(113), hex"010203"), GenericTlv(UInt64(117), hex"0102")))
159+
assert(r.finalPayload.records.unknown == Set(GenericTlv(UInt64(113), hex"010203"), GenericTlv(UInt64(117), hex"0102")))
144160
}
145161

146162
test("send very large message with hop") {
@@ -157,7 +173,7 @@ class MessageIntegrationSpec extends IntegrationSpec {
157173
assert(probe.expectMsgType[SendOnionMessageResponse].sent)
158174

159175
val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds)
160-
assert(r.finalPayload.records.unknown.toSet == Set(GenericTlv(UInt64(135), bytes)))
176+
assert(r.finalPayload.records.unknown == Set(GenericTlv(UInt64(135), bytes)))
161177
}
162178

163179
test("send too large message with hop") {
@@ -266,7 +282,7 @@ class MessageIntegrationSpec extends IntegrationSpec {
266282
assert(probe.expectMsgType[SendOnionMessageResponse].sent)
267283

268284
val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds)
269-
assert(r.finalPayload.records.unknown.toSet == Set(GenericTlv(UInt64(113), hex"010203"), GenericTlv(UInt64(117), hex"0102")))
285+
assert(r.finalPayload.records.unknown == Set(GenericTlv(UInt64(113), hex"010203"), GenericTlv(UInt64(117), hex"0102")))
270286
}
271287

272288
test("channel relay with no-relay") {
@@ -344,7 +360,7 @@ class MessageIntegrationSpec extends IntegrationSpec {
344360

345361
val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds)
346362
assert(r.finalPayload.pathId_opt.isEmpty)
347-
assert(r.finalPayload.records.unknown.toSet == Set(GenericTlv(UInt64(115), hex"")))
363+
assert(r.finalPayload.records.unknown == Set(GenericTlv(UInt64(115), hex"")))
348364
}
349365

350366
}

eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
5858
test("relay with new connection") { f =>
5959
import f._
6060

61-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
61+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
6262
val messageId = randomBytes32()
6363
relay ! RelayMessage(messageId, switchboard.ref, randomKey().publicKey, bobId, message, RelayAll, None)
6464

@@ -71,7 +71,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
7171
test("relay with existing peer") { f =>
7272
import f._
7373

74-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
74+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
7575
val messageId = randomBytes32()
7676
relay ! RelayMessage(messageId, switchboard.ref, randomKey().publicKey, bobId, message, RelayAll, None)
7777

@@ -84,7 +84,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
8484
test("can't open new connection") { f =>
8585
import f._
8686

87-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
87+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
8888
val messageId = randomBytes32()
8989
relay ! RelayMessage(messageId, switchboard.ref, randomKey().publicKey, bobId, message, RelayAll, Some(probe.ref))
9090

@@ -97,7 +97,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
9797
test("no channel with previous node") { f =>
9898
import f._
9999

100-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
100+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
101101
val messageId = randomBytes32()
102102
val previousNodeId = randomKey().publicKey
103103
relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, RelayChannelsOnly, Some(probe.ref))
@@ -113,7 +113,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
113113
test("no channel with next node") { f =>
114114
import f._
115115

116-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
116+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
117117
val messageId = randomBytes32()
118118
val previousNodeId = randomKey().publicKey
119119
relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, RelayChannelsOnly, Some(probe.ref))
@@ -133,7 +133,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
133133
test("channels on both ends") { f =>
134134
import f._
135135

136-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
136+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
137137
val messageId = randomBytes32()
138138
val previousNodeId = randomKey().publicKey
139139
relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, RelayChannelsOnly, None)
@@ -152,7 +152,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app
152152
test("no relay") { f =>
153153
import f._
154154

155-
val Success((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
155+
val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty)
156156
val messageId = randomBytes32()
157157
val previousNodeId = randomKey().publicKey
158158
relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, NoRelay, Some(probe.ref))

0 commit comments

Comments
 (0)