diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index b1e1904..9b1c53e 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -391,7 +391,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness timeout=timeout, ) - if not _link or _link.status != RNS.Link.ACTIVE: + if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]: _finished.set() return 255 diff --git a/rnsh/session.py b/rnsh/session.py index 2b71b82..e551b50 100644 --- a/rnsh/session.py +++ b/rnsh/session.py @@ -101,8 +101,12 @@ def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop): self.return_code: int | None = None self.return_code_sent = False self.process: process.CallbackSubprocess | None = None - self._set_state(LSState.LSSTATE_WAIT_IDENT) + if self.allow_all: + self._set_state(LSState.LSSTATE_WAIT_VERS) + else: + self._set_state(LSState.LSSTATE_WAIT_IDENT) self.sessions.append(self) + self.outlet.set_packet_received_callback(self._packet_received) def _terminated(self, return_code: int): self.return_code = return_code @@ -176,14 +180,13 @@ def _initiator_identified(self, outlet, identity): return self._log.info(f"initiator_identified {identity} on link {outlet}") - if self.state != LSState.LSSTATE_WAIT_IDENT: + if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]: self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name) if not self.allow_all and identity.hash not in self.allowed_identity_hashes: self.terminate("Identity is not allowed.") self.remote_identity = identity - self.outlet.set_packet_received_callback(self._packet_received) self._set_state(LSState.LSSTATE_WAIT_VERS) @classmethod @@ -277,7 +280,8 @@ def stderr(data: bytes): try: self.process = process.CallbackSubprocess(argv=self.cmdline, env={"TERM": self.term or os.environ.get("TERM", None), - "RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""}, + "RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash) + if self.remote_identity and self.remote_identity.hash else "")}, loop=self.loop, stdout_callback=stdout, stderr_callback=stderr, @@ -306,6 +310,9 @@ def _received_stdin(self, data: bytes, eof: bool): self.process.close_stdin() def _handle_message(self, message: protocol.Message): + if self.state == LSState.LSSTATE_WAIT_IDENT: + self._protocol_error("Identification required") + return if self.state == LSState.LSSTATE_WAIT_VERS: if not isinstance(message, protocol.VersionInfoMessage): self._protocol_error(self.state.name) diff --git a/tests/helpers.py b/tests/helpers.py index 53604b8..3d32045 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -84,6 +84,7 @@ def cleanup(self): self._log.debug(f"cleanup()") if self.process and self.process.running: self.process.terminate(kill_delay=0.1) + time.sleep(0.5) def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, __traceback: types.TracebackType) -> bool: diff --git a/tests/test_args.py b/tests/test_args.py index c3fcd72..f9c7cc4 100644 --- a/tests/test_args.py +++ b/tests/test_args.py @@ -33,11 +33,26 @@ def test_program_initiate_no_args(): args = rnsh.args.Args(shlex.split("rnsh one")) assert not args.listen assert args.destination == "one" + assert not args.no_id assert args.command_line == [] except docopt.DocoptExit: docopt_threw = True assert not docopt_threw + +def test_program_initiate_no_auth(): + docopt_threw = False + try: + args = rnsh.args.Args(shlex.split("rnsh -N one")) + assert not args.listen + assert args.destination == "one" + assert args.no_id + assert args.command_line == [] + except docopt.DocoptExit: + docopt_threw = True + assert not docopt_threw + + def test_program_initiate_dash_args(): docopt_threw = False try: diff --git a/tests/test_rnsh.py b/tests/test_rnsh.py index 17a3f94..9057d6f 100644 --- a/tests/test_rnsh.py +++ b/tests/test_rnsh.py @@ -123,14 +123,17 @@ async def do_connected_test(listener_args: str, initiator_args: str, test: calla assert len(ih) == 32 assert len(dh) == 32 assert len(iih) == 32 + assert "dh" in initiator_args + initiator_args = initiator_args.replace("dh", dh) + listener_args = listener_args.replace("iih", iih) with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \ - tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {dh} {initiator_args}")) as initiator: + tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {initiator_args}")) as initiator: # listener startup listener.start() await asyncio.sleep(0.1) assert listener.process.running # wait for process to start up - await asyncio.sleep(5) + await asyncio.sleep(2) # read the output text = listener.read().decode("utf-8") assert text.index(dh) is not None @@ -166,7 +169,55 @@ async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.Subp text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") assert text[len(text)-len(cwd):] == cwd - await do_connected_test("-n -C -- /bin/pwd", "", test) + await do_connected_test("-n -C -- /bin/pwd", "dh", test) + + +@pytest.mark.skip_ci +@pytest.mark.asyncio +async def test_rnsh_no_ident(): + cwd = os.getcwd() + + async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, + initiator: tests.helpers.SubprocessReader): + start_time = time.time() + while initiator.return_code is None and time.time() - start_time < 3: + await asyncio.sleep(0.1) + text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") + assert text[len(text)-len(cwd):] == cwd + + await do_connected_test("-n -C -- /bin/pwd", "-N dh", test) + + +@pytest.mark.skip_ci +@pytest.mark.asyncio +async def test_rnsh_invalid_ident(): + cwd = os.getcwd() + + async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, + initiator: tests.helpers.SubprocessReader): + start_time = time.time() + while initiator.return_code is None and time.time() - start_time < 3: + await asyncio.sleep(0.1) + text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") + assert "not allowed" in text + + await do_connected_test("-a 12345678901234567890123456789012 -C -- /bin/pwd", "dh", test) + + +@pytest.mark.skip_ci +@pytest.mark.asyncio +async def test_rnsh_valid_ident(): + cwd = os.getcwd() + + async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, + initiator: tests.helpers.SubprocessReader): + start_time = time.time() + while initiator.return_code is None and time.time() - start_time < 3: + await asyncio.sleep(0.1) + text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") + assert text[len(text)-len(cwd):] == cwd + + await do_connected_test("-a iih -C -- /bin/pwd", "dh", test)