diff --git a/esm/__init__.py b/esm/__init__.py index 7a30ab4..0dd7cd4 100644 --- a/esm/__init__.py +++ b/esm/__init__.py @@ -1 +1,2 @@ -__version__ = "3.0.7post1" +__version__ = "3.0.8" + diff --git a/esm/layers/attention.py b/esm/layers/attention.py index 28fee0e..41833b4 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -10,11 +10,7 @@ class MultiHeadAttention(nn.Module): def __init__( - self, - d_model: int, - n_heads: int, - bias: bool = False, - qk_layernorm: bool = True, + self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True ): super().__init__() diff --git a/esm/layers/geom_attention.py b/esm/layers/geom_attention.py index 5139cd0..69656d0 100644 --- a/esm/layers/geom_attention.py +++ b/esm/layers/geom_attention.py @@ -78,11 +78,7 @@ def forward(self, s, affine, affine_mask, sequence_id, chain_id): affine.rot[..., None] .apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3)) .split( - [ - self.v_heads, - self.v_heads, - self.v_heads * self.num_vector_messages, - ], + [self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages], dim=-2, ) ) diff --git a/esm/layers/regression_head.py b/esm/layers/regression_head.py index a237872..6b1c268 100644 --- a/esm/layers/regression_head.py +++ b/esm/layers/regression_head.py @@ -2,9 +2,7 @@ def RegressionHead( - d_model: int, - output_dim: int, - hidden_dim: int | None = None, + d_model: int, output_dim: int, hidden_dim: int | None = None ) -> nn.Module: """Single-hidden layer MLP for supervised output. diff --git a/esm/layers/structure_proj.py b/esm/layers/structure_proj.py index a650176..783ddeb 100644 --- a/esm/layers/structure_proj.py +++ b/esm/layers/structure_proj.py @@ -1,9 +1,7 @@ import torch import torch.nn as nn -from esm.utils.constants.physics import ( - BB_COORDINATES, -) +from esm.utils.constants.physics import BB_COORDINATES from esm.utils.structure.affine3d import ( Affine3D, RotationMatrix, diff --git a/esm/models/esm3.py b/esm/models/esm3.py index f383dff..ecc994a 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -29,9 +29,7 @@ ProteinType, SamplingConfig, ) -from esm.tokenization import ( - TokenizerCollectionProtocol, -) +from esm.tokenization import TokenizerCollectionProtocol from esm.utils import encoding from esm.utils.constants import esm3 as C from esm.utils.constants.models import ( @@ -173,11 +171,7 @@ def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput: secondary_structure_logits = self.ss8_head(x) sasa_logits = self.sasa_head(x) function_logits = self.function_head(x) - function_logits = einops.rearrange( - function_logits, - "... (k v) -> ... k v", - k=8, - ) + function_logits = einops.rearrange(function_logits, "... (k v) -> ... k v", k=8) residue_logits = self.residue_head(x) @@ -217,11 +211,7 @@ def __init__( super().__init__() self.encoder = EncodeInputs(d_model) self.transformer = TransformerStack( - d_model, - n_heads, - v_heads, - n_layers, - mask_and_zero_frameless=True, + d_model, n_heads, v_heads, n_layers, mask_and_zero_frameless=True ) self.output_heads = OutputHeads(d_model) @@ -237,9 +227,7 @@ def __init__( @classmethod def from_pretrained( - cls, - model_name: str = ESM3_OPEN_SMALL, - device: torch.device | None = None, + cls, model_name: str = ESM3_OPEN_SMALL, device: torch.device | None = None ) -> ESM3: from esm.pretrained import load_local_model @@ -489,15 +477,14 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor: reference_sequence = encoding.get_default_sequence(sequence_length - 2) else: reference_sequence = input.sequence - ( - function_tokens, - residue_annotation_tokens, - ) = encoding.tokenize_function_annotations( - input.function_annotations, - reference_sequence=reference_sequence, - function_tokenizer=self.tokenizers.function, - residue_annotation_tokenizer=self.tokenizers.residue_annotations, - add_special_tokens=True, + (function_tokens, residue_annotation_tokens) = ( + encoding.tokenize_function_annotations( + input.function_annotations, + reference_sequence=reference_sequence, + function_tokenizer=self.tokenizers.function, + residue_annotation_tokenizer=self.tokenizers.residue_annotations, + add_special_tokens=True, + ) ) return ESMProteinTensor( @@ -510,10 +497,7 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor: coordinates=coordinates, ).to(next(self.parameters()).device) - def decode( - self, - input: ESMProteinTensor, - ) -> ESMProtein: + def decode(self, input: ESMProteinTensor) -> ESMProtein: return decode_protein_tensor( input=input, tokenizers=self.tokenizers, @@ -613,10 +597,7 @@ def forward_and_sample( logits_output: LogitsOutput = _batch_forward(self, batched_protein) forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt( - batched_protein, - logits_output, - sampling_config, - self.tokenizers, + batched_protein, logits_output, sampling_config, self.tokenizers ) # There is only 1 prompt to sample for. diff --git a/esm/models/function_decoder.py b/esm/models/function_decoder.py index 6c5452d..2c34ce6 100644 --- a/esm/models/function_decoder.py +++ b/esm/models/function_decoder.py @@ -167,8 +167,7 @@ def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]: # Apply depth-position offset to use distinct vocabs. See __init__ for # explaination. vocab_offsets = self.config.function_token_vocab_size * torch.arange( - self.config.function_token_depth, - device=token_ids.device, + self.config.function_token_depth, device=token_ids.device ) inputs = token_ids + vocab_offsets[None, :] @@ -251,8 +250,7 @@ def decode( annotations.append(annotation) annotations = merge_annotations( - annotations, - merge_gap_max=annotation_gap_merge_max, + annotations, merge_gap_max=annotation_gap_merge_max ) # Drop very small annotations. diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 287e332..b7bd149 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -87,10 +87,7 @@ def forward(self, x, pairwise: torch.Tensor | None = None): prod = q[:, None, :, :] * k[:, :, None, :] diff = q[:, None, :, :] - k[:, :, None, :] - x_2d = [ - prod, - diff, - ] + x_2d = [prod, diff] if pairwise is not None: x_2d.append(pairwise) x = torch.cat(x_2d, dim=-1) @@ -289,11 +286,7 @@ def find_knn_edges( with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore ca = coords[..., 1, :] edges, edge_mask = knn_graph( - ca, - coord_mask, - padding_mask, - sequence_id, - no_knn=knn, + ca, coord_mask, padding_mask, sequence_id, no_knn=knn ) return edges, edge_mask @@ -333,12 +326,7 @@ def encode( class StructureTokenDecoder(nn.Module): - def __init__( - self, - d_model, - n_heads, - n_layers, - ): + def __init__(self, d_model, n_heads, n_layers): super().__init__() self.decoder_channels = d_model diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 4c4815d..3f9dec5 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from typing import Sequence +from typing import List, Sequence import attr import torch @@ -19,14 +19,10 @@ ) from esm.utils.structure.protein_chain import ProteinChain from esm.utils.structure.protein_complex import ProteinComplex -from esm.utils.types import ( - FunctionAnnotation, - PathOrBuffer, -) +from esm.utils.types import FunctionAnnotation, PathOrBuffer -class ProteinType(ABC): - ... +class ProteinType(ABC): ... ## Basic Types @@ -184,6 +180,9 @@ class ESMProteinTensor(ProteinType): # Such sequences may not go through standard safety filter for approved users. # Reach out if interested in using this. potential_sequence_of_concern: bool = False + # Control vectors are vectors added to each layer of the model to nudge hidden states to the desired direction. + # len(control_vectors) == number of blocks in the model. Each vector in the list have the shape of (batch size, sequence length, hidden dim) + # so it can be added to the corresponding layer in the model def _detect_attribute(self, func, msg): mapped = { @@ -260,20 +259,40 @@ class ESMProteinError(Exception, ProteinType): class GenerationConfig: track: str = "" invalid_ids: Sequence[int] = [] - schedule: str = "cosine" + # Controls the number of tokens to unmask during each round of iterative generation. + schedule: str = attr.field( + validator=attr.validators.in_(["cosine", "linear"]), default="cosine" + ) + # Controls which tokens to unmask during each round of iterative generation. + # "random" will unmask a correct number of tokens randomly. + # "entropy" will unmask the tokens with the lowest logit entropy first. + strategy: str = attr.field( + validator=attr.validators.in_(["random", "entropy"]), default="entropy" + ) # Set this to a higher value for better generation results. # Note that this needs to be less than or equal to the sequence length. num_steps: int = 1 temperature: float = 1.0 + temperature_annealing: bool = False top_p: float = 1.0 condition_on_coordinates_only: bool = True + def use_entropy_based_unmasking_strategy(self): + """Use entropy based unmasking strategy during generation.""" + self.schedule = "cosine" + self.strategy = "entropy" + self.temperature_annealing = False + + def use_generative_unmasking_strategy(self): + """Use an unmasking strategy that produces more variety of generations.""" + self.schedule = "cosine" + self.strategy = "random" + self.temperature_annealing = True + @define class InverseFoldingConfig: invalid_ids: Sequence[int] = [] - schedule: str = "cosine" - num_steps: int = 1 temperature: float = 1.0 @@ -370,9 +389,7 @@ def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: raise NotImplementedError def batch_generate( - self, - inputs: Sequence[ProteinType], - configs: Sequence[GenerationConfig], + self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig] ) -> Sequence[ProteinType]: # Same as generate(...), but generates a batch of proteins at once. raise NotImplementedError diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index eed7f21..6c2bd12 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -5,12 +5,7 @@ import requests import torch -from tenacity import ( - retry, - retry_if_result, - stop_after_attempt, - wait_exponential, -) +from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential from esm.sdk.api import ( ESM3InferenceClient, @@ -20,6 +15,7 @@ ForwardAndSampleOutput, ForwardTrackData, GenerationConfig, + InverseFoldingConfig, LogitsConfig, LogitsOutput, ProteinType, @@ -55,7 +51,15 @@ def log_retry_attempt(retry_state): ) -class FoldForgeInferenceClient: +def _validate_protein_tensor_input(input): + if not isinstance(input, ESMProteinTensor): + raise ValueError( + "Input must be an ESMProteinTensor instance. " + "Use encode() API to encode an ESMProtein into ESMProteinTensor." + ) + + +class SequenceStructureForgeInferenceClient: def __init__( self, url: str = "https://forge.evolutionaryscale.ai", @@ -73,31 +77,51 @@ def __init__( def fold( self, - model_name: str, sequence: str, potential_sequence_of_concern: bool, - ) -> torch.Tensor | ESMProteinError: + model_name: str | None = None, + ) -> ESMProtein | ESMProteinError: + request = {"sequence": sequence} + if model_name is not None: + request["model"] = model_name + try: + data = self._post("fold", request, potential_sequence_of_concern) + except ESMProteinError as e: + return e + + return ESMProtein( + coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True) + ) + + def inverse_fold( + self, + coordinates: torch.Tensor, + config: InverseFoldingConfig, + potential_sequence_of_concern: bool, + model_name: str | None = None, + ) -> ESMProtein | ESMProteinError: + inverse_folding_config = { + "invalid_ids": config.invalid_ids, + "temperature": config.temperature, + } request = { - "model": model_name, - "sequence": sequence, + "coordinates": maybe_list(coordinates, convert_nan_to_none=True), + "inverse_folding_config": inverse_folding_config, } + if model_name is not None: + request["model"] = model_name try: - data = self._post( - "fold", - request, - potential_sequence_of_concern, - ) + data = self._post("inverse_fold", request, potential_sequence_of_concern) except ESMProteinError as e: return e - return data["coordinates"] + return ESMProtein(sequence=data["sequence"]) def _post(self, endpoint, request, potential_sequence_of_concern): request["potential_sequence_of_concern"] = potential_sequence_of_concern - model_name_url = request["model"] if request["model"] != "esm3" else "api" response = requests.post( - urljoin(self.url, f"/{model_name_url}/v1/{endpoint}"), + urljoin(self.url, f"/api/v1/{endpoint}"), json=request, headers=self.headers, timeout=self.request_timeout, @@ -115,6 +139,11 @@ def _post(self, endpoint, request, potential_sequence_of_concern): if "outputs" not in data and "data" in data: data = data["data"] + # Print warning message if there is any. + if "warning_messages" in data and data["warning_messages"] is not None: + for msg in data["warning_messages"]: + print("\033[31m", msg, "\033[0m") + return data @@ -174,18 +203,13 @@ def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: output = self.__generate_protein_tensor(input, config) else: return ESMProteinError( - error_code=500, - error_msg=f"Unknown input type {type(input)}", + error_code=500, error_msg=f"Unknown input type {type(input)}" ) if ( isinstance(output, ESMProtein) and isinstance(input, ESMProtein) - and config.track - not in [ - "function", - "residue_annotations", - ] + and config.track not in ["function", "residue_annotations"] ): # Function and residue annotation encoding/decoding is lossy # There is no guarantee that decoding encoded tokens will yield the same input @@ -218,9 +242,7 @@ def _capture_exception(r): return [_capture_exception(r) for r in results] def __generate_protein( - self, - input: ESMProtein, - config: GenerationConfig, + self, input: ESMProtein, config: GenerationConfig ) -> ESMProtein | ESMProteinError: req = {} req["sequence"] = input.sequence @@ -261,9 +283,7 @@ def __generate_protein( ) def __generate_protein_tensor( - self, - input: ESMProteinTensor, - config: GenerationConfig, + self, input: ESMProteinTensor, config: GenerationConfig ) -> ESMProteinTensor | ESMProteinError: req = {} req["sequence"] = maybe_list(input.sequence) @@ -316,6 +336,7 @@ def _field_to_tensor(field, convert_none_to_nan: bool = False): def forward_and_sample( self, input: ESMProteinTensor, sampling_configuration: SamplingConfig ) -> ForwardAndSampleOutput | ESMProteinError: + _validate_protein_tensor_input(input) validate_sampling_config(sampling_configuration, on_invalid="raise") req = {} @@ -441,10 +462,9 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError: ) @retry_decorator - def decode( - self, - input: ESMProteinTensor, - ) -> ESMProtein | ESMProteinError: + def decode(self, input: ESMProteinTensor) -> ESMProtein | ESMProteinError: + _validate_protein_tensor_input(input) + tokens = {} tokens["sequence"] = maybe_list(input.sequence) tokens["structure"] = maybe_list(input.structure) @@ -454,10 +474,7 @@ def decode( tokens["residue_annotation"] = maybe_list(input.residue_annotations) tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True) - request = { - "model": self.model, - "inputs": tokens, - } + request = {"model": self.model, "inputs": tokens} try: data = self._post("decode", request, input.potential_sequence_of_concern) @@ -482,6 +499,8 @@ def decode( def logits( self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig() ) -> LogitsOutput | ESMProteinError: + _validate_protein_tensor_input(input) + # Note: using raw model forwards is discouraged because of the byte size # of the logits. # Please use forward_and_sample instead. @@ -504,11 +523,7 @@ def logits( "return_embeddings": config.return_embeddings, } - request = { - "model": self.model, - "inputs": req, - "logits_config": logits_config, - } + request = {"model": self.model, "inputs": req, "logits_config": logits_config} try: data = self._post("logits", request, input.potential_sequence_of_concern) diff --git a/esm/sdk/sagemaker.py b/esm/sdk/sagemaker.py index 925f629..d239311 100644 --- a/esm/sdk/sagemaker.py +++ b/esm/sdk/sagemaker.py @@ -2,7 +2,61 @@ import boto3 -from esm.sdk.forge import ESM3ForgeInferenceClient +from esm.sdk.forge import ( + ESM3ForgeInferenceClient, + SequenceStructureForgeInferenceClient, +) + + +class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient): + def __init__(self, endpoint_name: str): + """SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint. + + Args: + endpoint_name: Name of the SageMaker endpoint. + """ + # Dummy URL and token to make SequenceStructureForgeInferenceClient happy. + super().__init__(url="", token="dummy") + + self._endpoint_name = endpoint_name + + self._client = boto3.client(service_name="sagemaker-runtime") + + def _post(self, endpoint, request, potential_sequence_of_concern): + request["potential_sequence_of_concern"] = potential_sequence_of_concern + request["model"] = request.get("model", None) + invocations_request = { + # Duplicate these fields at the top level to make Forge requests consistent. + "model": request["model"], + "request_id": "", # Forge specific field. + "user_id": "", # Forge specific field. + # Invocation data bits. + "api_ver": "v1", # Must be v1 right now. + "endpoint": endpoint, + # Wrapped request. + endpoint: request, + } + + try: + response = self._client.invoke_endpoint( + EndpointName=self._endpoint_name, + ContentType="application/json", + Body=json.dumps(invocations_request), + ) + except Exception as e: + raise RuntimeError(f"Failure in {endpoint}: {e}") from e + + data = json.loads(response["Body"].read().decode()) + + # Response must match request. + assert ( + data["endpoint"] == endpoint + ), f"Response endpoint is {data['endpoint']} but request is {endpoint}" + + # Get the actual responses under the endpoint key. + data = data[endpoint] + + return data class ESM3SageMakerClient(ESM3ForgeInferenceClient): diff --git a/esm/tokenization/function_tokenizer.py b/esm/tokenization/function_tokenizer.py index c11ae51..15d6070 100644 --- a/esm/tokenization/function_tokenizer.py +++ b/esm/tokenization/function_tokenizer.py @@ -120,8 +120,7 @@ def keyword_to_index(self) -> dict[str, int]: def _tfidf(self) -> tfidf.TFIDFModel: """Creates TF-IDF model for encoding function keywords.""" return tfidf.TFIDFModel( - vocabulary_path=self.keyword_vocabulary_path, - idf_path=self.keyword_idf_path, + vocabulary_path=self.keyword_vocabulary_path, idf_path=self.keyword_idf_path ) @cached_property @@ -205,9 +204,7 @@ def tokenize( return tokens def _function_text_hash( - self, - labels: Collection[str], - keyword_mask: np.ndarray | None = None, + self, labels: Collection[str], keyword_mask: np.ndarray | None = None ) -> np.ndarray | None: """Applies a locality sensitive hash (LSH) to function text. @@ -295,9 +292,7 @@ def _token2ids(self, token: str) -> list[int]: raise ValueError(f"Unknown token: {token}") def batch_encode( - self, - token_batch: list[list[str]], - add_special_tokens: bool = True, + self, token_batch: list[list[str]], add_special_tokens: bool = True ) -> torch.Tensor: """Encodes batch of function tokens. @@ -312,8 +307,7 @@ def batch_encode( for tokens in token_batch ] return stack_variable_length_tensors( - encoded, - constant_value=self.vocab_to_index[""], + encoded, constant_value=self.vocab_to_index[""] ) def decode(self, encoded: torch.Tensor): diff --git a/esm/tokenization/residue_tokenizer.py b/esm/tokenization/residue_tokenizer.py index 5430a23..cf6ff48 100644 --- a/esm/tokenization/residue_tokenizer.py +++ b/esm/tokenization/residue_tokenizer.py @@ -13,11 +13,7 @@ class ResidueAnnotationsTokenizer(EsmTokenizerBase): - def __init__( - self, - csv_path: str | None = None, - max_annotations: int = 16, - ): + def __init__(self, csv_path: str | None = None, max_annotations: int = 16): if csv_path is None: csv_path = str(C.data_root() / C.RESID_CSV) self.csv_path = csv_path diff --git a/esm/tokenization/sasa_tokenizer.py b/esm/tokenization/sasa_tokenizer.py index a07c830..00a9ea6 100644 --- a/esm/tokenization/sasa_tokenizer.py +++ b/esm/tokenization/sasa_tokenizer.py @@ -31,7 +31,7 @@ def vocab(self) -> list[str]: return self.special_tokens + range_tokens @cached_property - def midpoints(self) -> list[float]: + def midpoints_tensor(self) -> torch.Tensor: """Midpoints of the SASA token ranges.""" boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2] midpoint_tokens = [ @@ -39,7 +39,11 @@ def midpoints(self) -> list[float]: for low, high in zip(boundaries[:-1], boundaries[1:]) ] midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens - return midpoint_tokens + return torch.Tensor(midpoint_tokens) + + def midpoints(self) -> list[float]: + """Midpoints of the SASA token ranges.""" + return self.midpoints_tensor.tolist() @cached_property def vocab_to_index(self) -> dict[str, int]: @@ -86,7 +90,11 @@ def encode( def decode_float(self, encoded: torch.Tensor) -> list[float]: """Decodes SASA token ids into float values.""" - return [self.midpoints[token_id] for token_id in encoded] + decoded = self.midpoints_tensor[encoded.cpu()] + nan_mask = torch.isnan(decoded) + np_arr = decoded.numpy() + np_arr[nan_mask.numpy()] = None + return np_arr.tolist() def decode(self, encoded: torch.Tensor) -> str: """Decodes SASA token ids.""" diff --git a/esm/tokenization/sequence_tokenizer.py b/esm/tokenization/sequence_tokenizer.py index a7e840d..a56ae74 100644 --- a/esm/tokenization/sequence_tokenizer.py +++ b/esm/tokenization/sequence_tokenizer.py @@ -40,9 +40,7 @@ def __init__( self.cb_token = chain_break_token additional_special_tokens = [chain_break_token] - tokenizer.add_special_tokens( - special_tokens, - ) + tokenizer.add_special_tokens(special_tokens) # This is where we configure the automatic addition of special tokens when we call # tokenizer(text, add_special_tokens=True). Note that you can also configure how two diff --git a/esm/tokenization/tokenizer_base.py b/esm/tokenization/tokenizer_base.py index a8032ea..4358ab5 100644 --- a/esm/tokenization/tokenizer_base.py +++ b/esm/tokenization/tokenizer_base.py @@ -3,56 +3,42 @@ @runtime_checkable class EsmTokenizerBase(Protocol): - def encode(self, *args, **kwargs): - ... + def encode(self, *args, **kwargs): ... - def decode(self, *args, **kwargs): - ... + def decode(self, *args, **kwargs): ... @property - def mask_token(self) -> str: - ... + def mask_token(self) -> str: ... @property - def mask_token_id(self) -> int: - ... + def mask_token_id(self) -> int: ... @property - def bos_token(self) -> str: - ... + def bos_token(self) -> str: ... @property - def bos_token_id(self) -> int: - ... + def bos_token_id(self) -> int: ... @property - def eos_token(self) -> str: - ... + def eos_token(self) -> str: ... @property - def eos_token_id(self) -> int: - ... + def eos_token_id(self) -> int: ... @property - def pad_token(self) -> str: - ... + def pad_token(self) -> str: ... @property - def pad_token_id(self) -> int: - ... + def pad_token_id(self) -> int: ... @property - def chain_break_token(self) -> str: - ... + def chain_break_token(self) -> str: ... @property - def chain_break_token_id(self) -> int: - ... + def chain_break_token_id(self) -> int: ... @property - def all_token_ids(self): - ... + def all_token_ids(self): ... @property - def special_token_ids(self): - ... + def special_token_ids(self): ... diff --git a/esm/utils/constants/esm3.py b/esm/utils/constants/esm3.py index dd6ebee..8a02be5 100644 --- a/esm/utils/constants/esm3.py +++ b/esm/utils/constants/esm3.py @@ -112,9 +112,7 @@ def data_root(): INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt" INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json" -LSH_TABLE_PATHS = { - "8bit": "data/hyperplanes_8bit_58641.npz", -} +LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"} KEYWORDS_VOCABULARY = ( IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt" diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index 0de37c7..6d897c8 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -1,4 +1,5 @@ import warnings +from typing import cast import attr import torch @@ -31,7 +32,7 @@ decode_function_tokens, decode_residue_annotation_tokens, ) -from esm.utils.misc import list_nan_to_none +from esm.utils.misc import maybe_list from esm.utils.structure.protein_chain import ProteinChain from esm.utils.types import FunctionAnnotation @@ -130,15 +131,10 @@ def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase): def decode_sequence( - sequence_tokens: torch.Tensor, - sequence_tokenizer: EsmSequenceTokenizer, - **kwargs, + sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs ) -> str: _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer) - sequence = sequence_tokenizer.decode( - sequence_tokens, - **kwargs, - ) + sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs) sequence = sequence.replace(" ", "") sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT) sequence = sequence.replace(sequence_tokenizer.cls_token, "") @@ -185,20 +181,16 @@ def decode_structure( def decode_secondary_structure( - secondary_structure_tokens: torch.Tensor, - ss_tokenizer: SecondaryStructureTokenizer, + secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer ) -> str: _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer) secondary_structure_tokens = secondary_structure_tokens[1:-1] - secondary_structure = ss_tokenizer.decode( - secondary_structure_tokens, - ) + secondary_structure = ss_tokenizer.decode(secondary_structure_tokens) return secondary_structure def decode_sasa( - sasa_tokens: torch.Tensor, - sasa_tokenizer: SASADiscretizingTokenizer, + sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer ) -> list[float]: if sasa_tokens[0] != 0: raise ValueError("SASA does not start with 0 corresponding to BOS token") @@ -213,12 +205,13 @@ def decode_sasa( torch.long, ]: # Decode if int + # handles turning NaN's into None's sasa = sasa_tokenizer.decode_float(sasa_tokens) else: # If already float, just convert to list - sasa = sasa_tokens.tolist() + sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True)) - return list_nan_to_none(sasa) + return sasa def decode_function_annotations( diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 9112d46..83c9d03 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -97,9 +97,7 @@ def tokenize_structure( # Add space for BOS and EOS tokens if add_special_tokens: coordinates = F.pad( - coordinates, - (0, 0, 0, 0, left_pad, right_pad), - value=torch.inf, + coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf ) plddt = F.pad(plddt, (left_pad, right_pad), value=0) structure_tokens = F.pad( @@ -171,8 +169,7 @@ def tokenize_function_annotations( # Tokenized Defaults def get_default_sequence_tokens( - sequence_length: int, - sequence_tokenizer: EsmSequenceTokenizer, + sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer ) -> torch.Tensor: assert sequence_tokenizer.mask_token_id is not None assert sequence_tokenizer.bos_token_id is not None @@ -191,10 +188,7 @@ def get_default_structure_tokens( sequence_length: int, structure_tokenizer: StructureTokenizer ) -> torch.Tensor: structure_tokens = ( - torch.ones( - (sequence_length + 2,), - dtype=torch.int64, - ) + torch.ones((sequence_length + 2,), dtype=torch.int64) * structure_tokenizer.mask_token_id ) # Always include BOS and EOS tokens @@ -241,10 +235,7 @@ def get_default_residue_annotation_tokens( sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer ) -> torch.Tensor: residue_annotation_tokens = ( - torch.ones( - (sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), - dtype=torch.int64, - ) + torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64) * residue_annotation_tokenizer.pad_token_id ) # Always include BOS and EOS tokens diff --git a/esm/utils/function/encode_decode.py b/esm/utils/function/encode_decode.py index 711aec2..29534e3 100644 --- a/esm/utils/function/encode_decode.py +++ b/esm/utils/function/encode_decode.py @@ -59,8 +59,7 @@ def encode_function_annotations( # Convert function token FunctionAnnotations -> Tensor function_tokens = function_tokens_tokenizer.tokenize( - annotations=ft_annotations, - seqlen=len(sequence), + annotations=ft_annotations, seqlen=len(sequence) ) function_token_ids = function_tokens_tokenizer.encode( function_tokens, add_special_tokens=add_special_tokens @@ -175,10 +174,7 @@ def decode_residue_annotation_tokens( annotation = FunctionAnnotation(label=label, start=loc, end=loc) annotations.append(annotation) - annotations = merge_annotations( - annotations, - merge_gap_max=annotation_gap_merge_max, - ) + annotations = merge_annotations(annotations, merge_gap_max=annotation_gap_merge_max) # Drop very small annotations. if annotation_min_length is not None: diff --git a/esm/utils/function/interpro.py b/esm/utils/function/interpro.py index 242eeb7..b9e0c74 100644 --- a/esm/utils/function/interpro.py +++ b/esm/utils/function/interpro.py @@ -127,11 +127,7 @@ def entries_frame(self) -> pd.DataFrame: col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"] ) df.rename( - columns={ - "ENTRY_AC": "id", - "ENTRY_TYPE": "type", - "ENTRY_NAME": "name", - }, + columns={"ENTRY_AC": "id", "ENTRY_TYPE": "type", "ENTRY_NAME": "name"}, inplace=True, ) df["type"] = df.type.str.upper().apply( diff --git a/esm/utils/function/tfidf.py b/esm/utils/function/tfidf.py index aba587f..76caf4f 100644 --- a/esm/utils/function/tfidf.py +++ b/esm/utils/function/tfidf.py @@ -50,8 +50,7 @@ def encode(self, terms: list[str]) -> sparse.csr_matrix: values /= np.linalg.norm(values) return sparse.csr_matrix( - (values, (np.zeros_like(indices), indices)), - shape=(1, len(self.vocabulary)), + (values, (np.zeros_like(indices), indices)), shape=(1, len(self.vocabulary)) ) def decode(self, vec: sparse.csr_matrix) -> list[str]: diff --git a/esm/utils/generation.py b/esm/utils/generation.py index 4aa9b92..0c1a3df 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -128,9 +128,7 @@ def iterative_sampling_raw( def _make_masked_inputs( - track: str, - sequence_length: int, - tokenizers: TokenizerCollectionProtocol, + track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol ): get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s) @@ -190,8 +188,7 @@ def _stack_field(fn: str): o, fn, stack_variable_length_tensors( - sequences=tensors, - constant_value=mask_token_id, + sequences=tensors, constant_value=mask_token_id ), ) @@ -240,6 +237,11 @@ def _get_iterative_sampling_mask_for_prompt_and_step( shape = tokens.shape B, L = shape[0], shape[1] + # TODO: figure out why we want this function to work with + # _BatchedESMProteinTensor in the first place. Logics below + # don't really work for batched tensors. + assert B == 1 + sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device) sampling_mask[:, 0] = False # BOS # EOS and all padding tokens. @@ -248,9 +250,7 @@ def _get_iterative_sampling_mask_for_prompt_and_step( ).to(device) is_mask = _get_masked_positions( - track_to_sample, - tokens, - getattr(tokenizers, track_to_sample).mask_token_id, + track_to_sample, tokens, getattr(tokenizers, track_to_sample).mask_token_id ) if not is_mask.any().item(): raise ValueError(f"Cannot sample {config.track} when input has no masks.") @@ -273,27 +273,36 @@ def _get_iterative_sampling_mask_for_prompt_and_step( ).int() num_to_sample = still_masked - num_tokens_masked_after_this_step - track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to( - device - ) # (B, L) or (B, L, D) + if config.strategy == "entropy": + track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to( + device + ) # (B, L) or (B, L, D) - if track_to_sample == "function": - track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L) + if track_to_sample == "function": + track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L) - track_entropy = track_entropy.masked_fill( - ~sampling_mask, torch.finfo(track_entropy.dtype).max - ) - _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False) - is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter( - 1, indices, True - ) - where_to_sample = sampling_mask & is_top_k + track_entropy = track_entropy.masked_fill( + ~sampling_mask, torch.finfo(track_entropy.dtype).max + ) + _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False) + is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter( + 1, indices, True + ) + where_to_sample = sampling_mask & is_top_k + elif config.strategy == "random": + # Skip B since we know there is only 1 prompt here. + _, masked_indices = sampling_mask.nonzero(as_tuple=True) + # Random shuffle the masked indices then select the first num_to_sample. + rnd_indices = masked_indices[torch.randperm(len(masked_indices))][ + :num_to_sample + ] + rnd_mask = torch.zeros_like(sampling_mask) + rnd_mask[:, rnd_indices] = True + where_to_sample = sampling_mask & rnd_mask if track_to_sample == "function": where_to_sample = where_to_sample.unsqueeze(-1).expand( - B, - L, - tokenizers.function.depth, + B, L, tokenizers.function.depth ) # (B, L) -> (B, L, D) return where_to_sample @@ -316,6 +325,11 @@ def _get_non_special_tokens( return int(torch.sum(mask).item()) +def _get_annealed_temperature(step: int, num_steps: int, initial_temperature: float): + step_ratio = step / max(1, (num_steps - 1)) + return max(initial_temperature - step_ratio, 0.001) ** 2 + + def iterative_sampling_tokens( client: ESM3InferenceClient, input_tokens: list[ESMProteinTensor], @@ -345,9 +359,7 @@ def iterative_sampling_tokens( num_sampling_steps = _get_non_special_tokens(protein, tokenizers) else: masked = _get_masked_positions( - track, - getattr(protein, track), - getattr(tokenizers, track).mask_token_id, + track, getattr(protein, track), getattr(tokenizers, track).mask_token_id ) num_sampling_steps = torch.sum(masked).item() @@ -365,10 +377,7 @@ def iterative_sampling_tokens( # Now stack the list to make a single batched ESMProteinTensor. batched_tokens = _stack_protein_tensors( - sampled_tokens, - sequence_lengths, - tokenizers, - devices.pop(), + sampled_tokens, sequence_lengths, tokenizers, devices.pop() ) # Remember sampled prompts that has somehow errored out. @@ -418,9 +427,18 @@ def iterative_sampling_tokens( len(per_prompt_cur_sampled), ) + # Handle temperature annealing, since _sample_per_prompt() doesn't have + # the concept of decoding steps. + if config.temperature_annealing: + temperature = _get_annealed_temperature( + t, config.num_steps, config.temperature + ) + else: + temperature = config.temperature + track_sample_config = SamplingTrackConfig() track_sample_config.invalid_ids = config.invalid_ids - track_sample_config.temperature = config.temperature + track_sample_config.temperature = temperature track_sample_config.top_p = config.top_p sampling_config = SamplingConfig(**{config.track: track_sample_config}) # type: ignore @@ -486,7 +504,7 @@ def iterative_sampling_tokens( setattr(outputs, "coordinates", getattr(inputs, "coordinates")) # Maybe restore all the other fields. for f in attr.fields(SamplingConfig): - if "embedding" in f.name: + if "embedding" in f.name or f.name == "return_hidden_states": continue if f.name != config.track: setattr(outputs, f.name, getattr(inputs, f.name)) @@ -494,10 +512,7 @@ def iterative_sampling_tokens( return output_tokens -def _batch_forward( - client: ESM3InferenceClient, - protein: _BatchedESMProteinTensor, -): +def _batch_forward(client: ESM3InferenceClient, protein: _BatchedESMProteinTensor): # Forward pass return client.logits( protein, diff --git a/esm/utils/generation_test.py b/esm/utils/generation_test.py index 5b3a046..fddc7f9 100644 --- a/esm/utils/generation_test.py +++ b/esm/utils/generation_test.py @@ -97,9 +97,7 @@ def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_ @pytest.mark.gpu def test_encode_chainbreak_token(esm3_remote_inference_client): - protein = esm3_remote_inference_client.encode( - ESMProtein(sequence="MSTNP|KPQKK"), - ) + protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK")) assert isinstance(protein, ESMProteinTensor) assert protein.sequence is not None assert ( diff --git a/esm/utils/misc.py b/esm/utils/misc.py index 370fba4..438800c 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -1,4 +1,3 @@ -import math import os from collections import defaultdict from typing import ContextManager, Sequence, TypeVar @@ -226,8 +225,7 @@ def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[ def merge_annotations( - annotations: list[FunctionAnnotation], - merge_gap_max: int | None = None, + annotations: list[FunctionAnnotation], merge_gap_max: int | None = None ) -> list[FunctionAnnotation]: """Merges annotations into non-overlapping segments. @@ -256,42 +254,24 @@ def merge_annotations( return merged -def list_nan_to_none(l: list) -> list: - if l is None: - return None # type: ignore - elif isinstance(l, float): - return None if math.isnan(l) else l # type: ignore - elif isinstance(l, list): - return [list_nan_to_none(x) for x in l] - else: - # Don't go into other structures. - return l - - -def list_none_to_nan(l: list) -> list: - if l is None: - return math.nan # type: ignore - elif isinstance(l, list): - return [list_none_to_nan(x) for x in l] - else: - return l - - def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: if x is None: return None if convert_none_to_nan: - x = list_none_to_nan(x) + x = np.array(x, copy=False, dtype=np.float32) + x = np.where(x is None, np.nan, x) return torch.tensor(x) def maybe_list(x, convert_nan_to_none: bool = False) -> list | None: if x is None: return None - x = x.tolist() - if convert_nan_to_none: - x = list_nan_to_none(x) - return x + if not convert_nan_to_none: + return x.tolist() + nan_mask = torch.isnan(x) + np_arr = x.cpu().numpy().astype(object) + np_arr[nan_mask.cpu().numpy()] = None + return np_arr.tolist() def huggingfacehub_login(): diff --git a/esm/utils/misc_test.py b/esm/utils/misc_test.py index 5e89d27..500acd2 100644 --- a/esm/utils/misc_test.py +++ b/esm/utils/misc_test.py @@ -1,6 +1,5 @@ """Tests for misc.py""" - from esm.utils.misc import merge_annotations from esm.utils.types import FunctionAnnotation diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py index 097e418..6611d71 100644 --- a/esm/utils/sampling.py +++ b/esm/utils/sampling.py @@ -22,16 +22,30 @@ SASA_DISCRETIZATION_BOUNDARIES, ) -# Number of dimensions for each protein tensor field without the batch dimension. -_DIMS: dict[str, int] = { - "sequence": 1, - "structure": 1, - "secondary_structure": 1, - "sasa": 1, - "function": 2, - "residue_annotations": 2, - "coordinates": 3, -} + +def _non_batched_dims(k: str, v: torch.Tensor): + match k: + case "sequence": + return 1 + case "structure": + if v.is_floating_point(): + # This is the one hot soft structure token. + return 2 + else: + # This is the normal int structure token. + return 1 + case "secondary_structure": + return 1 + case "sasa": + return 1 + case "function": + return 2 + case "residue_annotations": + return 2 + case "coordinates": + return 3 + case _: + raise ValueError(f"Unknown dim for track {k}") class _BatchedESMProteinTensor(ESMProteinTensor): @@ -52,7 +66,7 @@ def _maybe_unsqueeze(x: torch.Tensor | None): def __len__(self) -> int: def get_len(k, v) -> int: - assert len(v.shape) == _DIMS[k] + 1 + assert len(v.shape) == _non_batched_dims(k, v) + 1 return v.size(1) l = self._detect_attribute(get_len, "length") @@ -61,18 +75,14 @@ def get_len(k, v) -> int: @property def batch_size(self) -> int: def get_batch_size(k, v) -> int: - assert len(v.shape) == _DIMS[k] + 1 + assert len(v.shape) == _non_batched_dims(k, v) + 1 return v.size(0) d = self._detect_attribute(get_batch_size, "batch size") assert d is not None return d - def slice( - self, - i: int, - sequence_len: int | None = None, - ) -> ESMProteinTensor: + def slice(self, i: int, sequence_len: int | None = None) -> ESMProteinTensor: def _maybe_slice(x: torch.Tensor | None): if x is None: return None @@ -130,8 +140,7 @@ def get_default_sampling_config( def validate_sampling_config( - sampling_config: SamplingConfig, - on_invalid: Literal["raise", "warn"] = "warn", + sampling_config: SamplingConfig, on_invalid: Literal["raise", "warn"] = "warn" ): # Check that all tracks have topk_logprobs less or equal to MAX_TOP_K for track in attr.fields(SamplingConfig): @@ -288,10 +297,7 @@ def sample_sasa_logits( return sasa_value -def top_p_logits( - logits: torch.Tensor, - top_p: float | torch.Tensor, -) -> torch.Tensor: +def top_p_logits(logits: torch.Tensor, top_p: float | torch.Tensor) -> torch.Tensor: top_p = _tensorize_like(top_p, logits) batch_dims = logits.size()[:-1] @@ -320,9 +326,7 @@ def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor): def get_sampling_mask( - tokens: torch.Tensor, - sampling_track_config: SamplingTrackConfig, - mask_idx: int, + tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int ): # Do not sample at BOS and EOS tokens sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, ) diff --git a/esm/utils/sampling_test.py b/esm/utils/sampling_test.py index 5abfa4e..ee2dc61 100644 --- a/esm/utils/sampling_test.py +++ b/esm/utils/sampling_test.py @@ -31,9 +31,7 @@ def test_sample_logits(): with pytest.raises(ValueError): sampled = sample_logits( - logits=torch.randn((8, 4096)), - temperature=0.0, - valid_ids=[], + logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[] ) diff --git a/esm/utils/structure/affine3d.py b/esm/utils/structure/affine3d.py index 7402466..382abcd 100644 --- a/esm/utils/structure/affine3d.py +++ b/esm/utils/structure/affine3d.py @@ -12,15 +12,12 @@ @T.runtime_checkable class Rotation(T.Protocol): @classmethod - def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: - ... + def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ... @classmethod - def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: - ... + def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ... - def __getitem__(self, idx: T.Any) -> Self: - ... + def __getitem__(self, idx: T.Any) -> Self: ... @property def tensor(self) -> torch.Tensor: @@ -35,8 +32,7 @@ def shape(self) -> torch.Size: # This means that 1x4 quaternions are treated as size (1,) for example ... - def as_matrix(self) -> RotationMatrix: - ... + def as_matrix(self) -> RotationMatrix: ... def compose(self, other: Self) -> Self: # To be safe, we force users to explicitly convert between rotation types. @@ -50,8 +46,7 @@ def apply(self, p: torch.Tensor) -> torch.Tensor: # rotates points by this rotation object ... - def invert(self) -> Self: - ... + def invert(self) -> Self: ... @property def dtype(self) -> torch.dtype: @@ -194,10 +189,7 @@ def random( def __getitem__(self, idx: T.Any) -> "Affine3D": indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx) - return Affine3D( - trans=self.trans[indices + (slice(None),)], - rot=self.rot[idx], - ) + return Affine3D(trans=self.trans[indices + (slice(None),)], rot=self.rot[idx]) @property def shape(self) -> torch.Size: diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py index 5a8cad6..ec0a3ef 100644 --- a/esm/utils/structure/aligner.py +++ b/esm/utils/structure/aligner.py @@ -17,8 +17,7 @@ class Alignable(Protocol): # Trick to detect whether an object is a dataclass __dataclass_fields__: ClassVar[dict[str, Field[Any]]] - def __len__(self) -> int: - ... + def __len__(self) -> int: ... T = TypeVar("T", bound=Alignable) diff --git a/esm/utils/structure/metrics.py b/esm/utils/structure/metrics.py index e134cbf..dfe26eb 100644 --- a/esm/utils/structure/metrics.py +++ b/esm/utils/structure/metrics.py @@ -139,18 +139,13 @@ def compute_gdt_ts( """ if atom_exists_mask is None: atom_exists_mask = torch.isfinite(target).all(dim=-1) - ( - centered_mobile, - _, - centered_target, - _, - rotation_matrix, - _, - ) = compute_alignment_tensors( - mobile=mobile, - target=target, - atom_exists_mask=atom_exists_mask, - sequence_id=sequence_id, + (centered_mobile, _, centered_target, _, rotation_matrix, _) = ( + compute_alignment_tensors( + mobile=mobile, + target=target, + atom_exists_mask=atom_exists_mask, + sequence_id=sequence_id, + ) ) # Apply transformation to centered structure diff --git a/esm/utils/structure/normalize_coordinates.py b/esm/utils/structure/normalize_coordinates.py index 6b8efd6..d26f2e0 100644 --- a/esm/utils/structure/normalize_coordinates.py +++ b/esm/utils/structure/normalize_coordinates.py @@ -43,10 +43,7 @@ def get_protein_normalization_frame(coords: Tensor) -> Affine3D: Affine3D: tensor of Affine3D frame """ bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) - coord_mask = torch.all( - torch.all(torch.isfinite(bb_coords), dim=-1), - dim=-1, - ) + coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1) average_position_per_n_ca_c = bb_coords.masked_fill( ~coord_mask[..., None, None], 0 diff --git a/esm/utils/structure/predicted_aligned_error.py b/esm/utils/structure/predicted_aligned_error.py index 2b999c1..1071baf 100644 --- a/esm/utils/structure/predicted_aligned_error.py +++ b/esm/utils/structure/predicted_aligned_error.py @@ -49,11 +49,7 @@ def compute_predicted_aligned_error( @torch.no_grad -def compute_tm( - logits: torch.Tensor, - aa_mask: torch.Tensor, - max_bin: float = 31.0, -): +def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0): square_mask = _compute_pae_masks(aa_mask) seqlens = aa_mask.sum(-1, keepdim=True) bins = _pae_bins(max_bin, logits.shape[-1], logits.device) diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index 322e335..7155a85 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -229,8 +229,7 @@ def to_npz_string(self): return buf.getvalue() def to_structure_encoder_inputs( - self, - should_normalize_coordinates: bool = True, + self, should_normalize_coordinates: bool = True ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: coords = torch.tensor(self.atom37_positions, dtype=torch.float32) plddt = torch.tensor(self.confidence, dtype=torch.float32) @@ -494,9 +493,7 @@ def from_atom37( @classmethod def from_backbone_atom_coordinates( - cls, - backbone_atom_coordinates: np.ndarray | torch.Tensor, - **kwargs, + cls, backbone_atom_coordinates: np.ndarray | torch.Tensor, **kwargs ): """Create a ProteinChain from a set of backbone atom coordinates. @@ -529,10 +526,7 @@ def from_backbone_atom_coordinates( ) atom37_positions[:, :3, :] = backbone_atom_coordinates - return cls.from_atom37( - atom37_positions=atom37_positions, - **kwargs, - ) + return cls.from_atom37(atom37_positions=atom37_positions, **kwargs) @classmethod def from_pdb( @@ -586,22 +580,13 @@ def from_pdb( num_res = len(sequence) atom_positions = np.full( - [num_res, RC.atom_type_num, 3], - np.nan, - dtype=np.float32, - ) - atom_mask = np.full( - [num_res, RC.atom_type_num], - False, - dtype=bool, + [num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32 ) + atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool) residue_index = np.full([num_res], -1, dtype=np.int64) insertion_code = np.full([num_res], "", dtype=" "ProteinChain": """A simple converter from bs.AtomArray -> ProteinChain. Uses PDB file format as intermediate.""" diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 58b9b4e..aa27017 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -254,10 +254,7 @@ def from_pdb(cls, path: PathOrBuffer, id: str | None = None) -> "ProteinComplex" return ProteinComplex.from_chains(chains) @classmethod - def from_rcsb( - cls, - pdb_id: str, - ): + def from_rcsb(cls, pdb_id: str): """Fetch a protein complex from the RCSB PDB database.""" f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore return cls.from_pdb(f, id=pdb_id) @@ -345,10 +342,7 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): ) @classmethod - def from_chains( - cls, - chains: Sequence[ProteinChain], - ): + def from_chains(cls, chains: Sequence[ProteinChain]): if not chains: raise ValueError( "Cannot create a ProteinComplex from an empty list of chains" diff --git a/esm/utils/structure/protein_structure.py b/esm/utils/structure/protein_structure.py index 62aa66a..7f213ea 100644 --- a/esm/utils/structure/protein_structure.py +++ b/esm/utils/structure/protein_structure.py @@ -254,10 +254,7 @@ def compute_affine_and_rmsd( # Apply transformation to centered structure to compute rmsd rotated_mobile = torch.matmul(centered_mobile, rotation_matrix) avg_rmsd = compute_rmsd_no_alignment( - rotated_mobile, - centered_target, - num_valid_atoms, - reduction="batch", + rotated_mobile, centered_target, num_valid_atoms, reduction="batch" ) return affine, avg_rmsd diff --git a/esm/widgets/components/function_annotator.py b/esm/widgets/components/function_annotator.py index 8e0c3f9..714238f 100644 --- a/esm/widgets/components/function_annotator.py +++ b/esm/widgets/components/function_annotator.py @@ -116,19 +116,10 @@ def on_delete_click(b): ) delete_button = widgets.Button( - description="Delete", - tooltip="Delete this annotation", - icon="trash", - ) - entry = widgets.HBox( - [ - delete_button, - widgets.Label(value=function_str), - ] - ) - delete_button.on_click( - on_delete_click, + description="Delete", tooltip="Delete this annotation", icon="trash" ) + entry = widgets.HBox([delete_button, widgets.Label(value=function_str)]) + delete_button.on_click(on_delete_click) entries.children += (entry,) except Exception as e: diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index eb3f21d..99692e5 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -150,8 +150,7 @@ def create_sequence_results_page( sequence_items = [] for item in items: copy_to_prompt_button = widgets.Button( - description="Copy to Prompt", - disabled=copy_to_prompt_callback is None, + description="Copy to Prompt", disabled=copy_to_prompt_callback is None ) if copy_to_prompt_callback: copy_to_prompt_button.on_click( @@ -190,8 +189,7 @@ def create_sasa_results_page( sasa_items = [] for item in items: copy_to_prompt_button = widgets.Button( - description="Copy to Prompt", - disabled=copy_to_prompt_callback is None, + description="Copy to Prompt", disabled=copy_to_prompt_callback is None ) if copy_to_prompt_callback: copy_to_prompt_button.on_click(lambda b: copy_to_prompt_callback(item.sasa)) @@ -201,11 +199,7 @@ def create_sasa_results_page( print("Solvent Accessible Surface Area (SASA) is not available.") else: sasa = [s or 0 for s in item.sasa] - draw_data_array( - output, - data_array=sasa, - cmap="Reds", - ) + draw_data_array(output, data_array=sasa, cmap="Reds") if copy_to_prompt_callback: sasa_items.append( @@ -227,8 +221,7 @@ def create_secondary_structure_results_page( ss_items = [] for item in items: copy_to_prompt_button = widgets.Button( - description="Copy to Prompt", - disabled=copy_to_prompt_callback is None, + description="Copy to Prompt", disabled=copy_to_prompt_callback is None ) if copy_to_prompt_callback: copy_to_prompt_button.on_click( @@ -292,8 +285,7 @@ def create_structure_results_page( else: ptm_label = widgets.Label(value=f"pTM: {item.ptm.item():.2f}") copy_to_prompt_button = widgets.Button( - description="Copy to Prompt", - disabled=copy_to_prompt_callback is None, + description="Copy to Prompt", disabled=copy_to_prompt_callback is None ) if copy_to_prompt_callback: copy_to_prompt_button.on_click( @@ -351,8 +343,7 @@ def confidence_to_color(confidence) -> str: header = widgets.HBox([download_pdb_button, ptm_label]) grid[row, col] = widgets.VBox( - [header, output], - layout={"border": "1px solid gray"}, + [header, output], layout={"border": "1px solid gray"} ) return grid @@ -364,8 +355,7 @@ def create_function_annotations_results_page( function_items = [] for item in items: copy_to_prompt_button = widgets.Button( - description="Copy to Prompt", - disabled=copy_to_prompt_callback is None, + description="Copy to Prompt", disabled=copy_to_prompt_callback is None ) if copy_to_prompt_callback: copy_to_prompt_button.on_click( @@ -387,8 +377,7 @@ def create_function_annotations_results_page( ) else: image = draw_function_annotations( - interpro_annotations, - sequence_length=len(item), + interpro_annotations, sequence_length=len(item) ) if copy_to_prompt_callback: content = widgets.VBox( diff --git a/esm/widgets/components/secondary_structure_prompt_selector.py b/esm/widgets/components/secondary_structure_prompt_selector.py index 129f84a..020180f 100644 --- a/esm/widgets/components/secondary_structure_prompt_selector.py +++ b/esm/widgets/components/secondary_structure_prompt_selector.py @@ -155,11 +155,7 @@ def get_secondary_structure(protein_chain: ProteinChain) -> Sequence[int]: def get_ss3_categories(): - return [ - "Coil (C)", - "Alpha helix (H)", - "Beta strand (E)", - ] + return ["Coil (C)", "Alpha helix (H)", "Beta strand (E)"] def ss3_plot_index_to_letter(ss3_index: int) -> str: diff --git a/esm/widgets/components/sequence_prompt_selector.py b/esm/widgets/components/sequence_prompt_selector.py index a55b7e0..d538e67 100644 --- a/esm/widgets/components/sequence_prompt_selector.py +++ b/esm/widgets/components/sequence_prompt_selector.py @@ -124,9 +124,9 @@ def apply_highlighting(sequence, ranges): r, g, b, a = hex_to_rgba_tuple(combined_color) a = 0.5 # Set alpha to 0.5 combined_color = rgba_tuple_to_rgba_html_string((r, g, b, a)) - highlighted_line[ - i - ] = f'{highlighted_line[i]}' + highlighted_line[i] = ( + f'{highlighted_line[i]}' + ) highlighted_lines.append("".join(highlighted_line)) return "
".join(highlighted_lines) diff --git a/esm/widgets/components/structure_prompt_selector.py b/esm/widgets/components/structure_prompt_selector.py index 482cf70..f4b497c 100644 --- a/esm/widgets/components/structure_prompt_selector.py +++ b/esm/widgets/components/structure_prompt_selector.py @@ -112,10 +112,9 @@ def display_matrix_with_highlight(x_range, y_range): ).items(): selected_ranges = tuple(selected_ranges) # Convert to hashable if selected_ranges in contact_map_selection_cache: - ( - (x_start, x_end), - (y_start, y_end), - ) = contact_map_selection_cache[selected_ranges] + ((x_start, x_end), (y_start, y_end)) = contact_map_selection_cache[ + selected_ranges + ] rect = Rectangle( (x_start - 0.5, max_y - y_end - 1.5), x_end - x_start + 1, diff --git a/esm/widgets/utils/clients.py b/esm/widgets/utils/clients.py index 6c32d02..3f49e82 100644 --- a/esm/widgets/utils/clients.py +++ b/esm/widgets/utils/clients.py @@ -24,7 +24,5 @@ def get_forge_client(model_name: str) -> ESM3InferenceClient: "Forge API key not found. Please set the ESM_API_KEY environment variable." ) return ESM3ForgeInferenceClient( - model=model_name, - url="https://forge.evolutionaryscale.ai", - token=forge_token, + model=model_name, url="https://forge.evolutionaryscale.ai", token=forge_token ) diff --git a/esm/widgets/utils/drawing/draw_category_array.py b/esm/widgets/utils/drawing/draw_category_array.py index 69d2c9f..b5a7f41 100644 --- a/esm/widgets/utils/drawing/draw_category_array.py +++ b/esm/widgets/utils/drawing/draw_category_array.py @@ -110,8 +110,7 @@ def generate_color_palette(categories, category_color_mapping): else: legend_patches = [ patches.Patch( - color=rgb_colors[category_to_index[category]], - label=category, + color=rgb_colors[category_to_index[category]], label=category ) for category in categories ] diff --git a/esm/widgets/utils/drawing/draw_function_annotations.py b/esm/widgets/utils/drawing/draw_function_annotations.py index dbc2125..c71e543 100644 --- a/esm/widgets/utils/drawing/draw_function_annotations.py +++ b/esm/widgets/utils/drawing/draw_function_annotations.py @@ -26,9 +26,7 @@ def use_backend(backend): def draw_function_annotations( - annotations: list[FunctionAnnotation], - sequence_length: int, - interpro_=InterPro(), + annotations: list[FunctionAnnotation], sequence_length: int, interpro_=InterPro() ) -> widgets.Image: cmap = colormaps["tab10"] colors = [cmap(i) for i in range(len(InterProEntryType))] @@ -63,9 +61,7 @@ def draw_function_annotations( with use_backend("agg"): fig, ax = plt.subplots() record = GraphicRecord( - sequence=None, - sequence_length=sequence_length, - features=features, + sequence=None, sequence_length=sequence_length, features=features ) record.plot(ax=ax, plot_sequence=False) fig.savefig(buf, format="png", dpi=200, bbox_inches="tight") diff --git a/esm/widgets/utils/drawing/draw_protein_structure.py b/esm/widgets/utils/drawing/draw_protein_structure.py index 2f0da3d..b86bb0c 100644 --- a/esm/widgets/utils/drawing/draw_protein_structure.py +++ b/esm/widgets/utils/drawing/draw_protein_structure.py @@ -19,8 +19,7 @@ def draw_protein_structure( for start, end, color in highlighted_ranges: view.setStyle( - {"resi": str(start) + "-" + str(end)}, - {"cartoon": {"color": color}}, + {"resi": str(start) + "-" + str(end)}, {"cartoon": {"color": color}} ) view.zoomTo() diff --git a/esm/widgets/utils/indexing.py b/esm/widgets/utils/indexing.py index 892c724..2034bb5 100644 --- a/esm/widgets/utils/indexing.py +++ b/esm/widgets/utils/indexing.py @@ -8,18 +8,13 @@ PDB_INDEX_SUFFIX = "[PDB Index]" -def get_pdb_index_min_max( - protein_chain: ProteinChain, -) -> tuple[int, int]: +def get_pdb_index_min_max(protein_chain: ProteinChain) -> tuple[int, int]: residue_index = protein_chain.residue_index valid_residue_index = residue_index[residue_index != -1] return min(valid_residue_index), max(valid_residue_index) -def pdb_index_to_zero_index( - residue_index: int, - protein_chain: ProteinChain, -) -> int: +def pdb_index_to_zero_index(residue_index: int, protein_chain: ProteinChain) -> int: # Find the first position equal to residue_index pos = np.argwhere(residue_index == protein_chain.residue_index) if len(pos) == 0: @@ -27,16 +22,12 @@ def pdb_index_to_zero_index( return pos[0][0] -def zero_index_to_pdb_index( - zero_index: int, - protein_chain: ProteinChain, -) -> int: +def zero_index_to_pdb_index(zero_index: int, protein_chain: ProteinChain) -> int: return protein_chain.residue_index[zero_index] def zero_range_to_pdb_range( - zero_range: tuple[int, int], - protein_chain: ProteinChain, + zero_range: tuple[int, int], protein_chain: ProteinChain ) -> tuple[int, int]: return ( zero_index_to_pdb_index(zero_range[0], protein_chain), @@ -45,8 +36,7 @@ def zero_range_to_pdb_range( def pdb_range_to_zero_range( - pdb_range: tuple[int, int], - protein_chain: ProteinChain, + pdb_range: tuple[int, int], protein_chain: ProteinChain ) -> tuple[int, int]: return ( pdb_index_to_zero_index(pdb_range[0], protein_chain), diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 9b56134..1ce6d9c 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -196,9 +196,7 @@ def __init__( def redraw(self, change=None): categories = ["Mask (-)"] - color_map = { - "Mask (-)": "white", - } + color_map = {"Mask (-)": "white"} data_array = [0] * self.prompt_length for prompt_str, *_ in self.prompts.items(): color, _, _ = self.prompts[prompt_str] @@ -282,7 +280,7 @@ def add_entry_to_ui(self, range_string): value=( f'
' f"{range_string}" - ), + ) ) entry_label.tag = range_string # type: ignore entry_container = widgets.HBox([entry_button, entry_label]) diff --git a/esm/widgets/utils/protein_import.py b/esm/widgets/utils/protein_import.py index 9fdb4b2..c9d9dbb 100644 --- a/esm/widgets/utils/protein_import.py +++ b/esm/widgets/utils/protein_import.py @@ -9,11 +9,7 @@ class ProteinImporter: - def __init__( - self, - max_proteins: int | None = None, - autoload: bool = False, - ) -> None: + def __init__(self, max_proteins: int | None = None, autoload: bool = False) -> None: self._protein_list: list[tuple[str, ProteinChain]] = [] self._protein_workspace: dict[str, str] = {} self.max_proteins = max_proteins diff --git a/esm/widgets/utils/serialization.py b/esm/widgets/utils/serialization.py index 43799af..1123174 100644 --- a/esm/widgets/utils/serialization.py +++ b/esm/widgets/utils/serialization.py @@ -55,9 +55,7 @@ def create_download_results_button( ) -def serialize_protein( - protein: ESMProtein, -) -> str: +def serialize_protein(protein: ESMProtein) -> str: protein_dict = { "sequence": protein.sequence, "coordinates": protein.coordinates.tolist() diff --git a/esm/widgets/views/esm3_generation_launcher.py b/esm/widgets/views/esm3_generation_launcher.py index 6328446..e94c60f 100644 --- a/esm/widgets/views/esm3_generation_launcher.py +++ b/esm/widgets/views/esm3_generation_launcher.py @@ -125,9 +125,7 @@ def create_esm3_generation_launcher( ] ) - generation_config_ui = widgets.VBox( - [generation_config_settings_ui], - ) + generation_config_ui = widgets.VBox([generation_config_settings_ui]) def on_track_change(change): if change["new"] == "function": diff --git a/esm/widgets/views/esm3_prompt_selector.py b/esm/widgets/views/esm3_prompt_selector.py index 275b608..035db28 100644 --- a/esm/widgets/views/esm3_prompt_selector.py +++ b/esm/widgets/views/esm3_prompt_selector.py @@ -12,9 +12,7 @@ from esm.widgets.components.structure_prompt_selector import ( create_structure_prompt_selector, ) -from esm.widgets.utils.prompting import ( - PromptManagerCollection, -) +from esm.widgets.utils.prompting import PromptManagerCollection from esm.widgets.utils.protein_import import ProteinImporter diff --git a/esm/widgets/views/generation.py b/esm/widgets/views/generation.py index e5d5749..19015f6 100644 --- a/esm/widgets/views/generation.py +++ b/esm/widgets/views/generation.py @@ -2,20 +2,13 @@ from ipywidgets import widgets -from esm.sdk.api import ( - ESM3InferenceClient, - ESMProtein, -) +from esm.sdk.api import ESM3InferenceClient, ESMProtein from esm.utils.constants import esm3 as C from esm.widgets.components.function_annotator import ( create_function_annotator, ) -from esm.widgets.utils.prompting import ( - PromptManagerCollection, -) -from esm.widgets.utils.protein_import import ( - ProteinImporter, -) +from esm.widgets.utils.prompting import PromptManagerCollection +from esm.widgets.utils.protein_import import ProteinImporter from esm.widgets.views.esm3_generation_launcher import ( create_esm3_generation_launcher, ) @@ -47,14 +40,9 @@ def create_generation_ui( protein_length_ui = widgets.VBox( [ widgets.HTML(value="

Specify Prompt Length:

"), - widgets.HBox( - [ - protein_length_input, - protein_length_confirm_button, - ] - ), + widgets.HBox([protein_length_input, protein_length_confirm_button]), output, - ], + ] ) loading_ui = widgets.HTML(value="

Loading...

") @@ -117,10 +105,7 @@ def update_selector(*args, **kwargs): add_annotation_callback=prompt_manager_collection.add_function_annotation, delete_annotation_callback=prompt_manager_collection.delete_function_annotation, ) - function_annotator_ui.children = [ - function_annotator_title, - function_annotator, - ] + function_annotator_ui.children = [function_annotator_title, function_annotator] if len(protein_importer.protein_list) == 0: prompt_ui.children = [ @@ -139,10 +124,7 @@ def update_selector(*args, **kwargs): esm3_selector_ui = create_esm3_prompt_selector( prompt_manager_collection, protein_importer=protein_importer ) - selector_ui.children = [ - selector_title, - esm3_selector_ui, - ] + selector_ui.children = [selector_title, esm3_selector_ui] prompt_ui.children = [ protein_importer_ui, protein_length_ui, @@ -184,10 +166,7 @@ def copy_to_prompt_callback( copy_to_prompt_callback=copy_to_prompt_callback, ) generation_launcher_ui = widgets.VBox( - [ - widgets.HTML(value="

Generation Config:

"), - generation_launcher, - ] + [widgets.HTML(value="

Generation Config:

"), generation_launcher] ) if len(protein_importer.protein_list) > 0: diff --git a/esm/widgets/views/inverse_folding.py b/esm/widgets/views/inverse_folding.py index 903ceed..8becb8e 100644 --- a/esm/widgets/views/inverse_folding.py +++ b/esm/widgets/views/inverse_folding.py @@ -10,9 +10,7 @@ create_results_visualizer, ) from esm.widgets.utils.printing import wrapped_print -from esm.widgets.utils.protein_import import ( - ProteinImporter, -) +from esm.widgets.utils.protein_import import ProteinImporter def create_inverse_folding_ui(client: ESM3InferenceClient) -> widgets.Widget: diff --git a/esm/widgets/views/login.py b/esm/widgets/views/login.py index c144006..2d8be5a 100644 --- a/esm/widgets/views/login.py +++ b/esm/widgets/views/login.py @@ -133,10 +133,7 @@ def on_selection_change(change): start_msg_output, ] elif change["new"] == "Local": - model_selection_ui.children = [ - model_selection_header, - local_model, - ] + model_selection_ui.children = [model_selection_header, local_model] login_ui.children = [ infobox, selection_ui, diff --git a/esm/widgets/views/prediction.py b/esm/widgets/views/prediction.py index 55f2fa6..de6666d 100644 --- a/esm/widgets/views/prediction.py +++ b/esm/widgets/views/prediction.py @@ -10,9 +10,7 @@ create_results_visualizer, ) from esm.widgets.utils.printing import wrapped_print -from esm.widgets.utils.protein_import import ( - ProteinImporter, -) +from esm.widgets.utils.protein_import import ProteinImporter def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget: @@ -85,11 +83,7 @@ def on_click_predict(_): try: # Reset the output and results output.clear_output() - prediction_ui.children = [ - input_ui, - predict_button, - output, - ] + prediction_ui.children = [input_ui, predict_button, output] # Predict the protein's properties with output: protein = get_protein() @@ -159,19 +153,10 @@ def on_click_predict(_): wrapped_print(e) predict_button.on_click(on_click_predict) - protein_importer.entries_box.observe( - on_new_protein, - names="children", - ) + protein_importer.entries_box.observe(on_new_protein, names="children") protein_importer.register_delete_callback(lambda: validate_predict(None)) - sequence_input_ui.children[1].observe( - on_new_sequence, - names="value", - ) - input_ui.observe( - validate_predict, - names="selected_index", - ) + sequence_input_ui.children[1].observe(on_new_sequence, names="value") + input_ui.observe(validate_predict, names="selected_index") return prediction_ui diff --git a/examples/folding_inverse_folding_example.py b/examples/folding_inverse_folding_example.py index 608dd9a..902c576 100644 --- a/examples/folding_inverse_folding_example.py +++ b/examples/folding_inverse_folding_example.py @@ -7,8 +7,9 @@ ESM3InferenceClient, ESMProtein, GenerationConfig, + InverseFoldingConfig, ) -from esm.sdk.forge import FoldForgeInferenceClient +from esm.sdk.forge import SequenceStructureForgeInferenceClient def convert_none_to_nan(data): @@ -21,46 +22,53 @@ def convert_none_to_nan(data): return data -def are_allclose_with_nan(A, B, rtol=1e-5, atol=1e-2): - B = convert_none_to_nan(B) - - A = np.array(A) - B = np.array(B) - - if A.shape != B.shape: - raise ValueError("A and B must have the same shape") - - nan_mask_A = np.isnan(A) - nan_mask_B = np.isnan(B) - - if not np.array_equal(nan_mask_A, nan_mask_B): - return False - - return np.allclose(A[~nan_mask_A], B[~nan_mask_B], rtol=rtol, atol=atol) - - -def main(fold_client: FoldForgeInferenceClient, esm3_client: ESM3InferenceClient): - # Folding +def main( + sequence_structure_client: SequenceStructureForgeInferenceClient, + esm3_client: ESM3InferenceClient, +): + # Folding with esm3 client protein = get_sample_protein() - sequence_length = len(protein.sequence) # type: ignore - num_steps = int(sequence_length / 16) protein.coordinates = None protein.function_annotations = None protein.sasa = None + assert protein.sequence is not None, "Protein sequence must be set to fold" # Folding with esm3 client - folded_protein = cast( - ESMProtein, - esm3_client.generate( - protein, - GenerationConfig( - track="structure", schedule="cosine", num_steps=num_steps, temperature=0 - ), - ), - ) + config = GenerationConfig(track="structure", num_steps=1, temperature=0) + esm3_client_folded_protein = esm3_client.generate(protein, config) + assert isinstance( + esm3_client_folded_protein, ESMProtein + ), f"Using ESM3 client, ESMProtein was expected but got {protein}" # Folding with folding client - coordinates = fold_client.fold( - "esm3", - protein.sequence, # type:ignore - potential_sequence_of_concern=False, + sequence_structure_client_folded_protein = sequence_structure_client.fold( + protein.sequence, potential_sequence_of_concern=False + ) + assert isinstance( + sequence_structure_client_folded_protein, ESMProtein + ), f"Using sequence_structure client, ESMProtein was expected but got {sequence_structure_client_folded_protein}" + + # Inverse Folding with esm3 client + protein = get_sample_protein() + protein.sequence = None + protein.sasa = None + protein.function_annotations = None + assert ( + protein.coordinates is not None + ), "Protein coordinates must be set to inverse fold" + config = GenerationConfig("sequence", num_steps=1, temperature=0.7) + esm3_client_inv_folded_protein = cast( + ESMProtein, esm3_client.generate(protein, config) + ) + assert isinstance( + esm3_client_inv_folded_protein, ESMProtein + ), f"Using ESM3 client, ESMProtein was expected but got {protein}" + # Inverse Folding with inverse folding client + sequence_structure_client_inv_folded_protein = ( + sequence_structure_client.inverse_fold( + protein.coordinates, + config=InverseFoldingConfig(temperature=0.7), + potential_sequence_of_concern=False, + ) ) - assert are_allclose_with_nan(folded_protein.coordinates, coordinates) + assert isinstance( + sequence_structure_client_inv_folded_protein, ESMProtein + ), f"Using sequence_structure client, ESMProtein was expected but got {sequence_structure_client_inv_folded_protein}" diff --git a/examples/forge_generate.ipynb b/examples/forge_generate.ipynb index 50c8d82..e386dd5 100644 --- a/examples/forge_generate.ipynb +++ b/examples/forge_generate.ipynb @@ -1,679 +1,671 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ESM3\n", - "\n", - "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", - "\n", - "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n", - "\n", - "![image.png](https://github.com/evolutionaryscale/esm/blob/main/_assets/esm3_diagram.png?raw=true)\n", - "\n", - "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", - "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Imports\n", - "\n", - "If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%set_env TOKENIZERS_PARALLELISM=false\n", - "!pip install esm\n", - "import numpy as np\n", - "import torch\n", - "\n", - "!pip install py3Dmol\n", - "import py3Dmol\n", - "\n", - "from esm.utils.structure.protein_chain import ProteinChain\n", - "from esm.sdk import client\n", - "from esm.sdk.api import (\n", - " ESMProtein,\n", - " GenerationConfig,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Set up the client to Forge\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from getpass import getpass\n", - "\n", - "token = getpass(\"Token from Forge console: \")\n", - "model = client(\n", - " model=\"esm3-open\",\n", - " url=\"https://forge.evolutionaryscale.ai\",\n", - " token=token,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n", - "We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n", - "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n", - "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", - "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(renal_dipep_chain.sequence)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.\n", - "\n", - "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n", - "print(renal_dipep_chain.atom37_positions[:3])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can visualize the protein chain using the `py3Dmol` library\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# First we can create a `py3Dmol` view object\n", - "view = py3Dmol.view(width=500, height=500)\n", - "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n", - "pdb_str = renal_dipep_chain.to_pdb_string()\n", - "# Load the PDB string into the `py3Dmol` view object\n", - "view.addModel(pdb_str, \"pdb\")\n", - "# Set the style of the protein chain\n", - "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n", - "# Zoom in on the protein chain\n", - "view.zoomTo()\n", - "# Display the protein chain\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "motif_inds = np.arange(123, 146)\n", - "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n", - "motif_sequence = renal_dipep_chain[motif_inds].sequence\n", - "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n", - "print(\"Motif sequence: \", motif_sequence)\n", - "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "view = py3Dmol.view(width=500, height=500)\n", - "view.addModel(pdb_str, \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "motif_res_inds = (\n", - " motif_inds + 1\n", - ").tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n", - "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "prompt_length = 200\n", - "# First, we can construct a sequence prompt of all masks\n", - "sequence_prompt = [\"_\"] * prompt_length\n", - "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n", - "sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)\n", - "sequence_prompt = \"\".join(sequence_prompt)\n", - "print(\"Sequence prompt: \", sequence_prompt)\n", - "print(\"Length of sequence prompt: \", len(sequence_prompt))\n", - "\n", - "# Next, we can construct a structure prompt of all nan coordinates\n", - "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n", - "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n", - "structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(\n", - " motif_atom37_positions\n", - ")\n", - "print(\"Structure prompt shape: \", structure_prompt.shape)\n", - "print(\n", - " \"Indices with structure conditioning: \",\n", - " torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),\n", - ")\n", - "\n", - "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n", - "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n", - "sequence_generation_config = GenerationConfig(\n", - " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", - " num_steps=sequence_prompt.count(\"_\")\n", - " // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n", - " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n", - ")\n", - "\n", - "# Now, we can use the `generate` method of the model to decode the sequence\n", - "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n", - "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n", - "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "structure_prediction_config = GenerationConfig(\n", - " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n", - " num_steps=len(sequence_generation) // 8,\n", - " temperature=0.7,\n", - ")\n", - "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n", - "structure_prediction = model.generate(\n", - " structure_prediction_prompt, structure_prediction_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Convert the generated structure to a back into a ProteinChain object\n", - "structure_prediction_chain = structure_prediction.to_protein_chain()\n", - "# Align the generated structure to the original structure using the motif residues\n", - "motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))\n", - "structure_prediction_chain.align(\n", - " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", - ")\n", - "crmsd = structure_prediction_chain.rmsd(\n", - " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", - ")\n", - "print(\n", - " \"cRMSD of the motif in the generated structure vs the original structure: \", crmsd\n", - ")\n", - "\n", - "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", - "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n", - "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n", - "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n", - "view.addStyle(\n", - " {\"resi\": (motif_inds_in_generation + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"cyan\"}},\n", - " viewer=(0, 1),\n", - ")\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Secondary Structure Editing Example: Helix Shortening\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n", - "view = py3Dmol.view(width=500, height=500)\n", - "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "helix_region = np.arange(38, 111) # zero-indexed\n", - "view.addStyle(\n", - " {\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\": \"lightblue\"}}\n", - ")\n", - "view.zoomTo()\n", - "view.show()\n", - "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n", - "print(\n", - " \"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\",\n", - " helix_shortening_ss8,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "shortened_region_length = 45\n", - "\n", - "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n", - "sequence_prompt = (\n", - " helix_shortening_chain.sequence[: helix_region[0]]\n", - " + \"_\" * shortened_region_length\n", - " + helix_shortening_chain.sequence[helix_region[-1] + 1 :]\n", - ")\n", - "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n", - "\n", - "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n", - "ss8_prompt = (\n", - " helix_shortening_ss8[: helix_region[0]]\n", - " + (\n", - " ((shortened_region_length - 3) // 2) * \"H\"\n", - " + \"C\" * 3\n", - " + ((shortened_region_length - 3) // 2) * \"H\"\n", - " )\n", - " + helix_shortening_ss8[helix_region[-1] + 1 :]\n", - ")\n", - "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n", - "print(\n", - " \"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\",\n", - " \" \" * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],\n", - ")\n", - "\n", - "print(\"\")\n", - "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n", - "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n", - "print(\n", - " \"Original SS8 for helix-coil-helix region:\\n\\t\",\n", - " \" \" * helix_region[0]\n", - " + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],\n", - ")\n", - "\n", - "\n", - "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n", - "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Generating protein sequence...\")\n", - "sequence_generation = model.generate(\n", - " protein_prompt,\n", - " GenerationConfig(\n", - " track=\"sequence\",\n", - " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", - " temperature=0.5,\n", - " ),\n", - ")\n", - "print(\"Folding protein...\")\n", - "structure_prediction = model.generate(\n", - " ESMProtein(sequence=sequence_generation.sequence),\n", - " GenerationConfig(\n", - " track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "predicted_chain = structure_prediction.to_protein_chain()\n", - "predicted_chain = predicted_chain.align(\n", - " helix_shortening_chain,\n", - " mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),\n", - " target_inds=np.arange(\n", - " len(helix_shortening_chain) - 120, len(helix_shortening_chain)\n", - " ),\n", - ")\n", - "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", - "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", - "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "view.addStyle(\n", - " {\"resi\": (helix_region + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"lightblue\"}},\n", - " viewer=(0, 0),\n", - ")\n", - "view.addStyle(\n", - " {\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"pink\"}},\n", - " viewer=(0, 1),\n", - ")\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SASA Editing Example: Exposing a buried helix\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n", - "span_start = 105\n", - "span_end = 116\n", - "view = py3Dmol.view(width=500, height=500)\n", - "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "view.addStyle(\n", - " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"red\"}},\n", - ")\n", - "view.zoomTo()\n", - "view.show()\n", - "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n", - "\n", - "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n", - "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n", - "structure_prompt[span_start:span_end] = torch.tensor(\n", - " lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32\n", - ")\n", - "\n", - "sasa_prompt = [None] * len(lipase_chain)\n", - "sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)\n", - "\n", - "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n", - "\n", - "protein_prompt = ESMProtein(\n", - " sequence=\"_\" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 16 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import concurrent.futures\n", - "\n", - "\n", - "def generate_protein_sequence_and_structure(protein_prompt, model):\n", - " sequence_generation = model.generate(\n", - " protein_prompt,\n", - " GenerationConfig(\n", - " track=\"sequence\",\n", - " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", - " temperature=0.5,\n", - " ),\n", - " )\n", - " structure_prediction = model.generate(\n", - " ESMProtein(sequence=sequence_generation.sequence),\n", - " GenerationConfig(\n", - " track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0.7\n", - " ),\n", - " )\n", - " return structure_prediction\n", - "\n", - "\n", - "N_SAMPLES = 16\n", - "with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:\n", - " futures = [\n", - " executor.submit(generate_protein_sequence_and_structure, protein_prompt, model)\n", - " for _ in range(N_SAMPLES)\n", - " ]\n", - "\n", - " generated_proteins = [future.result() for future in futures]\n", - "\n", - "\n", - "# Sort generations by ptm\n", - "generated_proteins = sorted(\n", - " generated_proteins, key=lambda x: x.ptm.item(), reverse=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_SAMPLES_TO_SHOW = 4\n", - "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))\n", - "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", - "for i in range(N_SAMPLES_TO_SHOW):\n", - " print(\n", - " \"PTM of generated protein {}: {:.2f}\".format(\n", - " i + 1, generated_proteins[i].ptm.item()\n", - " )\n", - " )\n", - " view.addModel(\n", - " generated_proteins[i].to_protein_chain().to_pdb_string(),\n", - " \"pdb\",\n", - " viewer=(0, i + 1),\n", - " )\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "view.addStyle(\n", - " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"red\"}},\n", - ")\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.15" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ESM3\n", + "\n", + "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", + "\n", + "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n", + "\n", + "![image.png](https://github.com/evolutionaryscale/esm/blob/main/_assets/esm3_diagram.png?raw=true)\n", + "\n", + "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", + "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports\n", + "\n", + "If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%set_env TOKENIZERS_PARALLELISM=false\n", + "!pip install esm\n", + "import numpy as np\n", + "import torch\n", + "\n", + "!pip install py3Dmol\n", + "import py3Dmol\n", + "from esm.sdk import client\n", + "from esm.sdk.api import ESMProtein, GenerationConfig\n", + "from esm.utils.structure.protein_chain import ProteinChain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Set up the client to Forge\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from getpass import getpass\n", + "\n", + "token = getpass(\"Token from Forge console: \")\n", + "model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n", + "We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n", + "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n", + "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", + "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(renal_dipep_chain.sequence)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.\n", + "\n", + "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n", + "print(renal_dipep_chain.atom37_positions[:3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can visualize the protein chain using the `py3Dmol` library\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# First we can create a `py3Dmol` view object\n", + "view = py3Dmol.view(width=500, height=500)\n", + "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n", + "pdb_str = renal_dipep_chain.to_pdb_string()\n", + "# Load the PDB string into the `py3Dmol` view object\n", + "view.addModel(pdb_str, \"pdb\")\n", + "# Set the style of the protein chain\n", + "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n", + "# Zoom in on the protein chain\n", + "view.zoomTo()\n", + "# Display the protein chain\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "motif_inds = np.arange(123, 146)\n", + "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n", + "motif_sequence = renal_dipep_chain[motif_inds].sequence\n", + "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n", + "print(\"Motif sequence: \", motif_sequence)\n", + "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(pdb_str, \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "motif_res_inds = (\n", + " motif_inds + 1\n", + ").tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n", + "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt_length = 200\n", + "# First, we can construct a sequence prompt of all masks\n", + "sequence_prompt = [\"_\"] * prompt_length\n", + "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n", + "sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)\n", + "sequence_prompt = \"\".join(sequence_prompt)\n", + "print(\"Sequence prompt: \", sequence_prompt)\n", + "print(\"Length of sequence prompt: \", len(sequence_prompt))\n", + "\n", + "# Next, we can construct a structure prompt of all nan coordinates\n", + "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n", + "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n", + "structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(\n", + " motif_atom37_positions\n", + ")\n", + "print(\"Structure prompt shape: \", structure_prompt.shape)\n", + "print(\n", + " \"Indices with structure conditioning: \",\n", + " torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),\n", + ")\n", + "\n", + "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n", + "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n", + "sequence_generation_config = GenerationConfig(\n", + " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", + " num_steps=sequence_prompt.count(\"_\")\n", + " // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n", + " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n", + ")\n", + "\n", + "# Now, we can use the `generate` method of the model to decode the sequence\n", + "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n", + "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n", + "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "structure_prediction_config = GenerationConfig(\n", + " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n", + " num_steps=len(sequence_generation) // 8,\n", + " temperature=0.7,\n", + ")\n", + "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n", + "structure_prediction = model.generate(\n", + " structure_prediction_prompt, structure_prediction_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert the generated structure to a back into a ProteinChain object\n", + "structure_prediction_chain = structure_prediction.to_protein_chain()\n", + "# Align the generated structure to the original structure using the motif residues\n", + "motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))\n", + "structure_prediction_chain.align(\n", + " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", + ")\n", + "crmsd = structure_prediction_chain.rmsd(\n", + " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", + ")\n", + "print(\n", + " \"cRMSD of the motif in the generated structure vs the original structure: \", crmsd\n", + ")\n", + "\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", + "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n", + "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n", + "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n", + "view.addStyle(\n", + " {\"resi\": (motif_inds_in_generation + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"cyan\"}},\n", + " viewer=(0, 1),\n", + ")\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Secondary Structure Editing Example: Helix Shortening\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n", + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "helix_region = np.arange(38, 111) # zero-indexed\n", + "view.addStyle(\n", + " {\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\": \"lightblue\"}}\n", + ")\n", + "view.zoomTo()\n", + "view.show()\n", + "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n", + "print(\n", + " \"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\",\n", + " helix_shortening_ss8,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "shortened_region_length = 45\n", + "\n", + "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n", + "sequence_prompt = (\n", + " helix_shortening_chain.sequence[: helix_region[0]]\n", + " + \"_\" * shortened_region_length\n", + " + helix_shortening_chain.sequence[helix_region[-1] + 1 :]\n", + ")\n", + "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n", + "\n", + "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n", + "ss8_prompt = (\n", + " helix_shortening_ss8[: helix_region[0]]\n", + " + (\n", + " ((shortened_region_length - 3) // 2) * \"H\"\n", + " + \"C\" * 3\n", + " + ((shortened_region_length - 3) // 2) * \"H\"\n", + " )\n", + " + helix_shortening_ss8[helix_region[-1] + 1 :]\n", + ")\n", + "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n", + "print(\n", + " \"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\",\n", + " \" \" * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],\n", + ")\n", + "\n", + "print(\"\")\n", + "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n", + "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n", + "print(\n", + " \"Original SS8 for helix-coil-helix region:\\n\\t\",\n", + " \" \" * helix_region[0]\n", + " + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],\n", + ")\n", + "\n", + "\n", + "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n", + "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Generating protein sequence...\")\n", + "sequence_generation = model.generate(\n", + " protein_prompt,\n", + " GenerationConfig(\n", + " track=\"sequence\",\n", + " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", + " temperature=0.5,\n", + " ),\n", + ")\n", + "print(\"Folding protein...\")\n", + "structure_prediction = model.generate(\n", + " ESMProtein(sequence=sequence_generation.sequence),\n", + " GenerationConfig(\n", + " track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_chain = structure_prediction.to_protein_chain()\n", + "predicted_chain = predicted_chain.align(\n", + " helix_shortening_chain,\n", + " mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),\n", + " target_inds=np.arange(\n", + " len(helix_shortening_chain) - 120, len(helix_shortening_chain)\n", + " ),\n", + ")\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", + "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", + "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle(\n", + " {\"resi\": (helix_region + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"lightblue\"}},\n", + " viewer=(0, 0),\n", + ")\n", + "view.addStyle(\n", + " {\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"pink\"}},\n", + " viewer=(0, 1),\n", + ")\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SASA Editing Example: Exposing a buried helix\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n", + "span_start = 105\n", + "span_end = 116\n", + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle(\n", + " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"red\"}},\n", + ")\n", + "view.zoomTo()\n", + "view.show()\n", + "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n", + "\n", + "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n", + "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n", + "structure_prompt[span_start:span_end] = torch.tensor(\n", + " lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32\n", + ")\n", + "\n", + "sasa_prompt = [None] * len(lipase_chain)\n", + "sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)\n", + "\n", + "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n", + "\n", + "protein_prompt = ESMProtein(\n", + " sequence=\"_\" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 16 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import concurrent.futures\n", + "\n", + "\n", + "def generate_protein_sequence_and_structure(protein_prompt, model):\n", + " sequence_generation = model.generate(\n", + " protein_prompt,\n", + " GenerationConfig(\n", + " track=\"sequence\",\n", + " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", + " temperature=0.5,\n", + " ),\n", + " )\n", + " structure_prediction = model.generate(\n", + " ESMProtein(sequence=sequence_generation.sequence),\n", + " GenerationConfig(\n", + " track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0.7\n", + " ),\n", + " )\n", + " return structure_prediction\n", + "\n", + "\n", + "N_SAMPLES = 16\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:\n", + " futures = [\n", + " executor.submit(generate_protein_sequence_and_structure, protein_prompt, model)\n", + " for _ in range(N_SAMPLES)\n", + " ]\n", + "\n", + " generated_proteins = [future.result() for future in futures]\n", + "\n", + "\n", + "# Sort generations by ptm\n", + "generated_proteins = sorted(\n", + " generated_proteins, key=lambda x: x.ptm.item(), reverse=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N_SAMPLES_TO_SHOW = 4\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))\n", + "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", + "for i in range(N_SAMPLES_TO_SHOW):\n", + " print(\n", + " \"PTM of generated protein {}: {:.2f}\".format(\n", + " i + 1, generated_proteins[i].ptm.item()\n", + " )\n", + " )\n", + " view.addModel(\n", + " generated_proteins[i].to_protein_chain().to_pdb_string(),\n", + " \"pdb\",\n", + " viewer=(0, i + 1),\n", + " )\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle(\n", + " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"red\"}},\n", + ")\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/examples/generate.ipynb b/examples/generate.ipynb index 57d08ed..c8cec4c 100644 --- a/examples/generate.ipynb +++ b/examples/generate.ipynb @@ -1,684 +1,679 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ESM3\n", - "\n", - "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", - "\n", - "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n", - "\n", - "![image.png](https://github.com/evolutionaryscale/esm/blob/main/_assets/esm3_diagram.png?raw=true)\n", - "\n", - "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", - "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Imports\n", - "\n", - "If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%set_env TOKENIZERS_PARALLELISM=false\n", - "!pip install esm\n", - "import numpy as np\n", - "import torch\n", - "\n", - "!pip install py3Dmol\n", - "import py3Dmol\n", - "\n", - "from esm.utils.structure.protein_chain import ProteinChain\n", - "from esm.models.esm3 import ESM3\n", - "from esm.sdk import client\n", - "from esm.sdk.api import (\n", - " ESMProtein,\n", - " GenerationConfig,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load `esm-open-small` on GPU\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from esm.utils.misc import huggingfacehub_login\n", - "\n", - "huggingfacehub_login() # will prompt you to get an API key and accept the ESM3 license.\n", - "model = ESM3.from_pretrained(\"esm3_sm_open_v1\", device=torch.device(\"cuda\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Alternatively, you could use the Forge API running the model remotely, and use the local `client` to call the API just like you're used to with the model running locally on your GPU:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from getpass import getpass\n", - "# token = getpass(\"Token from Forge console: \")\n", - "# model = client(\n", - "# model=\"esm3-large-2024-03\",\n", - "# url=\"https://forge.evolutionaryscale.ai\",\n", - "# token=token,\n", - "# )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n", - "We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n", - "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n", - "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", - "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(renal_dipep_chain.sequence)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.\n", - "\n", - "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n", - "print(renal_dipep_chain.atom37_positions[:3])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can visualize the protein chain using the `py3Dmol` library\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# First we can create a `py3Dmol` view object\n", - "view = py3Dmol.view(width=500, height=500)\n", - "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n", - "pdb_str = renal_dipep_chain.to_pdb_string()\n", - "# Load the PDB string into the `py3Dmol` view object\n", - "view.addModel(pdb_str, \"pdb\")\n", - "# Set the style of the protein chain\n", - "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n", - "# Zoom in on the protein chain\n", - "view.zoomTo()\n", - "# Display the protein chain\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "motif_inds = np.arange(123, 146)\n", - "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n", - "motif_sequence = renal_dipep_chain[motif_inds].sequence\n", - "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n", - "print(\"Motif sequence: \", motif_sequence)\n", - "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "view = py3Dmol.view(width=500, height=500)\n", - "view.addModel(pdb_str, \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "motif_res_inds = (\n", - " motif_inds + 1\n", - ").tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n", - "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "prompt_length = 200\n", - "# First, we can construct a sequence prompt of all masks\n", - "sequence_prompt = [\"_\"] * prompt_length\n", - "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n", - "sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)\n", - "sequence_prompt = \"\".join(sequence_prompt)\n", - "print(\"Sequence prompt: \", sequence_prompt)\n", - "print(\"Length of sequence prompt: \", len(sequence_prompt))\n", - "\n", - "# Next, we can construct a structure prompt of all nan coordinates\n", - "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n", - "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n", - "structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(\n", - " motif_atom37_positions\n", - ")\n", - "print(\"Structure prompt shape: \", structure_prompt.shape)\n", - "print(\n", - " \"Indices with structure conditioning: \",\n", - " torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),\n", - ")\n", - "\n", - "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n", - "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n", - "sequence_generation_config = GenerationConfig(\n", - " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", - " num_steps=sequence_prompt.count(\"_\")\n", - " // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n", - " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n", - ")\n", - "\n", - "# Now, we can use the `generate` method of the model to decode the sequence\n", - "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n", - "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n", - "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "structure_prediction_config = GenerationConfig(\n", - " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n", - " num_steps=len(sequence_generation) // 8,\n", - " temperature=0.7,\n", - ")\n", - "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n", - "structure_prediction = model.generate(\n", - " structure_prediction_prompt, structure_prediction_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Convert the generated structure to a back into a ProteinChain object\n", - "structure_prediction_chain = structure_prediction.to_protein_chain()\n", - "# Align the generated structure to the original structure using the motif residues\n", - "motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))\n", - "structure_prediction_chain.align(\n", - " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", - ")\n", - "crmsd = structure_prediction_chain.rmsd(\n", - " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", - ")\n", - "print(\n", - " \"cRMSD of the motif in the generated structure vs the original structure: \", crmsd\n", - ")\n", - "\n", - "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", - "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n", - "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n", - "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n", - "view.addStyle(\n", - " {\"resi\": (motif_inds_in_generation + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"cyan\"}},\n", - " viewer=(0, 1),\n", - ")\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Secondary Structure Editing Example: Helix Shortening\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n", - "view = py3Dmol.view(width=500, height=500)\n", - "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "helix_region = np.arange(38, 111) # zero-indexed\n", - "view.addStyle(\n", - " {\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\": \"lightblue\"}}\n", - ")\n", - "view.zoomTo()\n", - "view.show()\n", - "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n", - "print(\n", - " \"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\",\n", - " helix_shortening_ss8,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "shortened_region_length = 45\n", - "\n", - "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n", - "sequence_prompt = (\n", - " helix_shortening_chain.sequence[: helix_region[0]]\n", - " + \"_\" * shortened_region_length\n", - " + helix_shortening_chain.sequence[helix_region[-1] + 1 :]\n", - ")\n", - "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n", - "\n", - "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n", - "ss8_prompt = (\n", - " helix_shortening_ss8[: helix_region[0]]\n", - " + (\n", - " ((shortened_region_length - 3) // 2) * \"H\"\n", - " + \"C\" * 3\n", - " + ((shortened_region_length - 3) // 2) * \"H\"\n", - " )\n", - " + helix_shortening_ss8[helix_region[-1] + 1 :]\n", - ")\n", - "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n", - "print(\n", - " \"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\",\n", - " \" \" * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],\n", - ")\n", - "\n", - "print(\"\")\n", - "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n", - "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n", - "print(\n", - " \"Original SS8 for helix-coil-helix region:\\n\\t\",\n", - " \" \" * helix_region[0]\n", - " + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],\n", - ")\n", - "\n", - "\n", - "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n", - "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Generating protein sequence...\")\n", - "sequence_generation = model.generate(\n", - " protein_prompt,\n", - " GenerationConfig(\n", - " track=\"sequence\",\n", - " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", - " temperature=0.5,\n", - " ),\n", - ")\n", - "print(\"Folding protein...\")\n", - "structure_prediction = model.generate(\n", - " ESMProtein(sequence=sequence_generation.sequence),\n", - " GenerationConfig(\n", - " track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "predicted_chain = structure_prediction.to_protein_chain()\n", - "predicted_chain = predicted_chain.align(\n", - " helix_shortening_chain,\n", - " mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),\n", - " target_inds=np.arange(\n", - " len(helix_shortening_chain) - 120, len(helix_shortening_chain)\n", - " ),\n", - ")\n", - "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", - "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", - "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "view.addStyle(\n", - " {\"resi\": (helix_region + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"lightblue\"}},\n", - " viewer=(0, 0),\n", - ")\n", - "view.addStyle(\n", - " {\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"pink\"}},\n", - " viewer=(0, 1),\n", - ")\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SASA Editing Example: Exposing a buried helix\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n", - "span_start = 105\n", - "span_end = 116\n", - "view = py3Dmol.view(width=500, height=500)\n", - "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "view.addStyle(\n", - " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"red\"}},\n", - ")\n", - "view.zoomTo()\n", - "view.show()\n", - "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n", - "\n", - "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n", - "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n", - "structure_prompt[span_start:span_end] = torch.tensor(\n", - " lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32\n", - ")\n", - "\n", - "sasa_prompt = [None] * len(lipase_chain)\n", - "sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)\n", - "\n", - "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n", - "\n", - "protein_prompt = ESMProtein(\n", - " sequence=\"_\" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 32 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "generated_proteins = []\n", - "N_SAMPLES = 16\n", - "for i in range(N_SAMPLES):\n", - " print(\"Generating protein sequence...\")\n", - " sequence_generation = model.generate(\n", - " protein_prompt,\n", - " GenerationConfig(\n", - " track=\"sequence\", num_steps=len(protein_prompt) // 8, temperature=0.7\n", - " ),\n", - " )\n", - " print(\"Folding protein...\")\n", - " structure_prediction = model.generate(\n", - " ESMProtein(sequence=sequence_generation.sequence),\n", - " GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 32),\n", - " )\n", - " generated_proteins.append(structure_prediction)\n", - "\n", - "# Sort generations by ptm\n", - "generated_proteins = sorted(\n", - " generated_proteins, key=lambda x: x.ptm.item(), reverse=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_SAMPLES_TO_SHOW = 4\n", - "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))\n", - "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", - "for i in range(N_SAMPLES_TO_SHOW):\n", - " print(\n", - " \"PTM of generated protein {}: {:.2f}\".format(\n", - " i + 1, generated_proteins[i].ptm.item()\n", - " )\n", - " )\n", - " view.addModel(\n", - " generated_proteins[i].to_protein_chain().to_pdb_string(),\n", - " \"pdb\",\n", - " viewer=(0, i + 1),\n", - " )\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "view.addStyle(\n", - " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", - " {\"cartoon\": {\"color\": \"red\"}},\n", - ")\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ESM3\n", + "\n", + "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", + "\n", + "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n", + "\n", + "![image.png](https://github.com/evolutionaryscale/esm/blob/main/_assets/esm3_diagram.png?raw=true)\n", + "\n", + "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", + "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports\n", + "\n", + "If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%set_env TOKENIZERS_PARALLELISM=false\n", + "!pip install esm\n", + "import numpy as np\n", + "import torch\n", + "\n", + "!pip install py3Dmol\n", + "import py3Dmol\n", + "from esm.models.esm3 import ESM3\n", + "from esm.sdk.api import ESMProtein, GenerationConfig\n", + "from esm.utils.structure.protein_chain import ProteinChain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load `esm-open-small` on GPU\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from esm.utils.misc import huggingfacehub_login\n", + "\n", + "huggingfacehub_login() # will prompt you to get an API key and accept the ESM3 license.\n", + "model = ESM3.from_pretrained(\"esm3_sm_open_v1\", device=torch.device(\"cuda\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, you could use the Forge API running the model remotely, and use the local `client` to call the API just like you're used to with the model running locally on your GPU:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from getpass import getpass\n", + "# token = getpass(\"Token from Forge console: \")\n", + "# model = client(\n", + "# model=\"esm3-large-2024-03\",\n", + "# url=\"https://forge.evolutionaryscale.ai\",\n", + "# token=token,\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n", + "We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n", + "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n", + "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", + "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(renal_dipep_chain.sequence)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.\n", + "\n", + "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n", + "print(renal_dipep_chain.atom37_positions[:3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can visualize the protein chain using the `py3Dmol` library\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# First we can create a `py3Dmol` view object\n", + "view = py3Dmol.view(width=500, height=500)\n", + "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n", + "pdb_str = renal_dipep_chain.to_pdb_string()\n", + "# Load the PDB string into the `py3Dmol` view object\n", + "view.addModel(pdb_str, \"pdb\")\n", + "# Set the style of the protein chain\n", + "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n", + "# Zoom in on the protein chain\n", + "view.zoomTo()\n", + "# Display the protein chain\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "motif_inds = np.arange(123, 146)\n", + "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n", + "motif_sequence = renal_dipep_chain[motif_inds].sequence\n", + "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n", + "print(\"Motif sequence: \", motif_sequence)\n", + "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(pdb_str, \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "motif_res_inds = (\n", + " motif_inds + 1\n", + ").tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n", + "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt_length = 200\n", + "# First, we can construct a sequence prompt of all masks\n", + "sequence_prompt = [\"_\"] * prompt_length\n", + "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n", + "sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)\n", + "sequence_prompt = \"\".join(sequence_prompt)\n", + "print(\"Sequence prompt: \", sequence_prompt)\n", + "print(\"Length of sequence prompt: \", len(sequence_prompt))\n", + "\n", + "# Next, we can construct a structure prompt of all nan coordinates\n", + "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n", + "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n", + "structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(\n", + " motif_atom37_positions\n", + ")\n", + "print(\"Structure prompt shape: \", structure_prompt.shape)\n", + "print(\n", + " \"Indices with structure conditioning: \",\n", + " torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),\n", + ")\n", + "\n", + "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n", + "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n", + "sequence_generation_config = GenerationConfig(\n", + " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", + " num_steps=sequence_prompt.count(\"_\")\n", + " // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n", + " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n", + ")\n", + "\n", + "# Now, we can use the `generate` method of the model to decode the sequence\n", + "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n", + "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n", + "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "structure_prediction_config = GenerationConfig(\n", + " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n", + " num_steps=len(sequence_generation) // 8,\n", + " temperature=0.7,\n", + ")\n", + "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n", + "structure_prediction = model.generate(\n", + " structure_prediction_prompt, structure_prediction_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert the generated structure to a back into a ProteinChain object\n", + "structure_prediction_chain = structure_prediction.to_protein_chain()\n", + "# Align the generated structure to the original structure using the motif residues\n", + "motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))\n", + "structure_prediction_chain.align(\n", + " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", + ")\n", + "crmsd = structure_prediction_chain.rmsd(\n", + " renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n", + ")\n", + "print(\n", + " \"cRMSD of the motif in the generated structure vs the original structure: \", crmsd\n", + ")\n", + "\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", + "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n", + "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n", + "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n", + "view.addStyle(\n", + " {\"resi\": (motif_inds_in_generation + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"cyan\"}},\n", + " viewer=(0, 1),\n", + ")\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Secondary Structure Editing Example: Helix Shortening\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n", + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "helix_region = np.arange(38, 111) # zero-indexed\n", + "view.addStyle(\n", + " {\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\": \"lightblue\"}}\n", + ")\n", + "view.zoomTo()\n", + "view.show()\n", + "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n", + "print(\n", + " \"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\",\n", + " helix_shortening_ss8,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "shortened_region_length = 45\n", + "\n", + "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n", + "sequence_prompt = (\n", + " helix_shortening_chain.sequence[: helix_region[0]]\n", + " + \"_\" * shortened_region_length\n", + " + helix_shortening_chain.sequence[helix_region[-1] + 1 :]\n", + ")\n", + "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n", + "\n", + "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n", + "ss8_prompt = (\n", + " helix_shortening_ss8[: helix_region[0]]\n", + " + (\n", + " ((shortened_region_length - 3) // 2) * \"H\"\n", + " + \"C\" * 3\n", + " + ((shortened_region_length - 3) // 2) * \"H\"\n", + " )\n", + " + helix_shortening_ss8[helix_region[-1] + 1 :]\n", + ")\n", + "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n", + "print(\n", + " \"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\",\n", + " \" \" * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],\n", + ")\n", + "\n", + "print(\"\")\n", + "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n", + "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n", + "print(\n", + " \"Original SS8 for helix-coil-helix region:\\n\\t\",\n", + " \" \" * helix_region[0]\n", + " + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],\n", + ")\n", + "\n", + "\n", + "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n", + "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Generating protein sequence...\")\n", + "sequence_generation = model.generate(\n", + " protein_prompt,\n", + " GenerationConfig(\n", + " track=\"sequence\",\n", + " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", + " temperature=0.5,\n", + " ),\n", + ")\n", + "print(\"Folding protein...\")\n", + "structure_prediction = model.generate(\n", + " ESMProtein(sequence=sequence_generation.sequence),\n", + " GenerationConfig(\n", + " track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_chain = structure_prediction.to_protein_chain()\n", + "predicted_chain = predicted_chain.align(\n", + " helix_shortening_chain,\n", + " mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),\n", + " target_inds=np.arange(\n", + " len(helix_shortening_chain) - 120, len(helix_shortening_chain)\n", + " ),\n", + ")\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n", + "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", + "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle(\n", + " {\"resi\": (helix_region + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"lightblue\"}},\n", + " viewer=(0, 0),\n", + ")\n", + "view.addStyle(\n", + " {\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"pink\"}},\n", + " viewer=(0, 1),\n", + ")\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SASA Editing Example: Exposing a buried helix\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n", + "span_start = 105\n", + "span_end = 116\n", + "view = py3Dmol.view(width=500, height=500)\n", + "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle(\n", + " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"red\"}},\n", + ")\n", + "view.zoomTo()\n", + "view.show()\n", + "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n", + "\n", + "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n", + "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n", + "structure_prompt[span_start:span_end] = torch.tensor(\n", + " lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32\n", + ")\n", + "\n", + "sasa_prompt = [None] * len(lipase_chain)\n", + "sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)\n", + "\n", + "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n", + "\n", + "protein_prompt = ESMProtein(\n", + " sequence=\"_\" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 32 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generated_proteins = []\n", + "N_SAMPLES = 16\n", + "for i in range(N_SAMPLES):\n", + " print(\"Generating protein sequence...\")\n", + " sequence_generation = model.generate(\n", + " protein_prompt,\n", + " GenerationConfig(\n", + " track=\"sequence\", num_steps=len(protein_prompt) // 8, temperature=0.7\n", + " ),\n", + " )\n", + " print(\"Folding protein...\")\n", + " structure_prediction = model.generate(\n", + " ESMProtein(sequence=sequence_generation.sequence),\n", + " GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 32),\n", + " )\n", + " generated_proteins.append(structure_prediction)\n", + "\n", + "# Sort generations by ptm\n", + "generated_proteins = sorted(\n", + " generated_proteins, key=lambda x: x.ptm.item(), reverse=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N_SAMPLES_TO_SHOW = 4\n", + "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))\n", + "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n", + "for i in range(N_SAMPLES_TO_SHOW):\n", + " print(\n", + " \"PTM of generated protein {}: {:.2f}\".format(\n", + " i + 1, generated_proteins[i].ptm.item()\n", + " )\n", + " )\n", + " view.addModel(\n", + " generated_proteins[i].to_protein_chain().to_pdb_string(),\n", + " \"pdb\",\n", + " viewer=(0, i + 1),\n", + " )\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", + "view.addStyle(\n", + " {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n", + " {\"cartoon\": {\"color\": \"red\"}},\n", + ")\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/examples/gfp_design.ipynb b/examples/gfp_design.ipynb index 6bc0e08..f4e56cf 100644 --- a/examples/gfp_design.ipynb +++ b/examples/gfp_design.ipynb @@ -1,553 +1,540 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Design a GFP Candidate with ESM3\n", - "\n", - "This notebook walks through the computational methods used to design esmGFP in [Hayes et al., 2024](https://doi.org/10.1101/2024.07.01.600583). esmGFP has similar brightness and spectral properties as GFPs found in nature despite being a far distance (58% identity) from known fluorescent proteins, but we also found many other bright new GFPs with similar or higher sequence identity. One can likely design a lot more new GFPs with the approach sketched in this notebook!\n", - "\n", - "This notebook implements the core prompt used to begin the chain of thought used to create esmGFP. The overall process we used differs in two keys ways:\n", - "\n", - "1. We continued the generation process beyond what is shown in this notebook\n", - " to do a joint optimization of the generated sequence and structure.\n", - "2. We used significantly more compute than is easy to do with a notebook to\n", - " generate many designs and filter them with a set of computational filters and\n", - " ranking mechanisms.\n", - "\n", - "And we validated a small number of the generated designs in a wet lab, which of course you can also do... but this notebook isn't very helpful with that!" - ], - "metadata": { - "id": "zWXOAcBB8h3z" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Set up the notebook and model (via the Forge API).\n", - "\n", - "We begin by installing the [esm package](https://github.com/evolutionaryscale/esm) and py3Dmol, which will allow us to visualize our generations, and then importing necessary packages." - ], - "metadata": { - "id": "kZRkkEKv-8YW" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VgTZdaIMQ44H" - }, - "outputs": [], - "source": [ - "from IPython.display import clear_output\n", - "\n", - "!pip install git+https://github.com/evolutionaryscale/esm.git\n", - "!pip install py3Dmol\n", - "\n", - "clear_output() # Suppress pip install log lines after installation is complete." - ] - }, - { - "cell_type": "code", - "source": [ - "import biotite.sequence as seq\n", - "import biotite.sequence.align as align\n", - "import biotite.sequence.graphics as graphics\n", - "from getpass import getpass\n", - "import matplotlib.pyplot as pl\n", - "import py3Dmol\n", - "import torch\n", - "\n", - "from esm.sdk import client\n", - "from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig\n", - "from esm.utils.structure.protein_chain import ProteinChain" - ], - "metadata": { - "id": "poK5NTaXRGcX" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "\n", - "ESM3 is a frontier generative model for biology. It is scalable due to its ability to tokenize sequence, structure, and function and use a (nearly standard) transformer architecture while still being able to reason across all modalities simulateously.\n", - "\n", - "The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n", - "\n", - "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." - ], - "metadata": { - "id": "vmVYm2uQ7m-5" - } - }, - { - "cell_type": "code", - "source": [ - "token = getpass(\"Token from Forge console: \")" - ], - "metadata": { - "id": "zNrU9Q2SYonX" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "We then create a model stub that behaves somewhat like a PyTorch model, but under the hood it sends the inputs to the Forge server, runs the inputs through the neural network weights on that remote server, and then returns the output tensors here in this notebook. This stub can also be used in the EvolutionaryScale SDK to simplify a lot of the operations around generation, folding, and generally using the sampling. This is important because iterative sampling is key to getting the best performance from ESM3, and the SDK manages a lot of the complexity around implementing these standard routines." - ], - "metadata": { - "id": "9jIc4OZyh2oE" - } - }, - { - "cell_type": "code", - "source": [ - "model = client(\n", - " model=\"esm3-medium-2024-03\",\n", - " url=\"https://forge.evolutionaryscale.ai\",\n", - " token=token,\n", - ")" - ], - "metadata": { - "id": "Tna_mjGOjdXA" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Construct the GFP Prompt\n" - ], - "metadata": { - "id": "ZaLBVJlzTxdT" - } - }, - { - "cell_type": "markdown", - "source": [ - "ESM3 is a generative model. To access the generative capabilities we need to get comfortable with constructing prompts. ESM3 jointly reasons across sequence, structure, and function of proteins, so we can construct new types of prompts that can guide the model to generate proteins with higher levels of control than many other biological language models.\n", - "\n", - "Sequence, structure, and function modalities are represented as tracks of discrete tokens that are present at both the input and output of the model and fused into a single latent space within the model. ESM3 is trained\n", - "with a generative masked language modeling objective with variable mask rates, so we can prompt with a fully or partially masked context and different points of conditioning across the various tracks. This gives us an opportunity to be highly creative with how we specify our prompts!\n", - "\n", - "Prompt engineering is a bit of an art and a bit of a science, so one typically needs to experiment to get a prompt that produces a desired result. Also because we use sampling to generate from the model the results of different generations from the same prompt will vary. Some prompts tend to have higher success rates requiring only a few generations to get a candidate protein design. Other more difficult prompts may require thousands of generations! The models are more controllable with alignment.\n", - "\n", - "The model we will be using is the raw pretrained (unaligned) model, but we've worked a lot on this prompt so one can typically get an interesting design with only a few generations." - ], - "metadata": { - "id": "KTn1Z4hCORVR" - } - }, - { - "cell_type": "markdown", - "source": [ - "We'll construct our prompt from fragments of the [1qy3](https://www.rcsb.org/structure/1qy3) sequence and structure from the PDB. The following code fetches data from the PDB and then uses ESM3's tokenizers to convert the sequence and structure to tokens that can be passed into the model. Once can see that both the amino acid type and the coordinates are converted into one discrete token per sequence position." - ], - "metadata": { - "id": "qtwnyA1BngWy" - } - }, - { - "cell_type": "code", - "source": [ - "template_gfp = ESMProtein.from_protein_chain(\n", - " ProteinChain.from_rcsb(\"1qy3\", chain_id=\"A\")\n", - ")\n", - "\n", - "template_gfp_tokens = model.encode(template_gfp)\n", - "\n", - "print(\"Sequence tokens:\")\n", - "print(\" \", \", \".join([\n", - " str(token) for token in template_gfp_tokens.sequence.tolist()\n", - "]))\n", - "\n", - "print(\"Structure tokens:\")\n", - "print(\" \", \", \".join([\n", - " str(token) for token in template_gfp_tokens.structure.tolist()\n", - "]))" - ], - "metadata": { - "id": "cDWcXKmlbC1z" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "We'll now build a prompt. Specifically we'll specify 4 amino acid identities at positions near where we want the chromophore to form, and 2 amino acid identities on the beta barrel that are known to support chromophore formation.\n", - "\n", - "Furthermore we'll specify the structure should be similar to the 1qy3 structure at all these positions by adding tokens from the encoded 1qy3 structure to the structure track of our prompt. We'll also specify a few more positions (along the alpha helix kink)." - ], - "metadata": { - "id": "xUxwNTuWqxks" - } - }, - { - "cell_type": "code", - "source": [ - "prompt_sequence = [\"_\"] * len(template_gfp.sequence)\n", - "prompt_sequence[59] = \"T\"\n", - "prompt_sequence[62] = \"T\"\n", - "prompt_sequence[63] = \"Y\"\n", - "prompt_sequence[64] = \"G\"\n", - "prompt_sequence[93] = \"R\"\n", - "prompt_sequence[219] = \"E\"\n", - "prompt_sequence = \"\".join(prompt_sequence)\n", - "\n", - "print(template_gfp.sequence)\n", - "print(prompt_sequence)\n", - "\n", - "prompt = model.encode(\n", - " ESMProtein(sequence=prompt_sequence)\n", - ")\n", - "\n", - "# We construct an empty structure track like | ... |...\n", - "prompt.structure = torch.full_like(prompt.sequence, 4096)\n", - "prompt.structure[0] = 4098\n", - "prompt.structure[-1] = 4097\n", - "# ... and then we fill in structure tokens at key residues near the alpha helix\n", - "# kink and at the stabilizing R and E positions on the beta barrel.\n", - "prompt.structure[55:70] = template_gfp_tokens.structure[56:71]\n", - "prompt.structure[93] = template_gfp_tokens.structure[93]\n", - "prompt.structure[219] = template_gfp_tokens.structure[219]\n", - "\n", - "print(\"\".join([\"✔\" if st < 4096 else \"_\" for st in prompt.structure]))" - ], - "metadata": { - "id": "YBfYwRFGDKjU" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "The output shows the original 1qy3 sequence and the our prompt sequence track amino acid identities and the positions that have a token on the structure track. ESM3 will then be tasked with filling in the structure and sequence at the remaining masked (underscore) positions.\n", - "\n", - "One small note, we introduced the mutation A93R in our prompt. This isn't a mistake. Using Alanine at this position causes the chromophore to mature extremely slowly (which is how we are able to measure the precyclized structure of GFP!). However we don't want to wait around for our GFPs to glow so we go with Arginine at this position." - ], - "metadata": { - "id": "107URq1bpA4_" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Generate a Structure\n", - "\n", - "We then prompt the model and decode the structure tokens track. This is similar to creating a backbone scaffold for an active site prompt, but there are some subtle differences. For example, since we've already specified some of the structure tokens (e.g., around the active site and key corresponding residues) the model literally generates around this structure.\n", - "\n", - "Tokens are iteratively sampled from ESM3. They\n", - "can be sampled one at a time, or in parallel, in any order, until all positions are fully unmasked. The generate() function in the EvolutionaryScale SDK implements one recipe we think is effective for sampling from the model.\n" - ], - "metadata": { - "id": "DE0BvswIATN8" - } - }, - { - "cell_type": "code", - "source": [ - "%%time\n", - "\n", - "num_tokens_to_decode = (prompt.structure == 4096).sum().item()\n", - "\n", - "structure_generation = model.generate(\n", - " prompt,\n", - " GenerationConfig(\n", - " # Generate a structure.\n", - " track=\"structure\",\n", - " # Sample one token per forward pass of the model.\n", - " num_steps=num_tokens_to_decode,\n", - " # Sampling temperature trades perplexity with diversity.\n", - " temperature=1.0,\n", - " )\n", - ")\n", - "\n", - "print(\"These are the structure tokens corresponding to our new design:\")\n", - "print(\" \", \", \".join([\n", - " str(token) for token in structure_generation.structure.tolist()\n", - "]))\n", - "\n", - "# Decodes structure tokens to backbone coordinates.\n", - "structure_generation_protein = model.decode(structure_generation)\n", - "\n", - "print(\"\")" - ], - "metadata": { - "id": "yatAF6kYHZdm" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Now let's visualize our generated structure. This will probably look like the familiar GFP beta barrel around an alpha helix." - ], - "metadata": { - "id": "0HARel94tJfI" - } - }, - { - "cell_type": "code", - "source": [ - "view = py3Dmol.view(width=1000, height=500)\n", - "view.addModel(structure_generation_protein.to_protein_chain().infer_oxygen().to_pdb_string(), \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}})\n", - "view.zoomTo()\n", - "view.show()" - ], - "metadata": { - "id": "D30KQC6ffrrH" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "At this point we only want to continue the generation if this design is a close match to a wildtype GFP at the active site, has some structural difference across the full protein (otherwise it would end up being very sequence-similar to wildtype GFP), and overall still looks like the classic GFP alpha helix in a beta barrel structure.\n", - "\n", - "Of course when generating many designs we cannot look at each one manually, so we adopt some automated rejection sampling criteria based on the overall structure RMSD and the constrained site RMSD for our generated structure being faithful to the prompt. If these checks pass then we'll try to design a sequence for this structure. If not, one should go back up a few cells and design another structure until it passes these computational screens. (Or not... this is your GFP design!)" - ], - "metadata": { - "id": "K55y-5BeDRdr" - } - }, - { - "cell_type": "code", - "source": [ - "constrained_site_positions = [59, 62, 63, 64, 93, 219]\n", - "\n", - "template_chain = template_gfp.to_protein_chain()\n", - "generation_chain = structure_generation_protein.to_protein_chain()\n", - "\n", - "constrained_site_rmsd = template_chain[constrained_site_positions].rmsd(\n", - " generation_chain[constrained_site_positions]\n", - ")\n", - "backbone_rmsd = template_chain.rmsd(generation_chain)\n", - "\n", - "c_pass = \"✅\" if constrained_site_rmsd < 1.5 else \"❌\"\n", - "b_pass = \"✅\" if backbone_rmsd > 1.5 else \"❌\"\n", - "\n", - "print(f\"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}\")\n", - "print(f\"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}\")" - ], - "metadata": { - "id": "aalzUw39t2O1" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Sequence Design\n", - "\n", - "Now we have a backbone with some structural variation but that also matches the GFP constrained site, and we want to design a sequence that folds to this structure. We can use the prior generation, which is exactly our original prompt plus the new structure tokens representing the backbone, to prompt ESM3 again.\n", - "\n", - "One we have designed a sequence we'll want to confirm that sequence is a match for our structure, so we'll remove all other conditioning from the prompt and fold the sequence. Conveniently with ESM3, folding a sequence is simply generating a set of structure tokens conditioned on the amino acid sequence. In this case we want the model's highest confidence generation (with no diversity) so we sample with a temperature of zero." - ], - "metadata": { - "id": "6Bfbm8UiEqya" - } - }, - { - "cell_type": "code", - "source": [ - "%%time\n", - "\n", - "num_tokens_to_decode = (prompt.sequence == 32).sum().item()\n", - "\n", - "sequence_generation = model.generate(\n", - " # Generate a sequence.\n", - " structure_generation,\n", - " GenerationConfig(\n", - " track=\"sequence\",\n", - " num_steps=num_tokens_to_decode,\n", - " temperature=1.0,\n", - " )\n", - ")\n", - "\n", - "# Refold\n", - "sequence_generation.structure = None\n", - "length_of_sequence = sequence_generation.sequence.numel() - 2\n", - "sequence_generation = model.generate(\n", - " sequence_generation,\n", - " GenerationConfig(\n", - " track=\"structure\",\n", - " num_steps=length_of_sequence,\n", - " temperature=0.0,\n", - " )\n", - ")\n", - "\n", - "# Decode to AA string and coordinates.\n", - "sequence_generation_protein = model.decode(sequence_generation)" - ], - "metadata": { - "id": "GOrWSEVTnOq0" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "We now have a candidate GFP sequence!" - ], - "metadata": { - "id": "v_zK7TDCzEX3" - } - }, - { - "cell_type": "code", - "source": [ - "sequence_generation_protein.sequence" - ], - "metadata": { - "id": "Ao_n0-R5r2fT" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "We can align this sequence against the original template to see how similar it is to avGFP. One might also want to search against all known fluorescent proteins to assess the novelty of this potential GFP." - ], - "metadata": { - "id": "LBvQYpR_zQAK" - } - }, - { - "cell_type": "code", - "source": [ - "seq1 = seq.ProteinSequence(template_gfp.sequence)\n", - "seq2 = seq.ProteinSequence(sequence_generation_protein.sequence)\n", - "\n", - "alignments = align.align_optimal(\n", - " seq1,\n", - " seq2,\n", - " align.SubstitutionMatrix.std_protein_matrix(),\n", - " gap_penalty=(-10, -1),\n", - ")\n", - "\n", - "alignment = alignments[0]\n", - "\n", - "identity = align.get_sequence_identity(alignment)\n", - "print(f\"Sequence identity: {100*identity:.2f}%\")\n", - "\n", - "print(\"\\nSequence alignment:\")\n", - "fig = pl.figure(figsize=(8.0, 4.0))\n", - "ax = fig.add_subplot(111)\n", - "graphics.plot_alignment_similarity_based(\n", - " ax, alignment, symbols_per_line=45, spacing=2,\n", - " show_numbers=True,\n", - ")\n", - "fig.tight_layout()\n", - "pl.show()" - ], - "metadata": { - "id": "JuNegp37JRyD" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "We now recheck our computational metrics for the constrained site. If we see the constrained site is not a match then we'd want to try designing the sequence again. If many attempts to design a sequence that matches the structure fail, then it's likely the structure is not easily designable and we may want to reject this structure generation as well!\n", - "\n", - "At this point the backbone RMSD doesn't matter very much to us, so long as the sequence is adequately distant to satisfy our scientific curiosity!" - ], - "metadata": { - "id": "QQHMYfJzz13w" - } - }, - { - "cell_type": "code", - "source": [ - "template_chain = template_gfp.to_protein_chain()\n", - "generation_chain = sequence_generation_protein.to_protein_chain()\n", - "\n", - "constrained_site_rmsd = template_chain[constrained_site_positions].rmsd(\n", - " generation_chain[constrained_site_positions]\n", - ")\n", - "backbone_rmsd = template_chain.rmsd(generation_chain)\n", - "\n", - "c_pass = \"✅\" if constrained_site_rmsd < 1.5 else \"❌\"\n", - "b_pass = \"🤷‍♂️\"\n", - "\n", - "print(f\"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}\")\n", - "print(f\"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}\")\n" - ], - "metadata": { - "id": "HUGQ7L4_z1BV" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "An now we can visualize the final structure prediction of our candidate GFP design." - ], - "metadata": { - "id": "0cIeC4Lg1Bz9" - } - }, - { - "cell_type": "code", - "source": [ - "view = py3Dmol.view(width=600, height=600)\n", - "view.addModel(sequence_generation_protein.to_pdb_string(), \"pdb\")\n", - "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}})\n", - "view.zoomTo()\n", - "view.show()" - ], - "metadata": { - "id": "WTGRyt-es2sJ" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Before considering this sequence for wet lab validation, we run a joint optimization of the sequence and structure. The outputs of that process are then passed through stringent computational filters and then many designs from many starting points are ranked by a number of computational scores to select the final designs sent for testing. We'll walk through that process in a different notebook." - ], - "metadata": { - "id": "VrNuZHeHRWuP" - } - }, - { - "cell_type": "markdown", - "source": [ - "If you've made it this far it's worth noting that this isn't the only method to prompt ESM3 to design a GFP, it's just the one we used to report the successful generation of esmGFP in our paper. We hope you'll try different techniques to generate from ESM3. We're interested to hear what works for you!" - ], - "metadata": { - "id": "c3jSQrJa1Tfi" - } - } - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "zWXOAcBB8h3z" + }, + "source": [ + "# Design a GFP Candidate with ESM3\n", + "\n", + "This notebook walks through the computational methods used to design esmGFP in [Hayes et al., 2024](https://doi.org/10.1101/2024.07.01.600583). esmGFP has similar brightness and spectral properties as GFPs found in nature despite being a far distance (58% identity) from known fluorescent proteins, but we also found many other bright new GFPs with similar or higher sequence identity. One can likely design a lot more new GFPs with the approach sketched in this notebook!\n", + "\n", + "This notebook implements the core prompt used to begin the chain of thought used to create esmGFP. The overall process we used differs in two keys ways:\n", + "\n", + "1. We continued the generation process beyond what is shown in this notebook\n", + " to do a joint optimization of the generated sequence and structure.\n", + "2. We used significantly more compute than is easy to do with a notebook to\n", + " generate many designs and filter them with a set of computational filters and\n", + " ranking mechanisms.\n", + "\n", + "And we validated a small number of the generated designs in a wet lab, which of course you can also do... but this notebook isn't very helpful with that!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kZRkkEKv-8YW" + }, + "source": [ + "## Set up the notebook and model (via the Forge API).\n", + "\n", + "We begin by installing the [esm package](https://github.com/evolutionaryscale/esm) and py3Dmol, which will allow us to visualize our generations, and then importing necessary packages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VgTZdaIMQ44H" + }, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "!pip install git+https://github.com/evolutionaryscale/esm.git\n", + "!pip install py3Dmol\n", + "\n", + "clear_output() # Suppress pip install log lines after installation is complete." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "poK5NTaXRGcX" + }, + "outputs": [], + "source": [ + "from getpass import getpass\n", + "\n", + "import biotite.sequence as seq\n", + "import biotite.sequence.align as align\n", + "import biotite.sequence.graphics as graphics\n", + "import matplotlib.pyplot as pl\n", + "import py3Dmol\n", + "import torch\n", + "from esm.sdk import client\n", + "from esm.sdk.api import ESMProtein, GenerationConfig\n", + "from esm.utils.structure.protein_chain import ProteinChain" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vmVYm2uQ7m-5" + }, + "source": [ + "\n", + "ESM3 is a frontier generative model for biology. It is scalable due to its ability to tokenize sequence, structure, and function and use a (nearly standard) transformer architecture while still being able to reason across all modalities simulateously.\n", + "\n", + "The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n", + "\n", + "Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zNrU9Q2SYonX" + }, + "outputs": [], + "source": [ + "token = getpass(\"Token from Forge console: \")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9jIc4OZyh2oE" + }, + "source": [ + "We then create a model stub that behaves somewhat like a PyTorch model, but under the hood it sends the inputs to the Forge server, runs the inputs through the neural network weights on that remote server, and then returns the output tensors here in this notebook. This stub can also be used in the EvolutionaryScale SDK to simplify a lot of the operations around generation, folding, and generally using the sampling. This is important because iterative sampling is key to getting the best performance from ESM3, and the SDK manages a lot of the complexity around implementing these standard routines." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Tna_mjGOjdXA" + }, + "outputs": [], + "source": [ + "model = client(\n", + " model=\"esm3-medium-2024-03\", url=\"https://forge.evolutionaryscale.ai\", token=token\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZaLBVJlzTxdT" + }, + "source": [ + "## Construct the GFP Prompt\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KTn1Z4hCORVR" + }, + "source": [ + "ESM3 is a generative model. To access the generative capabilities we need to get comfortable with constructing prompts. ESM3 jointly reasons across sequence, structure, and function of proteins, so we can construct new types of prompts that can guide the model to generate proteins with higher levels of control than many other biological language models.\n", + "\n", + "Sequence, structure, and function modalities are represented as tracks of discrete tokens that are present at both the input and output of the model and fused into a single latent space within the model. ESM3 is trained\n", + "with a generative masked language modeling objective with variable mask rates, so we can prompt with a fully or partially masked context and different points of conditioning across the various tracks. This gives us an opportunity to be highly creative with how we specify our prompts!\n", + "\n", + "Prompt engineering is a bit of an art and a bit of a science, so one typically needs to experiment to get a prompt that produces a desired result. Also because we use sampling to generate from the model the results of different generations from the same prompt will vary. Some prompts tend to have higher success rates requiring only a few generations to get a candidate protein design. Other more difficult prompts may require thousands of generations! The models are more controllable with alignment.\n", + "\n", + "The model we will be using is the raw pretrained (unaligned) model, but we've worked a lot on this prompt so one can typically get an interesting design with only a few generations." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qtwnyA1BngWy" + }, + "source": [ + "We'll construct our prompt from fragments of the [1qy3](https://www.rcsb.org/structure/1qy3) sequence and structure from the PDB. The following code fetches data from the PDB and then uses ESM3's tokenizers to convert the sequence and structure to tokens that can be passed into the model. Once can see that both the amino acid type and the coordinates are converted into one discrete token per sequence position." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cDWcXKmlbC1z" + }, + "outputs": [], + "source": [ + "template_gfp = ESMProtein.from_protein_chain(\n", + " ProteinChain.from_rcsb(\"1qy3\", chain_id=\"A\")\n", + ")\n", + "\n", + "template_gfp_tokens = model.encode(template_gfp)\n", + "\n", + "print(\"Sequence tokens:\")\n", + "print(\n", + " \" \", \", \".join([str(token) for token in template_gfp_tokens.sequence.tolist()])\n", + ")\n", + "\n", + "print(\"Structure tokens:\")\n", + "print(\n", + " \" \", \", \".join([str(token) for token in template_gfp_tokens.structure.tolist()])\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xUxwNTuWqxks" + }, + "source": [ + "We'll now build a prompt. Specifically we'll specify 4 amino acid identities at positions near where we want the chromophore to form, and 2 amino acid identities on the beta barrel that are known to support chromophore formation.\n", + "\n", + "Furthermore we'll specify the structure should be similar to the 1qy3 structure at all these positions by adding tokens from the encoded 1qy3 structure to the structure track of our prompt. We'll also specify a few more positions (along the alpha helix kink)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YBfYwRFGDKjU" + }, + "outputs": [], + "source": [ + "prompt_sequence = [\"_\"] * len(template_gfp.sequence)\n", + "prompt_sequence[59] = \"T\"\n", + "prompt_sequence[62] = \"T\"\n", + "prompt_sequence[63] = \"Y\"\n", + "prompt_sequence[64] = \"G\"\n", + "prompt_sequence[93] = \"R\"\n", + "prompt_sequence[219] = \"E\"\n", + "prompt_sequence = \"\".join(prompt_sequence)\n", + "\n", + "print(template_gfp.sequence)\n", + "print(prompt_sequence)\n", + "\n", + "prompt = model.encode(ESMProtein(sequence=prompt_sequence))\n", + "\n", + "# We construct an empty structure track like | ... |...\n", + "prompt.structure = torch.full_like(prompt.sequence, 4096)\n", + "prompt.structure[0] = 4098\n", + "prompt.structure[-1] = 4097\n", + "# ... and then we fill in structure tokens at key residues near the alpha helix\n", + "# kink and at the stabilizing R and E positions on the beta barrel.\n", + "prompt.structure[55:70] = template_gfp_tokens.structure[56:71]\n", + "prompt.structure[93] = template_gfp_tokens.structure[93]\n", + "prompt.structure[219] = template_gfp_tokens.structure[219]\n", + "\n", + "print(\"\".join([\"✔\" if st < 4096 else \"_\" for st in prompt.structure]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "107URq1bpA4_" + }, + "source": [ + "The output shows the original 1qy3 sequence and the our prompt sequence track amino acid identities and the positions that have a token on the structure track. ESM3 will then be tasked with filling in the structure and sequence at the remaining masked (underscore) positions.\n", + "\n", + "One small note, we introduced the mutation A93R in our prompt. This isn't a mistake. Using Alanine at this position causes the chromophore to mature extremely slowly (which is how we are able to measure the precyclized structure of GFP!). However we don't want to wait around for our GFPs to glow so we go with Arginine at this position." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DE0BvswIATN8" + }, + "source": [ + "## Generate a Structure\n", + "\n", + "We then prompt the model and decode the structure tokens track. This is similar to creating a backbone scaffold for an active site prompt, but there are some subtle differences. For example, since we've already specified some of the structure tokens (e.g., around the active site and key corresponding residues) the model literally generates around this structure.\n", + "\n", + "Tokens are iteratively sampled from ESM3. They\n", + "can be sampled one at a time, or in parallel, in any order, until all positions are fully unmasked. The generate() function in the EvolutionaryScale SDK implements one recipe we think is effective for sampling from the model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yatAF6kYHZdm" + }, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "num_tokens_to_decode = (prompt.structure == 4096).sum().item()\n", + "\n", + "structure_generation = model.generate(\n", + " prompt,\n", + " GenerationConfig(\n", + " # Generate a structure.\n", + " track=\"structure\",\n", + " # Sample one token per forward pass of the model.\n", + " num_steps=num_tokens_to_decode,\n", + " # Sampling temperature trades perplexity with diversity.\n", + " temperature=1.0,\n", + " ),\n", + ")\n", + "\n", + "print(\"These are the structure tokens corresponding to our new design:\")\n", + "print(\n", + " \" \", \", \".join([str(token) for token in structure_generation.structure.tolist()])\n", + ")\n", + "\n", + "# Decodes structure tokens to backbone coordinates.\n", + "structure_generation_protein = model.decode(structure_generation)\n", + "\n", + "print(\"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0HARel94tJfI" + }, + "source": [ + "Now let's visualize our generated structure. This will probably look like the familiar GFP beta barrel around an alpha helix." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D30KQC6ffrrH" + }, + "outputs": [], + "source": [ + "view = py3Dmol.view(width=1000, height=500)\n", + "view.addModel(\n", + " structure_generation_protein.to_protein_chain().infer_oxygen().to_pdb_string(),\n", + " \"pdb\",\n", + ")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}})\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K55y-5BeDRdr" + }, + "source": [ + "At this point we only want to continue the generation if this design is a close match to a wildtype GFP at the active site, has some structural difference across the full protein (otherwise it would end up being very sequence-similar to wildtype GFP), and overall still looks like the classic GFP alpha helix in a beta barrel structure.\n", + "\n", + "Of course when generating many designs we cannot look at each one manually, so we adopt some automated rejection sampling criteria based on the overall structure RMSD and the constrained site RMSD for our generated structure being faithful to the prompt. If these checks pass then we'll try to design a sequence for this structure. If not, one should go back up a few cells and design another structure until it passes these computational screens. (Or not... this is your GFP design!)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aalzUw39t2O1" + }, + "outputs": [], + "source": [ + "constrained_site_positions = [59, 62, 63, 64, 93, 219]\n", + "\n", + "template_chain = template_gfp.to_protein_chain()\n", + "generation_chain = structure_generation_protein.to_protein_chain()\n", + "\n", + "constrained_site_rmsd = template_chain[constrained_site_positions].rmsd(\n", + " generation_chain[constrained_site_positions]\n", + ")\n", + "backbone_rmsd = template_chain.rmsd(generation_chain)\n", + "\n", + "c_pass = \"✅\" if constrained_site_rmsd < 1.5 else \"❌\"\n", + "b_pass = \"✅\" if backbone_rmsd > 1.5 else \"❌\"\n", + "\n", + "print(f\"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}\")\n", + "print(f\"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Bfbm8UiEqya" + }, + "source": [ + "# Sequence Design\n", + "\n", + "Now we have a backbone with some structural variation but that also matches the GFP constrained site, and we want to design a sequence that folds to this structure. We can use the prior generation, which is exactly our original prompt plus the new structure tokens representing the backbone, to prompt ESM3 again.\n", + "\n", + "One we have designed a sequence we'll want to confirm that sequence is a match for our structure, so we'll remove all other conditioning from the prompt and fold the sequence. Conveniently with ESM3, folding a sequence is simply generating a set of structure tokens conditioned on the amino acid sequence. In this case we want the model's highest confidence generation (with no diversity) so we sample with a temperature of zero." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GOrWSEVTnOq0" + }, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "num_tokens_to_decode = (prompt.sequence == 32).sum().item()\n", + "\n", + "sequence_generation = model.generate(\n", + " # Generate a sequence.\n", + " structure_generation,\n", + " GenerationConfig(track=\"sequence\", num_steps=num_tokens_to_decode, temperature=1.0),\n", + ")\n", + "\n", + "# Refold\n", + "sequence_generation.structure = None\n", + "length_of_sequence = sequence_generation.sequence.numel() - 2\n", + "sequence_generation = model.generate(\n", + " sequence_generation,\n", + " GenerationConfig(track=\"structure\", num_steps=length_of_sequence, temperature=0.0),\n", + ")\n", + "\n", + "# Decode to AA string and coordinates.\n", + "sequence_generation_protein = model.decode(sequence_generation)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v_zK7TDCzEX3" + }, + "source": [ + "We now have a candidate GFP sequence!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ao_n0-R5r2fT" + }, + "outputs": [], + "source": [ + "sequence_generation_protein.sequence" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LBvQYpR_zQAK" + }, + "source": [ + "We can align this sequence against the original template to see how similar it is to avGFP. One might also want to search against all known fluorescent proteins to assess the novelty of this potential GFP." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JuNegp37JRyD" + }, + "outputs": [], + "source": [ + "seq1 = seq.ProteinSequence(template_gfp.sequence)\n", + "seq2 = seq.ProteinSequence(sequence_generation_protein.sequence)\n", + "\n", + "alignments = align.align_optimal(\n", + " seq1, seq2, align.SubstitutionMatrix.std_protein_matrix(), gap_penalty=(-10, -1)\n", + ")\n", + "\n", + "alignment = alignments[0]\n", + "\n", + "identity = align.get_sequence_identity(alignment)\n", + "print(f\"Sequence identity: {100*identity:.2f}%\")\n", + "\n", + "print(\"\\nSequence alignment:\")\n", + "fig = pl.figure(figsize=(8.0, 4.0))\n", + "ax = fig.add_subplot(111)\n", + "graphics.plot_alignment_similarity_based(\n", + " ax, alignment, symbols_per_line=45, spacing=2, show_numbers=True\n", + ")\n", + "fig.tight_layout()\n", + "pl.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QQHMYfJzz13w" + }, + "source": [ + "We now recheck our computational metrics for the constrained site. If we see the constrained site is not a match then we'd want to try designing the sequence again. If many attempts to design a sequence that matches the structure fail, then it's likely the structure is not easily designable and we may want to reject this structure generation as well!\n", + "\n", + "At this point the backbone RMSD doesn't matter very much to us, so long as the sequence is adequately distant to satisfy our scientific curiosity!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HUGQ7L4_z1BV" + }, + "outputs": [], + "source": [ + "template_chain = template_gfp.to_protein_chain()\n", + "generation_chain = sequence_generation_protein.to_protein_chain()\n", + "\n", + "constrained_site_rmsd = template_chain[constrained_site_positions].rmsd(\n", + " generation_chain[constrained_site_positions]\n", + ")\n", + "backbone_rmsd = template_chain.rmsd(generation_chain)\n", + "\n", + "c_pass = \"✅\" if constrained_site_rmsd < 1.5 else \"❌\"\n", + "b_pass = \"🤷‍♂️\"\n", + "\n", + "print(f\"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}\")\n", + "print(f\"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0cIeC4Lg1Bz9" + }, + "source": [ + "An now we can visualize the final structure prediction of our candidate GFP design." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WTGRyt-es2sJ" + }, + "outputs": [], + "source": [ + "view = py3Dmol.view(width=600, height=600)\n", + "view.addModel(sequence_generation_protein.to_pdb_string(), \"pdb\")\n", + "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}})\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VrNuZHeHRWuP" + }, + "source": [ + "Before considering this sequence for wet lab validation, we run a joint optimization of the sequence and structure. The outputs of that process are then passed through stringent computational filters and then many designs from many starting points are ranked by a number of computational scores to select the final designs sent for testing. We'll walk through that process in a different notebook." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c3jSQrJa1Tfi" + }, + "source": [ + "If you've made it this far it's worth noting that this isn't the only method to prompt ESM3 to design a GFP, it's just the one we used to report the successful generation of esmGFP in our paper. We hope you'll try different techniques to generate from ESM3. We're interested to hear what works for you!" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/local_generate.py b/examples/local_generate.py index e6359c3..10e968c 100644 --- a/examples/local_generate.py +++ b/examples/local_generate.py @@ -1,3 +1,5 @@ +import torch + from esm.models.esm3 import ESM3 from esm.sdk.api import ( ESM3InferenceClient, @@ -38,8 +40,7 @@ def main(client: ESM3InferenceClient): protein.function_annotations = None protein = client.encode(protein) single_step_protein = client.forward_and_sample( - protein, - SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2)), + protein, SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2)) ) single_step_protein.protein_tensor.sequence = protein.sequence single_step_protein = client.decode(single_step_protein.protein_tensor) @@ -52,8 +53,7 @@ def main(client: ESM3InferenceClient): ) protein = ESMProtein(sequence=prompt) protein = client.generate( - protein, - GenerationConfig(track="sequence", num_steps=8, temperature=0.7), + protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7) ) assert isinstance(protein, ESMProtein), f"ESMProtein was expected but got {protein}" @@ -189,5 +189,6 @@ def main(client: ESM3InferenceClient): assert isinstance(p, ESMProtein), f"ESMProtein was expected but got {p}" + if __name__ == "__main__": main(ESM3.from_pretrained("esm3_sm_open_v1")) diff --git a/pyproject.toml b/pyproject.toml index 88c17a5..8ee478e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "esm" -version = "3.0.7post1" +version = "3.0.8" description = "EvolutionaryScale open model repository" readme = "README.md" requires-python = ">=3.10" diff --git a/tools/invfold.ipynb b/tools/invfold.ipynb index a818bfa..07ab941 100644 --- a/tools/invfold.ipynb +++ b/tools/invfold.ipynb @@ -53,9 +53,7 @@ "outputs": [], "source": [ "from esm.widgets.utils.types import ClientInitContainer\n", - "from esm.widgets.views.inverse_folding import (\n", - " create_inverse_folding_ui,\n", - ")\n", + "from esm.widgets.views.inverse_folding import create_inverse_folding_ui\n", "from esm.widgets.views.login import create_login_ui" ] },