Skip to content

Architecture

Jeff Flatten edited this page Jul 1, 2024 · 40 revisions

Overview

The central focus of tmol is the PoseStack - a batch of structures. At its heart, tmol is a library for creating, scoring, manipulating, and exporting PoseStacks.

Creating PoseStacks

Under the hood, all PoseStack creation is done through a common function. Other PoseStack creation functions such as loading from a pdb, or importing from RosettaFold2 or OpenFold, work by first converting the source data into a common representation - the CanonicalForm.

CanonicalForm

Because tmol represents structures with a higher chemical granularity than most other ML packages, it has to resolve the chemical structure of the molecules coming into it from other sources. Even the process of reading in a PDB file requires this chemical type resolution.

The CanonicalForm is a structure batch format that lets us represent data in tmol while deferring the chemical resolution. This makes loading the source data into tmol easier, and lets us make the chemical type resolution step use the same machinery, regardless of data source. CanonicalForms are also stable, and can be serialized to disk and loaded in years later to make a PoseStack.

Converting to CanonicalForm

The process of converting input data into a CanonicalForm varies by the input source.

Note

TODO: do we want to say anything about this? Maybe just delete this section?

Chemical Type Resolution

Once a batch of structures has been converted into a CanonicalForm, it can be combined with several other data-source specific objects to resolve the chemical structure and create a PoseStack:

  • The PackedBlockTypes object, which contains chemical information for the set of chemical types used by the data source.
  • The CanonicalOrdering object, which describes the mapping of chemical types to integers and also a mapping for each type of the atom names to unique integers.

Variants of these two objects for each data source are stored in the database, and more can be added in order to support PoseStack creation from new data sources.

Note

It is important to note that the mappings in the CanonicalOrdering object for a data source must not change over time, as these mappings are what preserve the ability for a saved CanonicalForm to be loaded years after it was originally created.

Scoring PoseStacks

tmol can evaluate the energy of a PoseStack using a ScoreFunction that is composed of one or more EnergyTerms. The ScoreFunction will return a weighted sum of these EnergyTerms either on a whole-pose or block-pair basis.

ScoreFunction evaluation is separated into several steps:

  1. Precomputation of various EnergyTerm data needed by scoring.
  2. Packing of that data into tensors for efficient use on the GPU.
  3. Rendering of the torch modules.
  4. Calling the rendered ScoreFunction/EnergyTerms on the PoseStack coords.
Precomputation

Before scoring a PoseStack, there is some precompuation that must happen that ensures EnergyTerms have the data they need for every block type in the PoseStack.

The ScoreFunction will have each EnergyTerm iterate over the complete list of RefinedResidueTypes used by any pose in the PoseStack to let them do any preprocessing of block-type specific data that they may need.

This step mostly involves pulling values from the database for the EnergyTerm in question and caching it in the RefinedResidueType.

Packing

After any precomputation is finished, the ScoreFunction will 'pack' this data into an efficient tensor representation that can be transfered to the GPU so that it can be used directly by the torch modules.

This step centers around filling the PackedBlockTypes object. The ScoreFunction will have each EnergyTerm pack any information they may need into this object. This packed data is usually derived from the precomputed data from the first step.ed data and serialize it into compact tensors that are then stored in the PackedblockTypes object.

Rendering a ScoringModule
For EnergyTerms:

In order for torch to actually use our EnergyTerms, we have to create a torch Module. The EnergyTerms use the function render_whole_pose_scoring_module to instantiate a module that is configured for running with the precomputed and packed data.

The ScoringModule itself defines a forward function that does the actual computation on the atom coordinates. This computation can either be pure torch Python code (for an example in code, look at the 'RefEnergyTerm'), or can be written in C++/CUDA ('CartBondedEnergyTerm', 'HBondEnergyTerm', etc).

For the ScoreFunction:

