From b040d93108cac4e6c50769b511eb2e20cdc991de Mon Sep 17 00:00:00 2001 From: aboombadev Date: Thu, 3 Oct 2024 13:42:49 -0700 Subject: [PATCH] make usageparams a class --- .gitignore | 2 +- paper/paper.md | 2 ++ .../audio_processing_validation.py | 13 ++++---- src/data_processing/encode_audio_data.py | 9 +++--- src/generate.py | 2 +- src/stereo_sample_gan.py | 7 +++-- src/usage_params.py | 30 +++++++++++-------- src/utils/file_helpers.py | 7 +++-- src/utils/generation_helpers.py | 26 ++++++++-------- src/utils/signal_helpers.py | 24 +++++++-------- 10 files changed, 70 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index 5c6dbd5..5492b88 100644 --- a/.gitignore +++ b/.gitignore @@ -165,7 +165,7 @@ cython_debug/ data/ .vscode/ .DS_Store -outputs/training_progress +outputs/spectrogram_images outputs/generated_audio.wav outputs/generated_audio_[0-9]* test.wav \ No newline at end of file diff --git a/paper/paper.md b/paper/paper.md index b36ba6a..f29eba3 100644 --- a/paper/paper.md +++ b/paper/paper.md @@ -6,6 +6,8 @@ Continuation of UCLA COSMOS 2024 Research ## 1. Abstract +Existing convolutional aproaches to audio generation often are limited to producing low-fidelity, single-channel, monophonic audio, while demanding significant computational resources for both training and inference. To address these challenges, this work introduces StereoSampleGAN, a novel audio generation architecture that combines a Deep Convolutional Wasserstein GAN (WGAN), attention mechanisms, and loss optimization techniques. StereoSampleGAN allows high-fidelity, stereo audio generation for audio samples while being remaining computationally efficient. Training on three distinct sample datasets with varying spectral overlap–two of kick drums and one of tonal one shots–StereoSampleGAN demonstrates promising results in generating high quality simple stereo sounds. However, the model displays notable limitations when generating audio structures with greater amounts of spectral variation, indicating areas for further improvement. + ## 2. Introduction ## 3. Data Manipulation diff --git a/src/data_processing/audio_processing_validation.py b/src/data_processing/audio_processing_validation.py index 06a2b5e..096c190 100644 --- a/src/data_processing/audio_processing_validation.py +++ b/src/data_processing/audio_processing_validation.py @@ -4,21 +4,24 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import random -from usage_params import compiled_data_path, training_sample_length +from usage_params import UsageParams from utils.signal_helpers import ( stft_and_istft, ) +# Initialize sample selection +params = UsageParams() + def choose_random_sample(): audio_files = [ f - for f in os.listdir(compiled_data_path) - if os.path.isfile(os.path.join(compiled_data_path, f)) + for f in os.listdir(params.compiled_data_path) + if os.path.isfile(os.path.join(params.compiled_data_path, f)) ] if audio_files: sample_name = random.choice(audio_files) - sample_path = os.path.join(compiled_data_path, sample_name) + sample_path = os.path.join(params.compiled_data_path, sample_name) return sample_path, sample_name else: return None, None @@ -27,4 +30,4 @@ def choose_random_sample(): # Analyze fourier transform audio degradation sample_path, sample_name = choose_random_sample() -stft_and_istft(sample_path, "test", training_sample_length) +stft_and_istft(sample_path, "test", params.training_sample_length) diff --git a/src/data_processing/encode_audio_data.py b/src/data_processing/encode_audio_data.py index 7167e01..ed5c389 100644 --- a/src/data_processing/encode_audio_data.py +++ b/src/data_processing/encode_audio_data.py @@ -8,18 +8,19 @@ load_loudness_data, ) from utils.signal_helpers import encode_sample_directory -from usage_params import training_audio_dir, compiled_data_path +from usage_params import UsageParams # Encode audio samples +params = UsageParams() if len(sys.argv) > 1: visualize = sys.argv[1].lower() == "visualize" else: visualize = False -encode_sample_directory(training_audio_dir, compiled_data_path, visualize) +encode_sample_directory(params.training_audio_dir, params.compiled_data_path, visualize) real_data = load_loudness_data( - compiled_data_path + params.compiled_data_path ) # datapts, channels, frames, freq bins -print(f"{training_audio_dir} data shape: {str(real_data.shape)}") +print(f"{params.training_audio_dir} data shape: {str(real_data.shape)}") diff --git a/src/generate.py b/src/generate.py index 446db5f..244e3ac 100644 --- a/src/generate.py +++ b/src/generate.py @@ -6,4 +6,4 @@ # Generate based on usage_params -generate_audio(model_to_generate_with, training_sample_length) +generate_audio(model_to_generate_with, training_sample_length, True) diff --git a/src/stereo_sample_gan.py b/src/stereo_sample_gan.py index 4d107ca..e3c841b 100644 --- a/src/stereo_sample_gan.py +++ b/src/stereo_sample_gan.py @@ -10,20 +10,23 @@ load_loudness_data, ) -from usage_params import compiled_data_path +from usage_params import UsageParams # Constants LR_G = 0.003 LR_C = 0.004 # Load data -all_spectrograms = load_loudness_data(compiled_data_path) +params = UsageParams() +all_spectrograms = load_loudness_data(params.compiled_data_path) all_spectrograms = torch.FloatTensor(all_spectrograms) + train_size = int(0.8 * len(all_spectrograms)) val_size = len(all_spectrograms) - train_size train_dataset, val_dataset = random_split( TensorDataset(all_spectrograms), [train_size, val_size] ) + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) diff --git a/src/usage_params.py b/src/usage_params.py index a103673..a4520e2 100644 --- a/src/usage_params.py +++ b/src/usage_params.py @@ -1,16 +1,22 @@ # Main params -audio_generation_count = 2 # Audio examples to generate +class UsageParams: + def __init__(self): + self.audio_generation_count = 2 # Audio examples to generate -# Training params -training_sample_length = 1.5 # seconds -outputs_dir = "outputs" # Where to save your generated audio & model + # Training params + self.training_sample_length = 1.5 # seconds + self.outputs_dir = "outputs" # Where to save your generated audio & model -model_save_name = "StereoSampleGAN-InstrumentOneShot" # What to name your model save -training_audio_dir = "data/one_shots" # Your training data path -compiled_data_path = "data/compiled_data.npy" # Your compiled data/output path -model_save_path = f"{outputs_dir}/{model_save_name}.pth" + self.model_save_name = ( + "StereoSampleGAN-InstrumentOneShot" # What to name your model save + ) + self.training_audio_dir = "data/one_shots" # Your training data path + self.compiled_data_path = ( + "data/compiled_data.npy" # Your compiled data/output path + ) + self.model_save_path = f"{self.outputs_dir}/{self.model_save_name}.pth" -# Generating audio -model_to_generate_with = model_save_path # Generation model path -generated_audio_name = "generated_audio" # Output file name -visualize_generated = True # Show generated audio spectrograms + # Generating audio + self.model_to_generate_with = self.model_save_path # Generation model path + self.generated_audio_name = "generated_audio" # Output file name + self.visualize_generated = True # Show generated audio spectrograms diff --git a/src/utils/file_helpers.py b/src/utils/file_helpers.py index 03c3f3a..ae4e960 100644 --- a/src/utils/file_helpers.py +++ b/src/utils/file_helpers.py @@ -3,9 +3,10 @@ import torch import soundfile as sf -from usage_params import model_save_path +from usage_params import UsageParams # Constants +params = UsageParams() GLOBAL_SR = 44100 @@ -26,9 +27,9 @@ def save_model(model): # Save model torch.save( model.state_dict(), - model_save_path, + params.model_save_path, ) - print(f"Model saved at {model_save_path}") + print(f"Model saved at {params.model_save_path}") def get_device(): diff --git a/src/utils/generation_helpers.py b/src/utils/generation_helpers.py index 59bd46d..fb2d918 100644 --- a/src/utils/generation_helpers.py +++ b/src/utils/generation_helpers.py @@ -1,18 +1,16 @@ import os import torch from architecture import Generator, LATENT_DIM -from usage_params import ( - outputs_dir, - generated_audio_name, - audio_generation_count, - visualize_generated, -) +from usage_params import UsageParams from utils.file_helpers import get_device, save_audio from utils.signal_helpers import audio_to_norm_db, graph_spectrogram, norm_db_to_audio # Generation function -def generate_audio(generation_model_save, len_audio_in): +params = UsageParams() + + +def generate_audio(generation_model_save, len_audio_in, save_images=False): device = get_device() generator = Generator() @@ -26,7 +24,7 @@ def generate_audio(generation_model_save, len_audio_in): generator.eval() # Generate audio - z = torch.randn(audio_generation_count, LATENT_DIM, 1, 1) + z = torch.randn(params.audio_generation_count, LATENT_DIM, 1, 1) with torch.no_grad(): generated_output = generator(z) @@ -34,16 +32,20 @@ def generate_audio(generation_model_save, len_audio_in): print("Generated output shape:", generated_output.shape) # Visualize and save audio - for i in range(audio_generation_count): + for i in range(params.audio_generation_count): current_sample = generated_output[i] audio_info = norm_db_to_audio(current_sample, len_audio_in) audio_save_path = os.path.join( - outputs_dir, f"{generated_audio_name}_{i + 1}.wav" + params.outputs_dir, f"{params.generated_audio_name}_{i + 1}.wav" ) save_audio(audio_save_path, audio_info) - if visualize_generated is True: + if params.visualize_generated is True: vis_signal_after_istft = audio_to_norm_db(audio_info) - graph_spectrogram(vis_signal_after_istft, "generated audio (after istft)") + graph_spectrogram( + vis_signal_after_istft, + f"{params.generated_audio_name}_{i + 1}", + save_images, + ) diff --git a/src/utils/signal_helpers.py b/src/utils/signal_helpers.py index 5f8c414..7e8ae11 100644 --- a/src/utils/signal_helpers.py +++ b/src/utils/signal_helpers.py @@ -10,10 +10,7 @@ import plotly.subplots as sp import scipy -from usage_params import ( - training_sample_length, - outputs_dir, -) +from usage_params import UsageParams from utils.file_helpers import ( GLOBAL_SR, delete_DSStore, @@ -26,8 +23,9 @@ DATA_SHAPE = 256 # STFT Helpers +params = UsageParams() GLOBAL_WIN = 510 -GLOBAL_HOP = int(training_sample_length * GLOBAL_SR) // (DATA_SHAPE - 1) +GLOBAL_HOP = int(params.training_sample_length * GLOBAL_SR) // (DATA_SHAPE - 1) window = scipy.signal.windows.kaiser(GLOBAL_WIN, beta=12) @@ -126,7 +124,9 @@ def load_audio(path): y, sr = librosa.load(path, sr=GLOBAL_SR, mono=False) if y.ndim == 1: y = np.stack((y, y), axis=0) - y = librosa.util.fix_length(y, size=int(training_sample_length * GLOBAL_SR), axis=1) + y = librosa.util.fix_length( + y, size=int(params.training_sample_length * GLOBAL_SR), axis=1 + ) return y @@ -165,7 +165,7 @@ def scale_data_to_range(data, new_min, new_max): # Validation helpers -def graph_spectrogram(audio_data, sample_name, save=False): +def graph_spectrogram(audio_data, sample_name, save_images=False): fig = sp.make_subplots(rows=2, cols=1) for i in range(2): @@ -197,19 +197,19 @@ def graph_spectrogram(audio_data, sample_name, save=False): ) fig.update_layout(title_text=f"{sample_name}") - if save is False: + if save_images is False: fig.show() else: - fig.write_image(f"outputs/training_progress/{sample_name}") + fig.write_image(f"outputs/spectrogram_images/{sample_name}.png") def generate_sine_impulses(num_impulses=1, outPath="model"): amplitude = 1 for i in range(num_impulses): - t = np.arange(0, training_sample_length, 1 / GLOBAL_SR) + t = np.arange(0, params.training_sample_length, 1 / GLOBAL_SR) freq = np.random.uniform(0, 20000) audio_wave = amplitude * np.sin(2 * np.pi * freq * t) - num_samples = int(training_sample_length * GLOBAL_SR) + num_samples = int(params.training_sample_length * GLOBAL_SR) audio_signal = np.zeros(num_samples) audio_wave = audio_wave[:num_samples] @@ -240,7 +240,7 @@ def stft_and_istft(path, file_name, len_audio_in): vis_istft.shape, ) - save_path = os.path.join(outputs_dir, f"{file_name}.wav") + save_path = os.path.join(params.outputs_dir, f"{file_name}.wav") save_audio(save_path, istft) graph_spectrogram(stft, "stft")