Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(l2gprediction): add score explanation based on features #939

Open
wants to merge 49 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
72259fc
feat(prediction): add `model` as instance attribute
ireneisdoomed Dec 3, 2024
9e8c491
feat: added `convert_map_type_to_columns` spark util
ireneisdoomed Dec 3, 2024
450a937
feat(prediction): new method `explain` returns shapley values
ireneisdoomed Dec 3, 2024
08ae6bd
feat(prediction): `explain` returns predictions with shapley values
ireneisdoomed Dec 4, 2024
9d40e62
chore: compute `shapleyValues` in the l2g step
ireneisdoomed Dec 4, 2024
125425f
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
ireneisdoomed Dec 4, 2024
f407512
refactor: use pandas udf instead
ireneisdoomed Dec 5, 2024
f542395
refactor: forget about udfs and get shaps single threaded
ireneisdoomed Dec 6, 2024
9403fe6
chore: remove reference to chromatin interaction data in HF card
ireneisdoomed Dec 6, 2024
1bc6f3a
fix(l2g_prediction): methods that return new instance preserve attribute
ireneisdoomed Dec 6, 2024
8420933
feat(dataset): `filter` method preserves all instance attributes
ireneisdoomed Dec 6, 2024
8a85f4f
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
ireneisdoomed Dec 6, 2024
e6249c0
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
ireneisdoomed Jan 27, 2025
b987496
feat(l2gmodel): add features_list as model attribute and load it from…
ireneisdoomed Jan 27, 2025
3e99415
chore: merge
ireneisdoomed Jan 28, 2025
12de669
fix: pass correct order of features to shapley explainer
ireneisdoomed Jan 28, 2025
78027da
feat(l2g): predict mode to extract feature list from model, not from …
ireneisdoomed Jan 28, 2025
48b78ab
feat(l2g): pass default features list if model is loaded from a path
ireneisdoomed Jan 28, 2025
cb9d2e3
feat(l2gmodel): add features_list as model attribute and load it from…
ireneisdoomed Jan 27, 2025
eedc6ab
feat(l2g): predict mode to extract feature list from model, not from …
ireneisdoomed Jan 28, 2025
624602e
feat(l2gprediction): add `model` as attribute
ireneisdoomed Jan 28, 2025
e5853b5
chore: fix typo
ireneisdoomed Jan 28, 2025
1766ea1
feat(l2gmodel): add features_list as model attribute and load it from…
ireneisdoomed Jan 27, 2025
c0657dc
feat(l2g): predict mode to extract feature list from model, not from …
ireneisdoomed Jan 28, 2025
0f5c244
feat(l2gprediction): add `model` as attribute
ireneisdoomed Jan 28, 2025
267cb84
Merge branch 'dev' into il-l2g-feature-list-injection
ireneisdoomed Jan 28, 2025
66ea538
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
ireneisdoomed Jan 28, 2025
875aac2
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
ireneisdoomed Jan 28, 2025
775856f
Merge branch 'il-l2g-feature-list-injection' of https://github.com/op…
ireneisdoomed Jan 28, 2025
f30c2e3
Merge branch 'il-l2g-feature-list-injection' of https://github.com/op…
ireneisdoomed Jan 28, 2025
7423cac
chore: fix typo
ireneisdoomed Jan 28, 2025
fa3e5df
Merge branch 'il-l2g-feature-list-injection' of https://github.com/op…
ireneisdoomed Jan 28, 2025
450a45a
chore: remove `convert_map_type_to_columns`
ireneisdoomed Jan 28, 2025
2bf1112
feat(l2gprediction): refactor feature annotation and change schema
ireneisdoomed Jan 28, 2025
30a4676
chore: pre-commit auto fixes [...]
pre-commit-ci[bot] Jan 28, 2025
9a98332
feat: report as log odds
ireneisdoomed Feb 4, 2025
1fc73ca
feat: calculate scaled probabilities
ireneisdoomed Feb 4, 2025
625992a
chore(l2gprediction): remove shapBaseProbability
ireneisdoomed Feb 4, 2025
134bc51
chore: correct typo in add_features and make schemas non nullable
ireneisdoomed Feb 5, 2025
ee44c46
fix: rename columns in pandas df after pivoting
ireneisdoomed Feb 5, 2025
e927b44
fix: add raw shap contributions
ireneisdoomed Feb 6, 2025
4bca7c1
chore: merge
ireneisdoomed Feb 6, 2025
7b9aa03
Merge branch 'il-shapley-predictions' of https://github.com/opentarge…
ireneisdoomed Feb 6, 2025
cfc4529
fix(model): when saving create directory if not exists
ireneisdoomed Feb 12, 2025
fc32ba4
feat(l2g): bundle model and training data in hf
ireneisdoomed Feb 12, 2025
37b83ac
feat(model): include data when loading model
ireneisdoomed Feb 13, 2025
62f45b4
feat: final version of shap explanations
ireneisdoomed Feb 13, 2025
c388ff5
Merge branch 'dev' of https://github.com/opentargets/gentropy into il…
ireneisdoomed Feb 13, 2025
c635e18
fix: do not infer features_list from df
ireneisdoomed Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions src/gentropy/assets/schemas/l2g_predictions.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,41 @@
},
{
"metadata": {},
"name": "locusToGeneFeatures",
"name": "features",
"nullable": true,
"type": {
"keyType": "string",
"type": "map",
"valueContainsNull": true,
"valueType": "float"
"containsNull": false,
"elementType": {
"fields": [
{
"metadata": {},
"name": "name",
"nullable": false,
"type": "string"
},
{
"metadata": {},
"name": "value",
"nullable": false,
"type": "float"
},
{
"metadata": {},
"name": "shapValue",
"nullable": true,
"type": "float"
}
],
"type": "struct"
},
"type": "array"
}
},
{
"name": "shapBaseValue",
"type": "float",
"nullable": true,
"metadata": {}
}
]
}
14 changes: 12 additions & 2 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,25 @@ def from_parquet(
def filter(self: Self, condition: Column) -> Self:
"""Creates a new instance of a Dataset with the DataFrame filtered by the condition.

Preserves all attributes from the original instance.

Args:
condition (Column): Condition to filter the DataFrame

Returns:
Self: Filtered Dataset
Self: Filtered Dataset with preserved attributes
"""
df = self._df.filter(condition)
class_constructor = self.__class__
return class_constructor(_df=df, _schema=class_constructor.get_schema())
# Get all attributes from the current instance
attrs = {
key: value
for key, value in self.__dict__.items()
if key not in ["_df", "_schema"]
}
return class_constructor(
_df=df, _schema=class_constructor.get_schema(), **attrs
)

def validate_schema(self: Dataset) -> None:
"""Validate DataFrame schema against expected class schema.
Expand Down
198 changes: 163 additions & 35 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import shap
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.session import Session
from gentropy.common.spark_helpers import pivot_df
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import StudyLocus
from gentropy.method.l2g.model import LocusToGeneModel

if TYPE_CHECKING:
from pandas import DataFrame as pd_dataframe
from pyspark.sql.types import StructType


Expand Down Expand Up @@ -47,6 +52,7 @@ def from_credible_set(
credible_set: StudyLocus,
feature_matrix: L2GFeatureMatrix,
model_path: str | None,
features_list: list[str] | None = None,
hf_token: str | None = None,
download_from_hub: bool = True,
) -> L2GPrediction:
Expand All @@ -57,19 +63,29 @@ def from_credible_set(
credible_set (StudyLocus): Dataset containing credible sets from GWAS only
feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations
model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name).
features_list (list[str] | None): Default list of features the model uses. Only used if the model is not downloaded from the Hub. CAUTION: This default list can differ from the actual list the model was trained on.
hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private.
download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True.

Returns:
L2GPrediction: L2G scores for a set of credible sets.

Raises:
AttributeError: If `features_list` is not provided and the model is not downloaded from the Hub.
"""
# Load the model
if download_from_hub:
# Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops".
model_id = model_path or "opentargets/locus_to_gene"
l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token)
l2g_model = LocusToGeneModel.load_from_hub(session, model_id, hf_token)
elif model_path:
l2g_model = LocusToGeneModel.load_from_disk(model_path)
if not features_list:
raise AttributeError(
"features_list is required if the model is not downloaded from the Hub"
)
l2g_model = LocusToGeneModel.load_from_disk(
session, path=model_path, features_list=features_list
)

