diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index a0a2dc98556d..4bde009742f8 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -136,6 +136,7 @@ def new( framework_str = framework_str.lower() + llm_challenge_str = None if framework_str == "flowertune": llm_challenge_value = prompt_options( "Please select LLM challenge by typing in the number", @@ -171,7 +172,7 @@ def new( } # List of files to render - if framework_str == "flowertune": + if llm_challenge_str: files = { ".gitignore": {"template": "app/.gitignore.tpl"}, "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, @@ -228,10 +229,10 @@ def new( "README.md": {"template": "app/README.md.tpl"}, "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"}, - f"{import_name}/server.py": { + f"{import_name}/server_app.py": { "template": f"app/code/server.{framework_str}.py.tpl" }, - f"{import_name}/client.py": { + f"{import_name}/client_app.py": { "template": f"app/code/client.{framework_str}.py.tpl" }, } diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py index 33ad745efa93..7f22bd5f9825 100644 --- a/src/py/flwr/cli/new/new_test.py +++ b/src/py/flwr/cli/new/new_test.py @@ -86,8 +86,8 @@ def test_new_correct_name(tmp_path: str) -> None: } expected_files_module = { "__init__.py", - "server.py", - "client.py", + "server_app.py", + "client_app.py", "task.py", } diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl index ecb87bd71e3f..a0f781df04a1 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl @@ -12,10 +12,10 @@ from flwr.client import ClientApp from flwr.common import ndarrays_to_parameters from flwr.server import ServerApp, ServerConfig -from $import_name.client import gen_client_fn, get_parameters +from $import_name.client_app import gen_client_fn, get_parameters from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting from $import_name.models import get_model -from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config +from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config # Avoid warnings warnings.filterwarnings("ignore", category=UserWarning) diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl index 19223148bca5..5dd4d881f2f1 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl @@ -1,6 +1,6 @@ """$project_name: A Flower / FlowerTune app.""" -from $import_name.client import set_parameters +from $import_name.client_app import set_parameters from $import_name.models import get_model diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index b39facbec5a0..92c954e754cf 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -24,8 +24,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 405decf38f16..e899f48f4c5c 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -21,8 +21,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index a2b743800595..6004c076cf87 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -21,8 +21,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index ad074b90d24a..543936ed4a89 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -19,8 +19,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index ecd1497500ab..8a92cf0eca9a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -21,8 +21,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 4bc407c34262..5c1ffa09aed2 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -20,8 +20,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index 9dab874e50ff..de1a445e33f9 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -20,8 +20,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3"