forked from Elvenson/piano_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
114 lines (89 loc) · 3.53 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2019 Google LLC.
# Licensed under the Apache License, Version 2.0 (the "License");
# Modification copyright 2020 Bui Quoc Bao
# Change notebook script into package
"""Transform utilities."""
import numpy as np
import tensorflow as tf
from tensor2tensor.data_generators import text_encoder
import magenta.music as mm
from magenta.models.score2perf import score2perf
LOGGER = tf.compat.v1.logging
class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem): # pylint: disable=missing-module-docstring, abstract-method, missing-class-docstring
@property
def add_eos_symbol(self):
return True
class MelodyToPianoPerformanceProblem(score2perf.AbsoluteMelody2PerfProblem): # pylint: disable=missing-module-docstring, abstract-method, missing-class-docstring
@property
def add_eos_symbol(self):
return True
def decode(ids, encoder):
"""Decode a list of IDs."""
ids = list(ids)
if text_encoder.EOS_ID in ids:
ids = ids[:ids.index(text_encoder.EOS_ID)]
return encoder.decode(ids)
def unconditional_input_generator(targets, decode_length):
"""Estimator input function for unconditional Transformer."""
while True:
yield {
'targets': np.array([targets], dtype=np.int32),
'decode_length': np.array(decode_length, dtype=np.int32)
}
def melody_input_generator(inputs, decode_length):
"""Estimator input function for melody Transformer."""
while True:
yield {
'inputs': np.array([[inputs]], dtype=np.int32),
'targets': np.zeros([1, 0], dtype=np.int32),
'decode_length': np.array(decode_length, dtype=np.int32)
}
def get_primer_ns(filename, max_length):
"""
Convert Midi file to note sequences for priming.
:param filename: Midi file name.
:param max_length: Maximum note sequence length for priming in seconds.
:return:
Note sequences for priming.
"""
primer_ns = mm.midi_file_to_note_sequence(filename)
# Handle sustain pedal in primer.
primer_ns = mm.apply_sustain_control_changes(primer_ns)
# Trim to desired number of seconds.
if primer_ns.total_time > max_length:
LOGGER.warn(
'Primer duration %d is longer than max second %d, truncating.'
% (primer_ns.total_time, max_length))
primer_ns = mm.extract_subsequence(
primer_ns, 0, max_length
)
# Remove drums from primer if present.
if any(note.is_drum for note in primer_ns.notes):
LOGGER.warn('Primer contains drums; they will be removed.')
notes = [note for note in primer_ns.notes if not note.is_drum]
del primer_ns.notes[:]
primer_ns.notes.extend(notes)
# Set primer instrument and program.
for note in primer_ns.notes:
note.instrument = 1
note.program = 0
return primer_ns
def get_melody_ns(filename):
"""
Convert melody Midi file to note sequence.
:param filename: Midi file name.
:return:
Melody note sequences.
"""
melody_ns = mm.midi_file_to_note_sequence(filename)
melody_instrument = mm.infer_melody_for_sequence(melody_ns)
# pylint: disable=no-member
notes = [note for note in melody_ns.notes if note.instrument == melody_instrument]
del melody_ns.notes[:]
melody_ns.notes.extend(
sorted(notes, key=lambda note: note.start_time)
)
for i in range(len(melody_ns.notes) - 1):
melody_ns.notes[i].end_time = melody_ns.notes[i + 1].start_time
# pylint: disable=no-member
return melody_ns