From c3efd257bc06b116a95b64b341c87c6068a6ea40 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 16 Nov 2024 14:13:31 +0000 Subject: [PATCH 01/12] add skeleton --- README.md | 3 +- ariautils/ __init__.py | 0 ariautils/utils/__init__.py | 0 ariautils/utils/config.py | 0 pyproject.toml | 64 +++++++++++++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 ariautils/ __init__.py create mode 100644 ariautils/utils/__init__.py create mode 100644 ariautils/utils/config.py create mode 100644 pyproject.toml diff --git a/README.md b/README.md index 048cea7..2f51d57 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ -# ariautils +# aria-utils + MIDI tokenizers and pre-processing utils. diff --git a/ariautils/ __init__.py b/ariautils/ __init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7275cd0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "ariautils" +version = "0.0.1" +description = "" +authors = [{name = "Louis Bradshaw", email = "loua19@outlook.com"}] +requires-python = ">=3.11" +license = {text = "Apache-2.0"} +dependencies = [ + "mido", +] +readme = "README.md" +keywords = [] +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] + +[project.urls] +Repository = "https://github.com/EleutherAI/aria-utils" + +[project.optional-dependencies] +dev = [ + "mypy", + "black", +] + + +[tool.setuptools] +packages = ["ariautils"] +include-package-data = true + +[tool.setuptools.package-data] +"ariautils.config" = ["*.json"] + +[tool.black] +line-length = 80 +target-version = ["py311"] +include = '\.pyi?$' + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[tool.pytest] +testpaths = ["tests"] +python_files = "test_*.py" +addopts = "-ra -q" From a4ec0df781d7a3a49b921394a12da27505813d57 Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 17 Nov 2024 22:07:19 +0000 Subject: [PATCH 02/12] port midi.py --- README.md | 4 +- ariautils/config/config.json | 242 +++++++ ariautils/midi.py | 1145 ++++++++++++++++++++++++++++++++++ ariautils/utils/__init__.py | 3 + ariautils/utils/config.py | 16 + pyproject.toml | 5 +- 6 files changed, 1413 insertions(+), 2 deletions(-) create mode 100644 ariautils/config/config.json create mode 100644 ariautils/midi.py diff --git a/README.md b/README.md index 2f51d57..45a7e1f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ # aria-utils -MIDI tokenizers and pre-processing utils. +An extremely lightweight and simple library for pre-processing and tokenizing MIDI files. + + diff --git a/ariautils/config/config.json b/ariautils/config/config.json new file mode 100644 index 0000000..f94d753 --- /dev/null +++ b/ariautils/config/config.json @@ -0,0 +1,242 @@ +{ + "data": { + "tests": { + "max_programs":{ + "run": false, + "args": { + "max": 12 + } + }, + "max_instruments":{ + "run": false, + "args": { + "max": 7 + } + }, + "total_note_frequency":{ + "run": false, + "args": { + "min_per_second": 1.5, + "max_per_second": 30 + } + }, + "note_frequency_per_instrument":{ + "run": false, + "args": { + "min_per_second": 1.0, + "max_per_second": 25 + } + }, + "min_length":{ + "run": false, + "args": { + "min_seconds": 30 + } + } + }, + "pre_processing": { + "remove_instruments": { + "run": true, + "args": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + } + } + }, + "metadata": { + "functions": { + "composer_filename": { + "run": false, + "args": { + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + } + }, + "composer_metamsg": { + "run": false, + "args": { + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + } + }, + "form_filename": { + "run": false, + "args": { + "form_names": ["sonata", "prelude", "nocturne", "etude", "waltz", "mazurka", "impromptu", "fugue"] + } + }, + "maestro_json": { + "run": true, + "args": { + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"] + } + }, + "listening_model": { + "run": false, + "args": { + "tag_names": ["happy", "sad"] + } + }, + "abs_path": { + "run": true, + "args": {} + } + }, + "manual": { + "genre": ["classical", "jazz"], + "form": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "composer": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + } + }, + "finetuning": { + "min_noisy_interval_ms": 5000, + "max_noisy_interval_ms": 60000, + "min_clean_interval_ms": 60000, + "max_clean_interval_ms": 200000, + "noising": { + "activation_prob": 0.95, + "remove_notes": { + "activation_prob": 0.75, + "min_ratio": 0.1, + "max_ratio": 0.4 + }, + "adjust_velocity": { + "activation_prob": 0.3, + "min_adjust": 1, + "max_adjust": 30, + "max_ratio": 0.1, + "min_ratio": 0.30 + }, + "adjust_onsets": { + "activation_prob": 0.5, + "min_adjust_s": 0.03, + "max_adjust_s": 0.07, + "max_ratio": 0.15, + "min_ratio": 0.5 + }, + "quantize_onsets": { + "activation_prob": 0.15, + "min_quant_s": 0.05, + "max_quant_s": 0.15, + "max_vel_delta": 45 + } + } + } + }, + + "tokenizer": { + "rel": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization": { + "step": 15 + }, + "time_quantization": { + "num_steps": 500, + "step": 10 + }, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "genre_names": ["jazz", "classical"] + }, + "abs": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization": { + "step": 10 + }, + "abs_time_step_ms": 5000, + "max_dur_ms": 5000, + "time_step_ms": 10, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "genre_names": ["jazz", "classical"] + }, + "lm": { + "tags": ["happy", "sad"] + } + } +} diff --git a/ariautils/midi.py b/ariautils/midi.py new file mode 100644 index 0000000..7f5382a --- /dev/null +++ b/ariautils/midi.py @@ -0,0 +1,1145 @@ +"""Utils for data/MIDI processing.""" + +import re +import os +import json +import hashlib +import unicodedata +import mido + +from collections import defaultdict +from pathlib import Path +from typing import ( + List, + Dict, + Any, + Tuple, + ParamSpec, + Concatenate, + Callable, + TypeAlias, + Literal, + TypedDict, +) + +from mido.midifiles.units import tick2second +from ariautils.utils import load_config + + +class MetaMessage(TypedDict): + """Meta message type corresponding text or copyright MIDI meta messages.""" + + type: Literal["text", "copyright"] + data: str + + +class TempoMessage(TypedDict): + """Tempo message type corresponding to the set_tempo MIDI message.""" + + type: Literal["tempo"] + data: int + tick: int + + +class PedalMessage(TypedDict): + """Sustain pedal message type corresponding to control_change 64 MIDI messages.""" + + type: Literal["pedal"] + data: Literal[0, 1] # 0 for off, 1 for on + tick: int + channel: int + + +class InstrumentMessage(TypedDict): + """Instrument message type corresponding to program_change MIDI messages.""" + + type: Literal["instrument"] + data: int + tick: int + channel: int + + +class NoteData(TypedDict): + pitch: int + start: int + end: int + velocity: int + + +class NoteMessage(TypedDict): + """Note message type corresponding to paired note_on and note_off MIDI messages.""" + + type: Literal["note"] + data: NoteData + tick: int + channel: int + + +MidiMessage: TypeAlias = ( + MetaMessage | TempoMessage | PedalMessage | InstrumentMessage | NoteMessage +) + + +class MidiDictData(TypedDict): + """Type for MidiDict attributes in dictionary form.""" + + meta_msgs: List[MetaMessage] + tempo_msgs: List[TempoMessage] + pedal_msgs: List[PedalMessage] + instrument_msgs: List[InstrumentMessage] + note_msgs: List[NoteMessage] + ticks_per_beat: int + metadata: Dict[str, Any] + + +class MidiDict: + """Container for MIDI data in dictionary form. + + Args: + meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages. + tempo_msgs (List[TempoMessage]): List of tempo change messages. + pedal_msgs (List[PedalMessage]): List of sustain pedal messages. + instrument_msgs (List[InstrumentMessage]): List of program change messages. + note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events. + ticks_per_beat (int): MIDI ticks per beat. + metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}). + """ + + def __init__( + self, + meta_msgs: List[MetaMessage], + tempo_msgs: List[TempoMessage], + pedal_msgs: List[PedalMessage], + instrument_msgs: List[InstrumentMessage], + note_msgs: List[NoteMessage], + ticks_per_beat: int, + metadata: Dict[str, Any], + ): + self.meta_msgs = meta_msgs + self.tempo_msgs = tempo_msgs + self.pedal_msgs = pedal_msgs + self.instrument_msgs = instrument_msgs + self.note_msgs = sorted(note_msgs, key=lambda msg: msg["tick"]) + self.ticks_per_beat = ticks_per_beat + self.metadata = metadata + + # Tracks if resolve_pedal() has been called. + self.pedal_resolved = False + + # If tempo_msgs is empty, initalize to default + if not self.tempo_msgs: + DEFAULT_TEMPO_MSG: TempoMessage = { + "type": "tempo", + "data": 500000, + "tick": 0, + } + self.tempo_msgs = [DEFAULT_TEMPO_MSG] + # If tempo_msgs is empty, initalize to default (piano) + if not self.instrument_msgs: + DEFAULT_INSTRUMENT_MSG: InstrumentMessage = { + "type": "instrument", + "data": 0, + "tick": 0, + "channel": 0, + } + self.instrument_msgs = [DEFAULT_INSTRUMENT_MSG] + + self.program_to_instrument = self.get_program_to_instrument() + + @classmethod + def get_program_to_instrument(cls) -> Dict[int, str]: + """Get map of MIDI program to instrument name.""" + + return ( + {i: "piano" for i in range(0, 7 + 1)} + | {i: "chromatic" for i in range(8, 15 + 1)} + | {i: "organ" for i in range(16, 23 + 1)} + | {i: "guitar" for i in range(24, 31 + 1)} + | {i: "bass" for i in range(32, 39 + 1)} + | {i: "strings" for i in range(40, 47 + 1)} + | {i: "ensemble" for i in range(48, 55 + 1)} + | {i: "brass" for i in range(56, 63 + 1)} + | {i: "reed" for i in range(64, 71 + 1)} + | {i: "pipe" for i in range(72, 79 + 1)} + | {i: "synth_lead" for i in range(80, 87 + 1)} + | {i: "synth_pad" for i in range(88, 95 + 1)} + | {i: "synth_effect" for i in range(96, 103 + 1)} + | {i: "ethnic" for i in range(104, 111 + 1)} + | {i: "percussive" for i in range(112, 119 + 1)} + | {i: "sfx" for i in range(120, 127 + 1)} + ) + + def get_msg_dict(self) -> MidiDictData: + """Returns MidiDict data in dictionary form.""" + + return { + "meta_msgs": self.meta_msgs, + "tempo_msgs": self.tempo_msgs, + "pedal_msgs": self.pedal_msgs, + "instrument_msgs": self.instrument_msgs, + "note_msgs": self.note_msgs, + "ticks_per_beat": self.ticks_per_beat, + "metadata": self.metadata, + } + + def to_midi(self) -> mido.MidiFile: + """Inplace version of dict_to_midi.""" + + return dict_to_midi(self.get_msg_dict()) + + @classmethod + def from_msg_dict(cls, msg_dict: MidiDictData) -> "MidiDict": + """Inplace version of midi_to_dict.""" + + assert msg_dict.keys() == { + "meta_msgs", + "tempo_msgs", + "pedal_msgs", + "instrument_msgs", + "note_msgs", + "ticks_per_beat", + "metadata", + } + + return cls(**msg_dict) + + @classmethod + def from_midi(cls, mid_path: str | Path) -> "MidiDict": + """Loads a MIDI file from path and returns MidiDict.""" + + mid = mido.MidiFile(mid_path) + return cls(**midi_to_dict(mid)) + + def calculate_hash(self) -> str: + msg_dict_to_hash = dict(self.get_msg_dict()) + + # Remove metadata before calculating hash + msg_dict_to_hash.pop("meta_msgs") + msg_dict_to_hash.pop("ticks_per_beat") + msg_dict_to_hash.pop("metadata") + + return hashlib.md5( + json.dumps(msg_dict_to_hash, sort_keys=True).encode() + ).hexdigest() + + def tick_to_ms(self, tick: int) -> int: + """Calculate the time (in milliseconds) in current file at a MIDI tick.""" + + return get_duration_ms( + start_tick=0, + end_tick=tick, + tempo_msgs=self.tempo_msgs, + ticks_per_beat=self.ticks_per_beat, + ) + + def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]: + """Returns a mapping of channels to sustain pedal intervals.""" + + self.pedal_msgs.sort(key=lambda msg: msg["tick"]) + channel_to_pedal_intervals = defaultdict(list) + pedal_status: Dict[int, int] = {} + + for pedal_msg in self.pedal_msgs: + tick = pedal_msg["tick"] + channel = pedal_msg["channel"] + data = pedal_msg["data"] + + if data == 1 and pedal_status.get(channel, None) is None: + pedal_status[channel] = tick + elif data == 0 and pedal_status.get(channel, None) is not None: + # Close pedal interval + _start_tick = pedal_status[channel] + _end_tick = tick + channel_to_pedal_intervals[channel].append( + [_start_tick, _end_tick] + ) + del pedal_status[channel] + + # Close all unclosed pedals at end of track + final_tick = self.note_msgs[-1]["data"]["end"] + for channel, start_tick in pedal_status.items(): + channel_to_pedal_intervals[channel].append([start_tick, final_tick]) + + return channel_to_pedal_intervals + + def resolve_overlaps(self) -> "MidiDict": + """Resolves any note overlaps (inplace) between notes with the same + pitch and channel. This is achieved by converting a pair of notes with + the same pitch (a0): + + [a, b+x], [b-y, c] -> [a, b-y], [b-y, c] + + Note that this should not occur if the note messages have not been + modified, e.g., by resolve_overlap(). + """ + + # Organize notes by channel and pitch + note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict( + lambda: defaultdict(list) + ) + for msg in self.note_msgs: + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + note_msgs_c[_channel][_pitch].append(msg) + + # We can modify notes by reference as they are dictionaries + for channel, msgs_by_pitch in note_msgs_c.items(): + for pitch, msgs in msgs_by_pitch.items(): + msgs.sort( + key=lambda msg: (msg["data"]["start"], msg["data"]["end"]) + ) + prev_off_tick = -1 + for idx, msg in enumerate(msgs): + on_tick = msg["data"]["start"] + off_tick = msg["data"]["end"] + if prev_off_tick > on_tick: + # Adjust end of previous (idx - 1) msg to remove overlap + msgs[idx - 1]["data"]["end"] = on_tick + prev_off_tick = off_tick + + return self + + def resolve_pedal(self) -> "MidiDict": + """Extend note offsets according to pedal and resolve any note overlaps""" + + # If has been already resolved, we don't recalculate + if self.pedal_resolved == True: + print("Pedal has already been resolved") + + # Organize note messages by channel + note_msgs_c = defaultdict(list) + for msg in self.note_msgs: + _channel = msg["channel"] + note_msgs_c[_channel].append(msg) + + # We can modify notes by reference as they are dictionaries + channel_to_pedal_intervals = self._build_pedal_intervals() + for channel, msgs in note_msgs_c.items(): + for msg in msgs: + note_end_tick = msg["data"]["end"] + for pedal_interval in channel_to_pedal_intervals[channel]: + pedal_start, pedal_end = pedal_interval + if pedal_start < note_end_tick < pedal_end: + msg["data"]["end"] = pedal_end + break + + self.resolve_overlaps() + self.pedal_resolved = True + + return self + + # TODO: Needs to be refactored and tested + def remove_redundant_pedals(self) -> "MidiDict": + """Removes redundant pedal messages from the MIDI data in place. + + Removes all pedal on/off message pairs that don't extend any notes. + Makes an exception for pedal off messages that coincide exactly with + note offsets. + """ + + def _is_pedal_useful( + pedal_start_tick: int, + pedal_end_tick: int, + note_msgs: List[NoteMessage], + ) -> bool: + # This logic loops through the note_msgs that could possibly + # be effected by the pedal which starts at pedal_start_tick + # and ends at pedal_end_tick. If there is note effected by the + # pedal, then it returns early. + + note_idx = 0 + note_msg = note_msgs[0] + note_start = note_msg["data"]["start"] + + while note_start <= pedal_end_tick and note_idx < len(note_msgs): + note_msg = note_msgs[note_idx] + note_start, note_end = ( + note_msg["data"]["start"], + note_msg["data"]["end"], + ) + + if pedal_start_tick <= note_end <= pedal_end_tick: + # Found note for which pedal is useful + return True + + note_idx += 1 + + return False + + def _process_channel_pedals(channel: int) -> None: + pedal_msg_idxs_to_remove = [] + pedal_down_tick = None + pedal_down_msg_idx = None + + note_msgs = [ + msg for msg in self.note_msgs if msg["channel"] == channel + ] + + if not note_msgs: + # No notes to process. In this case we remove all pedal_msgs + # and then return early. + for pedal_msg_idx, pedal_msg in enumerate(self.pedal_msgs): + pedal_msg_value, pedal_msg_tick, _channel = ( + pedal_msg["data"], + pedal_msg["tick"], + pedal_msg["channel"], + ) + + if _channel == channel: + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + + # Remove messages + self.pedal_msgs = [ + msg + for _idx, msg in enumerate(self.pedal_msgs) + if _idx not in pedal_msg_idxs_to_remove + ] + return + + for pedal_msg_idx, pedal_msg in enumerate(self.pedal_msgs): + pedal_msg_value, pedal_msg_tick, _channel = ( + pedal_msg["data"], + pedal_msg["tick"], + pedal_msg["channel"], + ) + + # Only process pedal_msgs for specified MIDI channel + if _channel != channel: + continue + + # Remove never-closed pedal messages + if ( + pedal_msg_idx == len(self.pedal_msgs) - 1 + and pedal_msg_value == 1 + ): + # Current msg is last one and ON -> remove curr pedal_msg + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + + # Logic for removing repeated pedal messages and updating + # pedal_down_tick and pedal_down_idx + if pedal_down_tick is None: + if pedal_msg_value == 1: + # Pedal is OFF and current msg is ON -> update + pedal_down_tick = pedal_msg_tick + pedal_down_msg_idx = pedal_msg_idx + continue + else: + # Pedal is OFF and current msg is OFF -> remove curr pedal_msg + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + continue + else: + if pedal_msg_value == 1: + # Pedal is ON and current msg is ON -> remove curr pedal_msg + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + continue + + pedal_is_useful = _is_pedal_useful( + pedal_start_tick=pedal_down_tick, + pedal_end_tick=pedal_msg_tick, + note_msgs=note_msgs, + ) + + if pedal_is_useful is False: + # Pedal hasn't effected any notes -> remove + pedal_msg_idxs_to_remove.append(pedal_down_msg_idx) + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + + # Finished processing pedal, set pedal state to OFF + pedal_down_tick = None + pedal_down_msg_idx = None + + # Remove messages + self.pedal_msgs = [ + msg + for _idx, msg in enumerate(self.pedal_msgs) + if _idx not in pedal_msg_idxs_to_remove + ] + + for channel in set([msg["channel"] for msg in self.pedal_msgs]): + _process_channel_pedals(channel) + + return self + + def remove_instruments(self, config: dict) -> "MidiDict": + """Removes all messages with instruments specified in config at: + + data.preprocessing.remove_instruments + + Note that drum messages, defined as those which occur on MIDI channel 9 + are not removed. + """ + + programs_to_remove = [ + i + for i in range(1, 127 + 1) + if config[self.program_to_instrument[i]] is True + ] + channels_to_remove = [ + msg["channel"] + for msg in self.instrument_msgs + if msg["data"] in programs_to_remove + ] + + # Remove drums (channel 9) from channels to remove + channels_to_remove = [i for i in channels_to_remove if i != 9] + + # Remove unwanted messages all type by looping over msgs types + _msg_dict: Dict[str, List] = { + "meta_msgs": self.meta_msgs, + "tempo_msgs": self.tempo_msgs, + "pedal_msgs": self.pedal_msgs, + "instrument_msgs": self.instrument_msgs, + "note_msgs": self.note_msgs, + } + + for msgs_name, msgs_list in _msg_dict.items(): + setattr( + self, + msgs_name, + [ + msg + for msg in msgs_list + if msg.get("channel", -1) not in channels_to_remove + ], + ) + + return self + + +# TODO: The sign has been changed. Make sure this function isn't used anywhere else +def _extract_track_data( + track: mido.MidiTrack, +) -> Tuple[ + List[MetaMessage], + List[TempoMessage], + List[PedalMessage], + List[InstrumentMessage], + List[NoteMessage], +]: + """Converts MIDI messages into format used by MidiDict.""" + + meta_msgs: List[MetaMessage] = [] + tempo_msgs: List[TempoMessage] = [] + pedal_msgs: List[PedalMessage] = [] + instrument_msgs: List[InstrumentMessage] = [] + note_msgs: List[NoteMessage] = [] + + last_note_on = defaultdict(list) + for message in track: + # Meta messages + if message.is_meta is True: + if message.type == "text" or message.type == "copyright": + meta_msgs.append( + { + "type": message.type, + "data": message.text, + } + ) + # Tempo messages + elif message.type == "set_tempo": + tempo_msgs.append( + { + "type": "tempo", + "data": message.tempo, + "tick": message.time, + } + ) + # Instrument messages + elif message.type == "program_change": + instrument_msgs.append( + { + "type": "instrument", + "data": message.program, + "tick": message.time, + "channel": message.channel, + } + ) + # Pedal messages + elif message.type == "control_change" and message.control == 64: + # Consistent with pretty_midi and ableton-live default behavior + pedal_msgs.append( + { + "type": "pedal", + "data": 0 if message.value < 64 else 1, + "tick": message.time, + "channel": message.channel, + } + ) + # Note messages + elif message.type == "note_on" and message.velocity > 0: + last_note_on[(message.note, message.channel)].append( + (message.time, message.velocity) + ) + elif message.type == "note_off" or ( + message.type == "note_on" and message.velocity == 0 + ): + # Ignore non-existent note-ons + if (message.note, message.channel) in last_note_on: + end_tick = message.time + open_notes = last_note_on[(message.note, message.channel)] + + notes_to_close = [ + (start_tick, velocity) + for start_tick, velocity in open_notes + if start_tick != end_tick + ] + notes_to_keep = [ + (start_tick, velocity) + for start_tick, velocity in open_notes + if start_tick == end_tick + ] + + for start_tick, velocity in notes_to_close: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": message.note, + "start": start_tick, + "end": end_tick, + "velocity": velocity, + }, + "tick": start_tick, + "channel": message.channel, + } + ) + + if len(notes_to_close) > 0 and len(notes_to_keep) > 0: + # Note-on on the same tick but we already closed + # some previous notes -> it will continue, keep it. + last_note_on[(message.note, message.channel)] = ( + notes_to_keep + ) + else: + # Remove the last note on for this instrument + del last_note_on[(message.note, message.channel)] + + return meta_msgs, tempo_msgs, pedal_msgs, instrument_msgs, note_msgs + + +def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: + """Converts mid.MidiFile into MidiDictData representation. + + Additionally runs metadata extraction according to config specified at: + + data.metadata.functions + + Args: + mid (mido.MidiFile): A mido file object to parse. + + Returns: + MidiDictData: A dictionary containing extracted MIDI data including notes, + time signatures, key signatures, and other musical events. + """ + + metadata_config = load_config()["data"]["metadata"] + # Convert time in mid to absolute + for track in mid.tracks: + curr_tick = 0 + for message in track: + message.time += curr_tick + curr_tick = message.time + + midi_dict_data: MidiDictData = { + "meta_msgs": [], + "tempo_msgs": [], + "pedal_msgs": [], + "instrument_msgs": [], + "note_msgs": [], + "ticks_per_beat": mid.ticks_per_beat, + "metadata": {}, + } + + # Compile track data + for mid_track in mid.tracks: + meta_msgs, tempo_msgs, pedal_msgs, instrument_msgs, note_msgs = ( + _extract_track_data(mid_track) + ) + midi_dict_data["meta_msgs"] += meta_msgs + midi_dict_data["tempo_msgs"] += tempo_msgs + midi_dict_data["pedal_msgs"] += pedal_msgs + midi_dict_data["instrument_msgs"] += instrument_msgs + midi_dict_data["note_msgs"] += note_msgs + + # Sort by tick (for note msgs, this will be the same as data.start_tick) + midi_dict_data["tempo_msgs"] = sorted( + midi_dict_data["tempo_msgs"], key=lambda x: x["tick"] + ) + midi_dict_data["pedal_msgs"] = sorted( + midi_dict_data["pedal_msgs"], key=lambda x: x["tick"] + ) + midi_dict_data["instrument_msgs"] = sorted( + midi_dict_data["instrument_msgs"], key=lambda x: x["tick"] + ) + midi_dict_data["note_msgs"] = sorted( + midi_dict_data["note_msgs"], key=lambda x: x["tick"] + ) + + for metadata_process_name, metadata_process_config in metadata_config[ + "functions" + ].items(): + if metadata_process_config["run"] is True: + metadata_fn = get_metadata_fn( + metadata_process_name=metadata_process_name + ) + fn_args: Dict = metadata_process_config["args"] + + collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args) + if collected_metadata: + for k, v in collected_metadata.items(): + midi_dict_data["metadata"][k] = v + + return midi_dict_data + + +def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: + """Converts MIDI information from dictionary form into a mido.MidiFile. + + This function performs midi_to_dict in reverse. + + Args: + mid_data (dict): MIDI information in dictionary form. + + Returns: + mido.MidiFile: The MIDI parsed from the input data. + """ + + assert mid_data.keys() == { + "meta_msgs", + "tempo_msgs", + "pedal_msgs", + "instrument_msgs", + "note_msgs", + "ticks_per_beat", + "metadata", + }, "Invalid json/dict." + + ticks_per_beat = mid_data["ticks_per_beat"] + + # Add all messages (not ordered) to one track + track = mido.MidiTrack() + end_msgs = defaultdict(list) + + for tempo_msg in mid_data["tempo_msgs"]: + track.append( + mido.MetaMessage( + "set_tempo", tempo=tempo_msg["data"], time=tempo_msg["tick"] + ) + ) + + for pedal_msg in mid_data["pedal_msgs"]: + track.append( + mido.Message( + "control_change", + control=64, + value=pedal_msg["data"] + * 127, # Stored in PedalMessage as 1 or 0 + channel=pedal_msg["channel"], + time=pedal_msg["tick"], + ) + ) + + for instrument_msg in mid_data["instrument_msgs"]: + track.append( + mido.Message( + "program_change", + program=instrument_msg["data"], + channel=instrument_msg["channel"], + time=instrument_msg["tick"], + ) + ) + + for note_msg in mid_data["note_msgs"]: + # Note on + track.append( + mido.Message( + "note_on", + note=note_msg["data"]["pitch"], + velocity=note_msg["data"]["velocity"], + channel=note_msg["channel"], + time=note_msg["data"]["start"], + ) + ) + # Note off + end_msgs[(note_msg["channel"], note_msg["data"]["pitch"])].append( + (note_msg["data"]["start"], note_msg["data"]["end"]) + ) + + # Only add end messages that don't interfere with other notes + for k, v in end_msgs.items(): + channel, pitch = k + for start, end in v: + add = True + for _start, _end in v: + if start < _start < end < _end: + add = False + + if add is True: + track.append( + mido.Message( + "note_on", + note=pitch, + velocity=0, + channel=channel, + time=end, + ) + ) + + # Magic sorting function + def _sort_fn(msg: mido.Message) -> Tuple[int, int]: + if hasattr(msg, "velocity"): + return (msg.time, msg.velocity) + else: + return (msg.time, 1000) + + # Sort and convert from abs_time -> delta_time + track = sorted(track, key=_sort_fn) + tick = 0 + for msg in track: + msg.time -= tick + tick += msg.time + + track.append(mido.MetaMessage("end_of_track", time=0)) + mid = mido.MidiFile(type=0) + mid.ticks_per_beat = ticks_per_beat + mid.tracks.append(track) + + return mid + + +def get_duration_ms( + start_tick: int, + end_tick: int, + tempo_msgs: List[TempoMessage], + ticks_per_beat: int, +) -> int: + """Calculates elapsed time (in ms) between start_tick and end_tick.""" + + # Finds idx such that: + # tempo_msg[idx]["tick"] < start_tick <= tempo_msg[idx+1]["tick"] + for idx, curr_msg in enumerate(tempo_msgs): + if start_tick <= curr_msg["tick"]: + break + if idx > 0: # Special case idx == 0 -> Don't -1 + idx -= 1 + + # It is important that we initialise curr_tick & curr_tempo here. In the + # case that there is a single tempo message the following loop will not run. + duration = 0.0 + curr_tick = start_tick + curr_tempo = tempo_msgs[idx]["data"] + + # Sums all tempo intervals before tempo_msgs[-1]["tick"] + for curr_msg, next_msg in zip(tempo_msgs[idx:], tempo_msgs[idx + 1 :]): + curr_tempo = curr_msg["data"] + if end_tick < next_msg["tick"]: + delta_tick = end_tick - curr_tick + else: + delta_tick = next_msg["tick"] - curr_tick + + duration += tick2second( + tick=delta_tick, + tempo=curr_tempo, + ticks_per_beat=ticks_per_beat, + ) + + if end_tick < next_msg["tick"]: + break + else: + curr_tick = next_msg["tick"] + + # Case end_tick > tempo_msgs[-1]["tick"] + if end_tick > tempo_msgs[-1]["tick"]: + curr_tempo = tempo_msgs[-1]["data"] + delta_tick = end_tick - curr_tick + + duration += tick2second( + tick=delta_tick, + tempo=curr_tempo, + ticks_per_beat=ticks_per_beat, + ) + + # Convert from seconds to milliseconds + duration = duration * 1e3 + duration = round(duration) + + return duration + + +def _match_word(text: str, word: str) -> bool: + def to_ascii(s: str) -> str: + # Remove accents + normalized = unicodedata.normalize("NFKD", s) + return "".join(c for c in normalized if not unicodedata.combining(c)) + + text = to_ascii(text) + word = to_ascii(word) + + # If name="bach" this pattern will match "bach", "Bach" or "BACH" if + # it is either proceeded or preceded by a "_" or " ". + pattern = ( + r"(^|[\s_])(" + + word.lower() + + r"|" + + word.upper() + + r"|" + + word.capitalize() + + r")([\s_]|$)" + ) + + if re.search(pattern, text, re.IGNORECASE): + return True + else: + return False + + +def meta_composer_filename( + mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list +) -> Dict[str, str]: + file_name = Path(str(mid.filename)).stem + matched_names_unique = set() + for name in composer_names: + if _match_word(file_name, name): + matched_names_unique.add(name) + + # Only return data if only one composer is found + matched_names = list(matched_names_unique) + if len(matched_names) == 1: + return {"composer": matched_names[0]} + else: + return {} + + +def meta_form_filename( + mid: mido.MidiFile, msg_data: MidiDictData, form_names: list +) -> Dict[str, str]: + file_name = Path(str(mid.filename)).stem + matched_names_unique = set() + for name in form_names: + if _match_word(file_name, name): + matched_names_unique.add(name) + + # Only return data if only one composer is found + matched_names = list(matched_names_unique) + if len(matched_names) == 1: + return {"form": matched_names[0]} + else: + return {} + + +def meta_composer_metamsg( + mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list +) -> Dict[str, str]: + matched_names_unique = set() + for msg in msg_data["meta_msgs"]: + for name in composer_names: + if _match_word(msg["data"], name): + matched_names_unique.add(name) + + # Only return data if only one composer is found + matched_names = list(matched_names_unique) + if len(matched_names) == 1: + return {"composer": matched_names[0]} + else: + return {} + + +# This should only be used when processing MAESTRO, it requires maestro.json +# to be in the working directory. This json files contains MAESTRO metadata in +# the form file_name: {"composer": str, "title": str} +def meta_maestro_json( + mid: mido.MidiFile, + msg_data: MidiDictData, + composer_names: list, + form_names: list, +) -> Dict[str, str]: + if os.path.isfile("maestro.json") is False: + print("MAESTRO metadata function enabled but ./maestro.json not found.") + return {} + + file_name = Path(str(mid.filename)).name + with open("maestro.json", "r") as f: + _file_name_without_ext = os.path.splitext(file_name)[0] + metadata = json.load(f).get(_file_name_without_ext + ".midi", None) + if metadata == None: + return {} + + matched_forms_unique = set() + for form in form_names: + if _match_word(metadata["title"], form): + matched_forms_unique.add(form) + + matched_composers_unique = set() + for composer in composer_names: + if _match_word(metadata["composer"], composer): + matched_composers_unique.add(composer) + + res = {} + matched_composers = list(matched_composers_unique) + matched_forms = list(matched_forms_unique) + if len(matched_forms) == 1: + res["form"] = matched_forms[0] + if len(matched_composers) == 1: + res["composer"] = matched_composers[0] + + return res + + +def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]: + return {"abs_path": str(Path(str(mid.filename)).absolute())} + + +def get_metadata_fn( + metadata_process_name: str, +) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]: + name_to_fn: Dict[ + str, + Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]], + ] = { + "composer_filename": meta_composer_filename, + "composer_metamsg": meta_composer_metamsg, + "form_filename": meta_form_filename, + "maestro_json": meta_maestro_json, + "abs_path": meta_abs_path, + } + + fn = name_to_fn.get(metadata_process_name, None) + if fn is None: + raise ValueError( + f"Error finding metadata function for {metadata_process_name}" + ) + else: + return fn + + +def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: + """Returns false if midi_dict uses more than {max} programs.""" + present_programs = set( + map( + lambda msg: msg["data"], + midi_dict.instrument_msgs, + ) + ) + + if len(present_programs) <= max: + return True, len(present_programs) + else: + return False, len(present_programs) + + +def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: + present_instruments = set( + map( + lambda msg: midi_dict.program_to_instrument[msg["data"]], + midi_dict.instrument_msgs, + ) + ) + + if len(present_instruments) <= max: + return True, len(present_instruments) + else: + return False, len(present_instruments) + + +def test_note_frequency( + midi_dict: MidiDict, max_per_second: float, min_per_second: float +) -> Tuple[bool, float]: + if not midi_dict.note_msgs: + return False, 0.0 + + num_notes = len(midi_dict.note_msgs) + total_duration_ms = get_duration_ms( + start_tick=midi_dict.note_msgs[0]["data"]["start"], + end_tick=midi_dict.note_msgs[-1]["data"]["end"], + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + if total_duration_ms == 0: + return False, 0.0 + + notes_per_second = (num_notes * 1e3) / total_duration_ms + + if notes_per_second < min_per_second or notes_per_second > max_per_second: + return False, notes_per_second + else: + return True, notes_per_second + + +def test_note_frequency_per_instrument( + midi_dict: MidiDict, max_per_second: float, min_per_second: float +) -> Tuple[bool, float]: + num_instruments = len( + set( + map( + lambda msg: midi_dict.program_to_instrument[msg["data"]], + midi_dict.instrument_msgs, + ) + ) + ) + + if not midi_dict.note_msgs: + return False, 0.0 + + num_notes = len(midi_dict.note_msgs) + total_duration_ms = get_duration_ms( + start_tick=midi_dict.note_msgs[0]["data"]["start"], + end_tick=midi_dict.note_msgs[-1]["data"]["end"], + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + if total_duration_ms == 0: + return False, 0.0 + + notes_per_second = (num_notes * 1e3) / total_duration_ms + + note_freq_per_instrument = notes_per_second / num_instruments + if ( + note_freq_per_instrument < min_per_second + or note_freq_per_instrument > max_per_second + ): + return False, note_freq_per_instrument + else: + return True, note_freq_per_instrument + + +def test_min_length( + midi_dict: MidiDict, min_seconds: int +) -> Tuple[bool, float]: + if not midi_dict.note_msgs: + return False, 0.0 + + total_duration_ms = get_duration_ms( + start_tick=midi_dict.note_msgs[0]["data"]["start"], + end_tick=midi_dict.note_msgs[-1]["data"]["end"], + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + if total_duration_ms / 1e3 < min_seconds: + return False, total_duration_ms / 1e3 + else: + return True, total_duration_ms / 1e3 + + +def get_test_fn( + test_name: str, +) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]: + name_to_fn: Dict[ + str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]] + ] = { + "max_programs": test_max_programs, + "max_instruments": test_max_instruments, + "total_note_frequency": test_note_frequency, + "note_frequency_per_instrument": test_note_frequency_per_instrument, + "min_length": test_min_length, + } + + fn = name_to_fn.get(test_name, None) + if fn is None: + raise ValueError( + f"Error finding preprocessing function for {test_name}" + ) + else: + return fn diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index e69de29..fef643b 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -0,0 +1,3 @@ +from .config import load_config + +__all__ = ["load_config"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index e69de29..20c0f6c 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -0,0 +1,16 @@ +"""Includes functionality for loading config files.""" + +import json + +from importlib import resources +from typing import Dict, Any, cast + + +def load_config() -> Dict[str, Any]: + """Returns a dictionary loaded from the config.json file.""" + with ( + resources.files("ariautils.config") + .joinpath("config.json") + .open("r") as f + ): + return cast(Dict[str, Any], json.load(f)) diff --git a/pyproject.toml b/pyproject.toml index 7275cd0..724a51b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,8 +55,11 @@ no_implicit_optional = true warn_redundant_casts = true warn_unused_ignores = true warn_no_return = true -warn_unreachable = true +warn_unreachable = false strict_equality = true +ignore_missing_imports = true +namespace_packages = true +explicit_package_bases = true [tool.pytest] testpaths = ["tests"] From 70e6a803defaefb505d6846bcc2727d230cef578 Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 17 Nov 2024 22:20:36 +0000 Subject: [PATCH 03/12] update path for maestro metadata json --- ariautils/midi.py | 25 ++++++++++++++----------- ariautils/utils/__init__.py | 2 +- ariautils/utils/config.py | 12 ++++++++++++ 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index 7f5382a..b726322 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -23,7 +23,7 @@ ) from mido.midifiles.units import tick2second -from ariautils.utils import load_config +from ariautils.utils import load_config, load_maestro_metadata_json class MetaMessage(TypedDict): @@ -944,23 +944,26 @@ def meta_composer_metamsg( return {} -# This should only be used when processing MAESTRO, it requires maestro.json -# to be in the working directory. This json files contains MAESTRO metadata in -# the form file_name: {"composer": str, "title": str} +# TODO: Needs testing def meta_maestro_json( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list, form_names: list, ) -> Dict[str, str]: - if os.path.isfile("maestro.json") is False: - print("MAESTRO metadata function enabled but ./maestro.json not found.") - return {} + """Loads composer and form metadata from MAESTRO metadata json file. + - file_name = Path(str(mid.filename)).name - with open("maestro.json", "r") as f: - _file_name_without_ext = os.path.splitext(file_name)[0] - metadata = json.load(f).get(_file_name_without_ext + ".midi", None) + This should only be used when processing MAESTRO, it requires maestro.json + to be in the working directory. This json files contains MAESTRO metadata in + the form file_name: {"composer": str, "title": str}. + """ + + _file_name = Path(str(mid.filename)).name + _file_name_without_ext = os.path.splitext(_file_name)[0] + metadata = load_maestro_metadata_json().get( + _file_name_without_ext + ".midi", None + ) if metadata == None: return {} diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index fef643b..6e7f754 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -1,3 +1,3 @@ -from .config import load_config +from .config import load_config, load_maestro_metadata_json __all__ = ["load_config"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index 20c0f6c..0b8abf9 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -1,5 +1,6 @@ """Includes functionality for loading config files.""" +import os import json from importlib import resources @@ -14,3 +15,14 @@ def load_config() -> Dict[str, Any]: .open("r") as f ): return cast(Dict[str, Any], json.load(f)) + + +# TODO: Move somewhere else +def load_maestro_metadata_json() -> Dict[str, Any]: + """Loads MAESTRO metadata json .""" + with ( + resources.files("ariautils.config") + .joinpath("maestro_metadata.json") + .open("r") as f + ): + return cast(Dict[str, Any], json.load(f)) From d194e1eba020d650085a26ac7237185aaa8f2442 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 17:11:26 +0000 Subject: [PATCH 04/12] add tests and ci --- .github/workflows/python-ci.yml | 35 +++++++++++++ ariautils/config/config.json | 2 +- ariautils/midi.py | 8 +-- ariautils/utils/__init__.py | 40 ++++++++++++++- ariautils/utils/config.py | 11 ----- pyproject.toml | 13 ++--- tests/__init__.py | 0 tests/assets/data/arabesque.mid | Bin 0 -> 16975 bytes tests/assets/results/.gitkeep | 0 tests/test_midi.py | 85 ++++++++++++++++++++++++++++++++ 10 files changed, 171 insertions(+), 23 deletions(-) create mode 100644 .github/workflows/python-ci.yml create mode 100644 tests/__init__.py create mode 100644 tests/assets/data/arabesque.mid create mode 100644 tests/assets/results/.gitkeep create mode 100644 tests/test_midi.py diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml new file mode 100644 index 0000000..e66bd07 --- /dev/null +++ b/.github/workflows/python-ci.yml @@ -0,0 +1,35 @@ +name: Python CI + +on: + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: install + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: install + run: | + python -m pip install --upgrade pip + pip install .[dev] + + - name: black + run: | + black --check . + + - name: mypy + run: | + mypy ariautils + mypy tests + + - name: Run tests with pytest + run: | + pytest diff --git a/ariautils/config/config.json b/ariautils/config/config.json index f94d753..465a324 100644 --- a/ariautils/config/config.json +++ b/ariautils/config/config.json @@ -78,7 +78,7 @@ } }, "maestro_json": { - "run": true, + "run": false, "args": { "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"] diff --git a/ariautils/midi.py b/ariautils/midi.py index b726322..5977b25 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -14,7 +14,7 @@ Dict, Any, Tuple, - ParamSpec, + Final, Concatenate, Callable, TypeAlias, @@ -148,9 +148,9 @@ def __init__( @classmethod def get_program_to_instrument(cls) -> Dict[int, str]: - """Get map of MIDI program to instrument name.""" + """Return a map of MIDI program to instrument name.""" - return ( + PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = ( {i: "piano" for i in range(0, 7 + 1)} | {i: "chromatic" for i in range(8, 15 + 1)} | {i: "organ" for i in range(16, 23 + 1)} @@ -169,6 +169,8 @@ def get_program_to_instrument(cls) -> Dict[int, str]: | {i: "sfx" for i in range(120, 127 + 1)} ) + return PROGRAM_TO_INSTRUMENT + def get_msg_dict(self) -> MidiDictData: """Returns MidiDict data in dictionary form.""" diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index 6e7f754..eb1ccad 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -1,3 +1,39 @@ -from .config import load_config, load_maestro_metadata_json +"""Miscellaneous utilities.""" -__all__ = ["load_config"] +import json +import logging + +from importlib import resources +from typing import Dict, Any, cast + +from .config import load_config + + +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger + + +def load_maestro_metadata_json() -> Dict[str, Any]: + """Loads MAESTRO metadata json .""" + with ( + resources.files("ariautils.config") + .joinpath("maestro_metadata.json") + .open("r") as f + ): + return cast(Dict[str, Any], json.load(f)) + + +__all__ = ["load_config", "load_maestro_metadata_json", "get_logger"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index 0b8abf9..a4fd267 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -15,14 +15,3 @@ def load_config() -> Dict[str, Any]: .open("r") as f ): return cast(Dict[str, Any], json.load(f)) - - -# TODO: Move somewhere else -def load_maestro_metadata_json() -> Dict[str, Any]: - """Loads MAESTRO metadata json .""" - with ( - resources.files("ariautils.config") - .joinpath("maestro_metadata.json") - .open("r") as f - ): - return cast(Dict[str, Any], json.load(f)) diff --git a/pyproject.toml b/pyproject.toml index 724a51b..ca7e2dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,15 +28,14 @@ Repository = "https://github.com/EleutherAI/aria-utils" dev = [ "mypy", "black", + "pytest", ] - [tool.setuptools] packages = ["ariautils"] -include-package-data = true [tool.setuptools.package-data] -"ariautils.config" = ["*.json"] +ariautils = ["config/*.json"] [tool.black] line-length = 80 @@ -45,6 +44,7 @@ include = '\.pyi?$' [tool.mypy] python_version = "3.11" +packages = ["ariautils", "tests"] warn_return_any = true warn_unused_configs = true disallow_untyped_defs = true @@ -61,7 +61,8 @@ ignore_missing_imports = true namespace_packages = true explicit_package_bases = true -[tool.pytest] +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q -s" testpaths = ["tests"] -python_files = "test_*.py" -addopts = "-ra -q" +python_files = ["test_*.py"] \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/data/arabesque.mid b/tests/assets/data/arabesque.mid new file mode 100644 index 0000000000000000000000000000000000000000..38a0f85844d27d0a1a064115f248a34dc73f5f64 GIT binary patch literal 16975 zcmeHP33%JZm7g5gvV7!#Kw8qE5D5PyP8`d^@iDR`S^CS8WgV7e%SUXU%Wn(+# zpR@}o350ISO`(LVpTb6a-`;C=2&}ADS z+uiRAU%t%z-h1=@Z)VOUfnX`qMjLgd4V-b^i^PQP=Dl;*d8L{+@WGurY6C;-# z2ktnU{cH0W-D3m#@et52p8-PGb^={+ z8_+e2fo@X+?tT;K@!x^>+XopH%*ba{u;4nN6_=r@e-Xs^O+cIO2ig_{y2cB%e;?5I zb^$&8OQ1j9!pJ=T_l(S?_n`641##osKo?vEbklu6_v{3E@+F`@zs#ucBiAr0ta=S7 zF$uKgJ7});fq2)IKu>>(QPJ$rGAdf01PT=aZN3@kdMnT)p9Fe)kWsNVj3#y&6N|H0 z^5Xr2Ku=dPuVh{;qmqWlfX;pZ=*w$?p2Nn?S@Z;>Ic@mDoGYR%GUuTf6X(u@opT2d z0Ns8&^X9((Z;a;EXBo}gbt}*l=Q5gKdOD-|=iQIy@f}P&zTz{Cj=%g2pg)W-I-%pA z7@hE4Gou9+PcmBYg$;~OJmrs!$o+Po&h6*ICvtP`*j#R|+C7h(%f7#Wn@_(DbYT^E z+qa-O=N=5$S_IzaZV*Smfo5m{cpJ8W*ZVAp-4PJm4*W#CpdZG; zd%ys?|3>ic`61AqZ=v~S8AQIe0lfW~L*AF32k({@Al^Iykv-oA@rL8k>>dE`s{4Sh zut4OJEYPkO!MiX5wBv5@vTH#6hnqp%x){yJuLW;>5vDn^2WV(HcmrPt>i2+`dK{>G z1iVf&yH8&0&FoBh$%oNwIt9ed7XzKA&_6x`bWIq%y=_4I;I(;oAZGFoDBh!plDud3 zL*(WC81N?$^Yh9W<Z>#fW)9k-s5>EtS)m554{{3(dRLJ)h9*Gy-{fj))drt4+{ z-L53R{U*?(BN+F}%Zv(UBbgU0e4L2|tDXe0!3ETg#8mL{JAp3x1bBOI1^SjB=%+|w z1+T+0bIB$qnl&uB*_8lV_fZ~THjv58*W+X`-_Z@;V|&4SeHWv`6F$PI(2n^Pc4K0N z=OFtRe(rLJ-1nalc>z`zooFyBs=b|gMX3jYF2o#)_CE*o;?J2^Jii%@9m%J7Ba{_? zb_mME+v5al*O-+p4FmY0qQ3}mN=R)LXp9lKIgJ_kRrw>Bb;?W;zC^|f-Ke_0K5w-xAqJJ9z+Ko6_R__5o7o_q|vXK}a8 zf6)!%ZD*-wcw{w%G9xp0-L9E6X zOm%kvxyLZvUk()d6Hxs7K%IMlHsAta8tw!d$1U5m1($Twxk(VuAIE@QmjGS%RiGPA z0@|x2Z&h~ge-*rMG=X>5J~ZF?7>GZ(0_d?DfS$sdnV$Pspx>Mg-fPQ&{`fr5B;HyI z@>D&4LJboOPKHRqsY`*DYdoM}4a)k0wWxjz>J0EaUj#2!3e&!ioiAp9UcnL+y!l^@ z%mRySHXQ((=iqVXkKhT!ymA&$#XqBwDEZAE++NJVl^D>5ubY#LfCg7W@)KKucH9AU zu^H&yi1^9uU6@bIo^M4fN0gh&+L7i}|@HfPM!zHUII~j0$H#Sz$5kEIbMK zzrs>H#}ron3?fcEt`!F0427*fW7UPJTfrN;9O!JgP2olGy25Mn!Q0z`X5V)pawp=U z@W3yC9{v<~PlbTqK>biOvy@R$NrZVtAH4)<*$Y4wxP2DY{|dywHW1e@1R5^{I!~dC zTtK_=I#%?Bmw>*$0h0I60D21M7X5ZRh=0aIQL);w;*(Bfk>b-YU|z8eD_HFME>IH7 zQ#^7V&^Zr7^5PX}uE)z#@s~Y7cbOpaBe+fR%g7$;wU*defA3JgxL1h%`R59=!?@HZHJ=axsRe`xN-ZcvSZZM*nxxhwgjZ@_ zksb9&%_D>oRSY!%Y~hxgn}mG2=8~FA2$Ck1LERu_JnRrXYhipC=`hiVu1))t5fLKN zv~3|Z;dZtu*g^Ig0V`~^vaLQVtOvzwC0nN`ZY!B26vbs_qZGwq)f~Bgy)`>pFNED{ z*oCOI%32|hl4`A1h-#}=O`oaYzTK9!(Kj7yB^gs+Ey)T;4fpM}Y>V5T9VP1>E`!-_ zjg9nZ#HiP~@9<`AJPE%fAy<>+$u)D|5##~BFl~hpdR)^9wDS7st=xCF86*^F&yLbo z`?|4kJ-n2@B=^Iqtc}+=E!hY96&M8AsK78lMzT~NZ*I7W`w`uEw-fQ6>Dn|WGI*43 zOxxZm7b86ZU7O}f2AR~PXGh8X0y=}R&Yy|uGRll?rNr>*l1vWh@*vww4BCE=F5z;A z*ltqg7Mh1#i^pNwI!~MIgH7y5O}ar^-8!w1bm_dL(-k>Lg+pjAa%5(_Q;}cS{8NL4 zi$qDNPB-erPKQf3swMVaV&5h2VlDdbO}9&fw79(dH*b8M@H84^zst`QZ<9e&bvN^V zvO~CqHwa3IDeMCg;fZ8N?SB5A!_PkNYJ%}iY`ibKI85b z-hLiZ-^sr1>S8h8$1dSp&k7yuSlGReWynfzw`dyRICpj9qdhFc4$&h5gS@xyBujde zEW?hJ62VQJ$vi2R^!2g~yIijb53v_``&hE6k7d{$`b1K}`amCBcBc zJxS3)pHu!sK|y_=p0)9`b;*lD)00zu6#+C=mn4^jZ_pyjoCCs@**U7DsADSov_t14 zQ%6M7D1~2AL}o{MH4tZhqEe!g!hZMw@*zrHKKc$BPxc%P#)vLyRzaQf0xvH8ElJmr z&%tqU4mxt4WNVepG=+oaxOGl*2qEIa#dP+LTAibdlPuXzGg1cAFv1RKB>R~tNt7uB zt7LYRmVrl$9SM5^j@b@h>x2ibW54ip!>Qp@t`t134}PS)V?Cz2fy3Vi@AGctQT`#O z1mQ=KjNwfQe~MQsypa$|Z#+X(O<0$&TRWhAUMnqRt@}Y{xP|+0B zd5Y7k9xv5eIf`3HrUOOgQ$=wgtEzrIxe2?Vjw)F`D$q7`+2_E3}6wZXM`XI9>bhfBN*E+~99X5gz)nPR#>SCfODu5M5 z#hvO;U66cq_o8oGdFhYF2<5V2a_e zvdbNO^q4U>_vleL_2}VBjgKFC^7v))QoA0uNzFzlm(=Lsntl~i{rBdsmF%-BBTk*J z$c}5p0~V{US%s+dYL%qd$H`1x!fFFZZ_P}p3Taj3+hL~EwP+L+t|*GxID21RD^YfR zSc(p<%k1LZx~Rt>>$@I<#HDszP#qG1#&13p2>iH11qtWW={Hv;iU%LM9AkuHS4o=P zg~|E`+~+)~nZ2l)ee57jO&lSvsFe-^hLmX0M|~I$hqbvE!(*imqgiw*D2GTtPEj4trLx z{go>v9|5J3HnegT_ibfNFJHqlrRB1e`c}wSV7@lYw-WPR%Y9oF8WenI~D&@8T+PX(2=w|L>ueMl-6<&?qyOHCHJN;^p913iS=q(RmbtmJ1-Hkj?;m%?oW z^nEf%lZ4qWTKB2PQ&K>q9g^!z8p^RLX^9OENv4w~Qb5FQ1B_`YXg)LrwJS9sBABBY{Q z*t&V?NcWOXv8t0NPkL+#Cc0v{vqdE@P(;!YdzhB~NWRa`xs-nRk`{GoQJWT(RN+iB zz*v?R*0hAd($yvEx&nCZ3s=^og};Ny4>h4&=HtyNxp~ zo#HCTQK&hlWptcT+Z^wmSgMd@RKiiIQ^}W0oY~V7k;*dD<(16Q*1tlc=1}(7IWbn& zp)#mPHBgUgz`^u#C)=!2e0c-xSBc(@MCw7!fhxh~g}j$#RyCsL@If;whH5`*jwa5k z)d75^8D9yq%<3R&ju7{&!l*eSELq;d+EgV`6GhE|O2QUr$(mNw9IY(rZbQwHK+Vz4 z6h}L!6upCL4)ThaR^7x!V_g83X8ioJt{KO3km(dywuGjx*Y*$>C62Jh={UlDDyunj zlM8lfl~F#8YNNP<$5=ycjQf>wrjyH-CpB9upIg;!tXcg?r8>cVTRYQfN@ac8nH?JW zaaAWS^qr{9ySRVmI;L|7ZP3cQQQ~yt{Of_8N$%4pc$H~mHfd$OTo~JWS#wn%?Cj%y zS(@pb`_nj3(>PGUs_ccr6f3Mv(x=Z4FAZya(1&u3W!7+&#>>DfYz^Z@i`}5g%7=r^%Gs^f z%5{gjQ@O450=&~f?>LG_Dh}S4^3KZLmyQefr&B}X-rWC5BOK>uQw}lbeKgx67&GzYqW9lpP;j?J*6<8ApCfMm z_AmUMBzKk{VOIXI_CB|iNvd299z0@o?%-K+_`j>SY)qX^e;flN6T_*&q1AmO!_!8c zo*T7r>L1-(-u=(v`(}vv~9T0;6Xc^VguwfGuhfp%V+MxV|M%jVk%|8C)dLgo$G-p}sW None: + self.logger = get_logger(__name__ + ".TestMidiDict") + + def test_load(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(load_path) + + self.logger.info(f"Num meta_msgs: {len(midi_dict.meta_msgs)}") + self.logger.info(f"Num tempo_msgs: {len(midi_dict.tempo_msgs)}") + self.logger.info(f"Num pedal_msgs: {len(midi_dict.pedal_msgs)}") + self.logger.info( + f"Num instrument_msgs: {len(midi_dict.instrument_msgs)}" + ) + self.logger.info(f"Num note_msgs: {len(midi_dict.note_msgs)}") + self.logger.info(f"ticks_per_beat: {midi_dict.ticks_per_beat}") + self.logger.info(f"metadata: {midi_dict.metadata}") + + def test_save(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(mid_path=load_path) + midi_dict.to_midi().save(save_path) + + def test_resolve_pedal(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath( + "arabesque_pedal_resolved.mid" + ) + midi_dict = MidiDict.from_midi(mid_path=load_path).resolve_pedal() + midi_dict.to_midi().save(save_path) + + def test_remove_redundant_pedals(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath( + "arabesque_remove_redundant_pedals.mid" + ) + midi_dict = MidiDict.from_midi(mid_path=load_path) + self.logger.info( + f"Num pedal_msgs before remove_redundant_pedals: {len(midi_dict.pedal_msgs)}" + ) + + midi_dict_adj_resolve = ( + MidiDict.from_midi(mid_path=load_path) + .resolve_pedal() + .remove_redundant_pedals() + ) + midi_dict_resolve_adj = ( + MidiDict.from_midi(mid_path=load_path) + .remove_redundant_pedals() + .resolve_pedal() + ) + + self.logger.info( + f"Num pedal_msgs after remove_redundant_pedals: {len(midi_dict_adj_resolve.pedal_msgs)}" + ) + self.assertEqual( + len(midi_dict_adj_resolve.pedal_msgs), + len(midi_dict_resolve_adj.pedal_msgs), + ) + + for msg_1, msg_2 in zip( + midi_dict_adj_resolve.note_msgs, midi_dict_resolve_adj.note_msgs + ): + self.assertDictEqual(msg_1, msg_2) + + midi_dict_adj_resolve.to_midi().save(save_path) From bd49405fc76b6b00ac43c77dfacff5e4ffe161b0 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 17:18:25 +0000 Subject: [PATCH 05/12] add space --- tests/test_midi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_midi.py b/tests/test_midi.py index 74a7402..82e7c89 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -45,6 +45,7 @@ def test_resolve_pedal(self) -> None: save_path = RESULTS_DATA_DIRECTORY.joinpath( "arabesque_pedal_resolved.mid" ) + midi_dict = MidiDict.from_midi(mid_path=load_path).resolve_pedal() midi_dict.to_midi().save(save_path) From fe0f484b30123e39b3f8828b9ec8d1e2e282f441 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 18:05:30 +0000 Subject: [PATCH 06/12] update midi tests --- ariautils/midi.py | 9 +++++---- tests/test_midi.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index 5977b25..984ef5f 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -20,6 +20,7 @@ TypeAlias, Literal, TypedDict, + cast, ) from mido.midifiles.units import tick2second @@ -213,7 +214,7 @@ def from_midi(cls, mid_path: str | Path) -> "MidiDict": return cls(**midi_to_dict(mid)) def calculate_hash(self) -> str: - msg_dict_to_hash = dict(self.get_msg_dict()) + msg_dict_to_hash = cast(dict, self.get_msg_dict()) # Remove metadata before calculating hash msg_dict_to_hash.pop("meta_msgs") @@ -330,7 +331,7 @@ def resolve_pedal(self) -> "MidiDict": return self - # TODO: Needs to be refactored and tested + # TODO: Needs to be refactored def remove_redundant_pedals(self) -> "MidiDict": """Removes redundant pedal messages from the MIDI data in place. @@ -790,9 +791,9 @@ def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: # Magic sorting function def _sort_fn(msg: mido.Message) -> Tuple[int, int]: if hasattr(msg, "velocity"): - return (msg.time, msg.velocity) + return (msg.time, msg.velocity) # pyright: ignore else: - return (msg.time, 1000) + return (msg.time, 1000) # pyright: ignore # Sort and convert from abs_time -> delta_time track = sorted(track, key=_sort_fn) diff --git a/tests/test_midi.py b/tests/test_midi.py index 82e7c89..6ff28f4 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -1,4 +1,6 @@ import unittest +import tempfile +import shutil from importlib import resources from pathlib import Path @@ -40,12 +42,38 @@ def test_save(self) -> None: midi_dict = MidiDict.from_midi(mid_path=load_path) midi_dict.to_midi().save(save_path) + def test_tick_to_ms(self): + CORRECT_LAST_NOTE_ONSET_MS: Final[int] = 220140 + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(load_path) + last_note = midi_dict.note_msgs[-1] + last_note_onset_tick = last_note["tick"] + last_note_onset_ms = midi_dict.tick_to_ms(last_note_onset_tick) + self.assertEqual(last_note_onset_ms, CORRECT_LAST_NOTE_ONSET_MS) + + def test_calculate_hash(self): + # Load two identical files with different filenames and metadata + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict_orig = MidiDict.from_midi(load_path) + + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + shutil.copy(load_path, temp_file.name) + midi_dict_temp = MidiDict.from_midi(temp_file.name) + + midi_dict_temp.meta_msgs.append({"type": "text", "data": "test"}) + midi_dict_temp.metadata["composer"] = "test" + midi_dict_temp.metadata["composer"] = "test" + midi_dict_temp.metadata["ticks_per_beat"] = -1 + + self.assertEqual( + midi_dict_orig.calculate_hash(), midi_dict_temp.calculate_hash() + ) + def test_resolve_pedal(self) -> None: load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") save_path = RESULTS_DATA_DIRECTORY.joinpath( "arabesque_pedal_resolved.mid" ) - midi_dict = MidiDict.from_midi(mid_path=load_path).resolve_pedal() midi_dict.to_midi().save(save_path) From 24217475c67a2c6f2177e6b0ec485cf12b0b40b5 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 19:48:22 +0000 Subject: [PATCH 07/12] add abstract tokenizer class --- ariautils/midi.py | 96 ++++++------- ariautils/tokenizer/__init__.py | 246 ++++++++++++++++++++++++++++++++ ariautils/utils/__init__.py | 6 +- ariautils/utils/config.py | 6 +- tests/test_midi.py | 4 +- 5 files changed, 301 insertions(+), 57 deletions(-) create mode 100644 ariautils/tokenizer/__init__.py diff --git a/ariautils/midi.py b/ariautils/midi.py index 984ef5f..c34ac72 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1,4 +1,4 @@ -"""Utils for data/MIDI processing.""" +"""Utils for MIDI processing.""" import re import os @@ -7,11 +7,10 @@ import unicodedata import mido +from mido.midifiles.units import tick2second from collections import defaultdict from pathlib import Path from typing import ( - List, - Dict, Any, Tuple, Final, @@ -23,7 +22,6 @@ cast, ) -from mido.midifiles.units import tick2second from ariautils.utils import load_config, load_maestro_metadata_json @@ -84,37 +82,37 @@ class NoteMessage(TypedDict): class MidiDictData(TypedDict): """Type for MidiDict attributes in dictionary form.""" - meta_msgs: List[MetaMessage] - tempo_msgs: List[TempoMessage] - pedal_msgs: List[PedalMessage] - instrument_msgs: List[InstrumentMessage] - note_msgs: List[NoteMessage] + meta_msgs: list[MetaMessage] + tempo_msgs: list[TempoMessage] + pedal_msgs: list[PedalMessage] + instrument_msgs: list[InstrumentMessage] + note_msgs: list[NoteMessage] ticks_per_beat: int - metadata: Dict[str, Any] + metadata: dict[str, Any] class MidiDict: """Container for MIDI data in dictionary form. Args: - meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages. - tempo_msgs (List[TempoMessage]): List of tempo change messages. - pedal_msgs (List[PedalMessage]): List of sustain pedal messages. - instrument_msgs (List[InstrumentMessage]): List of program change messages. - note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events. + meta_msgs (list[MetaMessage]): List of text or copyright MIDI meta messages. + tempo_msgs (list[TempoMessage]): List of tempo change messages. + pedal_msgs (list[PedalMessage]): List of sustain pedal messages. + instrument_msgs (list[InstrumentMessage]): List of program change messages. + note_msgs (list[NoteMessage]): List of note messages from paired note-on/off events. ticks_per_beat (int): MIDI ticks per beat. metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}). """ def __init__( self, - meta_msgs: List[MetaMessage], - tempo_msgs: List[TempoMessage], - pedal_msgs: List[PedalMessage], - instrument_msgs: List[InstrumentMessage], - note_msgs: List[NoteMessage], + meta_msgs: list[MetaMessage], + tempo_msgs: list[TempoMessage], + pedal_msgs: list[PedalMessage], + instrument_msgs: list[InstrumentMessage], + note_msgs: list[NoteMessage], ticks_per_beat: int, - metadata: Dict[str, Any], + metadata: dict[str, Any], ): self.meta_msgs = meta_msgs self.tempo_msgs = tempo_msgs @@ -148,10 +146,10 @@ def __init__( self.program_to_instrument = self.get_program_to_instrument() @classmethod - def get_program_to_instrument(cls) -> Dict[int, str]: + def get_program_to_instrument(cls) -> dict[int, str]: """Return a map of MIDI program to instrument name.""" - PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = ( + PROGRAM_TO_INSTRUMENT: Final[dict[int, str]] = ( {i: "piano" for i in range(0, 7 + 1)} | {i: "chromatic" for i in range(8, 15 + 1)} | {i: "organ" for i in range(16, 23 + 1)} @@ -235,12 +233,12 @@ def tick_to_ms(self, tick: int) -> int: ticks_per_beat=self.ticks_per_beat, ) - def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]: + def _build_pedal_intervals(self) -> dict[int, list[list[int]]]: """Returns a mapping of channels to sustain pedal intervals.""" self.pedal_msgs.sort(key=lambda msg: msg["tick"]) channel_to_pedal_intervals = defaultdict(list) - pedal_status: Dict[int, int] = {} + pedal_status: dict[int, int] = {} for pedal_msg in self.pedal_msgs: tick = pedal_msg["tick"] @@ -277,7 +275,7 @@ def resolve_overlaps(self) -> "MidiDict": """ # Organize notes by channel and pitch - note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict( + note_msgs_c: dict[int, dict[int, list[NoteMessage]]] = defaultdict( lambda: defaultdict(list) ) for msg in self.note_msgs: @@ -343,7 +341,7 @@ def remove_redundant_pedals(self) -> "MidiDict": def _is_pedal_useful( pedal_start_tick: int, pedal_end_tick: int, - note_msgs: List[NoteMessage], + note_msgs: list[NoteMessage], ) -> bool: # This logic loops through the note_msgs that could possibly # be effected by the pedal which starts at pedal_start_tick @@ -487,7 +485,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": channels_to_remove = [i for i in channels_to_remove if i != 9] # Remove unwanted messages all type by looping over msgs types - _msg_dict: Dict[str, List] = { + _msg_dict: dict[str, list] = { "meta_msgs": self.meta_msgs, "tempo_msgs": self.tempo_msgs, "pedal_msgs": self.pedal_msgs, @@ -513,19 +511,19 @@ def remove_instruments(self, config: dict) -> "MidiDict": def _extract_track_data( track: mido.MidiTrack, ) -> Tuple[ - List[MetaMessage], - List[TempoMessage], - List[PedalMessage], - List[InstrumentMessage], - List[NoteMessage], + list[MetaMessage], + list[TempoMessage], + list[PedalMessage], + list[InstrumentMessage], + list[NoteMessage], ]: """Converts MIDI messages into format used by MidiDict.""" - meta_msgs: List[MetaMessage] = [] - tempo_msgs: List[TempoMessage] = [] - pedal_msgs: List[PedalMessage] = [] - instrument_msgs: List[InstrumentMessage] = [] - note_msgs: List[NoteMessage] = [] + meta_msgs: list[MetaMessage] = [] + tempo_msgs: list[TempoMessage] = [] + pedal_msgs: list[PedalMessage] = [] + instrument_msgs: list[InstrumentMessage] = [] + note_msgs: list[NoteMessage] = [] last_note_on = defaultdict(list) for message in track: @@ -685,7 +683,7 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: metadata_fn = get_metadata_fn( metadata_process_name=metadata_process_name ) - fn_args: Dict = metadata_process_config["args"] + fn_args: dict = metadata_process_config["args"] collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args) if collected_metadata: @@ -813,7 +811,7 @@ def _sort_fn(msg: mido.Message) -> Tuple[int, int]: def get_duration_ms( start_tick: int, end_tick: int, - tempo_msgs: List[TempoMessage], + tempo_msgs: list[TempoMessage], ticks_per_beat: int, ) -> int: """Calculates elapsed time (in ms) between start_tick and end_tick.""" @@ -898,7 +896,7 @@ def to_ascii(s: str) -> str: def meta_composer_filename( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list -) -> Dict[str, str]: +) -> dict[str, str]: file_name = Path(str(mid.filename)).stem matched_names_unique = set() for name in composer_names: @@ -915,7 +913,7 @@ def meta_composer_filename( def meta_form_filename( mid: mido.MidiFile, msg_data: MidiDictData, form_names: list -) -> Dict[str, str]: +) -> dict[str, str]: file_name = Path(str(mid.filename)).stem matched_names_unique = set() for name in form_names: @@ -932,7 +930,7 @@ def meta_form_filename( def meta_composer_metamsg( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list -) -> Dict[str, str]: +) -> dict[str, str]: matched_names_unique = set() for msg in msg_data["meta_msgs"]: for name in composer_names: @@ -953,7 +951,7 @@ def meta_maestro_json( msg_data: MidiDictData, composer_names: list, form_names: list, -) -> Dict[str, str]: +) -> dict[str, str]: """Loads composer and form metadata from MAESTRO metadata json file. @@ -991,16 +989,16 @@ def meta_maestro_json( return res -def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]: +def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> dict[str, str]: return {"abs_path": str(Path(str(mid.filename)).absolute())} def get_metadata_fn( metadata_process_name: str, -) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]: - name_to_fn: Dict[ +) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]]: + name_to_fn: dict[ str, - Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]], + Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]], ] = { "composer_filename": meta_composer_filename, "composer_metamsg": meta_composer_metamsg, @@ -1132,7 +1130,7 @@ def test_min_length( def get_test_fn( test_name: str, ) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]: - name_to_fn: Dict[ + name_to_fn: dict[ str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]] ] = { "max_programs": test_max_programs, diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py new file mode 100644 index 0000000..382adda --- /dev/null +++ b/ariautils/tokenizer/__init__.py @@ -0,0 +1,246 @@ +"""Includes Tokenizers and pre-processing utilities.""" + +import functools + +from typing import ( + Any, + Final, + Callable, + TypeAlias, +) + +from ariautils.midi import MidiDict + + +SpecialToken: TypeAlias = str +Token: TypeAlias = tuple[Any, ...] | str + + +class Tokenizer: + """Abstract Tokenizer class for tokenizing midi_dict objects. + + Args: + return_tensors (bool, optional): If True, encode will return tensors. + Defaults to False. + """ + + def __init__( + self, + return_tensors: bool = False, + ): + self.name: str = "" + self.return_tensors = return_tensors # DELETE + + self.bos_tok: Final[SpecialToken] = "" + self.eos_tok: Final[SpecialToken] = "" + self.pad_tok: Final[SpecialToken] = "

