Skip to content

Commit

Permalink
Update predict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aladelca authored Apr 26, 2024
1 parent 0a129d2 commit 781ae51
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions 09_model_pipeline/predict/src/predict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from utils import *
import numpy
import pickle
class Predict:
def main_predict(data, bucket_name, path):
model = load_models_from_s3(bucket_name, path )

#model = load_models_from_s3(bucket_name, path )
model = pickle.load(open("trained_model.pickle",'r'))
probas = model.predict_proba(data)
preds = np.where(probas[:,1] >= 0.4, 1, 0)
data['prob_0'] = probas[:,0]
data['prob_1'] = probas[:,1]
data['predictions'] = preds
record = data.to_dict(orient='records')[0]
json_result = data.to_json(orient='records', lines=True).splitlines()
return json_result
return json_result

0 comments on commit 781ae51

Please sign in to comment.