Skip to content

Commit

Permalink
Merge pull request #137 from Jammy2211/feature/jax_tracer
Browse files Browse the repository at this point in the history
Wrap class as PyTree for Tracer
  • Loading branch information
CKrawczyk authored Oct 25, 2024
2 parents c597f33 + e641904 commit f4da19e
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions autoarray/operators/over_sampling/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from autoarray import exc
from autoarray.operators.over_sampling import over_sample_util

from autofit.jax_wrapper import register_pytree_node_class


@register_pytree_node_class
class OverSamplingUniform(AbstractOverSampling):
def __init__(self, sub_size: Union[int, Array2D]):
"""
Expand Down Expand Up @@ -319,6 +322,15 @@ def over_sampler_from(self, mask: Mask2D) -> "OverSamplerUniform":
mask=mask,
sub_size=self.sub_size,
)

def tree_flatten(self):
children = (self.sub_size,)
aux_data = None
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)


class OverSamplerUniform(AbstractOverSampler):
Expand Down

0 comments on commit f4da19e

Please sign in to comment.