diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index e78ccf7a..064306ca 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -2318,20 +2318,23 @@ def evaluate_trained_model_cli( output_raster: OUTPUT_FILE_OPTION, validation_metrics: Annotated[List[str], typer.Option()], ): - """Train and optionally validate a Gradient boosting regressor model using Sklearn.""" - from eis_toolkit.prediction.machine_learning_general import ( - evaluate_model, - load_model, - prepare_data_for_ml, - reshape_predictions, - ) + """Evaluate a trained machine learning model by predicting and scoring.""" + from sklearn.base import is_classifier - X, y, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters, target_labels) + from eis_toolkit.evaluation.scoring import score_predictions + from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml, reshape_predictions + from eis_toolkit.prediction.machine_learning_predict import predict_classifier, predict_regressor + X, y, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters, target_labels) + print(len(np.unique(y))) typer.echo("Progress: 30%") model = load_model(model_file) - predictions, metrics_dict = evaluate_model(X, y, model, validation_metrics) + if is_classifier(model): + predictions, probabilities = predict_classifier(X, model, True) + else: + predictions = predict_regressor(X, model) + metrics_dict = score_predictions(y, predictions, validation_metrics) predictions_reshaped = reshape_predictions( predictions, reference_profile["height"], reference_profile["width"], nodata_mask ) @@ -2359,20 +2362,22 @@ def predict_with_trained_model_cli( model_file: INPUT_FILE_OPTION, output_raster: OUTPUT_FILE_OPTION, ): - """Train and optionally validate a Gradient boosting regressor model using Sklearn.""" - from eis_toolkit.prediction.machine_learning_general import ( - load_model, - predict, - prepare_data_for_ml, - reshape_predictions, - ) + """Predict with a trained machine learning model.""" + from sklearn.base import is_classifier + + from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml, reshape_predictions + from eis_toolkit.prediction.machine_learning_predict import predict_classifier, predict_regressor X, _, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters) typer.echo("Progress: 30%") model = load_model(model_file) - predictions = predict(X, model) + if is_classifier(model): + predictions, probabilities = predict_classifier(X, model, True) + else: + predictions = predict_regressor(X, model) + predictions_reshaped = reshape_predictions( predictions, reference_profile["height"], reference_profile["width"], nodata_mask )