Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add audio spectrogram transformer, and full audio clip #406

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

lucidrains
Copy link
Contributor

@lucidrains lucidrains commented Feb 3, 2023

for building out MuLaN

import torch
from src.open_clip.transformer import AudioSpectrogramTransformer

model = AudioSpectrogramTransformer(
    image_size = 256,
    patch_size = 16,
    width = 512,
    heads = 8,
    mlp_ratio = 4,
    layers = 1,
    output_dim = 512
)

wav = torch.randn(1, 1024)

embed = model(wav) # (1, 512)

Now one can do

import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg

mulan = AudioCLIP(
    embed_dim = 512,
    audio_cfg = CLIPAudioCfg(),
    text_cfg = CLIPTextCfg()
)

wav = torch.randn(2, 1024)
text = torch.randint(0, 10, (2, 77))

audio_latents, text_latents, _  = mulan(wav, text)

print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)

…hich it is, as the premise is to change the waveform to a 2d "image" and patchify and attend)
@lucidrains lucidrains changed the title add audio spectrogram transformer (probably reuse VisionTransformer code) - wip add audio spectrogram transformer - wip Feb 4, 2023
@lucidrains lucidrains mentioned this pull request Feb 4, 2023
@rwightman
Copy link
Collaborator

@lucidrains awesome, we should probably put this audio specific stuff in a new file, was thinking of splitting the other sub-transformers at some point too ... audio_transformer.py ?

@lucidrains
Copy link
Contributor Author

@rwightman sure, by modality, or by functionality, or both, either way is fine just let me know

@lucidrains
Copy link
Contributor Author

will still need to add the functions for generating from cfg as well as the full AudioClip

perhaps by modality is good

@rwightman
Copy link
Collaborator

yeah was thinking modality, leave base transformer as the parent, and split off modality specific transformers, at least in this case audio since it's new, can split the others later as other PR are probably based on current structure

@lucidrains
Copy link
Contributor Author

You got it, will make the changes next week

@lucidrains
Copy link
Contributor Author

Have a bunch of meetings with people around the valley this week, I'll get around to finishing this next week

@lucidrains lucidrains changed the title add audio spectrogram transformer - wip add audio spectrogram transformer, and full audio clip Feb 6, 2023
@lukewys
Copy link

lukewys commented Feb 6, 2023

Hi @lucidrains the current code looks great! Feel free to ping us (Ke and I) when you are finished!

@RetroCirce
Copy link

Hi @lucidrains Currently we briefly scanned your code and it looks great to us. After you finish the code, just let us know. We will go mainly over the spec-augment (time masking, freq masking, screeching, etc.) and hyperparameters on the spectrogram transformer. If you provide me the specific location in your code, that would be better.

Thanks!

@lucidrains
Copy link
Contributor Author

lucidrains commented Feb 7, 2023

@lukewys @RetroCirce Hello Yusong and Ke! Thank you so much for offering your audio expertise; it is more helpful than you realize

The hyperparameters that I am unsure about are listed here to here. But also whatever you think are reasonable default values would be good too!


# audio clip

class AudioCLIP(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should decide whether to extend CLIP

similarly, decide whether to just extend CoCa to AudioCoCa and override the visual modality transformer

@lucidrains
Copy link
Contributor Author

lucidrains commented Feb 18, 2023

also, decided to keep a lot of the image in there, in case there is a lot of logic in the library using the encode_image or accessing .visual. we are technically treating the audio as a 2d image (time and frequency) anyways

@marianna13
Copy link

Hi @lucidrains ! Can you use riffusion spectrogram as input in the encode_image function?

@lucidrains
Copy link
Contributor Author

@marianna13 oh hey Marianna! good to hear from you

yes, it should be able to accept spectrograms (you just have to pass in a tensor of shape batch, freqs, time)

@lucidrains
Copy link
Contributor Author

@marianna13 can you make sure the following code can run

import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg

mulan = AudioCLIP(
    embed_dim = 512,
    audio_cfg = CLIPAudioCfg(),
    text_cfg = CLIPTextCfg()
)

spectrogram = torch.randn(2, 32, 1024)
text = torch.randint(0, 10, (2, 77))

audio_latents, text_latents, _  = mulan(spectrogram, text)

print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)

