diff --git a/ariautils/midi.py b/ariautils/midi.py index 5977b25..1813bdf 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1,4 +1,4 @@ -"""Utils for data/MIDI processing.""" +"""Utils for MIDI processing.""" import re import os @@ -7,25 +7,28 @@ import unicodedata import mido +from mido.midifiles.units import tick2second from collections import defaultdict from pathlib import Path from typing import ( - List, - Dict, Any, - Tuple, Final, Concatenate, Callable, TypeAlias, Literal, TypedDict, + cast, ) -from mido.midifiles.units import tick2second from ariautils.utils import load_config, load_maestro_metadata_json +# TODO: +# - Remove unneeded comments +# - Add asserts + + class MetaMessage(TypedDict): """Meta message type corresponding text or copyright MIDI meta messages.""" @@ -83,37 +86,37 @@ class NoteMessage(TypedDict): class MidiDictData(TypedDict): """Type for MidiDict attributes in dictionary form.""" - meta_msgs: List[MetaMessage] - tempo_msgs: List[TempoMessage] - pedal_msgs: List[PedalMessage] - instrument_msgs: List[InstrumentMessage] - note_msgs: List[NoteMessage] + meta_msgs: list[MetaMessage] + tempo_msgs: list[TempoMessage] + pedal_msgs: list[PedalMessage] + instrument_msgs: list[InstrumentMessage] + note_msgs: list[NoteMessage] ticks_per_beat: int - metadata: Dict[str, Any] + metadata: dict[str, Any] class MidiDict: """Container for MIDI data in dictionary form. Args: - meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages. - tempo_msgs (List[TempoMessage]): List of tempo change messages. - pedal_msgs (List[PedalMessage]): List of sustain pedal messages. - instrument_msgs (List[InstrumentMessage]): List of program change messages. - note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events. + meta_msgs (list[MetaMessage]): List of text or copyright MIDI meta messages. + tempo_msgs (list[TempoMessage]): List of tempo change messages. + pedal_msgs (list[PedalMessage]): List of sustain pedal messages. + instrument_msgs (list[InstrumentMessage]): List of program change messages. + note_msgs (list[NoteMessage]): List of note messages from paired note-on/off events. ticks_per_beat (int): MIDI ticks per beat. metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}). """ def __init__( self, - meta_msgs: List[MetaMessage], - tempo_msgs: List[TempoMessage], - pedal_msgs: List[PedalMessage], - instrument_msgs: List[InstrumentMessage], - note_msgs: List[NoteMessage], + meta_msgs: list[MetaMessage], + tempo_msgs: list[TempoMessage], + pedal_msgs: list[PedalMessage], + instrument_msgs: list[InstrumentMessage], + note_msgs: list[NoteMessage], ticks_per_beat: int, - metadata: Dict[str, Any], + metadata: dict[str, Any], ): self.meta_msgs = meta_msgs self.tempo_msgs = tempo_msgs @@ -147,10 +150,10 @@ def __init__( self.program_to_instrument = self.get_program_to_instrument() @classmethod - def get_program_to_instrument(cls) -> Dict[int, str]: + def get_program_to_instrument(cls) -> dict[int, str]: """Return a map of MIDI program to instrument name.""" - PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = ( + PROGRAM_TO_INSTRUMENT: Final[dict[int, str]] = ( {i: "piano" for i in range(0, 7 + 1)} | {i: "chromatic" for i in range(8, 15 + 1)} | {i: "organ" for i in range(16, 23 + 1)} @@ -213,7 +216,7 @@ def from_midi(cls, mid_path: str | Path) -> "MidiDict": return cls(**midi_to_dict(mid)) def calculate_hash(self) -> str: - msg_dict_to_hash = dict(self.get_msg_dict()) + msg_dict_to_hash = cast(dict, self.get_msg_dict()) # Remove metadata before calculating hash msg_dict_to_hash.pop("meta_msgs") @@ -234,12 +237,12 @@ def tick_to_ms(self, tick: int) -> int: ticks_per_beat=self.ticks_per_beat, ) - def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]: + def _build_pedal_intervals(self) -> dict[int, list[list[int]]]: """Returns a mapping of channels to sustain pedal intervals.""" self.pedal_msgs.sort(key=lambda msg: msg["tick"]) channel_to_pedal_intervals = defaultdict(list) - pedal_status: Dict[int, int] = {} + pedal_status: dict[int, int] = {} for pedal_msg in self.pedal_msgs: tick = pedal_msg["tick"] @@ -276,7 +279,7 @@ def resolve_overlaps(self) -> "MidiDict": """ # Organize notes by channel and pitch - note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict( + note_msgs_c: dict[int, dict[int, list[NoteMessage]]] = defaultdict( lambda: defaultdict(list) ) for msg in self.note_msgs: @@ -330,7 +333,7 @@ def resolve_pedal(self) -> "MidiDict": return self - # TODO: Needs to be refactored and tested + # TODO: Needs to be refactored def remove_redundant_pedals(self) -> "MidiDict": """Removes redundant pedal messages from the MIDI data in place. @@ -342,7 +345,7 @@ def remove_redundant_pedals(self) -> "MidiDict": def _is_pedal_useful( pedal_start_tick: int, pedal_end_tick: int, - note_msgs: List[NoteMessage], + note_msgs: list[NoteMessage], ) -> bool: # This logic loops through the note_msgs that could possibly # be effected by the pedal which starts at pedal_start_tick @@ -486,7 +489,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": channels_to_remove = [i for i in channels_to_remove if i != 9] # Remove unwanted messages all type by looping over msgs types - _msg_dict: Dict[str, List] = { + _msg_dict: dict[str, list] = { "meta_msgs": self.meta_msgs, "tempo_msgs": self.tempo_msgs, "pedal_msgs": self.pedal_msgs, @@ -511,20 +514,20 @@ def remove_instruments(self, config: dict) -> "MidiDict": # TODO: The sign has been changed. Make sure this function isn't used anywhere else def _extract_track_data( track: mido.MidiTrack, -) -> Tuple[ - List[MetaMessage], - List[TempoMessage], - List[PedalMessage], - List[InstrumentMessage], - List[NoteMessage], +) -> tuple[ + list[MetaMessage], + list[TempoMessage], + list[PedalMessage], + list[InstrumentMessage], + list[NoteMessage], ]: """Converts MIDI messages into format used by MidiDict.""" - meta_msgs: List[MetaMessage] = [] - tempo_msgs: List[TempoMessage] = [] - pedal_msgs: List[PedalMessage] = [] - instrument_msgs: List[InstrumentMessage] = [] - note_msgs: List[NoteMessage] = [] + meta_msgs: list[MetaMessage] = [] + tempo_msgs: list[TempoMessage] = [] + pedal_msgs: list[PedalMessage] = [] + instrument_msgs: list[InstrumentMessage] = [] + note_msgs: list[NoteMessage] = [] last_note_on = defaultdict(list) for message in track: @@ -684,7 +687,7 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: metadata_fn = get_metadata_fn( metadata_process_name=metadata_process_name ) - fn_args: Dict = metadata_process_config["args"] + fn_args: dict = metadata_process_config["args"] collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args) if collected_metadata: @@ -788,11 +791,11 @@ def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: ) # Magic sorting function - def _sort_fn(msg: mido.Message) -> Tuple[int, int]: + def _sort_fn(msg: mido.Message) -> tuple[int, int]: if hasattr(msg, "velocity"): - return (msg.time, msg.velocity) + return (msg.time, msg.velocity) # pyright: ignore else: - return (msg.time, 1000) + return (msg.time, 1000) # pyright: ignore # Sort and convert from abs_time -> delta_time track = sorted(track, key=_sort_fn) @@ -812,7 +815,7 @@ def _sort_fn(msg: mido.Message) -> Tuple[int, int]: def get_duration_ms( start_tick: int, end_tick: int, - tempo_msgs: List[TempoMessage], + tempo_msgs: list[TempoMessage], ticks_per_beat: int, ) -> int: """Calculates elapsed time (in ms) between start_tick and end_tick.""" @@ -897,7 +900,7 @@ def to_ascii(s: str) -> str: def meta_composer_filename( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list -) -> Dict[str, str]: +) -> dict[str, str]: file_name = Path(str(mid.filename)).stem matched_names_unique = set() for name in composer_names: @@ -914,7 +917,7 @@ def meta_composer_filename( def meta_form_filename( mid: mido.MidiFile, msg_data: MidiDictData, form_names: list -) -> Dict[str, str]: +) -> dict[str, str]: file_name = Path(str(mid.filename)).stem matched_names_unique = set() for name in form_names: @@ -931,7 +934,7 @@ def meta_form_filename( def meta_composer_metamsg( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list -) -> Dict[str, str]: +) -> dict[str, str]: matched_names_unique = set() for msg in msg_data["meta_msgs"]: for name in composer_names: @@ -952,7 +955,7 @@ def meta_maestro_json( msg_data: MidiDictData, composer_names: list, form_names: list, -) -> Dict[str, str]: +) -> dict[str, str]: """Loads composer and form metadata from MAESTRO metadata json file. @@ -990,16 +993,16 @@ def meta_maestro_json( return res -def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]: +def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> dict[str, str]: return {"abs_path": str(Path(str(mid.filename)).absolute())} def get_metadata_fn( metadata_process_name: str, -) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]: - name_to_fn: Dict[ +) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]]: + name_to_fn: dict[ str, - Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]], + Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]], ] = { "composer_filename": meta_composer_filename, "composer_metamsg": meta_composer_metamsg, @@ -1017,7 +1020,7 @@ def get_metadata_fn( return fn -def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: +def test_max_programs(midi_dict: MidiDict, max: int) -> tuple[bool, int]: """Returns false if midi_dict uses more than {max} programs.""" present_programs = set( map( @@ -1032,7 +1035,7 @@ def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: return False, len(present_programs) -def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: +def test_max_instruments(midi_dict: MidiDict, max: int) -> tuple[bool, int]: present_instruments = set( map( lambda msg: midi_dict.program_to_instrument[msg["data"]], @@ -1048,7 +1051,7 @@ def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: def test_note_frequency( midi_dict: MidiDict, max_per_second: float, min_per_second: float -) -> Tuple[bool, float]: +) -> tuple[bool, float]: if not midi_dict.note_msgs: return False, 0.0 @@ -1073,7 +1076,7 @@ def test_note_frequency( def test_note_frequency_per_instrument( midi_dict: MidiDict, max_per_second: float, min_per_second: float -) -> Tuple[bool, float]: +) -> tuple[bool, float]: num_instruments = len( set( map( @@ -1111,7 +1114,7 @@ def test_note_frequency_per_instrument( def test_min_length( midi_dict: MidiDict, min_seconds: int -) -> Tuple[bool, float]: +) -> tuple[bool, float]: if not midi_dict.note_msgs: return False, 0.0 @@ -1130,9 +1133,9 @@ def test_min_length( def get_test_fn( test_name: str, -) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]: - name_to_fn: Dict[ - str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]] +) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]: + name_to_fn: dict[ + str, Callable[Concatenate[MidiDict, ...], tuple[bool, Any]] ] = { "max_programs": test_max_programs, "max_instruments": test_max_instruments, diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py new file mode 100644 index 0000000..792d910 --- /dev/null +++ b/ariautils/tokenizer/__init__.py @@ -0,0 +1,5 @@ +"""Includes Tokenizers and pre-processing utilities.""" + +from ariautils.tokenizer._base import Tokenizer + +__all__ = ["Tokenizer"] diff --git a/ariautils/tokenizer/_base.py b/ariautils/tokenizer/_base.py new file mode 100644 index 0000000..35e51ab --- /dev/null +++ b/ariautils/tokenizer/_base.py @@ -0,0 +1,244 @@ +"""Contains abstract tokenizer class.""" + +import functools + +from typing import ( + Any, + Final, + Callable, + TypeAlias, +) + +from ariautils.midi import MidiDict + + +SpecialToken: TypeAlias = str +Token: TypeAlias = tuple[Any, ...] | str + + +class Tokenizer: + """Abstract Tokenizer class for tokenizing MidiDict objects. + + Args: + return_tensors (bool, optional): If True, encode will return tensors. + Defaults to False. + """ + + def __init__( + self, + ) -> None: + self.name: str = "" + + self.bos_tok: Final[SpecialToken] = "" + self.eos_tok: Final[SpecialToken] = "" + self.pad_tok: Final[SpecialToken] = "

