Skip to content

Commit

Permalink
Fix test metrics prints for classifier and regressor test CLI functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nmaarnio committed May 21, 2024
1 parent f1ed697 commit 493e6c6
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions eis_toolkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,8 +2341,7 @@ def classifier_test_cli(
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
)

metrics_dict = score_predictions(y, predictions, test_metrics)
# json_str = json.dumps(metrics_dict)
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
typer.echo("Progress: 80%")

out_profile = reference_profile.copy()
Expand All @@ -2353,14 +2352,12 @@ def classifier_test_cli(
with rasterio.open(output_raster_classified, "w", **out_profile) as dst:
dst.write(predictions_reshaped, 1)

typer.echo("Progress: 100%\n")
# typer.echo(f"Results:")

typer.echo("\n")
for key, value in metrics_dict.items():
typer.echo(f"{key}: {value}")

typer.echo("\n")

typer.echo("Progress: 100%")
typer.echo(
(
"Testing classifier model completed, writing rasters to "
Expand Down Expand Up @@ -2388,28 +2385,26 @@ def regressor_test_cli(

model = load_model(model_file)
predictions = predict_regressor(X, model)
metrics_dict = score_predictions(y, predictions, test_metrics)
predictions_reshaped = reshape_predictions(
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
)
typer.echo("Progress: 80%")

# json_str = json.dumps(metrics_dict)
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
typer.echo("Progress: 80%")

out_profile = reference_profile.copy()
out_profile.update({"count": 1, "dtype": np.float32})

with rasterio.open(output_raster, "w", **out_profile) as dst:
dst.write(predictions_reshaped, 1)

typer.echo("Progress: 100%\n")
# typer.echo("Results: ")

typer.echo("\n")
for key, value in metrics_dict.items():
typer.echo(f"{key}: {value}")

typer.echo("\n")

typer.echo("Progress: 100%\n")

typer.echo(f"Testing regressor model completed, writing raster to {output_raster}.")


Expand Down

0 comments on commit 493e6c6

Please sign in to comment.