Skip to content

Commit

Permalink
ml/model update
Browse files Browse the repository at this point in the history
  • Loading branch information
XxRemsteelexX committed Aug 2, 2024
1 parent 387b4ca commit 5be3633
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions ml/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Databricks notebook source
import pandas as pd

import multiprocessing
import logging
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import pickle
from sklearn.metrics import fbeta_score, precision_score, recall_score
from ml.data import process_data
Expand All @@ -20,7 +27,9 @@ def train_model(X_train, y_train):
Trained machine learning model.
"""
# TODO: implement the function
pass
hgc = RandomForestClassifier()
model = hgc.fit(X_train, y_train)
return model


def compute_model_metrics(y, preds):
Expand Down Expand Up @@ -60,7 +69,8 @@ def inference(model, X):
Predictions from the model.
"""
# TODO: implement the function
pass
preds = model.predict(X)
return preds

def save_model(model, path):
""" Serializes model to a file.
Expand All @@ -73,12 +83,16 @@ def save_model(model, path):
Path to save pickle file.
"""
# TODO: implement the function
pass
with open(path, 'wb') as f:
pickle.dump(model, f)

def load_model(path):
""" Loads pickle file from `path` and returns it."""
# TODO: implement the function
pass
with open(path, 'rb') as f:
model = pickle.load(f)
return model



def performance_on_categorical_slice(
Expand Down Expand Up @@ -122,7 +136,14 @@ def performance_on_categorical_slice(
# your code here
# for input data, use data in column given as "column_name", with the slice_value
# use training = False
X = slice_data,
categorical_features=categorical_features,
label = label,
training = False,
encoder = encoder,
lb = lb

)
preds = # your code here to get prediction on X_slice using the inference function
preds = inference(model, X_slice)
precision, recall, fbeta = compute_model_metrics(y_slice, preds)
return precision, recall, fbeta

0 comments on commit 5be3633

Please sign in to comment.