Skip to content

Commit

Permalink
v3.0.8 (#146)
Browse files Browse the repository at this point in the history
Co-authored-by: Zeming Lin <[email protected]>
  • Loading branch information
ebetica and Zeming Lin authored Nov 25, 2024
1 parent 39a3a6c commit 1561962
Show file tree
Hide file tree
Showing 62 changed files with 2,325 additions and 2,500 deletions.
3 changes: 2 additions & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__version__ = "3.0.7post1"
__version__ = "3.0.8"

6 changes: 1 addition & 5 deletions esm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@

class MultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
bias: bool = False,
qk_layernorm: bool = True,
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
):
super().__init__()

Expand Down
6 changes: 1 addition & 5 deletions esm/layers/geom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ def forward(self, s, affine, affine_mask, sequence_id, chain_id):
affine.rot[..., None]
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
.split(
[
self.v_heads,
self.v_heads,
self.v_heads * self.num_vector_messages,
],
[self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages],
dim=-2,
)
)
Expand Down
4 changes: 1 addition & 3 deletions esm/layers/regression_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@


def RegressionHead(
d_model: int,
output_dim: int,
hidden_dim: int | None = None,
d_model: int, output_dim: int, hidden_dim: int | None = None
) -> nn.Module:
"""Single-hidden layer MLP for supervised output.
Expand Down
4 changes: 1 addition & 3 deletions esm/layers/structure_proj.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import torch
import torch.nn as nn

from esm.utils.constants.physics import (
BB_COORDINATES,
)
from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
Expand Down
47 changes: 14 additions & 33 deletions esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
ProteinType,
SamplingConfig,
)
from esm.tokenization import (
TokenizerCollectionProtocol,
)
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import (
Expand Down Expand Up @@ -173,11 +171,7 @@ def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput:
secondary_structure_logits = self.ss8_head(x)
sasa_logits = self.sasa_head(x)
function_logits = self.function_head(x)
function_logits = einops.rearrange(
function_logits,
"... (k v) -> ... k v",
k=8,
)
function_logits = einops.rearrange(function_logits, "... (k v) -> ... k v", k=8)

residue_logits = self.residue_head(x)

Expand Down Expand Up @@ -217,11 +211,7 @@ def __init__(
super().__init__()
self.encoder = EncodeInputs(d_model)
self.transformer = TransformerStack(
d_model,
n_heads,
v_heads,
n_layers,
mask_and_zero_frameless=True,
d_model, n_heads, v_heads, n_layers, mask_and_zero_frameless=True
)
self.output_heads = OutputHeads(d_model)

Expand All @@ -237,9 +227,7 @@ def __init__(

@classmethod
def from_pretrained(
cls,
model_name: str = ESM3_OPEN_SMALL,
device: torch.device | None = None,
cls, model_name: str = ESM3_OPEN_SMALL, device: torch.device | None = None
) -> ESM3:
from esm.pretrained import load_local_model

Expand Down Expand Up @@ -489,15 +477,14 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor:
reference_sequence = encoding.get_default_sequence(sequence_length - 2)
else:
reference_sequence = input.sequence
(
function_tokens,
residue_annotation_tokens,
) = encoding.tokenize_function_annotations(
input.function_annotations,
reference_sequence=reference_sequence,
function_tokenizer=self.tokenizers.function,
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
add_special_tokens=True,
(function_tokens, residue_annotation_tokens) = (
encoding.tokenize_function_annotations(
input.function_annotations,
reference_sequence=reference_sequence,
function_tokenizer=self.tokenizers.function,
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
add_special_tokens=True,
)
)

return ESMProteinTensor(
Expand All @@ -510,10 +497,7 @@ def encode(self, input: ESMProtein) -> ESMProteinTensor:
coordinates=coordinates,
).to(next(self.parameters()).device)

def decode(
self,
input: ESMProteinTensor,
) -> ESMProtein:
def decode(self, input: ESMProteinTensor) -> ESMProtein:
return decode_protein_tensor(
input=input,
tokenizers=self.tokenizers,
Expand Down Expand Up @@ -613,10 +597,7 @@ def forward_and_sample(

logits_output: LogitsOutput = _batch_forward(self, batched_protein)
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
batched_protein,
logits_output,
sampling_config,
self.tokenizers,
batched_protein, logits_output, sampling_config, self.tokenizers
)

# There is only 1 prompt to sample for.
Expand Down
6 changes: 2 additions & 4 deletions esm/models/function_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ def forward(self, token_ids: torch.Tensor) -> dict[str, torch.Tensor]:
# Apply depth-position offset to use distinct vocabs. See __init__ for
# explaination.
vocab_offsets = self.config.function_token_vocab_size * torch.arange(
self.config.function_token_depth,
device=token_ids.device,
self.config.function_token_depth, device=token_ids.device
)
inputs = token_ids + vocab_offsets[None, :]

Expand Down Expand Up @@ -251,8 +250,7 @@ def decode(
annotations.append(annotation)

annotations = merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
annotations, merge_gap_max=annotation_gap_merge_max
)

# Drop very small annotations.
Expand Down
18 changes: 3 additions & 15 deletions esm/models/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def forward(self, x, pairwise: torch.Tensor | None = None):

prod = q[:, None, :, :] * k[:, :, None, :]
diff = q[:, None, :, :] - k[:, :, None, :]
x_2d = [
prod,
diff,
]
x_2d = [prod, diff]
if pairwise is not None:
x_2d.append(pairwise)
x = torch.cat(x_2d, dim=-1)
Expand Down Expand Up @@ -289,11 +286,7 @@ def find_knn_edges(
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore
ca = coords[..., 1, :]
edges, edge_mask = knn_graph(
ca,
coord_mask,
padding_mask,
sequence_id,
no_knn=knn,
ca, coord_mask, padding_mask, sequence_id, no_knn=knn
)

return edges, edge_mask
Expand Down Expand Up @@ -333,12 +326,7 @@ def encode(


class StructureTokenDecoder(nn.Module):
def __init__(
self,
d_model,
n_heads,
n_layers,
):
def __init__(self, d_model, n_heads, n_layers):
super().__init__()
self.decoder_channels = d_model

Expand Down
43 changes: 30 additions & 13 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC
from typing import Sequence
from typing import List, Sequence

import attr
import torch
Expand All @@ -19,14 +19,10 @@
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import ProteinComplex
from esm.utils.types import (
FunctionAnnotation,
PathOrBuffer,
)
from esm.utils.types import FunctionAnnotation, PathOrBuffer


class ProteinType(ABC):
...
class ProteinType(ABC): ...


## Basic Types
Expand Down Expand Up @@ -184,6 +180,9 @@ class ESMProteinTensor(ProteinType):
# Such sequences may not go through standard safety filter for approved users.
# Reach out if interested in using this.
potential_sequence_of_concern: bool = False
# Control vectors are vectors added to each layer of the model to nudge hidden states to the desired direction.
# len(control_vectors) == number of blocks in the model. Each vector in the list have the shape of (batch size, sequence length, hidden dim)
# so it can be added to the corresponding layer in the model

def _detect_attribute(self, func, msg):
mapped = {
Expand Down Expand Up @@ -260,20 +259,40 @@ class ESMProteinError(Exception, ProteinType):
class GenerationConfig:
track: str = ""
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
# Controls the number of tokens to unmask during each round of iterative generation.
schedule: str = attr.field(
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
)
# Controls which tokens to unmask during each round of iterative generation.
# "random" will unmask a correct number of tokens randomly.
# "entropy" will unmask the tokens with the lowest logit entropy first.
strategy: str = attr.field(
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
)
# Set this to a higher value for better generation results.
# Note that this needs to be less than or equal to the sequence length.
num_steps: int = 1
temperature: float = 1.0
temperature_annealing: bool = False
top_p: float = 1.0
condition_on_coordinates_only: bool = True

def use_entropy_based_unmasking_strategy(self):
"""Use entropy based unmasking strategy during generation."""
self.schedule = "cosine"
self.strategy = "entropy"
self.temperature_annealing = False

def use_generative_unmasking_strategy(self):
"""Use an unmasking strategy that produces more variety of generations."""
self.schedule = "cosine"
self.strategy = "random"
self.temperature_annealing = True


@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
num_steps: int = 1
temperature: float = 1.0


Expand Down Expand Up @@ -370,9 +389,7 @@ def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
raise NotImplementedError

def batch_generate(
self,
inputs: Sequence[ProteinType],
configs: Sequence[GenerationConfig],
self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
) -> Sequence[ProteinType]:
# Same as generate(...), but generates a batch of proteins at once.
raise NotImplementedError
Expand Down
Loading

0 comments on commit 1561962

Please sign in to comment.