From 8b2f1498a559c469fdd4b964c26448bc0286171a Mon Sep 17 00:00:00 2001 From: sambujangfofana <112506351+sambujangfofana@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:39:38 -0400 Subject: [PATCH 1/2] Update cli.py --- zamba/cli.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/zamba/cli.py b/zamba/cli.py index f8cf1ae9..c7528e92 100644 --- a/zamba/cli.py +++ b/zamba/cli.py @@ -78,6 +78,7 @@ def train( If an argument is specified in both the command line and in a yaml file, the command line input will take precedence. """ + if config is not None: with config.open() as f: config_dict = yaml.safe_load(f) @@ -128,6 +129,27 @@ def train( if skip_load_validation is not None: train_dict["skip_load_validation"] = skip_load_validation + # surface the configuration before validation checking + msg = f"""Review the following configuration before proceeding with validation checking: + + Config file: {config_file} + Data directory: {data_dir if data_dir is not None else config_dict["train_config"].get("data_dir")} + Labels csv: {labels if labels is not None else config_dict["train_config"].get("labels")} + Checkpoint: {checkpoint if checkpoint is not None else config_dict["train_config"].get("checkpoint")} + """ + + if yes: + typer.echo(f"{msg}\n\nSkipping confirmation and proceeding to validation checking.") + else: + yes = typer.confirm( + f"{msg}\n\nIs this correct?", + abort=False, + default=True, + ) + if not yes: + print("\n\nPlease review and adjust the configuration and run the command again.") + return + try: manager = ModelManager( ModelConfig( @@ -321,6 +343,27 @@ def predict( if overwrite is not None: predict_dict["overwrite"] = overwrite + # surface the configuration before validation checking + msg = f"""Review the following configuration before proceeding with validation checking: + + Config file: {config_file} + Data directory: {data_dir if data_dir is not None else config_dict["predict_config"].get("data_dir")} + Filepath csv: {filepaths if filepaths is not None else config_dict["predict_config"].get("filepaths")} + Checkpoint: {checkpoint if checkpoint is not None else config_dict["predict_config"].get("checkpoint")} + """ + + if yes: + typer.echo(f"{msg}\n\nSkipping confirmation and proceeding to validation checking.") + else: + yes = typer.confirm( + f"{msg}\n\nIs this correct?", + abort=False, + default=True, + ) + if not yes: + print("\n\nPlease review and adjust the configuration and run the command again.") + return + try: manager = ModelManager( ModelConfig( From 5cd74c20217a84704e28075249f4dde942bd8a4b Mon Sep 17 00:00:00 2001 From: sambujangfofana <112506351+sambujangfofana@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:52:44 -0400 Subject: [PATCH 2/2] Update config.py --- zamba/models/config.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/zamba/models/config.py b/zamba/models/config.py index bd826381..a4e1d90e 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -524,11 +524,13 @@ def validate_filepaths_and_labels(cls, values): ) elif values["split_proportions"] is not None: - logger.warning( - "Labels contains split column yet split_proportions are also provided. Split column in labels takes precedence." - ) - # set to None for clarity in final configuration.yaml - values["split_proportions"] = None + # Check to see if split_proportions contains the default values + if values.get("split_proportions") != {"train": 3, "val": 1, "holdout": 1}: + logger.warning( + "Labels contains split column yet split_proportions are also provided. Split column in labels takes precedence." + ) + # set to None for clarity in final configuration.yaml + values["split_proportions"] = None # error if labels are entirely null null_labels = labels.label.isnull()