diff --git a/predict.py b/predict.py index 9e01659..a4aafee 100644 --- a/predict.py +++ b/predict.py @@ -8,11 +8,10 @@ def predict_image( image_path: str = typer.Argument(help="image path", show_default=True), - # model_path: str = typer.Argument("checkpoint_notebook.pth", help="model path (pth)", show_default=True), - download: bool = typer.Argument(True, help="True for download the model automatically", show_default=True), + model_path: str = typer.Argument(None, help="path to your model (pth)", show_default=True), device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True) ): - predictor = ImageRecognition(download=download, device=device) + predictor = ImageRecognition(model_path=model_path, device=device) result = predictor.predict(image=image_path) typer.echo(f"Prediction: {result}")