On the ScoreFunction level, the render function works differently. Rendering a ScoreFunction does not produce an actual torch module like the EnergyTerm's render. Instead, it is in this function that the component EnergyTerm precomputation, packing, and rendering functions are called. The returned value will be a callable configured for the specified PoseStack and Weights.

Note

TODO: Since the ScoreFunction's rendered 'module' isn't a torch module, does this mean it cannot be composed with other torch operations directly? Should we say something about this? Should we change this?

Note

TODO: Some sort of description of how the ScoringModules set up parameters (_p())

Calling the rendered ScoreFunction

The rendered ScoreFunction behaves like any other function operating on torch tensors. It takes as inputs the coordinates of a PoseStack, and outputs the energy of that PoseStack (or optionally a tensor of block-pair energies). In both cases, gradients are defined with respect to the input coordinates, so the rendered scorefunction can be used as part of a loss function in neural network training. The rendered scorefunction may be reused as long as only coordinates change. Any changes to the amino acid sequence or the individual score terms used will require re-rendering the ScoreFunction.

Whole-Pose vs Block-Pair scoring

ScoreFunctions and EnergyTerms can operate in two different modes.

Whole-pose mode aggregates all energies from all blocks into a single value.

Block-pair mode instead returns a N-block by N-block tensor that attributes each computed energy to the pair of blocks involved in the computation, with one-block calculations being on the diagonal.

Block-pair mode returns more detailed information, but this comes at some performance cost.

In whole-pose mode, derivatives are typically computed in the forward pass and then cached for the backward pass to use later. In block-pair mode, the backward pass is computed separately in order to avoid having to cache N-by-N tensors for every EnergyTerm. The result is that block-pair scoring takes longer on the backward pass compared to whole-pose mode.

Unless you specifically need energies at the block-pair level, it is recommended to use whole-pose scoring.

Manipulating PoseStacks

TODO: this whole section could use a lot of work

  1. setting xyzs
  2. setting torsion angles
  3. minimizing in either space using the included torch module. For minimizing poses, we recommend L-BFGS with Armijo line search; although other torch minimizers work, they are very slow at minimizing energies of pose stacks.

tmol includes some basic functionality for manipulating and optimizing the structures in a PoseStack.

Minimization

The tmol Minimizer uses the derivatives calculated in a ScoreFunction to perform a gradient-decent optimization of the PoseStack's coordinates.

Currently the minimization happens in Cartesian-space, though it should also support kinematic-based minimization in the near future.

The included minimizer uses the L-BFGS algorithm with Armijo line search, with scaling and parameters taken from Rosetta. Other functions can theoretically be used, including torch's built in minimization algorithms, but they are comparatively very slow at minimizing energies of PoseStacks.

TODO: Other manipulations? phi/psi setting?

Exporting PoseStacks

Exporting PoseStacks is just the inverse of PoseStack creation. Like creation, going back to the CanonicalForm uses common code, but requires a data-source specific CanonicalOrdering object. From the CanonicalForm, data-source specific code must be written to convert back into the source's data format.

Note

TODO: we should be able to convert to data formats other than the original source, assuming they both have all the required block-types that are present in the PoseStack, right? It seems trivially true for PDBs at least. Might be worth talking about this.

Python, C++, and CUDA

tmol is primarily written in Python, with C++/CUDA being used to write optimized low level code for specific operations (most EnergyTerms, for example). C++ functions are exported to Python using the torch C++ interface.

When C++/CUDA is used, both a CPU and a CUDA version are compiled. This compilation is done Just-In-Time (JIT) by Ninja when used. tmol makes use of a 'diamond' structure to share the implementation code between C++/CUDA. Note that this means implementation code may only use functions that are available both in C++17 and CUDA (critically, things like std::cout are missing).

Warning

There is currently a bug in the CUDA compilation where the JIT compiling may fail to recognize updates to the code. If you notice a difference between the behavior of your C++ and CUDA implementations, you may need to delete the local cached object files to force a recompile.