mypy_einsum
is a Mypy plugin for type checking np.einsum
, jax.numpy.einsum
, and torch.einsum
operations.
The Einstein summation convention can be used to compute many multi-dimensional, linear algebraic array operations. einsum
provides a succinct way of representing these.
However, since einsum
equations are passed as a string, it is very easy to overlook typos or other bugs as linters are unable to help.
mypy_einsum
is a Mypy plugin that that is able to statically verify the correctness of einsum
equations with needing to execute the code.
mypy_einsum
can be installed with pip
:
pip install mypy-einsum
To enable the plugin, add it to you projects Mypy configuration file.
Usually mypy.ini
:
[mypy]
plugins = mypy_einsum
or pyproject.toml
:
[tool.mypy]
plugins = ["mypy_einsum"]
Can you spot the 🐛 without running the code?
import numpy as np
a = np.arange(9).reshape(3, 3)
np.einsum("ik,kj->ij", a)
mypy_einsum
will catch it for you:
❯ mypy example.py --pretty
example.py:5: error: Number of einsum subscripts must be equal to the
number of operands. [einsum]
np.einsum("ik,kj->ij", a)
^~~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)
After fixing it mypy
will succeed 🎉:
np.einsum("ik,kj->ij", a, a)
❯ mypy example.py
Success: no issues found in 1 source file
mypy_einsum
aims to never raise warnings for valid einsum
operations. If you encounter a warning that you believe is incorrect, or think mypy_einsum
is not reporting an error please let us know. Contributions are very welcome!