Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CalmPropertyDataset & CalmLinearProbe Callbacks #34

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lobster/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._calm_linear_probe_callback import CalmLinearProbeCallback
from ._linear_probe_callback import LinearProbeCallback
from ._moleculeace_linear_probe_callback import MoleculeACELinearProbeCallback

Expand Down
159 changes: 159 additions & 0 deletions src/lobster/callbacks/_calm_linear_probe_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from collections import defaultdict
from typing import Optional, Sequence, Tuple

import lightning as L
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from lobster.constants import CALM_TASKS
from lobster.datasets import CalmPropertyDataset
from lobster.tokenization import NucleotideTokenizerFast
from lobster.transforms import TokenizerTransform

from ._linear_probe_callback import LinearProbeCallback


class CalmLinearProbeCallback(LinearProbeCallback):
"""Callback for evaluating embedding models on the CALM dataset collection."""

def __init__(
self,
max_length: int,
tasks: Optional[Sequence[str]] = None,
species: Optional[Sequence[str]] = None,
batch_size: int = 32,
run_every_n_epochs: Optional[int] = None,
test_size: float = 0.2,
max_samples: int = 3000,
seed: int = 42,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a docstring, especially for explaining tasks and species ?

tokenizer_transform = TokenizerTransform(
tokenizer=NucleotideTokenizerFast(),
padding="max_length",
truncation=True,
max_length=max_length,
)

super().__init__(
transform_fn=tokenizer_transform,
task_type="regression",
batch_size=batch_size,
run_every_n_epochs=run_every_n_epochs,
)

self.tasks = set(tasks) if tasks else set(CALM_TASKS.keys())
self.species = set(species) if species else None

self.test_size = test_size
self.max_samples = max_samples
self.seed = seed

self.dataset_splits = {}
self.aggregate_metrics = defaultdict(list)

def _create_split_datasets(
self, task: str, species: Optional[str] = None
) -> Tuple[CalmPropertyDataset, CalmPropertyDataset]:
"""Create train/test splits for a given task."""

rng = np.random.RandomState(self.seed) # TODO - seed everything fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, does this work better or equivalently to lightning's seed everything?


# Check cache for existing splits
split_key = f"{task}_{species}" if species else task
if split_key in self.dataset_splits:
return self.dataset_splits[split_key]

dataset = CalmPropertyDataset(task=task, species=species, transform_fn=self.transform_fn)

indices = np.arange(len(dataset))

# If dataset is too large, subsample it first
if len(indices) > self.max_samples:
indices = rng.choice(indices, size=self.max_samples, replace=False)

# Create train/test split from (possibly subsampled) indices
test_size = int(len(indices) * self.test_size)
train_size = len(indices) - test_size
shuffled_indices = rng.permutation(indices)
train_indices = shuffled_indices[:train_size]
test_indices = shuffled_indices[train_size:]

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

# Cache the splits
self.dataset_splits[split_key] = (train_dataset, test_dataset)

return train_dataset, test_dataset

def _evaluate_task(
self,
task_key: str,
task: str,
train_dataset,
test_dataset,
trainer: L.Trainer,
pl_module: L.LightningModule,
):
"""Evaluate a single task."""

task_type, num_classes = CALM_TASKS[task]

self._set_metrics(task_type, num_classes)

train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

try:
train_embeddings, train_targets = self._get_embeddings(pl_module, train_loader)
test_embeddings, test_targets = self._get_embeddings(pl_module, test_loader)

probe = self._train_probe(train_embeddings, train_targets)
self.probes[task_key] = probe

metrics = self._evaluate_probe(probe, test_embeddings, test_targets)

# Log metrics and store for averaging
for metric_name, value in metrics.items():
metric_key = f"calm_linear_probe/{task_key}/{metric_name}"
trainer.logger.log_metrics({metric_key: value}, step=trainer.current_epoch)
self.aggregate_metrics[metric_name].append(value)

except Exception as e:
print(f"Error in _evaluate_task for {task_key}: {str(e)}")
raise

def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
if self._skip(trainer):
return

self.device = pl_module.device
self.aggregate_metrics.clear() # Reset aggregation for new epoch

for task in tqdm(self.tasks, desc=f"{self.__class__.__name__}"):
# Handle species-specific tasks
if task in ["protein_abundance", "transcript_abundance"]:
if not self.species:
continue
for species in self.species:
task_key = f"{task}_{species}"
try:
train_dataset, test_dataset = self._create_split_datasets(task, species)
self._evaluate_task(task_key, task, train_dataset, test_dataset, trainer, pl_module)
except Exception as e:
print(f"Error processing {task_key}: {str(e)}")
else:
try:
train_dataset, test_dataset = self._create_split_datasets(task)
self._evaluate_task(task, task, train_dataset, test_dataset, trainer, pl_module)
except Exception as e:
print(f"Error processing {task}: {str(e)}")

# Calculate and log aggregate metrics
for metric_name, values in self.aggregate_metrics.items():
if values: # Only log if we have values
avg_value = sum(values) / len(values)
trainer.logger.log_metrics(
{f"calm_linear_probe/mean/{metric_name}": avg_value}, step=trainer.current_epoch
)
68 changes: 48 additions & 20 deletions src/lobster/callbacks/_linear_probe_callback.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Callable, Dict, Literal, Optional, Tuple

import lightning as L
import numpy as np
import torch
from beignet.transforms import Transform
from lightning.pytorch.callbacks import Callback
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics import AUROC, Accuracy, F1Score, MeanSquaredError, R2Score, SpearmanCorrCoef

TaskType = Literal["regression", "binary", "multiclass"]
TaskType = Literal["regression", "binary", "multiclass", "multilabel"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is multilabel different from multiclass?



class LinearProbeCallback(Callback):
Expand All @@ -31,21 +33,36 @@ def __init__(
self.run_every_n_epochs = run_every_n_epochs

# Initialize metrics based on task type
self._set_metrics(task_type, num_classes)

# Dictionary to store trained probes
self.probes: Dict[str, LinearRegression | LogisticRegression] = {}

def _set_metrics(self, task_type: TaskType, num_classes: Optional[int] = None) -> None:
"""Initialize metrics based on task type."""
if task_type == "regression":
self.mse = MeanSquaredError()
self.r2 = R2Score()
self.spearman = SpearmanCorrCoef()

elif task_type in {"binary", "multiclass"}:
self.accuracy = Accuracy(task=task_type, num_classes=num_classes)
self.f1 = F1Score(task=task_type, num_classes=num_classes)
self.auroc = AUROC(task_type=task_type, num_classes=num_classes)
self.accuracy = None
self.f1 = None
self.auroc = None

elif task_type in {"binary", "multiclass", "multilabel"}:
# For multilabel, we use num_classes as num_labels
metric_task = task_type
self.accuracy = Accuracy(task=metric_task, num_labels=num_classes)
self.f1 = F1Score(task=metric_task, num_labels=num_classes)
self.auroc = AUROC(task=metric_task, num_labels=num_classes)
self.mse = None
self.r2 = None
self.spearman = None

else:
raise ValueError("task_type must be: regression, binary, or multiclass")
raise ValueError("task_type must be: regression, binary, multiclass, or multilabel")

# Dictionary to store trained probes
self.probes: Dict[str, LinearRegression | LogisticRegression] = {}
self.task_type = task_type
self.num_classes = num_classes

def _skip(self, trainer: L.Trainer) -> bool:
"""Determine if we should skip validation this epoch."""
Expand Down Expand Up @@ -88,34 +105,45 @@ def _train_probe(self, embeddings: Tensor, targets: Tensor):

if self.task_type == "regression":
probe = LinearRegression()
else:
probe.fit(embeddings, targets)

elif self.task_type == "multilabel":
base_classifier = LogisticRegression(random_state=42)
probe = MultiOutputClassifier(base_classifier)
probe.fit(embeddings, targets)

else: # binary or multiclass
probe = LogisticRegression(
multi_class="ovr" if self.task_type == "binary" else "multinomial",
random_state=42,
)

probe.fit(embeddings, targets)
probe.fit(embeddings, targets.ravel())

return probe

def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[str, float]:
"""Evaluate a trained probe using task-appropriate metrics."""
embeddings_np = embeddings.numpy() # Convert to numpy for probe prediction
metrics = {}

if self.task_type == "regression":
predictions = probe.predict(embeddings.numpy())
predictions = torch.from_numpy(predictions).float()
predictions_np = probe.predict(embeddings_np)
predictions = torch.from_numpy(predictions_np).float()

metrics["mse"] = self.mse(predictions, targets).item()
metrics["r2"] = self.r2(predictions, targets).item()
metrics["spearman"] = self.spearman(predictions.squeeze(), targets.squeeze()).item()

else: # binary or multiclass
pred_probs = probe.predict_proba(embeddings.numpy())
predictions = torch.from_numpy(pred_probs).float()

if self.task_type == "binary":
predictions = predictions[:, 1]
else: # binary, multiclass, or multilabel
if self.task_type == "multilabel":
# Get probabilities for each label
predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] for est in probe.estimators_], axis=1)
else: # binary or multiclass
predictions_np = probe.predict_proba(embeddings_np)
if self.task_type == "binary":
predictions_np = predictions_np[:, 1]

