From 413a7790be17a7194b72564b14a567931e7d71e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Adri=C3=A1n=20Alarc=C3=B3n?= <83436724+aladelca@users.noreply.github.com> Date: Fri, 26 Apr 2024 09:38:47 -0400 Subject: [PATCH] Update predict.py --- 09_model_pipeline/predict/src/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/09_model_pipeline/predict/src/predict.py b/09_model_pipeline/predict/src/predict.py index 86cb4b3d..187a395f 100644 --- a/09_model_pipeline/predict/src/predict.py +++ b/09_model_pipeline/predict/src/predict.py @@ -4,7 +4,7 @@ class Predict: def main_predict(data, bucket_name, path): #model = load_models_from_s3(bucket_name, path ) - model = pickle.load(open("trained_model.pickle",'r')) + model = pickle.load(open("src/trained_model.pickle",'rb')) probas = model.predict_proba(data) preds = np.where(probas[:,1] >= 0.4, 1, 0) data['prob_0'] = probas[:,0]