Skip to content

Commit

Permalink
threading backend for model training and predictions (#866)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaycrk authored Dec 7, 2021
1 parent 829a1b7 commit 8bd22d8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/triage/component/catwalk/model_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd

from sklearn.model_selection import ParameterGrid
from sklearn.utils import parallel_backend
from sqlalchemy.orm import sessionmaker

from triage.util.random import generate_python_random_seed
Expand Down Expand Up @@ -122,7 +123,12 @@ def _train(self, matrix_store, class_path, parameters):
cls = getattr(module, class_name)
instance = cls(**parameters)

return instance.fit(matrix_store.design_matrix, matrix_store.labels)
# using a threading backend because the default loky backend doesn't
# allow for nested parallelization (e.g., multiprocessing at triage level)
with parallel_backend('threading'):
fitted = instance.fit(matrix_store.design_matrix, matrix_store.labels)

return fitted

@db_retry
def _save_feature_importances(self, model_id, feature_importances, feature_names):
Expand Down
10 changes: 7 additions & 3 deletions src/triage/component/catwalk/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from sqlalchemy.orm import sessionmaker
from sqlalchemy import or_
from sklearn.utils import parallel_backend

from .utils import db_retry, retrieve_model_hash_from_id, save_db_objects, sort_predictions_and_labels, AVAILABLE_TIEBREAKERS
from triage.component.results_schema import Model
Expand Down Expand Up @@ -269,9 +270,12 @@ def predict(self, model_id, matrix_store, misc_db_parameters, train_matrix_colum
# Labels are popped from matrix (i.e. they are removed and returned)
labels = matrix_store.labels

predictions = model.predict_proba(
matrix_store.matrix_with_sorted_columns(train_matrix_columns)
)[:, 1] # Returning only the scores for the label == 1
# using a threading backend because the default loky backend doesn't
# allow for nested parallelization (e.g., multiprocessing at triage level)
with parallel_backend('threading'):
predictions = model.predict_proba(
matrix_store.matrix_with_sorted_columns(train_matrix_columns)
)[:, 1] # Returning only the scores for the label == 1


logger.debug(
Expand Down

0 comments on commit 8bd22d8

Please sign in to comment.