Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoleculeACE linear probe evaluation callback #31

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lobster/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._linear_probe_callback import LinearProbeCallback
from ._moleculeace_linear_probe_callback import MoleculeACELinearProbeCallback

__all__ = ["MoleculeACELinearProbeCallback", "LinearProbeCallback"]
127 changes: 127 additions & 0 deletions src/lobster/callbacks/_linear_probe_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Callable, Dict, Literal, Optional, Tuple

import lightning as L
import torch
from beignet.transforms import Transform
from lightning.pytorch.callbacks import Callback
from sklearn.linear_model import LinearRegression, LogisticRegression
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics import AUROC, Accuracy, F1Score, MeanSquaredError, R2Score, SpearmanCorrCoef

TaskType = Literal["regression", "binary", "multiclass"]


class LinearProbeCallback(Callback):
"""Callback for evaluating embedding models using scikit-learn linear probes."""

def __init__(
self,
task_type: TaskType = "regression",
transform_fn: Transform | Callable | None = None,
num_classes: Optional[int] = None,
batch_size: int = 32,
run_every_n_epochs: int | None = None,
):
super().__init__()
self.transform_fn = transform_fn
self.task_type = task_type
self.num_classes = num_classes
self.batch_size = batch_size
self.run_every_n_epochs = run_every_n_epochs

# Initialize metrics based on task type
if task_type == "regression":
self.mse = MeanSquaredError()
self.r2 = R2Score()
self.spearman = SpearmanCorrCoef()

elif task_type in {"binary", "multiclass"}:
self.accuracy = Accuracy(task=task_type, num_classes=num_classes)
self.f1 = F1Score(task=task_type, num_classes=num_classes)
self.auroc = AUROC(task_type=task_type, num_classes=num_classes)

else:
raise ValueError("task_type must be: regression, binary, or multiclass")

# Dictionary to store trained probes
self.probes: Dict[str, LinearRegression | LogisticRegression] = {}

def _skip(self, trainer: L.Trainer) -> bool:
"""Determine if we should skip validation this epoch."""
if self.run_every_n_epochs is None:
return False

return trainer.current_epoch % self.run_every_n_epochs != 0

def _get_embeddings(self, module: L.LightningModule, dataloader: DataLoader) -> Tuple[Tensor, Tensor]:
"""Extract embeddings from the model for a given dataloader."""
embeddings = []
targets = []

module.eval()
with torch.no_grad():
for batch in dataloader:
x, y = batch
x = {k: v.to(module.device) for k, v in x.items()}

# Get token-level embeddings
batch_embeddings = module.tokens_to_latents(**x)

# Reshape to (batch_size, seq_len, hidden_size)
batch_size = len(y)
seq_len = x["input_ids"].size(-1)
batch_embeddings = batch_embeddings.view(batch_size, seq_len, -1)

# Simple mean pooling over sequence length dimension
seq_embeddings = batch_embeddings.mean(dim=1)

embeddings.append(seq_embeddings.cpu())
targets.append(y.cpu())

return torch.cat(embeddings), torch.cat(targets)

def _train_probe(self, embeddings: Tensor, targets: Tensor):
"""Train a probe on the given embeddings and targets."""
embeddings = embeddings.numpy()
targets = targets.numpy()

if self.task_type == "regression":
probe = LinearRegression()
else:
probe = LogisticRegression(
multi_class="ovr" if self.task_type == "binary" else "multinomial",
)

probe.fit(embeddings, targets)

return probe

def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[str, float]:
"""Evaluate a trained probe using task-appropriate metrics."""
metrics = {}

if self.task_type == "regression":
predictions = probe.predict(embeddings.numpy())
predictions = torch.from_numpy(predictions).float()

metrics["mse"] = self.mse(predictions, targets).item()
metrics["r2"] = self.r2(predictions, targets).item()
metrics["spearman"] = self.spearman(predictions.squeeze(), targets.squeeze()).item()

else: # binary or multiclass
pred_probs = probe.predict_proba(embeddings.numpy())
predictions = torch.from_numpy(pred_probs).float()

if self.task_type == "binary":
predictions = predictions[:, 1]

metrics["accuracy"] = self.accuracy(predictions, targets).item()
metrics["f1"] = self.f1(predictions, targets).item()
metrics["auroc"] = self.auroc(predictions, targets).item()

return metrics

