Skip to content

Commit b880666

Browse files
feat: Make default runner connection available
There is no reason to keep a generic interface in the runner and provide the implementation in the core lightning runner if the only way to establish a connection with the node is through the pyln.proto package. This commit moves the CLightningConn inside the runner interface to allow everyone access to the default implementation. Signed-off-by: Vincenzo Palazzo <[email protected]>
1 parent e51fb97 commit b880666

File tree

3 files changed

+65
-121
lines changed

3 files changed

+65
-121
lines changed

lnprototest/__init__.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .runner import (
4646
Runner,
4747
Conn,
48+
RunnerConn,
4849
remote_revocation_basepoint,
4950
remote_payment_basepoint,
5051
remote_delayed_payment_basepoint,
@@ -84,75 +85,3 @@
8485
AddWitnesses,
8586
)
8687
from .proposals import dual_fund_csv, channel_type_csv
87-
88-
__all__ = [
89-
"EventError",
90-
"SpecFileError",
91-
"Resolvable",
92-
"ResolvableInt",
93-
"ResolvableStr",
94-
"ResolvableBool",
95-
"Event",
96-
"Connect",
97-
"Disconnect",
98-
"DualFundAccept",
99-
"CreateDualFunding",
100-
"AddInput",
101-
"AddOutput",
102-
"FinalizeFunding",
103-
"AddWitnesses",
104-
"Msg",
105-
"RawMsg",
106-
"ExpectMsg",
107-
"Block",
108-
"ExpectTx",
109-
"FundChannel",
110-
"InitRbf",
111-
"Invoice",
112-
"AddHtlc",
113-
"ExpectError",
114-
"Sequence",
115-
"OneOf",
116-
"AnyOrder",
117-
"TryAll",
118-
"CheckEq",
119-
"MustNotMsg",
120-
"SigType",
121-
"Sig",
122-
"DummyRunner",
123-
"Runner",
124-
"Conn",
125-
"KeySet",
126-
"peer_message_namespace",
127-
"namespace",
128-
"assign_namespace",
129-
"make_namespace",
130-
"bitfield",
131-
"has_bit",
132-
"bitfield_len",
133-
"msat",
134-
"negotiated",
135-
"remote_revocation_basepoint",
136-
"remote_payment_basepoint",
137-
"remote_delayed_payment_basepoint",
138-
"remote_htlc_basepoint",
139-
"remote_per_commitment_point",
140-
"remote_per_commitment_secret",
141-
"remote_funding_pubkey",
142-
"remote_funding_privkey",
143-
"Commit",
144-
"HTLC",
145-
"UpdateCommit",
146-
"Side",
147-
"AcceptFunding",
148-
"CreateFunding",
149-
"Funding",
150-
"regtest_hash",
151-
"privkey_expand",
152-
"Wait",
153-
"dual_fund_csv",
154-
"channel_type_csv",
155-
"wait_for",
156-
"CloseChannel",
157-
"ExpectDisconnect",
158-
]

lnprototest/clightning/clightning.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/python3
21
# This script exercises the core-lightning implementation
32

43
# Released by Rusty Russell under CC0:
@@ -27,6 +26,7 @@
2726
SpecFileError,
2827
KeySet,
2928
Conn,
29+
RunnerConn,
3030
namespace,
3131
MustNotMsg,
3232
)
@@ -37,23 +37,6 @@
3737
LIGHTNING_SRC = os.path.join(os.getcwd(), os.getenv("LIGHTNING_SRC", "../lightning/"))
3838

3939

40-
class CLightningConn(lnprototest.Conn):
41-
def __init__(self, connprivkey: str, port: int):
42-
super().__init__(connprivkey)
43-
# FIXME: pyln.proto.wire should just use coincurve PrivateKey!
44-
self.connection = pyln.proto.wire.connect(
45-
pyln.proto.wire.PrivateKey(bytes.fromhex(self.connprivkey.to_hex())),
46-
# FIXME: Ask node for pubkey
47-
pyln.proto.wire.PublicKey(
48-
bytes.fromhex(
49-
"0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
50-
)
51-
),
52-
"127.0.0.1",
53-
port,
54-
)
55-
56-
5740
class Runner(lnprototest.Runner):
5841
def __init__(self, config: Any):
5942
super().__init__(config)
@@ -194,7 +177,7 @@ def stop(self, print_logs: bool = False, also_bitcoind: bool = True) -> None:
194177
self.shutdown(also_bitcoind=also_bitcoind)
195178
self.running = False
196179
for c in self.conns.values():
197-
cast(CLightningConn, c).connection.connection.close()
180+
c.connection.connection.close()
198181
if print_logs:
199182
log_path = f"{self.lightning_dir}/regtest/log"
200183
with open(log_path) as log:
@@ -216,7 +199,13 @@ def restart(self) -> None:
216199
self.start(also_bitcoind=False)
217200

