JAX prefers immutable objects but neither Python nor JAX provide an immutable
dictionary. 😢
This repository defines a light-weight immutable map
(lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️
pip install xmmutablemap
xmutablemap
provides the class ImmutableMap
, which is a full implementation
of
Python's Mapping
ABC.
If you've used a dict
then you already know how to use ImmutableMap
! The
things ImmutableMap
adds is 1) immutability (and related benefits like
hashability) and 2) compatibility with JAX
.
from xmmutablemap import ImmutableMap
print(ImmutableMap(a=1, b=2, c=3))
# ImmutableMap({'a': 1, 'b': 2, 'c': 3})
print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
# ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})
We welcome contributions!