diff --git a/.gitignore b/.gitignore index 7acfd5a..31cce71 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ # Repo specific ignores audio/ - +plots/ # Python basic ignores # Byte-compiled / optimized / DLL files diff --git a/docs/ladr/LADR_0001_lightweight_detections.md b/docs/ladr/LADR_0001_lightweight_detections.md index 0decccf..3458d09 100644 --- a/docs/ladr/LADR_0001_lightweight_detections.md +++ b/docs/ladr/LADR_0001_lightweight_detections.md @@ -1,38 +1,51 @@ # Lightweight Detection Mechanisms This doc's purpose is to consider the different options for simple whale vocalization detection mechanisms. +The chosen mechanism will sift through the large audio data, to find the most interesting parts to feed into the machine learning model. ## Background -The purpose of this pipeline to is efficiently detect vocalizations from whales encountered on [HappyWhale](https://happywhale.com/). +The main goal of this pipeline to is efficiently detect vocalizations from whales encountered on [HappyWhale](https://happywhale.com/). Since audio data is notoriously large, we want to quickly find which chunks of audio are most important to look at, before classifying them via the [humpback_whale model](https://tfhub.dev/google/humpback_whale/1). Initially, the data that does not make it past the filter will not be fed through the machine learning model, in order to keep costs down. This means, we need a filtering mechanism that is "generuous" in what it flags, but not too generous that it flags everything. +I'm still learning about different options, so this doc will be updated as I learn more. + ## Options ### Energy filter Simplest of filters. Just measures peaks in audio signal, above a specified threshold. +We would need to normalize then threshold the audio signal to detect peaks. +Could also do a root mean square (RMS) over a short window to detect high energy sections. #### Pros - very lightweight - easy to implement #### Cons -- too much noise can make it through +- too much noise will make it through - not very specific - prioritizes loudness over frequency - sounds from a distance will likely not be detected +- too rudiementary alone, but good to combine w/ frequency filters for example -### Butterworth Bandpass Filter -[Wikipedia](https://en.wikipedia.org/wiki/Butterworth_filter) +### [Butterworth Passband Filter](https://en.wikipedia.org/wiki/Butterworth_filter) Filters out audio that does not contain a certain frequency range. Only allows a particular band of frequencies to pass through. These are determined via the filter's order and cutoff frequencies, low and high. +The order of the filter determines how steep the roll-off is, i.e. how "boxy" the filter is. +The higher the order, the steeper the roll-off, and the sharper the corners of the top of the filter. + +![bandpass filter](https://upload.wikimedia.org/wikipedia/commons/thumb/f/f6/Bandwidth.svg/320px-Bandwidth.svg.png) + +For our use case, we want to find audio with frequencies matching expected whale vocalization frequencies. +If the audio over a specified time window contains frequencies inside of the band, it is flagged as a detection. +![detections](../../img/detections.png) -![alt text](https://upload.wikimedia.org/wikipedia/commons/thumb/f/f6/Bandwidth.svg/320px-Bandwidth.svg.png) +To get this most out of this type of filter, we would likely need to use a handful of them, each focusing on their own frequency range. #### Pros - more specific than energy filter @@ -42,29 +55,57 @@ These are determined via the filter's order and cutoff frequencies, low and high - easy to implement - can be used together w/ other filtering methods + #### Cons -- room for improvement on specificity +- fixed window size (needs to be tuned) - assumes a certain frequency range is the most important - disregards harmonics - clicks may not be detected - different ages/individuals may have different frequency ranges (?) -- not great at detecting sounds from a distance +- not great at detecting vocalizations from a distance + + +### [Chebyshev Passband Filter](https://en.wikipedia.org/wiki/Chebyshev_filter) +Slightly more complex than the Butterworth filter, but with a steeper roll-off. +The Chebyshev filter has a ripple in the passband, which can be tuned to be more or less aggressive. +It can also handle removing specific frequencies in the stopband, for example a ship's engine noise. + +![chebyshev](https://upload.wikimedia.org/wikipedia/commons/thumb/b/b7/ChebyshevII_response-en.svg/2880px-ChebyshevII_response-en.svg.png) + + +#### Pros +- more specific than Butterworth +- can be tuned to specific frequencies +- can be used together w/ other filtering methods +- can remove specific frequencies + +#### Cons +- more computationally expensive than Butterworth (I don't think it is by too much though. Need to test) +- more difficult to interpret +- no experience with it +- ... + +### [Spectrogram](https://en.wikipedia.org/wiki/Spectrogram) +A [short-time Fourier transform](https://en.wikipedia.org/wiki/Short-time_Fourier_transform), which visualizes the frequencies of a signal as it varies with time. +This is a 3D representation of the audio signal, where the x-axis is time, y-axis is frequency, and color is energy of that frequency. + +![spectrogram](../../img/spectrogram.png) -### Spectrogram -A visual representation of the spectrum of frequencies of a signal as it varies with time. -This is a 2D representation of the audio signal, where the x-axis is time, y-axis is frequency, and color is amplitude. #### Pros - can be used to detect harmonics - can be used to detect clicks - can be used to detect sounds from a distance - can be used to detect multiple species +- can be calibrated to find specific frequencies +- energy threshold can be applied to find exact time-range of matching frequencies + - might need to be dynamic to adjust for different distances +- visually intuitive #### Cons -- computationally expensive -- not lightweight -- more difficult to work with (2D data) +- slightly more computationally expensive (still less than model) + ### Distilled Student-Teacher model We could set up a lightweight NN classifier @@ -96,6 +137,8 @@ any whale prescence, i.e. most of the time. Final option is to just directly use the model on the data surronding a encounter. This is the most expensive option, but also the most accurate. +![model results](../../img/model_results.png) + #### Pros - most accurate - "smartest" filter @@ -106,6 +149,7 @@ This is the most expensive option, but also the most accurate. ## Decision -Initially, I will use the Butterworth Bandpass Filter, as it is the most lightweight and specific of the options. +In my brainstorming notebook, I used the Butterworth Bandpass Filter, since it was the simplest filter of a cert +The pipeline should maybe be built to allow for easy swapping of filters via the config file, which will enable easier experimentation at some later time. From this, I'll gather results, and see if going straight to model would be better. This doc will be updated with the results of the experiment. diff --git a/examples/butterworth.py b/examples/butterworth.py new file mode 100644 index 0000000..5f126fa --- /dev/null +++ b/examples/butterworth.py @@ -0,0 +1,80 @@ +""" +Example provided at: https://scipy-cookbook.readthedocs.io/items/ButterworthBandpass.html +""" +from scipy.signal import butter, lfilter, sosfilt + +def butter_bandpass(lowcut, highcut, fs, order=5, output="ba"): + nyq = 0.5 * fs + low = lowcut / nyq + high = highcut / nyq + return butter(order, [low, high], btype='band', output=output) + + +def butter_bandpass_filter(data, lowcut, highcut, fs, order=5, output="sos"): + butter_values = butter_bandpass(lowcut, highcut, fs, order=order, output=output) + if output == "ba": + b, a = butter_values + y = lfilter(b, a, data) + elif output == "sos": + sos = butter_values + y = sosfilt(sos, data) + return y + + +def run(): + import numpy as np + import matplotlib.pyplot as plt + from scipy.signal import freqz + + # Sample rate and desired cutoff frequencies (in Hz). + fs = 5000.0 + lowcut = 600.0 + highcut = 1250.0 + + # Plot the frequency response for a few different orders. + plt.figure(1) + plt.clf() + for order in [2, 4, 6]: #[3, 6, 9]: + b, a = butter_bandpass(lowcut, highcut, fs, order=order, output="ba") + w, h = freqz(b, a, worN=2000) + plt.plot((fs * 0.5 / np.pi) * w, abs(h), label="order = %d" % order) + + plt.plot([0, 0.5 * fs], [np.sqrt(0.5), np.sqrt(0.5)], + '--', label='sqrt(0.5)') + plt.xlabel('Frequency (Hz)') + plt.ylabel('Gain') + plt.grid(True) + plt.legend(loc='best') + + # Filter a noisy signal. + T = 0.05 + nsamples = int(T * fs) + t = np.linspace(0, T, nsamples, endpoint=False) + a = 0.05 # sinewave amplitude (used here to emphasize our desired frequency) + f0 = 600.0 + x = 0.1 * np.sin(2 * np.pi * 1.2 * np.sqrt(t)) + x += 0.01 * np.cos(2 * np.pi * 312 * t + 0.1) + x += 0.01 * np.cos(2 * np.pi * 510 * t + 0.1) # another frequency in the lowcut and highcut range + x += 0.01 * np.cos(2 * np.pi * 520 * t + 0.1) # another frequency in the lowcut and highcut range + x += 0.01 * np.cos(2 * np.pi * 530 * t + 0.1) # another frequency in the lowcut and highcut range + x += 0.01 * np.cos(2 * np.pi * 540 * t + 0.1) # another frequency in the lowcut and highcut range + x += 0.01 * np.cos(2 * np.pi * 550 * t + 0.1) # another frequency in the lowcut and highcut range + x += 0.01 * np.cos(2 * np.pi * 1200 * t + 0.1) # another frequency in the lowcut and highcut range + x += a * np.cos(2 * np.pi * f0 * t + .11) + x += 0.03 * np.cos(2 * np.pi * 2000 * t) + plt.figure(2) + plt.clf() + plt.plot(t, x, label='Noisy signal') + + y = butter_bandpass_filter(x, lowcut, highcut, fs, order=6) + plt.plot(t, y, label='Filtered signal') + plt.xlabel('time (seconds)') + plt.hlines([-a, a], 0, T, linestyles='--') + plt.grid(True) + plt.axis('tight') + plt.legend(loc='upper left') + + plt.show() + + +run() \ No newline at end of file diff --git a/examples/notebooks/adaptedpacificsounddetecthumpbacksong.py b/examples/notebooks/adaptedpacificsounddetecthumpbacksong.py index e3e872a..15823fc 100644 --- a/examples/notebooks/adaptedpacificsounddetecthumpbacksong.py +++ b/examples/notebooks/adaptedpacificsounddetecthumpbacksong.py @@ -461,7 +461,6 @@ def run( """## Filters -### Me I want to play around with different filters for the whale noises. Applying the filter to the signal will allow me to emphasize some signals, and mitigate others. """ @@ -477,29 +476,29 @@ def run( sample_rate = sample_rate, ) -def low_pass( +def average_filter( signal = signal, - factor = 0.01, size = 100, + factor = 0.1, ): - low_pass_filter = np.ones(size)*factor - filtered_signal = np.convolve(signal, low_pass_filter) + average_array = (np.ones(size)/size)*factor + filtered_signal = np.convolve(signal, average_array) return filtered_signal factor = 0.1 size = 100 -filtered_signal = low_pass(signal, factor=factor, size=size) +filtered_signal = average_filter(signal, factor=factor, size=size) # plot_play(filtered_signal, title=f"Low-pass filter - factor:{factor} size:{size}") factor = 0.1 size = 10 -filtered_signal = low_pass(signal, factor=factor, size=size) +filtered_signal = average_filter(signal, factor=factor, size=size) # plot_play(filtered_signal, title=f"Low-pass filter - factor:{factor} size:{size}") -"""### Claude""" +"""### Ask an agent""" ( sample_start, diff --git a/img/detections.png b/img/detections.png new file mode 100644 index 0000000..f5dd2b0 Binary files /dev/null and b/img/detections.png differ diff --git a/img/model_results.png b/img/model_results.png new file mode 100644 index 0000000..ecf68d3 Binary files /dev/null and b/img/model_results.png differ diff --git a/img/spectrogram.png b/img/spectrogram.png new file mode 100644 index 0000000..07e7bc3 Binary files /dev/null and b/img/spectrogram.png differ diff --git a/src/pipeline/app.py b/src/pipeline/app.py index 96f24b9..c3a15b1 100644 --- a/src/pipeline/app.py +++ b/src/pipeline/app.py @@ -2,22 +2,32 @@ from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions from stages.search import GeometrySearch -from stages.audio import RetrieveAudio, WriteAudio +from stages.audio import RetrieveAudio, WriteAudio, WriteSiftedAudio +from stages.sift import Butterworth +from config import load_pipeline_config +config = load_pipeline_config() def run(): # Initialize pipeline options pipeline_options = PipelineOptions() pipeline_options.view_as(SetupOptions).save_main_session = True + args = { + "start": config.input.start, + "end": config.input.end + } with beam.Pipeline(options=pipeline_options) as p: - input_data = p | "Create Input" >> beam.Create([{'start': '2016-12-21T00:30:0', 'end':"2016-12-21T00:40:0"}]) - search_results = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch()) - audio_results = search_results | "Retrieve Audio" >> beam.ParDo(RetrieveAudio()) - # filtered_audio = audio_results | "Filter Frequency" >> FilterFrequency() + input_data = p | "Create Input" >> beam.Create([args]) + search_results = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch()) + + audio_results = search_results | "Retrieve Audio" >> beam.ParDo(RetrieveAudio()) + audio_files = audio_results | "Store Audio (temp)" >> beam.ParDo(WriteAudio()) + + sifted_audio = audio_results | "Sift Audio" >> Butterworth() + sifted_audio_files = sifted_audio | "Store Sifted Audio" >> beam.ParDo(WriteSiftedAudio("butterworth")) # For debugging, you can write the output to a text file - audio_files = audio_results | "Store Audio (temp)" >> beam.ParDo(WriteAudio()) # audio_files | "Write Audio Output" >> beam.io.WriteToText('audio_files.txt') # search_results | "Write Search Output" >> beam.io.WriteToText('search_results.txt') diff --git a/src/pipeline/config.yaml b/src/pipeline/config.yaml index c650f03..1b9d2b6 100644 --- a/src/pipeline/config.yaml +++ b/src/pipeline/config.yaml @@ -1,6 +1,7 @@ pipeline: general: verbose: true + debug: true input: start: "2016-12-21T00:30:00" @@ -28,16 +29,24 @@ pipeline: source_sample_rate: 16000 margin: 30 # TODO set to 900 # seconds offset: 13 # TODO set to 0 # hours - output_path_template: "data/audio/{year}/{month:02}/{filename}" - skip_existing: false - - detection_filter: - highcut: 1500 - lowcut: 50 - order: 10 - frequency_threshold: 0.015 + output_path_template: "data/audio/raw/{year}/{month:02}/{filename}" + skip_existing: false # if true, skip downstream processing of existing audio files (false during development) + + sift: + output_path_template: "data/audio/{sift}/{year}/{month:02}/{filename}" + max_duration: 600 # seconds + plot: true + plot_path_template: "data/plots/{sift}/{year}/{month:02}/{day:02}/{plot_name}.png" window_size: 512 + # Specific sift-mechanism parameters + butterworth: + highcut: 1500 + lowcut: 50 + order: 5 + output: "sos" # "sos" or "ba" + sift_threshold: 0.015 + model: url: https://tfhub.dev/google/humpback_whale/1 model_sample_rate: 10000 diff --git a/src/pipeline/stages/audio.py b/src/pipeline/stages/audio.py index 531c3df..4260596 100644 --- a/src/pipeline/stages/audio.py +++ b/src/pipeline/stages/audio.py @@ -1,5 +1,6 @@ from apache_beam.io import filesystems from datetime import timedelta, datetime +from functools import partial from six.moves.urllib.request import urlopen # pyright: ignore from typing import Tuple @@ -17,6 +18,17 @@ class AudioTask(beam.DoFn): + debug = config.general.debug + filename_template = config.audio.filename_template + output_path_template = config.audio.output_path_template + skip_existing = config.audio.skip_existing + source_sample_rate = config.audio.source_sample_rate + url_template = config.audio.url_template + + def __init__(self): + # init certrain attributes for mockable testing + self.margin = config.audio.margin + self.offset = config.audio.offset def _save_audio(self, audio:np.array, file_path:str): # Write the numpy array to the file as .npy format @@ -39,17 +51,17 @@ def _get_export_path( encounter_ids: list[str], ): - file_path_prefix = "{date}-{start_time}-{end_time}-{ids}".format( + file_path_prefix = "{date}T{start_time}-{end_time}-{ids}".format( date=start.strftime("%Y%m%d"), - start_time=start.strftime("%H%M"), - end_time=end.strftime("%H%M"), + start_time=start.strftime("%H%M%S"), + end_time=end.strftime("%H%M%S"), ids="_".join(encounter_ids) ) # Create a unique file name for each element filename = f"{file_path_prefix}.npy" # f"{file_path_prefix}_{hash(element)}.npy" - file_path = config.audio.output_path_template.format( + file_path = self.output_path_template.format( year=start.year, month=start.month, filename=filename @@ -86,8 +98,7 @@ def process(self, search_results: pd.DataFrame): if self._file_exists_for_input(start_time, end_time, row.encounter_ids): audio_path = self._get_export_path(start_time, end_time, row.encounter_ids) logging.info(f"Audio already exists for {start_time} to {end_time}") - - if config.audio.skip_existing: + if self.skip_existing: logging.info(f"Skipping downstream processing for {audio_path}") continue else: @@ -100,7 +111,7 @@ def process(self, search_results: pd.DataFrame): audio, _ = self._download_audio( start_time, end_time, - config.audio.source_sample_rate, + self.source_sample_rate, ) # Yield the audio and the search_results @@ -115,7 +126,7 @@ def _preprocess(self, df: pd.DataFrame): def _build_time_frames(self, df: pd.DataFrame): - margin = config.audio.margin + margin = self.margin df["startTime"] = pd.to_datetime(df["startDate"].astype(str) + ' ' + df["startTime"].astype(str)) df["endTime"] = pd.to_datetime(df["startDate"].astype(str) + ' ' + df["endTime"].astype(str)) @@ -123,9 +134,9 @@ def _build_time_frames(self, df: pd.DataFrame): df["start_time"] = df.startTime - timedelta(seconds=margin) df["end_time"] = df.endTime + timedelta(seconds=margin) - # offset TODO remove this. Only testing for now - df["start_time"] = df["start_time"] - timedelta(hours=config.audio.offset) - df["end_time"] = df["end_time"] - timedelta(hours=config.audio.offset) + # TODO remove this. Only use during development + df["start_time"] = df["start_time"] - timedelta(hours=self.offset) + df["end_time"] = df["end_time"] - timedelta(hours=self.offset) assert pd.api.types.is_datetime64_any_dtype(df["start_time"]) assert pd.api.types.is_datetime64_any_dtype(df["end_time"]) @@ -178,8 +189,8 @@ def _get_file_url( month: int, day: int, ): - filename = str(config.audio.filename_template).format(year=year, month=month, day=day) - url = config.audio.url_template.format(year=year, month=month, day=day, filename=filename) + filename = str(self.filename_template).format(year=year, month=month, day=day) + url = self.url_template.format(year=year, month=month, day=day, filename=filename) return url @@ -263,3 +274,14 @@ def process(self, element): logging.info(f"Audio shape: {array.shape}") yield self._save_audio(array, file_path) + + +class WriteSiftedAudio(WriteAudio): + output_path_template = config.sift.output_path_template + + def __init__(self, sift="sift"): + super().__init__() + self.output_path_template = self.output_path_template.replace("{sift}", sift) + logging.info(f"Sifted output path template: {self.output_path_template}") + + # everything is used from WriteAudio diff --git a/src/pipeline/stages/sift.py b/src/pipeline/stages/sift.py new file mode 100644 index 0000000..7cfbf67 --- /dev/null +++ b/src/pipeline/stages/sift.py @@ -0,0 +1,294 @@ +from apache_beam.io import filesystems +from datetime import datetime +from scipy.signal import butter, lfilter, find_peaks, sosfilt + +import apache_beam as beam +import io +import logging +import math +import matplotlib.pyplot as plt +import numpy as np +import os + +from config import load_pipeline_config + + +config = load_pipeline_config() + + +class BaseSift(beam.PTransform): + """ + All our Sift transforms will have the structure: + - input: full audio signal + - output: timeframes + """ + name = "BaseSift" + + # general params + debug = config.general.debug + sample_rate = config.audio.source_sample_rate + + # sift-specific params + max_duration = config.sift.max_duration + window_size = config.sift.window_size + + # plot params + plot = config.sift.plot + plot_path_template = config.sift.plot_path_template + + def _build_key( + self, + start_time: datetime, + end_time: datetime, + encounter_ids: list, + ): + start_str = start_time.strftime('%Y%m%dT%H%M%S') + end_str = end_time.strftime('%H%M%S') + encounter_str = "_".join(encounter_ids) + return f"{start_str}-{end_str}_{encounter_str}" + + def _preprocess(self, pcoll): + """ + pcoll: tuple(audio, start_time, end_time, row.encounter_ids) + """ + signal, start, end, encounter_ids = pcoll + key = self._build_key(start, end, encounter_ids) + + max_samples = self.max_duration * self.sample_rate + + if signal.shape[0] > max_samples: + logging.debug(f"Signal size exceeds max sample size {max_samples}.") + split_indices = [max_samples*(i+1) for i in range(math.floor(signal.shape[0] / max_samples))] + signal_batches = np.array_split(signal, split_indices) + logging.debug(f"Split signal into {len(signal_batches)} batches of size {max_samples}.") + logging.debug(f"Size fo final batch {len(signal_batches[1])}") + + for batch in signal_batches: + yield (key, batch) + else: + yield (key, signal) + + def _postprocess(self, pcoll, min_max_detections): + signal, start, end, encounter_ids = pcoll + key = self._build_key(start, end, encounter_ids) + + logging.info(f"Postprocessing {self.name} sifted signal.") + logging.debug(f"Signal: {signal}") + logging.debug(f"Min-max detections: {min_max_detections}") + logging.debug(f"Key: {key}") + + global_detection_range = [ + min_max_detections[key]["min"], + min_max_detections[key]["max"] + ] + + return signal[global_detection_range[0]:global_detection_range[-1]], start, end, encounter_ids + + def _plot_signal_detections(self, pcoll, min_max_detections, all_detections, params=None): + signal, start, end, encounter_ids = pcoll + key = self._build_key(start, end, encounter_ids) + + min_max_detection_samples = [ + min_max_detections[key]["min"], # maually fix ordering + min_max_detections[key]["max"] + ] + logging.info(f"Plotting signal detections: {min_max_detection_samples}") + + # datetime format matches original audio file name + plot_path = self.plot_path_template.format( + sift=self.name, + year=start.year, + month=start.month, + day=start.day, + plot_name=key + ) + + # TODO make dirs cleaner + if not os.path.isdir(os.path.sep.join(plot_path.split(os.path.sep)[:-1])): + os.makedirs(os.path.sep.join(plot_path.split(os.path.sep)[:-1])) + + # normalize and center + signal = signal / np.max(signal) # normalize + signal = signal - np.mean(signal) # center + + plt.figure(figsize=(20, 10)) + plt.plot(signal) + + # NOTE: for legend logic, plot min_max window first + if len(min_max_detection_samples): + # shade window that resulted in detection + plt.axvspan( + min_max_detection_samples[0], + min_max_detection_samples[-1], + alpha=0.3, + color='y', + zorder=7, # on top of all detections + ) + + if len(all_detections[key]): + # shade window that resulted in detection + for detection in all_detections[key]: + plt.axvspan( + detection - self.window_size/2, + detection + self.window_size/2, + alpha=0.5, + color='r', + zorder=5, # on top of signal + ) + + + plt.legend(['Input signal', 'detection window', 'all detections']).set_zorder(10) + plt.xlabel(f'Samples (seconds * {self.sample_rate} Hz)') + plt.ylabel('Amplitude (normalized and centered)') + + title = f"({self.name}) Signal detections: {start.strftime('%Y-%m-%d %H:%M:%S')}-{end.strftime('%H:%M:%S')}\n" + title += f"Params: {params} \n" if params else "" + title += f"Encounters: {encounter_ids}" + plt.title(title) + plt.savefig(plot_path) + plt.show() + + +class Butterworth(BaseSift): + name = "Butterworth" + + def __init__( + self, + lowcut: int = None, + highcut: int = None, + order: int = None, + output: str = None, + sift_threshold: float = None, + ): + super().__init__() + + # define bandpass + self.lowcut = config.sift.butterworth.lowcut if not lowcut else lowcut + self.highcut = config.sift.butterworth.highcut if not highcut else highcut + self.order = config.sift.butterworth.order if not order else order + self.output = config.sift.butterworth.output if not output else output + + # apply bandpass + self.sift_threshold = ( + config.sift.butterworth.sift_threshold + if not sift_threshold else sift_threshold + ) + + + def expand(self, pcoll): + """ + pcoll: tuple(audio, start_time, end_time, row.encounter_ids) + """ + # full-input preprocess + batch = pcoll | "Preprocess" >> beam.ParDo(self._preprocess) + + # batched process + detections = batch | "Sift Frequency" >> beam.ParDo(self._frequency_filter_sift) + min_max_detections = detections | "Min-Max Detections" >> beam.CombinePerKey(MinMaxCombine()) + all_detections = detections | "List Detections" >> beam.CombinePerKey(ListCombine()) + + # full-input postprocess + sifted_output = pcoll | "Postprocess" >> beam.Map( + self._postprocess, + min_max_detections=beam.pvalue.AsDict(min_max_detections), + ) + + # plots for debugging purposes + if self.debug or self.plot: + pcoll | "Plot Sifted Output" >> beam.Map( + self._plot_signal_detections, + min_max_detections=beam.pvalue.AsDict(min_max_detections), + all_detections=beam.pvalue.AsDict(all_detections), + params={ + "lowcut": self.lowcut, + "highcut": self.highcut, + "order": self.order, + "threshold": self.sift_threshold, + } + ) + + return sifted_output + + def _butter_bandpass(self): + """ + Returns specific Butterworth filter (IIR) from parameters in config.sift.butterworth + """ + nyq = 0.5 * self.sample_rate + low = self.lowcut / nyq + high = self.highcut / nyq + return butter(self.order, [low, high], btype='band', output=self.output) + + def _frequency_filter_sift( + self, + batch: tuple, + ): + key, signal = batch + + logging.info(f"Start frequency detection on (key, signal): {(key, signal.shape)}") + + # Apply bandpass filter + butter_coeffients = self._butter_bandpass() + if self.output == "ba": + filtered_signal = lfilter(butter_coeffients[0], butter_coeffients[1], signal) + elif self.output == "sos": + filtered_signal = sosfilt(butter_coeffients, signal) + else: + raise ValueError(f"Invalid output type for config.sift.butterworth.output: {self.output}") + logging.debug(f"Filtered signal: {filtered_signal}") + + # Calculate energy in windows + energy = np.array([ + sum(abs(filtered_signal[i:i+self.window_size]**2)) + for i in range(0, len(filtered_signal), self.window_size) + ]) + logging.debug(f"Energy: {energy}") + + # Normalize energy + energy = energy / np.max(energy) + logging.debug(f"Normalized energy: {energy}") + + # Find peaks above threshold + peaks, _ = find_peaks(energy, height=self.sift_threshold) + logging.debug(f"Peaks: {peaks}") + + # Convert peak indices to time + peak_samples = peaks * (self.window_size) + logging.debug(f"Peak samples: {peak_samples}") + + yield (key, peak_samples) + + +class ListCombine(beam.CombineFn): + name = "ListCombine" + + def create_accumulator(self): + return [] + + def add_input(self, accumulator, input): + """ + Key is not available in this method, + though inputs are only added to accumulator under correct key. + """ + logging.debug(f"Adding input {input} to {self.name} accumulator.") + if isinstance(input, np.ndarray): + input = input.tolist() + accumulator += input + return accumulator + + def merge_accumulators(self, accumulators): + return [item for sublist in accumulators for item in sublist] + + def extract_output(self, accumulator): + return accumulator + + +class MinMaxCombine(ListCombine): + name = "MinMaxCombine" + + def extract_output(self, accumulator): + if not accumulator: + return None + + logging.debug(f"Extracting min-max output from {accumulator}") + return {"min": min(accumulator), "max": max(accumulator)} + diff --git a/tests/data/20161221T004930-005030-9182.npy b/tests/data/20161221T004930-005030-9182.npy new file mode 100644 index 0000000..8d1b77d Binary files /dev/null and b/tests/data/20161221T004930-005030-9182.npy differ diff --git a/tests/test_audio.py b/tests/test_audio.py index 79ab7e6..b401e1d 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -31,7 +31,6 @@ def sample_time_frame_df(): return pd.DataFrame(data) - @patch('stages.audio.config') def test_build_time_frames(mock_config, sample_search_results_df): # Assemble @@ -78,8 +77,8 @@ def test_find_overlapping(sample_time_frame_df): def test_preprocess(mock_config, sample_search_results_df): # Assemble # mock config values to avoid failing tests on config changes - mock_config.audio.margin = 5 * 60 # 5 minutes - mock_config.audio.offset = 1 # 1 hour # TODO remove. only used for development + mock_config.audio.margin = 5 * 60 # 5 minutes + mock_config.audio.offset = 1 # 1 hour # TODO remove. only used for development expected_df = pd.DataFrame({ 'encounter_ids':[["1"], ["3", "2"]], diff --git a/tests/test_config.py b/tests/test_config.py index 5f89fd9..cfce822 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ def test_load_pipeline_config(): This test checks that all expected params are present in the pipeline options, via set reduction set(a) - set(b) == 0. """ - expected_keys = ['general', 'input', 'search', 'audio', 'detection_filter', 'model', 'postprocess'] + expected_keys = ['general', 'input', 'search', 'audio', 'sift', 'model', 'postprocess'] actual_config = load_pipeline_config().__dict__ actual_keys = actual_config.keys() diff --git a/tests/test_sift.py b/tests/test_sift.py new file mode 100644 index 0000000..4d40923 --- /dev/null +++ b/tests/test_sift.py @@ -0,0 +1,139 @@ +from datetime import datetime + +import numpy as np +import pandas as pd +import pytest + +from stages.sift import BaseSift, Butterworth +from unittest.mock import patch + + +@pytest.fixture +def sample_audio_results_df(): + data = { + 'audio': [np.array([0, 1, 2, 3, 2, 1]), np.array([1, 2, 3, 2]), np.array([0, 1, 2, 3, 4, 5, 4, 3, 2, 1]*16_500)], + 'start_time': [datetime(2024, 7, 7, 23, 8, 0), datetime(2024, 9, 10, 0, 32, 0), datetime(2024, 9, 10, 0, 27, 0)], + 'end_time': [datetime(2024, 7, 7, 23, 18, 0), datetime(2024, 9, 10, 0, 42, 0), datetime(2024, 9, 10, 0, 47, 0)], + 'encounter_ids':[["1"], ["3", "2"], ["2"]] + } + return pd.DataFrame(data) + +@pytest.fixture +def sample_audio_results_row(): + audio = np.array([0, 1, 2, 3, 4, 5, 4, 3, 2, 1]*16_500*60) # bigger than max_duration + start_time = datetime(2024, 7, 7, 23, 8, 0) + end_time = datetime(2024, 7, 7, 23, 18, 0) + encounter_ids = ["encounter1", "encounter2"] + yield audio, start_time, end_time, encounter_ids + + +@pytest.fixture +def sample_batch(): + """ + Cherry picked example data with a sginal that has butterowrth detection w/ params: + - lowcut = 50 + - highcut = 1500 + - order = 2 + + """ + signal = np.load("tests/data/20161221T004930-005030-9182.npy") + key = "20161221T004930-005030-9182" + return (key, signal) + + +def test_build_key(sample_audio_results_row): + # Assemble + _, start, end, encounter_ids = sample_audio_results_row + + expected_key = "20240707T230800-231800_encounter1_encounter2" + + # Act + actual_key = BaseSift()._build_key(start, end, encounter_ids) + + # Assert + assert expected_key == actual_key + + +def test_preprocess(sample_audio_results_row): + # Assemble + expected_data = [ + ("20240707T230800-231800_encounter1_encounter2", np.array([0, 1, 2, 3, 4, 5, 4, 3, 2, 1]*16_000*60)), + ("20240707T230800-231800_encounter1_encounter2", np.array([0, 1, 2, 3, 4, 5, 4, 3, 2, 1]*500*60)), + ] + + # Act + actual_data_generator = Butterworth()._preprocess(sample_audio_results_row) + + # Assert + for expected in expected_data: + actual = next(actual_data_generator) # unload generator + + assert expected[0] == actual[0] # key + assert expected[1].shape == actual[1].shape # data + + +def test_postprocess(sample_audio_results_row): + # Assemble + pcoll = sample_audio_results_row + min_max_detections = { + "20240707T230800-231800_encounter1_encounter2": {"min":0, "max": 5} # samples + } + + expected_data = ( + np.array([0, 1, 2, 3, 4]), # audio + datetime(2024, 7, 7, 23, 8, 0), # start_time + datetime(2024, 7, 7, 23, 18, 0), # end_time + ["encounter1", "encounter2"] # encounter_ids + ) + + # Act + actual_data = Butterworth()._postprocess(pcoll, min_max_detections) + + # Assert + assert len(expected_data) == len(actual_data) + for expected, actual in zip(expected_data, actual_data): + assert np.array_equal(expected, actual) + + +@patch('stages.sift.config') +def test_butter_bandpass(mock_config): + # Assemble + mock_config.sift.butterworth.lowcut = 50 + mock_config.sift.butterworth.highcut = 1500 + mock_config.sift.butterworth.order = 2 + mock_config.sift.butterworth.output = "sos" + + expected_coefficients = [ + [ 0.05711683, 0.11423366, 0.05711683, 1. , -1.22806805, 0.4605427 ], + [ 1. , -2. , 1. , 1. , -1.97233136, 0.97273604] + ] + + actual_coefficients = Butterworth()._butter_bandpass() + + # Assert + assert len(actual_coefficients) == 2 # order + assert np.allclose(expected_coefficients, actual_coefficients) + + +@patch('stages.sift.config') +def test_frequency_filter_sift(mock_config, sample_batch): + # Assemble + mock_config.sift.butterworth.lowcut = 50 + mock_config.sift.butterworth.highcut = 1500 + mock_config.sift.butterworth.order = 5 + mock_config.sift.butterworth.output = "sos" + mock_config.sift.butterworth.sift_threshold = 0.015 + + expected_detections = ( + "20161221T004930-005030-9182", + np.array([13824]) + ) + + # Act + actual_detections_generator = Butterworth()._frequency_filter_sift(sample_batch) + actual_detections = next(actual_detections_generator) + + + # Assert + assert expected_detections[0] == actual_detections[0] + np.allclose(expected_detections[1], actual_detections[1])