Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add post-process labels stage #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions data/output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{'encounter_id': '11486', 'latitude': 36.91, 'longitude': -122.02, 'displayImgUrl': 'https://au-hw-media-m.happywhale.com/c5522187-058e-4a1a-83d7-893560ba6b2c.jpg', 'audio': array([], dtype=float32), 'start': Timestamp('2016-12-21 00:20:30'), 'end': Timestamp('2016-12-21 00:21:30'), 'classifications': []}
{'encounter_id': '9182', 'latitude': 36.91, 'longitude': -122.02, 'displayImgUrl': 'https://au-hw-media-m.happywhale.com/d40b9e6e-07cf-4f20-8cb4-4042ba22a00b.jpg', 'audio': array([-0.00352275, -0.00346267, -0.00334585, ..., -0.00339496,
-0.00333035, -0.00329852], dtype=float32), 'start': Timestamp('2016-12-21 00:49:30'), 'end': Timestamp('2016-12-21 00:50:30'), 'classifications': [[0.8753612041473389], [0.746759295463562], [0.26265254616737366], [0.45787951350212097], [0.35406064987182617], [0.42348742485046387], [0.4947870969772339], [0.7287474274635315], [0.7099379897117615], [0.2122703194618225], [0.044488538056612015], [0.00849922839552164], [0.024390267208218575], [0.33750119805336], [0.6530888080596924], [0.3057247996330261], [0.1243574470281601], [0.027093390002846718], [0.011367958970367908], [0.004032353404909372], [0.026372192427515984], [0.021978065371513367], [0.006407670211046934], [0.5405446887016296], [0.34207114577293396], [0.6080849766731262], [0.5394770503044128], [0.3662146031856537], [0.16772609949111938], [0.3641503155231476], [0.060217034071683884], [0.008764371275901794], [0.012523961253464222], [0.009186000563204288], [0.022050702944397926], [0.3908870816230774], [0.15179167687892914], [0.3454047441482544], [0.4770602285861969], [0.07589100301265717], [0.5439115166664124], [0.8634722232818604], [0.985602617263794], [0.3311924636363983], [0.8832067847251892], [0.6166273951530457], [0.42301759123802185], [0.03573732450604439], [0.09752023965120316], [0.01426385436207056], [0.022987568750977516], [0.012294118292629719], [0.010207954794168472], [0.00296270614489913]]}
36 changes: 17 additions & 19 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from stages.audio import RetrieveAudio, WriteAudio, WriteSiftedAudio
from stages.sift import Butterworth
from stages.classify import WhaleClassifier, WriteClassifications
from stages.postprocess import PostprocessLabels


from config import load_pipeline_config
config = load_pipeline_config()
Expand All @@ -19,29 +21,25 @@ def run():
}

with beam.Pipeline(options=pipeline_options) as p:
input_data = p | "Create Input" >> beam.Create([args])
search_output = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch())

audio_output = search_output | "Retrieve Audio" >> beam.ParDo(RetrieveAudio())
audio_output | "Store Audio (temp)" >> beam.ParDo(WriteAudio())

sifted_audio = audio_output | "Sift Audio" >> Butterworth()
sifted_audio | "Store Sifted Audio" >> beam.ParDo(WriteSiftedAudio("butterworth"))

classifications = sifted_audio | "Classify Audio" >> WhaleClassifier(config)
classifications | "Store Classifications" >> beam.ParDo(WriteClassifications(config))


# # Post-process the labels
# postprocessed_labels = classified_audio | "Postprocess Labels" >> PostprocessLabels()
input_data = p | "Create Input" >> beam.Create([args])
search_output = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch())
audio_output = search_output | "Retrieve Audio" >> beam.ParDo(RetrieveAudio())
sifted_audio = audio_output | "Sift Audio" >> Butterworth()
classifications = sifted_audio | "Classify Audio" >> WhaleClassifier(config)
postprocess_labels = classifications | "Postprocess Labels" >> beam.ParDo(
PostprocessLabels(config),
search_output=beam.pvalue.AsSingleton(search_output),
)

# Store results
audio_output | "Store Audio (temp)" >> beam.ParDo(WriteAudio())
sifted_audio | "Store Sifted Audio" >> beam.ParDo(WriteSiftedAudio("butterworth"))
classifications | "Store Classifications" >> beam.ParDo(WriteClassifications(config))
postprocess_labels | "Write Results" >> beam.io.WriteToText("data/output.txt", shard_name_template="")

# Output results
# postprocessed_labels | "Write Results" >> beam.io.WriteToText("output.txt")

# For debugging, you can write the output to a text file
# audio_files | "Write Audio Output" >> beam.io.WriteToText('audio_files.txt')
# search_results | "Write Search Output" >> beam.io.WriteToText('search_results.txt')


if __name__ == "__main__":
run()
59 changes: 59 additions & 0 deletions src/stages/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import apache_beam as beam

from datetime import datetime
from typing import Dict, Any, Tuple
from types import SimpleNamespace
from matplotlib import gridspec

import librosa
import logging
import numpy as np
import os
import time
import pandas as pd

import requests
import math
import matplotlib.pyplot as plt
import scipy.signal


class PostprocessLabels(beam.DoFn):
def __init__(self, config: SimpleNamespace):
self.config = config

self.search_output_path_template = config.search.export_template
self.sifted_audio_path_template = config.sift.output_path_template
self.classification_path = config.classify.classification_path


def process(self, element: Dict[str, Any], search_output: Dict[str, Any]):
logging.info(f"element \n{element}")
logging.info(f"search_output \n{search_output}")
breakpoint()

classifications_df = pd.DataFrame([element], columns=["audio", "start", "end", "encounter_ids", "classifications"])
classifications_df = classifications_df.explode("encounter_ids").rename(columns={"encounter_ids": "encounter_id"})
classifications_df["encounter_id"] = classifications_df["encounter_id"].astype(str)

# TODO pool classifications


search_output = search_output.rename(columns={"id": "encounter_id"})
search_output["encounter_id"] = search_output["encounter_id"].astype(str) # TODO do in one line
search_output = search_output[[
# TODO refactor to confing
"encounter_id",
"latitude",
"longitude",
"displayImgUrl",
# "species", # TODO add in geo search stage (require rm local file)
]]

# join dataframes
joined_df = pd.merge(search_output, classifications_df, how="inner", on="encounter_id")

logging.info(f"Final output: \n{joined_df.head()}")


return joined_df.to_dict(orient="records")
Loading