Official code repository of paper Equivariance via Minimal Frame Averaging for More Symmetries and Efficiency. In this repository, we have provided decorators to convert any non-equivariant/invariant neural network functions into group equivariant/invariant ones. This enables neural networks to handle transformations from various groups such as
In this code base, we provide the equivariance error test for all the groups included in the paper, which suffices to show our idea and corresponding algorithms.
Note: It is recommended to use a conda
environment. First, you can install torch
with the following command (if GPU is available):
conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia
Then, install our package using:
pip install -e .
The od_equivariant_decorator
can be used to wrap any forward function to make it
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# Define your layers here
def forward(self, x):
# Your forward pass implementation
return x
To apply this decorator to this neural network class’s forward function:
import torch.nn as nn
from minimal_frame.group_decorator import od_equivariant_decorator
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# Define your layers here
@od_equivariant_decorator
def forward(self, x):
# Your forward pass implementation
return x
To apply this decorator to a neural network instance's forward function:
# Instantiate your model
model = MyModel()
# Apply the O(d)-equivariant decorator
model.forward = od_equivariant_decorator(model.forward)
Similarly, the od_invariant_decorator
can be used to make the network
This repository provides a collection of decorators to enforce equivariance or invariance properties in neural network models. Equivariance ensures that the output of the model transforms in the same way as the input under a given group of transformations. Invariance ensures that the output remains unchanged under such transformations.
These decorators in group_decorator.py
are model-agnostic and can be applied to any model with input and output shapes of
- Orthogonal Group
$O(d)$ :-
od_equivariant_decorator/od_invariant_decorator
: Our MFA method for$O(d)$ -
od_equivariant_puny_decorator/od_invariant_puny_decorator
: Puny's FA method [1] for$O(d)$ -
od_equivariant_puny_improve_decorator/od_invariant_puny_improve_decorator
: Improved [1]'s FA for degenerate eigenvalues, see our Appendix H.3 -
od_equivariant_sfa_decorator/od_invariant_sfa_decorator
: Stochastic FA method [2]
-
- Special Orthogonal Group
$SO(d)$ :sod_equivariant_decorator/sod_invariant_decorator
- Euclidean Group
$E(d)$ :ed_equivariant_decorator/ed_invariant_decorator
- Special Euclidean Group
$SE(d)$ :sed_equivariant_decorator/sed_invariant_decorator
- Unitary Group
$U(d)$ :ud_equivariant_decorator/ud_invariant_decorator
- Special Unitary Group
$SU(d)$ :sud_equivariant_decorator/sud_invariant_decorator
- Lorentz Group
$O(1,d-1)$ :o1d_equivariant_decorator/o1d_invariant_decorator
- Proper Lorentz Group
$SO(1,d-1)$ :so1d_equivariant_decorator/so1d_invariant_decorator
- General Linear Group
$GL(d,\mathbb{R})$ (requiring full-column-rank input):gld_equivariant_decorator/gld_invariant_decorator
- Special Linear Group
$SL(d,\mathbb{R})$ (requiring full-column-rank input):sld_equivariant_decorator/sld_invariant_decorator
- Affine Group
$Aff(d,\mathbb{R})$ (requiring full-column-rank input):affd_equivariant_decorator/affd_invariant_decorator
- Permutation Group
$S_n$ :sn_equivariant_decorator/sn_invariant_decorator
- Direct product between Permutation Group
$S_n$ and other groups-
$S_n \times O(d)$ :sn_od_equivariant_decorator/sn_od_invariant_decorator
-
$S_n \times SO(d)$ :sn_sod_equivariant_decorator/sn_sod_invariant_decorator
-
$S_n \times O(1,d-1)$ :sn_o1d_equivariant_decorator/sn_o1d_invariant_decorator
-
Additionally, the graph_group_decorator.py
for undirected graphs are provided for any model with undirected adjacency matrices as input. The corresponding equivariance analysis can be found here, including decorators for
- Undirected unweighted adjacency matrix:
undirected_unweighted_sn_equivariant_decorator/undirected_unweighted_sn_invariant_decorator
- Undirected weighted adjacency matrix:
undirected_weighted_sn_equivariant_decorator/undirected_weighted_sn_invariant_decorator
[1] Puny, Omri, et al. "Frame averaging for invariant and equivariant network design." arXiv preprint arXiv:2110.03336 (2021).
[2] Duval, Alexandre Agm, et al. "Faenet: Frame averaging equivariant gnn for materials modeling." International Conference on Machine Learning. PMLR, 2023.
[3] McKay, Brendan D. "Practical graph isomorphism." (1981): 45-87.
This project is licensed under the MIT License. Please note that the nauty
software is covered by its own licensing terms. All rights to the files mentioned in nauty.py
are reserved by Brendan McKay under the Apache License 2.0.
We gratefully acknowledge the insightful discussions with Derek Lim and Hannah Lawrence. This work was supported in part by National Science Foundation grant IIS-2006861 and National Institutes of Health grant U01AG070112.
If you find this work useful, please consider citing:
@inproceedings{
lin2024equivariance,
title={Equivariance via Minimal Frame Averaging for More Symmetries and Efficiency},
author={Yuchao Lin and Jacob Helwig and Shurui Gui and Shuiwang Ji},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=guFsTBXsov}
}