-
Notifications
You must be signed in to change notification settings - Fork 618
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
97 changed files
with
30,990 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# PicoAudio: Enabling Precise Timing and Frequency Controllability of Audio Events in Text-to-audio Generation | ||
Duplicate of [github repo](https://github.com/zeyuxie29/PicoAudio) | ||
[![arXiv](https://img.shields.io/badge/arXiv-2407.02869v2-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.02869v2) | ||
[![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://zeyuxie29.github.io/PicoAudio.github.io/) | ||
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ZeyuXie/PicoAudio) | ||
[![Hugging Face data](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-blue)](https://huggingface.co/datasets/ZeyuXie/PicoAudio/tree/main) | ||
|
||
**Bullet contribution**: | ||
* A data simulation pipeline tailored specifically for controllable audio generation frameworks; | ||
* Propose a timing-controllable audio generation framework, enabling precise control over the timing and frequency of sound event; | ||
* Achieve any precise control related to timing by integrating of large language models. | ||
|
||
## Inference | ||
You can see the demo on the website [Huggingface Online Inference](https://huggingface.co/spaces/ZeyuXie/PicoAudio) and [Github Demo](https://zeyuxie29.github.io/PicoAudio.github.io). | ||
Or you can use the *"inference.py"* script provided by website [Huggingface Inference](https://huggingface.co/spaces/ZeyuXie/PicoAudio/tree/main) to generate. | ||
Huggingface Online Inference uses Gemini as a preprocessor, and we also provide a GPT preprocessing script consistent with the paper in *"llm_preprocess.py"* | ||
|
||
## Simulated Dataset | ||
Simulated data can be downloaded from (1) [HuggingfaceDataset](https://huggingface.co/datasets/ZeyuXie/PicoAudio/tree/main) or (2) [BaiduNetDisk](https://pan.baidu.com/s/1rGrcjtQCEYFpr3o6y9wI8Q?pwd=pico) with the extraction code "pico". | ||
The metadata is stored in *"data/meta_data/{}.json"*, one instance is as follows: | ||
```python | ||
{ | ||
"filepath": "data/multi_event_test/syn_1.wav", | ||
"onoffCaption": "cat meowing at 0.5-2.0, 3.0-4.5 and whistling at 5.0-6.5 and explosion at 7.0-8.0, 8.5-9.5", | ||
"frequencyCaption": "cat meowing two times and whistling one times and explosion two times" | ||
} | ||
``` | ||
where: | ||
* *"filepath"* indicates the path to the audio file. | ||
* *"frequencyCaption"* contains information about the occurrence frequency. | ||
* *"onoffCaption"* contains on- & off-set information. | ||
* For test file *"test-frequency-control_onoffFromGpt_{}.json"*, the *"onoffCaption"* is derived from *"frequencyCaption"* transformed by GPT-4, which is used for evaluation in the frequency control task. | ||
|
||
## Training | ||
Download data into the *"data"* folder. | ||
The training and inference code can be found in the *"picoaudio"* folder. | ||
```shell | ||
cd picoaudio | ||
pip install -r requirements.txt | ||
``` | ||
To start traning: | ||
```python | ||
accelerate launch runner/controllable_train.py | ||
``` | ||
|
||
## Acknowledgement | ||
Our code referred to the [AudioLDM](https://github.com/haoheliu/AudioLDM) and [Tango](https://github.com/declare-lab/tango). We appreciate their open-sourcing of their code. | ||
|
||
<!-- | ||
### Hi there 👋 | ||
**PicoAudio/PicoAudio** is a ✨ _special_ ✨ repository because its `README.md` (this file) appears on your GitHub profile. | ||
Here are some ideas to get you started: | ||
- 🔭 I’m currently working on ... | ||
- 🌱 I’m currently learning ... | ||
- 👯 I’m looking to collaborate on ... | ||
- 🤔 I’m looking for help with ... | ||
- 💬 Ask me about ... | ||
- 📫 How to reach me: ... | ||
- 😄 Pronouns: ... | ||
- ⚡ Fun fact: ... | ||
--> |
200 changes: 200 additions & 0 deletions
200
...ally_controllable_tta/data/meta_data/test-frequency-control_onoffFromGpt_multi-event.json
Large diffs are not rendered by default.
Oops, something went wrong.
400 changes: 400 additions & 0 deletions
400
...lly_controllable_tta/data/meta_data/test-frequency-control_onoffFromGpt_single-event.json
Large diffs are not rendered by default.
Oops, something went wrong.
200 changes: 200 additions & 0 deletions
200
models/temporally_controllable_tta/data/meta_data/test-onoff-control_multi-event.json
Large diffs are not rendered by default.
Oops, something went wrong.
400 changes: 400 additions & 0 deletions
400
models/temporally_controllable_tta/data/meta_data/test-onoff-control_single-event.json
Large diffs are not rendered by default.
Oops, something went wrong.
5,000 changes: 5,000 additions & 0 deletions
5,000
models/temporally_controllable_tta/data/meta_data/train.json
Large diffs are not rendered by default.
Oops, something went wrong.
8 changes: 8 additions & 0 deletions
8
models/temporally_controllable_tta/picoaudio/audioldm/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .ldm import LatentDiffusion | ||
from .utils import seed_everything, save_wave, get_time, get_duration | ||
from .pipeline import * | ||
|
||
|
||
|
||
|
||
|
183 changes: 183 additions & 0 deletions
183
models/temporally_controllable_tta/picoaudio/audioldm/__main__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
#!/usr/bin/python3 | ||
import os | ||
from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration | ||
import argparse | ||
|
||
CACHE_DIR = os.getenv( | ||
"AUDIOLDM_CACHE_DIR", | ||
os.path.join(os.path.expanduser("~"), ".cache/audioldm")) | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--mode", | ||
type=str, | ||
required=False, | ||
default="generation", | ||
help="generation: text-to-audio generation; transfer: style transfer", | ||
choices=["generation", "transfer"] | ||
) | ||
|
||
parser.add_argument( | ||
"-t", | ||
"--text", | ||
type=str, | ||
required=False, | ||
default="", | ||
help="Text prompt to the model for audio generation", | ||
) | ||
|
||
parser.add_argument( | ||
"-f", | ||
"--file_path", | ||
type=str, | ||
required=False, | ||
default=None, | ||
help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", | ||
) | ||
|
||
parser.add_argument( | ||
"--transfer_strength", | ||
type=float, | ||
required=False, | ||
default=0.5, | ||
help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", | ||
) | ||
|
||
parser.add_argument( | ||
"-s", | ||
"--save_path", | ||
type=str, | ||
required=False, | ||
help="The path to save model output", | ||
default="./output", | ||
) | ||
|
||
parser.add_argument( | ||
"--model_name", | ||
type=str, | ||
required=False, | ||
help="The checkpoint you gonna use", | ||
default="audioldm-s-full", | ||
choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"] | ||
) | ||
|
||
parser.add_argument( | ||
"-ckpt", | ||
"--ckpt_path", | ||
type=str, | ||
required=False, | ||
help="The path to the pretrained .ckpt model", | ||
default=None, | ||
) | ||
|
||
parser.add_argument( | ||
"-b", | ||
"--batchsize", | ||
type=int, | ||
required=False, | ||
default=1, | ||
help="Generate how many samples at the same time", | ||
) | ||
|
||
parser.add_argument( | ||
"--ddim_steps", | ||
type=int, | ||
required=False, | ||
default=200, | ||
help="The sampling step for DDIM", | ||
) | ||
|
||
parser.add_argument( | ||
"-gs", | ||
"--guidance_scale", | ||
type=float, | ||
required=False, | ||
default=2.5, | ||
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", | ||
) | ||
|
||
parser.add_argument( | ||
"-dur", | ||
"--duration", | ||
type=float, | ||
required=False, | ||
default=10.0, | ||
help="The duration of the samples", | ||
) | ||
|
||
parser.add_argument( | ||
"-n", | ||
"--n_candidate_gen_per_text", | ||
type=int, | ||
required=False, | ||
default=3, | ||
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", | ||
) | ||
|
||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
required=False, | ||
default=42, | ||
help="Change this value (any integer number) will lead to a different generation result.", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
if(args.ckpt_path is not None): | ||
print("Warning: ckpt_path has no effect after version 0.0.20.") | ||
|
||
assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" | ||
|
||
mode = args.mode | ||
if(mode == "generation" and args.file_path is not None): | ||
mode = "generation_audio_to_audio" | ||
if(len(args.text) > 0): | ||
print("Warning: You have specified the --file_path. --text will be ignored") | ||
args.text = "" | ||
|
||
save_path = os.path.join(args.save_path, mode) | ||
|
||
if(args.file_path is not None): | ||
save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) | ||
|
||
text = args.text | ||
random_seed = args.seed | ||
duration = args.duration | ||
guidance_scale = args.guidance_scale | ||
n_candidate_gen_per_text = args.n_candidate_gen_per_text | ||
|
||
os.makedirs(save_path, exist_ok=True) | ||
audioldm = build_model(model_name=args.model_name) | ||
|
||
if(args.mode == "generation"): | ||
waveform = text_to_audio( | ||
audioldm, | ||
text, | ||
args.file_path, | ||
random_seed, | ||
duration=duration, | ||
guidance_scale=guidance_scale, | ||
ddim_steps=args.ddim_steps, | ||
n_candidate_gen_per_text=n_candidate_gen_per_text, | ||
batchsize=args.batchsize, | ||
) | ||
|
||
elif(args.mode == "transfer"): | ||
assert args.file_path is not None | ||
assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path | ||
waveform = style_transfer( | ||
audioldm, | ||
text, | ||
args.file_path, | ||
args.transfer_strength, | ||
random_seed, | ||
duration=duration, | ||
guidance_scale=guidance_scale, | ||
ddim_steps=args.ddim_steps, | ||
batchsize=args.batchsize, | ||
) | ||
waveform = waveform[:,None,:] | ||
|
||
save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) |
2 changes: 2 additions & 0 deletions
2
models/temporally_controllable_tta/picoaudio/audioldm/audio/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .tools import wav_to_fbank, read_wav_file | ||
from .stft import TacotronSTFT |
100 changes: 100 additions & 0 deletions
100
models/temporally_controllable_tta/picoaudio/audioldm/audio/audio_processing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
import numpy as np | ||
import librosa.util as librosa_util | ||
from scipy.signal import get_window | ||
|
||
|
||
def window_sumsquare( | ||
window, | ||
n_frames, | ||
hop_length, | ||
win_length, | ||
n_fft, | ||
dtype=np.float32, | ||
norm=None, | ||
): | ||
""" | ||
# from librosa 0.6 | ||
Compute the sum-square envelope of a window function at a given hop length. | ||
This is used to estimate modulation effects induced by windowing | ||
observations in short-time fourier transforms. | ||
Parameters | ||
---------- | ||
window : string, tuple, number, callable, or list-like | ||
Window specification, as in `get_window` | ||
n_frames : int > 0 | ||
The number of analysis frames | ||
hop_length : int > 0 | ||
The number of samples to advance between frames | ||
win_length : [optional] | ||
The length of the window function. By default, this matches `n_fft`. | ||
n_fft : int > 0 | ||
The length of each analysis frame. | ||
dtype : np.dtype | ||
The data type of the output | ||
Returns | ||
------- | ||
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` | ||
The sum-squared envelope of the window function | ||
""" | ||
if win_length is None: | ||
win_length = n_fft | ||
|
||
n = n_fft + hop_length * (n_frames - 1) | ||
x = np.zeros(n, dtype=dtype) | ||
|
||
# Compute the squared window at the desired length | ||
win_sq = get_window(window, win_length, fftbins=True) | ||
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 | ||
win_sq = librosa_util.pad_center(win_sq, n_fft) | ||
|
||
# Fill the envelope | ||
for i in range(n_frames): | ||
sample = i * hop_length | ||
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] | ||
return x | ||
|
||
|
||
def griffin_lim(magnitudes, stft_fn, n_iters=30): | ||
""" | ||
PARAMS | ||
------ | ||
magnitudes: spectrogram magnitudes | ||
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods | ||
""" | ||
|
||
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) | ||
angles = angles.astype(np.float32) | ||
angles = torch.autograd.Variable(torch.from_numpy(angles)) | ||
signal = stft_fn.inverse(magnitudes, angles).squeeze(1) | ||
|
||
for i in range(n_iters): | ||
_, angles = stft_fn.transform(signal) | ||
signal = stft_fn.inverse(magnitudes, angles).squeeze(1) | ||
return signal | ||
|
||
|
||
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): | ||
""" | ||
PARAMS | ||
------ | ||
C: compression factor | ||
""" | ||
return normalize_fun(torch.clamp(x, min=clip_val) * C) | ||
|
||
|
||
def dynamic_range_decompression(x, C=1): | ||
""" | ||
PARAMS | ||
------ | ||
C: compression factor used to compress | ||
""" | ||
return torch.exp(x) / C |
Oops, something went wrong.