From 825743b5df290acfb83db354f9018cd11c17ca25 Mon Sep 17 00:00:00 2001 From: Ed Wagstaff Date: Thu, 6 Feb 2025 18:09:49 +0000 Subject: [PATCH 1/7] FlexBERT tweaks --- src/lobster/model/modern_bert/_modern_bert.py | 76 +++++-------------- 1 file changed, 18 insertions(+), 58 deletions(-) diff --git a/src/lobster/model/modern_bert/_modern_bert.py b/src/lobster/model/modern_bert/_modern_bert.py index c9d8ae5..63d21f1 100644 --- a/src/lobster/model/modern_bert/_modern_bert.py +++ b/src/lobster/model/modern_bert/_modern_bert.py @@ -1,18 +1,12 @@ import importlib.resources from importlib.util import find_spec -from typing import Literal, Union - import lightning.pytorch as pl import torch from torch import nn +from lobster.tokenization._pmlm_tokenizer import PmlmTokenizer from transformers.optimization import get_linear_schedule_with_warmup -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast - -from lobster.tokenization import PmlmTokenizer, SmilesTokenizerFast, AminoAcidTokenizerFast, NucleotideTokenizerFast -from lobster.tokenization._pmlm_tokenizer_transform import \ - PmlmTokenizerTransform -from lobster.transforms import TokenizerTransform +from lobster.tokenization._pmlm_tokenizer_transform import PmlmTokenizerTransform from ._config import FlexBertConfig from ._model import FlexBertModel, FlexBertPredictionHead @@ -20,13 +14,12 @@ if find_spec("flash_attn"): from flash_attn.losses.cross_entropy import CrossEntropyLoss - _FLASH_ATTN_AVAILABLE = True else: from torch.nn import CrossEntropyLoss - class FlexBERT(pl.LightningModule): + def __init__( self, lr: float = 1e-3, @@ -35,8 +28,8 @@ def __init__( eps: float = 1e-12, num_training_steps: int = 10_000, num_warmup_steps: int = 1_000, - tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast] = "amino_acid_tokenizer", - mask_percentage: float = 0.25, + tokenizer_dir: str = "pmlm_tokenizer", + mask_percentage: float = 0.15, max_length: int = 512, **model_kwargs, ): @@ -50,46 +43,18 @@ def __init__( self._mask_percentage = mask_percentage self.max_length = max_length - # TODO zadorozk: currently only accepts one tokenizer at a time - # Extend to accept multiple tokenizers for each modality - if isinstance(tokenizer, str): - if tokenizer == "pmlm_tokenizer": - path = importlib.resources.files("lobster") / "assets" / "plm_tokenizer" - tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False) - tokenizer_transform_class = PmlmTokenizerTransform - - elif tokenizer == "amino_acid_tokenizer": - tokenizer = AminoAcidTokenizerFast() - tokenizer_transform_class = TokenizerTransform - - elif tokenizer == "nucleotide_tokenizer": - tokenizer = NucleotideTokenizerFast() - tokenizer_transform_class = TokenizerTransform - - elif tokenizer == "smiles_tokenizer": - tokenizer = SmilesTokenizerFast() - tokenizer_transform_class = TokenizerTransform - else: - raise NotImplementedError(f"Tokenizer `{tokenizer}` not supported") - else: - if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - raise ValueError("Custom `tokenizer` must be an instance of `PreTrainedTokenizer` or `PreTrainedTokenizerFast`") - - tokenizer = tokenizer - tokenizer_transform_class = TokenizerTransform - - self.tokenizer = tokenizer - - self.tokenize_transform = tokenizer_transform_class( - tokenizer, - max_length=max_length, + path = importlib.resources.files("lobster") / "assets" / tokenizer_dir + self.tokenizer: PmlmTokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False) + self.tokenize_transform = PmlmTokenizerTransform( + path, padding="max_length", - truncation=True - ) + max_length=self.max_length, + truncation=True, + ) self.config = FlexBertConfig( vocab_size=self.tokenizer.vocab_size, - pad_token_id=self.tokenizer.pad_token_id, + pad_token_id = self.tokenizer.pad_token_id, **model_kwargs, ) self.model = FlexBertModel(self.config) @@ -163,27 +128,22 @@ def tokens_to_latents(self, input_ids: torch.Tensor, attention_mask: torch.Tenso def sequences_to_latents(self, sequences: list[str]) -> list[torch.Tensor]: transformed_sequences = self.tokenize_transform(sequences) input_ids = torch.concat([batch["input_ids"].squeeze(0) for batch in transformed_sequences]).to(self.device) - attention_mask = torch.concat([batch["attention_mask"].squeeze(0) for batch in transformed_sequences]).to( - self.device - ) + attention_mask = torch.concat([batch["attention_mask"].squeeze(0) for batch in transformed_sequences]).to(self.device) seqlens = [batch["input_ids"].size(1) for batch in transformed_sequences] - cu_seqlens = torch.tensor([sum(seqlens[:i]) for i in range(len(seqlens) + 1)], dtype=torch.int32).to( - self.device - ) + cu_seqlens = torch.tensor([sum(seqlens[:i]) for i in range(len(seqlens) + 1)], dtype=torch.int32).to(self.device) with torch.inference_mode(): hidden_states = self.model(input_ids, attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=self.max_length) - return [hidden_states[cu_seqlens[i] : cu_seqlens[i + 1]] for i in range(len(cu_seqlens) - 1)] + return [hidden_states[cu_seqlens[i]:cu_seqlens[i+1]] for i in range(len(cu_seqlens) - 1)] def _compute_loss(self, batch): - if isinstance(batch, tuple) and len(batch) == 2: - batch, _targets = batch + batch, _targets = batch tokens = batch["input_ids"].squeeze(1) B, length = tokens.shape tokens = tokens.view(-1) - attention_mask = batch["attention_mask"].squeeze(1).view(-1) + attention_mask=batch["attention_mask"].squeeze(1).view(-1) labels = tokens.clone() From 5d0eecce3ccaedce22868bcd8be6d00150234355 Mon Sep 17 00:00:00 2001 From: Ed Wagstaff Date: Thu, 6 Feb 2025 18:30:24 +0000 Subject: [PATCH 2/7] ruff --- tests/lobster/model/test__lobsterfold.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/lobster/model/test__lobsterfold.py b/tests/lobster/model/test__lobsterfold.py index 87ecf0b..3515202 100644 --- a/tests/lobster/model/test__lobsterfold.py +++ b/tests/lobster/model/test__lobsterfold.py @@ -1,14 +1,15 @@ import os -from io import StringIO import pytest import torch -from Bio.PDB import PDBParser, Superimposer from lobster.data import PDBDataModule from lobster.extern.openfold_utils import backbone_loss from lobster.model import LobsterPLMFold from lobster.transforms import StructureFeaturizer from torch import Size, Tensor +from Bio.PDB import PDBParser, Superimposer +from io import StringIO + torch.backends.cuda.matmul.allow_tf32 = True @@ -28,7 +29,7 @@ def example_fv(): @pytest.fixture -def example_fv_pdb(scope="session"): +def example_fv(scope="session"): return os.path.join(os.path.dirname(__file__), "../../../test_data/fv.pdb") @@ -79,10 +80,6 @@ def test_dataloader_tokenizer(self, model): def test_predict_fv(self, model, example_fv): pdb_string = model.predict_fv(example_fv[0], example_fv[1]) - # NOTE from zadorozk: ruff checks were failing because ground_truth_file was not defined - # TODO FIXME - ground_truth_file = None - # Parse the input PDB string parser = PDBParser(QUIET=True) structure1 = parser.get_structure("pdb_string_structure", StringIO(pdb_string)) From 68bb4fd32fd685d53dac66041b43670b66bcbf9d Mon Sep 17 00:00:00 2001 From: Taylor Joren Date: Wed, 19 Feb 2025 06:01:36 +0000 Subject: [PATCH 3/7] add calm tasks melt, solu, cell_loc, function --- src/lobster/constants/_calm_tasks.py | 62 ++++++ src/lobster/datasets/__init__.py | 2 + .../datasets/_calm_property_dataset.py | 181 ++++++++++++++++++ 3 files changed, 245 insertions(+) create mode 100644 src/lobster/constants/_calm_tasks.py create mode 100644 src/lobster/datasets/_calm_property_dataset.py diff --git a/src/lobster/constants/_calm_tasks.py b/src/lobster/constants/_calm_tasks.py new file mode 100644 index 0000000..465e6a7 --- /dev/null +++ b/src/lobster/constants/_calm_tasks.py @@ -0,0 +1,62 @@ +from enum import Enum + +# TODO - add to __init__ + +CALM_DATA_GITHUB_URL = "https://raw.githubusercontent.com/oxpig/CaLM/main/data" + +class Species(str, Enum): + ATHALIANA = "athaliana" + DMELANOGASTER = "dmelanogaster" + ECOLI = "ecoli" + HSAPIENS = "hsapiens" + HVOLCANII = "hvolcanii" + PPASTORIS = "ppastoris" + SCEREVISIAE = "scerevisiae" + +class Task(Enum): + MELTOME = "meltome" + SOLUBILITY = "solubility" + LOCALIZATION = "localization" + PROTEIN_ABUNDANCE = "protein_abundance" + TRANSCRIPT_ABUNDANCE = "transcript_abundance" + SPECIES = "species" + FUNCTION_BP = "function_bp" # Separate task for each function type + FUNCTION_CC = "function_cc" + FUNCTION_MF = "function_mf" + +FUNCTION_ZENODO_BASE_URL = "https://zenodo.org/records/14890750/files" +FUNCTION_HASHES = { + "function_bp": "md5:898265de59ba1ac97270bffc3621f334", + "function_cc": "md5:a6af91fe40e523c9adf47e6abd98d9c6", + "function_mf": "md5:cafe14db5dda19837bae536399e47e35" +} + + +TASK_SPECIES = { + Task.PROTEIN_ABUNDANCE: [ + Species.ATHALIANA, + Species.DMELANOGASTER, + Species.ECOLI, + Species.HSAPIENS, + Species.SCEREVISIAE, + ], + Task.TRANSCRIPT_ABUNDANCE: [ + Species.ATHALIANA, + Species.DMELANOGASTER, + Species.ECOLI, + Species.HSAPIENS, + Species.HVOLCANII, + Species.PPASTORIS, + Species.SCEREVISIAE, + ], + Task.SPECIES: [ + Species.ATHALIANA, + Species.DMELANOGASTER, + Species.ECOLI, + Species.HSAPIENS, + Species.HVOLCANII, + Species.PPASTORIS, + Species.SCEREVISIAE, + ] +} + diff --git a/src/lobster/datasets/__init__.py b/src/lobster/datasets/__init__.py index 9f28151..8fcee47 100644 --- a/src/lobster/datasets/__init__.py +++ b/src/lobster/datasets/__init__.py @@ -1,4 +1,5 @@ from ._calm_dataset import CalmDataset +from ._calm_property_dataset import CalmPropertyDataset from ._fasta_dataset import FASTADataset from ._m3_20m_dataset import M320MDataset from ._moleculeace_dataset import MoleculeACEDataset @@ -7,6 +8,7 @@ "M320MDataset", "ChEMBLDataset", "CalmDataset", + "CalmPropertyDataset", "FASTADataset", "MoleculeACEDataset", ] diff --git a/src/lobster/datasets/_calm_property_dataset.py b/src/lobster/datasets/_calm_property_dataset.py new file mode 100644 index 0000000..cb0b092 --- /dev/null +++ b/src/lobster/datasets/_calm_property_dataset.py @@ -0,0 +1,181 @@ +from enum import Enum +from pathlib import Path +import pandas as pd +import pooch +from torch.utils.data import Dataset +from typing import Optional, Literal, Callable, Sequence, Tuple + +from lobster.transforms import Transform +from lobster.constants._calm_tasks import CALM_DATA_GITHUB_URL, Species, Task, TASK_SPECIES, FUNCTION_HASHES, FUNCTION_ZENODO_BASE_URL + + +class CalmPropertyDataset(Dataset): + """ + Dataset from Outeiral, C., Deane, C.M. with additional function prediction support. + Each function type (BP, CC, MF) is treated as a separate task. + """ + + def __init__( + self, + task: Task | str, + root: str | Path | None = None, + *, + species: Optional[Species | str] = None, + split: Optional[Literal["train", "validation", "test"]] = None, + download: bool = True, + transform: Optional[Callable | Transform] = None, + columns: Optional[Sequence[str]] = None, + known_hash: Optional[str] = None, + ): + super().__init__() + + if isinstance(task, str): + task = Task(task) + if isinstance(species, str): + species = Species(species) + + self.task = task + self.species = species + self.split = split + self.transform = transform + + if root is None: + root = pooch.os_cache("lbster") + if isinstance(root, str): + root = Path(root) + self.root = root.resolve() + + # Determine file name and URL based on task type + if task in [Task.FUNCTION_BP, Task.FUNCTION_CC, Task.FUNCTION_MF]: + function_type = task.value.split('_')[1].lower() # Extract bp, cc, or mf + fname = f"calm_GO_{function_type}_middle_normal.parquet" + url = f"{FUNCTION_ZENODO_BASE_URL}/{fname}" + storage_fname = fname + + # Use predefined hash if none provided + if known_hash is None: + known_hash = FUNCTION_HASHES[task.value] + + else: + if task in [Task.MELTOME, Task.SOLUBILITY, Task.LOCALIZATION]: + fname = f"{task.value}/{task.value}_data.csv" + url = f"{CALM_DATA_GITHUB_URL}/{fname}" + elif task in [Task.PROTEIN_ABUNDANCE, Task.TRANSCRIPT_ABUNDANCE]: + if species is None: + raise ValueError(f"Must specify species for {task.value} task") + if species not in TASK_SPECIES[task]: + raise ValueError(f"Species {species.value} not available for {task.value} task") + fname = f"{task.value}/{species.value}.csv" + url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{fname}" + else: # species task + if split is None and species is None: + raise ValueError("Must specify either split or species for species task") + if split is not None: + fname = f"{split}.fasta" + url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{split}/{fname}" + else: + fname = f"{species.value}.fasta" + url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{fname}" + + storage_fname = f"{self.__class__.__name__}_{task.value}" + if species: + storage_fname += f"_{species.value}" + if split: + storage_fname += f"_{split}" + storage_fname += Path(fname).suffix + + # Download or load the file + file_path = Path(self.root / self.__class__.__name__ / storage_fname) + if download: + file_path = pooch.retrieve( + url=url, + fname=storage_fname, + known_hash=known_hash, + path=self.root / self.__class__.__name__, + progressbar=True, + ) + elif not file_path.exists(): + raise FileNotFoundError( + f"Data file {file_path} not found and download=False" + ) + + # Load the data based on file type + if str(file_path).endswith('.fasta'): + self.data = parse_fasta(file_path) + if species: + self.data['species'] = species.value + elif str(file_path).endswith('.parquet'): + self.data = pd.read_parquet(file_path) + elif str(file_path).endswith('.tsv'): + self.data = pd.read_csv(file_path, sep='\t') + else: + self.data = pd.read_csv(file_path) + + if columns is None: + if task == Task.FUNCTION_BP: + columns = ['sequence', "GO:0051092", "GO:0016573", "GO:0031146", "GO:0071427", "GO:0006613"] + elif task == Task.FUNCTION_CC: + columns = ['sequence', "GO:0022627", "GO:0000502", "GO:0034705", "GO:0030665", "GO:0005925"] + elif task == Task.FUNCTION_MF: + columns = ['sequence', "GO:0004843", "GO:0004714", "GO:0003774", "GO:0008227", "GO:0004866"] + elif task == Task.SPECIES: + columns = ["sequence", "description"] + if species: + columns.append("species") + elif task == Task.LOCALIZATION: + columns = ['Sequence', 'Cell membrane', 'Cytoplasm', 'Endoplasmic reticulum', + 'Extracellular', 'Golgi apparatus', 'Lysosome/Vacuole', + 'Mitochondrion', 'Nucleus', 'Peroxisome', 'Plastid'] + elif task == Task.MELTOME: + columns = ['sequence', 'melting_temperature'] + elif task == Task.SOLUBILITY: + columns = ['cds', 'solubility'] + else: + columns = list(self.data.columns) + + self.columns = columns + self._x = self.data[self.columns].apply(tuple, axis=1) + + def __getitem__(self, index: int) -> Tuple: + x = self._x[index] + + if len(x) == 1: + x = x[0] + + if self.transform is not None: + x = self.transform(x) + + return x + + def __len__(self) -> int: + return len(self._x) + + +def parse_fasta(path: Path) -> pd.DataFrame: + """Parse a FASTA file into a DataFrame with sequence and description columns.""" + sequences = [] + descriptions = [] + + with open(path) as f: + current_description = None + current_sequence = [] + + for line in f: + line = line.strip() + if line.startswith('>'): + if current_description is not None: + sequences.append(''.join(current_sequence)) + descriptions.append(current_description) + current_sequence = [] + current_description = line[1:] + else: + current_sequence.append(line) + + if current_description is not None: + sequences.append(''.join(current_sequence)) + descriptions.append(current_description) + + return pd.DataFrame({ + 'sequence': sequences, + 'description': descriptions + }) From ee11722b8cb1f5f6f568ebdb6a47107d9e88adef Mon Sep 17 00:00:00 2001 From: Taylor Joren Date: Wed, 19 Feb 2025 18:28:32 +0000 Subject: [PATCH 4/7] add complete calm property dataset --- src/lobster/constants/_calm_tasks.py | 48 +++++++---- .../datasets/_calm_property_dataset.py | 86 ++++--------------- 2 files changed, 44 insertions(+), 90 deletions(-) diff --git a/src/lobster/constants/_calm_tasks.py b/src/lobster/constants/_calm_tasks.py index 465e6a7..01bd170 100644 --- a/src/lobster/constants/_calm_tasks.py +++ b/src/lobster/constants/_calm_tasks.py @@ -3,6 +3,34 @@ # TODO - add to __init__ CALM_DATA_GITHUB_URL = "https://raw.githubusercontent.com/oxpig/CaLM/main/data" +FUNCTION_ZENODO_BASE_URL = "https://zenodo.org/records/14890750/files" # Gene Ontology datasets processed & uploaded on Zenodo +FILE_HASHES = { + "meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4", + "solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38", + "localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320", + + # Function prediction tasks + "calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334", + "calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6", + "calm_GO_mf_middle_normal.parquet": "md5:cafe14db5dda19837bae536399e47e35", + + # Protein abundance + "protein_abundance_athaliana.csv": "sha256:83f8d995ee3a0ff6f1ed4e74a9cb891546e2edb6e86334fef3b6901a0039b118", + "protein_abundance_dmelanogaster.csv": "sha256:6f9541d38217f71a4f9addec4ad567d60eee4e4cebb1d16775909b88e1775da4", + "protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb", + "protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23", + "protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076", + + # Transcript abundance + "transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139", + "transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32", + "transcript_abundance_ecoli.csv": "sha256:5e480d961c8b9043f6039211241ecf904b188214e9951352a9b2fc3d6a630a59", + "transcript_abundance_hsapiens.csv": "sha256:21b4b3f3f7267d28dbf6434090dfc0c58fde6f15393537d569f0b29e3eeec491", + "transcript_abundance_hvolcanii.csv": "sha256:91782d2839f78b7c3a4c4d2c0f685605fa746e9b3936579fbd91ce875f9860aa", + "transcript_abundance_ppastoris.csv": "sha256:4ebd4783e1b90e76e481c25bce801d4f6984f85d382f5d458d26f554e114798a", + "transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36" +} + class Species(str, Enum): ATHALIANA = "athaliana" @@ -19,19 +47,10 @@ class Task(Enum): LOCALIZATION = "localization" PROTEIN_ABUNDANCE = "protein_abundance" TRANSCRIPT_ABUNDANCE = "transcript_abundance" - SPECIES = "species" - FUNCTION_BP = "function_bp" # Separate task for each function type + FUNCTION_BP = "function_bp" FUNCTION_CC = "function_cc" FUNCTION_MF = "function_mf" -FUNCTION_ZENODO_BASE_URL = "https://zenodo.org/records/14890750/files" -FUNCTION_HASHES = { - "function_bp": "md5:898265de59ba1ac97270bffc3621f334", - "function_cc": "md5:a6af91fe40e523c9adf47e6abd98d9c6", - "function_mf": "md5:cafe14db5dda19837bae536399e47e35" -} - - TASK_SPECIES = { Task.PROTEIN_ABUNDANCE: [ Species.ATHALIANA, @@ -48,15 +67,6 @@ class Task(Enum): Species.HVOLCANII, Species.PPASTORIS, Species.SCEREVISIAE, - ], - Task.SPECIES: [ - Species.ATHALIANA, - Species.DMELANOGASTER, - Species.ECOLI, - Species.HSAPIENS, - Species.HVOLCANII, - Species.PPASTORIS, - Species.SCEREVISIAE, ] } diff --git a/src/lobster/datasets/_calm_property_dataset.py b/src/lobster/datasets/_calm_property_dataset.py index cb0b092..0a824e9 100644 --- a/src/lobster/datasets/_calm_property_dataset.py +++ b/src/lobster/datasets/_calm_property_dataset.py @@ -6,15 +6,10 @@ from typing import Optional, Literal, Callable, Sequence, Tuple from lobster.transforms import Transform -from lobster.constants._calm_tasks import CALM_DATA_GITHUB_URL, Species, Task, TASK_SPECIES, FUNCTION_HASHES, FUNCTION_ZENODO_BASE_URL +from lobster.constants._calm_tasks import CALM_DATA_GITHUB_URL, Species, Task, TASK_SPECIES, FUNCTION_ZENODO_BASE_URL, FILE_HASHES class CalmPropertyDataset(Dataset): - """ - Dataset from Outeiral, C., Deane, C.M. with additional function prediction support. - Each function type (BP, CC, MF) is treated as a separate task. - """ - def __init__( self, task: Task | str, @@ -51,40 +46,24 @@ def __init__( fname = f"calm_GO_{function_type}_middle_normal.parquet" url = f"{FUNCTION_ZENODO_BASE_URL}/{fname}" storage_fname = fname - - # Use predefined hash if none provided - if known_hash is None: - known_hash = FUNCTION_HASHES[task.value] - else: if task in [Task.MELTOME, Task.SOLUBILITY, Task.LOCALIZATION]: - fname = f"{task.value}/{task.value}_data.csv" - url = f"{CALM_DATA_GITHUB_URL}/{fname}" + fname = f"{task.value}_data.csv" + url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{fname}" + storage_fname = f"{task.value}.csv" elif task in [Task.PROTEIN_ABUNDANCE, Task.TRANSCRIPT_ABUNDANCE]: if species is None: raise ValueError(f"Must specify species for {task.value} task") if species not in TASK_SPECIES[task]: raise ValueError(f"Species {species.value} not available for {task.value} task") - fname = f"{task.value}/{species.value}.csv" + fname = f"{species.value}.csv" url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{fname}" - else: # species task - if split is None and species is None: - raise ValueError("Must specify either split or species for species task") - if split is not None: - fname = f"{split}.fasta" - url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{split}/{fname}" - else: - fname = f"{species.value}.fasta" - url = f"{CALM_DATA_GITHUB_URL}/{task.value}/{fname}" + storage_fname = f"{task.value}_{species.value}.csv" - storage_fname = f"{self.__class__.__name__}_{task.value}" - if species: - storage_fname += f"_{species.value}" - if split: - storage_fname += f"_{split}" - storage_fname += Path(fname).suffix + # Get hash for the storage filename --> to ensure file has not changed + if known_hash is None and storage_fname in FILE_HASHES: + known_hash = FILE_HASHES[storage_fname] - # Download or load the file file_path = Path(self.root / self.__class__.__name__ / storage_fname) if download: file_path = pooch.retrieve( @@ -99,12 +78,7 @@ def __init__( f"Data file {file_path} not found and download=False" ) - # Load the data based on file type - if str(file_path).endswith('.fasta'): - self.data = parse_fasta(file_path) - if species: - self.data['species'] = species.value - elif str(file_path).endswith('.parquet'): + if str(file_path).endswith('.parquet'): self.data = pd.read_parquet(file_path) elif str(file_path).endswith('.tsv'): self.data = pd.read_csv(file_path, sep='\t') @@ -118,10 +92,6 @@ def __init__( columns = ['sequence', "GO:0022627", "GO:0000502", "GO:0034705", "GO:0030665", "GO:0005925"] elif task == Task.FUNCTION_MF: columns = ['sequence', "GO:0004843", "GO:0004714", "GO:0003774", "GO:0008227", "GO:0004866"] - elif task == Task.SPECIES: - columns = ["sequence", "description"] - if species: - columns.append("species") elif task == Task.LOCALIZATION: columns = ['Sequence', 'Cell membrane', 'Cytoplasm', 'Endoplasmic reticulum', 'Extracellular', 'Golgi apparatus', 'Lysosome/Vacuole', @@ -130,6 +100,10 @@ def __init__( columns = ['sequence', 'melting_temperature'] elif task == Task.SOLUBILITY: columns = ['cds', 'solubility'] + elif task == Task.PROTEIN_ABUNDANCE: + columns = ['cds', 'abundance'] + elif task == Task.TRANSCRIPT_ABUNDANCE: + columns = ['cds', 'logtpm'] else: columns = list(self.data.columns) @@ -148,34 +122,4 @@ def __getitem__(self, index: int) -> Tuple: return x def __len__(self) -> int: - return len(self._x) - - -def parse_fasta(path: Path) -> pd.DataFrame: - """Parse a FASTA file into a DataFrame with sequence and description columns.""" - sequences = [] - descriptions = [] - - with open(path) as f: - current_description = None - current_sequence = [] - - for line in f: - line = line.strip() - if line.startswith('>'): - if current_description is not None: - sequences.append(''.join(current_sequence)) - descriptions.append(current_description) - current_sequence = [] - current_description = line[1:] - else: - current_sequence.append(line) - - if current_description is not None: - sequences.append(''.join(current_sequence)) - descriptions.append(current_description) - - return pd.DataFrame({ - 'sequence': sequences, - 'description': descriptions - }) + return len(self._x) \ No newline at end of file From 6f3e07cf4b1de532a5e595dc10bbcdc7929868c8 Mon Sep 17 00:00:00 2001 From: Taylor Joren Date: Thu, 20 Feb 2025 19:18:17 +0000 Subject: [PATCH 5/7] add calm linear probe, multilabel evals, tests --- src/lobster/callbacks/__init__.py | 1 + .../callbacks/_calm_linear_probe_callback.py | 161 ++++++++++++++++++ .../callbacks/_linear_probe_callback.py | 86 +++++++--- src/lobster/constants/__init__.py | 5 +- src/lobster/constants/_calm_tasks.py | 68 +++++--- .../datasets/_calm_property_dataset.py | 49 ++++-- 6 files changed, 301 insertions(+), 69 deletions(-) create mode 100644 src/lobster/callbacks/_calm_linear_probe_callback.py diff --git a/src/lobster/callbacks/__init__.py b/src/lobster/callbacks/__init__.py index 59536fb..087676b 100644 --- a/src/lobster/callbacks/__init__.py +++ b/src/lobster/callbacks/__init__.py @@ -1,3 +1,4 @@ +from ._calm_linear_probe_callback import CalmLinearProbeCallback from ._linear_probe_callback import LinearProbeCallback from ._moleculeace_linear_probe_callback import MoleculeACELinearProbeCallback diff --git a/src/lobster/callbacks/_calm_linear_probe_callback.py b/src/lobster/callbacks/_calm_linear_probe_callback.py new file mode 100644 index 0000000..44cba58 --- /dev/null +++ b/src/lobster/callbacks/_calm_linear_probe_callback.py @@ -0,0 +1,161 @@ +from collections import defaultdict +from typing import Optional, Sequence, Tuple +import torch +from torch.utils.data import DataLoader, random_split, Subset +import numpy as np +import lightning as L +from tqdm import tqdm + +from lobster.constants import CALM_TASKS +from lobster.datasets import CalmPropertyDataset +from lobster.tokenization import NucleotideTokenizerFast +from lobster.transforms import TokenizerTransform +from ._linear_probe_callback import LinearProbeCallback + + +class CalmLinearProbeCallback(LinearProbeCallback): + """Callback for evaluating embedding models on the CALM dataset collection.""" + + def __init__( + self, + max_length: int, + tasks: Optional[Sequence[str]] = None, + species: Optional[Sequence[str]] = None, + batch_size: int = 32, + run_every_n_epochs: Optional[int] = None, + test_size: float = 0.2, + max_samples: int = 3000, + seed: int = 42, + ): + tokenizer_transform = TokenizerTransform( + tokenizer=NucleotideTokenizerFast(), + padding="max_length", + truncation=True, + max_length=max_length, + ) + + super().__init__( + transform_fn=tokenizer_transform, + task_type="regression", + batch_size=batch_size, + run_every_n_epochs=run_every_n_epochs, + ) + + self.tasks = set(tasks) if tasks else set(CALM_TASKS.keys()) + self.species = set(species) if species else None + + self.test_size = test_size + self.max_samples = max_samples + self.seed = seed + + self.dataset_splits = {} + self.aggregate_metrics = defaultdict(list) + + def _create_split_datasets( + self, + task: str, + species: Optional[str] = None + ) -> Tuple[CalmPropertyDataset, CalmPropertyDataset]: + """Create train/test splits for a given task.""" + + rng = np.random.RandomState(self.seed) # TODO - seed everything fn + + # Check cache for existing splits + split_key = f"{task}_{species}" if species else task + if split_key in self.dataset_splits: + return self.dataset_splits[split_key] + + dataset = CalmPropertyDataset(task=task, species=species, transform_fn=self.transform_fn) + + indices = np.arange(len(dataset)) + + # If dataset is too large, subsample it first + if len(indices) > self.max_samples: + indices = rng.choice(indices, size=self.max_samples, replace=False) + + # Create train/test split from (possibly subsampled) indices + test_size = int(len(indices) * self.test_size) + train_size = len(indices) - test_size + shuffled_indices = rng.permutation(indices) + train_indices = shuffled_indices[:train_size] + test_indices = shuffled_indices[train_size:] + + train_dataset = Subset(dataset, train_indices) + test_dataset = Subset(dataset, test_indices) + + # Cache the splits + self.dataset_splits[split_key] = (train_dataset, test_dataset) + + return train_dataset, test_dataset + + def _evaluate_task( + self, + task_key: str, + task: str, + train_dataset, + test_dataset, + trainer: L.Trainer, + pl_module: L.LightningModule, + ): + """Evaluate a single task.""" + + task_type, num_classes = CALM_TASKS[task] + + self._set_metrics(task_type, num_classes) + + train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) + + try: + train_embeddings, train_targets = self._get_embeddings(pl_module, train_loader) + test_embeddings, test_targets = self._get_embeddings(pl_module, test_loader) + + probe = self._train_probe(train_embeddings, train_targets) + self.probes[task_key] = probe + + metrics = self._evaluate_probe(probe, test_embeddings, test_targets) + + # Log metrics and store for averaging + for metric_name, value in metrics.items(): + metric_key = f"calm_linear_probe/{task_key}/{metric_name}" + trainer.logger.log_metrics({metric_key: value}, step=trainer.current_epoch) + self.aggregate_metrics[metric_name].append(value) + + except Exception as e: + print(f"Error in _evaluate_task for {task_key}: {str(e)}") + raise + + def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: + if self._skip(trainer): + return + + self.device = pl_module.device + self.aggregate_metrics.clear() # Reset aggregation for new epoch + + for task in tqdm(self.tasks, desc=f"{self.__class__.__name__}"): + # Handle species-specific tasks + if task in ["protein_abundance", "transcript_abundance"]: + if not self.species: + continue + for species in self.species: + task_key = f"{task}_{species}" + try: + train_dataset, test_dataset = self._create_split_datasets(task, species) + self._evaluate_task(task_key, task, train_dataset, test_dataset, trainer, pl_module) + except Exception as e: + print(f"Error processing {task_key}: {str(e)}") + else: + try: + train_dataset, test_dataset = self._create_split_datasets(task) + self._evaluate_task(task, task, train_dataset, test_dataset, trainer, pl_module) + except Exception as e: + print(f"Error processing {task}: {str(e)}") + + # Calculate and log aggregate metrics + for metric_name, values in self.aggregate_metrics.items(): + if values: # Only log if we have values + avg_value = sum(values) / len(values) + trainer.logger.log_metrics( + {f"calm_linear_probe/mean/{metric_name}": avg_value}, + step=trainer.current_epoch + ) \ No newline at end of file diff --git a/src/lobster/callbacks/_linear_probe_callback.py b/src/lobster/callbacks/_linear_probe_callback.py index ae72c27..8e547bc 100644 --- a/src/lobster/callbacks/_linear_probe_callback.py +++ b/src/lobster/callbacks/_linear_probe_callback.py @@ -4,12 +4,14 @@ import torch from beignet.transforms import Transform from lightning.pytorch.callbacks import Callback +import numpy as np from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.multioutput import MultiOutputClassifier from torch import Tensor from torch.utils.data import DataLoader from torchmetrics import AUROC, Accuracy, F1Score, MeanSquaredError, R2Score, SpearmanCorrCoef -TaskType = Literal["regression", "binary", "multiclass"] +TaskType = Literal["regression", "binary", "multiclass", "multilabel"] class LinearProbeCallback(Callback): @@ -31,21 +33,37 @@ def __init__( self.run_every_n_epochs = run_every_n_epochs # Initialize metrics based on task type + self._set_metrics(task_type, num_classes) + + # Dictionary to store trained probes + self.probes: Dict[str, LinearRegression | LogisticRegression] = {} + + + def _set_metrics(self, task_type: TaskType, num_classes: Optional[int] = None) -> None: + """Initialize metrics based on task type.""" if task_type == "regression": self.mse = MeanSquaredError() self.r2 = R2Score() self.spearman = SpearmanCorrCoef() - - elif task_type in {"binary", "multiclass"}: - self.accuracy = Accuracy(task=task_type, num_classes=num_classes) - self.f1 = F1Score(task=task_type, num_classes=num_classes) - self.auroc = AUROC(task_type=task_type, num_classes=num_classes) - + self.accuracy = None + self.f1 = None + self.auroc = None + + elif task_type in {"binary", "multiclass", "multilabel"}: + # For multilabel, we use num_classes as num_labels + metric_task = task_type + self.accuracy = Accuracy(task=metric_task, num_labels=num_classes) + self.f1 = F1Score(task=metric_task, num_labels=num_classes) + self.auroc = AUROC(task=metric_task, num_labels=num_classes) + self.mse = None + self.r2 = None + self.spearman = None + else: - raise ValueError("task_type must be: regression, binary, or multiclass") - - # Dictionary to store trained probes - self.probes: Dict[str, LinearRegression | LogisticRegression] = {} + raise ValueError("task_type must be: regression, binary, multiclass, or multilabel") + + self.task_type = task_type + self.num_classes = num_classes def _skip(self, trainer: L.Trainer) -> bool: """Determine if we should skip validation this epoch.""" @@ -64,7 +82,7 @@ def _get_embeddings(self, module: L.LightningModule, dataloader: DataLoader) -> for batch in dataloader: x, y = batch x = {k: v.to(module.device) for k, v in x.items()} - + # Get token-level embeddings batch_embeddings = module.tokens_to_latents(**x) @@ -78,7 +96,7 @@ def _get_embeddings(self, module: L.LightningModule, dataloader: DataLoader) -> embeddings.append(seq_embeddings.cpu()) targets.append(y.cpu()) - + return torch.cat(embeddings), torch.cat(targets) def _train_probe(self, embeddings: Tensor, targets: Tensor): @@ -88,34 +106,46 @@ def _train_probe(self, embeddings: Tensor, targets: Tensor): if self.task_type == "regression": probe = LinearRegression() - else: + probe.fit(embeddings, targets) + + elif self.task_type == "multilabel": + base_classifier = LogisticRegression(random_state=42) + probe = MultiOutputClassifier(base_classifier) + probe.fit(embeddings, targets) + + else: # binary or multiclass probe = LogisticRegression( multi_class="ovr" if self.task_type == "binary" else "multinomial", + random_state=42, ) - - probe.fit(embeddings, targets) + probe.fit(embeddings, targets.ravel()) return probe def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[str, float]: """Evaluate a trained probe using task-appropriate metrics.""" + embeddings_np = embeddings.numpy() # Convert to numpy for probe prediction metrics = {} if self.task_type == "regression": - predictions = probe.predict(embeddings.numpy()) - predictions = torch.from_numpy(predictions).float() - + predictions_np = probe.predict(embeddings_np) + predictions = torch.from_numpy(predictions_np).float() + metrics["mse"] = self.mse(predictions, targets).item() metrics["r2"] = self.r2(predictions, targets).item() metrics["spearman"] = self.spearman(predictions.squeeze(), targets.squeeze()).item() - - else: # binary or multiclass - pred_probs = probe.predict_proba(embeddings.numpy()) - predictions = torch.from_numpy(pred_probs).float() - - if self.task_type == "binary": - predictions = predictions[:, 1] - + + else: # binary, multiclass, or multilabel + if self.task_type == "multilabel": + # Get probabilities for each label + predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] + for est in probe.estimators_], axis=1) + else: # binary or multiclass + predictions_np = probe.predict_proba(embeddings_np) + if self.task_type == "binary": + predictions_np = predictions_np[:, 1] + + predictions = torch.from_numpy(predictions_np).float() metrics["accuracy"] = self.accuracy(predictions, targets).item() metrics["f1"] = self.f1(predictions, targets).item() metrics["auroc"] = self.auroc(predictions, targets).item() @@ -124,4 +154,4 @@ def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[st def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: """Train and evaluate linear probes, optionally at specified epochs.""" - raise NotImplementedError("Subclasses must implement on_validation_epoch_end") + raise NotImplementedError("Subclasses must implement on_validation_epoch_end") \ No newline at end of file diff --git a/src/lobster/constants/__init__.py b/src/lobster/constants/__init__.py index a4972ea..ee7a4c2 100644 --- a/src/lobster/constants/__init__.py +++ b/src/lobster/constants/__init__.py @@ -1,3 +1,6 @@ +from ._calm_tasks import CALM_TASKS from ._moleculeace_tasks import MOLECULEACE_TASKS -__all__ = ["MOLECULEACE_TASKS"] +__all__ = ["CALM_TASKS", + "MOLECULEACE_TASKS" + ] diff --git a/src/lobster/constants/_calm_tasks.py b/src/lobster/constants/_calm_tasks.py index 01bd170..062c298 100644 --- a/src/lobster/constants/_calm_tasks.py +++ b/src/lobster/constants/_calm_tasks.py @@ -1,35 +1,19 @@ from enum import Enum -# TODO - add to __init__ +CALM_TASKS = { + "meltome": ("regression", None), # (task_type, num_classes) + "solubility": ("regression", None), + "localization": ("multilabel", 10), # 10 cellular locations + "protein_abundance": ("regression", None), + "transcript_abundance": ("regression", None), + "function_bp": ("multilabel", 4), # 4 GO terms + "function_cc": ("multilabel", 4), + "function_mf": ("multilabel", 4) +} + CALM_DATA_GITHUB_URL = "https://raw.githubusercontent.com/oxpig/CaLM/main/data" FUNCTION_ZENODO_BASE_URL = "https://zenodo.org/records/14890750/files" # Gene Ontology datasets processed & uploaded on Zenodo -FILE_HASHES = { - "meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4", - "solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38", - "localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320", - - # Function prediction tasks - "calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334", - "calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6", - "calm_GO_mf_middle_normal.parquet": "md5:cafe14db5dda19837bae536399e47e35", - - # Protein abundance - "protein_abundance_athaliana.csv": "sha256:83f8d995ee3a0ff6f1ed4e74a9cb891546e2edb6e86334fef3b6901a0039b118", - "protein_abundance_dmelanogaster.csv": "sha256:6f9541d38217f71a4f9addec4ad567d60eee4e4cebb1d16775909b88e1775da4", - "protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb", - "protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23", - "protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076", - - # Transcript abundance - "transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139", - "transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32", - "transcript_abundance_ecoli.csv": "sha256:5e480d961c8b9043f6039211241ecf904b188214e9951352a9b2fc3d6a630a59", - "transcript_abundance_hsapiens.csv": "sha256:21b4b3f3f7267d28dbf6434090dfc0c58fde6f15393537d569f0b29e3eeec491", - "transcript_abundance_hvolcanii.csv": "sha256:91782d2839f78b7c3a4c4d2c0f685605fa746e9b3936579fbd91ce875f9860aa", - "transcript_abundance_ppastoris.csv": "sha256:4ebd4783e1b90e76e481c25bce801d4f6984f85d382f5d458d26f554e114798a", - "transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36" -} class Species(str, Enum): @@ -70,3 +54,33 @@ class Task(Enum): ] } +# Files hashes to check upstream data files haven't been changed. Makes data download cleaner +FILE_HASHES = { + "meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4", + "solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38", + "localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320", + + # Function prediction tasks + "calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334", + "calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6", + "calm_GO_mf_middle_normal.parquet": "md5:cafe14db5dda19837bae536399e47e35", + + # Protein abundance + "protein_abundance_athaliana.csv": "sha256:83f8d995ee3a0ff6f1ed4e74a9cb891546e2edb6e86334fef3b6901a0039b118", + "protein_abundance_dmelanogaster.csv": "sha256:6f9541d38217f71a4f9addec4ad567d60eee4e4cebb1d16775909b88e1775da4", + "protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb", + "protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23", + "protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076", + + # Transcript abundance + "transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139", + "transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32", + "transcript_abundance_ecoli.csv": "sha256:5e480d961c8b9043f6039211241ecf904b188214e9951352a9b2fc3d6a630a59", + "transcript_abundance_hsapiens.csv": "sha256:21b4b3f3f7267d28dbf6434090dfc0c58fde6f15393537d569f0b29e3eeec491", + "transcript_abundance_hvolcanii.csv": "sha256:91782d2839f78b7c3a4c4d2c0f685605fa746e9b3936579fbd91ce875f9860aa", + "transcript_abundance_ppastoris.csv": "sha256:4ebd4783e1b90e76e481c25bce801d4f6984f85d382f5d458d26f554e114798a", + "transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36" +} + + + diff --git a/src/lobster/datasets/_calm_property_dataset.py b/src/lobster/datasets/_calm_property_dataset.py index 0a824e9..4d17319 100644 --- a/src/lobster/datasets/_calm_property_dataset.py +++ b/src/lobster/datasets/_calm_property_dataset.py @@ -2,11 +2,21 @@ from pathlib import Path import pandas as pd import pooch +import torch from torch.utils.data import Dataset +from torch import Tensor from typing import Optional, Literal, Callable, Sequence, Tuple from lobster.transforms import Transform -from lobster.constants._calm_tasks import CALM_DATA_GITHUB_URL, Species, Task, TASK_SPECIES, FUNCTION_ZENODO_BASE_URL, FILE_HASHES +from lobster.constants._calm_tasks import ( + CALM_DATA_GITHUB_URL, + Species, + Task, + TASK_SPECIES, + FUNCTION_ZENODO_BASE_URL, + FILE_HASHES, + CALM_TASKS, +) class CalmPropertyDataset(Dataset): @@ -18,7 +28,8 @@ def __init__( species: Optional[Species | str] = None, split: Optional[Literal["train", "validation", "test"]] = None, download: bool = True, - transform: Optional[Callable | Transform] = None, + transform_fn: Optional[Callable | Transform] = None, + target_transform_fn: Optional[Callable | Transform] = None, columns: Optional[Sequence[str]] = None, known_hash: Optional[str] = None, ): @@ -32,7 +43,10 @@ def __init__( self.task = task self.species = species self.split = split - self.transform = transform + self.transform_fn = transform_fn + self.target_transform_fn = target_transform_fn + + self.task_type, self.num_classes = CALM_TASKS[task.value] if root is None: root = pooch.os_cache("lbster") @@ -40,7 +54,7 @@ def __init__( root = Path(root) self.root = root.resolve() - # Determine file name and URL based on task type + # Get file name and URL based on task type if task in [Task.FUNCTION_BP, Task.FUNCTION_CC, Task.FUNCTION_MF]: function_type = task.value.split('_')[1].lower() # Extract bp, cc, or mf fname = f"calm_GO_{function_type}_middle_normal.parquet" @@ -108,18 +122,27 @@ def __init__( columns = list(self.data.columns) self.columns = columns - self._x = self.data[self.columns].apply(tuple, axis=1) - def __getitem__(self, index: int) -> Tuple: - x = self._x[index] + def __getitem__(self, index: int) -> Tuple[str | Tensor, Tensor]: + item = self.data.iloc[index] - if len(x) == 1: - x = x[0] + x = item[self.columns[0]] # First column is always the input sequence/data + + if self.transform_fn is not None: + x = self.transform_fn(x) + + y_cols = self.columns[1:] + y_values = pd.to_numeric(item[y_cols]).values - if self.transform is not None: - x = self.transform(x) + if self.task_type == "regression": + y = torch.tensor(y_values, dtype=torch.float32) + else: # multilabel tasks (localization and function prediction) + y = torch.tensor(y_values, dtype=torch.long) + + if self.target_transform_fn is not None: + y = self.target_transform_fn(y) - return x + return x, y def __len__(self) -> int: - return len(self._x) \ No newline at end of file + return len(self.data) \ No newline at end of file From ab21614d408e8ac86dfc3766455c7cac612c1ce2 Mon Sep 17 00:00:00 2001 From: Taylor Joren Date: Thu, 20 Feb 2025 20:14:19 +0000 Subject: [PATCH 6/7] fix ruff --- .../callbacks/_calm_linear_probe_callback.py | 51 ++++++++--------- .../callbacks/_linear_probe_callback.py | 26 ++++----- src/lobster/constants/_calm_tasks.py | 6 +- .../datasets/_calm_property_dataset.py | 56 +++++++++---------- 4 files changed, 70 insertions(+), 69 deletions(-) diff --git a/src/lobster/callbacks/_calm_linear_probe_callback.py b/src/lobster/callbacks/_calm_linear_probe_callback.py index 44cba58..b61bf89 100644 --- a/src/lobster/callbacks/_calm_linear_probe_callback.py +++ b/src/lobster/callbacks/_calm_linear_probe_callback.py @@ -1,21 +1,22 @@ from collections import defaultdict from typing import Optional, Sequence, Tuple -import torch -from torch.utils.data import DataLoader, random_split, Subset -import numpy as np + import lightning as L +import numpy as np +from torch.utils.data import DataLoader, Subset from tqdm import tqdm from lobster.constants import CALM_TASKS from lobster.datasets import CalmPropertyDataset from lobster.tokenization import NucleotideTokenizerFast from lobster.transforms import TokenizerTransform + from ._linear_probe_callback import LinearProbeCallback class CalmLinearProbeCallback(LinearProbeCallback): """Callback for evaluating embedding models on the CALM dataset collection.""" - + def __init__( self, max_length: int, @@ -43,11 +44,11 @@ def __init__( self.tasks = set(tasks) if tasks else set(CALM_TASKS.keys()) self.species = set(species) if species else None - + self.test_size = test_size self.max_samples = max_samples self.seed = seed - + self.dataset_splits = {} self.aggregate_metrics = defaultdict(list) @@ -57,35 +58,35 @@ def _create_split_datasets( species: Optional[str] = None ) -> Tuple[CalmPropertyDataset, CalmPropertyDataset]: """Create train/test splits for a given task.""" - + rng = np.random.RandomState(self.seed) # TODO - seed everything fn # Check cache for existing splits - split_key = f"{task}_{species}" if species else task + split_key = f"{task}_{species}" if species else task if split_key in self.dataset_splits: return self.dataset_splits[split_key] - + dataset = CalmPropertyDataset(task=task, species=species, transform_fn=self.transform_fn) - + indices = np.arange(len(dataset)) - + # If dataset is too large, subsample it first if len(indices) > self.max_samples: indices = rng.choice(indices, size=self.max_samples, replace=False) - + # Create train/test split from (possibly subsampled) indices test_size = int(len(indices) * self.test_size) train_size = len(indices) - test_size shuffled_indices = rng.permutation(indices) train_indices = shuffled_indices[:train_size] test_indices = shuffled_indices[train_size:] - + train_dataset = Subset(dataset, train_indices) test_dataset = Subset(dataset, test_indices) - + # Cache the splits self.dataset_splits[split_key] = (train_dataset, test_dataset) - + return train_dataset, test_dataset def _evaluate_task( @@ -100,27 +101,27 @@ def _evaluate_task( """Evaluate a single task.""" task_type, num_classes = CALM_TASKS[task] - + self._set_metrics(task_type, num_classes) - + train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) try: train_embeddings, train_targets = self._get_embeddings(pl_module, train_loader) test_embeddings, test_targets = self._get_embeddings(pl_module, test_loader) - + probe = self._train_probe(train_embeddings, train_targets) self.probes[task_key] = probe - + metrics = self._evaluate_probe(probe, test_embeddings, test_targets) - + # Log metrics and store for averaging for metric_name, value in metrics.items(): metric_key = f"calm_linear_probe/{task_key}/{metric_name}" trainer.logger.log_metrics({metric_key: value}, step=trainer.current_epoch) self.aggregate_metrics[metric_name].append(value) - + except Exception as e: print(f"Error in _evaluate_task for {task_key}: {str(e)}") raise @@ -128,10 +129,10 @@ def _evaluate_task( def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: if self._skip(trainer): return - + self.device = pl_module.device self.aggregate_metrics.clear() # Reset aggregation for new epoch - + for task in tqdm(self.tasks, desc=f"{self.__class__.__name__}"): # Handle species-specific tasks if task in ["protein_abundance", "transcript_abundance"]: @@ -150,7 +151,7 @@ def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModu self._evaluate_task(task, task, train_dataset, test_dataset, trainer, pl_module) except Exception as e: print(f"Error processing {task}: {str(e)}") - + # Calculate and log aggregate metrics for metric_name, values in self.aggregate_metrics.items(): if values: # Only log if we have values @@ -158,4 +159,4 @@ def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModu trainer.logger.log_metrics( {f"calm_linear_probe/mean/{metric_name}": avg_value}, step=trainer.current_epoch - ) \ No newline at end of file + ) diff --git a/src/lobster/callbacks/_linear_probe_callback.py b/src/lobster/callbacks/_linear_probe_callback.py index 8e547bc..072bc35 100644 --- a/src/lobster/callbacks/_linear_probe_callback.py +++ b/src/lobster/callbacks/_linear_probe_callback.py @@ -1,10 +1,10 @@ from typing import Callable, Dict, Literal, Optional, Tuple import lightning as L +import numpy as np import torch from beignet.transforms import Transform from lightning.pytorch.callbacks import Callback -import numpy as np from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.multioutput import MultiOutputClassifier from torch import Tensor @@ -48,7 +48,7 @@ def _set_metrics(self, task_type: TaskType, num_classes: Optional[int] = None) - self.accuracy = None self.f1 = None self.auroc = None - + elif task_type in {"binary", "multiclass", "multilabel"}: # For multilabel, we use num_classes as num_labels metric_task = task_type @@ -58,10 +58,10 @@ def _set_metrics(self, task_type: TaskType, num_classes: Optional[int] = None) - self.mse = None self.r2 = None self.spearman = None - + else: raise ValueError("task_type must be: regression, binary, multiclass, or multilabel") - + self.task_type = task_type self.num_classes = num_classes @@ -82,7 +82,7 @@ def _get_embeddings(self, module: L.LightningModule, dataloader: DataLoader) -> for batch in dataloader: x, y = batch x = {k: v.to(module.device) for k, v in x.items()} - + # Get token-level embeddings batch_embeddings = module.tokens_to_latents(**x) @@ -96,7 +96,7 @@ def _get_embeddings(self, module: L.LightningModule, dataloader: DataLoader) -> embeddings.append(seq_embeddings.cpu()) targets.append(y.cpu()) - + return torch.cat(embeddings), torch.cat(targets) def _train_probe(self, embeddings: Tensor, targets: Tensor): @@ -107,12 +107,12 @@ def _train_probe(self, embeddings: Tensor, targets: Tensor): if self.task_type == "regression": probe = LinearRegression() probe.fit(embeddings, targets) - + elif self.task_type == "multilabel": base_classifier = LogisticRegression(random_state=42) probe = MultiOutputClassifier(base_classifier) probe.fit(embeddings, targets) - + else: # binary or multiclass probe = LogisticRegression( multi_class="ovr" if self.task_type == "binary" else "multinomial", @@ -130,21 +130,21 @@ def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[st if self.task_type == "regression": predictions_np = probe.predict(embeddings_np) predictions = torch.from_numpy(predictions_np).float() - + metrics["mse"] = self.mse(predictions, targets).item() metrics["r2"] = self.r2(predictions, targets).item() metrics["spearman"] = self.spearman(predictions.squeeze(), targets.squeeze()).item() - + else: # binary, multiclass, or multilabel if self.task_type == "multilabel": # Get probabilities for each label - predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] + predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] for est in probe.estimators_], axis=1) else: # binary or multiclass predictions_np = probe.predict_proba(embeddings_np) if self.task_type == "binary": predictions_np = predictions_np[:, 1] - + predictions = torch.from_numpy(predictions_np).float() metrics["accuracy"] = self.accuracy(predictions, targets).item() metrics["f1"] = self.f1(predictions, targets).item() @@ -154,4 +154,4 @@ def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[st def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: """Train and evaluate linear probes, optionally at specified epochs.""" - raise NotImplementedError("Subclasses must implement on_validation_epoch_end") \ No newline at end of file + raise NotImplementedError("Subclasses must implement on_validation_epoch_end") diff --git a/src/lobster/constants/_calm_tasks.py b/src/lobster/constants/_calm_tasks.py index 062c298..85a07ab 100644 --- a/src/lobster/constants/_calm_tasks.py +++ b/src/lobster/constants/_calm_tasks.py @@ -55,11 +55,11 @@ class Task(Enum): } # Files hashes to check upstream data files haven't been changed. Makes data download cleaner -FILE_HASHES = { +FILE_HASHES = { "meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4", "solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38", "localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320", - + # Function prediction tasks "calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334", "calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6", @@ -71,7 +71,7 @@ class Task(Enum): "protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb", "protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23", "protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076", - + # Transcript abundance "transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139", "transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32", diff --git a/src/lobster/datasets/_calm_property_dataset.py b/src/lobster/datasets/_calm_property_dataset.py index 4d17319..f16a545 100644 --- a/src/lobster/datasets/_calm_property_dataset.py +++ b/src/lobster/datasets/_calm_property_dataset.py @@ -1,22 +1,22 @@ -from enum import Enum from pathlib import Path +from typing import Callable, Literal, Optional, Sequence, Tuple + import pandas as pd import pooch import torch -from torch.utils.data import Dataset from torch import Tensor -from typing import Optional, Literal, Callable, Sequence, Tuple +from torch.utils.data import Dataset -from lobster.transforms import Transform from lobster.constants._calm_tasks import ( - CALM_DATA_GITHUB_URL, - Species, - Task, - TASK_SPECIES, - FUNCTION_ZENODO_BASE_URL, - FILE_HASHES, + CALM_DATA_GITHUB_URL, CALM_TASKS, + FILE_HASHES, + FUNCTION_ZENODO_BASE_URL, + TASK_SPECIES, + Species, + Task, ) +from lobster.transforms import Transform class CalmPropertyDataset(Dataset): @@ -34,26 +34,26 @@ def __init__( known_hash: Optional[str] = None, ): super().__init__() - + if isinstance(task, str): task = Task(task) if isinstance(species, str): species = Species(species) - + self.task = task self.species = species self.split = split self.transform_fn = transform_fn self.target_transform_fn = target_transform_fn - + self.task_type, self.num_classes = CALM_TASKS[task.value] - + if root is None: root = pooch.os_cache("lbster") if isinstance(root, str): root = Path(root) self.root = root.resolve() - + # Get file name and URL based on task type if task in [Task.FUNCTION_BP, Task.FUNCTION_CC, Task.FUNCTION_MF]: function_type = task.value.split('_')[1].lower() # Extract bp, cc, or mf @@ -77,7 +77,7 @@ def __init__( # Get hash for the storage filename --> to ensure file has not changed if known_hash is None and storage_fname in FILE_HASHES: known_hash = FILE_HASHES[storage_fname] - + file_path = Path(self.root / self.__class__.__name__ / storage_fname) if download: file_path = pooch.retrieve( @@ -91,14 +91,14 @@ def __init__( raise FileNotFoundError( f"Data file {file_path} not found and download=False" ) - + if str(file_path).endswith('.parquet'): self.data = pd.read_parquet(file_path) elif str(file_path).endswith('.tsv'): self.data = pd.read_csv(file_path, sep='\t') else: self.data = pd.read_csv(file_path) - + if columns is None: if task == Task.FUNCTION_BP: columns = ['sequence', "GO:0051092", "GO:0016573", "GO:0031146", "GO:0071427", "GO:0006613"] @@ -120,29 +120,29 @@ def __init__( columns = ['cds', 'logtpm'] else: columns = list(self.data.columns) - + self.columns = columns - + def __getitem__(self, index: int) -> Tuple[str | Tensor, Tensor]: item = self.data.iloc[index] - + x = item[self.columns[0]] # First column is always the input sequence/data if self.transform_fn is not None: x = self.transform_fn(x) - + y_cols = self.columns[1:] - y_values = pd.to_numeric(item[y_cols]).values - + y_values = pd.to_numeric(item[y_cols]).values + if self.task_type == "regression": y = torch.tensor(y_values, dtype=torch.float32) else: # multilabel tasks (localization and function prediction) y = torch.tensor(y_values, dtype=torch.long) - + if self.target_transform_fn is not None: y = self.target_transform_fn(y) - + return x, y - + def __len__(self) -> int: - return len(self.data) \ No newline at end of file + return len(self.data) From be039a0c0237b9e5566c5eb55b3be71a4f60199d Mon Sep 17 00:00:00 2001 From: Taylor Joren Date: Thu, 20 Feb 2025 20:23:08 +0000 Subject: [PATCH 7/7] re-ruff --- .../callbacks/_calm_linear_probe_callback.py | 11 +- .../callbacks/_linear_probe_callback.py | 4 +- src/lobster/constants/__init__.py | 4 +- src/lobster/constants/_calm_tasks.py | 18 +- .../datasets/_calm_property_dataset.py | 44 ++-- .../test__calm_linear_probe_callback.py | 193 ++++++++++++++++++ 6 files changed, 233 insertions(+), 41 deletions(-) create mode 100644 tests/lobster/callbacks/test__calm_linear_probe_callback.py diff --git a/src/lobster/callbacks/_calm_linear_probe_callback.py b/src/lobster/callbacks/_calm_linear_probe_callback.py index b61bf89..ffce370 100644 --- a/src/lobster/callbacks/_calm_linear_probe_callback.py +++ b/src/lobster/callbacks/_calm_linear_probe_callback.py @@ -53,13 +53,11 @@ def __init__( self.aggregate_metrics = defaultdict(list) def _create_split_datasets( - self, - task: str, - species: Optional[str] = None + self, task: str, species: Optional[str] = None ) -> Tuple[CalmPropertyDataset, CalmPropertyDataset]: """Create train/test splits for a given task.""" - rng = np.random.RandomState(self.seed) # TODO - seed everything fn + rng = np.random.RandomState(self.seed) # TODO - seed everything fn # Check cache for existing splits split_key = f"{task}_{species}" if species else task @@ -97,7 +95,7 @@ def _evaluate_task( test_dataset, trainer: L.Trainer, pl_module: L.LightningModule, - ): + ): """Evaluate a single task.""" task_type, num_classes = CALM_TASKS[task] @@ -157,6 +155,5 @@ def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModu if values: # Only log if we have values avg_value = sum(values) / len(values) trainer.logger.log_metrics( - {f"calm_linear_probe/mean/{metric_name}": avg_value}, - step=trainer.current_epoch + {f"calm_linear_probe/mean/{metric_name}": avg_value}, step=trainer.current_epoch ) diff --git a/src/lobster/callbacks/_linear_probe_callback.py b/src/lobster/callbacks/_linear_probe_callback.py index 072bc35..e5a5c65 100644 --- a/src/lobster/callbacks/_linear_probe_callback.py +++ b/src/lobster/callbacks/_linear_probe_callback.py @@ -38,7 +38,6 @@ def __init__( # Dictionary to store trained probes self.probes: Dict[str, LinearRegression | LogisticRegression] = {} - def _set_metrics(self, task_type: TaskType, num_classes: Optional[int] = None) -> None: """Initialize metrics based on task type.""" if task_type == "regression": @@ -138,8 +137,7 @@ def _evaluate_probe(self, probe, embeddings: Tensor, targets: Tensor) -> Dict[st else: # binary, multiclass, or multilabel if self.task_type == "multilabel": # Get probabilities for each label - predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] - for est in probe.estimators_], axis=1) + predictions_np = np.stack([est.predict_proba(embeddings_np)[:, 1] for est in probe.estimators_], axis=1) else: # binary or multiclass predictions_np = probe.predict_proba(embeddings_np) if self.task_type == "binary": diff --git a/src/lobster/constants/__init__.py b/src/lobster/constants/__init__.py index ee7a4c2..b5540a1 100644 --- a/src/lobster/constants/__init__.py +++ b/src/lobster/constants/__init__.py @@ -1,6 +1,4 @@ from ._calm_tasks import CALM_TASKS from ._moleculeace_tasks import MOLECULEACE_TASKS -__all__ = ["CALM_TASKS", - "MOLECULEACE_TASKS" - ] +__all__ = ["CALM_TASKS", "MOLECULEACE_TASKS"] diff --git a/src/lobster/constants/_calm_tasks.py b/src/lobster/constants/_calm_tasks.py index 85a07ab..1b087e2 100644 --- a/src/lobster/constants/_calm_tasks.py +++ b/src/lobster/constants/_calm_tasks.py @@ -8,12 +8,14 @@ "transcript_abundance": ("regression", None), "function_bp": ("multilabel", 4), # 4 GO terms "function_cc": ("multilabel", 4), - "function_mf": ("multilabel", 4) + "function_mf": ("multilabel", 4), } CALM_DATA_GITHUB_URL = "https://raw.githubusercontent.com/oxpig/CaLM/main/data" -FUNCTION_ZENODO_BASE_URL = "https://zenodo.org/records/14890750/files" # Gene Ontology datasets processed & uploaded on Zenodo +FUNCTION_ZENODO_BASE_URL = ( + "https://zenodo.org/records/14890750/files" # Gene Ontology datasets processed & uploaded on Zenodo +) class Species(str, Enum): @@ -25,6 +27,7 @@ class Species(str, Enum): PPASTORIS = "ppastoris" SCEREVISIAE = "scerevisiae" + class Task(Enum): MELTOME = "meltome" SOLUBILITY = "solubility" @@ -35,6 +38,7 @@ class Task(Enum): FUNCTION_CC = "function_cc" FUNCTION_MF = "function_mf" + TASK_SPECIES = { Task.PROTEIN_ABUNDANCE: [ Species.ATHALIANA, @@ -51,7 +55,7 @@ class Task(Enum): Species.HVOLCANII, Species.PPASTORIS, Species.SCEREVISIAE, - ] + ], } # Files hashes to check upstream data files haven't been changed. Makes data download cleaner @@ -59,19 +63,16 @@ class Task(Enum): "meltome.csv": "sha256:699074debc9e3d66e0c084bca594ce81d26b3126b645d43b0597dbe466153ad4", "solubility.csv": "sha256:94b351d0f36b490423b3e80b2ff0ea5114423165948183157cf63d4f57c08d38", "localization.csv": "sha256:efedb7c394b96f4569c72d03eac54ca5a3b4a24e15c9c24d9429f9b1a4e29320", - # Function prediction tasks "calm_GO_bp_middle_normal.parquet": "md5:898265de59ba1ac97270bffc3621f334", "calm_GO_cc_middle_normal.parquet": "md5:a6af91fe40e523c9adf47e6abd98d9c6", "calm_GO_mf_middle_normal.parquet": "md5:cafe14db5dda19837bae536399e47e35", - # Protein abundance "protein_abundance_athaliana.csv": "sha256:83f8d995ee3a0ff6f1ed4e74a9cb891546e2edb6e86334fef3b6901a0039b118", "protein_abundance_dmelanogaster.csv": "sha256:6f9541d38217f71a4f9addec4ad567d60eee4e4cebb1d16775909b88e1775da4", "protein_abundance_ecoli.csv": "sha256:a6a8f91901a4ea4da62931d1e7c91b3a6aa72e4d6c6a83a6612c0988e94421fb", "protein_abundance_hsapiens.csv": "sha256:94ded0486f2f64575bd2d8f2a3e00611a6e8b28b691d0f367ca9210058771a23", "protein_abundance_scerevisiae.csv": "sha256:0ce0b6a5b0400c3cc1c568f6c5478a974e80aaecbab93299f81bb94eb2373076", - # Transcript abundance "transcript_abundance_athaliana.csv": "sha256:de7a6f57bcfbb60445d17b8461a8a3ea8353942e129f08ac2c6de5874cd6c139", "transcript_abundance_dmelanogaster.csv": "sha256:0124d101da004e7a66f4303ff097da39d5e4dd474548fa57e2f9fa7231544c32", @@ -79,8 +80,5 @@ class Task(Enum): "transcript_abundance_hsapiens.csv": "sha256:21b4b3f3f7267d28dbf6434090dfc0c58fde6f15393537d569f0b29e3eeec491", "transcript_abundance_hvolcanii.csv": "sha256:91782d2839f78b7c3a4c4d2c0f685605fa746e9b3936579fbd91ce875f9860aa", "transcript_abundance_ppastoris.csv": "sha256:4ebd4783e1b90e76e481c25bce801d4f6984f85d382f5d458d26f554e114798a", - "transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36" + "transcript_abundance_scerevisiae.csv": "sha256:2e0f3b4c0cee77f47ab4009be881f671b709df368495f92dad66f34b2c88ac36", } - - - diff --git a/src/lobster/datasets/_calm_property_dataset.py b/src/lobster/datasets/_calm_property_dataset.py index f16a545..3116fbf 100644 --- a/src/lobster/datasets/_calm_property_dataset.py +++ b/src/lobster/datasets/_calm_property_dataset.py @@ -56,7 +56,7 @@ def __init__( # Get file name and URL based on task type if task in [Task.FUNCTION_BP, Task.FUNCTION_CC, Task.FUNCTION_MF]: - function_type = task.value.split('_')[1].lower() # Extract bp, cc, or mf + function_type = task.value.split("_")[1].lower() # Extract bp, cc, or mf fname = f"calm_GO_{function_type}_middle_normal.parquet" url = f"{FUNCTION_ZENODO_BASE_URL}/{fname}" storage_fname = fname @@ -88,36 +88,44 @@ def __init__( progressbar=True, ) elif not file_path.exists(): - raise FileNotFoundError( - f"Data file {file_path} not found and download=False" - ) + raise FileNotFoundError(f"Data file {file_path} not found and download=False") - if str(file_path).endswith('.parquet'): + if str(file_path).endswith(".parquet"): self.data = pd.read_parquet(file_path) - elif str(file_path).endswith('.tsv'): - self.data = pd.read_csv(file_path, sep='\t') + elif str(file_path).endswith(".tsv"): + self.data = pd.read_csv(file_path, sep="\t") else: self.data = pd.read_csv(file_path) if columns is None: if task == Task.FUNCTION_BP: - columns = ['sequence', "GO:0051092", "GO:0016573", "GO:0031146", "GO:0071427", "GO:0006613"] + columns = ["sequence", "GO:0051092", "GO:0016573", "GO:0031146", "GO:0071427", "GO:0006613"] elif task == Task.FUNCTION_CC: - columns = ['sequence', "GO:0022627", "GO:0000502", "GO:0034705", "GO:0030665", "GO:0005925"] + columns = ["sequence", "GO:0022627", "GO:0000502", "GO:0034705", "GO:0030665", "GO:0005925"] elif task == Task.FUNCTION_MF: - columns = ['sequence', "GO:0004843", "GO:0004714", "GO:0003774", "GO:0008227", "GO:0004866"] + columns = ["sequence", "GO:0004843", "GO:0004714", "GO:0003774", "GO:0008227", "GO:0004866"] elif task == Task.LOCALIZATION: - columns = ['Sequence', 'Cell membrane', 'Cytoplasm', 'Endoplasmic reticulum', - 'Extracellular', 'Golgi apparatus', 'Lysosome/Vacuole', - 'Mitochondrion', 'Nucleus', 'Peroxisome', 'Plastid'] + columns = [ + "Sequence", + "Cell membrane", + "Cytoplasm", + "Endoplasmic reticulum", + "Extracellular", + "Golgi apparatus", + "Lysosome/Vacuole", + "Mitochondrion", + "Nucleus", + "Peroxisome", + "Plastid", + ] elif task == Task.MELTOME: - columns = ['sequence', 'melting_temperature'] + columns = ["sequence", "melting_temperature"] elif task == Task.SOLUBILITY: - columns = ['cds', 'solubility'] + columns = ["cds", "solubility"] elif task == Task.PROTEIN_ABUNDANCE: - columns = ['cds', 'abundance'] + columns = ["cds", "abundance"] elif task == Task.TRANSCRIPT_ABUNDANCE: - columns = ['cds', 'logtpm'] + columns = ["cds", "logtpm"] else: columns = list(self.data.columns) @@ -126,7 +134,7 @@ def __init__( def __getitem__(self, index: int) -> Tuple[str | Tensor, Tensor]: item = self.data.iloc[index] - x = item[self.columns[0]] # First column is always the input sequence/data + x = item[self.columns[0]] # First column is always the input sequence/data if self.transform_fn is not None: x = self.transform_fn(x) diff --git a/tests/lobster/callbacks/test__calm_linear_probe_callback.py b/tests/lobster/callbacks/test__calm_linear_probe_callback.py new file mode 100644 index 0000000..c46bd90 --- /dev/null +++ b/tests/lobster/callbacks/test__calm_linear_probe_callback.py @@ -0,0 +1,193 @@ +from unittest.mock import Mock, patch + +import numpy as np +import pytest +import torch +from torch.utils.data import Dataset + +from lobster.callbacks import CalmLinearProbeCallback +from lobster.constants._calm_tasks import CALM_TASKS + +# Add a constant for the max_length parameter used in tests +MAX_LENGTH = 1024 + + +class MockDataset(Dataset): + def __init__(self, size=10, task_type="regression"): + self.size = size + self.task_type = task_type + + # Generate random data + if task_type == "regression": + self.targets = torch.randn(size) + elif task_type == "binary": + self.targets = torch.randint(0, 2, (size,), dtype=torch.long) + elif task_type == "multiclass": + self.num_classes = 10 + self.targets = torch.randint(0, self.num_classes, (size,), dtype=torch.long) + elif task_type == "multilabel": + self.num_classes = 10 # Match localization task + self.targets = torch.randint(0, 2, (size, self.num_classes), dtype=torch.long) + else: + raise ValueError(f"Invalid task_type: {task_type}") + # Generate shorter random sequences + self.sequences = [] + for _ in range(size): + sequence = "".join(np.random.choice(["A", "C", "G", "T"]) for _ in range(10)) + self.sequences.append({"text": sequence}) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return self.sequences[idx], self.targets[idx] + + +class MockLightningModule(torch.nn.Module): + def __init__(self, hidden_size=32): + super().__init__() + self.hidden_size = hidden_size + self.device = "cpu" + + def tokens_to_latents(self, input_ids, **kwargs): + batch_size = input_ids.size(0) + return torch.randn(batch_size, MAX_LENGTH, self.hidden_size) + + +@pytest.fixture +def mock_trainer(): + trainer = Mock() + trainer.current_epoch = 0 + trainer.logger = Mock() + return trainer + + +@pytest.fixture +def mock_pl_module(): + return MockLightningModule() + + +def test_callback_initialization(): + """Test basic initialization of the callback.""" + callback = CalmLinearProbeCallback( + max_length=MAX_LENGTH, tasks=["meltome", "solubility"], species=["hsapiens", "ecoli"], batch_size=16 + ) + + assert callback.tasks == {"meltome", "solubility"} + assert callback.species == {"hsapiens", "ecoli"} + assert callback.batch_size == 16 + assert callback.test_size == 0.2 + assert callback.max_samples == 3000 + + +@pytest.mark.parametrize( + "task", + [ + "meltome", # regression task + "localization", # multiclass task + ], +) +def test_single_task_evaluation(task, mock_trainer, mock_pl_module): + """Test evaluation of individual tasks.""" + task_type, num_classes = CALM_TASKS[task] + + callback = CalmLinearProbeCallback(max_length=MAX_LENGTH, tasks=[task]) + + mock_dataset = MockDataset(size=10, task_type=task_type) + + with patch("lobster.datasets._calm_property_dataset.CalmPropertyDataset", return_value=mock_dataset): + callback.task_type = task_type # Need to explicitly set the task type before evaluation + + if num_classes: + callback.num_classes = num_classes + + mock_pl_module.device = "cpu" + callback.on_validation_epoch_end(mock_trainer, mock_pl_module) + + # Verify metrics were logged + mock_trainer.logger.log_metrics.assert_called() + + +@pytest.mark.parametrize( + "task,species", + [ + ("protein_abundance", "hsapiens"), + ("transcript_abundance", "ecoli"), + ], +) +def test_species_specific_task_evaluation(task, species, mock_trainer, mock_pl_module): + """Test evaluation of species-specific tasks.""" + task_type, num_classes = CALM_TASKS[task] + + callback = CalmLinearProbeCallback(max_length=MAX_LENGTH, tasks=[task], species=[species]) + + mock_dataset = MockDataset(size=50, task_type=task_type) + + with patch("lobster.datasets._calm_property_dataset.CalmPropertyDataset", return_value=mock_dataset): + callback.task_type = task_type + if num_classes: + callback.num_classes = num_classes + + mock_pl_module.device = "cpu" + callback.on_validation_epoch_end(mock_trainer, mock_pl_module) + + # Verify metrics were logged for the specific species + calls = mock_trainer.logger.log_metrics.call_args_list + logged_keys = set() + for call in calls: + args, _ = call + logged_keys.update(args[0].keys()) + + expected_task_key = f"{task}_{species}" + assert any(expected_task_key in key for key in logged_keys) + + +def test_dataset_caching(): + """Test that dataset splits are properly cached.""" + callback = CalmLinearProbeCallback(max_length=MAX_LENGTH, tasks=["meltome"]) + + mock_dataset = MockDataset(size=50) + + with patch("lobster.datasets._calm_property_dataset.CalmPropertyDataset", return_value=mock_dataset): + # First call should create and cache the split + train1, test1 = callback._create_split_datasets("meltome") + # Second call should return cached split + train2, test2 = callback._create_split_datasets("meltome") + + # Verify same splits are returned + assert id(train1) == id(train2) + assert id(test1) == id(test2) + + +def test_max_samples_limit(): + """Test that datasets are properly subsampled when exceeding max_samples.""" + max_samples = 100 + callback = CalmLinearProbeCallback(max_length=MAX_LENGTH, tasks=["meltome"], max_samples=max_samples) + + mock_dataset = MockDataset(size=max_samples * 2) + + with patch("lobster.datasets._calm_property_dataset.CalmPropertyDataset", return_value=mock_dataset): + train, test = callback._create_split_datasets("meltome") + + assert len(train) + len(test) <= max_samples + + +def test_aggregate_metrics_reset(mock_trainer, mock_pl_module): + """Test that aggregate metrics are properly reset between epochs.""" + callback = CalmLinearProbeCallback(max_length=MAX_LENGTH, tasks=["meltome"]) + callback.task_type = "regression" # Set task type explicitly + + mock_dataset = MockDataset(size=10, task_type="regression") + + with patch("lobster.datasets._calm_property_dataset.CalmPropertyDataset", return_value=mock_dataset): + # First epoch + callback.on_validation_epoch_end(mock_trainer, mock_pl_module) + first_metrics = callback.aggregate_metrics.copy() + + # Second epoch + callback.on_validation_epoch_end(mock_trainer, mock_pl_module) + second_metrics = callback.aggregate_metrics + + # Verify metrics were reset + assert first_metrics.keys() == second_metrics.keys() + assert all(len(first_metrics[k]) == len(second_metrics[k]) for k in first_metrics)