diff --git a/resources/scenarios/ln_init.py b/resources/scenarios/ln_init.py index 9f24e5040..c2f3d3474 100644 --- a/resources/scenarios/ln_init.py +++ b/resources/scenarios/ln_init.py @@ -183,3 +183,6 @@ def funded_lnnodes(): def main(): LNInit().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/miner_std.py b/resources/scenarios/miner_std.py index d91736d82..ac507171b 100755 --- a/resources/scenarios/miner_std.py +++ b/resources/scenarios/miner_std.py @@ -70,3 +70,6 @@ def run_test(self): def main(): MinerStd().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/reconnaissance.py b/resources/scenarios/reconnaissance.py index 1c539f5f7..453d6fc4a 100755 --- a/resources/scenarios/reconnaissance.py +++ b/resources/scenarios/reconnaissance.py @@ -82,3 +82,6 @@ def run_test(self): def main(): Reconnaissance().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/signet_miner.py b/resources/scenarios/signet_miner.py index 9a20ecc97..e4375515b 100644 --- a/resources/scenarios/signet_miner.py +++ b/resources/scenarios/signet_miner.py @@ -564,3 +564,6 @@ def get_args(parser): def main(): SignetMinerScenario().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/test_scenarios/__init__.py b/resources/scenarios/test_scenarios/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/resources/scenarios/testscenario_buggy_failure.py b/resources/scenarios/test_scenarios/buggy_failure.py similarity index 93% rename from resources/scenarios/testscenario_buggy_failure.py rename to resources/scenarios/test_scenarios/buggy_failure.py index 700f78ea1..249001bd4 100644 --- a/resources/scenarios/testscenario_buggy_failure.py +++ b/resources/scenarios/test_scenarios/buggy_failure.py @@ -22,3 +22,6 @@ def run_test(self): def main(): Failure().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/testscenario_connect_dag.py b/resources/scenarios/test_scenarios/connect_dag.py similarity index 99% rename from resources/scenarios/testscenario_connect_dag.py rename to resources/scenarios/test_scenarios/connect_dag.py index 4019e8944..2df50bc1b 100644 --- a/resources/scenarios/testscenario_connect_dag.py +++ b/resources/scenarios/test_scenarios/connect_dag.py @@ -119,3 +119,6 @@ def assert_connection(self, connector, connectee_index, connection_type: Connect def main(): ConnectDag().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/testscenario_p2p_interface.py b/resources/scenarios/test_scenarios/p2p_interface.py similarity index 97% rename from resources/scenarios/testscenario_p2p_interface.py rename to resources/scenarios/test_scenarios/p2p_interface.py index 47eee9006..4f88f49cc 100644 --- a/resources/scenarios/testscenario_p2p_interface.py +++ b/resources/scenarios/test_scenarios/p2p_interface.py @@ -54,3 +54,6 @@ def run_test(self): def main(): GetdataTest().main() + +if __name__ == "__main__": + main() diff --git a/resources/scenarios/tx_flood.py b/resources/scenarios/tx_flood.py index 1197fc8a6..7a60bccc5 100755 --- a/resources/scenarios/tx_flood.py +++ b/resources/scenarios/tx_flood.py @@ -69,3 +69,7 @@ def run_test(self): def main(): TXFlood().main() + + +if __name__ == "__main__": + main() diff --git a/src/warnet/control.py b/src/warnet/control.py index 41c12d6c1..12dc19cf7 100644 --- a/src/warnet/control.py +++ b/src/warnet/control.py @@ -163,14 +163,15 @@ def get_active_network(namespace): @click.command(context_settings={"ignore_unknown_options": True}) @click.argument("scenario_file", type=click.Path(exists=True, file_okay=True, dir_okay=False)) +@click.option("--source_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True), required=False) @click.argument("additional_args", nargs=-1, type=click.UNPROCESSED) -def run(scenario_file: str, additional_args: tuple[str]): +def run(scenario_file: str, source_dir, additional_args: tuple[str]): """ Run a scenario from a file. Pass `-- --help` to get individual scenario help """ scenario_path = Path(scenario_file).resolve() - scenario_dir = scenario_path.parent + scenario_dir = scenario_path.parent if not source_dir else Path(source_dir).resolve() scenario_name = scenario_path.stem if additional_args and ("--help" in additional_args or "-h" in additional_args): @@ -203,15 +204,24 @@ def run(scenario_file: str, additional_args: tuple[str]): def filter(path): if any(needle in str(path) for needle in [".pyc", ".csv", ".DS_Store"]): return False - return any( - needle in str(path) for needle in ["commander.py", "test_framework", scenario_name] - ) - + if any(needle in str(path) for needle in ["__init__.py", "commander.py", "test_framework", scenario_path.name]): + print(f"Including: {path}") + return True + return False + + # In case the scenario file is not in the root of the archive directory, + # we need to specify its relative path as a submodule + # First get the path of the file relative to the source directory + relative_path = scenario_path.relative_to(scenario_dir) + # Remove the '.py' extension + relative_name = relative_path.with_suffix("") + # Replace path separators with dots and pray the user included __init__.py + module_name = ".".join(relative_name.parts) # Compile python archive zipapp.create_archive( source=scenario_dir, target=archive_buffer, - main=f"{scenario_name}:main", + main=f"{module_name}:main", compressed=True, filter=filter, ) diff --git a/src/warnet/network.py b/src/warnet/network.py index f78eda42b..401ab5106 100644 --- a/src/warnet/network.py +++ b/src/warnet/network.py @@ -44,7 +44,7 @@ def copy_scenario_defaults(directory: Path): directory, SCENARIOS_DIR.name, SCENARIOS_DIR, - ["__pycache__", "testscenario_*.py"], + ["__pycache__", "test_scenarios"], ) diff --git a/test/dag_connection_test.py b/test/dag_connection_test.py index 4d8d953eb..dee38356a 100755 --- a/test/dag_connection_test.py +++ b/test/dag_connection_test.py @@ -26,9 +26,9 @@ def setup_network(self): self.wait_for_all_edges() def run_connect_dag_scenario(self): - scenario_file = self.scen_dir / "testscenario_connect_dag.py" + scenario_file = self.scen_dir / "test_scenarios" / "connect_dag.py" self.log.info(f"Running scenario from: {scenario_file}") - self.warnet(f"run {scenario_file}") + self.warnet(f"run {scenario_file} --source_dir={self.scen_dir}") self.wait_for_all_scenarios() diff --git a/test/scenarios_test.py b/test/scenarios_test.py index 835c273c5..0b8ba7a4a 100755 --- a/test/scenarios_test.py +++ b/test/scenarios_test.py @@ -83,9 +83,9 @@ def run_and_check_miner_scenario_from_file(self): self.stop_scenario() def run_and_check_scenario_from_file(self): - scenario_file = self.scen_dir / "testscenario_p2p_interface.py" + scenario_file = self.scen_dir / "test_scenarios" / "p2p_interface.py" self.log.info(f"Running scenario from: {scenario_file}") - self.warnet(f"run {scenario_file}") + self.warnet(f"run {scenario_file} --source_dir={self.scen_dir}") self.wait_for_predicate(self.check_scenario_clean_exit) def check_regtest_recon(self): @@ -95,9 +95,9 @@ def check_regtest_recon(self): self.wait_for_predicate(self.check_scenario_clean_exit) def check_active_count(self): - scenario_file = self.scen_dir / "testscenario_buggy_failure.py" + scenario_file = self.scen_dir / "test_scenarios" / "buggy_failure.py" self.log.info(f"Running scenario from: {scenario_file}") - self.warnet(f"run {scenario_file}") + self.warnet(f"run {scenario_file} --source_dir={self.scen_dir}") def two_pass_one_fail(): deployed = scenarios_deployed()