diff --git a/09_model_pipeline/predict/src/predict.py b/09_model_pipeline/predict/src/predict.py index 6d20c8dc..86cb4b3d 100644 --- a/09_model_pipeline/predict/src/predict.py +++ b/09_model_pipeline/predict/src/predict.py @@ -1,9 +1,10 @@ from utils import * import numpy +import pickle class Predict: def main_predict(data, bucket_name, path): - model = load_models_from_s3(bucket_name, path ) - + #model = load_models_from_s3(bucket_name, path ) + model = pickle.load(open("trained_model.pickle",'r')) probas = model.predict_proba(data) preds = np.where(probas[:,1] >= 0.4, 1, 0) data['prob_0'] = probas[:,0] @@ -11,4 +12,4 @@ def main_predict(data, bucket_name, path): data['predictions'] = preds record = data.to_dict(orient='records')[0] json_result = data.to_json(orient='records', lines=True).splitlines() - return json_result \ No newline at end of file + return json_result