" + self.unk_tok: Final[SpecialToken] = "" + self.dim_tok: Final[SpecialToken] = "" + + self.special_tokens: list[SpecialToken] = [ + self.bos_tok, + self.eos_tok, + self.pad_tok, + self.unk_tok, + self.dim_tok, + ] + + # These must be implemented in child class (abstract params) + self.config: dict[str, Any] = {} + self.vocab: tuple[Token, ...] = () + self.instruments_wd: list[str] = [] + self.instruments_nd: list[str] = [] + self.tok_to_id: dict[Token, int] = {} + self.id_to_tok: dict[int, Token] = {} + self.vocab_size: int = -1 + self.pad_id: int = -1 + + def _tokenize_midi_dict(self, midi_dict: MidiDict) -> list[Token]: + """Abstract method for tokenizing a MidiDict object into a sequence of + tokens. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + raise NotImplementedError + + def tokenize(self, midi_dict: MidiDict, **kwargs: Any) -> list[Token]: + """Tokenizes a MidiDict object. + + This function should be overridden if additional transformations are + required, e.g., adding additional tokens. The default behavior is to + call tokenize_midi_dict. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + **kwargs (Any): Additional keyword arguments passed to _tokenize_midi_dict. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + return self._tokenize_midi_dict(midi_dict, **kwargs) + + def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: + """Abstract method for de-tokenizing a sequence of tokens into a + MidiDict Object. + + Args: + tokenized_seq (list[int]): The sequence of tokens to detokenize. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + raise NotImplementedError + + def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict: + """Detokenizes a MidiDict object. + + This function should be overridden if additional are required during + detokenization. The default behavior is to call detokenize_midi_dict. + + Args: + tokenized_seq (list): The sequence of tokens to detokenize. + **kwargs (Any): Additional keyword arguments passed to detokenize_midi_dict. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + return self._detokenize_midi_dict(tokenized_seq, **kwargs) + + def export_data_aug(cls) -> list[Callable[[list[Token]], list[Token]]]: + """Export a list of implemented data augmentation functions.""" + + raise NotImplementedError + + def encode(self, unencoded_seq: list[Token]) -> list[int]: + """Converts tokenized sequence into the corresponding list of ids.""" + + def _enc_fn(tok: Token) -> int: + return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) + + if self.tok_to_id is None: + raise NotImplementedError("tok_to_id") + + encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] + + return encoded_seq + + def decode(self, encoded_seq: list[int]) -> list[Token]: + """Converts list of ids into the corresponding list of tokens.""" + + def _dec_fn(id: int) -> Token: + return self.id_to_tok.get(id, self.unk_tok) + + if self.id_to_tok is None: + raise NotImplementedError("id_to_tok") + + decoded_seq = [_dec_fn(idx) for idx in encoded_seq] + + return decoded_seq + + @classmethod + def _find_closest_int(cls, n: int, sorted_list: list[int]) -> int: + # Selects closest integer to n from sorted_list + # Time ~ Log(n) + + if not sorted_list: + raise ValueError("List is empty") + + left, right = 0, len(sorted_list) - 1 + closest = float("inf") + + while left <= right: + mid = (left + right) // 2 + diff = abs(sorted_list[mid] - n) + + if diff < abs(closest - n): + closest = sorted_list[mid] + + if sorted_list[mid] < n: + left = mid + 1 + else: + right = mid - 1 + + return closest # type: ignore[return-value] + + def add_tokens_to_vocab(self, tokens: list[Token] | tuple[Token]) -> None: + """Utility function for safely adding extra tokens to vocab.""" + + for token in tokens: + assert token not in self.vocab + + self.vocab = self.vocab + tuple(tokens) + self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} + self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} + self.vocab_size = len(self.vocab) + + def export_aug_fn_concat( + self, aug_fn: Callable[[list[Token]], list[Token]] + ) -> Callable[[list[Token]], list[Token]]: + """Transforms an augmentation function for concatenated sequences. + + This is useful for augmentation functions that are only defined for + sequences which start with "" and end with "". + + Args: + aug_fn (Callable[[list[Token]], list[Token]]): The augmentation + function to transform. + + Returns: + Callable[[list[Token]], list[Token]]: A transformed augmentation + function that can handle concatenated sequences. + """ + + def _aug_fn_concat( + src: list[Token], + _aug_fn: Callable[[list[Token]], list[Token]], + pad_tok: str, + eos_tok: str, + **kwargs: Any, + ) -> list[Token]: + # Split list on "" + initial_seq_len = len(src) + src_sep = [] + prev_idx = 0 + for curr_idx, tok in enumerate(src, start=1): + if tok == eos_tok: + src_sep.append(src[prev_idx:curr_idx]) + prev_idx = curr_idx + + # Last sequence + if prev_idx != curr_idx: + src_sep.append(src[prev_idx:]) + + # Augment + src_sep = [ + _aug_fn( + _src, + **kwargs, + ) + for _src in src_sep + ] + + # Concatenate + src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] + + # Pad or truncate to original sequence length as necessary + src_aug_concat = src_aug_concat[:initial_seq_len] + src_aug_concat += [pad_tok] * ( + initial_seq_len - len(src_aug_concat) + ) + + return src_aug_concat + + return functools.partial( + _aug_fn_concat, + _aug_fn=aug_fn, + pad_tok=self.pad_tok, + eos_tok=self.eos_tok, + ) diff --git a/ariautils/tokenizer/absolute.py b/ariautils/tokenizer/absolute.py new file mode 100644 index 0000000..6ac465b --- /dev/null +++ b/ariautils/tokenizer/absolute.py @@ -0,0 +1,848 @@ +"""Contains MIDI tokenizer with absolute onset timings.""" + +import functools +import itertools +import random +import copy + +from collections import defaultdict +from typing import Final, Callable + +from ariautils.midi import ( + MidiDict, + MetaMessage, + TempoMessage, + PedalMessage, + InstrumentMessage, + NoteMessage, + get_duration_ms, +) +from ariautils.utils import load_config, get_logger +from ariautils.tokenizer._base import Tokenizer, Token + + +logger = get_logger(__package__) + + +# TODO: +# - Add asserts to the tokenization / detokenization for user error + + +class AbsTokenizer(Tokenizer): + """MidiDict tokenizer implemented with absolute onset timings. + + The tokenizer processes MIDI files in 5000ms segments, with each segment separated by + a special token. Within each segment, note timings are represented relative to the + segment start. + + Tokenization Schema: + For non-percussion instruments: + - Each note is represented by three consecutive tokens: + 1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, and velocity + 2. [onset]: Absolute time in milliseconds from segment start + 3. [duration]: Note duration in milliseconds + + For percussion instruments: + - Each note is represented by two consecutive tokens: + 1. [drum, note_number]: Percussion instrument and MIDI note number + 2. [onset]: Absolute time in milliseconds from segment start + + Notes: + - Notes are ordered according to onset time + - Sustain pedals effects are incorporated directly into note durations + - Various configuration settings effecting instrument processing, + timing resolution, and quantization levels can be adjusted the + config.json at 'tokenizer.abs'. + """ + + def __init__(self) -> None: # Not sure why this is required by + super().__init__() + self.config = load_config()["tokenizer"]["abs"] + self.name = "abs" + + # Calculate time quantizations (in ms) + self.abs_time_step: int = self.config["abs_time_step_ms"] + self.max_dur: int = self.config["max_dur_ms"] + self.time_step: int = self.config["time_step_ms"] + + self.dur_time_quantizations = [ + self.time_step * i + for i in range((self.max_dur // self.time_step) + 1) + ] + self.onset_time_quantizations = [ + self.time_step * i for i in range((self.max_dur // self.time_step)) + ] + + # Calculate velocity quantizations + self.velocity_step: int = self.config["velocity_quantization"]["step"] + self.velocity_quantizations = [ + i * self.velocity_step + for i in range(int(127 / self.velocity_step) + 1) + ] + self.max_velocity = self.velocity_quantizations[-1] + + # _nd = no drum; _wd = with drum + self.instruments_nd = [ + k + for k, v in self.config["ignore_instruments"].items() + if v is False + ] + self.instruments_wd = self.instruments_nd + ["drum"] + + # Prefix tokens + self.prefix_tokens: list[Token] = [ + ("prefix", "instrument", x) for x in self.instruments_wd + ] + self.composer_names: list[Token] = self.config["composer_names"] + self.form_names: list[str] = self.config["form_names"] + self.genre_names: list[str] = self.config["genre_names"] + self.prefix_tokens += [ + ("prefix", "composer", x) for x in self.composer_names + ] + self.prefix_tokens += [("prefix", "form", x) for x in self.form_names] + self.prefix_tokens += [("prefix", "genre", x) for x in self.genre_names] + + # Build vocab + self.time_tok = "" + self.onset_tokens: list[Token] = [ + ("onset", i) for i in self.onset_time_quantizations + ] + self.dur_tokens: list[Token] = [ + ("dur", i) for i in self.dur_time_quantizations + ] + self.drum_tokens: list[Token] = [("drum", i) for i in range(35, 82)] + + self.note_tokens: list[Token] = list( + itertools.product( + self.instruments_nd, + [i for i in range(128)], + self.velocity_quantizations, + ) + ) + + self.special_tokens.append(self.time_tok) + self.add_tokens_to_vocab( + self.special_tokens + + self.prefix_tokens + + self.note_tokens + + self.drum_tokens + + self.dur_tokens + + self.onset_tokens + ) + self.pad_id = self.tok_to_id[self.pad_tok] + + def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]: + return [ + self.export_tempo_aug(tempo_aug_range=0.2, mixup=True), + self.export_pitch_aug(5), + self.export_velocity_aug(1), + ] + + def _quantize_dur(self, time: int) -> int: + # This function will return values res >= 0 (inc. 0) + return self._find_closest_int(time, self.dur_time_quantizations) + + def _quantize_onset(self, time: int) -> int: + # This function will return values res >= 0 (inc. 0) + return self._find_closest_int(time, self.onset_time_quantizations) + + def _quantize_velocity(self, velocity: int) -> int: + # This function will return values in the range 0 < res =< 127 + velocity_quantized = self._find_closest_int( + velocity, self.velocity_quantizations + ) + + if velocity_quantized == 0 and velocity != 0: + return self.velocity_step + else: + return velocity_quantized + + def _format( + self, prefix: list[Token], unformatted_seq: list[Token] + ) -> list[Token]: + # If unformatted_seq is longer than 150 tokens insert diminish tok + idx = -100 + random.randint(-10, 10) + if len(unformatted_seq) > 150: + if ( + unformatted_seq[idx][0] == "onset" + ): # Don't want: note, , onset, due + unformatted_seq.insert(idx - 1, self.dim_tok) + elif ( + unformatted_seq[idx][0] == "dur" + ): # Don't want: note, onset, , dur + unformatted_seq.insert(idx - 2, self.dim_tok) + else: + unformatted_seq.insert(idx, self.dim_tok) + + res = prefix + [self.bos_tok] + unformatted_seq + [self.eos_tok] + + return res + + def calc_length_ms(self, seq: list[Token], onset: bool = False) -> int: + """Calculates time (ms) end of sequence to the end of the last note. If + onset=True, then it will return the onset time of the last note instead + """ + + # Find the index of the last onset or dur token + seq = copy.deepcopy(seq) + for _idx in range(len(seq) - 1, -1, -1): + tok = seq[_idx] + if type(tok) is tuple and tok[0] in {"onset", "dur"}: + break + else: + seq.pop() + + time_offset_ms = seq.count(self.time_tok) * self.abs_time_step + idx = len(seq) - 1 + for tok in seq[::-1]: + if type(tok) is tuple and tok[0] == "dur": + assert seq[idx][0] == "dur", "Expected duration token" + assert seq[idx - 1][0] == "onset", "Expect onset token" + + onset_ms = seq[idx - 1][1] + duration_ms = seq[idx][1] + assert isinstance(onset_ms, int), "Expected int" + assert isinstance(duration_ms, int), "Expected int" + + if onset is False: + return time_offset_ms + onset_ms + duration_ms + elif onset is True: + return time_offset_ms + onset_ms # Ignore dur + + idx -= 1 + + # If it gets to this point, an error has occurred + raise Exception("Invalid sequence format") + + def truncate_by_time( + self, tokenized_seq: list[Token], trunc_time_ms: int + ) -> list[Token]: + """Truncates notes with onset_ms > trunc_time_ms.""" + time_offset_ms = 0 + for idx, tok in enumerate(tokenized_seq): + if tok == self.time_tok: + time_offset_ms += self.abs_time_step + elif type(tok) is tuple and tok[0] == "onset": + if time_offset_ms + tok[1] > trunc_time_ms: + return tokenized_seq[: idx - 1] + + return tokenized_seq + + def _tokenize_midi_dict( + self, midi_dict: MidiDict, remove_preceding_silence: bool = True + ) -> list[Token]: + ticks_per_beat = midi_dict.ticks_per_beat + midi_dict.remove_instruments(self.config["ignore_instruments"]) + + if len(midi_dict.note_msgs) == 0: + raise Exception("note_msgs is empty after ignoring instruments") + + channel_to_pedal_intervals = midi_dict._build_pedal_intervals() + + channels_used = {msg["channel"] for msg in midi_dict.note_msgs} + + channel_to_instrument = { + msg["channel"]: midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + if msg["channel"] != 9 # Exclude drums + } + # If non-drum channel is missing from instrument_msgs, default to piano + for c in channels_used: + if channel_to_instrument.get(c) is None and c != 9: + channel_to_instrument[c] = "piano" + + # Calculate prefix + prefix: list[Token] = [ + ("prefix", "instrument", x) + for x in set(channel_to_instrument.values()) + ] + if 9 in channels_used: + prefix.append(("prefix", "instrument", "drum")) + composer = midi_dict.metadata.get("composer") + if composer and (composer in self.composer_names): + prefix.insert(0, ("prefix", "composer", composer)) + form = midi_dict.metadata.get("form") + if form and (form in self.form_names): + prefix.insert(0, ("prefix", "form", form)) + genre = midi_dict.metadata.get("genre") + if genre and (genre in self.genre_names): + prefix.insert(0, ("prefix", "genre", genre)) + random.shuffle(prefix) + + tokenized_seq: list[Token] = [] + + if remove_preceding_silence is False: + initial_onset_tick = 0 + else: + initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"] + + curr_time_since_onset = 0 + for _, msg in enumerate(midi_dict.note_msgs): + # Extract msg data + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + _velocity = msg["data"]["velocity"] + _start_tick = msg["data"]["start"] + _end_tick = msg["data"]["end"] + + # Calculate time data + prev_time_since_onset = curr_time_since_onset + curr_time_since_onset = get_duration_ms( + start_tick=initial_onset_tick, + end_tick=_start_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=ticks_per_beat, + ) + + # Add abs time token if necessary + time_toks_to_append = ( + curr_time_since_onset // self.abs_time_step + ) - (prev_time_since_onset // self.abs_time_step) + if time_toks_to_append > 0: + for _ in range(time_toks_to_append): + tokenized_seq.append(self.time_tok) + + # Special case instrument is a drum. This occurs exclusively when + # MIDI channel is 9 when 0 indexing + if _channel == 9: + _note_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step + ) + tokenized_seq.append(("drum", _pitch)) + tokenized_seq.append(("onset", _note_onset)) + + else: # Non drum case (i.e. an instrument note) + _instrument = channel_to_instrument[_channel] + + # Update _end_tick if affected by pedal + for pedal_interval in channel_to_pedal_intervals[_channel]: + pedal_start, pedal_end = ( + pedal_interval[0], + pedal_interval[1], + ) + if pedal_start < _end_tick < pedal_end: + _end_tick = pedal_end + break + + _note_duration = get_duration_ms( + start_tick=_start_tick, + end_tick=_end_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=ticks_per_beat, + ) + + # Quantize + _velocity = self._quantize_velocity(_velocity) + _note_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step + ) + _note_duration = self._quantize_dur(_note_duration) + if _note_duration == 0: + _note_duration = self.time_step + + tokenized_seq.append((_instrument, _pitch, _velocity)) + tokenized_seq.append(("onset", _note_onset)) + tokenized_seq.append(("dur", _note_duration)) + + return self._format( + prefix=prefix, + unformatted_seq=tokenized_seq, + ) + + def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: + # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to + # skip converting between ticks and ms + instrument_programs = self.config["instrument_programs"] + TICKS_PER_BEAT: Final[int] = 500 + TEMPO: Final[int] = 500000 + + # Set message tempos + tempo_msgs: list[TempoMessage] = [ + {"type": "tempo", "data": TEMPO, "tick": 0} + ] + meta_msgs: list[MetaMessage] = [] + pedal_msgs: list[PedalMessage] = [] + instrument_msgs: list[InstrumentMessage] = [] + + instrument_to_channel: dict[str, int] = {} + + # Add non-drum instrument_msgs, breaks at first note token + channel_idx = 0 + curr_tick = 0 + for idx, tok in enumerate(tokenized_seq): + if channel_idx == 9: # Skip channel reserved for drums + channel_idx += 1 + + if tok in self.special_tokens: + if tok == self.time_tok: + curr_tick += self.abs_time_step + continue + elif ( + tok[0] == "prefix" + and tok[1] == "instrument" + and tok[2] in self.instruments_wd + ): + # Process instrument prefix tokens + if tok[2] in instrument_to_channel.keys(): + logger.warning(f"Duplicate prefix {tok[2]}") + continue + elif tok[2] == "drum": + instrument_msgs.append( + { + "type": "instrument", + "data": 0, + "tick": 0, + "channel": 9, + } + ) + instrument_to_channel["drum"] = 9 + else: + instrument_msgs.append( + { + "type": "instrument", + "data": instrument_programs[tok[2]], + "tick": 0, + "channel": channel_idx, + } + ) + instrument_to_channel[tok[2]] = channel_idx + channel_idx += 1 + elif tok in self.prefix_tokens: + continue + else: + # Note, wait, or duration token + start = idx + break + + # Note messages + note_msgs: list[NoteMessage] = [] + for tok_1, tok_2, tok_3 in zip( + tokenized_seq[start:], + tokenized_seq[start + 1 :], + tokenized_seq[start + 2 :], + ): + if tok_1 in self.special_tokens: + _tok_type_1 = "special" + else: + _tok_type_1 = tok_1[0] + if tok_2 in self.special_tokens: + _tok_type_2 = "special" + else: + _tok_type_2 = tok_2[0] + if tok_3 in self.special_tokens: + _tok_type_3 = "special" + else: + _tok_type_3 = tok_3[0] + + if tok_1 == self.time_tok: + curr_tick += self.abs_time_step + + elif ( + _tok_type_1 == "special" + or _tok_type_1 == "prefix" + or _tok_type_1 == "onset" + or _tok_type_1 == "dur" + ): + continue + elif _tok_type_1 == "drum" and _tok_type_2 == "onset": + assert isinstance( + tok_2[1], int + ), f"Expected int for onset, got {tok_2[1]}" + assert isinstance( + tok_1[1], int + ), f"Expected int for pitch, got {tok_1[1]}" + + _start_tick: int = curr_tick + tok_2[1] + _end_tick: int = _start_tick + self.time_step + _pitch: int = tok_1[1] + _channel: int = instrument_to_channel[tok_1[0]] + _velocity: int = self.config["drum_velocity"] + + if _channel is None: + logger.warning( + "Tried to decode note message for unexpected instrument" + ) + else: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) + + elif ( + _tok_type_1 in self.instruments_nd + and _tok_type_2 == "onset" + and _tok_type_3 == "dur" + ): + assert isinstance( + tok_1[1], int + ), f"Expected int for pitch, got {tok_1[1]}" + assert isinstance( + tok_1[2], int + ), f"Expected int for velocity, got {tok_1[2]}" + assert isinstance( + tok_2[1], int + ), f"Expected int for onset, got {tok_2[1]}" + assert isinstance( + tok_3[1], int + ), f"Expected int for duration, got {tok_3[1]}" + + _pitch = tok_1[1] + _channel = instrument_to_channel[tok_1[0]] + _velocity = tok_1[2] + _start_tick = curr_tick + tok_2[1] + _end_tick = _start_tick + tok_3[1] + + if _channel is None: + logger.warning( + "Tried to decode note message for unexpected instrument" + ) + else: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) + else: + logger.warning( + f"Unexpected token sequence: {tok_1}, {tok_2}, {tok_3}" + ) + + return MidiDict( + meta_msgs=meta_msgs, + tempo_msgs=tempo_msgs, + pedal_msgs=pedal_msgs, + instrument_msgs=instrument_msgs, + note_msgs=note_msgs, + ticks_per_beat=TICKS_PER_BEAT, + metadata={}, + ) + + def export_pitch_aug( + self, aug_range: int + ) -> Callable[[list[Token]], list[Token]]: + """Exports a function that augments the pitch of all note tokens. + + Notes which fall out of the range (0, 127) will be replaced + with the unknown token ''. + + Args: + aug_range (int): Returned function will randomly augment the pitch + from a value in the range (-aug_range, aug_range). + + Returns: + Callable[[list[Token]], list[Token]]: Exported function. + """ + + def pitch_aug_seq( + src: list[Token], + unk_tok: str, + _aug_range: int, + pitch_aug: int | None = None, + ) -> list[Token]: + def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token: + if isinstance(tok, str): # Stand in for SpecialToken + _tok_type = "special" + else: + _tok_type = tok[0] + + if ( + _tok_type == "special" + or _tok_type == "prefix" + or _tok_type == "dur" + or _tok_type == "drum" + or _tok_type == "onset" + ): + # Return without changing + return tok + else: + # Return augmented tok + assert ( + isinstance(tok, tuple) and len(tok) == 3 + ), f"Invalid note token" + (_instrument, _pitch, _velocity) = tok + + assert isinstance( + _pitch, int + ), f"Expected int for pitch, got {_pitch}" + assert isinstance( + _velocity, int + ), f"Expected int for velocity, got {_velocity}" + + if 0 <= _pitch + _pitch_aug <= 127: + return (_instrument, _pitch + _pitch_aug, _velocity) + else: + return unk_tok + + if not pitch_aug: + pitch_aug = random.randint(-_aug_range, _aug_range) + + return [pitch_aug_tok(x, pitch_aug) for x in src] + + # See functools.partial docs + return self.export_aug_fn_concat( + functools.partial( + pitch_aug_seq, + unk_tok=self.unk_tok, + _aug_range=aug_range, + ) + ) + + def export_velocity_aug( + self, aug_steps_range: int + ) -> Callable[[list[Token]], list[Token]]: + """Exports a function which augments the velocity of all pitch tokens. + + Velocity values are clipped so that they don't fall outside of the + valid range. + + Args: + aug_steps_range (int): Returned function will randomly augment + velocity in the range aug_steps_range * (-self.velocity_step, + self.velocity step). + + Returns: + Callable[[list[Token]], list[Token]]: Exported function. + """ + + def velocity_aug_seq( + src: list[Token], + velocity_step: int, + max_velocity: int, + _aug_steps_range: int, + velocity_aug: int | None = None, + ) -> list[Token]: + def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token: + if isinstance(tok, str): # Stand in for SpecialToken + _tok_type = "special" + else: + _tok_type = tok[0] + + if ( + _tok_type == "special" + or _tok_type == "prefix" + or _tok_type == "dur" + or _tok_type == "drum" + or _tok_type == "onset" + ): + # Return without changing + return tok + else: + assert isinstance(tok, tuple) and len(tok) == 3 + (_instrument, _pitch, _velocity) = tok + + assert isinstance(_pitch, int) + assert isinstance(_velocity, int) + + # Check it doesn't go out of bounds + if _velocity + _velocity_aug >= max_velocity: + return (_instrument, _pitch, max_velocity) + elif _velocity + _velocity_aug <= velocity_step: + return (_instrument, _pitch, velocity_step) + + return (_instrument, _pitch, _velocity + _velocity_aug) + + if not velocity_aug: + velocity_aug = velocity_step * random.randint( + -_aug_steps_range, _aug_steps_range + ) + + return [velocity_aug_tok(x, velocity_aug) for x in src] + + # See functools.partial docs + return self.export_aug_fn_concat( + functools.partial( + velocity_aug_seq, + velocity_step=self.velocity_step, + max_velocity=self.max_velocity, + _aug_steps_range=aug_steps_range, + ) + ) + + # TODO: Adjust this so it can handle other tokens like + def export_tempo_aug( + self, tempo_aug_range: float, mixup: bool + ) -> Callable[[list[Token]], list[Token]]: + """Exports a function which augments the tempo of a sequence of tokens. + + Additionally this function performs note-mixup: randomly re-ordering + the note subsequences which occur on the same onset. + + IMPORTANT: This function doesn't support additional tokens. If you + have modified the tokenizer in any way, this should not be added to + export_data_aug. + + Args: + tempo_aug_range (int): Returned function will randomly augment + tempo by a factor in the range (1 - tempo_aug_range, + 1 + tempo_aug_range). + + Returns: + Callable[[list[Token]], list[Token]]: Exported function. + """ + + def tempo_aug( + src: list[Token], + abs_time_step: int, + max_dur: int, + time_step: int, + unk_tok: str, + time_tok: str, + dim_tok: str, + start_tok: str, + end_tok: str, + instruments_wd: list, + tokenizer_name: str, + _tempo_aug_range: float, + _mixup: bool, + tempo_aug: float | None = None, + ) -> list[Token]: + """This must be used with export_aug_fn_concat in order to work + properly for concatenated sequences.""" + + def _quantize_time(_n: int) -> int: + return round(_n / time_step) * time_step + + assert ( + tokenizer_name == "abs" + ), f"Augmentation function only supports base AbsTokenizer" + + if not tempo_aug: + tempo_aug = random.uniform( + 1 - _tempo_aug_range, 1 + _tempo_aug_range + ) + + src_time_tok_cnt = 0 + dim_tok_seen = None + res: list[Token] = [] + note_buffer: dict[str, Token | None] | None = None + + # Buffer tracks + buffer: dict[int, dict[int, list[dict[str, Token | None]]]] = ( + defaultdict(lambda: defaultdict(list)) + ) + for tok_1, tok_2, tok_3 in zip(src, src[1:], src[2:]): + if tok_1 == time_tok: + _tok_type = "time" + elif tok_1 == unk_tok: + _tok_type = "unk" + elif tok_1 == start_tok: + res.append(tok_1) + continue + elif tok_1 == dim_tok and note_buffer: + assert isinstance(note_buffer["onset"], int) + dim_tok_seen = (src_time_tok_cnt, note_buffer["onset"][1]) + continue + elif tok_1[0] == "prefix": + res.append(tok_1) + continue + elif tok_1[0] in instruments_wd: + _tok_type = tok_1[0] + else: + # This only triggers for incomplete notes at the beginning, + # e.g. an onset token before a note token is seen + continue + + if _tok_type == "time": + src_time_tok_cnt += 1 + elif _tok_type == "drum": + assert isinstance(tok_2[1], int) + note_buffer = { + "note": tok_1, + "onset": tok_2, + "dur": None, + } + buffer[src_time_tok_cnt][tok_2[1]].append(note_buffer) + else: # unk or in instruments_wd + assert isinstance(tok_2[1], int) + note_buffer = { + "note": tok_1, + "onset": tok_2, + "dur": tok_3, + } + buffer[src_time_tok_cnt][tok_2[1]].append(note_buffer) + + prev_tgt_time_tok_cnt = 0 + for src_time_tok_cnt, interval_notes in sorted(buffer.items()): + for src_onset, notes_by_onset in sorted(interval_notes.items()): + src_time = src_time_tok_cnt * abs_time_step + src_onset + tgt_time = round(src_time * tempo_aug) + curr_tgt_time_tok_cnt = tgt_time // abs_time_step + curr_tgt_onset = _quantize_time(tgt_time % abs_time_step) + + if curr_tgt_onset == abs_time_step: + curr_tgt_onset -= time_step + + for _ in range( + curr_tgt_time_tok_cnt - prev_tgt_time_tok_cnt + ): + res.append(time_tok) + prev_tgt_time_tok_cnt = curr_tgt_time_tok_cnt + + if _mixup == True: + random.shuffle(notes_by_onset) + + for note in notes_by_onset: + _src_note_tok = note["note"] + _src_dur_tok = note["dur"] + assert _src_note_tok is not None + + if _src_dur_tok is not None: + assert isinstance(_src_dur_tok[1], int) + tgt_dur = _quantize_time( + round(_src_dur_tok[1] * tempo_aug) + ) + tgt_dur = min(tgt_dur, max_dur) + else: + tgt_dur = None + + res.append(_src_note_tok) + res.append(("onset", curr_tgt_onset)) + if tgt_dur: + res.append(("dur", tgt_dur)) + + if dim_tok_seen is not None and dim_tok_seen == ( + src_time_tok_cnt, + src_onset, + ): + res.append(dim_tok) + dim_tok_seen = None + + if src[-1] == end_tok: + res.append(end_tok) + + return res + + return self.export_aug_fn_concat( + functools.partial( + tempo_aug, + abs_time_step=self.abs_time_step, + max_dur=self.max_dur, + time_step=self.time_step, + unk_tok=self.unk_tok, + time_tok=self.time_tok, + dim_tok=self.dim_tok, + end_tok=self.eos_tok, + start_tok=self.bos_tok, + instruments_wd=self.instruments_wd, + tokenizer_name=self.name, + _tempo_aug_range=tempo_aug_range, + _mixup=mixup, + ) + ) diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index eb1ccad..c2ccbbf 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -4,19 +4,24 @@ import logging from importlib import resources -from typing import Dict, Any, cast +from typing import Any, cast from .config import load_config -def get_logger(name: str) -> logging.Logger: +def get_logger(name: str | None) -> logging.Logger: logger = logging.getLogger(name) if not logger.handlers: logger.propagate = False logger.setLevel(logging.DEBUG) - formatter = logging.Formatter( - "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" - ) + if name is not None: + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" + ) + else: + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] %(message)s" + ) ch = logging.StreamHandler() ch.setLevel(logging.INFO) @@ -26,14 +31,14 @@ def get_logger(name: str) -> logging.Logger: return logger -def load_maestro_metadata_json() -> Dict[str, Any]: +def load_maestro_metadata_json() -> dict[str, Any]: """Loads MAESTRO metadata json .""" with ( resources.files("ariautils.config") .joinpath("maestro_metadata.json") .open("r") as f ): - return cast(Dict[str, Any], json.load(f)) + return cast(dict[str, Any], json.load(f)) __all__ = ["load_config", "load_maestro_metadata_json", "get_logger"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index a4fd267..4093318 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -1,17 +1,16 @@ """Includes functionality for loading config files.""" -import os import json from importlib import resources -from typing import Dict, Any, cast +from typing import Any, cast -def load_config() -> Dict[str, Any]: +def load_config() -> dict[str, Any]: """Returns a dictionary loaded from the config.json file.""" with ( resources.files("ariautils.config") .joinpath("config.json") .open("r") as f ): - return cast(Dict[str, Any], json.load(f)) + return cast(dict[str, Any], json.load(f)) diff --git a/tests/test_midi.py b/tests/test_midi.py index 74a7402..5d6aa02 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -1,4 +1,6 @@ import unittest +import tempfile +import shutil from importlib import resources from pathlib import Path @@ -40,6 +42,33 @@ def test_save(self) -> None: midi_dict = MidiDict.from_midi(mid_path=load_path) midi_dict.to_midi().save(save_path) + def test_tick_to_ms(self) -> None: + CORRECT_LAST_NOTE_ONSET_MS: Final[int] = 220140 + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(load_path) + last_note = midi_dict.note_msgs[-1] + last_note_onset_tick = last_note["tick"] + last_note_onset_ms = midi_dict.tick_to_ms(last_note_onset_tick) + self.assertEqual(last_note_onset_ms, CORRECT_LAST_NOTE_ONSET_MS) + + def test_calculate_hash(self) -> None: + # Load two identical files with different filenames and metadata + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict_orig = MidiDict.from_midi(load_path) + + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + shutil.copy(load_path, temp_file.name) + midi_dict_temp = MidiDict.from_midi(temp_file.name) + + midi_dict_temp.meta_msgs.append({"type": "text", "data": "test"}) + midi_dict_temp.metadata["composer"] = "test" + midi_dict_temp.metadata["composer"] = "test" + midi_dict_temp.metadata["ticks_per_beat"] = -1 + + self.assertEqual( + midi_dict_orig.calculate_hash(), midi_dict_temp.calculate_hash() + ) + def test_resolve_pedal(self) -> None: load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") save_path = RESULTS_DATA_DIRECTORY.joinpath(