Skip to content

Commit

Permalink
[Feat] Fixed a few type hints manually
Browse files Browse the repository at this point in the history
  • Loading branch information
fjosw committed Dec 25, 2024
1 parent 9fe375a commit 8d86295
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions pyerrors/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def __add__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
else:
return derived_observable(lambda x, **kwargs: x[0] + y, [self], man_grad=[1])

def __radd__(self, y: Union[float, int]) -> "Obs":
def __radd__(self, y: Union[float, int]) -> Obs:
return self + y

def __mul__(self, y: Any) -> Union[Obs, ndarray, CObs, NotImplementedType]:
Expand All @@ -812,7 +812,7 @@ def __mul__(self, y: Any) -> Union[Obs, ndarray, CObs, NotImplementedType]:
else:
return derived_observable(lambda x, **kwargs: x[0] * y, [self], man_grad=[y])

def __rmul__(self, y: Union[float, int]) -> "Obs":
def __rmul__(self, y: Union[float, int]) -> Obs:
return self * y

def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]:
Expand All @@ -826,13 +826,13 @@ def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]:
else:
return derived_observable(lambda x, **kwargs: x[0] - y, [self], man_grad=[1])

def __rsub__(self, y: Union[float, int]) -> "Obs":
def __rsub__(self, y: Union[float, int]) -> Obs:
return -1 * (self - y)

def __pos__(self) -> "Obs":
def __pos__(self) -> Obs:
return self

def __neg__(self) -> "Obs":
def __neg__(self) -> Obs:
return -1 * self

def __truediv__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]:
Expand All @@ -846,7 +846,7 @@ def __truediv__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]:
else:
return derived_observable(lambda x, **kwargs: x[0] / y, [self], man_grad=[1 / y])

def __rtruediv__(self, y: Union[float, int]) -> "Obs":
def __rtruediv__(self, y: Union[float, int]) -> Obs:
if isinstance(y, Obs):
return derived_observable(lambda x, **kwargs: x[0] / x[1], [y, self], man_grad=[1 / self.value, - y.value / self.value ** 2])
else:
Expand All @@ -857,62 +857,62 @@ def __rtruediv__(self, y: Union[float, int]) -> "Obs":
else:
return derived_observable(lambda x, **kwargs: y / x[0], [self], man_grad=[-y / self.value ** 2])

def __pow__(self, y: Union[Obs, float, int]) -> "Obs":
def __pow__(self, y: Union[Obs, float, int]) -> Obs:
if isinstance(y, Obs):
return derived_observable(lambda x, **kwargs: x[0] ** x[1], [self, y], man_grad=[y.value * self.value ** (y.value - 1), self.value ** y.value * np.log(self.value)])
else:
return derived_observable(lambda x, **kwargs: x[0] ** y, [self], man_grad=[y * self.value ** (y - 1)])

def __rpow__(self, y: Union[float, int]) -> "Obs":
def __rpow__(self, y: Union[float, int]) -> Obs:
return derived_observable(lambda x, **kwargs: y ** x[0], [self], man_grad=[y ** self.value * np.log(y)])

def __abs__(self) -> "Obs":
def __abs__(self) -> Obs:
return derived_observable(lambda x: anp.abs(x[0]), [self])

# Overload numpy functions
def sqrt(self) -> "Obs":
def sqrt(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.sqrt(x[0]), [self], man_grad=[1 / 2 / np.sqrt(self.value)])

def log(self) -> "Obs":
def log(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.log(x[0]), [self], man_grad=[1 / self.value])

def exp(self) -> "Obs":
def exp(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.exp(x[0]), [self], man_grad=[np.exp(self.value)])

def sin(self) -> "Obs":
def sin(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.sin(x[0]), [self], man_grad=[np.cos(self.value)])

def cos(self) -> "Obs":
def cos(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.cos(x[0]), [self], man_grad=[-np.sin(self.value)])

def tan(self) -> "Obs":
def tan(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2])

def arcsin(self) -> "Obs":
def arcsin(self) -> Obs:
return derived_observable(lambda x: anp.arcsin(x[0]), [self])

def arccos(self) -> "Obs":
def arccos(self) -> Obs:
return derived_observable(lambda x: anp.arccos(x[0]), [self])

def arctan(self) -> "Obs":
def arctan(self) -> Obs:
return derived_observable(lambda x: anp.arctan(x[0]), [self])

def sinh(self) -> "Obs":
def sinh(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)])

def cosh(self) -> "Obs":
def cosh(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.cosh(x[0]), [self], man_grad=[np.sinh(self.value)])

def tanh(self) -> "Obs":
def tanh(self) -> Obs:
return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2])

def arcsinh(self) -> "Obs":
def arcsinh(self) -> Obs:
return derived_observable(lambda x: anp.arcsinh(x[0]), [self])

def arccosh(self) -> "Obs":
def arccosh(self) -> Obs:
return derived_observable(lambda x: anp.arccosh(x[0]), [self])

def arctanh(self) -> "Obs":
def arctanh(self) -> Obs:
return derived_observable(lambda x: anp.arctanh(x[0]), [self])


Expand Down Expand Up @@ -944,7 +944,7 @@ def is_zero(self) -> bool:
"""Checks whether both real and imaginary part are zero within machine precision."""
return self.real == 0.0 and self.imag == 0.0

def conjugate(self) -> "CObs":
def conjugate(self) -> CObs:
return CObs(self.real, -self.imag)

def __add__(self, other: Any) -> Union[CObs, ndarray]:
Expand Down Expand Up @@ -989,7 +989,7 @@ def __mul__(self, other: Any) -> Union[CObs, ndarray]:
else:
return CObs(self.real * other, self.imag * other)

def __rmul__(self, other: Union[complex, Obs, float, int]) -> "CObs":
def __rmul__(self, other: Union[complex, Obs, CObs, float, int]) -> "CObs":
return self * other

def __truediv__(self, other: Any) -> Union[CObs, ndarray]:
Expand All @@ -1001,7 +1001,7 @@ def __truediv__(self, other: Any) -> Union[CObs, ndarray]:
else:
return CObs(self.real / other, self.imag / other)

def __rtruediv__(self, other: Union[complex, float, Obs, int]) -> "CObs":
def __rtruediv__(self, other: Union[complex, float, Obs, CObs, int]) -> CObs:
r = self.real ** 2 + self.imag ** 2
if hasattr(other, 'real') and hasattr(other, 'imag'):
return CObs((self.real * other.real + self.imag * other.imag) / r, (self.real * other.imag - self.imag * other.real) / r)
Expand Down

0 comments on commit 8d86295

Please sign in to comment.