Skip to content

Commit

Permalink
Merge pull request #213 from facebookresearch/update_kld
Browse files Browse the repository at this point in the history
Update KLD with PaSST to reproduce MusicGen results
  • Loading branch information
adiyoss authored Aug 11, 2023
2 parents d6df4f0 + b960e69 commit 1b4012a
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 26 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear de
dataset.train.num_samples=10 dataset.valid.num_samples=10 \
dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
logging.level=DEBUG
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 616d7b3c
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e
INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
checkpoint.save_last=false # Using compression model from 616d7b3c

Expand Down
38 changes: 20 additions & 18 deletions audiocraft/metrics/kld.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,31 @@ def _load_base_model(self, pretrained_length: tp.Optional[float]):
model = get_basic_model(mode='logits')
return model, model_sample_rate, max_input_frames, min_input_frames

def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.Optional[torch.Tensor]:
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
"""Process audio to feed to the pretrained model."""
wav = wav.unsqueeze(0)
wav = wav[..., :wav_len]
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
wav = wav.squeeze(0)
# create chunks of audio to match the classifier processing length
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
segments = torch.split(wav, self.max_input_frames, dim=-1)
valid_segments = []
for s in segments:
# ignoring too small segments that are breaking the model inference
if s.size(-1) > self.min_input_frames:
s = torch.nn.functional.pad(s, (0, self.max_input_frames - s.shape[-1]))
valid_segments.append(s)
if len(valid_segments) > 0:
return torch.stack(valid_segments, dim=0)
else:
return None
return [s[None] for s in valid_segments]

def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
"""Run the pretrained model and get the predictions."""
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
wav = wav.mean(dim=1)
# PaSST is printing a lot of garbage that we are not interested in
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
with torch.no_grad(), _patch_passt_stft():
logits = self.model(wav.to(self.device))
probs = torch.softmax(logits, dim=-1)
return probs

def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
Expand All @@ -201,17 +210,10 @@ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
for i, wav in enumerate(x):
sample_rate = int(sample_rates[i].item())
wav_len = int(sizes[i].item())
wav = self._process_audio(wav, sample_rate, wav_len)
if wav is not None:
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
wav = wav.mean(dim=1)
# PaSST is printing a lot of infos that we are not interested in
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
with torch.no_grad(), _patch_passt_stft():
logits = self.model(wav.to(self.device))
probs = torch.softmax(logits, dim=-1)
probs = probs.mean(dim=0)
all_probs.append(probs)
wav_segments = self._process_audio(wav, sample_rate, wav_len)
for segment in wav_segments:
probs = self._get_model_preds(segment).mean(dim=0)
all_probs.append(probs)
if len(all_probs) > 0:
return torch.stack(all_probs, dim=0)
else:
Expand Down
5 changes: 5 additions & 0 deletions config/solver/audiogen/evaluation/objective_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ evaluate:
fad: true
kld: true
text_consistency: true

metrics:
kld:
passt:
pretrained_length: 10 # similarly to reported results in AudioGen paper
2 changes: 1 addition & 1 deletion model_cards/AUDIOGEN_MODEL_CARD.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Below are the objective metrics obtained with the released model on AudioCaps (c

| Model | Frechet Audio Distance | KLD | Text consistency |
|---|---|---|---|
| facebook/audiogen-medium | 1.77 | 1.41 | 0.299 |
| facebook/audiogen-medium | 1.77 | 1.58 | 0.30 |

More information can be found in the paper [AudioGen: Textually Guided Audio Generation][audiogen], in the Experiments section.

Expand Down
8 changes: 4 additions & 4 deletions model_cards/MUSICGEN_MODEL_CARD.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ Below are the objective metrics obtained on MusicCaps with the released model. N

| Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity |
|---|---|---|---|---|
| facebook/musicgen-small | 4.88 | 1.28 | 0.27 | - |
| facebook/musicgen-medium | 5.14 | 1.24 | 0.28 | - |
| facebook/musicgen-large | 5.48 | 1.22 | 0.28 | - |
| facebook/musicgen-melody | 4.93 | 1.26 | 0.27 | 0.44 |
| facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - |
| facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - |
| facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - |
| facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 |

More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section.

Expand Down

0 comments on commit 1b4012a

Please sign in to comment.