def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
"""Train and evaluate linear probes, optionally at specified epochs."""
raise NotImplementedError("Subclasses must implement on_validation_epoch_end")
94 changes: 94 additions & 0 deletions src/lobster/callbacks/_moleculeace_linear_probe_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from collections import defaultdict
from typing import Optional, Sequence

import lightning as L
from torch.utils.data import DataLoader
from tqdm import tqdm

from lobster.constants import MOLECULEACE_TASKS
from lobster.datasets import MoleculeACEDataset
from lobster.tokenization import SmilesTokenizerFast
from lobster.transforms import TokenizerTransform

from ._linear_probe_callback import LinearProbeCallback


class MoleculeACELinearProbeCallback(LinearProbeCallback):
"""Callback for evaluating embedding models on the Molecule Activity Cliff
Estimation (MoleculeACE) dataset from Tilborg et al. (2022).

This callback assesses how well a molecular embedding model captures activity
cliffs - pairs of molecules that are structurally similar but show large
differences in biological activity (potency). It does this by training linear
probes on the frozen embeddings to predict pEC50/pKi values for 30 different
protein targets from ChEMBL.

Reference:
van Tilborg et al. (2022) "Exposing the Limitations of Molecular Machine
Learning with Activity Cliffs"
https://pubs.acs.org/doi/10.1021/acs.jcim.2c01073
"""

def __init__(
self,
max_length: int,
tasks: Optional[Sequence[str]] = None,
batch_size: int = 32,
run_every_n_epochs: int | None = None,
):
tokenizer_transform = TokenizerTransform(
tokenizer=SmilesTokenizerFast(),
padding="max_length",
truncation=True,
max_length=max_length,
)

super().__init__(
transform_fn=tokenizer_transform,
task_type="regression",
batch_size=batch_size,
run_every_n_epochs=run_every_n_epochs,
)

# Set tasks
self.tasks = set(tasks) if tasks is not None else MOLECULEACE_TASKS

def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
"""Train and evaluate linear probes at specified epochs."""
if self._skip(trainer):
return

aggregate_metrics = defaultdict(list)

for task in tqdm(self.tasks, desc=f"{self.__class__.__name__}"):
# Create datasets
train_dataset = MoleculeACEDataset(task=task, transform_fn=self.transform_fn, train=True)
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
test_dataset = MoleculeACEDataset(task=task, transform_fn=self.transform_fn, train=False)
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

# Get embeddings
train_embeddings, train_targets = self._get_embeddings(pl_module, train_loader)
test_embeddings, test_targets = self._get_embeddings(pl_module, test_loader)

# Train probe
probe = self._train_probe(train_embeddings, train_targets)
self.probes[task] = probe

# Evaluate
metrics = self._evaluate_probe(probe, test_embeddings, test_targets)

# Log metrics and store for averaging
for metric_name, value in metrics.items():
trainer.logger.log_metrics(
{f"moleculeace_linear_probe/{task}/{metric_name}": value}, step=trainer.current_epoch
)

aggregate_metrics[metric_name].append(value)

# Calculate and log aggregate metrics
for metric_name, values in aggregate_metrics.items():
avg_value = sum(values) / len(values)
trainer.logger.log_metrics(
{f"moleculeace_linear_probe/mean/{metric_name}": avg_value}, step=trainer.current_epoch
)
3 changes: 3 additions & 0 deletions src/lobster/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._moleculeace_tasks import MOLECULEACE_TASKS

__all__ = ["MOLECULEACE_TASKS"]
32 changes: 32 additions & 0 deletions src/lobster/constants/_moleculeace_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
MOLECULEACE_TASKS = {
"CHEMBL1862_Ki",
"CHEMBL1871_Ki",
"CHEMBL2034_Ki",
"CHEMBL2047_EC50",
"CHEMBL204_Ki",
"CHEMBL2147_Ki",
"CHEMBL214_Ki",
"CHEMBL218_EC50",
"CHEMBL219_Ki",
"CHEMBL228_Ki",
"CHEMBL231_Ki",
"CHEMBL233_Ki",
"CHEMBL234_Ki",
"CHEMBL235_EC50",
"CHEMBL236_Ki",
"CHEMBL237_EC50",
"CHEMBL237_Ki",
"CHEMBL238_Ki",
"CHEMBL239_EC50",
"CHEMBL244_Ki",
"CHEMBL262_Ki",
"CHEMBL264_Ki",
"CHEMBL2835_Ki",
"CHEMBL287_Ki",
"CHEMBL2971_Ki",
"CHEMBL3979_EC50",
"CHEMBL4005_Ki",
"CHEMBL4203_Ki",
"CHEMBL4616_EC50",
"CHEMBL4792_Ki",
}
9 changes: 9 additions & 0 deletions src/lobster/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from ._calm_dataset import CalmDataset
from ._fasta_dataset import FASTADataset
from ._m3_20m_dataset import M320MDataset
from ._moleculeace_dataset import MoleculeACEDataset

