-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CalmPropertyDataset & CalmLinearProbe Callbacks #34
base: main
Are you sure you want to change the base?
Changes from all commits
825743b
5d0eecc
68bb4fd
ee11722
6f3e07c
ab21614
be039a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from collections import defaultdict | ||
from typing import Optional, Sequence, Tuple | ||
|
||
import lightning as L | ||
import numpy as np | ||
from torch.utils.data import DataLoader, Subset | ||
from tqdm import tqdm | ||
|
||
from lobster.constants import CALM_TASKS | ||
from lobster.datasets import CalmPropertyDataset | ||
from lobster.tokenization import NucleotideTokenizerFast | ||
from lobster.transforms import TokenizerTransform | ||
|
||
from ._linear_probe_callback import LinearProbeCallback | ||
|
||
|
||
class CalmLinearProbeCallback(LinearProbeCallback): | ||
"""Callback for evaluating embedding models on the CALM dataset collection.""" | ||
|
||
def __init__( | ||
self, | ||
max_length: int, | ||
tasks: Optional[Sequence[str]] = None, | ||
species: Optional[Sequence[str]] = None, | ||
batch_size: int = 32, | ||
run_every_n_epochs: Optional[int] = None, | ||
test_size: float = 0.2, | ||
max_samples: int = 3000, | ||
seed: int = 42, | ||
): | ||
tokenizer_transform = TokenizerTransform( | ||
tokenizer=NucleotideTokenizerFast(), | ||
padding="max_length", | ||
truncation=True, | ||
max_length=max_length, | ||
) | ||
|
||
super().__init__( | ||
transform_fn=tokenizer_transform, | ||
task_type="regression", | ||
batch_size=batch_size, | ||
run_every_n_epochs=run_every_n_epochs, | ||
) | ||
|
||
self.tasks = set(tasks) if tasks else set(CALM_TASKS.keys()) | ||
self.species = set(species) if species else None | ||
|
||
self.test_size = test_size | ||
self.max_samples = max_samples | ||
self.seed = seed | ||
|
||
self.dataset_splits = {} | ||
self.aggregate_metrics = defaultdict(list) | ||
|
||
def _create_split_datasets( | ||
self, task: str, species: Optional[str] = None | ||
) -> Tuple[CalmPropertyDataset, CalmPropertyDataset]: | ||
"""Create train/test splits for a given task.""" | ||
|
||
rng = np.random.RandomState(self.seed) # TODO - seed everything fn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, does this work better or equivalently to lightning's seed everything? |
||
|
||
# Check cache for existing splits | ||
split_key = f"{task}_{species}" if species else task | ||
if split_key in self.dataset_splits: | ||
return self.dataset_splits[split_key] | ||
|
||
dataset = CalmPropertyDataset(task=task, species=species, transform_fn=self.transform_fn) | ||
|
||
indices = np.arange(len(dataset)) | ||
|
||
# If dataset is too large, subsample it first | ||
if len(indices) > self.max_samples: | ||
indices = rng.choice(indices, size=self.max_samples, replace=False) | ||
|
||
# Create train/test split from (possibly subsampled) indices | ||
test_size = int(len(indices) * self.test_size) | ||
train_size = len(indices) - test_size | ||
shuffled_indices = rng.permutation(indices) | ||
train_indices = shuffled_indices[:train_size] | ||
test_indices = shuffled_indices[train_size:] | ||
|
||
train_dataset = Subset(dataset, train_indices) | ||
test_dataset = Subset(dataset, test_indices) | ||
|
||
# Cache the splits | ||
self.dataset_splits[split_key] = (train_dataset, test_dataset) | ||
|
||
return train_dataset, test_dataset | ||
|
||
def _evaluate_task( | ||
self, | ||
task_key: str, | ||
task: str, | ||
train_dataset, | ||
test_dataset, | ||
trainer: L.Trainer, | ||
pl_module: L.LightningModule, | ||
): | ||
"""Evaluate a single task.""" | ||
|
||
task_type, num_classes = CALM_TASKS[task] | ||
|
||
self._set_metrics(task_type, num_classes) | ||
|
||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | ||
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) | ||
|
||
try: | ||
train_embeddings, train_targets = self._get_embeddings(pl_module, train_loader) | ||
test_embeddings, test_targets = self._get_embeddings(pl_module, test_loader) | ||
|
||
probe = self._train_probe(train_embeddings, train_targets) | ||
self.probes[task_key] = probe | ||
|
||
metrics = self._evaluate_probe(probe, test_embeddings, test_targets) | ||
|
||
# Log metrics and store for averaging | ||
for metric_name, value in metrics.items(): | ||
metric_key = f"calm_linear_probe/{task_key}/{metric_name}" | ||
trainer.logger.log_metrics({metric_key: value}, step=trainer.current_epoch) | ||
self.aggregate_metrics[metric_name].append(value) | ||
|
||
except Exception as e: | ||
print(f"Error in _evaluate_task for {task_key}: {str(e)}") | ||
raise | ||
|
||
def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: | ||
if self._skip(trainer): | ||
return | ||
|
||
self.device = pl_module.device | ||
self.aggregate_metrics.clear() # Reset aggregation for new epoch | ||
|
||
for task in tqdm(self.tasks, desc=f"{self.__class__.__name__}"): | ||
# Handle species-specific tasks | ||
if task in ["protein_abundance", "transcript_abundance"]: | ||
if not self.species: | ||
continue | ||
for species in self.species: | ||
task_key = f"{task}_{species}" | ||
try: | ||
train_dataset, test_dataset = self._create_split_datasets(task, species) | ||
self._evaluate_task(task_key, task, train_dataset, test_dataset, trainer, pl_module) | ||
except Exception as e: | ||
print(f"Error processing {task_key}: {str(e)}") | ||
else: | ||
try: | ||
train_dataset, test_dataset = self._create_split_datasets(task) | ||
self._evaluate_task(task, task, train_dataset, test_dataset, trainer, pl_module) | ||
except Exception as e: | ||
print(f"Error processing {task}: {str(e)}") | ||
|
||
# Calculate and log aggregate metrics | ||
for metric_name, values in self.aggregate_metrics.items(): | ||
if values: # Only log if we have values | ||
avg_value = sum(values) / len(values) | ||
trainer.logger.log_metrics( | ||
{f"calm_linear_probe/mean/{metric_name}": avg_value}, step=trainer.current_epoch | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,17 @@ | ||
from typing import Callable, Dict, Literal, Optional, Tuple | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
from beignet.transforms import Transform | ||
from lightning.pytorch.callbacks import Callback | ||
from sklearn.linear_model import LinearRegression, LogisticRegression | ||
from sklearn.multioutput import MultiOutputClassifier | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
from torchmetrics import AUROC, Accuracy, F1Score, MeanSquaredError, R2Score, SpearmanCorrCoef | ||
|
||
TaskType = Literal["regression", "binary", "multiclass"] | ||
TaskType = Literal["regression", "binary", "multiclass", "multilabel"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is multilabel different from multiclass? |
||
|
||
|
||
class LinearProbeCallback(Callback): | ||
|
@@ -31,21 +33,36 @@ def __init__( | |
self.run_every_n_epochs = run_every_n_epochs | ||
|
||
# Initialize metrics based on task type | ||
self._set_metrics(task_type, num_classes) | ||
|
||
# Dictionary to store trained probes | ||
self.probes: Dict[str, LinearRegression | LogisticRegression] = {} | ||
|
||
def _set_metrics(self, task_type: TaskType, num_classes: Optional[int] = None) -> None: | ||
"""Initialize metrics based on task type.""" | ||
if task_type == "regression": | ||
self.mse = MeanSquaredError() | ||
self.r2 = R2Score() | ||
self.spearman = SpearmanCorrCoef() | ||
|
||
elif task_type in {"binary", "multiclass"}: | ||
self.accuracy = Accuracy(task=task_type, num_classes=num_classes) | ||
self.f1 = F1Score(task=task_type, num_classes=num_classes) | ||
self.auroc = AUROC(task_type=task_type, num_classes=num_classes) | ||
self.accuracy = None | ||
self.f1 = None | ||
self.auroc = None | ||
|
||
elif task_type in {"binary", "multiclass", "multilabel"}: | ||
# For multilabel, we use num_classes as num_labels | ||
metric_task = task_type | ||
self.accuracy = Accuracy(task=metric_task, num_labels=num_classes) | ||
self.f1 = F1Score(task=metric_task, num_labels=num_classes) | ||
self.auroc = AUROC(task=metric_task, num_labels=num_classes) | ||
self.mse = None | ||
self.r2 = None | ||
self.spearman = None | ||
|
||
else: | ||
raise ValueError("task_type must be: regression, binary, or multiclass") | ||
raise ValueError("task_type must be: regression, binary, multiclass, or multilabel") | ||
|
||
# Dictionary to store trained probes | ||
self.probes: Dict[str, LinearRegression | LogisticRegression] = {} | ||
self.task_type = task_type | ||
self.num_classes = num_classes | ||
|
||
def _skip(self, trainer: L.Trainer) -> bool: | ||
"""Determine if we should skip validation this epoch.""" | ||
|
@@ -88,34 +105,45 @@ def _train_probe(self, embeddings: Tensor, targets: Tensor): | |
|
||
if self.task_type == "regression": | ||
probe = LinearRegression() | ||
else: | ||
probe.fit(embeddings, targets) | ||
|
||
elif self.task_type == "multilabel": | ||
base_classifier = LogisticRegression(random_state=42) | ||
probe = MultiOutputClassifier(base_classifier) | ||
probe.fit(embeddings, targets) | ||
|
||
else: # binary or multiclass | ||
probe = LogisticRegression( | ||
multi_class="ovr" if self.task_type == "binary" else "multinomial", | ||
random_state=42, | ||
) | ||
|
||
probe.fit(embeddings, targets) | ||
probe.fit(embeddings, targets.ravel()) | ||
|
||
return probe | ||
|
||
def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[str, float]: | ||
"""Evaluate a trained probe using task-appropriate metrics.""" | ||
embeddings_np = embeddings.numpy() # Convert to numpy for probe prediction | ||
metrics = {} | ||
|
||
if self.task_type == "regression": | ||
predictions = probe.predict(embeddings.numpy()) | ||
predictions = torch.from_numpy(predictions).float() | ||
predictions_np = probe.predict(embeddings_np) | ||
predictions = torch.from_numpy(predictions_np).float() | ||
|
||
metrics["mse"] = self.mse(predictions, targets).item() | ||
metrics["r2"] = self.r2(predictions, targets).item() | ||
metrics["spearman"] = self.spearman(predictions.squeeze(), targets.squeeze()).item() | ||
|
||
else: # binary or multiclass | ||
pred_probs = probe.predict_proba(embeddings.numpy()) | ||
predictions = torch.from_numpy(pred_probs).float() | ||
|
||
if self.task_type == "binary": | ||
predictions = predictions[:, 1] | ||
else: # binary, multiclass, or multilabel | ||
if self.task_type == "multilabel": | ||
# Get probabilities for each label | ||
predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] for est in probe.estimators_], axis=1) | ||
else: # binary or multiclass | ||
predictions_np = probe.predict_proba(embeddings_np) | ||
if self.task_type == "binary": | ||
predictions_np = predictions_np[:, 1] | ||
|
||
predictions = torch.from_numpy(predictions_np).float() | ||
metrics["accuracy"] = self.accuracy(predictions, targets).item() | ||
metrics["f1"] = self.f1(predictions, targets).item() | ||
metrics["auroc"] = self.auroc(predictions, targets).item() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from ._calm_tasks import CALM_TASKS | ||
from ._moleculeace_tasks import MOLECULEACE_TASKS | ||
|
||
__all__ = ["MOLECULEACE_TASKS"] | ||
__all__ = ["CALM_TASKS", "MOLECULEACE_TASKS"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from enum import Enum | ||
|
||
CALM_TASKS = { | ||
"meltome": ("regression", None), # (task_type, num_classes) | ||
"solubility": ("regression", None), | ||
"localization": ("multilabel", 10), # 10 cellular locations | ||
"protein_abundance": ("regression", None), | ||
"transcript_abundance": ("regression", None), | ||
"function_bp": ("multilabel", 4), # 4 GO terms | ||
"function_cc": ("multilabel", 4), | ||
"function_mf": ("multilabel", 4), | ||
} | ||
|
||
|
||
CALM_DATA_GITHUB_URL = "https://raw.githubusercontent.com/oxpig/CaLM/main/data" | ||
FUNCTION_ZENODO_BASE_URL = ( | ||
"https://zenodo.org/records/14890750/files" # Gene Ontology datasets processed & uploaded on Zenodo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. switch to HF There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was going to write something similar about CALM_DATA_GITHUB_URL -- is this a subset of the calm dataset from Nathan's HF account? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if it's in HF, we can also load it more easily with |
||
) | ||
|
||
|
||
class Species(str, Enum): | ||
ATHALIANA = "athaliana" | ||
DMELANOGASTER = "dmelanogaster" | ||
ECOLI = "ecoli" | ||
HSAPIENS = "hsapiens" | ||
HVOLCANII = "hvolcanii" | ||
PPASTORIS = "ppastoris" | ||
SCEREVISIAE = "scerevisiae" | ||
|
||
|
||
class Task(Enum): | ||
MELTOME = "meltome" | ||
SOLUBILITY = "solubility" | ||
LOCALIZATION = "localization" | ||
PROTEIN_ABUNDANCE = "protein_abundance" | ||
TRANSCRIPT_ABUNDANCE = "transcript_abundance" | ||
FUNCTION_BP = "function_bp" | ||
FUNCTION_CC = "function_cc" | ||
FUNCTION_MF = "function_mf" | ||
|
||
|
||
TASK_SPECIES = { | ||
Task.PROTEIN_ABUNDANCE: [ | ||
Species.ATHALIANA, | ||
Species.DMELANOGASTER, | ||
Species.ECOLI, | ||
Species.HSAPIENS, | ||
Species.SCEREVISIAE, | ||
], | ||
Task.TRANSCRIPT_ABUNDANCE: [ | ||
Species.ATHALIANA, | ||
Species.DMELANOGASTER, | ||
Species.ECOLI, | ||
Species.HSAPIENS, | ||
Species.HVOLCANII, | ||
Species.PPASTORIS, | ||
Species.SCEREVISIAE, | ||
], | ||
} | ||
|
||
# Files hashes to check upstream data files haven't been changed. Makes data download cleaner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing I don't love about hashes is that pooch currently just errors out when the files are different. Maybe we'd want to add custom logic where it just downloads new files instead. Could be something for a future MR though |
||
FILE_HASHES = { | ||
"meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4", | ||
"solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38", | ||
"localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320", | ||
# Function prediction tasks | ||
"calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334", | ||
"calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6", | ||
"calm_GO_mf_middle_normal.parquet": "md5:cafe14db5dda19837bae536399e47e35", | ||
# Protein abundance | ||
"protein_abundance_athaliana.csv": "sha256:83f8d995ee3a0ff6f1ed4e74a9cb891546e2edb6e86334fef3b6901a0039b118", | ||
"protein_abundance_dmelanogaster.csv": "sha256:6f9541d38217f71a4f9addec4ad567d60eee4e4cebb1d16775909b88e1775da4", | ||
"protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb", | ||
"protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23", | ||
"protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076", | ||
# Transcript abundance | ||
"transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139", | ||
"transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32", | ||
"transcript_abundance_ecoli.csv": "sha256:5e480d961c8b9043f6039211241ecf904b188214e9951352a9b2fc3d6a630a59", | ||
"transcript_abundance_hsapiens.csv": "sha256:21b4b3f3f7267d28dbf6434090dfc0c58fde6f15393537d569f0b29e3eeec491", | ||
"transcript_abundance_hvolcanii.csv": "sha256:91782d2839f78b7c3a4c4d2c0f685605fa746e9b3936579fbd91ce875f9860aa", | ||
"transcript_abundance_ppastoris.csv": "sha256:4ebd4783e1b90e76e481c25bce801d4f6984f85d382f5d458d26f554e114798a", | ||
"transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a docstring, especially for explaining
tasks
andspecies
?