diff --git a/Makefile b/Makefile index 5a21d00..e409d68 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,10 @@ POSSIBLE_PYTEST_NAMES=pytest-3 pytest3 pytest PYTEST := $(shell for p in $(POSSIBLE_PYTEST_NAMES); do if type $$p > /dev/null; then echo $$p; break; fi done) TEST_DIR=tests -default: check-source check check-quotes +# We disable the default target for the moment because we +# need to fix pytest +# default: check-source check check-quotes +default: fmt check check-pytest-found: @if [ -z "$(PYTEST)" ]; then echo "Cannot find any pytest: $(POSSIBLE_PYTEST_NAMES)" >&2; exit 1; fi diff --git a/lnprototest/__init__.py b/lnprototest/__init__.py index 687f225..c3170f9 100644 --- a/lnprototest/__init__.py +++ b/lnprototest/__init__.py @@ -45,6 +45,7 @@ from .runner import ( Runner, Conn, + RunnerConn, remote_revocation_basepoint, remote_payment_basepoint, remote_delayed_payment_basepoint, @@ -54,7 +55,6 @@ remote_funding_pubkey, remote_funding_privkey, ) -from .dummyrunner import DummyRunner from .namespace import ( peer_message_namespace, namespace, @@ -84,75 +84,3 @@ AddWitnesses, ) from .proposals import dual_fund_csv, channel_type_csv - -__all__ = [ - "EventError", - "SpecFileError", - "Resolvable", - "ResolvableInt", - "ResolvableStr", - "ResolvableBool", - "Event", - "Connect", - "Disconnect", - "DualFundAccept", - "CreateDualFunding", - "AddInput", - "AddOutput", - "FinalizeFunding", - "AddWitnesses", - "Msg", - "RawMsg", - "ExpectMsg", - "Block", - "ExpectTx", - "FundChannel", - "InitRbf", - "Invoice", - "AddHtlc", - "ExpectError", - "Sequence", - "OneOf", - "AnyOrder", - "TryAll", - "CheckEq", - "MustNotMsg", - "SigType", - "Sig", - "DummyRunner", - "Runner", - "Conn", - "KeySet", - "peer_message_namespace", - "namespace", - "assign_namespace", - "make_namespace", - "bitfield", - "has_bit", - "bitfield_len", - "msat", - "negotiated", - "remote_revocation_basepoint", - "remote_payment_basepoint", - "remote_delayed_payment_basepoint", - "remote_htlc_basepoint", - "remote_per_commitment_point", - "remote_per_commitment_secret", - "remote_funding_pubkey", - "remote_funding_privkey", - "Commit", - "HTLC", - "UpdateCommit", - "Side", - "AcceptFunding", - "CreateFunding", - "Funding", - "regtest_hash", - "privkey_expand", - "Wait", - "dual_fund_csv", - "channel_type_csv", - "wait_for", - "CloseChannel", - "ExpectDisconnect", -] diff --git a/lnprototest/clightning/clightning.py b/lnprototest/clightning/clightning.py index 1749bdd..1e4e5d1 100644 --- a/lnprototest/clightning/clightning.py +++ b/lnprototest/clightning/clightning.py @@ -1,4 +1,3 @@ -#!/usr/bin/python3 # This script exercises the core-lightning implementation # Released by Rusty Russell under CC0: @@ -27,6 +26,7 @@ SpecFileError, KeySet, Conn, + RunnerConn, namespace, MustNotMsg, ) @@ -37,23 +37,6 @@ LIGHTNING_SRC = os.path.join(os.getcwd(), os.getenv("LIGHTNING_SRC", "../lightning/")) -class CLightningConn(lnprototest.Conn): - def __init__(self, connprivkey: str, port: int): - super().__init__(connprivkey) - # FIXME: pyln.proto.wire should just use coincurve PrivateKey! - self.connection = pyln.proto.wire.connect( - pyln.proto.wire.PrivateKey(bytes.fromhex(self.connprivkey.to_hex())), - # FIXME: Ask node for pubkey - pyln.proto.wire.PublicKey( - bytes.fromhex( - "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" - ) - ), - "127.0.0.1", - port, - ) - - class Runner(lnprototest.Runner): def __init__(self, config: Any): super().__init__(config) @@ -207,7 +190,7 @@ def stop(self, print_logs: bool = False, also_bitcoind: bool = True) -> None: self.shutdown(also_bitcoind=also_bitcoind) self.running = False for c in self.conns.values(): - cast(CLightningConn, c).connection.connection.close() + c.connection.connection.close() if print_logs: log_path = f"{self.lightning_dir}/regtest/log" with open(log_path) as log: @@ -228,8 +211,14 @@ def restart(self) -> None: self.bitcoind.restart() self.start(also_bitcoind=False) - def connect(self, _: Event, connprivkey: str) -> None: - self.add_conn(CLightningConn(connprivkey, self.lightning_port)) + def connect(self, _: Event, connprivkey: str) -> RunnerConn: + conn = RunnerConn( + connprivkey, + "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798", + self.lightning_port, + ) + self.add_conn(conn) + return conn def getblockheight(self) -> int: return self.bitcoind.rpc.getblockcount() @@ -245,15 +234,13 @@ def add_blocks(self, event: Event, txs: List[str], n: int) -> None: wait_for(lambda: self.rpc.getinfo()["blockheight"] == self.getblockheight()) - def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None: + def recv(self, event: Event, conn: RunnerConn, outbuf: bytes) -> None: try: - cast(CLightningConn, conn).connection.send_message(outbuf) + conn.connection.send_message(outbuf) except BrokenPipeError: # This happens when they've sent an error and closed; try # reading it to figure out what went wrong. - fut = self.executor.submit( - cast(CLightningConn, conn).connection.read_message - ) + fut = self.executor.submit(conn.connection.read_message) try: msg = fut.result(1) except futures.TimeoutError: @@ -268,7 +255,7 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None: def fundchannel( self, event: Event, - conn: Conn, + conn: RunnerConn, amount: int, feerate: int = 253, expect_fail: bool = False, @@ -292,7 +279,7 @@ def fundchannel( def _fundchannel( runner: Runner, - conn: Conn, + conn: RunnerConn, amount: int, feerate: int, expect_fail: bool = False, @@ -365,7 +352,7 @@ def kill_fundchannel(self) -> None: def init_rbf( self, event: Event, - conn: Conn, + conn: RunnerConn, channel_id: str, amount: int, utxo_txid: str, @@ -429,7 +416,9 @@ def accept_add_fund(self, event: Event) -> None: }, ) - def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None: + def addhtlc( + self, event: Event, conn: RunnerConn, amount: int, preimage: str + ) -> None: payhash = hashlib.sha256(bytes.fromhex(preimage)).hexdigest() routestep = { "msatoshi": amount, @@ -442,9 +431,9 @@ def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None: self.rpc.sendpay([routestep], payhash) def get_output_message( - self, conn: Conn, event: Event, timeout: int = TIMEOUT + self, conn: RunnerConn, event: Event, timeout: int = TIMEOUT ) -> Optional[bytes]: - fut = self.executor.submit(cast(CLightningConn, conn).connection.read_message) + fut = self.executor.submit(conn.connection.read_message) try: return fut.result(timeout) except futures.TimeoutError as ex: @@ -454,7 +443,7 @@ def get_output_message( logging.error(f"{ex}") return None - def check_error(self, event: Event, conn: Conn) -> Optional[str]: + def check_error(self, event: Event, conn: RunnerConn) -> Optional[str]: # We get errors in form of err msgs, always. super().check_error(event, conn) msg = self.get_output_message(conn, event) @@ -465,13 +454,13 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]: def check_final_error( self, event: Event, - conn: Conn, + conn: RunnerConn, expected: bool, must_not_events: List[MustNotMsg], ) -> None: if not expected: # Inject raw packet to ensure it hangs up *after* processing all previous ones. - cast(CLightningConn, conn).connection.connection.send(bytes(18)) + conn.connection.connection.send(bytes(18)) while True: binmsg = self.get_output_message(conn, event) @@ -488,7 +477,7 @@ def check_final_error( if msgtype == namespace().get_msgtype("error").number: raise EventError(event, "Got error msg: {}".format(binmsg.hex())) - cast(CLightningConn, conn).connection.connection.close() + conn.connection.connection.close() def expect_tx(self, event: Event, txid: str) -> None: # Ah bitcoin endianness... diff --git a/lnprototest/dummyrunner.py b/lnprototest/dummyrunner.py deleted file mode 100644 index ee005ab..0000000 --- a/lnprototest/dummyrunner.py +++ /dev/null @@ -1,220 +0,0 @@ -#! /usr/bin/python3 -# #### Dummy runner which you should replace with real one. #### -import io -from .runner import Runner, Conn -from .event import Event, ExpectMsg, MustNotMsg -from typing import List, Optional -from .keyset import KeySet -from pyln.proto.message import ( - Message, - FieldType, - DynamicArrayType, - EllipsisArrayType, - SizedArrayType, -) -from typing import Any - - -class DummyRunner(Runner): - def __init__(self, config: Any): - super().__init__(config) - - def _is_dummy(self) -> bool: - """The DummyRunner returns True here, as it can't do some things""" - return True - - def get_keyset(self) -> KeySet: - return KeySet( - revocation_base_secret="11", - payment_base_secret="12", - htlc_base_secret="14", - delayed_payment_base_secret="13", - shachain_seed="FF" * 32, - ) - - def add_startup_flag(self, flag: str) -> None: - if self.config.getoption("verbose"): - print("[ADD STARTUP FLAG {}]".format(flag)) - return - - def get_node_privkey(self) -> str: - return "01" - - def get_node_bitcoinkey(self) -> str: - return "10" - - def has_option(self, optname: str) -> Optional[str]: - return None - - def start(self) -> None: - self.blockheight = 102 - - def stop(self, print_logs: bool = False) -> None: - pass - - def restart(self) -> None: - super().restart() - if self.config.getoption("verbose"): - print("[RESTART]") - self.blockheight = 102 - - def connect(self, event: Event, connprivkey: str) -> None: - if self.config.getoption("verbose"): - print("[CONNECT {} {}]".format(event, connprivkey)) - self.add_conn(Conn(connprivkey)) - - def getblockheight(self) -> int: - return self.blockheight - - def trim_blocks(self, newheight: int) -> None: - if self.config.getoption("verbose"): - print("[TRIMBLOCK TO HEIGHT {}]".format(newheight)) - self.blockheight = newheight - - def add_blocks(self, event: Event, txs: List[str], n: int) -> None: - if self.config.getoption("verbose"): - print("[ADDBLOCKS {} WITH {} TXS]".format(n, len(txs))) - self.blockheight += n - - def disconnect(self, event: Event, conn: Conn) -> None: - super().disconnect(event, conn) - if self.config.getoption("verbose"): - print("[DISCONNECT {}]".format(conn)) - - def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None: - if self.config.getoption("verbose"): - print("[RECV {} {}]".format(event, outbuf.hex())) - - def fundchannel( - self, - event: Event, - conn: Conn, - amount: int, - feerate: int = 253, - expect_fail: bool = False, - ) -> None: - if self.config.getoption("verbose"): - print( - "[FUNDCHANNEL TO {} for {} at feerate {}. Expect fail? {}]".format( - conn, amount, feerate, expect_fail - ) - ) - - def init_rbf( - self, - event: Event, - conn: Conn, - channel_id: str, - amount: int, - utxo_txid: str, - utxo_outnum: int, - feerate: int, - ) -> None: - if self.config.getoption("verbose"): - print( - "[INIT_RBF TO {} (channel {}) for {} at feerate {}. {}:{}".format( - conn, channel_id, amount, feerate, utxo_txid, utxo_outnum - ) - ) - - def invoice(self, event: Event, amount: int, preimage: str) -> None: - if self.config.getoption("verbose"): - print("[INVOICE for {} with PREIMAGE {}]".format(amount, preimage)) - - def accept_add_fund(self, event: Event) -> None: - if self.config.getoption("verbose"): - print("[ACCEPT_ADD_FUND]") - - def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None: - if self.config.getoption("verbose"): - print( - "[ADDHTLC TO {} for {} with PREIMAGE {}]".format(conn, amount, preimage) - ) - - @staticmethod - def fake_field(ftype: FieldType) -> str: - if isinstance(ftype, DynamicArrayType) or isinstance(ftype, EllipsisArrayType): - # Byte arrays are literal hex strings - if ftype.elemtype.name == "byte": - return "" - return "[]" - elif isinstance(ftype, SizedArrayType): - # Byte arrays are literal hex strings - if ftype.elemtype.name == "byte": - return "00" * ftype.arraysize - return ( - "[" - + ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize) - + "]" - ) - elif ftype.name in ( - "byte", - "u8", - "u16", - "u32", - "u64", - "tu16", - "tu32", - "tu64", - "bigsize", - "varint", - ): - return "0" - elif ftype.name in ("chain_hash", "channel_id", "sha256"): - return "00" * 32 - elif ftype.name == "point": - return "038f1573b4238a986470d250ce87c7a91257b6ba3baf2a0b14380c4e1e532c209d" - elif ftype.name == "short_channel_id": - return "0x0x0" - elif ftype.name == "signature": - return "01" * 64 - else: - raise NotImplementedError( - "don't know how to fake {} type!".format(ftype.name) - ) - - def get_output_message(self, conn: Conn, event: ExpectMsg) -> Optional[bytes]: - if self.config.getoption("verbose"): - print("[GET_OUTPUT_MESSAGE {}]".format(conn)) - - # We make the message they were expecting. - msg = Message(event.msgtype, **event.resolve_args(self, event.kwargs)) - - # Fake up the other fields. - for m in msg.missing_fields(): - ftype = msg.messagetype.find_field(m.name) - msg.set_field(m.name, self.fake_field(ftype.fieldtype)) - - binmsg = io.BytesIO() - msg.write(binmsg) - return binmsg.getvalue() - - def expect_tx(self, event: Event, txid: str) -> None: - if self.config.getoption("verbose"): - print("[EXPECT-TX {}]".format(txid)) - - def check_error(self, event: Event, conn: Conn) -> Optional[str]: - super().check_error(event, conn) - if self.config.getoption("verbose"): - print("[CHECK-ERROR {}]".format(event)) - return "Dummy error" - - def check_final_error( - self, - event: Event, - conn: Conn, - expected: bool, - must_not_events: List[MustNotMsg], - ) -> None: - pass - - def close_channel(self, channel_id: str) -> bool: - if self.config.getoption("verbose"): - print("[CLOSE-CHANNEL {}]".format(channel_id)) - return True - - def is_running(self) -> bool: - return True - - def teardown(self): - pass diff --git a/lnprototest/event.py b/lnprototest/event.py index 250371f..35d5cb3 100644 --- a/lnprototest/event.py +++ b/lnprototest/event.py @@ -1,4 +1,3 @@ -#! /usr/bin/python3 import logging import traceback import collections diff --git a/lnprototest/runner.py b/lnprototest/runner.py index efdd6ec..0ef5a40 100644 --- a/lnprototest/runner.py +++ b/lnprototest/runner.py @@ -1,19 +1,24 @@ -#! /usr/bin/python3 +import io import logging import shutil import tempfile +import pyln import coincurve import functools +from abc import ABC, abstractmethod +from typing import Dict, Optional, List, Union, Any, Callable + +from pyln.proto.message import Message + from .bitfield import bitfield from .errors import SpecFileError from .structure import Sequence from .event import Event, MustNotMsg, ExpectMsg -from .utils import privkey_expand +from .utils import privkey_expand, ResolvableStr, ResolvableInt, resolve_args from .keyset import KeySet -from abc import ABC, abstractmethod -from typing import Dict, Optional, List, Union, Any, Callable +from .namespace import namespace class Conn(object): @@ -32,6 +37,74 @@ def __str__(self) -> str: return self.name +class RunnerConn(Conn): + """ + Default Connection implementation for a runner that use the pyln.proto + to open a connection over a socket. + + Each connection has an internal memory to stash information + and keep connection state. + """ + + def __init__( + self, + connprivkey: str, + counterparty_pubkey: str, + port: int, + host: str = "127.0.0.1", + ): + super().__init__(connprivkey) + # FIXME: pyln.proto.wire should just use coincurve PrivateKey! + self.connection = pyln.proto.wire.connect( + pyln.proto.wire.PrivateKey(bytes.fromhex(self.connprivkey.to_hex())), + # FIXME: Ask node for pubkey + pyln.proto.wire.PublicKey(bytes.fromhex(counterparty_pubkey)), + host, + port, + ) + self.stash: Dict[str, Dict[str, Any]] = {} + self.logger = logging.getLogger(__name__) + + def add_stash(self, stashname: str, vals: Any) -> None: + """Add a dict to the stash.""" + self.stash[stashname] = vals + + def get_stash(self, event: Event, stashname: str, default: Any = None) -> Any: + """Get an entry from the stash.""" + if stashname not in self.stash: + if default is not None: + return default + raise SpecFileError(event, "Unknown stash name {}".format(stashname)) + return self.stash[stashname] + + def recv_msg( + self, timeout: int = 1000, skip_filter: Optional[int] = None + ) -> Message: + """Listen on the connection for incoming message. + + If the {skip_filter} is specified, the message that + match the filters are skipped. + """ + raw_msg = self.connection.read_message() + msg = Message.read(namespace(), io.BytesIO(raw_msg)) + self.add_stash(msg.messagetype.name, msg) + return msg + + def send_msg( + self, msg_name: str, **kwargs: Union[ResolvableStr, ResolvableInt] + ) -> None: + """Send a message through the last connection""" + msgtype = namespace().get_msgtype(msg_name) + msg = Message(msgtype, **resolve_args(self, kwargs)) + missing = msg.missing_fields() + if missing: + raise SpecFileError(self, "Missing fields {}".format(missing)) + binmsg = io.BytesIO() + msg.write(binmsg) + self.connection.send_message(binmsg.getvalue()) + # FIXME: we should listen to possible connection here + + class Runner(ABC): """Abstract base class for runners. @@ -43,8 +116,8 @@ def __init__(self, config: Any): self.config = config self.directory = tempfile.mkdtemp(prefix="lnpt-cl-") # key == connprivkey, value == Conn - self.conns: Dict[str, Conn] = {} - self.last_conn: Optional[Conn] = None + self.conns: Dict[str, RunnerConn] = {} + self.last_conn: Optional[RunnerConn] = None self.stash: Dict[str, Dict[str, Any]] = {} self.logger = logging.getLogger(__name__) if self.config.getoption("verbose"): @@ -56,7 +129,7 @@ def _is_dummy(self) -> bool: """The DummyRunner returns True here, as it can't do some things""" return False - def find_conn(self, connprivkey: Optional[str]) -> Optional[Conn]: + def find_conn(self, connprivkey: Optional[str]) -> Optional[RunnerConn]: # Default is whatever we specified last. if connprivkey is None: return self.last_conn @@ -65,17 +138,17 @@ def find_conn(self, connprivkey: Optional[str]) -> Optional[Conn]: return self.last_conn return None - def add_conn(self, conn: Conn) -> None: + def add_conn(self, conn: RunnerConn) -> None: self.conns[conn.name] = conn self.last_conn = conn - def disconnect(self, event: Event, conn: Conn) -> None: + def disconnect(self, event: Event, conn: RunnerConn) -> None: if conn is None: raise SpecFileError(event, "Unknown conn") del self.conns[conn.name] self.check_final_error(event, conn, conn.expected_error, conn.must_not_events) - def check_error(self, event: Event, conn: Conn) -> Optional[str]: + def check_error(self, event: Event, conn: RunnerConn) -> Optional[str]: conn.expected_error = True return None @@ -142,9 +215,31 @@ def is_running(self) -> bool: pass @abstractmethod - def connect(self, event: Event, connprivkey: str) -> None: + def connect(self, event: Event, connprivkey: str) -> RunnerConn: pass + def send_msg(self, msg: Message) -> None: + """Send a message through the last connection""" + missing = msg.missing_fields() + if missing: + raise SpecFileError(self, "Missing fields {}".format(missing)) + binmsg = io.BytesIO() + msg.write(binmsg) + self.last_conn.connection.send_message(msg.getvalue()) + + def recv_msg( + self, timeout: int = 1000, skip_filter: Optional[int] = None + ) -> Message: + """Listen on the connection for incoming message. + + If the {skip_filter} is specified, the message that + match the filters are skipped. + """ + raw_msg = self.last_conn.connection.read_message() + msg = Message.read(namespace(), io.BytesIO(raw_msg)) + self.add_stash(msg.messagetype.name, msg) + return msg + @abstractmethod def check_final_error( self, @@ -171,11 +266,11 @@ def stop(self, print_logs: bool = False) -> None: pass @abstractmethod - def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None: + def recv(self, event: Event, conn: RunnerConn, outbuf: bytes) -> None: pass @abstractmethod - def get_output_message(self, conn: Conn, event: ExpectMsg) -> Optional[bytes]: + def get_output_message(self, conn: RunnerConn, event: ExpectMsg) -> Optional[bytes]: pass @abstractmethod @@ -206,7 +301,7 @@ def accept_add_fund(self, event: Event) -> None: def fundchannel( self, event: Event, - conn: Conn, + conn: RunnerConn, amount: int, feerate: int = 0, expect_fail: bool = False, @@ -217,7 +312,7 @@ def fundchannel( def init_rbf( self, event: Event, - conn: Conn, + conn: RunnerConn, channel_id: str, amount: int, utxo_txid: str, @@ -227,7 +322,9 @@ def init_rbf( pass @abstractmethod - def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None: + def addhtlc( + self, event: Event, conn: RunnerConn, amount: int, preimage: str + ) -> None: pass @abstractmethod diff --git a/lnprototest/utils/__init__.py b/lnprototest/utils/__init__.py index ea3025c..790d473 100644 --- a/lnprototest/utils/__init__.py +++ b/lnprototest/utils/__init__.py @@ -14,6 +14,12 @@ check_hex, privkey_for_index, merge_events_sequences, + Resolvable, + ResolvableBool, + ResolvableInt, + ResolvableStr, + resolve_arg, + resolve_args, ) from .bitcoin_utils import ( ScriptType, diff --git a/lnprototest/utils/utils.py b/lnprototest/utils/utils.py index 4b28bd3..3d1fc16 100644 --- a/lnprototest/utils/utils.py +++ b/lnprototest/utils/utils.py @@ -8,11 +8,17 @@ import logging import traceback -from typing import Union, Sequence, List +from typing import Union, Sequence, List, Dict, Callable, Any from enum import IntEnum from lnprototest.keyset import KeySet +# Type for arguments: either strings, or functions to call at runtime +ResolvableStr = Union[str, Callable[["RunnerConn", "Event", str], str]] +ResolvableInt = Union[int, Callable[["RunnerConn", "Event", str], int]] +ResolvableBool = Union[int, Callable[["RunnerConn", "Event", str], bool]] +Resolvable = Union[Any, Callable[["RunnerConn", "Event", str], Any]] + class Side(IntEnum): local = 0 @@ -106,3 +112,19 @@ def merge_events_sequences( """Merge the two list in the pre-post order""" pre.extend(post) return pre + + +def resolve_arg(fieldname: str, conn: "RunnerConn", arg: Resolvable) -> Any: + """If this is a string, return it, otherwise call it to get result""" + if callable(arg): + return arg(conn, fieldname) + else: + return arg + + +def resolve_args(conn: "RunnerConn", kwargs: Dict[str, Resolvable]) -> Dict[str, Any]: + """Take a dict of args, replace callables with their return values""" + ret: Dict[str, str] = {} + for field, str_or_func in kwargs.items(): + ret[field] = resolve_arg(field, conn, str_or_func) + return ret diff --git a/tests/conftest.py b/tests/conftest.py index e3e737e..02fe20d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,12 @@ -#! /usr/bin/python3 import pytest import importlib + import lnprototest import pyln.spec.bolt1 import pyln.spec.bolt2 import pyln.spec.bolt7 from pyln.proto.message import MessageNamespace + from typing import Any, Callable, Generator, List @@ -14,7 +15,6 @@ def pytest_addoption(parser: Any) -> None: "--runner", action="store", help="runner to use", - default="lnprototest.DummyRunner", ) parser.addoption( "--runner-args", @@ -26,6 +26,11 @@ def pytest_addoption(parser: Any) -> None: @pytest.fixture() # type: ignore def runner(pytestconfig: Any) -> Any: + runner_opt = pytestconfig.getoption("runner") + if runner_opt is None: + pytest.skip( + "Runner need to be specified eg. `make check PYTEST_ARGS='--runner=lnprototest.clightning.Runner'`" + ) parts = pytestconfig.getoption("runner").rpartition(".") runner = importlib.import_module(parts[0]).__dict__[parts[2]](pytestconfig) yield runner diff --git a/tests/test_v2_bolt1-01-init.py b/tests/test_v2_bolt1-01-init.py new file mode 100644 index 0000000..42646b1 --- /dev/null +++ b/tests/test_v2_bolt1-01-init.py @@ -0,0 +1,20 @@ +from typing import Any + +from lnprototest.runner import Runner + + +def test_v2_init_is_first_msg(runner: Runner, namespaceoverride: Any) -> None: + """Tests workflow + + runner --- connect ---> ln node + runner <-- init ------ ln node + """ + runner.start() + + conn1 = runner.connect(None, connprivkey="03") + init_msg = conn1.recv_msg() + assert ( + init_msg.messagetype.number == 16 + ), f"received not an init msg but: {init_msg.to_str()}" + conn1.send_msg("init", globalfeatures="", features="") + runner.stop()