From 30bfb55981f97b7c5e78c01b728b5f2277a74a25 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Tue, 26 Nov 2024 17:52:27 +0100 Subject: [PATCH] [Feat] Provide derivatives for pow (#246) * [Feat] Provide manual derivatives for __pow__ * [Feat] Also applied changes to rpow * [Test] Another pow test added. --- pyerrors/obs.py | 9 +++------ tests/obs_test.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index a1c2fd55..0caecfdc 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -856,15 +856,12 @@ def __rtruediv__(self, y): def __pow__(self, y): if isinstance(y, Obs): - return derived_observable(lambda x: x[0] ** x[1], [self, y]) + return derived_observable(lambda x, **kwargs: x[0] ** x[1], [self, y], man_grad=[y.value * self.value ** (y.value - 1), self.value ** y.value * np.log(self.value)]) else: - return derived_observable(lambda x: x[0] ** y, [self]) + return derived_observable(lambda x, **kwargs: x[0] ** y, [self], man_grad=[y * self.value ** (y - 1)]) def __rpow__(self, y): - if isinstance(y, Obs): - return derived_observable(lambda x: x[0] ** x[1], [y, self]) - else: - return derived_observable(lambda x: y ** x[0], [self]) + return derived_observable(lambda x, **kwargs: y ** x[0], [self], man_grad=[y ** self.value * np.log(y)]) def __abs__(self): return derived_observable(lambda x: anp.abs(x[0]), [self]) diff --git a/tests/obs_test.py b/tests/obs_test.py index 726ecffa..8b82213f 100644 --- a/tests/obs_test.py +++ b/tests/obs_test.py @@ -461,6 +461,18 @@ def test_cobs_overloading(): obs / cobs +def test_pow(): + data = [1, 2.341, pe.pseudo_Obs(4.8, 0.48, "test_obs"), pe.cov_Obs(1.1, 0.3 ** 2, "test_cov_obs")] + + for d in data: + assert d * d == d ** 2 + assert d * d * d == d ** 3 + + for d2 in data: + assert np.log(d ** d2) == d2 * np.log(d) + assert (d ** d2) ** (1 / d2) == d + + def test_reweighting(): my_obs = pe.Obs([np.random.rand(1000)], ['t']) assert not my_obs.reweighted