Skip to content

Commit

Permalink
V1.0.1 release (#69)
Browse files Browse the repository at this point in the history
* fix: Resolve "AttributeError: 'SpectrumDataFrame' object has no attribute 'df'"

* feat: update notebooks to v1.0.0

* feat: Automatic model download and improve residues
Co-Authored-By: Kevin Eloff <[email protected]>

* feat: update tests for v1.0.1 release
Co-Authored-By: Rachel Catzel <[email protected]>

* feat: update packages
  • Loading branch information
BioGeek authored Jan 21, 2025
1 parent dca4423 commit 9b3ef20
Show file tree
Hide file tree
Showing 29 changed files with 1,070 additions and 448 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ jobs:
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@fb13cb306901256ace3dab689990e13a5550ffaa
uses: pypa/gh-action-pypi-publish@v1.12
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- name: refresh PyPI badge
uses: fjogeleit/http-request-action@v1
uses: fjogeleit/http-request-action@v1.16
with:
url: https://camo.githubusercontent.com/a22fbcbadf81751212d5367cce341631bc28d7749b9cd5c317fbf0706a30c9ae/68747470733a2f2f62616467652e667572792e696f2f70792f696e7374616e6f766f2e737667
method: PURGE
2 changes: 1 addition & 1 deletion instanovo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "1.0.0"
__version__ = "1.0.1"
10 changes: 6 additions & 4 deletions instanovo/configs/inference/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Data paths and output location
data_path: # type: .mgf, .mzml or any other filetype supported by SpectruMataFrame
model_path: # type: .ckpt
data_path: # type: .mgf, .mzml or any other filetype supported by SpectrumDataFrame
model_path: instanovo-extended # type: .ckpt or model id
output_path: # type: .csv
knapsack_path: # type: directory

Expand All @@ -17,9 +17,11 @@ use_knapsack: False
save_beams: False
subset: 1.0 # Subset of dataset to perform inference on, useful for debugging

# These two only work in greedy search
# Residues whose log probability will be set to -inf
# Only works in greedy search
# suppressed_residues: TODO
suppressed_residues:
# Stop model from predicting n-terminal modifications anywhere along the sequence
disable_terminal_residues_anywhere: True

# Run config
num_workers: 16
Expand Down
5 changes: 3 additions & 2 deletions instanovo/configs/inference/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ defaults:

# Data paths and output location
data_path: ./tests/instanovo_test_resources/example_data/test_sample.mgf # type: .ipc
model_path: ./tests/instanovo_test_resources/train_test/epoch=4-step=2420.ckpt # type: .ckpt
output_path: ./tests/instanovo_test_resources/train_test/test_sample_preds.csv # type: .csv
model_path: ./tests/instanovo_test_resources/model.ckpt # type: .ckpt
output_path: ./tests/instanovo_test_resources/test_sample_preds.csv # type: .csv
knapsack_path: ./tests/instanovo_test_resources/example_knapsack # type: directory
use_knapsack: False

num_beams: 5
max_length: 30
max_charge: 3

subset: 1

Expand Down
5 changes: 4 additions & 1 deletion instanovo/configs/instanovo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ train_subset: 1.0
valid_subset: 0.01
val_check_interval: 1.0 # 1.0 This doesn't work
lazy_loading: True # Use lazy loading mode
max_shard_size: 100_000 # Max data shard size for lazy loading, may influence shuffling mechanics
max_shard_size: 1_000_000 # Max data shard size for lazy loading, may influence shuffling mechanics
preshuffle_shards: True # Perform a preshuffle across shards to ensure shards are homogeneous in lazy mode
perform_data_checks: True # Check residues, check precursor masses, etc.
validate_precursor_mass: False # Slow for large datasets
verbose_loading: True # Verbose SDF logs when loading the dataset

# Checkpointing parameters
save_model: True
Expand Down
6 changes: 3 additions & 3 deletions instanovo/configs/instanovo_unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ defaults:
tb_summarywriter: "./logs/instanovo/instanovo-unit-test"

# Training parameters
warmup_iters: 1000
warmup_iters: 480
max_iters: 3_000_000
learning_rate: 1e-3
train_batch_size: 32
grad_accumulation: 1

# Logging parameters
logger:
epochs: 5
epochs: 1
num_sanity_val_steps: 10
console_logging_steps: 50
tensorboard_logging_steps: 500
Expand All @@ -29,4 +29,4 @@ valid_subset: 1.0

# Checkpointing parameters
model_save_folder_path: ./tests/instanovo_test_resources/train_test
ckpt_interval: 2420
ckpt_interval: 480
80 changes: 77 additions & 3 deletions instanovo/inference/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ class GreedyDecoder(Decoder):
models that conform to the `Decodable` interface.
"""

def __init__(self, model: Decodable, mass_scale: int = MASS_SCALE):
def __init__(
self,
model: Decodable,
suppressed_residues: list[str] | None = None,
mass_scale: int = MASS_SCALE,
disable_terminal_residues_anywhere: bool = True,
):
super().__init__(model=model)
self.mass_scale = mass_scale
self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere

suppressed_residues = suppressed_residues or []

# NOTE: Greedy search requires `residue_set` class in the model, update all methods accordingly.
if not hasattr(model, "residue_set"):
Expand All @@ -37,10 +46,32 @@ def __init__(self, model: Decodable, mass_scale: int = MASS_SCALE):
self.residue_masses = torch.zeros(
(len(self.model.residue_set),), dtype=torch.float64
)
terminal_residues_idx: list[int] = []
suppressed_residues_idx: list[int] = []
for i, residue in enumerate(model.residue_set.vocab):
if residue in self.model.residue_set.special_tokens:
continue
self.residue_masses[i] = self.model.residue_set.get_mass(residue)
# If no residue is attached, assume it is a n-terminal residue
if not residue[0].isalpha():
terminal_residues_idx.append(i)

# Check if residue is suppressed
if residue in suppressed_residues:
suppressed_residues_idx.append(i)
suppressed_residues.remove(residue)

if len(suppressed_residues) > 0:
raise ValueError(
f"Suppressed residues not found in vocabulary: {suppressed_residues}"
)

self.terminal_residue_indices = torch.tensor(
terminal_residues_idx, dtype=torch.long
)
self.suppressed_residue_indices = torch.tensor(
suppressed_residues_idx, dtype=torch.long
)

self.vocab_size = len(self.model.residue_set)

Expand Down Expand Up @@ -270,10 +301,53 @@ def decode( # type:ignore
next_token_probabilities_filtered[
:, self.model.residue_set.EOS_INDEX
] = -float("inf")
# Allow the model to predict PAD when all residues are -inf
# next_token_probabilities_filtered[
# :, self.model.residue_set.PAD_INDEX
# ] = -float("inf")
next_token_probabilities_filtered[
:, self.model.residue_set.SOS_INDEX
] = -float("inf")
# TODO set probability of n-terminal modifications to 0 when i > 0, requires n-terms to be specified in residue_set
next_token_probabilities_filtered[
:, self.suppressed_residue_indices
] = -float("inf")
# Set probability of n-terminal modifications to -inf when i > 0
if self.disable_terminal_residues_anywhere:
# Check if adding terminal residues would result in a complete sequence
# First generate remaining mass matrix with isotopes
remaining_mass_incomplete_isotope = remaining_mass_incomplete[
:, None
].expand(sub_batch_size, max_isotope + 1) - CARBON_MASS_DELTA * (
torch.arange(max_isotope + 1, device=device)
)
# Expand with terminal residues and subtract
remaining_mass_incomplete_isotope_delta = (
remaining_mass_incomplete_isotope[:, :, None].expand(
sub_batch_size,
max_isotope + 1,
self.terminal_residue_indices.shape[0],
)
- self.residue_masses[self.terminal_residue_indices]
)

# If within target delta, allow these residues to be predicted, otherwise set probability to -inf
allow_terminal = (
remaining_mass_incomplete_isotope_delta.abs()
< mass_target_incomplete[:, None, None]
).any(dim=1)
allow_terminal_full = torch.ones(
(sub_batch_size, self.vocab_size),
device=spectra.device,
dtype=bool,
)
allow_terminal_full[:, self.terminal_residue_indices] = (
allow_terminal
)

# Set to -inf
next_token_probabilities_filtered[~allow_terminal_full] = -float(
"inf"
)

# Step 5: Select next token:
next_token = next_token_probabilities_filtered.argmax(-1).unsqueeze(
Expand Down Expand Up @@ -362,7 +436,7 @@ def decode( # type:ignore
token_log_probabilities=[
x.cpu().item()
for x in all_log_probabilities[i, : len(sequence)]
], # list[float] (sequence_length) excludes EOS
][::-1], # list[float] (sequence_length) excludes EOS
)
)

Expand Down
8 changes: 8 additions & 0 deletions instanovo/models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"transformer": {
"instanovo-extended": {
"url": "https://github.com/instadeepai/InstaNovo/releases/download/1.0.0/instanovo_extended.ckpt"
}
},
"diffusion": {}
}
2 changes: 1 addition & 1 deletion instanovo/scripts/convert_to_sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def main() -> None:
"""Convert data to ipc."""
"""Convert data to spectrum data frame and save as parquet."""
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()

Expand Down
4 changes: 1 addition & 3 deletions instanovo/scripts/get_zenodo_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ def main(
extract_path + "/instanovo_test_resources"
):
print(
f"Skipping download and extraction. Path '{extract_path}'/instanovo_test_resources already exists and is non-empty."
f"Skipping download and extraction. Path '{extract_path}/instanovo_test_resources' already exists and is non-empty."
)
return

get_zenodo(zenodo_url, zip_path)
unzip_zenodo(zip_path, extract_path)

os.makedirs("./tests/instanovo_test_resources/train_test", exist_ok=True)


if __name__ == "__main__":
main()
77 changes: 76 additions & 1 deletion instanovo/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from typing import Tuple

import torch
import json
import os
import requests
from urllib.parse import urlsplit
from tqdm import tqdm
from pathlib import Path
from jaxtyping import Bool
from jaxtyping import Float
from jaxtyping import Integer
Expand All @@ -27,6 +33,9 @@
from instanovo.types import SpectrumMask
from instanovo.utils import ResidueSet

MODELS_PATH = Path(__file__).parent.parent / "models.json"
MODEL_TYPE = "transformer"


class InstaNovo(nn.Module, Decodable):
"""The Instanovo model."""
Expand Down Expand Up @@ -113,9 +122,21 @@ def _get_causal_mask(seq_len: int, return_float: bool = False) -> PeptideMask:
)
return ~mask.bool()

@staticmethod
def get_pretrained() -> list[str]:
"""Get a list of pretrained model ids."""
# Load the models.json file
with open(MODELS_PATH, "r") as f:
models_config = json.load(f)

if MODEL_TYPE not in models_config:
return []

return list(models_config[MODEL_TYPE].keys())

@classmethod
def load(cls, path: str) -> Tuple["InstaNovo", "DictConfig"]:
"""Load model from checkpoint."""
"""Load model from checkpoint path."""
# Add to allow list
_whitelist_torch_omegaconf()
ckpt = torch.load(path, map_location="cpu", weights_only=True)
Expand Down Expand Up @@ -145,6 +166,60 @@ def load(cls, path: str) -> Tuple["InstaNovo", "DictConfig"]:

return model, config

@classmethod
def from_pretrained(cls, model_id: str) -> Tuple["InstaNovo", "DictConfig"]:
"""Download and load by model id or model path."""
# Check if model_id is a local file path
if "/" in model_id or "\\" in model_id or "." in model_id:
if os.path.isfile(model_id):
return cls.load(model_id)
else:
raise FileNotFoundError(f"No file found at path: {model_id}")

# Load the models.json file
with open(MODELS_PATH, "r") as f:
models_config = json.load(f)

# Find the model in the config
if MODEL_TYPE not in models_config or model_id not in models_config[MODEL_TYPE]:
raise ValueError(
f"Model {model_id} not found in models.json, options are [{', '.join(models_config[MODEL_TYPE].keys())}]"
)

model_info = models_config[MODEL_TYPE][model_id]
url = model_info["url"]

# Create cache directory if it doesn't exist
cache_dir = Path.home() / ".cache" / "instanovo"
cache_dir.mkdir(parents=True, exist_ok=True)

# Generate a filename for the cached model
file_name = urlsplit(url).path.split("/")[-1]
cached_file = cache_dir / file_name

# Check if the file is already cached
if not cached_file.exists():
# If not cached, download the file with a progress bar
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))

with open(cached_file, "wb") as file, tqdm(
desc=file_name,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
# else:
# TODO: Optional verbose logging
# print(f"Model {model_id} already cached at {cached_file}")

# Load and return the model
return cls.load(str(cached_file))

def forward(
self,
x: Float[Spectrum, " batch"],
Expand Down
Loading

0 comments on commit 9b3ef20

Please sign in to comment.