diff --git a/09_model_pipeline/predict/src/predict.py b/09_model_pipeline/predict/src/predict.py index 86cb4b3d..187a395f 100644 --- a/09_model_pipeline/predict/src/predict.py +++ b/09_model_pipeline/predict/src/predict.py @@ -4,7 +4,7 @@ class Predict: def main_predict(data, bucket_name, path): #model = load_models_from_s3(bucket_name, path ) - model = pickle.load(open("trained_model.pickle",'r')) + model = pickle.load(open("src/trained_model.pickle",'rb')) probas = model.predict_proba(data) preds = np.where(probas[:,1] >= 0.4, 1, 0) data['prob_0'] = probas[:,0]