Skip to content

Commit

Permalink
Merge pull request #298 from laserkelvin/training-documentation-update
Browse files Browse the repository at this point in the history
Training documentation update
  • Loading branch information
laserkelvin authored Sep 30, 2024
2 parents 3930ee4 + 17d7582 commit 5bdd353
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 10 deletions.
89 changes: 89 additions & 0 deletions docs/source/best-practices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,81 @@ the accelerator.
Training
--------

Target normalization
^^^^^^^^^^^^^^^^^^^^

Tasks can be provided with ``normalize_kwargs``, which are key/value mappings
that specify the mean and standard deviation of a target; an example is given below.

.. code-block: python
Task(
...,
normalize_kwargs={
"energy_mean": 0.0,
"energy_std": 1.0,
}
)
The example above will normalize ``energy`` labels and can be substituted with
any of target key of interest (e.g. ``force``, ``bandgap``, etc.)

Target loss scaling
^^^^^^^^^^^^^^^^^^^

A generally common practice is to scale some targets relative to others (e.g. force over
energy, etc). To specify this, you can pass a ``task_loss_scaling`` dictionary to
any task module, which maps target keys to a floating point value that will be used
to multiply the corresponding target loss value before summation and backpropagation.

.. code-block: python
Task(
...,
task_loss_scaling={
"energy": 1.0,
"force": 10.0
}
)
A related, but alternative way to specify target scaling is to apply a *schedule* to
the training loss contributions: essentially, this provides a way to smoothly ramp
up (or down) different targets, i.e. to allow for more complex training curricula.
To achieve this, you will need to use the ``LossScalingScheduler`` callback,

.. autoclass:: matsciml.lightning.callbacks.LossScalingScheduler
:members:


To specify this callback, you must pass subclasses of ``BaseScalingSchedule`` as arguments.
Each schedule type implements the functional form of a schedule, and currently
there are two concrete schedules. Composed together, an example would look like this

.. code-block: python
import pytorch_lightning as pl
from matsciml.lightning.callbacks import LossScalingScheduler
from matsciml.lightning.loss_scaling import LinearScalingSchedule
scheduler = LossScalingScheduler(
LinearScalingSchedule("energy", initial_value=1.0, end_value=5.0, step_frequency="epoch")
)
trainer = pl.Trainer(callbacks=[scheduler])
The stepping schedule is determined during ``setup`` (as training begins), where the callback will
inspect ``Trainer`` arguments to determine how many steps will be taken. The ``step_frequency``
just specifies how often the learning rate is updated.


.. autoclass:: matsciml.lightning.loss_scaling.LinearScalingSchedule
:members:


.. autoclass:: matsciml.lightning.loss_scaling.SigmoidScalingSchedule
:members:


Quick debugging
^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -223,6 +298,20 @@ inspired by observations made in LLM training research, where the breakdown of
assumptions in the convergent properties of ``Adam``-like optimizers causes large
spikes in the training loss. This callback can help identify these occurrences.

The ``devset``/``fast_dev_run`` approach detailed above is also useful for testing
engineering/infrastructure (e.g. accelerator offload), but not necessarily
for probing training dynamics. Instead, we recommend using the ``overfit_batches``
argument in ``pl.Trainer``

.. code-block:: python
import pytorch_lightning as pl
trainer = pl.Trainer(overfit_batches=100)
This will disable shuffling in the training and validation splits (per the PyTorch Lightning
documentation), and ensure that the same batches are being reused every epoch.

.. _e3nn documentation: https://docs.e3nn.org/en/latest/

.. _IPEX installation: https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu
18 changes: 18 additions & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ Inference
"Inference" can be a bit of an overloaded term, and this page is broken down into different possible
downstream use cases for trained models.

Task ``predict`` and ``forward`` methods
----------------------------------------

``matsciml`` tasks implement separate ``forward`` and ``predict`` methods. Both take a
``BatchDict`` as input, and the latter wraps the former. The difference, however, is that
``predict`` is intended for inference use primarily because it will also take care of
reversing the normalization procedure, if they were provided during training, *and* perhaps
more importantly, will ensure that the exponential moving average weights are used instead
of the training ones.

In the special case of force prediction (as a derivative of the energy) tasks, you should
only need to specify normalization ``kwargs`` for energy: the scale value is taking automatically
from the energy value, and applied to forces.

In short, if you are writing functionality that requires unnormalized outputs (e.g. ``ase`` calculators),
please ensure you are using ``predict`` instead of ``forward`` directly.


Parity plots and model evaluations
----------------------------------

Expand Down
158 changes: 148 additions & 10 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,152 @@
Training pipeline
=================
Task abstraction
================