" + self.unk_tok: Final[SpecialToken] = "" + self.dim_tok: Final[SpecialToken] = "" + + self.special_tokens: list[SpecialToken] = [ + self.bos_tok, + self.eos_tok, + self.pad_tok, + self.unk_tok, + self.dim_tok, + ] + + # These must be implemented in child class (abstract params) + self.vocab: tuple[Token, ...] = () + self.instruments_wd: list[str] = [] + self.instruments_nd: list[str] = [] + self.config: dict[str, Any] = {} + self.tok_to_id: dict[Token, int] = {} + self.id_to_tok: dict[int, Token] = {} + self.vocab_size: int = -1 + self.pad_id: int = -1 + + def _tokenize_midi_dict(self, midi_dict: MidiDict) -> list[Token]: + """Abstract method for tokenizing a MidiDict object into a sequence of + tokens. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + raise NotImplementedError + + def tokenize(self, midi_dict: MidiDict, **kwargs: Any) -> list[Token]: + """Tokenizes a MidiDict object. + + This function should be overridden if additional transformations are + required, e.g., adding additional tokens. The default behavior is to + call tokenize_midi_dict. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + **kwargs (Any): Additional keyword arguments passed to _tokenize_midi_dict. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + return self._tokenize_midi_dict(midi_dict, **kwargs) + + def _detokenize_midi_dict(self, tokenized_seq: list[int]) -> MidiDict: + """Abstract method for de-tokenizing a sequence of tokens into a + MidiDict Object. + + Args: + tokenized_seq (list[int]): The sequence of tokens to detokenize. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + raise NotImplementedError + + def detokenize(self, tokenized_seq: list[int], **kwargs: Any) -> MidiDict: + """Detokenizes a MidiDict object. + + This function should be overridden if additional are required during + detokenization. The default behavior is to call detokenize_midi_dict. + + Args: + tokenized_seq (list): The sequence of tokens to detokenize. + **kwargs (Any): Additional keyword arguments passed to detokenize_midi_dict. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + return self._detokenize_midi_dict(tokenized_seq, **kwargs) + + def export_data_aug(cls) -> list[Callable[[list[Token]], list[Token]]]: + """Export a list of implemented data augmentation functions.""" + + raise NotImplementedError + + def encode(self, unencoded_seq: list[Token]) -> list[int]: + """Converts tokenized sequence into the corresponding list of ids.""" + + def _enc_fn(tok: Token) -> int: + return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) + + if self.tok_to_id is None: + raise NotImplementedError("tok_to_id") + + encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] + + return encoded_seq + + def decode(self, encoded_seq: list[int]) -> list[Token]: + """Converts list of ids into the corresponding list of tokens.""" + + def _dec_fn(id: int) -> Token: + return self.id_to_tok.get(id, self.unk_tok) + + if self.id_to_tok is None: + raise NotImplementedError("id_to_tok") + + decoded_seq = [_dec_fn(idx) for idx in encoded_seq] + + return decoded_seq + + @classmethod + def _find_closest_int(cls, n: int, sorted_list: list[int]) -> int: + # Selects closest integer to n from sorted_list + # Time ~ Log(n) + + if not sorted_list: + raise ValueError("List is empty") + + left, right = 0, len(sorted_list) - 1 + closest = float("inf") + + while left <= right: + mid = (left + right) // 2 + diff = abs(sorted_list[mid] - n) + + if diff < abs(closest - n): + closest = sorted_list[mid] + + if sorted_list[mid] < n: + left = mid + 1 + else: + right = mid - 1 + + return closest # type: ignore[return-value] + + def add_tokens_to_vocab(self, tokens: list[Token] | tuple[Token]) -> None: + """Utility function for safely adding extra tokens to vocab.""" + + for token in tokens: + assert token not in self.vocab + + self.vocab = self.vocab + tuple(tokens) + self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} + self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} + self.vocab_size = len(self.vocab) + + def export_aug_fn_concat( + self, aug_fn: Callable[[list[Token]], list[Token]] + ) -> Callable[[list[Token]], list[Token]]: + """Transforms an augmentation function for concatenated sequences. + + This is useful for augmentation functions that are only defined for + sequences which start with and end with . + + Args: + aug_fn (Callable[[list[Token]], list[Token]]): The augmentation + function to transform. + + Returns: + Callable[[list[Token]], list[Token]]: A transformed augmentation + function that can handle concatenated sequences. + """ + + def _aug_fn_concat( + src: list[Token], + _aug_fn: Callable[[list[Token]], list[Token]], + pad_tok: str, + eos_tok: str, + **kwargs: Any, + ) -> list[Token]: + # Split list on '' + initial_seq_len = len(src) + src_sep = [] + prev_idx = 0 + for curr_idx, tok in enumerate(src, start=1): + if tok == eos_tok: + src_sep.append(src[prev_idx:curr_idx]) + prev_idx = curr_idx + + # Last sequence + if prev_idx != curr_idx: + src_sep.append(src[prev_idx:]) + + # Augment + src_sep = [ + _aug_fn( + _src, + **kwargs, + ) + for _src in src_sep + ] + + # Concatenate + src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] + + # Pad or truncate to original sequence length as necessary + src_aug_concat = src_aug_concat[:initial_seq_len] + src_aug_concat += [pad_tok] * ( + initial_seq_len - len(src_aug_concat) + ) + + return src_aug_concat + + return functools.partial( + _aug_fn_concat, + _aug_fn=aug_fn, + pad_tok=self.pad_tok, + eos_tok=self.eos_tok, + ) diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index eb1ccad..4c51084 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -4,7 +4,7 @@ import logging from importlib import resources -from typing import Dict, Any, cast +from typing import Any, cast from .config import load_config @@ -26,14 +26,14 @@ def get_logger(name: str) -> logging.Logger: return logger -def load_maestro_metadata_json() -> Dict[str, Any]: +def load_maestro_metadata_json() -> dict[str, Any]: """Loads MAESTRO metadata json .""" with ( resources.files("ariautils.config") .joinpath("maestro_metadata.json") .open("r") as f ): - return cast(Dict[str, Any], json.load(f)) + return cast(dict[str, Any], json.load(f)) __all__ = ["load_config", "load_maestro_metadata_json", "get_logger"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index a4fd267..471a3f2 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -4,14 +4,14 @@ import json from importlib import resources -from typing import Dict, Any, cast +from typing import Any, cast -def load_config() -> Dict[str, Any]: +def load_config() -> dict[str, Any]: """Returns a dictionary loaded from the config.json file.""" with ( resources.files("ariautils.config") .joinpath("config.json") .open("r") as f ): - return cast(Dict[str, Any], json.load(f)) + return cast(dict[str, Any], json.load(f)) diff --git a/tests/test_midi.py b/tests/test_midi.py index 6ff28f4..5d6aa02 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -42,7 +42,7 @@ def test_save(self) -> None: midi_dict = MidiDict.from_midi(mid_path=load_path) midi_dict.to_midi().save(save_path) - def test_tick_to_ms(self): + def test_tick_to_ms(self) -> None: CORRECT_LAST_NOTE_ONSET_MS: Final[int] = 220140 load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") midi_dict = MidiDict.from_midi(load_path) @@ -51,7 +51,7 @@ def test_tick_to_ms(self): last_note_onset_ms = midi_dict.tick_to_ms(last_note_onset_tick) self.assertEqual(last_note_onset_ms, CORRECT_LAST_NOTE_ONSET_MS) - def test_calculate_hash(self): + def test_calculate_hash(self) -> None: # Load two identical files with different filenames and metadata load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") midi_dict_orig = MidiDict.from_midi(load_path) From 25eff8e2afdc50b35a2b14ebaafc5277f31d4c67 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 19:49:54 +0000 Subject: [PATCH 08/12] fix mypy and upgrade to pep 585 --- ariautils/midi.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index c34ac72..a59526e 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import ( Any, - Tuple, Final, Concatenate, Callable, @@ -510,7 +509,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": # TODO: The sign has been changed. Make sure this function isn't used anywhere else def _extract_track_data( track: mido.MidiTrack, -) -> Tuple[ +) -> tuple[ list[MetaMessage], list[TempoMessage], list[PedalMessage], @@ -787,7 +786,7 @@ def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: ) # Magic sorting function - def _sort_fn(msg: mido.Message) -> Tuple[int, int]: + def _sort_fn(msg: mido.Message) -> tuple[int, int]: if hasattr(msg, "velocity"): return (msg.time, msg.velocity) # pyright: ignore else: @@ -1016,7 +1015,7 @@ def get_metadata_fn( return fn -def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: +def test_max_programs(midi_dict: MidiDict, max: int) -> tuple[bool, int]: """Returns false if midi_dict uses more than {max} programs.""" present_programs = set( map( @@ -1031,7 +1030,7 @@ def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: return False, len(present_programs) -def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: +def test_max_instruments(midi_dict: MidiDict, max: int) -> tuple[bool, int]: present_instruments = set( map( lambda msg: midi_dict.program_to_instrument[msg["data"]], @@ -1047,7 +1046,7 @@ def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: def test_note_frequency( midi_dict: MidiDict, max_per_second: float, min_per_second: float -) -> Tuple[bool, float]: +) -> tuple[bool, float]: if not midi_dict.note_msgs: return False, 0.0 @@ -1072,7 +1071,7 @@ def test_note_frequency( def test_note_frequency_per_instrument( midi_dict: MidiDict, max_per_second: float, min_per_second: float -) -> Tuple[bool, float]: +) -> tuple[bool, float]: num_instruments = len( set( map( @@ -1110,7 +1109,7 @@ def test_note_frequency_per_instrument( def test_min_length( midi_dict: MidiDict, min_seconds: int -) -> Tuple[bool, float]: +) -> tuple[bool, float]: if not midi_dict.note_msgs: return False, 0.0 @@ -1129,9 +1128,9 @@ def test_min_length( def get_test_fn( test_name: str, -) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]: +) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]: name_to_fn: dict[ - str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]] + str, Callable[Concatenate[MidiDict, ...], tuple[bool, Any]] ] = { "max_programs": test_max_programs, "max_instruments": test_max_instruments, From 1f2476d581b24d4ee28de718025ce865fd453fcb Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 19:50:09 +0000 Subject: [PATCH 09/12] rmv import --- ariautils/utils/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index 471a3f2..4093318 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -1,6 +1,5 @@ """Includes functionality for loading config files.""" -import os import json from importlib import resources From bdf3c8331ce35bc21b9a6a881ae05a1a5b669666 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 19:53:23 +0000 Subject: [PATCH 10/12] fix docstring --- ariautils/tokenizer/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py index 382adda..5172d34 100644 --- a/ariautils/tokenizer/__init__.py +++ b/ariautils/tokenizer/__init__.py @@ -187,7 +187,7 @@ def export_aug_fn_concat( """Transforms an augmentation function for concatenated sequences. This is useful for augmentation functions that are only defined for - sequences which start with and end with . + sequences which start with "" and end with "". Args: aug_fn (Callable[[list[Token]], list[Token]]): The augmentation @@ -205,7 +205,7 @@ def _aug_fn_concat( eos_tok: str, **kwargs: Any, ) -> list[Token]: - # Split list on '' + # Split list on "" initial_seq_len = len(src) src_sep = [] prev_idx = 0 From 149a4cb07e1a2d61a7401db19f345768b2b2e994 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 20 Nov 2024 16:33:03 +0000 Subject: [PATCH 11/12] migrate abstokenizer --- ariautils/tokenizer/__init__.py | 245 +-------- ariautils/tokenizer/_base.py | 244 +++++++++ ariautils/tokenizer/absolute.py | 848 ++++++++++++++++++++++++++++++++ 3 files changed, 1094 insertions(+), 243 deletions(-) create mode 100644 ariautils/tokenizer/_base.py create mode 100644 ariautils/tokenizer/absolute.py diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py index 5172d34..792d910 100644 --- a/ariautils/tokenizer/__init__.py +++ b/ariautils/tokenizer/__init__.py @@ -1,246 +1,5 @@ """Includes Tokenizers and pre-processing utilities.""" -import functools +from ariautils.tokenizer._base import Tokenizer -from typing import ( - Any, - Final, - Callable, - TypeAlias, -) - -from ariautils.midi import MidiDict - - -SpecialToken: TypeAlias = str -Token: TypeAlias = tuple[Any, ...] | str - - -class Tokenizer: - """Abstract Tokenizer class for tokenizing midi_dict objects. - - Args: - return_tensors (bool, optional): If True, encode will return tensors. - Defaults to False. - """ - - def __init__( - self, - return_tensors: bool = False, - ): - self.name: str = "" - self.return_tensors = return_tensors # DELETE - - self.bos_tok: Final[SpecialToken] = "" - self.eos_tok: Final[SpecialToken] = "" - self.pad_tok: Final[SpecialToken] = "

