From 493e6c61a7df50995bf9e69cf7f742b890e8035d Mon Sep 17 00:00:00 2001 From: Niko Aarnio Date: Tue, 21 May 2024 08:41:17 +0300 Subject: [PATCH] Fix test metrics prints for classifier and regressor test CLI functions --- eis_toolkit/cli.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index 76064307..9e5aa3f3 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -2341,8 +2341,7 @@ def classifier_test_cli( predictions, reference_profile["height"], reference_profile["width"], nodata_mask ) - metrics_dict = score_predictions(y, predictions, test_metrics) - # json_str = json.dumps(metrics_dict) + metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics)) typer.echo("Progress: 80%") out_profile = reference_profile.copy() @@ -2353,14 +2352,12 @@ def classifier_test_cli( with rasterio.open(output_raster_classified, "w", **out_profile) as dst: dst.write(predictions_reshaped, 1) - typer.echo("Progress: 100%\n") - # typer.echo(f"Results:") - + typer.echo("\n") for key, value in metrics_dict.items(): typer.echo(f"{key}: {value}") - typer.echo("\n") + typer.echo("Progress: 100%") typer.echo( ( "Testing classifier model completed, writing rasters to " @@ -2388,13 +2385,12 @@ def regressor_test_cli( model = load_model(model_file) predictions = predict_regressor(X, model) - metrics_dict = score_predictions(y, predictions, test_metrics) predictions_reshaped = reshape_predictions( predictions, reference_profile["height"], reference_profile["width"], nodata_mask ) - typer.echo("Progress: 80%") - # json_str = json.dumps(metrics_dict) + metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics)) + typer.echo("Progress: 80%") out_profile = reference_profile.copy() out_profile.update({"count": 1, "dtype": np.float32}) @@ -2402,14 +2398,13 @@ def regressor_test_cli( with rasterio.open(output_raster, "w", **out_profile) as dst: dst.write(predictions_reshaped, 1) - typer.echo("Progress: 100%\n") - # typer.echo("Results: ") - + typer.echo("\n") for key, value in metrics_dict.items(): typer.echo(f"{key}: {value}") - typer.echo("\n") + typer.echo("Progress: 100%\n") + typer.echo(f"Testing regressor model completed, writing raster to {output_raster}.")