Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add audio spectrogram transformer, and full audio clip #406

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions requirements-training.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch>=1.9.0
torchvision
torchaudio
webdataset>=0.2.5
regex
ftfy
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch>=1.9.0
torchvision
torchaudio
regex
ftfy
tqdm
Expand Down
3 changes: 2 additions & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .audio import AudioCLIP, CLIPAudioCfg, AudioSpectrogramTransformer
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
from .tokenizer import SimpleTokenizer, tokenize, decode
from .transform import image_transform, AugmentationCfg
from .transform import image_transform, AugmentationCfg
277 changes: 277 additions & 0 deletions src/open_clip/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
from typing import Callable, Optional, Sequence, Tuple, Optional

import numpy as np
from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import functional as F

from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking

from .utils import to_2tuple
from .model import CLIPTextCfg, CLIPVisionCfg, _build_text_tower

from .transformer import (
VisionTransformer,
LayerNormFp32,
LayerNorm,
QuickGELU
)

# audio spectrogram transformer

class AudioSpectrogramTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
global_average_pool: bool = False,
attentional_pool: bool = False,
n_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
spec_n_fft: int = 128,
spec_power: int = 2,
spec_win_length: int = 24,
spec_hop_length: Optional[int] = None,
spec_pad: int = 0,
spec_center: bool = True,
spec_pad_mode: str = 'reflect',
aug_stretch_factor: float = 0.8,
aug_freq_mask: int = 80,
aug_time_mask: int = 80,
):
super().__init__()

self.patch_size = to_2tuple(patch_size)

self.spec = Spectrogram(
n_fft=spec_n_fft,
power=spec_power,
win_length=spec_win_length,
hop_length=spec_hop_length,
pad=spec_pad,
center=spec_center,
pad_mode=spec_pad_mode
)

# spec augment - https://arxiv.org/abs/1904.08779

self.aug = torch.nn.Sequential(
TimeStretch(aug_stretch_factor, fixed_rate=True),
FrequencyMasking(freq_mask_param=aug_freq_mask),
TimeMasking(time_mask_param=aug_time_mask),
)

self.vit = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
global_average_pool=global_average_pool,
attentional_pool=attentional_pool,
n_queries=n_queries,
attn_pooler_heads=attn_pooler_heads,
output_dim=output_dim,
patch_dropout=patch_dropout,
act_layer=act_layer,
norm_layer=norm_layer,
output_tokens=output_tokens,
channels=1
)

def lock(self, unlocked_groups=0, freeze_bn_stats=False):
self.vit.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

def init_parameters(self):
self.vit.init_parameters()

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.vit.set_grad_checkpointing(enable=enable)

def forward(self, x: torch.Tensor):
x = self.spec(x)

if self.training:
x = self.aug(x)

# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes

height, width = x.shape[-2:]
patch_height, patch_width = self.patch_size

rounded_height = height // patch_height * patch_height
rounded_width = width // patch_width * patch_width

if (height, width) != (rounded_height, rounded_width):
print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')

x = x[..., None, :rounded_height, :rounded_width]

# pass maybe cropped spectrogram to vit

return self.vit(x)

# audio class config

@dataclass
class CLIPAudioCfg(CLIPVisionCfg):
spec_n_fft: int = 128
spec_power: int = 2
spec_win_length: int = 24
spec_hop_length: Optional[int] = None
spec_pad: int = 0
spec_center: bool = True
spec_pad_mode: str = 'reflect'
aug_stretch_factor: float = 0.8
aug_freq_mask: int = 80
aug_time_mask: int = 80

# factory method for building audio tower

def _build_audio_tower(
embed_dim: int,
audio_cfg: CLIPAudioCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(audio_cfg, dict):
audio_cfg = CLIPAudioCfg(**audio_cfg)

# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU

audio_heads = audio_cfg.width // audio_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm

audio = AudioSpectrogramTransformer(
image_size=audio_cfg.image_size,
patch_size=audio_cfg.patch_size,
width=audio_cfg.width,
layers=audio_cfg.layers,
heads=audio_heads,
mlp_ratio=audio_cfg.mlp_ratio,
ls_init_value=audio_cfg.ls_init_value,
patch_dropout=audio_cfg.patch_dropout,
global_average_pool=audio_cfg.global_average_pool,
attentional_pool=audio_cfg.attentional_pool,
n_queries=audio_cfg.n_queries,
attn_pooler_heads=audio_cfg.attn_pooler_heads,
output_tokens=audio_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
spec_n_fft=audio_cfg.spec_n_fft,
spec_power=audio_cfg.spec_power,
spec_win_length=audio_cfg.spec_win_length,
spec_hop_length=audio_cfg.spec_hop_length,
spec_pad=audio_cfg.spec_pad,
spec_center=audio_cfg.spec_center,
spec_pad_mode=audio_cfg.spec_pad_mode,
aug_stretch_factor=audio_cfg.aug_stretch_factor,
aug_freq_mask=audio_cfg.aug_freq_mask,
aug_time_mask=audio_cfg.aug_time_mask
)

return audio

# audio clip

class AudioCLIP(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should decide whether to extend CLIP

similarly, decide whether to just extend CoCa to AudioCoCa and override the visual modality transformer

output_dict: torch.jit.Final[bool]

def __init__(
self,
embed_dim,
text_cfg: CLIPTextCfg,
audio_cfg: CLIPAudioCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict

text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
audio_cfg = CLIPAudioCfg(**audio_cfg) if isinstance(audio_cfg, dict) else audio_cfg

self.visual = _build_audio_tower(
embed_dim=embed_dim,
audio_cfg=audio_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable

def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features

def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()

x = self.token_embedding(text).to(cast_dtype)

x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2)
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2)
x = self.ln_final(x)

# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x

def forward(self, audio, text, audio_latent=None):
text_latent = self.encode_text(text)

if audio_latent is None:
audio_latent = self.encode_image(audio)

logit_scale = self.logit_scale.exp()

if self.output_dict:
return {
"image_features": audio_latent,
"text_features": text_latent,
"logit_scale": logit_scale
}

return audio_latent, text_latent, logit_scale
9 changes: 5 additions & 4 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def __init__(
input_patchnorm: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False
output_tokens: bool = False,
channels: int = 3
):
super().__init__()
self.output_tokens = output_tokens
Expand All @@ -354,12 +355,12 @@ def __init__(
self.input_patchnorm = input_patchnorm

if input_patchnorm:
patch_input_dim = patch_height * patch_width * 3
patch_input_dim = patch_height * patch_width * channels
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
self.conv1 = nn.Linear(patch_input_dim, width)
else:
self.patchnorm_pre_ln = nn.Identity()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
self.conv1 = nn.Conv2d(in_channels=channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

# class embeddings and positional embeddings
scale = width ** -0.5
Expand Down Expand Up @@ -474,7 +475,7 @@ def forward(self, x: torch.Tensor):
x = torch.cat(
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = x + self.positional_embedding[:x.shape[1]].to(x.dtype)

# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
Expand Down