# Prepare data
fm = (
Expand All @@ -79,7 +95,8 @@ def from_credible_set(
.select("studyLocusId")
.join(feature_matrix._df, "studyLocusId")
.filter(f.col("isProteinCoding") == 1)
)
),
features_list=l2g_model.features_list,
)
.fill_na()
.select_features(l2g_model.features_list)
Expand Down Expand Up @@ -127,7 +144,129 @@ def to_disease_target_evidence(
)
)

def add_locus_to_gene_features(
def explain(
self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None = None
) -> L2GPrediction:
"""Extract Shapley values for the L2G predictions and add them as a map in an additional column.

Args:
feature_matrix (L2GFeatureMatrix | None): Feature matrix in case the predictions are missing the feature annotation. If None, the features are fetched from the dataset.

Returns:
L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings

Raises:
ValueError: If the model is not set or If feature matrix is not provided and the predictions do not have features
"""
# Fetch features if they are not present:
if "features" not in self.df.columns:
if feature_matrix is None:
raise ValueError(
"Feature matrix is required to explain the L2G predictions"
)
self.add_features(feature_matrix)

if self.model is None:
raise ValueError("Model not set, explainer cannot be created")

# Format and pivot the dataframe to pass them before calculating shapley values
pdf = pivot_df(
df=self.df.withColumn("feature", f.explode("features")).select(
"studyLocusId",
"geneId",
"score",
f.col("feature.name").alias("feature_name"),
f.col("feature.value").alias("feature_value"),
),
pivot_col="feature_name",
value_col="feature_value",
grouping_cols=[f.col("studyLocusId"), f.col("geneId"), f.col("score")],
).toPandas()
pdf = pdf.rename(
# trim the suffix that is added after pivoting the df
columns={
col: col.replace("_feature_value", "")
for col in pdf.columns
if col.endswith("_feature_value")
}
)

features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on)
base_value, shap_values = L2GPrediction._explain(
model=self.model,
pdf=pdf.filter(items=features_list),
)
for i, feature in enumerate(features_list):
pdf[f"shap_{feature}"] = [row[i] for row in shap_values]