Training with the Open MatSci ML Toolkit utilizes—for the most part—the
PyTorch Lightning abstractions.
The Open MatSciML Toolkit uses PyTorch Lightning abstractions for managing the flow
of training: how data from a datamodule gets mapped, to what loss terms are calculated,
to what gets logged is defined in a base task class. From start to finish, this module
will take in the definition of an encoding architecture (through ``encoder_class`` and
``encoder_kwargs`` keyword arguments), construct it, and in concrete task implementations,
initialize the respective output heads a set of provided or task-specific target keys.
The ``encoder_kwargs`` specification makes things a bit more verbose, but this ensures
that the hyperparameters are saved appropriately per the ``save_hyperparameters`` method
in PyTorch Lightning.

Task API reference
##################

.. autosummary::
:toctree: generated
:recursive:
``BaseTaskModule`` API reference
--------------------------------

matsciml.models.base
.. autoclass:: matsciml.models.base.BaseTaskModule
:members:


Multi task reference
--------------------------------

One core functionality for ``matsciml`` is the ability to compose multiple tasks
together, in an (almost) seamless fashion from the single task case.

.. important::
The ``MultiTaskLitModule`` is not written in a particularly friendly way at
the moment, and may be subject to a significant refactor later!


.. autoclass:: matsciml.models.base.MultiTaskLitModule
:members:


``OutputHead`` API reference
----------------------------

While there is a singular ``OutputHead`` definition, the blocks that constitute
an ``OutputHead`` can be specified depending on the type of model architecture
being used. The default stack is based on simple ``nn.Linear`` layers, however,
for architectures like MACE which may depend on preserving irreducible representations,
the ``IrrepOutputBlock`` allows users to specify transformations per-representation.

.. autoclass:: matsciml.models.common.OutputHead
:members:


.. autoclass:: matsciml.models.common.OutputBlock
:members:


.. autoclass:: matsciml.models.common.IrrepOutputBlock
:members:


Scalar regression
-----------------

This task is primarily designed for tasks adjacent to property prediction: you can
predict an arbitrary number of properties (per output head), based on a shared
embedding (i.e. one structure maps to a single embedding, which is used by each head).

A special case for using this class would be in tandem (as a multitask setup) with
the :ref:`_gradfree_force`, which treats energy/force prediction as two
separate output heads, albeit with the same shared embedding.

Please use continuous valued (e.g. ``nn.MSELoss``) loss metrics for this task.


.. autoclass:: matsciml.models.base.ScalarRegressionTask
:members:


Binary classification
-----------------------

This task, as the name suggests, uses the embedding to perform one or more binary
classifications with a shared embedding. This can be something like a ``stability``
label like in the Materials Project. Keep in mind, however, that a special class
exists for crystal symmetry classification.

.. autoclass:: matsciml.models.base.BinaryClassificationTask
:members:

.. _crystal_symmetry:

Crystal symmetry classification
-------------------------------

This task is a specialized class for what is essentially multiclass classification,
where given an embedding, we predict which crystal space group the structure belongs
to using ``nn.CrossEntropyLoss``. This can be a good potential pretraining task.


.. note::
This task expects that your data includes ``spacegroup`` target key.

.. autoclass:: matsciml.models.base.CrystalSymmetryClassificationTask
:members:


Force regression task
---------------------

This task implements energy/force regression, where an ``OutputHead`` is used to first
predict the energy, followed by taking its derivative with respect to the input coordinates.
From a developer perspective, this task is quite mechanically different due to the need
for manual ``autograd``, which is not normally supported by PyTorch Lightning workflows.


.. note::
This task expects that your data includes ``force`` target key.

.. autoclass:: matsciml.models.base.ForceRegressionTask
:members:


.. _gradfree_force:

Gradient-free force regression task
-----------------------------------

This task implements a force prediction task, albeit as a direct output head property
prediction as opposed to the derivative of an energy value using ``autograd``.

.. note::
This task expects that your data includes ``force`` target key.

.. autoclass:: matsciml.models.base.GradFreeForceRegressionTask
:members:


Node denoising task
-------------------

This task implements a powerful, and recently becoming more popular, pre-training strategy
for graph neural networks. The premise is quite simple: an encoder learns as a denoising
autoencoder by taking in a perturbed structure, and attempting to predict the amount of
noise in the 3D coordinates.

As a requirement, this task requires the following data transform; you are able to specify
the scale of the noise added to the positions and intuitively the large the scale, the higher
potential difficulty in the task.

.. autoclass:: matsciml.datasets.transforms.pretraining.NoisyPositions
:members:


.. autoclass:: matsciml.models.base.NodeDenoisingTask
:members:

0 comments on commit 5bdd353

Please sign in to comment.