From 011b8b0ae39c0e6287e5a2af1d71c5945dfc662e Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Mon, 20 Jan 2025 20:15:03 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(arrayish):=20unary=20generic?= =?UTF-8?q?=20mixin=20(#123)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/quaxed/experimental/_arrayish/unary.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/quaxed/experimental/_arrayish/unary.py b/src/quaxed/experimental/_arrayish/unary.py index 272368d..c7f4ebc 100644 --- a/src/quaxed/experimental/_arrayish/unary.py +++ b/src/quaxed/experimental/_arrayish/unary.py @@ -220,9 +220,11 @@ def __abs__(self) -> R: # Combined Mixins -class LaxUnaryMixin(LaxPosMixin, LaxNegMixin, LaxAbsMixin): +class LaxUnaryMixin(LaxPosMixin, LaxNegMixin[R], LaxAbsMixin[R]): """Combined mixin for unary operations using quaxified `jax.lax`.""" -class NumpyUnaryMixin(NumpyPosMixin, NumpyNegMixin, NumpyAbsMixin, NumpyInvertMixin): +class NumpyUnaryMixin( + NumpyPosMixin, NumpyNegMixin[R], NumpyAbsMixin[R], NumpyInvertMixin[R] +): """Combined mixin for unary operations using quaxified `jax.numpy`."""