Skip to content

Commit

Permalink
add abstokenizer augmentation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Nov 22, 2024
1 parent c736a4d commit d7677c1
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 35 deletions.
74 changes: 39 additions & 35 deletions ariautils/tokenizer/absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy

from collections import defaultdict
from typing import Final, Callable, Any
from typing import Final, Callable, Any, Concatenate

from ariautils.midi import (
MidiDict,
Expand All @@ -28,8 +28,8 @@
# - Add asserts to the tokenization / detokenization for user error
# - Need to add a tokenization or MidiDict check of how to resolve different
# channels, with the same instrument, have overlaping notes
# - There are tons of edge cases here e.g., what if there are two indetical notes?
# on different channels.
# - There are tons of edge cases here e.g., what if there are two identical
# notes on different channels.


class AbsTokenizer(Tokenizer):
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self) -> None: # Not sure why this is required by

def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]:
return [
self.export_tempo_aug(tempo_aug_range=0.2, mixup=True),
self.export_tempo_aug(max_tempo_aug=0.2, mixup=True),
self.export_pitch_aug(5),
self.export_velocity_aug(1),
]
Expand Down Expand Up @@ -574,9 +574,11 @@ def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict:

return self._detokenize_midi_dict(tokenized_seq=tokenized_seq)

from typing import Optional

def export_pitch_aug(
self, aug_range: int
) -> Callable[[list[Token]], list[Token]]:
self, max_pitch_aug: int
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
"""Exports a function that augments the pitch of all note tokens.
Notes which fall out of the range (0, 127) will be replaced
Expand All @@ -593,7 +595,7 @@ def export_pitch_aug(
def pitch_aug_seq(
src: list[Token],
unk_tok: str,
_aug_range: int,
_max_pitch_aug: int,
pitch_aug: int | None = None,
) -> list[Token]:
def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:
Expand Down Expand Up @@ -630,8 +632,8 @@ def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:
else:
return unk_tok

if not pitch_aug:
pitch_aug = random.randint(-_aug_range, _aug_range)
if pitch_aug is None:
pitch_aug = random.randint(-_max_pitch_aug, _max_pitch_aug)

return [pitch_aug_tok(x, pitch_aug) for x in src]

Expand All @@ -640,13 +642,13 @@ def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:
functools.partial(
pitch_aug_seq,
unk_tok=self.unk_tok,
_aug_range=aug_range,
_max_pitch_aug=max_pitch_aug,
)
)

def export_velocity_aug(
self, aug_steps_range: int
) -> Callable[[list[Token]], list[Token]]:
self, max_num_aug_steps: int
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
"""Exports a function which augments the velocity of all pitch tokens.
Velocity values are clipped so that they don't fall outside of the
Expand All @@ -663,10 +665,10 @@ def export_velocity_aug(

def velocity_aug_seq(
src: list[Token],
velocity_step: int,
min_velocity_step: int,
max_velocity: int,
_aug_steps_range: int,
velocity_aug: int | None = None,
_max_num_aug_steps: int,
aug_step: int | None = None,
) -> list[Token]:
def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token:
if isinstance(tok, str): # Stand in for SpecialToken
Expand All @@ -693,32 +695,34 @@ def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token:
# Check it doesn't go out of bounds
if _velocity + _velocity_aug >= max_velocity:
return (_instrument, _pitch, max_velocity)
elif _velocity + _velocity_aug <= velocity_step:
return (_instrument, _pitch, velocity_step)
elif _velocity + _velocity_aug <= min_velocity_step:
return (_instrument, _pitch, min_velocity_step)

return (_instrument, _pitch, _velocity + _velocity_aug)

if not velocity_aug:
velocity_aug = velocity_step * random.randint(
-_aug_steps_range, _aug_steps_range
if aug_step is None:
velocity_aug = min_velocity_step * random.randint(
-_max_num_aug_steps, _max_num_aug_steps
)
else:
velocity_aug = aug_step * min_velocity_step

return [velocity_aug_tok(x, velocity_aug) for x in src]

# See functools.partial docs
return self.export_aug_fn_concat(
functools.partial(
velocity_aug_seq,
velocity_step=self.velocity_step,
min_velocity_step=self.velocity_step,
max_velocity=self.max_velocity,
_aug_steps_range=aug_steps_range,
_max_num_aug_steps=max_num_aug_steps,
)
)

# TODO: Adjust this so it can handle other tokens like <SEP>
# TODO: Refactor the logic
def export_tempo_aug(
self, tempo_aug_range: float, mixup: bool
) -> Callable[[list[Token]], list[Token]]:
self, max_tempo_aug: float, mixup: bool
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
"""Exports a function which augments the tempo of a sequence of tokens.
Additionally this function performs note-mixup: randomly re-ordering
Expand Down Expand Up @@ -749,23 +753,23 @@ def tempo_aug(
end_tok: str,
instruments_wd: list,
tokenizer_name: str,
_tempo_aug_range: float,
_max_tempo_aug: float,
_mixup: bool,
tempo_aug: float | None = None,
) -> list[Token]:
"""This must be used with export_aug_fn_concat in order to work
properly for concatenated sequences."""

def _quantize_time(_n: int) -> int:
def _quantize_time(_n: int | float) -> int:
return round(_n / time_step) * time_step

assert (
tokenizer_name == "abs"
), f"Augmentation function only supports base AbsTokenizer"

if not tempo_aug:
if tempo_aug is None:
tempo_aug = random.uniform(
1 - _tempo_aug_range, 1 + _tempo_aug_range
1 - _max_tempo_aug, 1 + _max_tempo_aug
)

src_time_tok_cnt = 0
Expand All @@ -785,8 +789,8 @@ def _quantize_time(_n: int) -> int:
elif tok_1 == start_tok:
res.append(tok_1)
continue
elif tok_1 == dim_tok and note_buffer:
assert isinstance(note_buffer["onset"], int)
elif tok_1 == dim_tok and note_buffer is not None:
assert isinstance(note_buffer["onset"], tuple)
dim_tok_seen = (src_time_tok_cnt, note_buffer["onset"][1])
continue
elif tok_1[0] == "prefix":
Expand Down Expand Up @@ -822,9 +826,9 @@ def _quantize_time(_n: int) -> int:
for src_time_tok_cnt, interval_notes in sorted(buffer.items()):
for src_onset, notes_by_onset in sorted(interval_notes.items()):
src_time = src_time_tok_cnt * abs_time_step + src_onset
tgt_time = round(src_time * tempo_aug)
tgt_time = _quantize_time(src_time * tempo_aug)
curr_tgt_time_tok_cnt = tgt_time // abs_time_step
curr_tgt_onset = _quantize_time(tgt_time % abs_time_step)
curr_tgt_onset = tgt_time % abs_time_step

if curr_tgt_onset == abs_time_step:
curr_tgt_onset -= time_step
Expand All @@ -846,7 +850,7 @@ def _quantize_time(_n: int) -> int:
if _src_dur_tok is not None:
assert isinstance(_src_dur_tok[1], int)
tgt_dur = _quantize_time(
round(_src_dur_tok[1] * tempo_aug)
_src_dur_tok[1] * tempo_aug
)
tgt_dur = min(tgt_dur, max_dur)
else:
Expand Down Expand Up @@ -882,7 +886,7 @@ def _quantize_time(_n: int) -> int:
start_tok=self.bos_tok,
instruments_wd=self.instruments_wd,
tokenizer_name=self.name,
_tempo_aug_range=tempo_aug_range,
_max_tempo_aug=max_tempo_aug,
_mixup=mixup,
)
)
Loading

0 comments on commit d7677c1

Please sign in to comment.