__all__ = [
"M320MDataset",
"ChEMBLDataset",
"CalmDataset",
"FASTADataset",
"MoleculeACEDataset",
]
2 changes: 1 addition & 1 deletion src/lobster/datasets/_m3_20m_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
List of columns to be used from the dataset.
"""
super().__init__()
url = "https://huggingface.co/datasets/karina-zadorozhny/M320M-multi-modal-molecular-dataset/resolve/main/M320M-Dataset.parquet.gzip"
url = "https://huggingface.co/datasets/karina-zadorozhny/M320M/resolve/main/M320M-Dataset.parquet.gzip"

suffix = ".parquet.gzip"

Expand Down
95 changes: 95 additions & 0 deletions src/lobster/datasets/_moleculeace_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from pathlib import Path
from typing import Callable, Tuple

import pandas
import pooch
import torch
from beignet.transforms import Transform
from torch import Tensor
from torch.utils.data import Dataset

from lobster.constants import MOLECULEACE_TASKS


class MoleculeACEDataset(Dataset):
def __init__(
self,
root: str | Path | None = None,
*,
task: str,
download: bool = True,
transform_fn: Callable | Transform | None = None,
target_transform_fn: Callable | Transform | None = None,
train: bool = True,
known_hash: str | None = None,
) -> None:
"""
Molecule Activity Cliff Estimation (MoleculeACE) from Tilborg et al. (2024)

Reference: https://pubs.acs.org/doi/10.1021/acs.jcim.2c01073

Contains activity data for 30 different ChEMBL targets (=tasks).
"""
super().__init__()

if root is None:
root = pooch.os_cache("lbster")

if isinstance(root, str):
root = Path(root)

self._root = root.resolve()

self._download = download
self.transform_fn = transform_fn
self.target_transform_fn = target_transform_fn
self.column = "smiles"
self.target_column = "y [pEC50/pKi]"
self.task = task
self.train = train

if self.task not in MOLECULEACE_TASKS:
raise ValueError(f"`task` must be one of {MOLECULEACE_TASKS}, got {self.task}")

suffix = ".csv"
url = "https://raw.githubusercontent.com/molML/MoleculeACE/main/MoleculeACE/Data/benchmark_data/"
url = f"{url}/{self.task}{suffix}"

if self._download:
pooch.retrieve(
url=url,
fname=f"{self.__class__.__name__}_{self.task}_{suffix}",
known_hash=known_hash,
path=self._root / self.__class__.__name__,
progressbar=True,
)

data = pandas.read_csv(
self._root / self.__class__.__name__ / f"{self.__class__.__name__}_{self.task}_{suffix}"
).reset_index(drop=True)

if train:
self.data = data[data["split"] == "train"]
else:
self.data = data[data["split"] == "test"]

def __getitem__(self, index: int) -> Tuple[str | Tensor, Tensor]:
item = self.data.iloc[index]

x = item[self.column]

if self.transform_fn is not None:
x = self.transform_fn(x)

y = item[self.target_column]

if self.target_transform_fn is not None:
y = self.target_transform_fn(y)

if not isinstance(y, Tensor):
y = torch.tensor(y).unsqueeze(-1)

return x, y

def __len__(self) -> int:
return len(self.data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
moleculeace_linear_probe:
_target_: lobster.callbacks.MoleculeACELinearProbeCallback
max_length: ???
1 change: 0 additions & 1 deletion src/lobster/hydra_config/data/m320m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ defaults:
- transform_fn: smiles_tokenizer_transform

_target_: lobster.data.M320MLightningDataModule
columns: ["smiles"]
download: true
lengths: [0.9,0.05,0.05]
batch_size: 64
Expand Down
Loading
Loading