" - self.unk_tok: Final[SpecialToken] = "" - self.dim_tok: Final[SpecialToken] = "" - - self.special_tokens: list[SpecialToken] = [ - self.bos_tok, - self.eos_tok, - self.pad_tok, - self.unk_tok, - self.dim_tok, - ] - - # These must be implemented in child class (abstract params) - self.vocab: tuple[Token, ...] = () - self.instruments_wd: list[str] = [] - self.instruments_nd: list[str] = [] - self.config: dict[str, Any] = {} - self.tok_to_id: dict[Token, int] = {} - self.id_to_tok: dict[int, Token] = {} - self.vocab_size: int = -1 - self.pad_id: int = -1 - - def _tokenize_midi_dict(self, midi_dict: MidiDict) -> list[Token]: - """Abstract method for tokenizing a MidiDict object into a sequence of - tokens. - - Args: - midi_dict (MidiDict): The MidiDict to tokenize. - - Returns: - list[Token]: A sequence of tokens representing the MIDI content. - """ - - raise NotImplementedError - - def tokenize(self, midi_dict: MidiDict, **kwargs: Any) -> list[Token]: - """Tokenizes a MidiDict object. - - This function should be overridden if additional transformations are - required, e.g., adding additional tokens. The default behavior is to - call tokenize_midi_dict. - - Args: - midi_dict (MidiDict): The MidiDict to tokenize. - **kwargs (Any): Additional keyword arguments passed to _tokenize_midi_dict. - - Returns: - list[Token]: A sequence of tokens representing the MIDI content. - """ - - return self._tokenize_midi_dict(midi_dict, **kwargs) - - def _detokenize_midi_dict(self, tokenized_seq: list[int]) -> MidiDict: - """Abstract method for de-tokenizing a sequence of tokens into a - MidiDict Object. - - Args: - tokenized_seq (list[int]): The sequence of tokens to detokenize. - - Returns: - MidiDict: A MidiDict reconstructed from the tokens. - """ - - raise NotImplementedError - - def detokenize(self, tokenized_seq: list[int], **kwargs: Any) -> MidiDict: - """Detokenizes a MidiDict object. - - This function should be overridden if additional are required during - detokenization. The default behavior is to call detokenize_midi_dict. - - Args: - tokenized_seq (list): The sequence of tokens to detokenize. - **kwargs (Any): Additional keyword arguments passed to detokenize_midi_dict. - - Returns: - MidiDict: A MidiDict reconstructed from the tokens. - """ - - return self._detokenize_midi_dict(tokenized_seq, **kwargs) - - def export_data_aug(cls) -> list[Callable[[list[Token]], list[Token]]]: - """Export a list of implemented data augmentation functions.""" - - raise NotImplementedError - - def encode(self, unencoded_seq: list[Token]) -> list[int]: - """Converts tokenized sequence into the corresponding list of ids.""" - - def _enc_fn(tok: Token) -> int: - return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) - - if self.tok_to_id is None: - raise NotImplementedError("tok_to_id") - - encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] - - return encoded_seq - - def decode(self, encoded_seq: list[int]) -> list[Token]: - """Converts list of ids into the corresponding list of tokens.""" - - def _dec_fn(id: int) -> Token: - return self.id_to_tok.get(id, self.unk_tok) - - if self.id_to_tok is None: - raise NotImplementedError("id_to_tok") - - decoded_seq = [_dec_fn(idx) for idx in encoded_seq] - - return decoded_seq - - @classmethod - def _find_closest_int(cls, n: int, sorted_list: list[int]) -> int: - # Selects closest integer to n from sorted_list - # Time ~ Log(n) - - if not sorted_list: - raise ValueError("List is empty") - - left, right = 0, len(sorted_list) - 1 - closest = float("inf") - - while left <= right: - mid = (left + right) // 2 - diff = abs(sorted_list[mid] - n) - - if diff < abs(closest - n): - closest = sorted_list[mid] - - if sorted_list[mid] < n: - left = mid + 1 - else: - right = mid - 1 - - return closest # type: ignore[return-value] - - def add_tokens_to_vocab(self, tokens: list[Token] | tuple[Token]) -> None: - """Utility function for safely adding extra tokens to vocab.""" - - for token in tokens: - assert token not in self.vocab - - self.vocab = self.vocab + tuple(tokens) - self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} - self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} - self.vocab_size = len(self.vocab) - - def export_aug_fn_concat( - self, aug_fn: Callable[[list[Token]], list[Token]] - ) -> Callable[[list[Token]], list[Token]]: - """Transforms an augmentation function for concatenated sequences. - - This is useful for augmentation functions that are only defined for - sequences which start with "" and end with "". - - Args: - aug_fn (Callable[[list[Token]], list[Token]]): The augmentation - function to transform. - - Returns: - Callable[[list[Token]], list[Token]]: A transformed augmentation - function that can handle concatenated sequences. - """ - - def _aug_fn_concat( - src: list[Token], - _aug_fn: Callable[[list[Token]], list[Token]], - pad_tok: str, - eos_tok: str, - **kwargs: Any, - ) -> list[Token]: - # Split list on "" - initial_seq_len = len(src) - src_sep = [] - prev_idx = 0 - for curr_idx, tok in enumerate(src, start=1): - if tok == eos_tok: - src_sep.append(src[prev_idx:curr_idx]) - prev_idx = curr_idx - - # Last sequence - if prev_idx != curr_idx: - src_sep.append(src[prev_idx:]) - - # Augment - src_sep = [ - _aug_fn( - _src, - **kwargs, - ) - for _src in src_sep - ] - - # Concatenate - src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] - - # Pad or truncate to original sequence length as necessary - src_aug_concat = src_aug_concat[:initial_seq_len] - src_aug_concat += [pad_tok] * ( - initial_seq_len - len(src_aug_concat) - ) - - return src_aug_concat - - return functools.partial( - _aug_fn_concat, - _aug_fn=aug_fn, - pad_tok=self.pad_tok, - eos_tok=self.eos_tok, - ) +__all__ = ["Tokenizer"] diff --git a/ariautils/tokenizer/_base.py b/ariautils/tokenizer/_base.py new file mode 100644 index 0000000..35e51ab --- /dev/null +++ b/ariautils/tokenizer/_base.py @@ -0,0 +1,244 @@ +"""Contains abstract tokenizer class.""" + +import functools + +from typing import ( + Any, + Final, + Callable, + TypeAlias, +) + +from ariautils.midi import MidiDict + + +SpecialToken: TypeAlias = str +Token: TypeAlias = tuple[Any, ...] | str + + +class Tokenizer: + """Abstract Tokenizer class for tokenizing MidiDict objects. + + Args: + return_tensors (bool, optional): If True, encode will return tensors. + Defaults to False. + """ + + def __init__( + self, + ) -> None: + self.name: str = "" + + self.bos_tok: Final[SpecialToken] = "" + self.eos_tok: Final[SpecialToken] = "" + self.pad_tok: Final[SpecialToken] = "

" + self.unk_tok: Final[SpecialToken] = "" + self.dim_tok: Final[SpecialToken] = "" + + self.special_tokens: list[SpecialToken] = [ + self.bos_tok, + self.eos_tok, + self.pad_tok, + self.unk_tok, + self.dim_tok, + ] + + # These must be implemented in child class (abstract params) + self.config: dict[str, Any] = {} + self.vocab: tuple[Token, ...] = () + self.instruments_wd: list[str] = [] + self.instruments_nd: list[str] = [] + self.tok_to_id: dict[Token, int] = {} + self.id_to_tok: dict[int, Token] = {} + self.vocab_size: int = -1 + self.pad_id: int = -1 + + def _tokenize_midi_dict(self, midi_dict: MidiDict) -> list[Token]: + """Abstract method for tokenizing a MidiDict object into a sequence of + tokens. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + raise NotImplementedError + + def tokenize(self, midi_dict: MidiDict, **kwargs: Any) -> list[Token]: + """Tokenizes a MidiDict object. + + This function should be overridden if additional transformations are + required, e.g., adding additional tokens. The default behavior is to + call tokenize_midi_dict. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + **kwargs (Any): Additional keyword arguments passed to _tokenize_midi_dict. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + return self._tokenize_midi_dict(midi_dict, **kwargs) + + def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: + """Abstract method for de-tokenizing a sequence of tokens into a + MidiDict Object. + + Args: + tokenized_seq (list[int]): The sequence of tokens to detokenize. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + raise NotImplementedError + + def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict: + """Detokenizes a MidiDict object. + + This function should be overridden if additional are required during + detokenization. The default behavior is to call detokenize_midi_dict. + + Args: + tokenized_seq (list): The sequence of tokens to detokenize. + **kwargs (Any): Additional keyword arguments passed to detokenize_midi_dict. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + return self._detokenize_midi_dict(tokenized_seq, **kwargs) + + def export_data_aug(cls) -> list[Callable[[list[Token]], list[Token]]]: + """Export a list of implemented data augmentation functions.""" + + raise NotImplementedError + + def encode(self, unencoded_seq: list[Token]) -> list[int]: + """Converts tokenized sequence into the corresponding list of ids.""" + + def _enc_fn(tok: Token) -> int: + return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) + + if self.tok_to_id is None: + raise NotImplementedError("tok_to_id") + + encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] + + return encoded_seq + + def decode(self, encoded_seq: list[int]) -> list[Token]: + """Converts list of ids into the corresponding list of tokens.""" + + def _dec_fn(id: int) -> Token: + return self.id_to_tok.get(id, self.unk_tok) + + if self.id_to_tok is None: + raise NotImplementedError("id_to_tok") + + decoded_seq = [_dec_fn(idx) for idx in encoded_seq] + + return decoded_seq + + @classmethod + def _find_closest_int(cls, n: int, sorted_list: list[int]) -> int: + # Selects closest integer to n from sorted_list + # Time ~ Log(n) + + if not sorted_list: + raise ValueError("List is empty") + + left, right = 0, len(sorted_list) - 1 + closest = float("inf") + + while left <= right: + mid = (left + right) // 2 + diff = abs(sorted_list[mid] - n) + + if diff < abs(closest - n): + closest = sorted_list[mid] + + if sorted_list[mid] < n: + left = mid + 1 + else: + right = mid - 1 + + return closest # type: ignore[return-value] + + def add_tokens_to_vocab(self, tokens: list[Token] | tuple[Token]) -> None: + """Utility function for safely adding extra tokens to vocab.""" + + for token in tokens: + assert token not in self.vocab + + self.vocab = self.vocab + tuple(tokens) + self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} + self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} + self.vocab_size = len(self.vocab) + + def export_aug_fn_concat( + self, aug_fn: Callable[[list[Token]], list[Token]] + ) -> Callable[[list[Token]], list[Token]]: + """Transforms an augmentation function for concatenated sequences. + + This is useful for augmentation functions that are only defined for + sequences which start with "" and end with "". + + Args: + aug_fn (Callable[[list[Token]], list[Token]]): The augmentation + function to transform. + + Returns: + Callable[[list[Token]], list[Token]]: A transformed augmentation + function that can handle concatenated sequences. + """ + + def _aug_fn_concat( + src: list[Token], + _aug_fn: Callable[[list[Token]], list[Token]], + pad_tok: str, + eos_tok: str, + **kwargs: Any, + ) -> list[Token]: + # Split list on "" + initial_seq_len = len(src) + src_sep = [] + prev_idx = 0 + for curr_idx, tok in enumerate(src, start=1): + if tok == eos_tok: + src_sep.append(src[prev_idx:curr_idx]) + prev_idx = curr_idx + + # Last sequence + if prev_idx != curr_idx: + src_sep.append(src[prev_idx:]) + + # Augment + src_sep = [ + _aug_fn( + _src, + **kwargs, + ) + for _src in src_sep + ] + + # Concatenate + src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] + + # Pad or truncate to original sequence length as necessary + src_aug_concat = src_aug_concat[:initial_seq_len] + src_aug_concat += [pad_tok] * ( + initial_seq_len - len(src_aug_concat) + ) + + return src_aug_concat + + return functools.partial( + _aug_fn_concat, + _aug_fn=aug_fn, + pad_tok=self.pad_tok, + eos_tok=self.eos_tok, + ) diff --git a/ariautils/tokenizer/absolute.py b/ariautils/tokenizer/absolute.py new file mode 100644 index 0000000..6ac465b --- /dev/null +++ b/ariautils/tokenizer/absolute.py @@ -0,0 +1,848 @@ +"""Contains MIDI tokenizer with absolute onset timings.""" + +import functools +import itertools +import random +import copy + +from collections import defaultdict +from typing import Final, Callable + +from ariautils.midi import ( + MidiDict, + MetaMessage, + TempoMessage, + PedalMessage, + InstrumentMessage, + NoteMessage, + get_duration_ms, +) +from ariautils.utils import load_config, get_logger +from ariautils.tokenizer._base import Tokenizer, Token + + +logger = get_logger(__package__) + + +# TODO: +# - Add asserts to the tokenization / detokenization for user error + + +class AbsTokenizer(Tokenizer): + """MidiDict tokenizer implemented with absolute onset timings. + + The tokenizer processes MIDI files in 5000ms segments, with each segment separated by + a special token. Within each segment, note timings are represented relative to the + segment start. + + Tokenization Schema: + For non-percussion instruments: + - Each note is represented by three consecutive tokens: + 1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, and velocity + 2. [onset]: Absolute time in milliseconds from segment start + 3. [duration]: Note duration in milliseconds + + For percussion instruments: + - Each note is represented by two consecutive tokens: + 1. [drum, note_number]: Percussion instrument and MIDI note number + 2. [onset]: Absolute time in milliseconds from segment start + + Notes: + - Notes are ordered according to onset time + - Sustain pedals effects are incorporated directly into note durations + - Various configuration settings effecting instrument processing, + timing resolution, and quantization levels can be adjusted the + config.json at 'tokenizer.abs'. + """ + + def __init__(self) -> None: # Not sure why this is required by + super().__init__() + self.config = load_config()["tokenizer"]["abs"] + self.name = "abs" + + # Calculate time quantizations (in ms) + self.abs_time_step: int = self.config["abs_time_step_ms"] + self.max_dur: int = self.config["max_dur_ms"] + self.time_step: int = self.config["time_step_ms"] + + self.dur_time_quantizations = [ + self.time_step * i + for i in range((self.max_dur // self.time_step) + 1) + ] + self.onset_time_quantizations = [ + self.time_step * i for i in range((self.max_dur // self.time_step)) + ] + + # Calculate velocity quantizations + self.velocity_step: int = self.config["velocity_quantization"]["step"] + self.velocity_quantizations = [ + i * self.velocity_step + for i in range(int(127 / self.velocity_step) + 1) + ] + self.max_velocity = self.velocity_quantizations[-1] + + # _nd = no drum; _wd = with drum + self.instruments_nd = [ + k + for k, v in self.config["ignore_instruments"].items() + if v is False + ] + self.instruments_wd = self.instruments_nd + ["drum"] + + # Prefix tokens + self.prefix_tokens: list[Token] = [ + ("prefix", "instrument", x) for x in self.instruments_wd + ] + self.composer_names: list[Token] = self.config["composer_names"] + self.form_names: list[str] = self.config["form_names"] + self.genre_names: list[str] = self.config["genre_names"] + self.prefix_tokens += [ + ("prefix", "composer", x) for x in self.composer_names + ] + self.prefix_tokens += [("prefix", "form", x) for x in self.form_names] + self.prefix_tokens += [("prefix", "genre", x) for x in self.genre_names] + + # Build vocab + self.time_tok = "" + self.onset_tokens: list[Token] = [ + ("onset", i) for i in self.onset_time_quantizations + ] + self.dur_tokens: list[Token] = [ + ("dur", i) for i in self.dur_time_quantizations + ] + self.drum_tokens: list[Token] = [("drum", i) for i in range(35, 82)] + + self.note_tokens: list[Token] = list( + itertools.product( + self.instruments_nd, + [i for i in range(128)], + self.velocity_quantizations, + ) + ) + + self.special_tokens.append(self.time_tok) + self.add_tokens_to_vocab( + self.special_tokens + + self.prefix_tokens + + self.note_tokens + + self.drum_tokens + + self.dur_tokens + + self.onset_tokens + ) + self.pad_id = self.tok_to_id[self.pad_tok] + + def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]: + return [ + self.export_tempo_aug(tempo_aug_range=0.2, mixup=True), + self.export_pitch_aug(5), + self.export_velocity_aug(1), + ] + + def _quantize_dur(self, time: int) -> int: + # This function will return values res >= 0 (inc. 0) + return self._find_closest_int(time, self.dur_time_quantizations) + + def _quantize_onset(self, time: int) -> int: + # This function will return values res >= 0 (inc. 0) + return self._find_closest_int(time, self.onset_time_quantizations) + + def _quantize_velocity(self, velocity: int) -> int: + # This function will return values in the range 0 < res =< 127 + velocity_quantized = self._find_closest_int( + velocity, self.velocity_quantizations + ) + + if velocity_quantized == 0 and velocity != 0: + return self.velocity_step + else: + return velocity_quantized + + def _format( + self, prefix: list[Token], unformatted_seq: list[Token] + ) -> list[Token]: + # If unformatted_seq is longer than 150 tokens insert diminish tok + idx = -100 + random.randint(-10, 10) + if len(unformatted_seq) > 150: + if ( + unformatted_seq[idx][0] == "onset" + ): # Don't want: note, , onset, due + unformatted_seq.insert(idx - 1, self.dim_tok) + elif ( + unformatted_seq[idx][0] == "dur" + ): # Don't want: note, onset, , dur + unformatted_seq.insert(idx - 2, self.dim_tok) + else: + unformatted_seq.insert(idx, self.dim_tok) + + res = prefix + [self.bos_tok] + unformatted_seq + [self.eos_tok] + + return res + + def calc_length_ms(self, seq: list[Token], onset: bool = False) -> int: + """Calculates time (ms) end of sequence to the end of the last note. If + onset=True, then it will return the onset time of the last note instead + """ + + # Find the index of the last onset or dur token + seq = copy.deepcopy(seq) + for _idx in range(len(seq) - 1, -1, -1): + tok = seq[_idx] + if type(tok) is tuple and tok[0] in {"onset", "dur"}: + break + else: + seq.pop() + + time_offset_ms = seq.count(self.time_tok) * self.abs_time_step + idx = len(seq) - 1 + for tok in seq[::-1]: + if type(tok) is tuple and tok[0] == "dur": + assert seq[idx][0] == "dur", "Expected duration token" + assert seq[idx - 1][0] == "onset", "Expect onset token" + + onset_ms = seq[idx - 1][1] + duration_ms = seq[idx][1] + assert isinstance(onset_ms, int), "Expected int" + assert isinstance(duration_ms, int), "Expected int" + + if onset is False: + return time_offset_ms + onset_ms + duration_ms + elif onset is True: + return time_offset_ms + onset_ms # Ignore dur + + idx -= 1 + + # If it gets to this point, an error has occurred + raise Exception("Invalid sequence format") + + def truncate_by_time( + self, tokenized_seq: list[Token], trunc_time_ms: int + ) -> list[Token]: + """Truncates notes with onset_ms > trunc_time_ms.""" + time_offset_ms = 0 + for idx, tok in enumerate(tokenized_seq): + if tok == self.time_tok: + time_offset_ms += self.abs_time_step + elif type(tok) is tuple and tok[0] == "onset": + if time_offset_ms + tok[1] > trunc_time_ms: + return tokenized_seq[: idx - 1] + + return tokenized_seq + + def _tokenize_midi_dict( + self, midi_dict: MidiDict, remove_preceding_silence: bool = True + ) -> list[Token]: + ticks_per_beat = midi_dict.ticks_per_beat + midi_dict.remove_instruments(self.config["ignore_instruments"]) + + if len(midi_dict.note_msgs) == 0: + raise Exception("note_msgs is empty after ignoring instruments") + + channel_to_pedal_intervals = midi_dict._build_pedal_intervals() + + channels_used = {msg["channel"] for msg in midi_dict.note_msgs} + + channel_to_instrument = { + msg["channel"]: midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + if msg["channel"] != 9 # Exclude drums + } + # If non-drum channel is missing from instrument_msgs, default to piano + for c in channels_used: + if channel_to_instrument.get(c) is None and c != 9: + channel_to_instrument[c] = "piano" + + # Calculate prefix + prefix: list[Token] = [ + ("prefix", "instrument", x) + for x in set(channel_to_instrument.values()) + ] + if 9 in channels_used: + prefix.append(("prefix", "instrument", "drum")) + composer = midi_dict.metadata.get("composer") + if composer and (composer in self.composer_names): + prefix.insert(0, ("prefix", "composer", composer)) + form = midi_dict.metadata.get("form") + if form and (form in self.form_names): + prefix.insert(0, ("prefix", "form", form)) + genre = midi_dict.metadata.get("genre") + if genre and (genre in self.genre_names): + prefix.insert(0, ("prefix", "genre", genre)) + random.shuffle(prefix) + + tokenized_seq: list[Token] = [] + + if remove_preceding_silence is False: + initial_onset_tick = 0 + else: + initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"] + + curr_time_since_onset = 0 + for _, msg in enumerate(midi_dict.note_msgs): + # Extract msg data + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + _velocity = msg["data"]["velocity"] + _start_tick = msg["data"]["start"] + _end_tick = msg["data"]["end"] + + # Calculate time data + prev_time_since_onset = curr_time_since_onset + curr_time_since_onset = get_duration_ms( + start_tick=initial_onset_tick, + end_tick=_start_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=ticks_per_beat, + ) + + # Add abs time token if necessary + time_toks_to_append = ( + curr_time_since_onset // self.abs_time_step + ) - (prev_time_since_onset // self.abs_time_step) + if time_toks_to_append > 0: + for _ in range(time_toks_to_append): + tokenized_seq.append(self.time_tok) + + # Special case instrument is a drum. This occurs exclusively when + # MIDI channel is 9 when 0 indexing + if _channel == 9: + _note_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step + ) + tokenized_seq.append(("drum", _pitch)) + tokenized_seq.append(("onset", _note_onset)) + + else: # Non drum case (i.e. an instrument note) + _instrument = channel_to_instrument[_channel] + + # Update _end_tick if affected by pedal + for pedal_interval in channel_to_pedal_intervals[_channel]: + pedal_start, pedal_end = ( + pedal_interval[0], + pedal_interval[1], + ) + if pedal_start < _end_tick < pedal_end: + _end_tick = pedal_end + break + + _note_duration = get_duration_ms( + start_tick=_start_tick, + end_tick=_end_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=ticks_per_beat, + ) + + # Quantize + _velocity = self._quantize_velocity(_velocity) + _note_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step + ) + _note_duration = self._quantize_dur(_note_duration) + if _note_duration == 0: + _note_duration = self.time_step + + tokenized_seq.append((_instrument, _pitch, _velocity)) + tokenized_seq.append(("onset", _note_onset)) + tokenized_seq.append(("dur", _note_duration)) + + return self._format( + prefix=prefix, + unformatted_seq=tokenized_seq, + ) + + def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: + # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to + # skip converting between ticks and ms + instrument_programs = self.config["instrument_programs"] + TICKS_PER_BEAT: Final[int] = 500 + TEMPO: Final[int] = 500000 + + # Set message tempos + tempo_msgs: list[TempoMessage] = [ + {"type": "tempo", "data": TEMPO, "tick": 0} + ] + meta_msgs: list[MetaMessage] = [] + pedal_msgs: list[PedalMessage] = [] + instrument_msgs: list[InstrumentMessage] = [] + + instrument_to_channel: dict[str, int] = {} + + # Add non-drum instrument_msgs, breaks at first note token + channel_idx = 0 + curr_tick = 0 + for idx, tok in enumerate(tokenized_seq): + if channel_idx == 9: # Skip channel reserved for drums + channel_idx += 1 + + if tok in self.special_tokens: + if tok == self.time_tok: + curr_tick += self.abs_time_step + continue + elif ( + tok[0] == "prefix" + and tok[1] == "instrument" + and tok[2] in self.instruments_wd + ): + # Process instrument prefix tokens + if tok[2] in instrument_to_channel.keys(): + logger.warning(f"Duplicate prefix {tok[2]}") + continue + elif tok[2] == "drum": + instrument_msgs.append( + { + "type": "instrument", + "data": 0, + "tick": 0, + "channel": 9, + } + ) + instrument_to_channel["drum"] = 9 + else: + instrument_msgs.append( + { + "type": "instrument", + "data": instrument_programs[tok[2]], + "tick": 0, + "channel": channel_idx, + } + ) + instrument_to_channel[tok[2]] = channel_idx + channel_idx += 1 + elif tok in self.prefix_tokens: + continue + else: + # Note, wait, or duration token + start = idx + break + + # Note messages + note_msgs: list[NoteMessage] = [] + for tok_1, tok_2, tok_3 in zip( + tokenized_seq[start:], + tokenized_seq[start + 1 :], + tokenized_seq[start + 2 :], + ): + if tok_1 in self.special_tokens: + _tok_type_1 = "special" + else: + _tok_type_1 = tok_1[0] + if tok_2 in self.special_tokens: + _tok_type_2 = "special" + else: + _tok_type_2 = tok_2[0] + if tok_3 in self.special_tokens: + _tok_type_3 = "special" + else: + _tok_type_3 = tok_3[0] + + if tok_1 == self.time_tok: + curr_tick += self.abs_time_step + + elif ( + _tok_type_1 == "special" + or _tok_type_1 == "prefix" + or _tok_type_1 == "onset" + or _tok_type_1 == "dur" + ): + continue + elif _tok_type_1 == "drum" and _tok_type_2 == "onset": + assert isinstance( + tok_2[1], int + ), f"Expected int for onset, got {tok_2[1]}" + assert isinstance( + tok_1[1], int + ), f"Expected int for pitch, got {tok_1[1]}" + + _start_tick: int = curr_tick + tok_2[1] + _end_tick: int = _start_tick + self.time_step + _pitch: int = tok_1[1] + _channel: int = instrument_to_channel[tok_1[0]] + _velocity: int = self.config["drum_velocity"] + + if _channel is None: + logger.warning( + "Tried to decode note message for unexpected instrument" + ) + else: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) + + elif ( + _tok_type_1 in self.instruments_nd + and _tok_type_2 == "onset" + and _tok_type_3 == "dur" + ): + assert isinstance( + tok_1[1], int + ), f"Expected int for pitch, got {tok_1[1]}" + assert isinstance( + tok_1[2], int + ), f"Expected int for velocity, got {tok_1[2]}" + assert isinstance( + tok_2[1], int + ), f"Expected int for onset, got {tok_2[1]}" + assert isinstance( + tok_3[1], int + ), f"Expected int for duration, got {tok_3[1]}" + + _pitch = tok_1[1] + _channel = instrument_to_channel[tok_1[0]] + _velocity = tok_1[2] + _start_tick = curr_tick + tok_2[1] + _end_tick = _start_tick + tok_3[1] + + if _channel is None: + logger.warning( + "Tried to decode note message for unexpected instrument" + ) + else: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) + else: + logger.warning( + f"Unexpected token sequence: {tok_1}, {tok_2}, {tok_3}" + ) + + return MidiDict( + meta_msgs=meta_msgs, + tempo_msgs=tempo_msgs, + pedal_msgs=pedal_msgs, + instrument_msgs=instrument_msgs, + note_msgs=note_msgs, + ticks_per_beat=TICKS_PER_BEAT, + metadata={}, + ) + + def export_pitch_aug( + self, aug_range: int + ) -> Callable[[list[Token]], list[Token]]: + """Exports a function that augments the pitch of all note tokens. + + Notes which fall out of the range (0, 127) will be replaced + with the unknown token ''. + + Args: + aug_range (int): Returned function will randomly augment the pitch + from a value in the range (-aug_range, aug_range). + + Returns: + Callable[[list[Token]], list[Token]]: Exported function. + """ + + def pitch_aug_seq( + src: list[Token], + unk_tok: str, + _aug_range: int, + pitch_aug: int | None = None, + ) -> list[Token]: + def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token: + if isinstance(tok, str): # Stand in for SpecialToken + _tok_type = "special" + else: + _tok_type = tok[0] + + if ( + _tok_type == "special" + or _tok_type == "prefix" + or _tok_type == "dur" + or _tok_type == "drum" + or _tok_type == "onset" + ): + # Return without changing + return tok + else: + # Return augmented tok + assert ( + isinstance(tok, tuple) and len(tok) == 3 + ), f"Invalid note token" + (_instrument, _pitch, _velocity) = tok + + assert isinstance( + _pitch, int + ), f"Expected int for pitch, got {_pitch}" + assert isinstance( + _velocity, int + ), f"Expected int for velocity, got {_velocity}" + + if 0 <= _pitch + _pitch_aug <= 127: + return (_instrument, _pitch + _pitch_aug, _velocity) + else: + return unk_tok + + if not pitch_aug: + pitch_aug = random.randint(-_aug_range, _aug_range) + + return [pitch_aug_tok(x, pitch_aug) for x in src] + + # See functools.partial docs + return self.export_aug_fn_concat( + functools.partial( + pitch_aug_seq, + unk_tok=self.unk_tok, + _aug_range=aug_range, + ) + ) + + def export_velocity_aug( + self, aug_steps_range: int + ) -> Callable[[list[Token]], list[Token]]: + """Exports a function which augments the velocity of all pitch tokens. + + Velocity values are clipped so that they don't fall outside of the + valid range. + + Args: + aug_steps_range (int): Returned function will randomly augment + velocity in the range aug_steps_range * (-self.velocity_step, + self.velocity step). + + Returns: + Callable[[list[Token]], list[Token]]: Exported function. + """ + + def velocity_aug_seq( + src: list[Token], + velocity_step: int, + max_velocity: int, + _aug_steps_range: int, + velocity_aug: int | None = None, + ) -> list[Token]: + def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token: + if isinstance(tok, str): # Stand in for SpecialToken + _tok_type = "special" + else: + _tok_type = tok[0] + + if ( + _tok_type == "special" + or _tok_type == "prefix" + or _tok_type == "dur" + or _tok_type == "drum" + or _tok_type == "onset" + ): + # Return without changing + return tok + else: + assert isinstance(tok, tuple) and len(tok) == 3 + (_instrument, _pitch, _velocity) = tok + + assert isinstance(_pitch, int) + assert isinstance(_velocity, int) + + # Check it doesn't go out of bounds + if _velocity + _velocity_aug >= max_velocity: + return (_instrument, _pitch, max_velocity) + elif _velocity + _velocity_aug <= velocity_step: + return (_instrument, _pitch, velocity_step) + + return (_instrument, _pitch, _velocity + _velocity_aug) + + if not velocity_aug: + velocity_aug = velocity_step * random.randint( + -_aug_steps_range, _aug_steps_range + ) + + return [velocity_aug_tok(x, velocity_aug) for x in src] + + # See functools.partial docs + return self.export_aug_fn_concat( + functools.partial( + velocity_aug_seq, + velocity_step=self.velocity_step, + max_velocity=self.max_velocity, + _aug_steps_range=aug_steps_range, + ) + ) + + # TODO: Adjust this so it can handle other tokens like + def export_tempo_aug( + self, tempo_aug_range: float, mixup: bool + ) -> Callable[[list[Token]], list[Token]]: + """Exports a function which augments the tempo of a sequence of tokens. + + Additionally this function performs note-mixup: randomly re-ordering + the note subsequences which occur on the same onset. + + IMPORTANT: This function doesn't support additional tokens. If you + have modified the tokenizer in any way, this should not be added to + export_data_aug. + + Args: + tempo_aug_range (int): Returned function will randomly augment + tempo by a factor in the range (1 - tempo_aug_range, + 1 + tempo_aug_range). + + Returns: + Callable[[list[Token]], list[Token]]: Exported function. + """ + + def tempo_aug( + src: list[Token], + abs_time_step: int, + max_dur: int, + time_step: int, + unk_tok: str, + time_tok: str, + dim_tok: str, + start_tok: str, + end_tok: str, + instruments_wd: list, + tokenizer_name: str, + _tempo_aug_range: float, + _mixup: bool, + tempo_aug: float | None = None, + ) -> list[Token]: + """This must be used with export_aug_fn_concat in order to work + properly for concatenated sequences.""" + + def _quantize_time(_n: int) -> int: + return round(_n / time_step) * time_step + + assert ( + tokenizer_name == "abs" + ), f"Augmentation function only supports base AbsTokenizer" + + if not tempo_aug: + tempo_aug = random.uniform( + 1 - _tempo_aug_range, 1 + _tempo_aug_range + ) + + src_time_tok_cnt = 0 + dim_tok_seen = None + res: list[Token] = [] + note_buffer: dict[str, Token | None] | None = None + + # Buffer tracks + buffer: dict[int, dict[int, list[dict[str, Token | None]]]] = ( + defaultdict(lambda: defaultdict(list)) + ) + for tok_1, tok_2, tok_3 in zip(src, src[1:], src[2:]): + if tok_1 == time_tok: + _tok_type = "time" + elif tok_1 == unk_tok: + _tok_type = "unk" + elif tok_1 == start_tok: + res.append(tok_1) + continue + elif tok_1 == dim_tok and note_buffer: + assert isinstance(note_buffer["onset"], int) + dim_tok_seen = (src_time_tok_cnt, note_buffer["onset"][1]) + continue + elif tok_1[0] == "prefix": + res.append(tok_1) + continue + elif tok_1[0] in instruments_wd: + _tok_type = tok_1[0] + else: + # This only triggers for incomplete notes at the beginning, + # e.g. an onset token before a note token is seen + continue + + if _tok_type == "time": + src_time_tok_cnt += 1 + elif _tok_type == "drum": + assert isinstance(tok_2[1], int) + note_buffer = { + "note": tok_1, + "onset": tok_2, + "dur": None, + } + buffer[src_time_tok_cnt][tok_2[1]].append(note_buffer) + else: # unk or in instruments_wd + assert isinstance(tok_2[1], int) + note_buffer = { + "note": tok_1, + "onset": tok_2, + "dur": tok_3, + } + buffer[src_time_tok_cnt][tok_2[1]].append(note_buffer) + + prev_tgt_time_tok_cnt = 0 + for src_time_tok_cnt, interval_notes in sorted(buffer.items()): + for src_onset, notes_by_onset in sorted(interval_notes.items()): + src_time = src_time_tok_cnt * abs_time_step + src_onset + tgt_time = round(src_time * tempo_aug) + curr_tgt_time_tok_cnt = tgt_time // abs_time_step + curr_tgt_onset = _quantize_time(tgt_time % abs_time_step) + + if curr_tgt_onset == abs_time_step: + curr_tgt_onset -= time_step + + for _ in range( + curr_tgt_time_tok_cnt - prev_tgt_time_tok_cnt + ): + res.append(time_tok) + prev_tgt_time_tok_cnt = curr_tgt_time_tok_cnt + + if _mixup == True: + random.shuffle(notes_by_onset) + + for note in notes_by_onset: + _src_note_tok = note["note"] + _src_dur_tok = note["dur"] + assert _src_note_tok is not None + + if _src_dur_tok is not None: + assert isinstance(_src_dur_tok[1], int) + tgt_dur = _quantize_time( + round(_src_dur_tok[1] * tempo_aug) + ) + tgt_dur = min(tgt_dur, max_dur) + else: + tgt_dur = None + + res.append(_src_note_tok) + res.append(("onset", curr_tgt_onset)) + if tgt_dur: + res.append(("dur", tgt_dur)) + + if dim_tok_seen is not None and dim_tok_seen == ( + src_time_tok_cnt, + src_onset, + ): + res.append(dim_tok) + dim_tok_seen = None + + if src[-1] == end_tok: + res.append(end_tok) + + return res + + return self.export_aug_fn_concat( + functools.partial( + tempo_aug, + abs_time_step=self.abs_time_step, + max_dur=self.max_dur, + time_step=self.time_step, + unk_tok=self.unk_tok, + time_tok=self.time_tok, + dim_tok=self.dim_tok, + end_tok=self.eos_tok, + start_tok=self.bos_tok, + instruments_wd=self.instruments_wd, + tokenizer_name=self.name, + _tempo_aug_range=tempo_aug_range, + _mixup=mixup, + ) + ) From e6763a6fa95127b351114e07d445045568fafab3 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 20 Nov 2024 16:33:18 +0000 Subject: [PATCH 12/12] fix mypy --- ariautils/midi.py | 5 +++++ ariautils/utils/__init__.py | 13 +++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index a59526e..1813bdf 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -24,6 +24,11 @@ from ariautils.utils import load_config, load_maestro_metadata_json +# TODO: +# - Remove unneeded comments +# - Add asserts + + class MetaMessage(TypedDict): """Meta message type corresponding text or copyright MIDI meta messages.""" diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index 4c51084..c2ccbbf 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -9,14 +9,19 @@ from .config import load_config -def get_logger(name: str) -> logging.Logger: +def get_logger(name: str | None) -> logging.Logger: logger = logging.getLogger(name) if not logger.handlers: logger.propagate = False logger.setLevel(logging.DEBUG) - formatter = logging.Formatter( - "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" - ) + if name is not None: + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" + ) + else: + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] %(message)s" + ) ch = logging.StreamHandler() ch.setLevel(logging.INFO)