diff --git a/test/testsuite/spectral.jl b/test/testsuite/spectral.jl index 223f63bc..0bf46c72 100644 --- a/test/testsuite/spectral.jl +++ b/test/testsuite/spectral.jl @@ -26,10 +26,10 @@ function spectral_testsuite(Backend) @testset "Grads" begin if Backend != CPU x = rand(Float32, 16, batch...) - gradtest_fn(s -> abs.(stft(s; n_fft=16)), x) + window = hann_window(16) - window = device(hann_window(16)) - gradtest_fn(s -> abs.(stft(s; n_fft=16, window)), x) + gradtest_fn(s -> abs.(stft(s; n_fft=16)), x) + gradtest_fn((s, w) -> abs.(stft(s; n_fft=16, window=w)), x, window) end end