predictions = torch.from_numpy(predictions_np).float()
metrics["accuracy"] = self.accuracy(predictions, targets).item()
metrics["f1"] = self.f1(predictions, targets).item()
metrics["auroc"] = self.auroc(predictions, targets).item()
Expand Down
3 changes: 2 additions & 1 deletion src/lobster/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._calm_tasks import CALM_TASKS
from ._moleculeace_tasks import MOLECULEACE_TASKS

__all__ = ["MOLECULEACE_TASKS"]
__all__ = ["CALM_TASKS", "MOLECULEACE_TASKS"]
84 changes: 84 additions & 0 deletions src/lobster/constants/_calm_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from enum import Enum

CALM_TASKS = {
"meltome": ("regression", None), # (task_type, num_classes)
"solubility": ("regression", None),
"localization": ("multilabel", 10), # 10 cellular locations
"protein_abundance": ("regression", None),
"transcript_abundance": ("regression", None),
"function_bp": ("multilabel", 4), # 4 GO terms
"function_cc": ("multilabel", 4),
"function_mf": ("multilabel", 4),
}


CALM_DATA_GITHUB_URL = "https://raw.githubusercontent.com/oxpig/CaLM/main/data"
FUNCTION_ZENODO_BASE_URL = (
"https://zenodo.org/records/14890750/files" # Gene Ontology datasets processed & uploaded on Zenodo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch to HF

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to write something similar about CALM_DATA_GITHUB_URL -- is this a subset of the calm dataset from Nathan's HF account?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's in HF, we can also load it more easily with datahaha once we switch

)


