From 5646303becdff09411cf8aefaa9985781afffb0c Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Fri, 3 Mar 2023 02:32:48 +0800 Subject: [PATCH] rewrite config Signed-off-by: Zhiyuan Chen --- unifold/config/__init__.py | 7 ++ unifold/config/config.py | 182 ++++++++++++++++++++++++++++ unifold/config/data.py | 234 ++++++++++++++++++++++++++++++++++++ unifold/config/globals.py | 19 +++ unifold/config/globals.yaml | 11 ++ unifold/config/loss.py | 117 ++++++++++++++++++ unifold/config/loss.yaml | 59 +++++++++ unifold/config/model.py | 189 +++++++++++++++++++++++++++++ unifold/config/model.yaml | 138 +++++++++++++++++++++ unifold/config/variables.py | 14 +++ 10 files changed, 970 insertions(+) create mode 100755 unifold/config/__init__.py create mode 100755 unifold/config/config.py create mode 100755 unifold/config/data.py create mode 100755 unifold/config/globals.py create mode 100644 unifold/config/globals.yaml create mode 100755 unifold/config/loss.py create mode 100644 unifold/config/loss.yaml create mode 100755 unifold/config/model.py create mode 100644 unifold/config/model.yaml create mode 100644 unifold/config/variables.py diff --git a/unifold/config/__init__.py b/unifold/config/__init__.py new file mode 100755 index 0000000..163201c --- /dev/null +++ b/unifold/config/__init__.py @@ -0,0 +1,7 @@ +from .config import UniFoldConfig, base_config, model_config +from .data import DataConfig +from .globals import GlobalsConfig +from .loss import LossConfig +from .model import ModelConfig + +__all__ = ["base_config", "model_config", "UniFoldConfig", "GlobalsConfig", "DataConfig", "ModelConfig", "LossConfig"] diff --git a/unifold/config/config.py b/unifold/config/config.py new file mode 100755 index 0000000..6dbfc9e --- /dev/null +++ b/unifold/config/config.py @@ -0,0 +1,182 @@ +from typing import Any, Optional +from warnings import warn + +from chanfig import Config + +from .data import DataConfig +from .globals import GlobalsConfig +from .loss import LossConfig +from .model import ModelConfig + + +class UniFoldConfig(Config): + def __init__(self, *args, **kwargs): + self.globals = GlobalsConfig() + self.data = DataConfig() + self.model = ModelConfig() + self.loss = LossConfig() + super().__init__(*args, **kwargs) + + +def recursive_set(c: Config, key: str, value: Any, ignore: Optional[str] = None): + with c.unlocked(): + for k, v in c.items(): + if ignore is not None and k == ignore: + continue + if isinstance(v, Config): + recursive_set(v, key, value) + elif k == key: + c[k] = value + + +def base_config(): + deprecation_message = "`base_config` is deprecated.\nPlease call `UniFoldConfig()` instead" + warn(deprecation_message, DeprecationWarning, stacklevel=2) + return UniFoldConfig() + + +def model_config(name, train=False): + c = UniFoldConfig() + + def model_2_v2(c): + recursive_set(c, "v2_feature", True) + recursive_set(c, "gumbel_sample", True) + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.template.template_angle_embedder.d_in = 34 + return c + + def multimer(c): + recursive_set(c, "is_multimer", True) + recursive_set(c, "max_extra_msa", 1152) + recursive_set(c, "max_msa_clusters", 128) + recursive_set(c, "v2_feature", True) + recursive_set(c, "gumbel_sample", True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + # we forget to enable it in our training, so disable it here + c.model.heads.pae.disable_enhance_head = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.model.input_embedder.tf_dim = 21 + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.chain_centre_mass.weight = 1.0 + return c + + if name == "model_1": + pass + elif name == "model_1_ft": + recursive_set(c, "max_extra_msa", 5120) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == "model_1_af2": + recursive_set(c, "max_extra_msa", 5120) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == "model_2": + pass + elif name == "model_init": + pass + elif name == "model_init_af2": + c.globals.alphafold_original_mode = True + pass + elif name == "model_2_ft": + recursive_set(c, "max_extra_msa", 1024) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == "model_2_af2": + recursive_set(c, "max_extra_msa", 1024) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == "model_2_v2": + c = model_2_v2(c) + elif name == "model_2_v2_ft": + c = model_2_v2(c) + recursive_set(c, "max_extra_msa", 1024) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == "model_3_af2" or name == "model_4_af2": + recursive_set(c, "max_extra_msa", 5120) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, "use_templates", False) + recursive_set(c, "use_template_torsion_angles", False) + elif name == "model_5_af2": + recursive_set(c, "max_extra_msa", 1024) + recursive_set(c, "max_msa_clusters", 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, "use_templates", False) + recursive_set(c, "use_template_torsion_angles", False) + elif name == "multimer": + c = multimer(c) + elif name == "multimer_ft": + c = multimer(c) + recursive_set(c, "max_extra_msa", 1152) + recursive_set(c, "max_msa_clusters", 256) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.5 + elif name == "multimer_af2": + recursive_set(c, "max_extra_msa", 1152) + recursive_set(c, "max_msa_clusters", 256) + recursive_set(c, "is_multimer", True) + recursive_set(c, "v2_feature", True) + recursive_set(c, "gumbel_sample", True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + c.model.heads.experimentally_resolved.enabled = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.loss.violation.weight = 0.5 + c.loss.experimentally_resolved.weight = 0.01 + c.model.input_embedder.tf_dim = 21 + c.globals.alphafold_original_mode = True + c.data.train.crop_size = 384 + c.loss.repr_norm.weight = 0 + c.loss.chain_centre_mass.weight = 1.0 + recursive_set(c, "outer_product_mean_first", True) + else: + raise ValueError(f"invalid --model-name: {name}.") + if train: + c.globals.chunk_size = None + recursive_set(c, "inf", 3e4) + recursive_set(c, "eps", 1e-5, "loss") + return c diff --git a/unifold/config/data.py b/unifold/config/data.py new file mode 100755 index 0000000..c907c34 --- /dev/null +++ b/unifold/config/data.py @@ -0,0 +1,234 @@ +from chanfig import Config + +N_RES = "number of residues" +N_MSA = "number of MSA sequences" +N_EXTRA_MSA = "number of extra MSA sequences" +N_TPL = "number of templates" + +from .variables import is_multimer, max_recycling_iters, use_templates + + +class DataConfig(Config): + def __init__(self, *args, **kwargs): + self.common = CommonDataConfig() + self.supervised = SupervisedDataConfig() + self.train = TrainDataConfig() + self.eval = EvalDataConfig() + self.predict = PredictDataConfig() + super().__init__(*args, **kwargs) + + +class CommonDataConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.features = FeaturesConfig() + self.masked_msa = Config( + profile_prob=0.1, + same_prob=0.1, + uniform_prob=0.1, + ) + self.block_delete_msa = Config( + msa_fraction_per_block=0.3, + randomize_num_blocks=False, + num_blocks=5, + min_num_msa=16, + ) + self.random_delete_msa = Config( + max_msa_entry=1 << 25, # := 33554432 + ) + self.v2_feature = False + self.gumbel_sample = False + self.max_extra_msa = 1024 + self.msa_cluster_features = True + self.reduce_msa_clusters_by_max_templates = True + self.resample_msa_in_recycling = True + self.template_features = [ + "template_all_atom_positions", + "template_sum_probs", + "template_aatype", + "template_all_atom_mask", + ] + self.unsupervised_features = [ + "aatype", + "residue_index", + "msa", + "msa_chains", + "num_alignments", + "seq_length", + "between_segment_residues", + "deletion_matrix", + "num_recycling_iters", + "crop_and_fix_size_seed", + ] + self.recycling_features = [ + "msa_chains", + "msa_mask", + "msa_row_mask", + "bert_mask", + "true_msa", + "msa_feat", + "extra_msa_deletion_value", + "extra_msa_has_deletion", + "extra_msa", + "extra_msa_mask", + "extra_msa_row_mask", + "is_distillation", + ] + self.multimer_features = [ + "assembly_num_chains", + "asym_id", + "sym_id", + "num_sym", + "entity_id", + "asym_len", + "cluster_bias_mask", + ] + self.use_templates = use_templates + self.is_multimer = is_multimer + self.use_template_torsion_angles = use_templates + self.max_recycling_iters = max_recycling_iters + + +class FeaturesConfig(Config): + def __init__(self): + self.aatype = [N_RES] + self.all_atom_mask = [N_RES, None] + self.all_atom_positions = [N_RES, None, None] + self.alt_chi_angles = [N_RES, None] + self.atom14_alt_gt_exists = [N_RES, None] + self.atom14_alt_gt_positions = [N_RES, None, None] + self.atom14_atom_exists = [N_RES, None] + self.atom14_atom_is_ambiguous = [N_RES, None] + self.atom14_gt_exists = [N_RES, None] + self.atom14_gt_positions = [N_RES, None, None] + self.atom37_atom_exists = [N_RES, None] + self.frame_mask = [N_RES] + self.true_frame_tensor = [N_RES, None, None] + self.bert_mask = [N_MSA, N_RES] + self.chi_angles_sin_cos = [N_RES, None, None] + self.chi_mask = [N_RES, None] + self.extra_msa_deletion_value = [N_EXTRA_MSA, N_RES] + self.extra_msa_has_deletion = [N_EXTRA_MSA, N_RES] + self.extra_msa = [N_EXTRA_MSA, N_RES] + self.extra_msa_mask = [N_EXTRA_MSA, N_RES] + self.extra_msa_row_mask = [N_EXTRA_MSA] + self.is_distillation = [] + self.msa_feat = [N_MSA, N_RES, None] + self.msa_mask = [N_MSA, N_RES] + self.msa_chains = [N_MSA, None] + self.msa_row_mask = [N_MSA] + self.num_recycling_iters = [] + self.pseudo_beta = [N_RES, None] + self.pseudo_beta_mask = [N_RES] + self.residue_index = [N_RES] + self.residx_atom14_to_atom37 = [N_RES, None] + self.residx_atom37_to_atom14 = [N_RES, None] + self.resolution = [] + self.rigidgroups_alt_gt_frames = [N_RES, None, None, None] + self.rigidgroups_group_exists = [N_RES, None] + self.rigidgroups_group_is_ambiguous = [N_RES, None] + self.rigidgroups_gt_exists = [N_RES, None] + self.rigidgroups_gt_frames = [N_RES, None, None, None] + self.seq_length = [] + self.seq_mask = [N_RES] + self.target_feat = [N_RES, None] + self.template_aatype = [N_TPL, N_RES] + self.template_all_atom_mask = [N_TPL, N_RES, None] + self.template_all_atom_positions = [N_TPL, N_RES, None, None] + self.template_alt_torsion_angles_sin_cos = [ + N_TPL, + N_RES, + None, + None, + ] + + self.template_frame_mask = [N_TPL, N_RES] + self.template_frame_tensor = [N_TPL, N_RES, None, None] + self.template_mask = [N_TPL] + self.template_pseudo_beta = [N_TPL, N_RES, None] + self.template_pseudo_beta_mask = [N_TPL, N_RES] + self.template_sum_probs = [N_TPL, None] + self.template_torsion_angles_mask = [N_TPL, N_RES, None] + self.template_torsion_angles_sin_cos = [N_TPL, N_RES, None, None] + self.true_msa = [N_MSA, N_RES] + self.use_clamped_fape = [] + self.assembly_num_chains = [1] + self.asym_id = [N_RES] + self.sym_id = [N_RES] + self.entity_id = [N_RES] + self.num_sym = [N_RES] + self.asym_len = [None] + self.cluster_bias_mask = [N_MSA] + + +class SupervisedDataConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.use_clamped_fape_prob = 1.0 + self.supervised_features = [ + "all_atom_mask", + "all_atom_positions", + "resolution", + "use_clamped_fape", + "is_distillation", + ] + + +class TrainDataConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fixed_size = True + self.subsample_templates = True + self.block_delete_msa = True + self.random_delete_msa = True + self.masked_msa_replace_fraction = 0.15 + self.max_msa_clusters = 128 + self.max_templates = 4 + self.num_ensembles = 1 + self.crop = True + self.crop_size = 256 + self.spatial_crop_prob = 0.5 + self.ca_ca_threshold = 10.0 + self.supervised = True + self.use_clamped_fape_prob = 1.0 + self.max_distillation_msa_clusters = 1000 + self.biased_msa_by_chain = True + self.share_mask = True + + +class EvalDataConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fixed_size = True + self.subsample_templates = False + self.block_delete_msa = False + self.random_delete_msa = True + self.masked_msa_replace_fraction = 0.15 + self.max_msa_clusters = 128 + self.max_templates = 4 + self.num_ensembles = 1 + self.crop = False + self.crop_size = None + self.spatial_crop_prob = 0.5 + self.ca_ca_threshold = 10.0 + self.supervised = True + self.biased_msa_by_chain = False + self.share_mask = False + + +class PredictDataConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fixed_size = True + self.subsample_templates = False + self.block_delete_msa = False + self.random_delete_msa = True + self.masked_msa_replace_fraction = 0.15 + self.max_msa_clusters = 128 + self.max_templates = 4 + self.num_ensembles = 2 + self.crop = False + self.crop_size = None + self.supervised = False + self.biased_msa_by_chain = False + self.share_mask = False diff --git a/unifold/config/globals.py b/unifold/config/globals.py new file mode 100755 index 0000000..ba133be --- /dev/null +++ b/unifold/config/globals.py @@ -0,0 +1,19 @@ +from chanfig import Config + +from .variables import chunk_size, d_extra_msa, d_msa, d_pair, d_single, d_template, eps, inf, max_recycling_iters + + +class GlobalsConfig(Config): + def __init__(self, *args, **kwargs): + self.chunk_size = chunk_size + self.block_size = None + self.d_pair = d_pair + self.d_msa = d_msa + self.d_template = d_template + self.d_extra_msa = d_extra_msa + self.d_single = d_single + self.eps = eps + self.inf = inf + self.max_recycling_iters = max_recycling_iters + self.alphafold_original_mode = False + super().__init__(*args, **kwargs) diff --git a/unifold/config/globals.yaml b/unifold/config/globals.yaml new file mode 100644 index 0000000..87797d0 --- /dev/null +++ b/unifold/config/globals.yaml @@ -0,0 +1,11 @@ +alphafold_original_mode: false +block_size: null +chunk_size: 4 +d_extra_msa: 64 +d_msa: 256 +d_pair: 128 +d_single: 384 +d_template: 64 +eps: 1.0e-08 +inf: 30000.0 +max_recycling_iters: 3 diff --git a/unifold/config/loss.py b/unifold/config/loss.py new file mode 100755 index 0000000..b8d891c --- /dev/null +++ b/unifold/config/loss.py @@ -0,0 +1,117 @@ +from chanfig import Config + + +class LossConfig(Config): + def __init__(self, *args, **kwargs): + self.distogram = DistogramLossConfig() + self.experimentally_resolved = ExperimentallyResolvedLossConfig() + self.fape = FAPELossConfig() + self.plddt = PLDDTLossConfig() + self.masked_msa = MaskedMSALossConfig() + self.supervised_chi = SupervisedChiLossConfig() + self.violation = ViolationLossConfig() + self.pae = PAELossConfig() + self.repr_norm = ReprNormLossConfig() + self.chain_centre_mass = ChainCentreMassLossConfig() + super().__init__(*args, **kwargs) + + +class DistogramLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.min_bin = 2.3125 + self.max_bin = 21.6875 + self.num_bins = 64 + self.eps = 1e-6 + self.weight = 0.3 + + +class ExperimentallyResolvedLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.eps = 1e-8 + self.min_resolution = 0.1 + self.max_resolution = 3.0 + self.weight = 0.0 + + +class FAPELossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backbone = Config( + clamp_distance=10.0, + clamp_distance_between_chains=30.0, + loss_unit_distance=10.0, + loss_unit_distance_between_chains=20.0, + weight=0.5, + eps=1e-4, + ) + self.sidechain = Config( + clamp_distance=10.0, + length_scale=10.0, + weight=0.5, + eps=1e-4, + ) + self.weight = 1.0 + + +class PLDDTLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.min_resolution = 0.1 + self.max_resolution = 3.0 + self.cutoff = 15.0 + self.num_bins = 50 + self.eps = 1e-10 + self.weight = 0.01 + + +class MaskedMSALossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.eps = 1e-8 + self.weight = 2.0 + + +class SupervisedChiLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.chi_weight = 0.5 + self.angle_norm_weight = 0.01 + self.eps = 1e-6 + self.weight = 1.0 + + +class ViolationLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.violation_tolerance_factor = 12.0 + self.clash_overlap_tolerance = 1.5 + self.bond_angle_loss_weight = 0.3 + self.eps = 1e-6 + self.weight = 0.0 + + +class PAELossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_bin = 31 + self.num_bins = 64 + self.min_resolution = 0.1 + self.max_resolution = 3.0 + self.eps = 1e-8 + self.weight = 0.0 + + +class ReprNormLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.weight = 0.01 + self.tolerance = 1.0 + + +class ChainCentreMassLossConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.weight = 0.0 + self.eps = 1e-8 diff --git a/unifold/config/loss.yaml b/unifold/config/loss.yaml new file mode 100644 index 0000000..345e720 --- /dev/null +++ b/unifold/config/loss.yaml @@ -0,0 +1,59 @@ +chain_centre_mass: + eps: 1.0e-08 + weight: 0.0 +distogram: + eps: 1.0e-06 + max_bin: 21.6875 + min_bin: 2.3125 + num_bins: 64 + weight: 0.3 +experimentally_resolved: + eps: 1.0e-08 + max_resolution: 3.0 + min_resolution: 0.1 + weight: 0.0 +fape: + backbone: + clamp_distance: 10.0 + clamp_distance_between_chains: 30.0 + eps: 0.0001 + loss_unit_distance: 10.0 + loss_unit_distance_between_chains: 20.0 + weight: 0.5 + sidechain: + clamp_distance: 10.0 + eps: 0.0001 + length_scale: 10.0 + weight: 0.5 + weight: 1.0 +masked_msa: + eps: 1.0e-08 + weight: 2.0 +pae: + eps: 1.0e-08 + max_bin: 31 + max_resolution: 3.0 + min_resolution: 0.1 + num_bins: 64 + weight: 0.0 +plddt: + cutoff: 15.0 + eps: 1.0e-10 + max_resolution: 3.0 + min_resolution: 0.1 + num_bins: 50 + weight: 0.01 +repr_norm: + tolerance: 1.0 + weight: 0.01 +supervised_chi: + angle_norm_weight: 0.01 + chi_weight: 0.5 + eps: 1.0e-06 + weight: 1.0 +violation: + bond_angle_loss_weight: 0.3 + clash_overlap_tolerance: 1.5 + eps: 1.0e-06 + violation_tolerance_factor: 12.0 + weight: 0.0 diff --git a/unifold/config/model.py b/unifold/config/model.py new file mode 100755 index 0000000..3c41287 --- /dev/null +++ b/unifold/config/model.py @@ -0,0 +1,189 @@ +from chanfig import Config + +from .variables import aux_distogram_bins, d_extra_msa, d_msa, d_pair, d_single, d_template, is_multimer, use_templates + + +class ModelConfig(Config): + def __init__(self, *args, **kwargs): + self.is_multimer = is_multimer + self.input_embedder = Config( + tf_dim=22, + msa_dim=49, + d_pair=d_pair, + d_msa=d_msa, + relpos_k=32, + max_relative_chain=2, + ) + self.recycling_embedder = Config( + d_pair=d_pair, + d_msa=d_msa, + min_bin=3.25, + max_bin=20.75, + num_bins=15, + inf=1e8, + ) + self.template = TemplateConfig() + self.extra_msa = ExtraMSAConfig() + self.evoformer_stack = Config( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=32, + d_hid_opm=32, + d_hid_mul=128, + d_hid_pair_att=32, + d_single=d_single, + num_heads_msa=8, + num_heads_pair=4, + num_blocks=48, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + inf=1e9, + eps=1e-10, + outer_product_mean_first=False, + ) + self.structure_module = Config( + d_single=d_single, + d_pair=d_pair, + d_ipa=16, + d_angle=128, + num_heads_ipa=12, + num_qk_points=4, + num_v_points=8, + dropout_rate=0.1, + num_blocks=8, + no_transition_layers=1, + num_resnet_blocks=2, + num_angles=7, + trans_scale_factor=10, + epsilon=1e-12, + inf=1e5, + separate_kv=False, + ipa_bias=True, + ) + self.heads = HeadsConfig() + super().__init__(*args, **kwargs) + + +class TemplateConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.distogram = Config( + min_bin=3.25, + max_bin=50.75, + num_bins=39, + ) + self.template_angle_embedder = Config( + d_in=57, + d_out=d_msa, + ) + self.template_pair_embedder = Config( + d_in=88, + v2_d_in=[39, 1, 22, 22, 1, 1, 1, 1], + d_pair=d_pair, + d_out=d_template, + v2_feature=False, + ) + self.template_pair_stack = Config( + d_template=d_template, + d_hid_tri_att=16, + d_hid_tri_mul=64, + num_blocks=2, + num_heads=4, + pair_transition_n=2, + dropout_rate=0.25, + inf=1e9, + tri_attn_first=True, + ) + self.template_pointwise_attention = Config( + enabled=True, + d_template=d_template, + d_pair=d_pair, + d_hid=16, + num_heads=4, + inf=1e5, + ) + self.inf = 1e5 + self.eps = 1e-6 + self.enabled = use_templates + self.embed_angles = use_templates + + +class ExtraMSAConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.extra_msa_embedder = Config( + d_in=25, + d_out=d_extra_msa, + ) + self.extra_msa_stack = Config( + d_msa=d_extra_msa, + d_pair=d_pair, + d_hid_msa_att=8, + d_hid_opm=32, + d_hid_mul=128, + d_hid_pair_att=32, + num_heads_msa=8, + num_heads_pair=4, + num_blocks=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + inf=1e9, + eps=1e-10, + outer_product_mean_first=False, + ) + self.enabled = True + + +class HeadsConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.plddt = PLDDTHeadConfig() + self.distogram = DistogramHeadConfig() + self.pae = PAEHeadConfig() + self.masked_msa = MaskedMSAtHeadConfig() + self.experimentally_resolved = ExperimentallyResolvedHeadConfig() + + +class PLDDTHeadConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_bins = 50 + self.d_in = d_single + self.d_hid = 128 + + +class DistogramHeadConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.d_pair = d_pair + self.num_bins = aux_distogram_bins + self.disable_enhance_head = False + + +class PAEHeadConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.d_pair = d_pair + self.num_bins = aux_distogram_bins + self.enabled = False + self.iptm_weight = 0.8 + self.disable_enhance_head = False + + +class MaskedMSAtHeadConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.d_msa = d_msa + self.d_out = 23 + self.disable_enhance_head = False + + +class ExperimentallyResolvedHeadConfig(Config): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.d_single = d_single + self.d_out = 37 + self.enabled = False + self.disable_enhance_head = False diff --git a/unifold/config/model.yaml b/unifold/config/model.yaml new file mode 100644 index 0000000..bc4f53d --- /dev/null +++ b/unifold/config/model.yaml @@ -0,0 +1,138 @@ +evoformer_stack: + d_hid_msa_att: 32 + d_hid_mul: 128 + d_hid_opm: 32 + d_hid_pair_att: 32 + d_msa: 256 + d_pair: 128 + d_single: 384 + eps: 1.0e-10 + inf: 1000000000.0 + msa_dropout: 0.15 + num_blocks: 48 + num_heads_msa: 8 + num_heads_pair: 4 + outer_product_mean_first: false + pair_dropout: 0.25 + transition_n: 4 +extra_msa: + enabled: true + extra_msa_embedder: + d_in: 25 + d_out: 64 + extra_msa_stack: + d_hid_msa_att: 8 + d_hid_mul: 128 + d_hid_opm: 32 + d_hid_pair_att: 32 + d_msa: 64 + d_pair: 128 + eps: 1.0e-10 + inf: 1000000000.0 + msa_dropout: 0.15 + num_blocks: 4 + num_heads_msa: 8 + num_heads_pair: 4 + outer_product_mean_first: false + pair_dropout: 0.25 + transition_n: 4 +heads: + distogram: + d_pair: 128 + disable_enhance_head: false + num_bins: 64 + experimentally_resolved: + d_out: 37 + d_single: 384 + disable_enhance_head: false + enabled: false + masked_msa: + d_msa: 256 + d_out: 23 + disable_enhance_head: false + pae: + d_pair: 128 + disable_enhance_head: false + enabled: false + iptm_weight: 0.8 + num_bins: 64 + plddt: + d_hid: 128 + d_in: 384 + num_bins: 50 +input_embedder: + d_msa: 256 + d_pair: 128 + max_relative_chain: 2 + msa_dim: 49 + relpos_k: 32 + tf_dim: 22 +is_multimer: false +recycling_embedder: + d_msa: 256 + d_pair: 128 + inf: 100000000.0 + max_bin: 20.75 + min_bin: 3.25 + num_bins: 15 +structure_module: + d_angle: 128 + d_ipa: 16 + d_pair: 128 + d_single: 384 + dropout_rate: 0.1 + epsilon: 1.0e-12 + inf: 100000.0 + ipa_bias: true + no_transition_layers: 1 + num_angles: 7 + num_blocks: 8 + num_heads_ipa: 12 + num_qk_points: 4 + num_resnet_blocks: 2 + num_v_points: 8 + separate_kv: false + trans_scale_factor: 10 +template: + distogram: + max_bin: 50.75 + min_bin: 3.25 + num_bins: 39 + embed_angles: true + enabled: true + eps: 1.0e-06 + inf: 100000.0 + template_angle_embedder: + d_in: 57 + d_out: 256 + template_pair_embedder: + d_in: 88 + d_out: 64 + d_pair: 128 + v2_d_in: + - 39 + - 1 + - 22 + - 22 + - 1 + - 1 + - 1 + - 1 + v2_feature: false + template_pair_stack: + d_hid_tri_att: 16 + d_hid_tri_mul: 64 + d_template: 64 + dropout_rate: 0.25 + inf: 1000000000.0 + num_blocks: 2 + num_heads: 4 + pair_transition_n: 2 + tri_attn_first: true + template_pointwise_attention: + d_hid: 16 + d_pair: 128 + d_template: 64 + enabled: true + inf: 100000.0 + num_heads: 4 diff --git a/unifold/config/variables.py b/unifold/config/variables.py new file mode 100644 index 0000000..dae6af0 --- /dev/null +++ b/unifold/config/variables.py @@ -0,0 +1,14 @@ +from chanfig import Variable + +d_pair = Variable(128) +d_msa = Variable(256) +d_template = Variable(64) +d_extra_msa = Variable(64) +d_single = Variable(384) +max_recycling_iters = Variable(3) +chunk_size = Variable(4) +aux_distogram_bins = Variable(64) +eps = Variable(1e-8) +inf = Variable(3e4) +use_templates = Variable(True) +is_multimer = Variable(False)