@marianna13
Copy link

@lucidrains no, unfortunately I get this error: RuntimeError: Given groups=1, weight of size [768, 3, 16, 16], expected input[2, 1, 32, 1024] to have 3 channels, but got 1 channels instead
Can you also please tell me if it's possible to run encode_image over a batch of images? I have found out that the input should have 3 dimensions, right? https://github.com/lucidrains/open_clip/blob/audio-compatible/src/open_clip/audio.py#L107

@lucidrains
Copy link
Contributor Author

@marianna13 ohh, what is the shape of the input tensor you are passing in? i thought spectrograms only have 1 channel, but i am not really an audio expert

@lucidrains
Copy link
Contributor Author

@marianna13 i can make it accommodate 3 channels, if that is the case

@lucidrains
Copy link
Contributor Author

@marianna13

import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg

mulan = AudioCLIP(
    embed_dim = 512,
    audio_cfg = CLIPAudioCfg(channels = 3),
    text_cfg = CLIPTextCfg(),
)

spectrogram = torch.randn(2, 3, 32, 1024)
text = torch.randint(0, 10, (2, 77))

audio_latents, text_latents, _  = mulan(spectrogram, text)

print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)

@lucidrains
Copy link
Contributor Author

@lucidrains no, unfortunately I get this error: RuntimeError: Given groups=1, weight of size [768, 3, 16, 16], expected input[2, 1, 32, 1024] to have 3 channels, but got 1 channels instead Can you also please tell me if it's possible to run encode_image over a batch of images? I have found out that the input should have 3 dimensions, right? https://github.com/lucidrains/open_clip/blob/audio-compatible/src/open_clip/audio.py#L107

hmm, how are you testing this? are you checking out the entire PR? this error may also suggest you don't have the necessary changes to the vision transformer (to be able to configure it to have 1 channel)

@marianna13
Copy link

