From 6cf2edf6dbee7a496a9c03780c02c2bf3e635255 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 1 Feb 2025 15:35:30 +0100 Subject: [PATCH] ENH: add dtype argument to fft.{fftfreq,rfftfreq} --- array_api_compat/common/_fft.py | 32 +++++++++++++++++++++++++----- array_api_compat/common/_typing.py | 1 + 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index 666b0b1f..e5caebef 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union, Optional, Literal if TYPE_CHECKING: - from ._typing import Device, ndarray + from ._typing import Device, ndarray, DType from collections.abc import Sequence # Note: NumPy fft functions improperly upcast float32 and complex64 to @@ -149,15 +149,37 @@ def ihfft( return res.astype(xp.complex64) return res -def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: +def fftfreq( + n: int, + /, + xp, + *, + d: float = 1.0, + dtype: Optional[DType] = None, + device: Optional[Device] = None +) -> ndarray: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - return xp.fft.fftfreq(n, d=d) + res = xp.fft.fftfreq(n, d=d) + if dtype is not None: + return res.astype(dtype) + return res -def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: +def rfftfreq( + n: int, + /, + xp, + *, + d: float = 1.0, + dtype: Optional[DType] = None, + device: Optional[Device] = None +) -> ndarray: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - return xp.fft.rfftfreq(n, d=d) + res = xp.fft.rfftfreq(n, d=d) + if dtype is not None: + return res.astype(dtype) + return res def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: return xp.fft.fftshift(x, axes=axes) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 07f3850d..1f916cd8 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -21,3 +21,4 @@ def __len__(self, /) -> int: ... Array = Any Device = Any +DType = Any