From 72259fca520b495c5eb4d142e74034b7b2ebbf48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Dec 2024 09:33:13 +0000 Subject: [PATCH 01/38] feat(prediction): add `model` as instance attribute --- src/gentropy/dataset/l2g_prediction.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 2bc286a40..72b6afdef 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -29,6 +29,8 @@ class L2GPrediction(Dataset): confidence of the prediction that a gene is causal to an association. """ + model: LocusToGeneModel | None = None + @classmethod def get_schema(cls: type[L2GPrediction]) -> StructType: """Provides the schema for the L2GPrediction dataset. @@ -85,7 +87,9 @@ def from_credible_set( .select_features(features_list) ) - return l2g_model.predict(fm, session) + predictions = l2g_model.predict(fm, session) + predictions.model = l2g_model # Set the model attribute + return predictions def to_disease_target_evidence( self: L2GPrediction, From 9e8c491961d7f9315011c04ac9ad18c34d3dc545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Dec 2024 11:14:55 +0000 Subject: [PATCH 02/38] feat: added `convert_map_type_to_columns` spark util --- src/gentropy/common/spark_helpers.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 64a8bceb7..c55d77001 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -885,3 +885,35 @@ def calculate_harmonic_sum(input_array: Column) -> Column: / f.pow(x["pos"], 2) / f.lit(sum(1 / ((i + 1) ** 2) for i in range(1000))), ) + + +def convert_map_type_to_columns(df: DataFrame, map_column: Column) -> list[Column]: + """Convert a MapType column into multiple columns, one for each key in the map. + + Args: + df (DataFrame): A Spark DataFrame + map_column (Column): A Spark Column of MapType + + Returns: + list[Column]: List of columns, one for each key in the map + + Examples: + >>> df = spark.createDataFrame([({'a': 1, 'b': 2},), ({'c':3},)], ["map_col"]) + >>> df.select(*convert_map_type_to_columns(df, f.col("map_col"))).show() + +----+----+----+ + | a| b| c| + +----+----+----+ + | 1| 2|null| + |null|null| 3| + +----+----+----+ + + """ + # Schema is agnostic of the map keys, I have to collect them first + keys = ( + df.select(f.explode(map_column)) + .select("key") + .distinct() + .rdd.flatMap(lambda x: x) + .collect() + ) + return [map_column.getItem(k).alias(k) for k in keys] From 450a9375862cc0e94a813cab1d543cb319291dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Dec 2024 11:38:41 +0000 Subject: [PATCH 03/38] feat(prediction): new method `explain` returns shapley values --- src/gentropy/dataset/l2g_prediction.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 72b6afdef..c80c449d6 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -6,10 +6,12 @@ from typing import TYPE_CHECKING, Type import pyspark.sql.functions as f +import shap from pyspark.sql import DataFrame from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session +from gentropy.common.spark_helpers import convert_map_type_to_columns from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex @@ -17,6 +19,7 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: + from numpy import ndarray as np_ndarray from pyspark.sql.types import StructType @@ -132,6 +135,29 @@ def to_disease_target_evidence( ) ) + def explain(self: L2GPrediction) -> np_ndarray: + """Extract Shapley values for the L2G predictions. + + Returns: + np_ndarray: Shapley values + + Raises: + ValueError: If the model is not set + """ + if self.model is None: + raise ValueError("Model not set, explainer cannot be created") + explainer = shap.TreeExplainer( + self.model.model, feature_perturbation="tree_path_dependent" + ) + features_matrix = ( + self.df.select( + *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + ) + .toPandas() + .to_numpy() + ) + return explainer.shap_values(features_matrix) + def add_locus_to_gene_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix ) -> L2GPrediction: From 08ae6bd294bf3ed6d8a94ab801d0cbcddce7ed51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 4 Dec 2024 15:58:23 +0000 Subject: [PATCH 04/38] feat(prediction): `explain` returns predictions with shapley values --- .../assets/schemas/l2g_predictions.json | 11 ++++ src/gentropy/dataset/l2g_prediction.py | 60 +++++++++++++++---- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 57247a49a..8bda086a3 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -29,6 +29,17 @@ "valueContainsNull": true, "valueType": "float" } + }, + { + "metadata": {}, + "name": "shapleyValues", + "nullable": true, + "type": { + "keyType": "string", + "type": "map", + "valueContainsNull": false, + "valueType": "float" + } } ] } diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index c80c449d6..5833810e9 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -7,7 +7,7 @@ import pyspark.sql.functions as f import shap -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, Window from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -19,7 +19,6 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: - from numpy import ndarray as np_ndarray from pyspark.sql.types import StructType @@ -135,11 +134,12 @@ def to_disease_target_evidence( ) ) - def explain(self: L2GPrediction) -> np_ndarray: - """Extract Shapley values for the L2G predictions. + def explain(self: L2GPrediction) -> L2GPrediction: + """Extract Shapley values for the L2G predictions and add them as a map column. Returns: - np_ndarray: Shapley values + L2GPrediction: L2GPrediction object with an additional column 'shapleyValues' containing + feature name to Shapley value mappings Raises: ValueError: If the model is not set @@ -149,14 +149,50 @@ def explain(self: L2GPrediction) -> np_ndarray: explainer = shap.TreeExplainer( self.model.model, feature_perturbation="tree_path_dependent" ) - features_matrix = ( - self.df.select( - *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ) - .toPandas() - .to_numpy() + features_matrix = self.df.select( + *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + ).toPandas() + shapley_values = explainer.shap_values(features_matrix.to_numpy()) + + # Create arrays of Shapley values for each feature + features_list = list(features_matrix.columns) + shapley_arrays = { + feature: [row[i] for row in shapley_values] + for i, feature in enumerate(features_list) + } + return L2GPrediction( + _df=( + self.df.withColumn( + # Add row index to ensure correct mapping between the predictions and the shapley values + "tmp_idx", + f.row_number().over( + Window.orderBy(f.monotonically_increasing_id()) + ), + ) + .withColumn( + "shapleyValues", + f.create_map( + *[ + item + for feature in features_list + for item in [ + f.lit(feature), + f.array( + [f.lit(float(x)) for x in shapley_arrays[feature]] + ) + .getItem( + f.col("tmp_idx") - f.lit(1) + ) # we substract one because row_number starts counting from 1 + .cast("float"), + ] + ] + ), + ) + .drop("tmp_idx") + ), + _schema=self.get_schema(), + model=self.model, ) - return explainer.shap_values(features_matrix) def add_locus_to_gene_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix From 9d40e6254b18bc6a8112fe5f3571b11592c9d2cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 4 Dec 2024 16:15:16 +0000 Subject: [PATCH 05/38] chore: compute `shapleyValues` in the l2g step --- src/gentropy/l2g.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 16922ef78..1825c305e 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -285,7 +285,7 @@ def run_predict(self) -> None: ) predictions.filter( f.col("score") >= self.l2g_threshold - ).add_locus_to_gene_features(self.feature_matrix).df.coalesce( + ).add_locus_to_gene_features(self.feature_matrix).explain().df.coalesce( self.session.output_partitions ).write.mode(self.session.write_mode).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") From f407512d0f7c825c41c632c5abbb13285286c0e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 5 Dec 2024 17:53:13 +0000 Subject: [PATCH 06/38] refactor: use pandas udf instead --- src/gentropy/dataset/l2g_prediction.py | 98 +++++++++++++++----------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 5833810e9..a0581d0d6 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -5,9 +5,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Type +import pandas as pd import pyspark.sql.functions as f import shap -from pyspark.sql import DataFrame, Window +from pyspark.sql import DataFrame +from pyspark.sql.functions import pandas_udf from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -146,52 +148,64 @@ def explain(self: L2GPrediction) -> L2GPrediction: """ if self.model is None: raise ValueError("Model not set, explainer cannot be created") + + # Create explainer once explainer = shap.TreeExplainer( self.model.model, feature_perturbation="tree_path_dependent" ) - features_matrix = self.df.select( - *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ).toPandas() - shapley_values = explainer.shap_values(features_matrix.to_numpy()) - - # Create arrays of Shapley values for each feature - features_list = list(features_matrix.columns) - shapley_arrays = { - feature: [row[i] for row in shapley_values] - for i, feature in enumerate(features_list) - } - return L2GPrediction( - _df=( - self.df.withColumn( - # Add row index to ensure correct mapping between the predictions and the shapley values - "tmp_idx", - f.row_number().over( - Window.orderBy(f.monotonically_increasing_id()) - ), - ) - .withColumn( - "shapleyValues", - f.create_map( - *[ - item - for feature in features_list - for item in [ + + # Create UDF for Shapley calculation + @pandas_udf("array") + def calculate_shapley_values(features_pd: pd.DataFrame) -> pd.Series: + """Calculate Shapley values for a batch of features. + + Args: + features_pd (pd.DataFrame): Batch of features. + + Returns: + pd.Series: Series of Shapley values for the batch. + """ + feature_array = features_pd.to_numpy() + shapley_values = explainer.shap_values(feature_array) + return pd.Series([list(row) for row in shapley_values]) + + df_w_features = self.df.select( + "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + ) + features_list = [ + col for col in df_w_features.columns if col not in self.df.columns + ] + # Apply UDF and create map of feature names to Shapley values + result_df = ( + df_w_features.withColumn( + "shapley_array", + calculate_shapley_values( + f.array(*[f.col(feature) for feature in features_list]) + ), + ) + .withColumn( + "shapleyValues", + f.create_map( + *sum( + ( + ( f.lit(feature), - f.array( - [f.lit(float(x)) for x in shapley_arrays[feature]] - ) - .getItem( - f.col("tmp_idx") - f.lit(1) - ) # we substract one because row_number starts counting from 1 - .cast("float"), - ] - ] - ), - ) - .drop("tmp_idx") - ), - _schema=self.get_schema(), + f.element_at("shapley_array", f.lit(pos + 1)), + ) + for pos, feature in enumerate(features_list) + ), + (), + ) + ), + ) + .drop("shapley_array") + .select(*[field.name for field in self.get_schema().fields]) + ) + + return L2GPrediction( + _df=result_df, model=self.model, + _schema=self.get_schema(), ) def add_locus_to_gene_features( From f542395075e45e7a4717d9b660261c9c279a7256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 15:22:41 +0000 Subject: [PATCH 07/38] refactor: forget about udfs and get shaps single threaded --- src/gentropy/dataset/l2g_prediction.py | 70 +++++++++----------------- 1 file changed, 25 insertions(+), 45 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index a0581d0d6..d4b2e314c 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,14 +2,14 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Type -import pandas as pd import pyspark.sql.functions as f import shap from pyspark.sql import DataFrame -from pyspark.sql.functions import pandas_udf +from pyspark.sql.types import StructType from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -137,11 +137,10 @@ def to_disease_target_evidence( ) def explain(self: L2GPrediction) -> L2GPrediction: - """Extract Shapley values for the L2G predictions and add them as a map column. + """Extract Shapley values for the L2G predictions and add them as a map in an additional column. Returns: - L2GPrediction: L2GPrediction object with an additional column 'shapleyValues' containing - feature name to Shapley value mappings + L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings Raises: ValueError: If the model is not set @@ -149,62 +148,43 @@ def explain(self: L2GPrediction) -> L2GPrediction: if self.model is None: raise ValueError("Model not set, explainer cannot be created") - # Create explainer once explainer = shap.TreeExplainer( self.model.model, feature_perturbation="tree_path_dependent" ) - - # Create UDF for Shapley calculation - @pandas_udf("array") - def calculate_shapley_values(features_pd: pd.DataFrame) -> pd.Series: - """Calculate Shapley values for a batch of features. - - Args: - features_pd (pd.DataFrame): Batch of features. - - Returns: - pd.Series: Series of Shapley values for the batch. - """ - feature_array = features_pd.to_numpy() - shapley_values = explainer.shap_values(feature_array) - return pd.Series([list(row) for row in shapley_values]) - df_w_features = self.df.select( "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ) + ).drop("shapleyValues") features_list = [ - col for col in df_w_features.columns if col not in self.df.columns + col for col in df_w_features.columns if col not in self.get_schema().names ] - # Apply UDF and create map of feature names to Shapley values - result_df = ( - df_w_features.withColumn( - "shapley_array", - calculate_shapley_values( - f.array(*[f.col(feature) for feature in features_list]) - ), + pdf = df_w_features.select(features_list).toPandas() + + # Calculate SHAP values + if pdf.shape[0] >= 10_000: + logging.warning( + "Calculating SHAP values for more than 10,000 rows. This may take a while..." + ) + shap_values = explainer.shap_values(pdf.to_numpy()) + for i, feature in enumerate(features_list): + pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + + spark_session = df_w_features.sparkSession + return L2GPrediction( + _df=df_w_features.join( + # Convert df with shapley values to Spark and join with original df + spark_session.createDataFrame(pdf.to_dict(orient="records")), + features_list, ) .withColumn( "shapleyValues", f.create_map( *sum( - ( - ( - f.lit(feature), - f.element_at("shapley_array", f.lit(pos + 1)), - ) - for pos, feature in enumerate(features_list) - ), + ((f.lit(col), f.col(f"shap_{col}")) for col in features_list), (), ) ), ) - .drop("shapley_array") - .select(*[field.name for field in self.get_schema().fields]) - ) - - return L2GPrediction( - _df=result_df, - model=self.model, + .select(*self.get_schema().names), _schema=self.get_schema(), ) From 9403fe620a9303613fbc36c6051b50d3a16aeff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 15:23:23 +0000 Subject: [PATCH 08/38] chore: remove reference to chromatin interaction data in HF card --- src/gentropy/method/l2g/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 336efeb7f..091a970d4 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -194,7 +194,6 @@ def _create_hugging_face_model_card( - Distance: (from credible set variants to gene) - Molecular QTL Colocalization - - Chromatin Interaction: (e.g., promoter-capture Hi-C) - Variant Pathogenicity: (from VEP) More information at: https://opentargets.github.io/gentropy/python_api/methods/l2g/_l2g/ From 1bc6f3a1d03c05f6c0e5065ae90ec58acd94d850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 16:04:53 +0000 Subject: [PATCH 09/38] fix(l2g_prediction): methods that return new instance preserve attribute --- src/gentropy/dataset/l2g_prediction.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index d4b2e314c..f8534c4f7 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -186,6 +186,7 @@ def explain(self: L2GPrediction) -> L2GPrediction: ) .select(*self.get_schema().names), _schema=self.get_schema(), + model=self.model, ) def add_locus_to_gene_features( @@ -237,4 +238,5 @@ def add_locus_to_gene_features( return L2GPrediction( _df=self.df.join(aggregated_features, on=prediction_id_columns, how="left"), _schema=self.get_schema(), + model=self.model, ) From 8420933ac7fb3f3c4ab23cfa1c186cf8c097eb16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 16:06:03 +0000 Subject: [PATCH 10/38] feat(dataset): `filter` method preserves all instance attributes --- src/gentropy/dataset/dataset.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 67fe05eaf..442c0e6ba 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -156,15 +156,25 @@ def from_parquet( def filter(self: Self, condition: Column) -> Self: """Creates a new instance of a Dataset with the DataFrame filtered by the condition. + Preserves all attributes from the original instance. + Args: condition (Column): Condition to filter the DataFrame Returns: - Self: Filtered Dataset + Self: Filtered Dataset with preserved attributes """ df = self._df.filter(condition) class_constructor = self.__class__ - return class_constructor(_df=df, _schema=class_constructor.get_schema()) + # Get all attributes from the current instance + attrs = { + key: value + for key, value in self.__dict__.items() + if key not in ["_df", "_schema"] + } + return class_constructor( + _df=df, _schema=class_constructor.get_schema(), **attrs + ) def validate_schema(self: Dataset) -> None: """Validate DataFrame schema against expected class schema. From b98749676e64f57d346874b691bb918ed005d3f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 27 Jan 2025 19:21:49 +0000 Subject: [PATCH 11/38] feat(l2gmodel): add features_list as model attribute and load it from the hub metadata --- src/gentropy/l2g.py | 5 ++++- src/gentropy/method/l2g/model.py | 27 ++++++++++++++++++++++++--- src/gentropy/method/l2g/trainer.py | 21 +++++++++++++-------- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index b8b334944..71b4656d4 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -74,7 +74,9 @@ def __init__( else None ) target_index = ( - TargetIndex.from_parquet(session, target_index_path, recursiveFileLookup=True) + TargetIndex.from_parquet( + session, target_index_path, recursiveFileLookup=True + ) if target_index_path else None ) @@ -305,6 +307,7 @@ def run_train(self) -> None: l2g_model = LocusToGeneModel( model=GradientBoostingClassifier(random_state=42, loss="log_loss"), hyperparameters=self.hyperparameters, + features_list=self.features_list, ) # Calculate the gold standard features diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index c8cb1977a..2b011cc4e 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,6 +27,7 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) + features_list: list[str] = field(default_factory=list) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -51,11 +52,14 @@ def __post_init__(self: LocusToGeneModel) -> None: self.model.set_params(**self.hyperparameters_dict) @classmethod - def load_from_disk(cls: type[LocusToGeneModel], path: str) -> LocusToGeneModel: + def load_from_disk( + cls: type[LocusToGeneModel], path: str, **kwargs: Any + ) -> LocusToGeneModel: """Load a fitted model from disk. Args: path (str): Path to the model + **kwargs: Keyword arguments to pass to the constructor Returns: LocusToGeneModel: L2G model loaded from disk @@ -79,7 +83,7 @@ def load_from_disk(cls: type[LocusToGeneModel], path: str) -> LocusToGeneModel: if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - return cls(model=loaded_model) + return cls(model=loaded_model, **kwargs) @classmethod def load_from_hub( @@ -98,9 +102,26 @@ def load_from_hub( Returns: LocusToGeneModel: L2G model loaded from the Hugging Face Hub """ + + def get_features_list_from_metadata() -> list[str]: + """Get the features list (in the right order) from the metadata file from the Hub.""" + import json + + model_config_path = str(Path(local_path) / "config.json") + with open(model_config_path) as f: + model_config = json.load(f) + return [ + column + for column in model_config["sklearn"]["columns"] + if column != "studyLocusId" + ] + local_path = Path(model_id) hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) - return cls.load_from_disk(str(Path(local_path) / model_name)) + features_list = get_features_list_from_metadata() + return cls.load_from_disk( + str(Path(local_path) / model_name), features_list=features_list + ) @property def hyperparameters_dict(self) -> dict[str, Any]: diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index a43d6609d..a123cfda9 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -89,15 +89,20 @@ def fit( Raises: ValueError: Train data not set, nothing to fit. """ - if self.x_train is not None and self.y_train is not None: - assert ( - self.x_train.size != 0 and self.y_train.size != 0 - ), "Train data not set, nothing to fit." + if ( + self.x_train is not None + and self.y_train is not None + and self.features_list is not None + ): + assert self.x_train.size != 0 and self.y_train.size != 0, ( + "Train data not set, nothing to fit." + ) fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train) self.model = LocusToGeneModel( model=fitted_model, hyperparameters=fitted_model.get_params(), training_data=self.feature_matrix, + features_list=self.features_list, ) return self.model raise ValueError("Train data not set, nothing to fit.") @@ -184,9 +189,9 @@ def log_to_wandb( or self.features_list is None ): raise RuntimeError("Train data not set, we cannot log to W&B.") - assert ( - self.x_train.size != 0 and self.y_train.size != 0 - ), "Train data not set, nothing to evaluate." + assert self.x_train.size != 0 and self.y_train.size != 0, ( + "Train data not set, nothing to evaluate." + ) fitted_classifier = self.model.model y_predicted = fitted_classifier.predict(self.x_test) y_probas = fitted_classifier.predict_proba(self.x_test) @@ -456,7 +461,7 @@ def run_all_folds() -> None: cross_validate_single_fold( fold_index=fold_index, sweep_id=sweep_id, - sweep_run_name=f"{wandb_run_name}-fold{fold_index+1}", + sweep_run_name=f"{wandb_run_name}-fold{fold_index + 1}", config=config, ) From 12de6695de7086d4247dff3739c7c80ee8fde2cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 11:40:49 +0000 Subject: [PATCH 12/38] fix: pass correct order of features to shapley explainer --- src/gentropy/dataset/l2g_prediction.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 5172cc1c4..da2702f85 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -154,10 +154,9 @@ def explain(self: L2GPrediction) -> L2GPrediction: df_w_features = self.df.select( "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) ).drop("shapleyValues") - features_list = [ - col for col in df_w_features.columns if col not in self.get_schema().names - ] - pdf = df_w_features.select(features_list).toPandas() + # The matrix needs to present the features in the same order that the model was trained on + features_list = self.model.features_list + pdf = df_w_features.select(*features_list).toPandas() # Calculate SHAP values if pdf.shape[0] >= 10_000: From 78027da35ba54ca1b149adf649916b4bd9c1cfa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 12:11:31 +0000 Subject: [PATCH 13/38] feat(l2g): predict mode to extract feature list from model, not from config --- src/gentropy/dataset/l2g_prediction.py | 20 +++++++++++++------- src/gentropy/l2g.py | 17 +++++++++++------ src/gentropy/method/l2g/model.py | 12 +++++++++--- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index da2702f85..3ff45f598 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -50,7 +50,6 @@ def from_credible_set( session: Session, credible_set: StudyLocus, feature_matrix: L2GFeatureMatrix, - features_list: list[str], model_path: str | None, hf_token: str | None = None, download_from_hub: bool = True, @@ -61,7 +60,6 @@ def from_credible_set( session (Session): Session object that contains the Spark session credible_set (StudyLocus): Dataset containing credible sets from GWAS only feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations - features_list (list[str]): List of features to use for the model model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. @@ -88,7 +86,7 @@ def from_credible_set( ) ) .fill_na() - .select_features(features_list) + .select_features(l2g_model.features_list) ) predictions = l2g_model.predict(fm, session) @@ -189,17 +187,22 @@ def explain(self: L2GPrediction) -> L2GPrediction: ) def add_locus_to_gene_features( - self: L2GPrediction, feature_matrix: L2GFeatureMatrix, features_list: list[str] + self: L2GPrediction, + feature_matrix: L2GFeatureMatrix, ) -> L2GPrediction: """Add features used to extract the L2G predictions. Args: feature_matrix (L2GFeatureMatrix): Feature matrix dataset - features_list (list[str]): List of features used in the model Returns: L2GPrediction: L2G predictions with additional features + + Raises: + ValueError: If model is not set, feature list won't be available """ + if self.model is None: + raise ValueError("Model not set, feature annotation cannot be created.") # Testing if `locusToGeneFeatures` column already exists: if "locusToGeneFeatures" in self.df.columns: self.df = self.df.drop("locusToGeneFeatures") @@ -210,7 +213,10 @@ def add_locus_to_gene_features( "locusToGeneFeatures", f.create_map( *sum( - ((f.lit(feature), f.col(feature)) for feature in features_list), + ( + (f.lit(feature), f.col(feature)) + for feature in self.model.features_list + ), (), ) ), @@ -219,7 +225,7 @@ def add_locus_to_gene_features( "locusToGeneFeatures", f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"), ) - .drop(*features_list) + .drop(*self.model.features_list) ) return L2GPrediction( _df=self.df.join( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 71b4656d4..9b7993bf3 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -104,7 +104,6 @@ def __init__( session: Session, *, run_mode: str, - features_list: list[str], hyperparameters: dict[str, Any], download_from_hub: bool, cross_validate: bool, @@ -112,6 +111,7 @@ def __init__( credible_set_path: str, feature_matrix_path: str, model_path: str | None = None, + features_list: list[str] | None, gold_standard_curation_path: str | None = None, variant_index_path: str | None = None, gene_interactions_path: str | None = None, @@ -125,7 +125,6 @@ def __init__( Args: session (Session): Session object that contains the Spark session run_mode (str): Run mode, either 'train' or 'predict' - features_list (list[str]): List of features to use for the model hyperparameters (dict[str, Any]): Hyperparameters for the model download_from_hub (bool): Whether to download the model from Hugging Face Hub cross_validate (bool): Whether to run cross validation (5-fold by default) to train the model. @@ -133,6 +132,7 @@ def __init__( credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix feature_matrix_path (str): Path to the L2G feature matrix input dataset model_path (str | None): Path to the model. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + features_list (list[str] | None): List of features to use to train the model gold_standard_curation_path (str | None): Path to the gold standard curation file variant_index_path (str | None): Path to the variant index gene_interactions_path (str | None): Path to the gene interactions dataset @@ -153,7 +153,7 @@ def __init__( self.run_mode = run_mode self.model_path = model_path self.predictions_path = predictions_path - self.features_list = list(features_list) + self.features_list = list(features_list) if features_list else None self.hyperparameters = dict(hyperparameters) self.wandb_run_name = wandb_run_name self.cross_validate = cross_validate @@ -283,7 +283,6 @@ def run_predict(self) -> None: self.session, self.credible_set, self.feature_matrix, - self.features_list, model_path=self.model_path, hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, @@ -291,14 +290,20 @@ def run_predict(self) -> None: predictions.filter( f.col("score") >= self.l2g_threshold ).add_locus_to_gene_features( - self.feature_matrix, self.features_list + self.feature_matrix, ).explain().df.coalesce(self.session.output_partitions).write.mode( self.session.write_mode ).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") def run_train(self) -> None: - """Run the training step.""" + """Run the training step. + + Raises: + ValueError: If features list is not provided for model training. + """ + if self.features_list is None: + raise ValueError("Features list is required for model training.") # Initialize access to weights and biases wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") wandb_login(key=wandb_key) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 2b011cc4e..5bd8e6319 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,7 +27,9 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - features_list: list[str] = field(default_factory=list) + features_list: list[str] = field( + default_factory=list + ) # TODO: default to list in config if not provided hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -59,7 +61,7 @@ def load_from_disk( Args: path (str): Path to the model - **kwargs: Keyword arguments to pass to the constructor + **kwargs(Any): Keyword arguments to pass to the constructor Returns: LocusToGeneModel: L2G model loaded from disk @@ -104,7 +106,11 @@ def load_from_hub( """ def get_features_list_from_metadata() -> list[str]: - """Get the features list (in the right order) from the metadata file from the Hub.""" + """Get the features list (in the right order) from the metadata JSON file downloaded from the Hub. + + Returns: + list[str]: Features list + """ import json model_config_path = str(Path(local_path) / "config.json") From 48b78ab01c272389078cbd20404fb74be81f1863 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 12:40:27 +0000 Subject: [PATCH 14/38] feat(l2g): pass default features list if model is loaded from a path --- src/gentropy/dataset/l2g_prediction.py | 13 ++++++++++++- src/gentropy/l2g.py | 1 + src/gentropy/method/l2g/model.py | 4 +--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 3ff45f598..e3d223f78 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -51,6 +51,7 @@ def from_credible_set( credible_set: StudyLocus, feature_matrix: L2GFeatureMatrix, model_path: str | None, + features_list: list[str] | None = None, hf_token: str | None = None, download_from_hub: bool = True, ) -> L2GPrediction: @@ -61,11 +62,15 @@ def from_credible_set( credible_set (StudyLocus): Dataset containing credible sets from GWAS only feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + features_list (list[str] | None): Default list of features the model uses. Only used if the model is not downloaded from the Hub. CAUTION: This default list can differ from the actual list the model was trained on. hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. Returns: L2GPrediction: L2G scores for a set of credible sets. + + Raises: + AttributeError: If `features_list` is not provided and the model is not downloaded from the Hub. """ # Load the model if download_from_hub: @@ -73,7 +78,13 @@ def from_credible_set( model_id = model_path or "opentargets/locus_to_gene" l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token) elif model_path: - l2g_model = LocusToGeneModel.load_from_disk(model_path) + if not features_list: + raise AttributeError( + "features_list is required if the model is not downloaded from the Hub" + ) + l2g_model = LocusToGeneModel.load_from_disk( + model_path, features_list=features_list + ) # Prepare data fm = ( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 9b7993bf3..8177fefc2 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -284,6 +284,7 @@ def run_predict(self) -> None: self.credible_set, self.feature_matrix, model_path=self.model_path, + features_list=self.features_list, hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, ) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 5bd8e6319..baa6083de 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,9 +27,7 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - features_list: list[str] = field( - default_factory=list - ) # TODO: default to list in config if not provided + features_list: list[str] = field(default_factory=list) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, From cb9d2e313530e1bf3bfc7dbbd7ae7f1bfa86f0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 27 Jan 2025 19:21:49 +0000 Subject: [PATCH 15/38] feat(l2gmodel): add features_list as model attribute and load it from the hub metadata --- src/gentropy/l2g.py | 5 ++++- src/gentropy/method/l2g/model.py | 27 ++++++++++++++++++++++++--- src/gentropy/method/l2g/trainer.py | 21 +++++++++++++-------- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 705725315..d7728f569 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -74,7 +74,9 @@ def __init__( else None ) target_index = ( - TargetIndex.from_parquet(session, target_index_path, recursiveFileLookup=True) + TargetIndex.from_parquet( + session, target_index_path, recursiveFileLookup=True + ) if target_index_path else None ) @@ -305,6 +307,7 @@ def run_train(self) -> None: l2g_model = LocusToGeneModel( model=GradientBoostingClassifier(random_state=42, loss="log_loss"), hyperparameters=self.hyperparameters, + features_list=self.features_list, ) # Calculate the gold standard features diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index c1aea9e08..cccb4a47b 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,6 +27,7 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) + features_list: list[str] = field(default_factory=list) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -51,11 +52,14 @@ def __post_init__(self: LocusToGeneModel) -> None: self.model.set_params(**self.hyperparameters_dict) @classmethod - def load_from_disk(cls: type[LocusToGeneModel], path: str) -> LocusToGeneModel: + def load_from_disk( + cls: type[LocusToGeneModel], path: str, **kwargs: Any + ) -> LocusToGeneModel: """Load a fitted model from disk. Args: path (str): Path to the model + **kwargs: Keyword arguments to pass to the constructor Returns: LocusToGeneModel: L2G model loaded from disk @@ -79,7 +83,7 @@ def load_from_disk(cls: type[LocusToGeneModel], path: str) -> LocusToGeneModel: if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - return cls(model=loaded_model) + return cls(model=loaded_model, **kwargs) @classmethod def load_from_hub( @@ -98,9 +102,26 @@ def load_from_hub( Returns: LocusToGeneModel: L2G model loaded from the Hugging Face Hub """ + + def get_features_list_from_metadata() -> list[str]: + """Get the features list (in the right order) from the metadata file from the Hub.""" + import json + + model_config_path = str(Path(local_path) / "config.json") + with open(model_config_path) as f: + model_config = json.load(f) + return [ + column + for column in model_config["sklearn"]["columns"] + if column != "studyLocusId" + ] + local_path = Path(model_id) hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) - return cls.load_from_disk(str(Path(local_path) / model_name)) + features_list = get_features_list_from_metadata() + return cls.load_from_disk( + str(Path(local_path) / model_name), features_list=features_list + ) @property def hyperparameters_dict(self) -> dict[str, Any]: diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index a43d6609d..a123cfda9 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -89,15 +89,20 @@ def fit( Raises: ValueError: Train data not set, nothing to fit. """ - if self.x_train is not None and self.y_train is not None: - assert ( - self.x_train.size != 0 and self.y_train.size != 0 - ), "Train data not set, nothing to fit." + if ( + self.x_train is not None + and self.y_train is not None + and self.features_list is not None + ): + assert self.x_train.size != 0 and self.y_train.size != 0, ( + "Train data not set, nothing to fit." + ) fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train) self.model = LocusToGeneModel( model=fitted_model, hyperparameters=fitted_model.get_params(), training_data=self.feature_matrix, + features_list=self.features_list, ) return self.model raise ValueError("Train data not set, nothing to fit.") @@ -184,9 +189,9 @@ def log_to_wandb( or self.features_list is None ): raise RuntimeError("Train data not set, we cannot log to W&B.") - assert ( - self.x_train.size != 0 and self.y_train.size != 0 - ), "Train data not set, nothing to evaluate." + assert self.x_train.size != 0 and self.y_train.size != 0, ( + "Train data not set, nothing to evaluate." + ) fitted_classifier = self.model.model y_predicted = fitted_classifier.predict(self.x_test) y_probas = fitted_classifier.predict_proba(self.x_test) @@ -456,7 +461,7 @@ def run_all_folds() -> None: cross_validate_single_fold( fold_index=fold_index, sweep_id=sweep_id, - sweep_run_name=f"{wandb_run_name}-fold{fold_index+1}", + sweep_run_name=f"{wandb_run_name}-fold{fold_index + 1}", config=config, ) From eedc6ab7adc293e2dd301e93160296bb9644e9f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 12:11:31 +0000 Subject: [PATCH 16/38] feat(l2g): predict mode to extract feature list from model, not from config --- src/gentropy/dataset/l2g_prediction.py | 20 +++++++++++++------- src/gentropy/l2g.py | 19 ++++++++++++------- src/gentropy/method/l2g/model.py | 12 +++++++++--- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 02818b7db..ed9d9dd62 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -44,7 +44,6 @@ def from_credible_set( session: Session, credible_set: StudyLocus, feature_matrix: L2GFeatureMatrix, - features_list: list[str], model_path: str | None, hf_token: str | None = None, download_from_hub: bool = True, @@ -55,7 +54,6 @@ def from_credible_set( session (Session): Session object that contains the Spark session credible_set (StudyLocus): Dataset containing credible sets from GWAS only feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations - features_list (list[str]): List of features to use for the model model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. @@ -82,7 +80,7 @@ def from_credible_set( ) ) .fill_na() - .select_features(features_list) + .select_features(l2g_model.features_list) ) return l2g_model.predict(fm, session) @@ -129,17 +127,22 @@ def to_disease_target_evidence( ) def add_locus_to_gene_features( - self: L2GPrediction, feature_matrix: L2GFeatureMatrix, features_list: list[str] + self: L2GPrediction, + feature_matrix: L2GFeatureMatrix, ) -> L2GPrediction: """Add features used to extract the L2G predictions. Args: feature_matrix (L2GFeatureMatrix): Feature matrix dataset - features_list (list[str]): List of features used in the model Returns: L2GPrediction: L2G predictions with additional features + + Raises: + ValueError: If model is not set, feature list won't be available """ + if self.model is None: + raise ValueError("Model not set, feature annotation cannot be created.") # Testing if `locusToGeneFeatures` column already exists: if "locusToGeneFeatures" in self.df.columns: self.df = self.df.drop("locusToGeneFeatures") @@ -150,7 +153,10 @@ def add_locus_to_gene_features( "locusToGeneFeatures", f.create_map( *sum( - ((f.lit(feature), f.col(feature)) for feature in features_list), + ( + (f.lit(feature), f.col(feature)) + for feature in self.model.features_list + ), (), ) ), @@ -159,7 +165,7 @@ def add_locus_to_gene_features( "locusToGeneFeatures", f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"), ) - .drop(*features_list) + .drop(*self.model.features_list) ) return L2GPrediction( _df=self.df.join( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index d7728f569..9b7993bf3 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -104,7 +104,6 @@ def __init__( session: Session, *, run_mode: str, - features_list: list[str], hyperparameters: dict[str, Any], download_from_hub: bool, cross_validate: bool, @@ -112,6 +111,7 @@ def __init__( credible_set_path: str, feature_matrix_path: str, model_path: str | None = None, + features_list: list[str] | None, gold_standard_curation_path: str | None = None, variant_index_path: str | None = None, gene_interactions_path: str | None = None, @@ -125,7 +125,6 @@ def __init__( Args: session (Session): Session object that contains the Spark session run_mode (str): Run mode, either 'train' or 'predict' - features_list (list[str]): List of features to use for the model hyperparameters (dict[str, Any]): Hyperparameters for the model download_from_hub (bool): Whether to download the model from Hugging Face Hub cross_validate (bool): Whether to run cross validation (5-fold by default) to train the model. @@ -133,6 +132,7 @@ def __init__( credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix feature_matrix_path (str): Path to the L2G feature matrix input dataset model_path (str | None): Path to the model. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + features_list (list[str] | None): List of features to use to train the model gold_standard_curation_path (str | None): Path to the gold standard curation file variant_index_path (str | None): Path to the variant index gene_interactions_path (str | None): Path to the gene interactions dataset @@ -153,7 +153,7 @@ def __init__( self.run_mode = run_mode self.model_path = model_path self.predictions_path = predictions_path - self.features_list = list(features_list) + self.features_list = list(features_list) if features_list else None self.hyperparameters = dict(hyperparameters) self.wandb_run_name = wandb_run_name self.cross_validate = cross_validate @@ -283,7 +283,6 @@ def run_predict(self) -> None: self.session, self.credible_set, self.feature_matrix, - self.features_list, model_path=self.model_path, hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, @@ -291,14 +290,20 @@ def run_predict(self) -> None: predictions.filter( f.col("score") >= self.l2g_threshold ).add_locus_to_gene_features( - self.feature_matrix, self.features_list - ).df.coalesce(self.session.output_partitions).write.mode( + self.feature_matrix, + ).explain().df.coalesce(self.session.output_partitions).write.mode( self.session.write_mode ).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") def run_train(self) -> None: - """Run the training step.""" + """Run the training step. + + Raises: + ValueError: If features list is not provided for model training. + """ + if self.features_list is None: + raise ValueError("Features list is required for model training.") # Initialize access to weights and biases wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") wandb_login(key=wandb_key) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index cccb4a47b..338d2f8f0 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,7 +27,9 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - features_list: list[str] = field(default_factory=list) + features_list: list[str] = field( + default_factory=list + ) # TODO: default to list in config if not provided hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -59,7 +61,7 @@ def load_from_disk( Args: path (str): Path to the model - **kwargs: Keyword arguments to pass to the constructor + **kwargs(Any): Keyword arguments to pass to the constructor Returns: LocusToGeneModel: L2G model loaded from disk @@ -104,7 +106,11 @@ def load_from_hub( """ def get_features_list_from_metadata() -> list[str]: - """Get the features list (in the right order) from the metadata file from the Hub.""" + """Get the features list (in the right order) from the metadata JSON file downloaded from the Hub. + + Returns: + list[str]: Features list + """ import json model_config_path = str(Path(local_path) / "config.json") From 624602e6d9488b3be0a132a318689f3a38e98736 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 15:43:11 +0000 Subject: [PATCH 17/38] feat(l2gprediction): add `model` as attribute --- src/gentropy/dataset/l2g_prediction.py | 6 ++++-- src/gentropy/method/l2g/model.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index ed9d9dd62..255722414 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING import pyspark.sql.functions as f @@ -29,6 +29,8 @@ class L2GPrediction(Dataset): confidence of the prediction that a gene is causal to an association. """ + model: LocusToGeneModel | None = field(default=None, repr=False) + @classmethod def get_schema(cls: type[L2GPrediction]) -> StructType: """Provides the schema for the L2GPrediction dataset. @@ -82,7 +84,6 @@ def from_credible_set( .fill_na() .select_features(l2g_model.features_list) ) - return l2g_model.predict(fm, session) def to_disease_target_evidence( @@ -172,4 +173,5 @@ def add_locus_to_gene_features( aggregated_features, on=["studyLocusId", "geneId"], how="left" ), _schema=self.get_schema(), + model=self.model, ) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 338d2f8f0..9d9011332 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -175,6 +175,7 @@ def predict( return L2GPrediction( _df=session.spark.createDataFrame(feature_matrix_pdf.filter(output_cols)), _schema=L2GPrediction.get_schema(), + model=self, ) def save(self: LocusToGeneModel, path: str) -> None: From 1766ea17eb1516533d6456c12961e48b9030b16b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 27 Jan 2025 19:21:49 +0000 Subject: [PATCH 18/38] feat(l2gmodel): add features_list as model attribute and load it from the hub metadata --- src/gentropy/l2g.py | 5 ++++- src/gentropy/method/l2g/model.py | 27 ++++++++++++++++++++++++--- src/gentropy/method/l2g/trainer.py | 21 +++++++++++++-------- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 705725315..d7728f569 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -74,7 +74,9 @@ def __init__( else None ) target_index = ( - TargetIndex.from_parquet(session, target_index_path, recursiveFileLookup=True) + TargetIndex.from_parquet( + session, target_index_path, recursiveFileLookup=True + ) if target_index_path else None ) @@ -305,6 +307,7 @@ def run_train(self) -> None: l2g_model = LocusToGeneModel( model=GradientBoostingClassifier(random_state=42, loss="log_loss"), hyperparameters=self.hyperparameters, + features_list=self.features_list, ) # Calculate the gold standard features diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index c1aea9e08..cccb4a47b 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,6 +27,7 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) + features_list: list[str] = field(default_factory=list) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -51,11 +52,14 @@ def __post_init__(self: LocusToGeneModel) -> None: self.model.set_params(**self.hyperparameters_dict) @classmethod - def load_from_disk(cls: type[LocusToGeneModel], path: str) -> LocusToGeneModel: + def load_from_disk( + cls: type[LocusToGeneModel], path: str, **kwargs: Any + ) -> LocusToGeneModel: """Load a fitted model from disk. Args: path (str): Path to the model + **kwargs: Keyword arguments to pass to the constructor Returns: LocusToGeneModel: L2G model loaded from disk @@ -79,7 +83,7 @@ def load_from_disk(cls: type[LocusToGeneModel], path: str) -> LocusToGeneModel: if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - return cls(model=loaded_model) + return cls(model=loaded_model, **kwargs) @classmethod def load_from_hub( @@ -98,9 +102,26 @@ def load_from_hub( Returns: LocusToGeneModel: L2G model loaded from the Hugging Face Hub """ + + def get_features_list_from_metadata() -> list[str]: + """Get the features list (in the right order) from the metadata file from the Hub.""" + import json + + model_config_path = str(Path(local_path) / "config.json") + with open(model_config_path) as f: + model_config = json.load(f) + return [ + column + for column in model_config["sklearn"]["columns"] + if column != "studyLocusId" + ] + local_path = Path(model_id) hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) - return cls.load_from_disk(str(Path(local_path) / model_name)) + features_list = get_features_list_from_metadata() + return cls.load_from_disk( + str(Path(local_path) / model_name), features_list=features_list + ) @property def hyperparameters_dict(self) -> dict[str, Any]: diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index a43d6609d..a123cfda9 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -89,15 +89,20 @@ def fit( Raises: ValueError: Train data not set, nothing to fit. """ - if self.x_train is not None and self.y_train is not None: - assert ( - self.x_train.size != 0 and self.y_train.size != 0 - ), "Train data not set, nothing to fit." + if ( + self.x_train is not None + and self.y_train is not None + and self.features_list is not None + ): + assert self.x_train.size != 0 and self.y_train.size != 0, ( + "Train data not set, nothing to fit." + ) fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train) self.model = LocusToGeneModel( model=fitted_model, hyperparameters=fitted_model.get_params(), training_data=self.feature_matrix, + features_list=self.features_list, ) return self.model raise ValueError("Train data not set, nothing to fit.") @@ -184,9 +189,9 @@ def log_to_wandb( or self.features_list is None ): raise RuntimeError("Train data not set, we cannot log to W&B.") - assert ( - self.x_train.size != 0 and self.y_train.size != 0 - ), "Train data not set, nothing to evaluate." + assert self.x_train.size != 0 and self.y_train.size != 0, ( + "Train data not set, nothing to evaluate." + ) fitted_classifier = self.model.model y_predicted = fitted_classifier.predict(self.x_test) y_probas = fitted_classifier.predict_proba(self.x_test) @@ -456,7 +461,7 @@ def run_all_folds() -> None: cross_validate_single_fold( fold_index=fold_index, sweep_id=sweep_id, - sweep_run_name=f"{wandb_run_name}-fold{fold_index+1}", + sweep_run_name=f"{wandb_run_name}-fold{fold_index + 1}", config=config, ) From c0657dc7d4d21c7b5a3643b144694c117dbdb1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 12:11:31 +0000 Subject: [PATCH 19/38] feat(l2g): predict mode to extract feature list from model, not from config --- src/gentropy/dataset/l2g_prediction.py | 20 +++++++++++++------- src/gentropy/l2g.py | 19 ++++++++++++------- src/gentropy/method/l2g/model.py | 12 +++++++++--- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 02818b7db..ed9d9dd62 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -44,7 +44,6 @@ def from_credible_set( session: Session, credible_set: StudyLocus, feature_matrix: L2GFeatureMatrix, - features_list: list[str], model_path: str | None, hf_token: str | None = None, download_from_hub: bool = True, @@ -55,7 +54,6 @@ def from_credible_set( session (Session): Session object that contains the Spark session credible_set (StudyLocus): Dataset containing credible sets from GWAS only feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations - features_list (list[str]): List of features to use for the model model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. @@ -82,7 +80,7 @@ def from_credible_set( ) ) .fill_na() - .select_features(features_list) + .select_features(l2g_model.features_list) ) return l2g_model.predict(fm, session) @@ -129,17 +127,22 @@ def to_disease_target_evidence( ) def add_locus_to_gene_features( - self: L2GPrediction, feature_matrix: L2GFeatureMatrix, features_list: list[str] + self: L2GPrediction, + feature_matrix: L2GFeatureMatrix, ) -> L2GPrediction: """Add features used to extract the L2G predictions. Args: feature_matrix (L2GFeatureMatrix): Feature matrix dataset - features_list (list[str]): List of features used in the model Returns: L2GPrediction: L2G predictions with additional features + + Raises: + ValueError: If model is not set, feature list won't be available """ + if self.model is None: + raise ValueError("Model not set, feature annotation cannot be created.") # Testing if `locusToGeneFeatures` column already exists: if "locusToGeneFeatures" in self.df.columns: self.df = self.df.drop("locusToGeneFeatures") @@ -150,7 +153,10 @@ def add_locus_to_gene_features( "locusToGeneFeatures", f.create_map( *sum( - ((f.lit(feature), f.col(feature)) for feature in features_list), + ( + (f.lit(feature), f.col(feature)) + for feature in self.model.features_list + ), (), ) ), @@ -159,7 +165,7 @@ def add_locus_to_gene_features( "locusToGeneFeatures", f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"), ) - .drop(*features_list) + .drop(*self.model.features_list) ) return L2GPrediction( _df=self.df.join( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index d7728f569..9b7993bf3 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -104,7 +104,6 @@ def __init__( session: Session, *, run_mode: str, - features_list: list[str], hyperparameters: dict[str, Any], download_from_hub: bool, cross_validate: bool, @@ -112,6 +111,7 @@ def __init__( credible_set_path: str, feature_matrix_path: str, model_path: str | None = None, + features_list: list[str] | None, gold_standard_curation_path: str | None = None, variant_index_path: str | None = None, gene_interactions_path: str | None = None, @@ -125,7 +125,6 @@ def __init__( Args: session (Session): Session object that contains the Spark session run_mode (str): Run mode, either 'train' or 'predict' - features_list (list[str]): List of features to use for the model hyperparameters (dict[str, Any]): Hyperparameters for the model download_from_hub (bool): Whether to download the model from Hugging Face Hub cross_validate (bool): Whether to run cross validation (5-fold by default) to train the model. @@ -133,6 +132,7 @@ def __init__( credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix feature_matrix_path (str): Path to the L2G feature matrix input dataset model_path (str | None): Path to the model. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + features_list (list[str] | None): List of features to use to train the model gold_standard_curation_path (str | None): Path to the gold standard curation file variant_index_path (str | None): Path to the variant index gene_interactions_path (str | None): Path to the gene interactions dataset @@ -153,7 +153,7 @@ def __init__( self.run_mode = run_mode self.model_path = model_path self.predictions_path = predictions_path - self.features_list = list(features_list) + self.features_list = list(features_list) if features_list else None self.hyperparameters = dict(hyperparameters) self.wandb_run_name = wandb_run_name self.cross_validate = cross_validate @@ -283,7 +283,6 @@ def run_predict(self) -> None: self.session, self.credible_set, self.feature_matrix, - self.features_list, model_path=self.model_path, hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, @@ -291,14 +290,20 @@ def run_predict(self) -> None: predictions.filter( f.col("score") >= self.l2g_threshold ).add_locus_to_gene_features( - self.feature_matrix, self.features_list - ).df.coalesce(self.session.output_partitions).write.mode( + self.feature_matrix, + ).explain().df.coalesce(self.session.output_partitions).write.mode( self.session.write_mode ).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") def run_train(self) -> None: - """Run the training step.""" + """Run the training step. + + Raises: + ValueError: If features list is not provided for model training. + """ + if self.features_list is None: + raise ValueError("Features list is required for model training.") # Initialize access to weights and biases wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") wandb_login(key=wandb_key) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index cccb4a47b..338d2f8f0 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -27,7 +27,9 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - features_list: list[str] = field(default_factory=list) + features_list: list[str] = field( + default_factory=list + ) # TODO: default to list in config if not provided hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -59,7 +61,7 @@ def load_from_disk( Args: path (str): Path to the model - **kwargs: Keyword arguments to pass to the constructor + **kwargs(Any): Keyword arguments to pass to the constructor Returns: LocusToGeneModel: L2G model loaded from disk @@ -104,7 +106,11 @@ def load_from_hub( """ def get_features_list_from_metadata() -> list[str]: - """Get the features list (in the right order) from the metadata file from the Hub.""" + """Get the features list (in the right order) from the metadata JSON file downloaded from the Hub. + + Returns: + list[str]: Features list + """ import json model_config_path = str(Path(local_path) / "config.json") From 0f5c2448db92f20e1b61cefd6e0967efb3e8c905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 15:43:11 +0000 Subject: [PATCH 20/38] feat(l2gprediction): add `model` as attribute --- src/gentropy/dataset/l2g_prediction.py | 6 ++++-- src/gentropy/method/l2g/model.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index ed9d9dd62..255722414 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING import pyspark.sql.functions as f @@ -29,6 +29,8 @@ class L2GPrediction(Dataset): confidence of the prediction that a gene is causal to an association. """ + model: LocusToGeneModel | None = field(default=None, repr=False) + @classmethod def get_schema(cls: type[L2GPrediction]) -> StructType: """Provides the schema for the L2GPrediction dataset. @@ -82,7 +84,6 @@ def from_credible_set( .fill_na() .select_features(l2g_model.features_list) ) - return l2g_model.predict(fm, session) def to_disease_target_evidence( @@ -172,4 +173,5 @@ def add_locus_to_gene_features( aggregated_features, on=["studyLocusId", "geneId"], how="left" ), _schema=self.get_schema(), + model=self.model, ) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 338d2f8f0..9d9011332 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -175,6 +175,7 @@ def predict( return L2GPrediction( _df=session.spark.createDataFrame(feature_matrix_pdf.filter(output_cols)), _schema=L2GPrediction.get_schema(), + model=self, ) def save(self: LocusToGeneModel, path: str) -> None: From 7423cac4c0d20dd61a0abf6e8ce812ec68534849 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 18:03:42 +0000 Subject: [PATCH 21/38] chore: fix typo --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1a022b1d..843a326dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: - id: python-check-blanket-noqa - repo: https://github.com/hadialqattan/pycln - rev: v2.4.0 + rev: v2.5.0 hooks: - id: pycln args: [--all] From 450a45a52bd7bd5f3c82477e401304c5f06cbd6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 18:14:05 +0000 Subject: [PATCH 22/38] chore: remove `convert_map_type_to_columns` --- src/gentropy/common/spark_helpers.py | 32 ---------------------------- 1 file changed, 32 deletions(-) diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index a77cb1369..08bbe069c 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -886,35 +886,3 @@ def calculate_harmonic_sum(input_array: Column) -> Column: / f.pow(x["pos"], 2) / f.lit(sum(1 / ((i + 1) ** 2) for i in range(1000))), ) - - -def convert_map_type_to_columns(df: DataFrame, map_column: Column) -> list[Column]: - """Convert a MapType column into multiple columns, one for each key in the map. - - Args: - df (DataFrame): A Spark DataFrame - map_column (Column): A Spark Column of MapType - - Returns: - list[Column]: List of columns, one for each key in the map - - Examples: - >>> df = spark.createDataFrame([({'a': 1, 'b': 2},), ({'c':3},)], ["map_col"]) - >>> df.select(*convert_map_type_to_columns(df, f.col("map_col"))).show() - +----+----+----+ - | a| b| c| - +----+----+----+ - | 1| 2|null| - |null|null| 3| - +----+----+----+ - - """ - # Schema is agnostic of the map keys, I have to collect them first - keys = ( - df.select(f.explode(map_column)) - .select("key") - .distinct() - .rdd.flatMap(lambda x: x) - .collect() - ) - return [map_column.getItem(k).alias(k) for k in keys] From 2bf111254b19a95e839d6807c712836187aab459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 28 Jan 2025 18:29:19 +0000 Subject: [PATCH 23/38] feat(l2gprediction): refactor feature annotation and change schema --- .../assets/schemas/l2g_predictions.json | 25 +++- src/gentropy/dataset/l2g_prediction.py | 141 +++++++++--------- src/gentropy/l2g.py | 4 +- 3 files changed, 89 insertions(+), 81 deletions(-) diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 8bda086a3..c36e63979 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -21,13 +21,28 @@ }, { "metadata": {}, - "name": "locusToGeneFeatures", + "name": "features", "nullable": true, "type": { - "keyType": "string", - "type": "map", - "valueContainsNull": true, - "valueType": "float" + "containsNull": false, + "elementType": { + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": false, + "type": "string" + }, + { + "metadata": {}, + "name": "value", + "nullable": false, + "type": "float" + } + ], + "type": "struct" + }, + "type": "array" } }, { diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 661aa3b82..e0d647f4d 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -13,7 +13,6 @@ from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session -from gentropy.common.spark_helpers import convert_map_type_to_columns from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex @@ -142,7 +141,7 @@ def to_disease_target_evidence( ) ) - def explain(self: L2GPrediction) -> L2GPrediction: + def explain(self: L2GPrediction) -> NotImplementedError: """Extract Shapley values for the L2G predictions and add them as a map in an additional column. Returns: @@ -151,50 +150,52 @@ def explain(self: L2GPrediction) -> L2GPrediction: Raises: ValueError: If the model is not set """ - if self.model is None: - raise ValueError("Model not set, explainer cannot be created") - - explainer = shap.TreeExplainer( - self.model.model, feature_perturbation="tree_path_dependent" - ) - df_w_features = self.df.select( - "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ).drop("shapleyValues") - # The matrix needs to present the features in the same order that the model was trained on - features_list = self.model.features_list - pdf = df_w_features.select(*features_list).toPandas() - - # Calculate SHAP values - if pdf.shape[0] >= 10_000: - logging.warning( - "Calculating SHAP values for more than 10,000 rows. This may take a while..." - ) - shap_values = explainer.shap_values(pdf.to_numpy()) - for i, feature in enumerate(features_list): - pdf[f"shap_{feature}"] = [row[i] for row in shap_values] - - spark_session = df_w_features.sparkSession - return L2GPrediction( - _df=df_w_features.join( - # Convert df with shapley values to Spark and join with original df - spark_session.createDataFrame(pdf.to_dict(orient="records")), - features_list, - ) - .withColumn( - "shapleyValues", - f.create_map( - *sum( - ((f.lit(col), f.col(f"shap_{col}")) for col in features_list), - (), - ) - ), - ) - .select(*self.get_schema().names), - _schema=self.get_schema(), - model=self.model, - ) - - def add_locus_to_gene_features( + return NotImplementedError + + # if self.model is None: + # raise ValueError("Model not set, explainer cannot be created") + + # explainer = shap.TreeExplainer( + # self.model.model, feature_perturbation="tree_path_dependent" + # ) + # df_w_features = self.df.select( + # "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + # ).drop("shapleyValues") + # # The matrix needs to present the features in the same order that the model was trained on + # features_list = self.model.features_list + # pdf = df_w_features.select(*features_list).toPandas() + + # # Calculate SHAP values + # if pdf.shape[0] >= 10_000: + # logging.warning( + # "Calculating SHAP values for more than 10,000 rows. This may take a while..." + # ) + # shap_values = explainer.shap_values(pdf.to_numpy()) + # for i, feature in enumerate(features_list): + # pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + + # spark_session = df_w_features.sparkSession + # return L2GPrediction( + # _df=df_w_features.join( + # # Convert df with shapley values to Spark and join with original df + # spark_session.createDataFrame(pdf.to_dict(orient="records")), + # features_list, + # ) + # .withColumn( + # "shapleyValues", + # f.create_map( + # *sum( + # ((f.lit(col), f.col(f"shap_{col}")) for col in features_list), + # (), + # ) + # ), + # ) + # .select(*self.get_schema().names), + # _schema=self.get_schema(), + # model=self.model, + # ) + + def add_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix, ) -> L2GPrediction: @@ -204,40 +205,34 @@ def add_locus_to_gene_features( feature_matrix (L2GFeatureMatrix): Feature matrix dataset Returns: - L2GPrediction: L2G predictions with additional features + L2GPrediction: L2G predictions with additional column `features` Raises: ValueError: If model is not set, feature list won't be available """ if self.model is None: raise ValueError("Model not set, feature annotation cannot be created.") - # Testing if `locusToGeneFeatures` column already exists: - if "locusToGeneFeatures" in self.df.columns: - self.df = self.df.drop("locusToGeneFeatures") - - # Aggregating all features into a single map column: - aggregated_features = ( - feature_matrix._df.withColumn( - "locusToGeneFeatures", - f.create_map( - *sum( - ( - (f.lit(feature), f.col(feature)) - for feature in self.model.features_list - ), - (), - ) - ), - ) - .withColumn( - "locusToGeneFeatures", - f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"), - ) - .drop(*self.model.features_list) - ) + # Testing if `features` column already exists: + if "features" in self.df.columns: + self.df = self.df.drop("features") + + features_list = self.model.features_list + feature_expressions = [ + f.struct(f.lit(col).alias("name"), f.col(col).alias("value")) + for col in features_list + ] return L2GPrediction( - _df=self.df.join( - aggregated_features, on=["studyLocusId", "geneId"], how="left" + _df=( + self.df.join( + feature_matrix._df.select(*features_list), + on=["studyLocusId", "geneId"], + how="left", + ).select( + "studyLocusId", + "geneId", + "score", + f.array(*feature_expressions).alias("features"), + ) ), _schema=self.get_schema(), model=self.model, diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 8177fefc2..23ff670da 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -288,9 +288,7 @@ def run_predict(self) -> None: hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, ) - predictions.filter( - f.col("score") >= self.l2g_threshold - ).add_locus_to_gene_features( + predictions.filter(f.col("score") >= self.l2g_threshold).add_features( self.feature_matrix, ).explain().df.coalesce(self.session.output_partitions).write.mode( self.session.write_mode From 30a4676d5ed19d96a4a3f0b0aaf12c6cab149099 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:30:04 +0000 Subject: [PATCH 24/38] chore: pre-commit auto fixes [...] --- src/gentropy/dataset/l2g_prediction.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index e0d647f4d..17d45b8b8 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,12 +2,10 @@ from __future__ import annotations -import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING import pyspark.sql.functions as f -import shap from pyspark.sql import DataFrame from pyspark.sql.types import StructType From 9a98332b8db4b1bb049fc96015aa3bf123a0d181 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 4 Feb 2025 10:20:03 +0100 Subject: [PATCH 25/38] feat: report as log odds --- .../assets/schemas/l2g_predictions.json | 25 ++- src/gentropy/dataset/l2g_prediction.py | 163 +++++++++++------- 2 files changed, 118 insertions(+), 70 deletions(-) diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index c36e63979..62588671d 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -38,6 +38,18 @@ "name": "value", "nullable": false, "type": "float" + }, + { + "metadata": {}, + "name": "shapValue", + "nullable": false, + "type": "float" + }, + { + "metadata": {}, + "name": "shapProbabilityContribution", + "nullable": false, + "type": "float" } ], "type": "struct" @@ -46,15 +58,10 @@ } }, { - "metadata": {}, - "name": "shapleyValues", - "nullable": true, - "type": { - "keyType": "string", - "type": "map", - "valueContainsNull": false, - "valueType": "float" - } + "name": "shapBaseProbability", + "type": "double", + "nullable": false, + "metadata": {} } ] } diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index e0d647f4d..fbeb2cd8a 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -10,9 +10,11 @@ import shap from pyspark.sql import DataFrame from pyspark.sql.types import StructType +from scipy.special import expit from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session +from gentropy.common.spark_helpers import pivot_df from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex @@ -20,6 +22,7 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: + from pandas import DataFrame as pd_dataframe from pyspark.sql.types import StructType @@ -141,59 +144,102 @@ def to_disease_target_evidence( ) ) - def explain(self: L2GPrediction) -> NotImplementedError: + def explain( + self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None + ) -> L2GPrediction: """Extract Shapley values for the L2G predictions and add them as a map in an additional column. + Args: + feature_matrix (L2GFeatureMatrix | None): Feature matrix in case the predictions are missing the feature annotation. If None, the features are fetched from the dataset. + Returns: L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings Raises: - ValueError: If the model is not set + ValueError: If the model is not set or If feature matrix is not provided and the predictions do not have features """ - return NotImplementedError - - # if self.model is None: - # raise ValueError("Model not set, explainer cannot be created") - - # explainer = shap.TreeExplainer( - # self.model.model, feature_perturbation="tree_path_dependent" - # ) - # df_w_features = self.df.select( - # "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - # ).drop("shapleyValues") - # # The matrix needs to present the features in the same order that the model was trained on - # features_list = self.model.features_list - # pdf = df_w_features.select(*features_list).toPandas() - - # # Calculate SHAP values - # if pdf.shape[0] >= 10_000: - # logging.warning( - # "Calculating SHAP values for more than 10,000 rows. This may take a while..." - # ) - # shap_values = explainer.shap_values(pdf.to_numpy()) - # for i, feature in enumerate(features_list): - # pdf[f"shap_{feature}"] = [row[i] for row in shap_values] - - # spark_session = df_w_features.sparkSession - # return L2GPrediction( - # _df=df_w_features.join( - # # Convert df with shapley values to Spark and join with original df - # spark_session.createDataFrame(pdf.to_dict(orient="records")), - # features_list, - # ) - # .withColumn( - # "shapleyValues", - # f.create_map( - # *sum( - # ((f.lit(col), f.col(f"shap_{col}")) for col in features_list), - # (), - # ) - # ), - # ) - # .select(*self.get_schema().names), - # _schema=self.get_schema(), - # model=self.model, - # ) + # Fetch features if they are not present: + if "features" not in self.df.columns: + if feature_matrix is None: + raise ValueError( + "Feature matrix is required to explain the L2G predictions" + ) + self.add_features(feature_matrix) + + if self.model is None: + raise ValueError("Model not set, explainer cannot be created") + + # Format and pivot the dataframe to pass them before calculating shapley values + pdf = pivot_df( + df=self.df.withColumn("feature", f.explode("features")).select( + "studyLocusId", + "geneId", + "score", + f.col("feature.name").alias("feature_name"), + f.col("feature.value").alias("feature_value"), + ), + pivot_col="feature_name", + value_col="feature_value", + grouping_cols=[f.col("studyLocusId"), f.col("geneId"), f.col("score")], + ).toPandas() + + features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on) + base_value, shap_values = L2GPrediction._explain( + model=self.model, pdf=pdf.filter(items=features_list) + ) + + for i, feature in enumerate(features_list): + pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + + spark_session = self.df.sparkSession + return L2GPrediction( + _df=( + spark_session.createDataFrame(pdf.to_dict(orient="records")) + .withColumn( + "features", + f.array( + *( + f.struct( + f.lit(feature).alias("name"), + f.col(feature).alias("value"), + f.col(f"shap_{feature}").alias("shapValue"), + ) + for feature in features_list + ) + ), + ) + .withColumn("shapBaseProbability", f.lit(base_value)) + .select(*L2GPrediction.get_schema().names) + ), + _schema=self.get_schema(), + model=self.model, + ) + + @staticmethod + def _explain( + model: LocusToGeneModel, pdf: pd_dataframe + ) -> tuple[float, list[list[float]]]: + """Calculate SHAP values. Output is log odds ratio (raw mode). + + Args: + model (LocusToGeneModel): L2G model + pdf (pd_dataframe): Pandas dataframe containing the feature matrix in the same order that the model was trained on + + Returns: + tuple[float, list[list[float]]]: A tuple containing: + - base_value (float): Base value of the model + - shap_values (list[list[float]]): SHAP values for prediction + """ + explainer = shap.TreeExplainer( + model.model, feature_perturbation="tree_path_dependent" + ) + if pdf.shape[0] >= 10_000: + logging.warning( + "Calculating SHAP values for more than 10,000 rows. This may take a while..." + ) + shap_values = explainer.shap_values(pdf.to_numpy()) + base_value = expit(explainer.expected_value[0]) + return (base_value, shap_values) def add_features( self: L2GPrediction, @@ -221,19 +267,14 @@ def add_features( f.struct(f.lit(col).alias("name"), f.col(col).alias("value")) for col in features_list ] - return L2GPrediction( - _df=( - self.df.join( - feature_matrix._df.select(*features_list), - on=["studyLocusId", "geneId"], - how="left", - ).select( - "studyLocusId", - "geneId", - "score", - f.array(*feature_expressions).alias("features"), - ) - ), - _schema=self.get_schema(), - model=self.model, + self.df = self.df.join( + feature_matrix._df.select(*features_list), + on=["studyLocusId", "geneId"], + how="left", + ).select( + "studyLocusId", + "geneId", + "score", + f.array(*feature_expressions).alias("features"), ) + return self From 1fc73ca7c6ca3ee9d025ed6806af24c730006cf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 4 Feb 2025 11:11:50 +0100 Subject: [PATCH 26/38] feat: calculate scaled probabilities --- src/gentropy/dataset/l2g_prediction.py | 33 ++++++++++++++++--- tests/gentropy/dataset/test_l2g_prediction.py | 27 +++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) create mode 100644 tests/gentropy/dataset/test_l2g_prediction.py diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index fbeb2cd8a..7adce1426 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -184,17 +184,19 @@ def explain( ).toPandas() features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on) - base_value, shap_values = L2GPrediction._explain( + _, shap_values = L2GPrediction._explain( model=self.model, pdf=pdf.filter(items=features_list) ) - for i, feature in enumerate(features_list): pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + # Normalise feature contributions so they sum to final probability + scaled_pdf = L2GPrediction._normalise_feature_contributions(pdf) + spark_session = self.df.sparkSession return L2GPrediction( _df=( - spark_session.createDataFrame(pdf.to_dict(orient="records")) + spark_session.createDataFrame(scaled_pdf.to_dict(orient="records")) .withColumn( "features", f.array( @@ -202,13 +204,14 @@ def explain( f.struct( f.lit(feature).alias("name"), f.col(feature).alias("value"), - f.col(f"shap_{feature}").alias("shapValue"), + f.col(f"scaled_prob_shap_{feature}").alias( + "scaledProbability" + ), ) for feature in features_list ) ), ) - .withColumn("shapBaseProbability", f.lit(base_value)) .select(*L2GPrediction.get_schema().names) ), _schema=self.get_schema(), @@ -241,6 +244,26 @@ def _explain( base_value = expit(explainer.expected_value[0]) return (base_value, shap_values) + @staticmethod + def _normalise_feature_contributions(pdf: pd_dataframe) -> pd_dataframe: + """Normalise feature contributions. + + Args: + pdf (pd_dataframe): Pandas dataframe containing the SHAP values (log odds) for each feature. + + Returns: + pd_dataframe: Pandas dataframe with normalised feature contributions + """ + shap_cols = [col for col in pdf if col.startswith("shap")] + for col in shap_cols: + pdf[f"prob_{col}"] = expit(pdf[col]) + prob_feature_sum = pdf[[f"prob_{col}" for col in shap_cols]].sum(axis=1) + pdf["scaling_factor"] = pdf["score"] / prob_feature_sum + for col in shap_cols: + pdf[f"scaled_prob_{col}"] = pdf[f"prob_{col}"] * pdf["scaling_factor"] + pdf.drop(columns=[f"prob_{col}" for col in shap_cols], inplace=True) + return pdf + def add_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix, diff --git a/tests/gentropy/dataset/test_l2g_prediction.py b/tests/gentropy/dataset/test_l2g_prediction.py new file mode 100644 index 000000000..3a4f610ad --- /dev/null +++ b/tests/gentropy/dataset/test_l2g_prediction.py @@ -0,0 +1,27 @@ +"""Test L2G Prediction methods.""" + +import numpy as np +import pandas as pd + +from gentropy.dataset.l2g_prediction import L2GPrediction + + +def test_normalise_feature_contributions() -> None: + """Tests that scaled probabilities per feature add up to the probability inferred by the model.""" + df = pd.DataFrame( + { + "score": [0.163311], # Final probability + "shap_feature1": [-3.850356], + "shap_feature2": [3.015085], + "shap_feature3": [0.063206], + } + ) + scaled_df = L2GPrediction._normalise_feature_contributions(df) + reconstructed_prob = ( + scaled_df["scaled_prob_shap_feature1"].sum() + + scaled_df["scaled_prob_shap_feature2"].sum() + + scaled_df["scaled_prob_shap_feature3"].sum() + ) + assert np.allclose( + reconstructed_prob, df["score"], atol=1e-6 + ), "SHAP probability contributions do not sum to the expected probability." From 625992a690da6c119c80edcab09b27136fcccc3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 4 Feb 2025 11:12:29 +0100 Subject: [PATCH 27/38] chore(l2gprediction): remove shapBaseProbability --- src/gentropy/assets/schemas/l2g_predictions.json | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 62588671d..3e1c60135 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -47,7 +47,7 @@ }, { "metadata": {}, - "name": "shapProbabilityContribution", + "name": "scaledProbability", "nullable": false, "type": "float" } @@ -56,12 +56,6 @@ }, "type": "array" } - }, - { - "name": "shapBaseProbability", - "type": "double", - "nullable": false, - "metadata": {} } ] } From 134bc517bf9df4b205a5692c6014a56af0f6ffd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 5 Feb 2025 17:59:56 +0100 Subject: [PATCH 28/38] chore: correct typo in add_features and make schemas non nullable --- .../assets/schemas/l2g_predictions.json | 4 +- src/gentropy/dataset/l2g_prediction.py | 74 +++++++++++++++---- tests/gentropy/dataset/test_l2g_prediction.py | 11 +-- 3 files changed, 67 insertions(+), 22 deletions(-) diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 3e1c60135..2043b3bee 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -42,13 +42,13 @@ { "metadata": {}, "name": "shapValue", - "nullable": false, + "nullable": true, "type": "float" }, { "metadata": {}, "name": "scaledProbability", - "nullable": false, + "nullable": true, "type": "float" } ], diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 7adce1426..2a172fefb 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -145,7 +145,7 @@ def to_disease_target_evidence( ) def explain( - self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None + self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None = None ) -> L2GPrediction: """Extract Shapley values for the L2G predictions and add them as a map in an additional column. @@ -184,14 +184,14 @@ def explain( ).toPandas() features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on) - _, shap_values = L2GPrediction._explain( + base_value, shap_values = L2GPrediction._explain( model=self.model, pdf=pdf.filter(items=features_list) ) for i, feature in enumerate(features_list): pdf[f"shap_{feature}"] = [row[i] for row in shap_values] # Normalise feature contributions so they sum to final probability - scaled_pdf = L2GPrediction._normalise_feature_contributions(pdf) + scaled_pdf = L2GPrediction._normalise_feature_contributions(pdf, base_value) spark_session = self.df.sparkSession return L2GPrediction( @@ -245,23 +245,67 @@ def _explain( return (base_value, shap_values) @staticmethod - def _normalise_feature_contributions(pdf: pd_dataframe) -> pd_dataframe: - """Normalise feature contributions. + def _normalise_feature_contributions( + pdf: pd_dataframe, base_log_odds: float + ) -> pd_dataframe: + """Normalize SHAP contributions to probability space while preserving directionality. Args: - pdf (pd_dataframe): Pandas dataframe containing the SHAP values (log odds) for each feature. + pdf (pd_dataframe): Input dataframe with SHAP values and scores + base_log_odds (float): Base log-odds from the SHAP explainer Returns: - pd_dataframe: Pandas dataframe with normalised feature contributions + pd_dataframe: Output dataframe with normalized probability contributions """ - shap_cols = [col for col in pdf if col.startswith("shap")] - for col in shap_cols: - pdf[f"prob_{col}"] = expit(pdf[col]) - prob_feature_sum = pdf[[f"prob_{col}" for col in shap_cols]].sum(axis=1) - pdf["scaling_factor"] = pdf["score"] / prob_feature_sum + # Calculate base probability and sigmoid derivative + prob_base = expit(base_log_odds) + sigmoid_slope = prob_base * (1 - prob_base) # Derivative at base log-odds + + # ---------------------------------- + # 1. Linear Approximation Phase + # ---------------------------------- + # Convert SHAP values to directional probability deltas + shap_cols = [col for col in pdf.columns if col.startswith("shap_")] + linear_deltas = pdf[shap_cols] * sigmoid_slope + + # ---------------------------------- + # 2. Base Probability Distribution + # ---------------------------------- + # Calculate total absolute SHAP magnitude per row + total_abs_shap = ( + pdf[shap_cols].abs().sum(axis=1).replace(0, 1) # Avoid division by zero + ) + + # Distribute base probability proportionally to SHAP magnitudes + base_distribution = ( + pdf[shap_cols].abs().div(total_abs_shap, axis=0).mul(prob_base, axis=0) + ) + + # ---------------------------------- + # 3. Contribution Scaling Phase + # ---------------------------------- + # Calculate required probability adjustment + target_diff = pdf["score"] - prob_base + + # Calculate scaling factor for linear deltas + raw_delta_sum = linear_deltas.sum(axis=1).replace( + 0, 1 + ) # Avoid division by zero + scaling_factor = target_diff / raw_delta_sum + + # Scale deltas to match target probability difference + scaled_deltas = linear_deltas.mul(scaling_factor, axis=0) + + # ---------------------------------- + # 4. Final Contribution Calculation + # ---------------------------------- + # Combine base distribution and scaled deltas + final_contributions = base_distribution + scaled_deltas + + # Assign results to new columns for col in shap_cols: - pdf[f"scaled_prob_{col}"] = pdf[f"prob_{col}"] * pdf["scaling_factor"] - pdf.drop(columns=[f"prob_{col}" for col in shap_cols], inplace=True) + feature_name = col.replace("shap_", "") + pdf[f"scaled_prob_shap_{feature_name}"] = final_contributions[col] return pdf def add_features( @@ -291,7 +335,7 @@ def add_features( for col in features_list ] self.df = self.df.join( - feature_matrix._df.select(*features_list), + feature_matrix._df.select(*features_list, "studyLocusId", "geneId"), on=["studyLocusId", "geneId"], how="left", ).select( diff --git a/tests/gentropy/dataset/test_l2g_prediction.py b/tests/gentropy/dataset/test_l2g_prediction.py index 3a4f610ad..7a150f519 100644 --- a/tests/gentropy/dataset/test_l2g_prediction.py +++ b/tests/gentropy/dataset/test_l2g_prediction.py @@ -10,13 +10,14 @@ def test_normalise_feature_contributions() -> None: """Tests that scaled probabilities per feature add up to the probability inferred by the model.""" df = pd.DataFrame( { - "score": [0.163311], # Final probability - "shap_feature1": [-3.850356], - "shap_feature2": [3.015085], - "shap_feature3": [0.063206], + "score": [0.45], # Final probability + "shap_feature1": [-3.85], + "shap_feature2": [3.015], + "shap_feature3": [0.063], } ) - scaled_df = L2GPrediction._normalise_feature_contributions(df) + base_log_odds = 0.56 + scaled_df = L2GPrediction._normalise_feature_contributions(df, base_log_odds) reconstructed_prob = ( scaled_df["scaled_prob_shap_feature1"].sum() + scaled_df["scaled_prob_shap_feature2"].sum() From ee44c46e3f3f01345a70a21d166ac494e3b5c7b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 5 Feb 2025 19:35:21 +0100 Subject: [PATCH 29/38] fix: rename columns in pandas df after pivoting --- src/gentropy/dataset/l2g_prediction.py | 27 +++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 2a172fefb..99b0ceee6 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -182,10 +182,19 @@ def explain( value_col="feature_value", grouping_cols=[f.col("studyLocusId"), f.col("geneId"), f.col("score")], ).toPandas() + pdf = pdf.rename( + # trim the suffix that is added after pivoting the df + columns={ + col: col.replace("_feature_value", "") + for col in pdf.columns + if col.endswith("_feature_value") + } + ) features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on) base_value, shap_values = L2GPrediction._explain( - model=self.model, pdf=pdf.filter(items=features_list) + model=self.model, + pdf=pdf.filter(items=features_list), ) for i, feature in enumerate(features_list): pdf[f"shap_{feature}"] = [row[i] for row in shap_values] @@ -203,10 +212,10 @@ def explain( *( f.struct( f.lit(feature).alias("name"), - f.col(feature).alias("value"), - f.col(f"scaled_prob_shap_{feature}").alias( - "scaledProbability" - ), + f.col(feature).cast("float").alias("value"), + f.col(f"scaled_prob_shap_{feature}") + .cast("float") + .alias("scaledProbability"), ) for feature in features_list ) @@ -234,13 +243,17 @@ def _explain( - shap_values (list[list[float]]): SHAP values for prediction """ explainer = shap.TreeExplainer( - model.model, feature_perturbation="tree_path_dependent" + model.model, + feature_perturbation="tree_path_dependent", ) if pdf.shape[0] >= 10_000: logging.warning( "Calculating SHAP values for more than 10,000 rows. This may take a while..." ) - shap_values = explainer.shap_values(pdf.to_numpy()) + shap_values = explainer.shap_values( + pdf.to_numpy(), + check_additivity=False, + ) base_value = expit(explainer.expected_value[0]) return (base_value, shap_values) From e927b44b2e222878781d0a4726c15be0218905c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 6 Feb 2025 12:14:19 +0100 Subject: [PATCH 30/38] fix: add raw shap contributions --- src/gentropy/dataset/l2g_prediction.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 99b0ceee6..4d8a05695 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -213,6 +213,9 @@ def explain( f.struct( f.lit(feature).alias("name"), f.col(feature).cast("float").alias("value"), + f.col(f"shap_{feature}") + .cast("float") + .alias("shapValue"), f.col(f"scaled_prob_shap_{feature}") .cast("float") .alias("scaledProbability"), From cfc4529f37ebce3c1579840f6d9f0cc21c480131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 12 Feb 2025 17:03:26 +0000 Subject: [PATCH 31/38] fix(model): when saving create directory if not exists --- src/gentropy/dataset/l2g_prediction.py | 2 ++ src/gentropy/method/l2g/model.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 2ee198c6e..4d8a05695 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,10 +2,12 @@ from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING import pyspark.sql.functions as f +import shap from pyspark.sql import DataFrame from pyspark.sql.types import StructType from scipy.special import expit diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 3e7775ece..984b7ef51 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -194,6 +194,8 @@ def save(self: LocusToGeneModel, path: str) -> None: sio.dump(self.model, local_path) copy_to_gcs(local_path, path) else: + # create directory if path does not exist + Path(path).parent.mkdir(parents=True, exist_ok=True) sio.dump(self.model, path) @staticmethod From fc32ba4fb1fe150792830800e9b39e52616c6bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 12 Feb 2025 17:19:31 +0000 Subject: [PATCH 32/38] feat(l2g): bundle model and training data in hf --- src/gentropy/method/l2g/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 984b7ef51..6ffeff659 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -269,7 +269,7 @@ def export_to_hugging_face_hub( repo_id: str = "opentargets/locus_to_gene", local_repo: str = "locus_to_gene", ) -> None: - """Share the model on Hugging Face Hub. + """Share the model and training dataset on Hugging Face Hub. Args: model_path (str): The path to the L2G model file. @@ -293,6 +293,7 @@ def export_to_hugging_face_hub( data=data, ) self._create_hugging_face_model_card(local_repo) + data.to_parquet(f"{local_repo}/training_set.parquet") hub_utils.push( repo_id=repo_id, source=local_repo, From 37b83acc72fe01a660b5c520095d9fb905bfe03a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 13 Feb 2025 08:32:04 +0000 Subject: [PATCH 33/38] feat(model): include data when loading model --- src/gentropy/dataset/l2g_prediction.py | 4 +-- src/gentropy/l2g.py | 2 +- src/gentropy/method/l2g/model.py | 43 +++++++++++++++++++------- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 4d8a05695..973baf382 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -78,14 +78,14 @@ def from_credible_set( if download_from_hub: # Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops". model_id = model_path or "opentargets/locus_to_gene" - l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token) + l2g_model = LocusToGeneModel.load_from_hub(session, model_id, hf_token) elif model_path: if not features_list: raise AttributeError( "features_list is required if the model is not downloaded from the Hub" ) l2g_model = LocusToGeneModel.load_from_disk( - model_path, features_list=features_list + session, path=model_path, features_list=features_list ) # Prepare data diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 23ff670da..39b7fcca1 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -330,7 +330,7 @@ def run_train(self) -> None: "hfhub-key", "open-targets-genetics-dev" ) trained_model.export_to_hugging_face_hub( - # we upload the model in the filesystem + # we upload the model saved in the filesystem self.model_path.split("/")[-1], hf_hub_token, data=trained_model.training_data._df.drop( diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 6ffeff659..a13260974 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -16,9 +16,9 @@ from gentropy.common.session import Session from gentropy.common.utils import copy_to_gcs +from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix if TYPE_CHECKING: - from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_prediction import L2GPrediction @@ -53,12 +53,18 @@ def __post_init__(self: LocusToGeneModel) -> None: @classmethod def load_from_disk( - cls: type[LocusToGeneModel], path: str, **kwargs: Any + cls: type[LocusToGeneModel], + session: Session, + path: str, + model_name: str = "classifier.skops", + **kwargs: Any, ) -> LocusToGeneModel: """Load a fitted model from disk. Args: - path (str): Path to the model + session (Session): Session object that loads the training data + path (str): Path to the directory containing model and metadata + model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". **kwargs(Any): Keyword arguments to pass to the constructor Returns: @@ -67,8 +73,9 @@ def load_from_disk( Raises: ValueError: If the model has not been fitted yet """ - if path.startswith("gs://"): - path = path.removeprefix("gs://") + model_path = (Path(path) / model_name).as_posix() + if model_path.startswith("gs://"): + path = model_path.removeprefix("gs://") bucket_name = path.split("/")[0] blob_name = "/".join(path.split("/")[1:]) from google.cloud import storage @@ -79,25 +86,37 @@ def load_from_disk( data = blob.download_as_string(client=client) loaded_model = sio.loads(data, trusted=sio.get_untrusted_types(data=data)) else: - loaded_model = sio.load(path, trusted=sio.get_untrusted_types(file=path)) + loaded_model = sio.load( + model_path, trusted=sio.get_untrusted_types(file=model_path) + ) + try: + # Try loading the training data if it is in the model directory + training_data = L2GFeatureMatrix( + _df=session.load_data( + (Path(path) / "training_data.parquet").as_posix() + ), + features_list=kwargs.get("features_list"), + ) + except Exception: + training_data = None if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - return cls(model=loaded_model, **kwargs) + return cls(model=loaded_model, training_data=training_data, **kwargs) @classmethod def load_from_hub( cls: type[LocusToGeneModel], + session: Session, model_id: str, hf_token: str | None = None, - model_name: str = "classifier.skops", ) -> LocusToGeneModel: """Load a model from the Hugging Face Hub. This will download the model from the hub and load it from disk. Args: + session (Session): Session object to load the training data model_id (str): Model ID on the Hugging Face Hub hf_token (str | None): Hugging Face Hub token to download the model (only required if private) - model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". Returns: LocusToGeneModel: L2G model loaded from the Hugging Face Hub @@ -120,11 +139,13 @@ def get_features_list_from_metadata() -> list[str]: if column != "studyLocusId" ] - local_path = Path(model_id) + local_path = model_id hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) features_list = get_features_list_from_metadata() return cls.load_from_disk( - str(Path(local_path) / model_name), features_list=features_list + session, + local_path, + features_list=features_list, ) @property From 62f45b4845e8486e2ac3cb80dd2d3acefba605c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 13 Feb 2025 11:25:18 +0000 Subject: [PATCH 34/38] feat: final version of shap explanations --- .../assets/schemas/l2g_predictions.json | 12 +-- src/gentropy/dataset/l2g_prediction.py | 91 ++++--------------- tests/gentropy/dataset/test_l2g_prediction.py | 28 ------ 3 files changed, 22 insertions(+), 109 deletions(-) delete mode 100644 tests/gentropy/dataset/test_l2g_prediction.py diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 2043b3bee..1d100bf94 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -44,18 +44,18 @@ "name": "shapValue", "nullable": true, "type": "float" - }, - { - "metadata": {}, - "name": "scaledProbability", - "nullable": true, - "type": "float" } ], "type": "struct" }, "type": "array" } + }, + { + "name": "shapBaseValue", + "type": "float", + "nullable": true, + "metadata": {} } ] } diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 973baf382..a04a4507a 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -10,7 +10,6 @@ import shap from pyspark.sql import DataFrame from pyspark.sql.types import StructType -from scipy.special import expit from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -199,13 +198,10 @@ def explain( for i, feature in enumerate(features_list): pdf[f"shap_{feature}"] = [row[i] for row in shap_values] - # Normalise feature contributions so they sum to final probability - scaled_pdf = L2GPrediction._normalise_feature_contributions(pdf, base_value) - spark_session = self.df.sparkSession return L2GPrediction( _df=( - spark_session.createDataFrame(scaled_pdf.to_dict(orient="records")) + spark_session.createDataFrame(pdf.to_dict(orient="records")) .withColumn( "features", f.array( @@ -216,14 +212,12 @@ def explain( f.col(f"shap_{feature}") .cast("float") .alias("shapValue"), - f.col(f"scaled_prob_shap_{feature}") - .cast("float") - .alias("scaledProbability"), ) for feature in features_list ) ), ) + .withColumn("shapBaseValue", f.lit(base_value).cast("float")) .select(*L2GPrediction.get_schema().names) ), _schema=self.get_schema(), @@ -234,7 +228,7 @@ def explain( def _explain( model: LocusToGeneModel, pdf: pd_dataframe ) -> tuple[float, list[list[float]]]: - """Calculate SHAP values. Output is log odds ratio (raw mode). + """Calculate SHAP values. Output is in probability form (approximated from the log odds ratios). Args: model (LocusToGeneModel): L2G model @@ -244,10 +238,21 @@ def _explain( tuple[float, list[list[float]]]: A tuple containing: - base_value (float): Base value of the model - shap_values (list[list[float]]): SHAP values for prediction + + Raises: + AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created. """ + if not model.training_data: + raise AttributeError( + "`model.training_data` is missing, seed dataset to get shapley values cannot be created." + ) + background_data = model.training_data._df.select( + *model.features_list + ).toPandas() explainer = shap.TreeExplainer( model.model, - feature_perturbation="tree_path_dependent", + data=background_data, + model_output="probability", ) if pdf.shape[0] >= 10_000: logging.warning( @@ -257,73 +262,9 @@ def _explain( pdf.to_numpy(), check_additivity=False, ) - base_value = expit(explainer.expected_value[0]) + base_value = explainer.expected_value return (base_value, shap_values) - @staticmethod - def _normalise_feature_contributions( - pdf: pd_dataframe, base_log_odds: float - ) -> pd_dataframe: - """Normalize SHAP contributions to probability space while preserving directionality. - - Args: - pdf (pd_dataframe): Input dataframe with SHAP values and scores - base_log_odds (float): Base log-odds from the SHAP explainer - - Returns: - pd_dataframe: Output dataframe with normalized probability contributions - """ - # Calculate base probability and sigmoid derivative - prob_base = expit(base_log_odds) - sigmoid_slope = prob_base * (1 - prob_base) # Derivative at base log-odds - - # ---------------------------------- - # 1. Linear Approximation Phase - # ---------------------------------- - # Convert SHAP values to directional probability deltas - shap_cols = [col for col in pdf.columns if col.startswith("shap_")] - linear_deltas = pdf[shap_cols] * sigmoid_slope - - # ---------------------------------- - # 2. Base Probability Distribution - # ---------------------------------- - # Calculate total absolute SHAP magnitude per row - total_abs_shap = ( - pdf[shap_cols].abs().sum(axis=1).replace(0, 1) # Avoid division by zero - ) - - # Distribute base probability proportionally to SHAP magnitudes - base_distribution = ( - pdf[shap_cols].abs().div(total_abs_shap, axis=0).mul(prob_base, axis=0) - ) - - # ---------------------------------- - # 3. Contribution Scaling Phase - # ---------------------------------- - # Calculate required probability adjustment - target_diff = pdf["score"] - prob_base - - # Calculate scaling factor for linear deltas - raw_delta_sum = linear_deltas.sum(axis=1).replace( - 0, 1 - ) # Avoid division by zero - scaling_factor = target_diff / raw_delta_sum - - # Scale deltas to match target probability difference - scaled_deltas = linear_deltas.mul(scaling_factor, axis=0) - - # ---------------------------------- - # 4. Final Contribution Calculation - # ---------------------------------- - # Combine base distribution and scaled deltas - final_contributions = base_distribution + scaled_deltas - - # Assign results to new columns - for col in shap_cols: - feature_name = col.replace("shap_", "") - pdf[f"scaled_prob_shap_{feature_name}"] = final_contributions[col] - return pdf - def add_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix, diff --git a/tests/gentropy/dataset/test_l2g_prediction.py b/tests/gentropy/dataset/test_l2g_prediction.py deleted file mode 100644 index 7a150f519..000000000 --- a/tests/gentropy/dataset/test_l2g_prediction.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Test L2G Prediction methods.""" - -import numpy as np -import pandas as pd - -from gentropy.dataset.l2g_prediction import L2GPrediction - - -def test_normalise_feature_contributions() -> None: - """Tests that scaled probabilities per feature add up to the probability inferred by the model.""" - df = pd.DataFrame( - { - "score": [0.45], # Final probability - "shap_feature1": [-3.85], - "shap_feature2": [3.015], - "shap_feature3": [0.063], - } - ) - base_log_odds = 0.56 - scaled_df = L2GPrediction._normalise_feature_contributions(df, base_log_odds) - reconstructed_prob = ( - scaled_df["scaled_prob_shap_feature1"].sum() - + scaled_df["scaled_prob_shap_feature2"].sum() - + scaled_df["scaled_prob_shap_feature3"].sum() - ) - assert np.allclose( - reconstructed_prob, df["score"], atol=1e-6 - ), "SHAP probability contributions do not sum to the expected probability." From c635e18497e43d8d49aa49eac07a5cd9fab39f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 13 Feb 2025 14:43:53 +0000 Subject: [PATCH 35/38] fix: do not infer features_list from df --- src/gentropy/dataset/l2g_prediction.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index a04a4507a..adaf7751b 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -95,7 +95,8 @@ def from_credible_set( .select("studyLocusId") .join(feature_matrix._df, "studyLocusId") .filter(f.col("isProteinCoding") == 1) - ) + ), + features_list=l2g_model.features_list, ) .fill_na() .select_features(l2g_model.features_list) From 58f35d9022ca7df4fa55002d18100afd5895c5f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 13 Feb 2025 16:04:23 +0000 Subject: [PATCH 36/38] fix: get_features_list_from_metadata returned cols that were not features --- src/gentropy/dataset/l2g_prediction.py | 1 - src/gentropy/method/l2g/model.py | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index adaf7751b..d7477a354 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -96,7 +96,6 @@ def from_credible_set( .join(feature_matrix._df, "studyLocusId") .filter(f.col("isProteinCoding") == 1) ), - features_list=l2g_model.features_list, ) .fill_na() .select_features(l2g_model.features_list) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index a13260974..d0c0beeb8 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -136,7 +136,13 @@ def get_features_list_from_metadata() -> list[str]: return [ column for column in model_config["sklearn"]["columns"] - if column != "studyLocusId" + if column + not in [ + "studyLocusId", + "geneId", + "traitFromSourceMappedId", + "goldStandardSet", + ] ] local_path = model_id From d45acea651028c9b41f755797dc9c1e08ef1eff1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 13 Feb 2025 16:56:23 +0000 Subject: [PATCH 37/38] refactor(model): read training data in the local filesystem w pandas --- src/gentropy/l2g.py | 2 +- src/gentropy/method/l2g/model.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 39b7fcca1..be3e42612 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -168,7 +168,7 @@ def __init__( # Load common inputs self.credible_set = StudyLocus.from_parquet( session, credible_set_path, recursiveFileLookup=True - ) + ).filter(f.col("studyLocusId") == "2089b267ff0a27715af4b75d81abd834") self.feature_matrix = L2GFeatureMatrix( _df=session.load_data(feature_matrix_path), ) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index d0c0beeb8..8d31597ee 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import logging from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any @@ -92,12 +93,16 @@ def load_from_disk( try: # Try loading the training data if it is in the model directory training_data = L2GFeatureMatrix( - _df=session.load_data( - (Path(path) / "training_data.parquet").as_posix() + _df=session.spark.createDataFrame( + # Parquet is read with Pandas to easily read local files + pd.read_parquet( + (Path(path) / "training_data.parquet").as_posix() + ) ), features_list=kwargs.get("features_list"), ) - except Exception: + except Exception as e: + logging.error("Training data set to none. Error: %s", e) training_data = None if not loaded_model._is_fitted(): From 7b826a430037dea7bbee0287dce36fc2fe656092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 13 Feb 2025 17:05:15 +0000 Subject: [PATCH 38/38] chore: successful run, remove test --- src/gentropy/l2g.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index be3e42612..39b7fcca1 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -168,7 +168,7 @@ def __init__( # Load common inputs self.credible_set = StudyLocus.from_parquet( session, credible_set_path, recursiveFileLookup=True - ).filter(f.col("studyLocusId") == "2089b267ff0a27715af4b75d81abd834") + ) self.feature_matrix = L2GFeatureMatrix( _df=session.load_data(feature_matrix_path), )