@lucidrains I checked again now it works! (I just forgot that I've made changes to the code) sorry, that's my bad!

@lucidrains
Copy link
Contributor Author

@marianna13 oh great! can you confirm that you are using 1 channel then? i should revert that commit

@lucidrains
Copy link
Contributor Author

lucidrains commented Mar 4, 2023

@marianna13 i'll add the MulanCoCa version tomorrow too, so we can possibly leap frog the state of the art going on within google

@marianna13
Copy link

@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :(

@marianna13
Copy link

@marianna13 i'll add the MulanCoCa version tomorrow too, so we can possibly leap frog the state of the art going on within google

That's great! Thank you :)

@lucidrains
Copy link
Contributor Author

@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :(

oh, that's odd, what is the shape of the batch of images you are sending in?

@lucidrains
Copy link
Contributor Author

@marianna13 if you can show me a reproducible error like the sample script above, i can fix it

@marianna13
Copy link

Hi @lucidrains ! Sorry for the late reply. Here's the code I'm using:

import torch
import cv2
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
import webdataset as wds
import sys
import os
from torchvision import transforms
from PIL import Image
import numpy as np
import time

transform = transforms.Compose([
    transforms.ToTensor()
])


def preprocess(sample:tuple):
    image, json_data = sample
    # json_data = json.loads(json_data.decode())
   
    audio_meta = json_data.get('audio_meta', None)
    
    if audio_meta is not None:
        tags = audio_meta.get('tags', None)
        if tags is not None:
            try:
                title, artist, genre = '', '', ''
                for k in tags.keys():
                    if k in ['title', 'TITLE']:
                        title = f'titled {tags[k]}'
                    if k in ['artist', 'ARTIST']:
                        artist = f'by {tags[k]}'
                    if k in ['genre', 'GENRE']:
                        genre = tags[k]

                label = f'{genre} song "{title}" {artist}'
            except:
                pass
    label = f'{json_data["caption"]}'
    

    return image, {'label': label}

def get_dataset(urls: list):
    '''
    Pass s3 urls and get processed torch dataset
    '''
    dataset = (
           wds.WebDataset(urls)
           .decode("pil")
           .to_tuple("jpg", "json")
           .map_tuple(transform)
           .map(preprocess)
    )
    return dataset

urls = [f'{i:05}.tar' for i in range(1)]
dataset = get_dataset(urls)

batch_size = 32

loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)


mulan = AudioCLIP(
    embed_dim = 32,
    audio_cfg = CLIPAudioCfg(**{'image_size': (512, 1001), 'patch_size': 16}),
    text_cfg = CLIPTextCfg()
)


for i, batch in enumerate(loader):
    im, label = batch
    print(type(mulan.encode_image(im)))

The one example of the dataset can be found here: https://drive.google.com/file/d/15VFMSovEWCHJcDeg9lXqFnlACJmi5gr5/view?usp=sharing

Thank you!

@lucidrains
Copy link
Contributor Author

lucidrains commented Mar 5, 2023

@marianna13 hey Marianna, thanks for sharing the script

it looks good except for the image dimensions, whose height and width needs to be divisible by the patch size. however, that assert should be in the code somewhere, maybe left for a separate PR. it also does not matter for the vision transformer other than generating the absolute positions, so long as the image dimensions are the maximum of what you send in during training. The spectrogram must be of fixed shape during training as well, for now

Could you try rerunning your script? And also insert a print statement before the mulan invocation, in the case that it fails again even with my recent changes?

for i, batch in enumerate(loader):
    im, label = batch
    print('input shape is:', im.shape)
    print(type(mulan.encode_image(im)))

@marianna13
Copy link

@lucidrains it works! Thank you! :)

@lucidrains
Copy link
Contributor Author

@marianna13 hey Marianna, were you able to do a small test run?

if we can even get a training run to overfit on a small training set, maybe we can try to get this PR merged

@marianna13
Copy link

Hey @lucidrains, I tried to train a model with a small fraction of the dataset but it gets stuck at the first epoch and then gets killed. I can post my training script (I think it might be an issue on my side) but anyway

@lucidrains
Copy link
Contributor Author

@marianna13 ohh got it, what kind of error do you see before it dies?

@marianna13
Copy link

@lucidrains it just says "terminated" (I think some oom issue)

@lucidrains
Copy link
Contributor Author

@marianna13 ohh, yea i think Romain mentioned this to me

maybe i should take a look at your dataset class

you could also try just plucking the code from here. others have been training parts of audiolm successfully with it

@marianna13
Copy link

Thank you @lucidrains ! Does the AudioCLIP accepts only audio? I mean I have a bunch of spectrograms :)

@lucidrains
Copy link
Contributor Author

@marianna13 ohh i see, you will probably have to set a max length on the time dimension of the spectrogram

do you know if the spectrogram is generated from the full piece of music, or just chunks of them?

@lucidrains
Copy link
Contributor Author

@marianna13 actually, i can also just allow the audio clip to take care of that (curtailing the time dimension to some maximum number of patches

@marianna13
Copy link

@lucidrains I split every audio into 10 sec pieces and then convert them into spectrograms, so they should have the same time dimension

@lucidrains
Copy link
Contributor Author

@marianna13 ohh, that should be ok then 🤔 i'm not sure what's going wrong

@leoauri
Copy link

leoauri commented Dec 13, 2023

Hi, excuse me, I would just like to ask if work on this stalled or if training is ongoing. I tried looking on https://discord.gg/xBPBXfcFHd as suggested on https://github.com/lucidrains/musiclm-pytorch but I could not find active discussion. This implementation would be really something. Many thanks, L

@SuperGoodGame
Copy link

Can audioclip be used for training? If yes, how should I modify my config? Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants