diff --git a/autoarray/operators/over_sampling/uniform.py b/autoarray/operators/over_sampling/uniform.py index cd075178..5c8652b0 100644 --- a/autoarray/operators/over_sampling/uniform.py +++ b/autoarray/operators/over_sampling/uniform.py @@ -14,7 +14,10 @@ from autoarray import exc from autoarray.operators.over_sampling import over_sample_util +from autofit.jax_wrapper import register_pytree_node_class + +@register_pytree_node_class class OverSamplingUniform(AbstractOverSampling): def __init__(self, sub_size: Union[int, Array2D]): """ @@ -319,6 +322,15 @@ def over_sampler_from(self, mask: Mask2D) -> "OverSamplerUniform": mask=mask, sub_size=self.sub_size, ) + + def tree_flatten(self): + children = (self.sub_size,) + aux_data = None + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) class OverSamplerUniform(AbstractOverSampler):