Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Short-time Fourier transform and its inverse #587

Merged
merged 13 commits into from
Jul 4, 2024
8 changes: 0 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ jobs:
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: |
julia --color=yes --project=docs/ -e '
using NNlib
# using Pkg; Pkg.activate("docs")
using Documenter
using Documenter: doctest
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive=true)
doctest(NNlib)'
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
26 changes: 2 additions & 24 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.9.17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a weak dependence and move any function relying on FFTW under an extension?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -27,7 +28,7 @@ NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"

[compat]
AMDGPU = "0.8, 0.9"
AMDGPU = "0.9.4"
Adapt = "3.2, 4"
Atomix = "0.1"
CUDA = "4, 5"
Expand All @@ -42,26 +43,3 @@ Requires = "1.0"
Statistics = "1"
cuDNN = "1"
julia = "1.9"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils", "Interpolations", "ImageTransformations"]
5 changes: 5 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
26 changes: 14 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using Documenter, NNlib

DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive = true)
DocMeta.setdocmeta!(NNlib, :DocTestSetup,
:(using NNlib, UnicodePlots); recursive = true)

makedocs(modules = [NNlib],
sitename = "NNlib.jl",
doctest = false,
pages = ["Home" => "index.md",
"Reference" => "reference.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)
sitename = "NNlib.jl",
doctest = true,
pages = ["Home" => "index.md",
"Reference" => "reference.md",
"Audio" => "audio.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)

deploydocs(repo = "github.com/FluxML/NNlib.jl.git",
target = "build",
Expand Down
Binary file added docs/src/assets/jfk.flac
Binary file not shown.
57 changes: 57 additions & 0 deletions docs/src/audio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Reference

## Window functions

```@docs
hann_window
hamming_window
```

## Spectral

```@docs
stft
istft
NNlib.power_to_db
NNlib.db_to_power
```

## Spectrogram

```@docs
melscale_filterbanks
spectrogram
```

Example:

```@example 1
using NNlib
using FileIO
using Makie, CairoMakie
CairoMakie.activate!()

waveform, sampling_rate = load("./assets/jfk.flac")
fig = lines(reshape(waveform, :))
save("waveform.png", fig)

# Spectrogram.

n_fft = 1024
spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))
fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))
save("spectrogram.png", fig)

# Mel-scale spectrogram.

n_freqs = n_fft ÷ 2 + 1
fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))
mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)
fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])
save("mel-spectrogram.png", fig)
nothing # hide
```

|Waveform|Spectrogram|Mel Spectrogram|
|:---:|:---:|:---:|
|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)|
6 changes: 6 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Random
using Requires
using Statistics
using Statistics: mean
using FFTW

const libblas = Base.libblas_name

Expand Down Expand Up @@ -126,4 +127,9 @@ include("deprecations.jl")
include("rotation.jl")
export imrotate, ∇imrotate

include("audio/stft.jl")
include("audio/spectrogram.jl")
include("audio/mel.jl")
export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks

end # module NNlib
102 changes: 102 additions & 0 deletions src/audio/mel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
melscale_filterbanks(;
n_freqs::Int, n_mels::Int, sample_rate::Int,
fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2))

Create triangular Mel scale filter banks
(ref: https://en.wikipedia.org/wiki/Mel_scale).
Each column is a filterbank that highlights its own frequency.

# Arguments:

- `n_freqs::Int`: Number of frequencies to highlight.
- `n_mels::Int`: Number of mel filterbanks.
- `sample_rate::Int`: Sample rate of the audio waveform.
- `fmin::Float32`: Minimum frequency in Hz.
- `fmax::Float32`: Maximum frequency in Hz.

# Returns:

Filterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank.

```jldoctest
julia> n_mels = 8;

julia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000);

julia> plot = lineplot(fb[:, 1]);

julia> for i in 2:n_mels
lineplot!(plot, fb[:, i])
end

julia> plot
┌────────────────────────────────────────┐
1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│
│⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│
│⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│
│⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│
│⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│
│⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│
│⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│
│⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│
│⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│
│⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│
│⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│
0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│
└────────────────────────────────────────┘
⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀
```
"""
function melscale_filterbanks(;
n_freqs::Int, n_mels::Int, sample_rate::Int,
fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2),
)
mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax)
mel_points = range(mel_min, mel_max; length=n_mels + 2)

all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs))
freq_points = _mel_to_hz.(mel_points)
filter_banks = _triangular_filterbanks(freq_points, all_freqs)

if any(maximum(filter_banks; dims=1) .≈ 0f0)
@warn """At least one mel filterbank has all zero values.

Check warning on line 66 in src/audio/mel.jl

View check run for this annotation

Codecov / codecov/patch

src/audio/mel.jl#L66

Added line #L66 was not covered by tests
The value for `n_mels=$n_mels` may be set too high.
Or the value for `n_freqs=$n_freqs` may be set too low.
"""
end
return filter_banks
end

_hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700)))

_mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1))

"""
_triangular_filterbanks(
freq_points::Vector{Float32}, all_freqs::Vector{Float32})

Create triangular filter banks.

# Arguments:

- `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`.
- `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`.

# Returns:

Array of size `(n_freqs, n_filters)`.
"""
function _triangular_filterbanks(
freq_points::Vector{Float32}, all_freqs::Vector{Float32},
)
diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1])
slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :))

down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :))
up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :)
return max.(0f0, min.(down_slopes, up_slopes))
end
79 changes: 79 additions & 0 deletions src/audio/spectrogram.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
spectrogram(waveform;
pad::Int = 0, n_fft::Int, hop_length::Int, window,
center::Bool = true, power::Real = 2.0,
normalized::Bool = false, window_normalized::Bool = false,
)

pxl-th marked this conversation as resolved.
Show resolved Hide resolved
Create a spectrogram or a batch of spectrograms from a raw audio signal.

# Arguments

- `pad::Int`:
Then amount of padding to apply on both sides.
- `window_normalized::Bool`:
Whether to normalize the waveform by the window’s L2 energy.
- `power::Real`:
Exponent for the magnitude spectrogram (must be ≥ 0)
e.g., `1` for magnitude, `2` for power, etc.
If `0`, complex spectrum is returned instead.

See [`stft`](@ref) for other arguments.

# Returns

Spectrogram in the shape `(T, F, B)`, where
`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.
"""
function spectrogram(waveform;
pad::Int = 0, n_fft::Int, hop_length::Int, window,
center::Bool = true, power::Real = 2.0,
normalized::Bool = false, window_normalized::Bool = false,
)
pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)

# Pack batch dimensions.
sz = size(waveform)
spec_ = stft(reshape(waveform, (sz[1], :));
n_fft, hop_length, window, center, normalized)
# Unpack batch dimensions.
spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...))
window_normalized && (spec = spec .* inv(norm(window));)

if power > 0
p = real(eltype(spec)(power))
spec = abs.(spec).^p
end
return spec
end

"""
power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)

Convert a power spectrogram (amplitude squared) to decibel (dB) units.

# Arguments

- `s`: Input power.
- `ref`: Scalar w.r.t. which the input is scaled.
- `amin`: Minimum threshold for `s`.
- `top_db`: Threshold the output at `top_db` below the peak:
`max.(s_db, maximum(s_db) - top_db)`.

# Returns

`s_db ~= 10 * log10(s) - 10 * log10(ref)`
"""
function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)
log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref)))
return max.(log_spec, maximum(log_spec) - top_db)
end

"""
db_to_power(s_db; ref::Real = 1f0)

Inverse of [`power_to_db`](@ref).
"""
function db_to_power(s_db; ref::Real = 1f0)
return ref .* 10f0.^(s_db .* 0.1f0)
end
Loading
Loading