spark_session = self.df.sparkSession
return L2GPrediction(
_df=(
spark_session.createDataFrame(pdf.to_dict(orient="records"))
.withColumn(
"features",
f.array(
*(
f.struct(
f.lit(feature).alias("name"),
f.col(feature).cast("float").alias("value"),
f.col(f"shap_{feature}")
.cast("float")
.alias("shapValue"),
)
for feature in features_list
)
),
)
.withColumn("shapBaseValue", f.lit(base_value).cast("float"))
.select(*L2GPrediction.get_schema().names)
),
_schema=self.get_schema(),
model=self.model,
)

@staticmethod
def _explain(
model: LocusToGeneModel, pdf: pd_dataframe
) -> tuple[float, list[list[float]]]:
"""Calculate SHAP values. Output is in probability form (approximated from the log odds ratios).

Args:
model (LocusToGeneModel): L2G model
pdf (pd_dataframe): Pandas dataframe containing the feature matrix in the same order that the model was trained on

Returns:
tuple[float, list[list[float]]]: A tuple containing:
- base_value (float): Base value of the model
- shap_values (list[list[float]]): SHAP values for prediction

Raises:
AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created.
"""
if not model.training_data:
raise AttributeError(
"`model.training_data` is missing, seed dataset to get shapley values cannot be created."
)
background_data = model.training_data._df.select(
*model.features_list
).toPandas()
explainer = shap.TreeExplainer(
model.model,
data=background_data,
model_output="probability",
)
if pdf.shape[0] >= 10_000:
logging.warning(
"Calculating SHAP values for more than 10,000 rows. This may take a while..."
)
shap_values = explainer.shap_values(
pdf.to_numpy(),
check_additivity=False,
)
base_value = explainer.expected_value
return (base_value, shap_values)

def add_features(
self: L2GPrediction,
feature_matrix: L2GFeatureMatrix,
) -> L2GPrediction:
Expand All @@ -137,41 +276,30 @@ def add_locus_to_gene_features(
feature_matrix (L2GFeatureMatrix): Feature matrix dataset

Returns:
L2GPrediction: L2G predictions with additional features
L2GPrediction: L2G predictions with additional column `features`

Raises:
ValueError: If model is not set, feature list won't be available
"""
if self.model is None:
raise ValueError("Model not set, feature annotation cannot be created.")
# Testing if `locusToGeneFeatures` column already exists:
if "locusToGeneFeatures" in self.df.columns:
self.df = self.df.drop("locusToGeneFeatures")

# Aggregating all features into a single map column:
aggregated_features = (
feature_matrix._df.withColumn(
"locusToGeneFeatures",
f.create_map(
*sum(
(
(f.lit(feature), f.col(feature))
for feature in self.model.features_list
),
(),
)
),
)
.withColumn(
"locusToGeneFeatures",
f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"),
)
.drop(*self.model.features_list)
)
return L2GPrediction(
_df=self.df.join(
aggregated_features, on=["studyLocusId", "geneId"], how="left"
),
_schema=self.get_schema(),
model=self.model,
# Testing if `features` column already exists:
if "features" in self.df.columns:
self.df = self.df.drop("features")

features_list = self.model.features_list
feature_expressions = [
f.struct(f.lit(col).alias("name"), f.col(col).alias("value"))
for col in features_list
]
self.df = self.df.join(
feature_matrix._df.select(*features_list, "studyLocusId", "geneId"),
on=["studyLocusId", "geneId"],
how="left",
).select(
"studyLocusId",
"geneId",
"score",
f.array(*feature_expressions).alias("features"),
)
return self
9 changes: 4 additions & 5 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,13 @@ def run_predict(self) -> None:
self.credible_set,
self.feature_matrix,
model_path=self.model_path,
features_list=self.features_list,
hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"),
download_from_hub=self.download_from_hub,
)
predictions.filter(
f.col("score") >= self.l2g_threshold
).add_locus_to_gene_features(
predictions.filter(f.col("score") >= self.l2g_threshold).add_features(
self.feature_matrix,
).df.coalesce(self.session.output_partitions).write.mode(
).explain().df.coalesce(self.session.output_partitions).write.mode(
self.session.write_mode
).parquet(self.predictions_path)
self.session.logger.info("L2G predictions saved successfully.")
Expand Down Expand Up @@ -331,7 +330,7 @@ def run_train(self) -> None:
"hfhub-key", "open-targets-genetics-dev"
)
trained_model.export_to_hugging_face_hub(
# we upload the model in the filesystem
# we upload the model saved in the filesystem
self.model_path.split("/")[-1],
hf_hub_token,
data=trained_model.training_data._df.drop(
Expand Down
Loading
Loading