Skip to content

Commit

Permalink
Surface the basic configuration with the user before proceeding with …
Browse files Browse the repository at this point in the history
…validation checking

Issue drivendataorg#196
  • Loading branch information
sambujangfofana authored Apr 26, 2024
1 parent 70e79f2 commit 497f3b1
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions zamba/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def train(
If an argument is specified in both the command line and in a yaml file, the command line input will take precedence.
"""

if config is not None:
with config.open() as f:
config_dict = yaml.safe_load(f)
Expand Down Expand Up @@ -128,6 +129,27 @@ def train(
if skip_load_validation is not None:
train_dict["skip_load_validation"] = skip_load_validation

# surface the configuration before validation checking
msg = f"""Review the following configuration before proceeding with validation checking:
Config file: {config_file}
Data directory: {data_dir if data_dir is not None else config_dict["train_config"].get("data_dir")}
Labels csv: {labels if labels is not None else config_dict["train_config"].get("labels")}
Checkpoint: {checkpoint if checkpoint is not None else config_dict["train_config"].get("checkpoint")}
"""

if yes:
typer.echo(f"{msg}\n\nSkipping confirmation and proceeding to validation checking.")
else:
yes = typer.confirm(
f"{msg}\n\nIs this correct?",
abort=False,
default=True,
)
if not yes:
print("\n\nPlease review and adjust the configuration and run the command again.")
return

try:
manager = ModelManager(
ModelConfig(
Expand Down Expand Up @@ -321,6 +343,27 @@ def predict(
if overwrite is not None:
predict_dict["overwrite"] = overwrite

# surface the configuration before validation checking
msg = f"""Review the following configuration before proceeding with validation checking:
Config file: {config_file}
Data directory: {data_dir if data_dir is not None else config_dict["predict_config"].get("data_dir")}
Filepath csv: {filepaths if filepaths is not None else config_dict["predict_config"].get("filepaths")}
Checkpoint: {checkpoint if checkpoint is not None else config_dict["predict_config"].get("checkpoint")}
"""

if yes:
typer.echo(f"{msg}\n\nSkipping confirmation and proceeding to validation checking.")
else:
yes = typer.confirm(
f"{msg}\n\nIs this correct?",
abort=False,
default=True,
)
if not yes:
print("\n\nPlease review and adjust the configuration and run the command again.")
return

try:
manager = ModelManager(
ModelConfig(
Expand Down

0 comments on commit 497f3b1

Please sign in to comment.