diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b7620fb6b..6b8493e43 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -93,14 +93,6 @@ jobs: using Pkg Pkg.develop(PackageSpec(path=pwd())) Pkg.instantiate()' - - run: | - julia --color=yes --project=docs/ -e ' - using NNlib - # using Pkg; Pkg.activate("docs") - using Documenter - using Documenter: doctest - DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive=true) - doctest(NNlib)' - run: julia --project=docs docs/make.jl env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Project.toml b/Project.toml index 0496564e0..f3ad440ca 100644 --- a/Project.toml +++ b/Project.toml @@ -17,22 +17,26 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" [extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" +NNlibFFTWExt = "FFTW" [compat] -AMDGPU = "0.8, 0.9" +AMDGPU = "0.9.4" Adapt = "3.2, 4" Atomix = "0.1" CUDA = "4, 5" +cuDNN = "1" ChainRulesCore = "1.13" EnzymeCore = "0.5, 0.6, 0.7" +FFTW = "1.8.0" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" LinearAlgebra = "<0.0.1, 1" @@ -40,28 +44,4 @@ Pkg = "<0.0.1, 1" Random = "<0.0.1, 1" Requires = "1.0" Statistics = "1" -cuDNN = "1" julia = "1.9" - -[extras] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils", "Interpolations", "ImageTransformations"] diff --git a/docs/Project.toml b/docs/Project.toml index ce9aa3290..9de5539c2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,9 @@ [deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" diff --git a/docs/make.jl b/docs/make.jl index 599aaf9dc..4bffca944 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,19 +1,21 @@ using Documenter, NNlib -DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive = true) +DocMeta.setdocmeta!(NNlib, :DocTestSetup, + :(using FFTW, NNlib, UnicodePlots); recursive = true) makedocs(modules = [NNlib], - sitename = "NNlib.jl", - doctest = false, - pages = ["Home" => "index.md", - "Reference" => "reference.md"], - format = Documenter.HTML( - canonical = "https://fluxml.ai/NNlib.jl/stable/", - # analytics = "UA-36890222-9", - assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true"), - warnonly=[:missing_docs,] - ) + sitename = "NNlib.jl", + doctest = true, + pages = ["Home" => "index.md", + "Reference" => "reference.md", + "Audio" => "audio.md"], + format = Documenter.HTML( + canonical = "https://fluxml.ai/NNlib.jl/stable/", + # analytics = "UA-36890222-9", + assets = ["assets/flux.css"], + prettyurls = get(ENV, "CI", nothing) == "true"), + warnonly=[:missing_docs,] +) deploydocs(repo = "github.com/FluxML/NNlib.jl.git", target = "build", diff --git a/docs/src/assets/jfk.flac b/docs/src/assets/jfk.flac new file mode 100644 index 000000000..24841d55a Binary files /dev/null and b/docs/src/assets/jfk.flac differ diff --git a/docs/src/audio.md b/docs/src/audio.md new file mode 100644 index 000000000..e56a5bf43 --- /dev/null +++ b/docs/src/audio.md @@ -0,0 +1,61 @@ +# Reference + +!!! note + Spectral functions require importing `FFTW` package to enable them. + +## Window functions + +```@docs +hann_window +hamming_window +``` + +## Spectral + +```@docs +stft +istft +NNlib.power_to_db +NNlib.db_to_power +``` + +## Spectrogram + +```@docs +melscale_filterbanks +spectrogram +``` + +Example: + +```@example 1 +using FFTW # <- required for STFT support. +using NNlib +using FileIO +using Makie, CairoMakie +CairoMakie.activate!() + +waveform, sampling_rate = load("./assets/jfk.flac") +fig = lines(reshape(waveform, :)) +save("waveform.png", fig) + +# Spectrogram. + +n_fft = 1024 +spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft)) +fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1])) +save("spectrogram.png", fig) + +# Mel-scale spectrogram. + +n_freqs = n_fft ÷ 2 + 1 +fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate)) +mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels) +fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1]) +save("mel-spectrogram.png", fig) +nothing # hide +``` + +|Waveform|Spectrogram|Mel Spectrogram| +|:---:|:---:|:---:| +|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)| diff --git a/ext/NNlibFFTWExt/NNlibFFTWExt.jl b/ext/NNlibFFTWExt/NNlibFFTWExt.jl new file mode 100644 index 000000000..ee314cd51 --- /dev/null +++ b/ext/NNlibFFTWExt/NNlibFFTWExt.jl @@ -0,0 +1,9 @@ +module NNlibFFTWExt + +using FFTW +using NNlib +using KernelAbstractions + +include("stft.jl") + +end diff --git a/ext/NNlibFFTWExt/stft.jl b/ext/NNlibFFTWExt/stft.jl new file mode 100644 index 000000000..dda76cec1 --- /dev/null +++ b/ext/NNlibFFTWExt/stft.jl @@ -0,0 +1,127 @@ +function NNlib.stft(x; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, +) + kab = get_backend(x) + use_window = !isnothing(window) + + use_window && kab != get_backend(window) && throw(ArgumentError( + "`window` must be on the same device as stft input `x` ($kab), \ + instead: `$(get_backend(window))`.")) + use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError( + "Expected `0 < length(window) ≤ n_fft=$n_fft`, \ + but got `length(window)=$(length(window))`.")) + hop_length < 0 && throw(ArgumentError( + "Expected `hop_length > 0`, but got `hop_length=$hop_length`.")) + + # Pad window on both sides with `0` to `n_fft` length if needed. + if use_window && length(window) < n_fft + left = ((n_fft - length(window)) ÷ 2) + 1 + tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft) + tmp[left:left + length(window) - 1] .= window + window = tmp + end + + if center + pad_amount = n_fft ÷ 2 + x = pad_reflect(x, pad_amount; dims=1) + end + + n = size(x, 1) + (0 < n_fft ≤ n) || throw(ArgumentError( + "Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`.")) + + n_frames = 1 + (n - n_fft) ÷ hop_length + + # time2col. + # Reshape `x` to (n_fft, n_frames, B) if needed. + # Each row in `n_frames` is shifted by `hop_length`. + if n_frames > 1 + # TODO can be more efficient if we support something like torch.as_strided + ids = [ + row + hop_length * col + for row in 1:n_fft, col in 0:(n_frames - 1)] + x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...] + end + + region = 1 + use_window && (x = x .* window;) + y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region) + + normalized && (y = y .* eltype(y)(n_fft^-0.5);) + return y +end + +function NNlib.istft(y; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, + return_complex::Bool = false, + original_length::Union{Nothing, Int} = nothing, +) + kab = get_backend(y) + use_window = !isnothing(window) + + use_window && kab != get_backend(window) && throw(ArgumentError( + "`window` must be on the same device as istft input `y` ($kab), \ + instead: `$(get_backend(window))`.")) + use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError( + "Expected `0 < length(window) ≤ n_fft=$n_fft`, \ + but got `length(window)=$(length(window))`.")) + hop_length < 0 && throw(ArgumentError( + "Expected `hop_length > 0`, but got `hop_length=$hop_length`.")) + + # TODO check `y` eltype is complex + + n_frames = size(y, 2) + + # Pad window on both sides with `0` to `n_fft` length if needed. + if use_window && length(window) < n_fft + left = ((n_fft - length(window)) ÷ 2) + 1 + tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft) + tmp[left:left + length(window) - 1] .= window + window = tmp + end + + # Denormalize. + normalized && (y = y .* eltype(y)(n_fft^0.5);) + + region = 1 + x = return_complex ? ifft(y, region) : irfft(y, n_fft, region) + + # De-apply window. + use_window && (x = x ./ window;) + + # col2time. + expected_output_len = n_fft + hop_length * (n_frames - 1) + + ids = Vector{Int}(undef, expected_output_len) + in_idx, out_idx = 0, 0 + prev_e, v = 0, 0 + + for col in 0:(n_frames - 1) + for row in 1:n_fft + in_idx += 1 + v = row + hop_length * col + v > prev_e || continue + + out_idx += 1 + ids[out_idx] = in_idx + end + prev_e = v + end + + # In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch). + nd = ntuple(_ -> Colon(), ndims(x) - 2) + ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));) + x = x[ids, nd...] + + # Trim padding. + left = center ? (n_fft ÷ 2 + 1) : 1 + right = if isnothing(original_length) + center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len + else + left + original_length - 1 + end + x = x[left:right, nd...] + return x +end diff --git a/src/NNlib.jl b/src/NNlib.jl index 7bf7bd172..73d5c4f56 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -126,4 +126,9 @@ include("deprecations.jl") include("rotation.jl") export imrotate, ∇imrotate +include("audio/stft.jl") +include("audio/spectrogram.jl") +include("audio/mel.jl") +export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks + end # module NNlib diff --git a/src/audio/mel.jl b/src/audio/mel.jl new file mode 100644 index 000000000..6fda9a091 --- /dev/null +++ b/src/audio/mel.jl @@ -0,0 +1,102 @@ +""" + melscale_filterbanks(; + n_freqs::Int, n_mels::Int, sample_rate::Int, + fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2)) + +Create triangular Mel scale filter banks +(ref: https://en.wikipedia.org/wiki/Mel_scale). +Each column is a filterbank that highlights its own frequency. + +# Arguments: + +- `n_freqs::Int`: Number of frequencies to highlight. +- `n_mels::Int`: Number of mel filterbanks. +- `sample_rate::Int`: Sample rate of the audio waveform. +- `fmin::Float32`: Minimum frequency in Hz. +- `fmax::Float32`: Maximum frequency in Hz. + +# Returns: + +Filterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank. + +```jldoctest +julia> n_mels = 8; + +julia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000); + +julia> plot = lineplot(fb[:, 1]); + +julia> for i in 2:n_mels + lineplot!(plot, fb[:, i]) + end + +julia> plot + ┌────────────────────────────────────────┐ + 1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│ + │⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│ + │⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│ + │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│ + │⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│ + │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│ + │⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│ + │⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│ + │⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│ + │⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│ + │⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│ + 0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│ + └────────────────────────────────────────┘ + ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀ +``` +""" +function melscale_filterbanks(; + n_freqs::Int, n_mels::Int, sample_rate::Int, + fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2), +) + mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax) + mel_points = range(mel_min, mel_max; length=n_mels + 2) + + all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs)) + freq_points = _mel_to_hz.(mel_points) + filter_banks = _triangular_filterbanks(freq_points, all_freqs) + + if any(maximum(filter_banks; dims=1) .≈ 0f0) + @warn """At least one mel filterbank has all zero values. + The value for `n_mels=$n_mels` may be set too high. + Or the value for `n_freqs=$n_freqs` may be set too low. + """ + end + return filter_banks +end + +_hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700))) + +_mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1)) + +""" + _triangular_filterbanks( + freq_points::Vector{Float32}, all_freqs::Vector{Float32}) + +Create triangular filter banks. + +# Arguments: + +- `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`. +- `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`. + +# Returns: + +Array of size `(n_freqs, n_filters)`. +""" +function _triangular_filterbanks( + freq_points::Vector{Float32}, all_freqs::Vector{Float32}, +) + diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1]) + slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :)) + + down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :)) + up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :) + return max.(0f0, min.(down_slopes, up_slopes)) +end diff --git a/src/audio/spectrogram.jl b/src/audio/spectrogram.jl new file mode 100644 index 000000000..116b23cf4 --- /dev/null +++ b/src/audio/spectrogram.jl @@ -0,0 +1,79 @@ +""" + spectrogram(waveform; + pad::Int = 0, n_fft::Int, hop_length::Int, window, + center::Bool = true, power::Real = 2.0, + normalized::Bool = false, window_normalized::Bool = false, + ) + +Create a spectrogram or a batch of spectrograms from a raw audio signal. + +# Arguments + +- `pad::Int`: + Then amount of padding to apply on both sides. +- `window_normalized::Bool`: + Whether to normalize the waveform by the window’s L2 energy. +- `power::Real`: + Exponent for the magnitude spectrogram (must be ≥ 0) + e.g., `1` for magnitude, `2` for power, etc. + If `0`, complex spectrum is returned instead. + +See [`stft`](@ref) for other arguments. + +# Returns + +Spectrogram in the shape `(T, F, B)`, where +`T` is the number of window hops and `F = n_fft ÷ 2 + 1`. +""" +function spectrogram(waveform; + pad::Int = 0, n_fft::Int, hop_length::Int, window, + center::Bool = true, power::Real = 2.0, + normalized::Bool = false, window_normalized::Bool = false, +) + pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);) + + # Pack batch dimensions. + sz = size(waveform) + spec_ = stft(reshape(waveform, (sz[1], :)); + n_fft, hop_length, window, center, normalized) + # Unpack batch dimensions. + spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...)) + window_normalized && (spec = spec .* inv(norm(window));) + + if power > 0 + p = real(eltype(spec)(power)) + spec = abs.(spec).^p + end + return spec +end + +""" + power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0) + +Convert a power spectrogram (amplitude squared) to decibel (dB) units. + +# Arguments + +- `s`: Input power. +- `ref`: Scalar w.r.t. which the input is scaled. +- `amin`: Minimum threshold for `s`. +- `top_db`: Threshold the output at `top_db` below the peak: + `max.(s_db, maximum(s_db) - top_db)`. + +# Returns + +`s_db ~= 10 * log10(s) - 10 * log10(ref)` +""" +function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0) + log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref))) + return max.(log_spec, maximum(log_spec) - top_db) +end + +""" + db_to_power(s_db; ref::Real = 1f0) + +Inverse of [`power_to_db`](@ref). +""" +function db_to_power(s_db; ref::Real = 1f0) + return ref .* 10f0.^(s_db .* 0.1f0) +end diff --git a/src/audio/stft.jl b/src/audio/stft.jl new file mode 100644 index 000000000..a5b84cff0 --- /dev/null +++ b/src/audio/stft.jl @@ -0,0 +1,206 @@ +""" + hamming_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, + α::T = T(0.54), β::T = T(0.46), + ) where T <: Real + +Hamming window function +(ref: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows). +Generalized version of `hann_window`. + +``w[n] = \\alpha - \\beta cos(\\frac{2 \\pi n}{N - 1})`` + +Where ``N`` is the window length. + +```julia +julia> lineplot(hamming_window(100); width=30, height=10) + ┌──────────────────────────────┐ + 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠚⠉⠉⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠁⠀⠀⠀⠀⠀⠈⢢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⡀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⡰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀│ + │⠀⠀⠀⢀⠴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀│ + │⠀⢀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⣀⠀│ + 0 │⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉│ + └──────────────────────────────┘ + ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀ +``` + +# Arguments: + +- `window_length::Int`: Size of the window. +- `::Type{T}`: Elemet type of the window. + +# Keyword Arguments: + +- `periodic::Bool`: If `true` (default), returns a window to be used as + periodic function. If `false`, return a symmetric window. + + Following always holds: + +```jldoctest +julia> N = 256; + +julia> hamming_window(N; periodic=true) ≈ hamming_window(N + 1; periodic=false)[1:end - 1] +true +``` +- `α::Real`: Coefficient α in the equation above. +- `β::Real`: Coefficient β in the equation above. + +# Returns: + +Vector of length `window_length` and eltype `T`. +""" +function hamming_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, + α::T = T(0.54), β::T = T(0.46), +) where T <: Real + window_length < 1 && throw(ArgumentError( + "`window_length` must be > 0, instead: `$window_length`.")) + + n::T = ifelse(periodic, window_length, window_length - 1) + scale = T(2) * π / n + return [α - β * cos(scale * T(k)) for k in 0:(window_length - 1)] +end + +""" + hann_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, + ) where T <: Real + +Hann window function +(ref: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows). + +``w[n] = \\frac{1}{2}[1 - cos(\\frac{2 \\pi n}{N - 1})]`` + +Where ``N`` is the window length. + +```julia +julia> lineplot(hann_window(100); width=30, height=10) + ┌──────────────────────────────┐ + 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠚⠉⠉⠉⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⢀⡜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀│ + │⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠣⡀⠀⠀│ + 0 │⣀⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢤⣀│ + └──────────────────────────────┘ + ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀ +``` + +# Arguments: + +- `window_length::Int`: Size of the window. +- `::Type{T}`: Elemet type of the window. + +# Keyword Arguments: + +- `periodic::Bool`: If `true` (default), returns a window to be used as + periodic function. If `false`, return a symmetric window. + + Following always holds: + +```jldoctest +julia> N = 256; + +julia> hann_window(N; periodic=true) ≈ hann_window(N + 1; periodic=false)[1:end - 1] +true + +julia> hann_window(N) ≈ hamming_window(N; α=0.5f0, β=0.5f0) +true +``` + +# Returns: + +Vector of length `window_length` and eltype `T`. +""" +function hann_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, +) where T <: Real + hamming_window(window_length, T; periodic, α=T(0.5), β=T(0.5)) +end + +""" + stft(x; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, + ) + +Short-time Fourier transform (STFT). + +The STFT computes the Fourier transform of short overlapping windows of the input, +giving frequency components of the signal as they change over time. + +``Y[\\omega, m] = \\sum_{k = 0}^{N - 1} \\text{window}[k] \\text{input}[m \\times \\text{hop length} + k] exp(-j \\frac{2 \\pi \\omega k}{\\text{n fft}})`` + +where ``N`` is the window length, +``\\omega`` is the frequency ``0 \\le \\omega < \\text{n fft}`` +and ``m`` is the index of the sliding window. + +# Arguments: + +- `x`: Input, must be either a 1D time sequence (`(L,)` shape) + or a 2D batch of time sequence (`(L, B)` shape). + +# Positional Arguments: + +- `n_fft::Int`: Size of Fourier transform. +- `hop_length::Int`: Distance between neighboring sliding window frames. +- `window`: Optional window function to apply. + Must be 1D vector `0 < length(window) ≤ n_fft`. + If window is shorter than `n_fft`, it is padded with zeros on both sides. + If `nothing` (default), then no window is applied. +- `center::Bool`: Whether to pad input on both sides so that ``t``-th frame + is centered at time ``t \\times \\text{hop length}``. + Padding is done with `pad_reflect` function. +- `normalized::Bool`: Whether to return normalized STFT, + i.e. multiplied with ``\\text{n fft}^{-0.5}``. + +# Returns: + +Complex array of shape `(n_fft, n_frames, B)`, +where `B` is the optional batch dimension. +""" +function stft end + +""" + istft(y; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, + return_complex::Bool = false, + original_length::Union{Nothing, Int} = nothing, + ) + +Inverse Short-time Fourier Transform. + +Return the least squares estimation of the original signal + +# Arguments: + +- `y`: Input complex array in the `(n_fft, n_frames, B)` shape. + Where `B` is the optional batch dimension. + +# Positional Arguments: + +- `n_fft::Int`: Size of Fourier transform. +- `hop_length::Int`: Distance between neighboring sliding window frames. +- `window`: Window function that was applied to the input of `stft`. + If `nothing` (default), then no window was applied. +- `center::Bool`: Whether input to `stft` was padded on both sides + so that ``t``-th frame is centered at time ``t \\times \\text{hop length}``. + Padding is done with `pad_reflect` function. +- `normalized::Bool`: Whether input to `stft` was normalized. +- `return_complex::Bool`: Whether the output should be complex, + or if the input should be assumed to derive from a real signal and window. +- `original_length::Union{Nothing, Int}`: Optional size of the first dimension + of the input to `stft`. Helps restoring the exact `stft` input size. + Otherwise, the array might be a bit shorter. +""" +function istft end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..b30ef0e13 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,24 @@ +[deps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" +ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" diff --git a/test/runtests.jl b/test/runtests.jl index b79a8a011..77f841d31 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,9 @@ using Adapt using ImageTransformations using Interpolations: Constant using KernelAbstractions +using FFTW import ReverseDiff as RD # used in `pooling.jl` +import Pkg const Test_Enzyme = VERSION <= v"1.10-" @@ -40,10 +42,11 @@ end cpu(x) = adapt(CPU(), x) -include("gather.jl") -include("scatter.jl") -include("upsample.jl") -include("rotation.jl") +include("testsuite/gather.jl") +include("testsuite/scatter.jl") +include("testsuite/upsample.jl") +include("testsuite/rotation.jl") +include("testsuite/spectral.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) @conditional_testset "Upsample" skip_tests begin @@ -58,12 +61,15 @@ function nnlib_testsuite(Backend; skip_tests = Set{String}()) @conditional_testset "Scatter" skip_tests begin scatter_testsuite(Backend) end + @conditional_testset "Spectral" skip_tests begin + spectral_testsuite(Backend) + end end @testset verbose=true "NNlib.jl" begin if get(ENV, "NNLIB_TEST_CPU", "true") == "true" - @testset "CPU" begin + @testset "CPU" begin @testset "Doctests" begin doctest(NNlib, manual=false) end @@ -133,6 +139,8 @@ end end if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + Pkg.add(["CUDA", "cuDNN"]) + using CUDA if CUDA.functional() @testset "CUDA" begin @@ -145,19 +153,20 @@ end end else @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" - end + end if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + Pkg.add("AMDGPU") + using AMDGPU AMDGPU.versioninfo() if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - @show AMDGPU.MIOpen.version() @testset "AMDGPU" begin nnlib_testsuite(ROCBackend) - AMDGPU.synchronize(; blocking=false) + AMDGPU.synchronize(; blocking=false, stop_hostcalls=true) include("ext_amdgpu/runtests.jl") - AMDGPU.synchronize(; blocking=false) + AMDGPU.synchronize(; blocking=false, stop_hostcalls=true) end else @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." diff --git a/test/gather.jl b/test/testsuite/gather.jl similarity index 100% rename from test/gather.jl rename to test/testsuite/gather.jl diff --git a/test/rotation.jl b/test/testsuite/rotation.jl similarity index 100% rename from test/rotation.jl rename to test/testsuite/rotation.jl diff --git a/test/scatter.jl b/test/testsuite/scatter.jl similarity index 100% rename from test/scatter.jl rename to test/testsuite/scatter.jl diff --git a/test/testsuite/spectral.jl b/test/testsuite/spectral.jl new file mode 100644 index 000000000..12e38cc4a --- /dev/null +++ b/test/testsuite/spectral.jl @@ -0,0 +1,151 @@ +function spectral_testsuite(Backend) + cpu(x) = adapt(CPU(), x) + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest + + @testset "Window functions" begin + for window_fn in (hann_window, hamming_window) + @inferred window_fn(10, Float32) + @inferred window_fn(10, Float64) + + w = window_fn(10) + @test length(w) == 10 + @test eltype(w) == Float32 + + wp = window_fn(10; periodic=false) + @test wp[1:5] ≈ reverse(wp[6:10]) + + @test window_fn(10; periodic=true) ≈ window_fn(10 + 1; periodic=false)[1:10] + end + end + + @testset "STFT" for batch in ((), (3,)) + @testset "Grads" begin + if Backend != CPU + x = rand(Float32, 16, batch...) + window = hann_window(16) + + gradtest_fn(s -> abs.(stft(s; n_fft=16)), x) + gradtest_fn((s, w) -> abs.(stft(s; n_fft=16, window=w)), x, window) + + x = rand(Float32, 2045, batch...) + n_fft = 256 + window = hann_window(n_fft) + gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w)), x, window) + gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, center=false)), x, window) + gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, normalized=true)), x, window) + end + end + + @testset "Batch $batch" begin + x = device(ones(Float32, 16, batch...)) + # TODO fix type stability for pad_reflect + # @inferred stft(x; n_fft=16) + + bd = ntuple(_ -> Colon(), length(batch)) + + y = stft(x; n_fft=16) + @test size(y) == (9, 5, batch...) + @test all(real(cpu(y))[1, :, bd...] .≈ 16) + + xx = istft(y; n_fft=16) + @test size(xx) == (16, batch...) + @test cpu(x) ≈ cpu(xx) + + # Test multiple hops. + x = device(rand(Float32, 2048, batch...)) + y = stft(x; n_fft=1024) + xx = istft(y; n_fft=1024) + @test cpu(x) ≈ cpu(xx) + + # Test odd sizes. + x = device(rand(Float32, 1111, batch...)) + y = stft(x; n_fft=256) + xx = istft(y; n_fft=256, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + + # Output from inverse is cropped on the right + # without knowing the original size. + xx = istft(y; n_fft=256) + @test length(xx) < length(x) + @test cpu(x)[[1:s for s in size(xx)]...] ≈ cpu(xx) + + # Test different options. + + # Normalized. + x = device(rand(Float32, 1234, batch...)) + y = stft(x; n_fft=512, normalized=true) + xx = istft(y; n_fft=512, normalized=true, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + + # With window. + window = device(hann_window(512)) + y = stft(x; n_fft=512, window) + xx = istft(y; n_fft=512, window, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + + # Hop. + for hop_length in (32, 33, 255, 256, 511, 512) + y = stft(x; n_fft=512, hop_length) + xx = istft(y; n_fft=512, hop_length, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + end + + # N FFT. + for n_fft in (32, 33, 64, 65, 128, 129, 512) + y = stft(x; n_fft) + xx = istft(y; n_fft, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + end + end + end + + @testset "Spectrogram" begin + x = device(rand(Float32, 1024)) + window = device(hann_window(1024)) + + y = stft(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + spec = spectrogram(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + + @test abs.(y).^2 ≈ spec + + # Batched. + x = device(rand(Float32, 1024, 3)) + spec = spectrogram(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + for i in 1:3 + y = stft(x[:, i]; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + @test abs.(y).^2 ≈ spec[:, :, i] + end + + if Backend != CPU + @testset "Grads" begin + for batch in ((), (3,)) + x = rand(Float32, 2045, batch...) + n_fft = 256 + window = hann_window(n_fft) + gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w), x, window) + gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, center=false), x, window) + gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, normalized=true), x, window) + end + end + end + end + + @testset "Power to dB" begin + x = device(rand(Float32, 1024)) + window = device(hann_window(1024)) + spec = spectrogram(x; pad=0, n_fft=1024, hop_length=128, window) + + @test spec ≈ NNlib.db_to_power(NNlib.power_to_db(spec)) + @inferred NNlib.power_to_db(spec) + @inferred NNlib.db_to_power(NNlib.power_to_db(spec)) + end +end diff --git a/test/upsample.jl b/test/testsuite/upsample.jl similarity index 100% rename from test/upsample.jl rename to test/testsuite/upsample.jl