Skip to content

Commit

Permalink
Basic online BeatTracker
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianPoell committed Sep 1, 2017
1 parent c7746a8 commit 0f7c279
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 10 deletions.
2 changes: 1 addition & 1 deletion bin/BeatTracker
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main():
# version
p.add_argument('--version', action='version', version='BeatTracker.2016')
# input/output arguments
io_arguments(p, output_suffix='.beats.txt')
io_arguments(p, output_suffix='.beats.txt', online=True)
ActivationsProcessor.add_arguments(p)
# signal processing arguments
SignalProcessor.add_arguments(p, norm=False, gain=0)
Expand Down
119 changes: 110 additions & 9 deletions madmom/features/beats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..audio.signal import smooth as smooth_signal
from ..ml.nn import average_predictions
from ..processors import (OnlineProcessor, ParallelProcessor, Processor,
SequentialProcessor, )
SequentialProcessor, BufferProcessor)


# classes for tracking (down-)beats with RNNs
Expand Down Expand Up @@ -392,7 +392,7 @@ def recursive(position):


# classes for detecting/tracking of beat inside a beat activation function
class BeatTrackingProcessor(Processor):
class BeatTrackingProcessor(OnlineProcessor):
"""
Track the beats according to previously determined (local) tempo by
iteratively aligning them around the estimated position [1]_.
Expand All @@ -406,6 +406,15 @@ class BeatTrackingProcessor(Processor):
look_ahead : float, optional
Look `look_ahead` seconds in both directions to determine the local
tempo and align the beats accordingly.
threshold : float, optional
Only accept activations as beat which exceed that threshold.
Currently only available in online mode.
buffer_size : float, optional
Use past `buffer_size` seconds for detecting if the current frame is a
beat in online mode.
look_back : int, optional
In online mode a window of 'look_back' frames is used for allowing
slight variations of tempo and detections.
tempo_estimator : :class:`TempoEstimationProcessor`, optional
Use this processor to estimate the (local) tempo. If 'None' a default
tempo estimator will be created and used.
Expand Down Expand Up @@ -458,23 +467,48 @@ class BeatTrackingProcessor(Processor):
"""
LOOK_ASIDE = 0.2
LOOK_AHEAD = 10.

def __init__(self, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD, fps=None,
tempo_estimator=None, **kwargs):
LOOK_AHEAD = 10
THRESHOLD = 0.1
BUFFER_SIZE = 3.
LOOK_BACK = 5

def __init__(self, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD,
threshold=THRESHOLD, buffer_size=BUFFER_SIZE,
look_back=LOOK_BACK, tempo_estimator=None, fps=None,
online=False, **kwargs):
# pylint: disable=unused-argument
super(BeatTrackingProcessor, self).__init__(online=online)
# save variables
self.look_aside = look_aside
self.look_ahead = look_ahead
self.threshold = threshold
self.fps = fps
# tempo estimator
if tempo_estimator is None:
# import the TempoEstimation here otherwise we have a loop
from .tempo import TempoEstimationProcessor
# create default tempo estimator
tempo_estimator = TempoEstimationProcessor(fps=fps, **kwargs)
tempo_estimator = TempoEstimationProcessor(fps=fps,
online=online,
**kwargs)
self.tempo_estimator = tempo_estimator
if self.online:
self.visualize = kwargs.get('verbose', False)
self.buffer = BufferProcessor(int(buffer_size * self.fps))
self.look_back = look_back
self.counter = 0
self.beat_counter = 0
self.last_beat = 0

def process(self, activations, **kwargs):
def reset(self):
"""Reset the BeatTrackingProcessor."""
self.tempo_estimator.reset()
self.buffer.reset()
self.counter = 0
self.beat_counter = 0
self.last_beat = 0

def process_offline(self, activations, **kwargs):
"""
Detect the beats in the given activation function.
Expand Down Expand Up @@ -544,6 +578,73 @@ def process(self, activations, **kwargs):
# remove beats with negative times and return them
return detections[np.searchsorted(detections, 0):]

def process_online(self, activations, reset=True, **kwargs):
"""
Detect the beats in the given activation function for online mode.
Parameters
----------
activations : numpy array
Beat activation function.
reset : bool, optional
Reset the BeatTrackingProcessor to its initial state before
processing.
Returns
-------
beats : numpy array
Detected beat positions [seconds].
"""
# reset to initial state
if reset:
self.reset()
beats_ = []
for activation in activations:
# shift buffer and put new activation at end of buffer
buffer = self.buffer(activation)
# create an online interval histogram
histogram = self.tempo_estimator.interval_histogram(
np.array([activation]), reset=reset)
# get the dominant interval
interval = self.tempo_estimator.dominant_interval(histogram)
# compute the current and the next possible beat time
cur_beat = self.counter / float(self.fps)
next_beat = self.last_beat + 60. / self.tempo_estimator.max_bpm
# only detect beats again after at least min_interval frames
detections = []
if cur_beat >= next_beat:
detections = detect_beats(buffer, interval, self.look_aside)
# if a detection falls within the last few frames and exceeds a
# certain threshold a beat was detected. this is important because
# for every frame the tempo or the detecctions may change and
# therefore the last beat can easily be missed.
if len(detections) and \
detections[-1] >= len(buffer) - self.look_back and \
buffer[detections[-1]] > self.threshold:
# append to beats
beats_.append(cur_beat)
# update last beat
self.last_beat = cur_beat
# visualize beats
if self.visualize:
display = ['']
if len(beats_) > 0 and beats_[-1] == cur_beat:
self.beat_counter = 10
if self.beat_counter > 0:
display.append('| X ')
else:
display.append('| ')
self.beat_counter -= 1
# display tempo
display.append('| %5.1f | ' % float(self.fps * 60 / interval))
sys.stderr.write('\r%s' % ''.join(display))
sys.stderr.flush()
# increase counter
self.counter += 1
# return beat(s)
return np.array(beats_)

@staticmethod
def add_arguments(parser, look_aside=LOOK_ASIDE,
look_ahead=LOOK_AHEAD):
Expand Down Expand Up @@ -985,7 +1086,7 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, num_tempi=NUM_TEMPI,
self.fps = fps
self.min_bpm = min_bpm
self.max_bpm = max_bpm
# kepp state in online mode
# keep state in online mode
self.online = online
# TODO: refactor the visualisation stuff
if self.online:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_features_beats.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@ def test_process(self):
self.assertTrue(np.allclose(beats, [0.11, 0.45, 0.79, 1.13, 1.47,
1.81, 2.15, 2.49]))

def test_process_online(self):
processor = BeatTrackingProcessor(fps=sample_lstm_act.fps,
online=True)
# compute the beats at once
beats = processor.process_online(sample_lstm_act, reset=False)
self.assertTrue(np.allclose(beats, [0.68, 1.14, 1.48, 1.84, 2.18,
2.51]))
# compute the beats framewise
processor.reset()
beats = [processor.process_online(np.atleast_2d(act), reset=False)
for act in sample_lstm_act]
self.assertTrue(np.allclose(np.nonzero(beats),
[68, 114, 148, 184, 218, 251]))
# without resetting results are different
beats = [processor.process_online(np.atleast_2d(act), reset=False)
for act in sample_lstm_act]
self.assertTrue(np.allclose(np.nonzero(beats), [5, 148, 184, 217,
251]))


class TestBeatDetectionProcessorClass(unittest.TestCase):

Expand Down

0 comments on commit 0f7c279

Please sign in to comment.