class Species(str, Enum):
ATHALIANA = "athaliana"
DMELANOGASTER = "dmelanogaster"
ECOLI = "ecoli"
HSAPIENS = "hsapiens"
HVOLCANII = "hvolcanii"
PPASTORIS = "ppastoris"
SCEREVISIAE = "scerevisiae"


class Task(Enum):
MELTOME = "meltome"
SOLUBILITY = "solubility"
LOCALIZATION = "localization"
PROTEIN_ABUNDANCE = "protein_abundance"
TRANSCRIPT_ABUNDANCE = "transcript_abundance"
FUNCTION_BP = "function_bp"
FUNCTION_CC = "function_cc"
FUNCTION_MF = "function_mf"


TASK_SPECIES = {
Task.PROTEIN_ABUNDANCE: [
Species.ATHALIANA,
Species.DMELANOGASTER,
Species.ECOLI,
Species.HSAPIENS,
Species.SCEREVISIAE,
],
Task.TRANSCRIPT_ABUNDANCE: [
Species.ATHALIANA,
Species.DMELANOGASTER,
Species.ECOLI,
Species.HSAPIENS,
Species.HVOLCANII,
Species.PPASTORIS,
Species.SCEREVISIAE,
],
}

# Files hashes to check upstream data files haven't been changed. Makes data download cleaner
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I don't love about hashes is that pooch currently just errors out when the files are different. Maybe we'd want to add custom logic where it just downloads new files instead. Could be something for a future MR though

FILE_HASHES = {
"meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4",
"solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38",
"localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320",
# Function prediction tasks
"calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334",
"calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6",
"calm_GO_mf_middle_normal.parquet": "md5:cafe14db5dda19837bae536399e47e35",
# Protein abundance
"protein_abundance_athaliana.csv": "sha256:83f8d995ee3a0ff6f1ed4e74a9cb891546e2edb6e86334fef3b6901a0039b118",
"protein_abundance_dmelanogaster.csv": "sha256:6f9541d38217f71a4f9addec4ad567d60eee4e4cebb1d16775909b88e1775da4",
"protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb",
"protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23",
"protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076",
# Transcript abundance
"transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139",
"transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32",
"transcript_abundance_ecoli.csv": "sha256:5e480d961c8b9043f6039211241ecf904b188214e9951352a9b2fc3d6a630a59",
"transcript_abundance_hsapiens.csv": "sha256:21b4b3f3f7267d28dbf6434090dfc0c58fde6f15393537d569f0b29e3eeec491",
"transcript_abundance_hvolcanii.csv": "sha256:91782d2839f78b7c3a4c4d2c0f685605fa746e9b3936579fbd91ce875f9860aa",
"transcript_abundance_ppastoris.csv": "sha256:4ebd4783e1b90e76e481c25bce801d4f6984f85d382f5d458d26f554e114798a",
"transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36",
}
2 changes: 2 additions & 0 deletions src/lobster/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._calm_dataset import CalmDataset
from ._calm_property_dataset import CalmPropertyDataset
from ._fasta_dataset import FASTADataset
from ._m3_20m_dataset import M320MDataset
from ._moleculeace_dataset import MoleculeACEDataset
Expand All @@ -7,6 +8,7 @@
"M320MDataset",
"ChEMBLDataset",
"CalmDataset",
"CalmPropertyDataset",
"FASTADataset",
"MoleculeACEDataset",
]
Loading
Loading