diff --git a/src/haliax/partitioning.py b/src/haliax/partitioning.py index 10260b5..903e4ff 100644 --- a/src/haliax/partitioning.py +++ b/src/haliax/partitioning.py @@ -313,7 +313,7 @@ def _call(self, is_lower, *args, **kwargs): output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs) my_pjit_args = dict(**self._pjit_args) - if in_axis_resources is not None or axis_resources is not None: + if in_axis_resources is not None: in_resources = infer_resource_partitions( (dynamic_donated, dynamic_reserved), in_axis_resources,