218201
def connect(self, _: Event, connprivkey: str) -> None:
219-
self.add_conn(CLightningConn(connprivkey, self.lightning_port))
202+
self.add_conn(
203+
RunnerConn(
204+
connprivkey,
205+
"0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798",
206+
self.lightning_port,
207+
)
208+
)
220209

221210
def getblockheight(self) -> int:
222211
return self.bitcoind.rpc.getblockcount()
@@ -232,15 +221,13 @@ def add_blocks(self, event: Event, txs: List[str], n: int) -> None:
232221

233222
wait_for(lambda: self.rpc.getinfo()["blockheight"] == self.getblockheight())
234223

235-
def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
224+
def recv(self, event: Event, conn: RunnerConn, outbuf: bytes) -> None:
236225
try:
237-
cast(CLightningConn, conn).connection.send_message(outbuf)
226+
conn.connection.send_message(outbuf)
238227
except BrokenPipeError:
239228
# This happens when they've sent an error and closed; try
240229
# reading it to figure out what went wrong.
241-
fut = self.executor.submit(
242-
cast(CLightningConn, conn).connection.read_message
243-
)
230+
fut = self.executor.submit(conn.connection.read_message)
244231
try:
245232
msg = fut.result(1)
246233
except futures.TimeoutError:
@@ -255,7 +242,7 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
255242
def fundchannel(
256243
self,
257244
event: Event,
258-
conn: Conn,
245+
conn: RunnerConn,
259246
amount: int,
260247
feerate: int = 253,
261248
expect_fail: bool = False,
@@ -279,7 +266,7 @@ def fundchannel(
279266

280267
def _fundchannel(
281268
runner: Runner,
282-
conn: Conn,
269+
conn: RunnerConn,
283270
amount: int,
284271
feerate: int,
285272
expect_fail: bool = False,
@@ -352,7 +339,7 @@ def kill_fundchannel(self) -> None:
352339
def init_rbf(
353340
self,
354341
event: Event,
355-
conn: Conn,
342+
conn: RunnerConn,
356343
channel_id: str,
357344
amount: int,
358345
utxo_txid: str,
@@ -416,7 +403,9 @@ def accept_add_fund(self, event: Event) -> None:
416403
},
417404
)
418405

419-
def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None:
406+
def addhtlc(
407+
self, event: Event, conn: RunnerConn, amount: int, preimage: str
408+
) -> None:
420409
payhash = hashlib.sha256(bytes.fromhex(preimage)).hexdigest()
421410
routestep = {
422411
"msatoshi": amount,
@@ -429,9 +418,9 @@ def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None:
429418
self.rpc.sendpay([routestep], payhash)
430419

431420
def get_output_message(
432-
self, conn: Conn, event: Event, timeout: int = TIMEOUT
421+
self, conn: RunnerConn, event: Event, timeout: int = TIMEOUT
433422
) -> Optional[bytes]:
434-
fut = self.executor.submit(cast(CLightningConn, conn).connection.read_message)
423+
fut = self.executor.submit(conn.connection.read_message)
435424
try:
436425
return fut.result(timeout)
437426
except futures.TimeoutError as ex:
@@ -441,7 +430,7 @@ def get_output_message(
441430
logging.error(f"{ex}")
442431
return None
443432

444-
def check_error(self, event: Event, conn: Conn) -> Optional[str]:
433+
def check_error(self, event: Event, conn: RunnerConn) -> Optional[str]:
445434
# We get errors in form of err msgs, always.
446435
super().check_error(event, conn)
447436
msg = self.get_output_message(conn, event)
@@ -452,13 +441,13 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
452441
def check_final_error(
453442
self,
454443
event: Event,
455-
conn: Conn,
444+
conn: RunnerConn,
456445
expected: bool,
457446
must_not_events: List[MustNotMsg],
458447
) -> None:
459448
if not expected:
460449
# Inject raw packet to ensure it hangs up *after* processing all previous ones.
461-
cast(CLightningConn, conn).connection.connection.send(bytes(18))
450+
conn.connection.connection.send(bytes(18))
462451

463452
while True:
464453
binmsg = self.get_output_message(conn, event)
@@ -475,7 +464,7 @@ def check_final_error(
475464
if msgtype == namespace().get_msgtype("error").number:
476465
raise EventError(event, "Got error msg: {}".format(binmsg.hex()))
477466

478-
cast(CLightningConn, conn).connection.connection.close()
467+
conn.connection.connection.close()
479468

480469
def expect_tx(self, event: Event, txid: str) -> None:
481470
# Ah bitcoin endianness...

lnprototest/runner.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
#! /usr/bin/python3
21
import logging
32
import shutil
43
import tempfile
54

65
import coincurve
76
import functools
87

8+
import pyln
9+
10+
from abc import ABC, abstractmethod
11+
from typing import Dict, Optional, List, Union, Any, Callable
12+
913
from .errors import SpecFileError
1014
from .structure import Sequence
1115
from .event import Event, MustNotMsg, ExpectMsg
1216
from .utils import privkey_expand
1317
from .keyset import KeySet
14-
from abc import ABC, abstractmethod
15-
from typing import Dict, Optional, List, Union, Any, Callable
1618

1719

1820
class Conn(object):
@@ -31,6 +33,28 @@ def __str__(self) -> str:
3133
return self.name
3234

3335

36+
class RunnerConn(Conn):
37+
"""Default Connection implementation for a runner that use the pyln.proto
38+
to open a connection over a socket."""
39+
40+
def __init__(
41+
self,
42+
connprivkey: str,
43+
counterparty_pubkey: str,
44+
port: int,
45+
host: str = "127.0.0.1",
46+
):
47+
super().__init__(connprivkey)
48+
# FIXME: pyln.proto.wire should just use coincurve PrivateKey!
49+
self.connection = pyln.proto.wire.connect(
50+
pyln.proto.wire.PrivateKey(bytes.fromhex(self.connprivkey.to_hex())),
51+
# FIXME: Ask node for pubkey
52+
pyln.proto.wire.PublicKey(bytes.fromhex(counterparty_pubkey)),
53+
host,
54+
port,
55+
)
56+
57+
3458
class Runner(ABC):
3559
"""Abstract base class for runners.
3660
@@ -42,8 +66,8 @@ def __init__(self, config: Any):
4266
self.config = config
4367
self.directory = tempfile.mkdtemp(prefix="lnpt-cl-")
4468
# key == connprivkey, value == Conn
45-
self.conns: Dict[str, Conn] = {}
46-
self.last_conn: Optional[Conn] = None
69+
self.conns: Dict[str, RunnerConn] = {}
70+
self.last_conn: Optional[RunnerConn] = None
4771
self.stash: Dict[str, Dict[str, Any]] = {}
4872
self.logger = logging.getLogger(__name__)
4973
if self.config.getoption("verbose"):
@@ -55,7 +79,7 @@ def _is_dummy(self) -> bool:
5579
"""The DummyRunner returns True here, as it can't do some things"""
5680
return False
5781

58-
def find_conn(self, connprivkey: Optional[str]) -> Optional[Conn]:
82+
def find_conn(self, connprivkey: Optional[str]) -> Optional[RunnerConn]:
5983
# Default is whatever we specified last.
6084
if connprivkey is None:
6185
return self.last_conn
@@ -64,17 +88,17 @@ def find_conn(self, connprivkey: Optional[str]) -> Optional[Conn]:
6488
return self.last_conn
6589
return None
6690

67-
def add_conn(self, conn: Conn) -> None:
91+
def add_conn(self, conn: RunnerConn) -> None:
6892
self.conns[conn.name] = conn
6993
self.last_conn = conn
7094

71-
def disconnect(self, event: Event, conn: Conn) -> None:
95+
def disconnect(self, event: Event, conn: RunnerConn) -> None:
7296
if conn is None:
7397
raise SpecFileError(event, "Unknown conn")
7498
del self.conns[conn.name]
7599
self.check_final_error(event, conn, conn.expected_error, conn.must_not_events)
76100

77-
def check_error(self, event: Event, conn: Conn) -> Optional[str]:
101+
def check_error(self, event: Event, conn: RunnerConn) -> Optional[str]:
78102
conn.expected_error = True
79103
return None
80104

@@ -157,11 +181,11 @@ def stop(self, print_logs: bool = False) -> None:
157181
pass
158182

159183
@abstractmethod
160-
def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
184+
def recv(self, event: Event, conn: RunnerConn, outbuf: bytes) -> None:
161185
pass
162186

163187
@abstractmethod
164-
def get_output_message(self, conn: Conn, event: ExpectMsg) -> Optional[bytes]:
188+
def get_output_message(self, conn: RunnerConn, event: ExpectMsg) -> Optional[bytes]:
165189
pass
166190

167191
@abstractmethod
@@ -192,7 +216,7 @@ def accept_add_fund(self, event: Event) -> None:
192216
def fundchannel(
193217
self,
194218
event: Event,
195-
conn: Conn,
219+
conn: RunnerConn,
196220
amount: int,
197221
feerate: int = 0,
198222
expect_fail: bool = False,
@@ -203,7 +227,7 @@ def fundchannel(
203227
def init_rbf(
204228
self,
205229
event: Event,
206-
conn: Conn,
230+
conn: RunnerConn,
207231
channel_id: str,
208232
amount: int,
209233
utxo_txid: str,
@@ -213,7 +237,9 @@ def init_rbf(
213237
pass
214238

215239
@abstractmethod
216-
def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None:
240+
def addhtlc(
241+
self, event: Event, conn: RunnerConn, amount: int, preimage: str
242+
) -> None:
217243
pass
218244

219245
@abstractmethod

0 commit comments

Comments
 (0)