Skip to content

Commit

Permalink
Wrap class as PyTree
Browse files Browse the repository at this point in the history
Needed to make the `autolens.Tracer` example work.
  • Loading branch information
CKrawczyk committed Oct 25, 2024
1 parent c597f33 commit e641904
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions autoarray/operators/over_sampling/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from autoarray import exc
from autoarray.operators.over_sampling import over_sample_util

from autofit.jax_wrapper import register_pytree_node_class


@register_pytree_node_class
class OverSamplingUniform(AbstractOverSampling):
def __init__(self, sub_size: Union[int, Array2D]):
"""
Expand Down Expand Up @@ -319,6 +322,15 @@ def over_sampler_from(self, mask: Mask2D) -> "OverSamplerUniform":
mask=mask,
sub_size=self.sub_size,
)

def tree_flatten(self):
children = (self.sub_size,)
aux_data = None
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)


class OverSamplerUniform(AbstractOverSampler):
Expand Down

0 comments on commit e641904

Please sign in to comment.