diff --git a/bin/BeatTracker b/bin/BeatTracker index e207e12c5..6a8b62f5f 100755 --- a/bin/BeatTracker +++ b/bin/BeatTracker @@ -59,7 +59,7 @@ def main(): # version p.add_argument('--version', action='version', version='BeatTracker.2016') # input/output arguments - io_arguments(p, output_suffix='.beats.txt') + io_arguments(p, output_suffix='.beats.txt', online=True) ActivationsProcessor.add_arguments(p) # signal processing arguments SignalProcessor.add_arguments(p, norm=False, gain=0) diff --git a/madmom/features/beats.py b/madmom/features/beats.py index f4460f381..b6dad7851 100755 --- a/madmom/features/beats.py +++ b/madmom/features/beats.py @@ -16,7 +16,7 @@ from ..audio.signal import smooth as smooth_signal from ..ml.nn import average_predictions from ..processors import (OnlineProcessor, ParallelProcessor, Processor, - SequentialProcessor, ) + SequentialProcessor, BufferProcessor) # classes for tracking (down-)beats with RNNs @@ -392,7 +392,7 @@ def recursive(position): # classes for detecting/tracking of beat inside a beat activation function -class BeatTrackingProcessor(Processor): +class BeatTrackingProcessor(OnlineProcessor): """ Track the beats according to previously determined (local) tempo by iteratively aligning them around the estimated position [1]_. @@ -406,6 +406,15 @@ class BeatTrackingProcessor(Processor): look_ahead : float, optional Look `look_ahead` seconds in both directions to determine the local tempo and align the beats accordingly. + threshold : float, optional + Only accept activations as beat which exceed that threshold. + Currently only available in online mode. + buffer_size : float, optional + Use past `buffer_size` seconds for detecting if the current frame is a + beat in online mode. + look_back : int, optional + In online mode a window of 'look_back' frames is used for allowing + slight variations of tempo and detections. tempo_estimator : :class:`TempoEstimationProcessor`, optional Use this processor to estimate the (local) tempo. If 'None' a default tempo estimator will be created and used. @@ -458,23 +467,48 @@ class BeatTrackingProcessor(Processor): """ LOOK_ASIDE = 0.2 - LOOK_AHEAD = 10. - - def __init__(self, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD, fps=None, - tempo_estimator=None, **kwargs): + LOOK_AHEAD = 10 + THRESHOLD = 0.1 + BUFFER_SIZE = 3. + LOOK_BACK = 5 + + def __init__(self, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD, + threshold=THRESHOLD, buffer_size=BUFFER_SIZE, + look_back=LOOK_BACK, tempo_estimator=None, fps=None, + online=False, **kwargs): + # pylint: disable=unused-argument + super(BeatTrackingProcessor, self).__init__(online=online) # save variables self.look_aside = look_aside self.look_ahead = look_ahead + self.threshold = threshold self.fps = fps # tempo estimator if tempo_estimator is None: # import the TempoEstimation here otherwise we have a loop from .tempo import TempoEstimationProcessor # create default tempo estimator - tempo_estimator = TempoEstimationProcessor(fps=fps, **kwargs) + tempo_estimator = TempoEstimationProcessor(fps=fps, + online=online, + **kwargs) self.tempo_estimator = tempo_estimator + if self.online: + self.visualize = kwargs.get('verbose', False) + self.buffer = BufferProcessor(int(buffer_size * self.fps)) + self.look_back = look_back + self.counter = 0 + self.beat_counter = 0 + self.last_beat = 0 - def process(self, activations, **kwargs): + def reset(self): + """Reset the BeatTrackingProcessor.""" + self.tempo_estimator.reset() + self.buffer.reset() + self.counter = 0 + self.beat_counter = 0 + self.last_beat = 0 + + def process_offline(self, activations, **kwargs): """ Detect the beats in the given activation function. @@ -544,6 +578,73 @@ def process(self, activations, **kwargs): # remove beats with negative times and return them return detections[np.searchsorted(detections, 0):] + def process_online(self, activations, reset=True, **kwargs): + """ + Detect the beats in the given activation function for online mode. + + Parameters + ---------- + activations : numpy array + Beat activation function. + reset : bool, optional + Reset the BeatTrackingProcessor to its initial state before + processing. + + Returns + ------- + beats : numpy array + Detected beat positions [seconds]. + + """ + # reset to initial state + if reset: + self.reset() + beats_ = [] + for activation in activations: + # shift buffer and put new activation at end of buffer + buffer = self.buffer(activation) + # create an online interval histogram + histogram = self.tempo_estimator.interval_histogram( + np.array([activation]), reset=reset) + # get the dominant interval + interval = self.tempo_estimator.dominant_interval(histogram) + # compute the current and the next possible beat time + cur_beat = self.counter / float(self.fps) + next_beat = self.last_beat + 60. / self.tempo_estimator.max_bpm + # only detect beats again after at least min_interval frames + detections = [] + if cur_beat >= next_beat: + detections = detect_beats(buffer, interval, self.look_aside) + # if a detection falls within the last few frames and exceeds a + # certain threshold a beat was detected. this is important because + # for every frame the tempo or the detecctions may change and + # therefore the last beat can easily be missed. + if len(detections) and \ + detections[-1] >= len(buffer) - self.look_back and \ + buffer[detections[-1]] > self.threshold: + # append to beats + beats_.append(cur_beat) + # update last beat + self.last_beat = cur_beat + # visualize beats + if self.visualize: + display = [''] + if len(beats_) > 0 and beats_[-1] == cur_beat: + self.beat_counter = 10 + if self.beat_counter > 0: + display.append('| X ') + else: + display.append('| ') + self.beat_counter -= 1 + # display tempo + display.append('| %5.1f | ' % float(self.fps * 60 / interval)) + sys.stderr.write('\r%s' % ''.join(display)) + sys.stderr.flush() + # increase counter + self.counter += 1 + # return beat(s) + return np.array(beats_) + @staticmethod def add_arguments(parser, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD): @@ -985,7 +1086,7 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, num_tempi=NUM_TEMPI, self.fps = fps self.min_bpm = min_bpm self.max_bpm = max_bpm - # kepp state in online mode + # keep state in online mode self.online = online # TODO: refactor the visualisation stuff if self.online: diff --git a/tests/test_features_beats.py b/tests/test_features_beats.py index 4cec91500..df5f5de8b 100644 --- a/tests/test_features_beats.py +++ b/tests/test_features_beats.py @@ -79,6 +79,25 @@ def test_process(self): self.assertTrue(np.allclose(beats, [0.11, 0.45, 0.79, 1.13, 1.47, 1.81, 2.15, 2.49])) + def test_process_online(self): + processor = BeatTrackingProcessor(fps=sample_lstm_act.fps, + online=True) + # compute the beats at once + beats = processor.process_online(sample_lstm_act, reset=False) + self.assertTrue(np.allclose(beats, [0.68, 1.14, 1.48, 1.84, 2.18, + 2.51])) + # compute the beats framewise + processor.reset() + beats = [processor.process_online(np.atleast_2d(act), reset=False) + for act in sample_lstm_act] + self.assertTrue(np.allclose(np.nonzero(beats), + [68, 114, 148, 184, 218, 251])) + # without resetting results are different + beats = [processor.process_online(np.atleast_2d(act), reset=False) + for act in sample_lstm_act] + self.assertTrue(np.allclose(np.nonzero(beats), [5, 148, 184, 217, + 251])) + class TestBeatDetectionProcessorClass(unittest.TestCase):