diff --git a/birdnet_analyzer/analyze.py b/birdnet_analyzer/analyze.py index de2f037d..4e3ad327 100644 --- a/birdnet_analyzer/analyze.py +++ b/birdnet_analyzer/analyze.py @@ -65,7 +65,19 @@ def loadCodes(): return codes -def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str: +def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str): + """ + Generates a Raven selection table from the given timestamps and prediction results. + + Args: + timestamps (list[str]): List of timestamp strings in the format "start-end". + result (dict[str, list]): Dictionary where keys are timestamp strings and values are lists of predictions. + afile_path (str): Path to the audio file being analyzed. + result_path (str): Path where the resulting Raven selection table will be saved. + + Returns: + None + """ selection_id = 0 out_string = RAVEN_TABLE_HEADER @@ -104,7 +116,19 @@ def generate_raven_table(timestamps: list[str], result: dict[str, list], afile_p utils.save_result_file(result_path, out_string) -def generate_audacity(timestamps: list[str], result: dict[str, list], result_path: str) -> str: +def generate_audacity(timestamps: list[str], result: dict[str, list], result_path: str): + """ + Generates an Audacity timeline label file from the given timestamps and results. + + Args: + timestamps (list[str]): A list of timestamp strings. + result (dict[str, list]): A dictionary where keys are timestamps and values are lists of tuples, + each containing a label and a confidence score. + result_path (str): The file path where the result string will be saved. + + Returns: + None + """ out_string = "" # Audacity timeline labels @@ -124,7 +148,20 @@ def generate_audacity(timestamps: list[str], result: dict[str, list], result_pat utils.save_result_file(result_path, out_string) -def generate_rtable(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str: +def generate_rtable(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str): + """ + Generates a R table string from the given timestamps and result data, and saves it to a file. + + Args: + timestamps (list[str]): A list of timestamp strings in the format "start-end". + result (dict[str, list]): A dictionary where keys are timestamp strings and values are lists of tuples containing + classification results (label, confidence). + afile_path (str): The path to the audio file being analyzed. + result_path (str): The path where the result table file will be saved. + + Returns: + None + """ out_string = RTABLE_HEADER for timestamp in timestamps: @@ -157,7 +194,20 @@ def generate_rtable(timestamps: list[str], result: dict[str, list], afile_path: utils.save_result_file(result_path, out_string) -def generate_kaleidoscope(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str: +def generate_kaleidoscope(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str): + """ + Generates a Kaleidoscope-compatible CSV string from the given timestamps and results, and saves it to a file. + + Args: + timestamps (list[str]): List of timestamp strings in the format "start-end". + result (dict[str, list]): Dictionary where keys are timestamp strings and values are lists of tuples containing + species label and confidence score. + afile_path (str): Path to the audio file being analyzed. + result_path (str): Path where the resulting CSV file will be saved. + + Returns: + None + """ out_string = KALEIDOSCOPE_HEADER folder_path, filename = os.path.split(afile_path) @@ -192,7 +242,20 @@ def generate_kaleidoscope(timestamps: list[str], result: dict[str, list], afile_ utils.save_result_file(result_path, out_string) -def generate_csv(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str) -> str: +def generate_csv(timestamps: list[str], result: dict[str, list], afile_path: str, result_path: str): + """ + Generates a CSV file from the given timestamps and results. + + Args: + timestamps (list[str]): A list of timestamp strings in the format "start-end". + result (dict[str, list]): A dictionary where keys are timestamp strings and values are lists of tuples. + Each tuple contains a label and a confidence score. + afile_path (str): The file path of the audio file being analyzed. + result_path (str): The file path where the resulting CSV file will be saved. + + Returns: + None + """ out_string = CSV_HEADER for timestamp in timestamps: @@ -212,12 +275,16 @@ def generate_csv(timestamps: list[str], result: dict[str, list], afile_path: str def saveResultFiles(r: dict[str, list], result_files: dict[str, str], afile_path: str): - """Saves the results to the hard drive. + """ + Saves the result files in various formats based on the provided configuration. Args: - r: The dictionary with {segment: scores}. - path: The path where the result should be saved. - afile_path: The path to audio file. + r (dict[str, list]): A dictionary containing the analysis results with timestamps as keys. + result_files (dict[str, str]): A dictionary mapping result types to their respective file paths. + afile_path (str): The path to the audio file being analyzed. + + Returns: + None """ os.makedirs(cfg.OUTPUT_PATH, exist_ok=True) @@ -242,6 +309,15 @@ def saveResultFiles(r: dict[str, list], result_files: dict[str, str], afile_path def combine_raven_tables(saved_results: list[str]): + """ + Combines multiple Raven selection table files into a single file and adjusts the selection IDs and times. + + Args: + saved_results (list[str]): List of file paths to the Raven selection table files to be combined. + + Returns: + None + """ # Combine all files s_id = 1 time_offset = 0 @@ -303,6 +379,15 @@ def combine_raven_tables(saved_results: list[str]): def combine_rtable_files(saved_results: list[str]): + """ + Combines multiple R table files into a single file. + + Args: + saved_results (list[str]): A list of file paths to the result table files to be combined. + + Returns: + None + """ # Combine all files with open(os.path.join(cfg.OUTPUT_PATH, cfg.OUTPUT_RTABLE_FILENAME), "w", encoding="utf-8") as f: f.write(RTABLE_HEADER) @@ -326,6 +411,15 @@ def combine_rtable_files(saved_results: list[str]): def combine_kaleidoscope_files(saved_results: list[str]): + """ + Combines multiple Kaleidoscope result files into a single file. + + Args: + saved_results (list[str]): A list of file paths to the saved Kaleidoscope result files. + + Returns: + None + """ # Combine all files with open(os.path.join(cfg.OUTPUT_PATH, cfg.OUTPUT_KALEIDOSCOPE_FILENAME), "w", encoding="utf-8") as f: f.write(KALEIDOSCOPE_HEADER) @@ -349,6 +443,12 @@ def combine_kaleidoscope_files(saved_results: list[str]): def combine_csv_files(saved_results: list[str]): + """ + Combines multiple CSV files into a single CSV file. + + Args: + saved_results (list[str]): A list of file paths to the CSV files to be combined. + """ # Combine all files with open(os.path.join(cfg.OUTPUT_PATH, cfg.OUTPUT_CSV_FILENAME), "w", encoding="utf-8") as f: f.write(CSV_HEADER) @@ -372,6 +472,19 @@ def combine_csv_files(saved_results: list[str]): def combineResults(saved_results: list[dict[str, str]]): + """ + Combines various types of result files based on the configuration settings. + This function checks the types of results specified in the configuration + and combines the corresponding files from the saved results list. + + Args: + saved_results (list[dict[str, str]]): A list of dictionaries containing + file paths for different result types. Each dictionary represents + a set of result files for a particular analysis. + + Returns: + None + """ if "table" in cfg.RESULT_TYPES: combine_raven_tables([f["table"] for f in saved_results if f]) @@ -398,9 +511,7 @@ def getSortedTimestamps(results: dict[str, list]): def getRawAudioFromFile(fpath: str, offset, duration): - """Reads an audio file. - - Reads the file and splits the signal into chunks. + """Reads an audio file and splits the signal into chunks. Args: fpath: Path to the audio file. @@ -438,6 +549,16 @@ def predict(samples): def get_result_file_names(fpath: str): + """ + Generates a dictionary of result file names based on the input file path and configured result types. + + Args: + fpath (str): The file path of the input file. + + Returns: + dict: A dictionary where the keys are result types (e.g., "table", "audacity", "r", "kaleidoscope", "csv") + and the values are the corresponding output file paths. + """ result_names = {} rpath = fpath.replace(cfg.INPUT_PATH, "") @@ -466,15 +587,17 @@ def get_result_file_names(fpath: str): def analyzeFile(item): - """Analyzes a file. - - Predicts the scores for the file and saves the results. + """ + Analyzes an audio file and generates prediction results. Args: - item: Tuple containing (file path, config) + item (tuple): A tuple containing the file path (str) and configuration settings. Returns: - The `True` if the file was analyzed successfully. + dict or None: A dictionary of result file names if analysis is successful, + None if the file is skipped or an error occurs. + Raises: + Exception: If there is an error in reading the audio file or saving the results. """ # Get file path and restore cfg fpath: str = item[0] diff --git a/birdnet_analyzer/audio.py b/birdnet_analyzer/audio.py index 0330af80..23713a3d 100644 --- a/birdnet_analyzer/audio.py +++ b/birdnet_analyzer/audio.py @@ -37,12 +37,31 @@ def openAudioFile(path: str, sample_rate=48000, offset=0.0, duration=None, fmin= def getAudioFileLength(path, sample_rate=48000): + """ + Get the length of an audio file in seconds. + + Args: + path (str): The file path to the audio file. + sample_rate (int, optional): The sample rate to use for reading the audio file. Default is 48000. + + Returns: + float: The duration of the audio file in seconds. + """ # Open file with librosa (uses ffmpeg or libav) return librosa.get_duration(filename=path, sr=sample_rate) def get_sample_rate(path: str): + """ + Get the sample rate of an audio file. + + Args: + path (str): The file path to the audio file. + + Returns: + int: The sample rate of the audio file. + """ return librosa.get_samplerate(path) @@ -52,15 +71,16 @@ def saveSignal(sig, fname: str): Args: sig: The signal to be saved. fname: The file path. + + Returns: + None """ sf.write(fname, sig, 48000, "PCM_16") def pad(sig, seconds, srate, amount=None): - """Creates noise. - - Creates a noise vector with the given shape. + """Creates a noise vector with the given shape. Args: sig: The original audio signal. @@ -165,6 +185,9 @@ def cropCenter(sig, rate, seconds): sig: The original signal. rate: The sampling rate. seconds: The length of the signal. + + Returns: + The cropped signal. """ if len(sig) > int(seconds * rate): start = int((len(sig) - int(seconds * rate)) / 2) @@ -179,6 +202,19 @@ def cropCenter(sig, rate, seconds): def bandpass(sig, rate, fmin, fmax, order=5): + """ + Apply a bandpass filter to the input signal. + + Args: + sig (numpy.ndarray): The input signal to be filtered. + rate (int): The sampling rate of the input signal. + fmin (float): The minimum frequency for the bandpass filter. + fmax (float): The maximum frequency for the bandpass filter. + order (int, optional): The order of the filter. Default is 5. + + Returns: + numpy.ndarray: The filtered signal as a float32 array. + """ # Check if we have to bandpass at all if fmin == cfg.SIG_FMIN and fmax == cfg.SIG_FMAX or fmin > fmax: return sig @@ -216,6 +252,18 @@ def bandpass(sig, rate, fmin, fmax, order=5): # For a complete description of this method, see Discrete-Time Signal Processing # (Second Edition), by Alan Oppenheim, Ronald Schafer, and John Buck, Prentice Hall 1998, pp. 474-476. def bandpassKaiserFIR(sig, rate, fmin, fmax, width=0.02, stopband_attenuation_db=100): + """ + Applies a bandpass filter to the given signal using a Kaiser window FIR filter. + Args: + sig (numpy.ndarray): The input signal to be filtered. + rate (int): The sample rate of the input signal. + fmin (float): The minimum frequency of the bandpass filter. + fmax (float): The maximum frequency of the bandpass filter. + width (float, optional): The transition width of the filter. Default is 0.02. + stopband_attenuation_db (float, optional): The desired attenuation in the stopband, in decibels. Default is 100. + Returns: + numpy.ndarray: The filtered signal as a float32 numpy array. + """ # Check if we have to bandpass at all if fmin == cfg.SIG_FMIN and fmax == cfg.SIG_FMAX or fmin > fmax: return sig diff --git a/birdnet_analyzer/embeddings.py b/birdnet_analyzer/embeddings.py index 720c7c99..ae0a4813 100644 --- a/birdnet_analyzer/embeddings.py +++ b/birdnet_analyzer/embeddings.py @@ -17,6 +17,12 @@ def writeErrorLog(msg): + """ + Appends an error message to the error log file. + + Args: + msg (str): The error message to be logged. + """ with open(cfg.ERROR_LOG_FILE, "a") as elog: elog.write(msg + "\n") diff --git a/birdnet_analyzer/gui/analysis.py b/birdnet_analyzer/gui/analysis.py index 41c71725..96cf8196 100644 --- a/birdnet_analyzer/gui/analysis.py +++ b/birdnet_analyzer/gui/analysis.py @@ -17,6 +17,18 @@ def analyzeFile_wrapper(entry): + """ + Wrapper function for analyzing a file. + + Args: + entry (tuple): A tuple where the first element is the file path and the + remaining elements are arguments to be passed to the + analyze.analyzeFile function. + + Returns: + tuple: A tuple where the first element is the file path and the second + element is the result of the analyze.analyzeFile function. + """ return (entry[0], analyze.analyzeFile(entry)) diff --git a/birdnet_analyzer/gui/utils.py b/birdnet_analyzer/gui/utils.py index ebb8e24d..dc12ba96 100644 --- a/birdnet_analyzer/gui/utils.py +++ b/birdnet_analyzer/gui/utils.py @@ -48,6 +48,17 @@ # Nishant - Following two functions (select_folder andget_files_and_durations) are written for Folder selection def select_folder(state_key=None): + """ + Opens a folder selection dialog and returns the selected folder path. + On Windows, it uses tkinter's filedialog to open the folder selection dialog. + On other platforms, it uses webview's FOLDER_DIALOG to open the folder selection dialog. + If a state_key is provided, the initial directory for the dialog is retrieved from the state. + If a folder is selected and a state_key is provided, the selected folder path is saved to the state. + Args: + state_key (str, optional): The key to retrieve and save the folder path in the state. Defaults to None. + Returns: + str: The path of the selected folder, or None if no folder was selected. + """ if sys.platform == "win32": from tkinter import Tk, filedialog @@ -70,6 +81,14 @@ def select_folder(state_key=None): def get_files_and_durations(folder, max_files=None): + """ + Collects audio files from a specified folder and retrieves their durations. + Args: + folder (str): The path to the folder containing audio files. + max_files (int, optional): The maximum number of files to collect. If None, all files are collected. + Returns: + list: A list of lists, where each inner list contains the relative file path and its duration as a string. + """ files_and_durations = [] files = utils.collect_audio_files(folder, max_files=max_files) # Use the collect_audio_files function @@ -85,6 +104,12 @@ def get_files_and_durations(folder, max_files=None): def set_window(window): + """ + Sets the global _WINDOW variable to the provided window object. + + Args: + window: The window object to be set as the global _WINDOW. + """ global _WINDOW _WINDOW = window @@ -148,8 +173,6 @@ def select_directory(collect_files=True, max_files=None, state_key=None): def build_header(): - # Custom HTML header with gr.Markdown - # There has to be another way, but this works for now; paths are weird in gradio with gr.Row(): gr.Markdown( f""" @@ -489,6 +512,11 @@ def _get_win_drives(): def open_window(builder: list[Callable] | Callable): + """ + Opens a GUI window using the Gradio library and the webview module. + Args: + builder (list[Callable] | Callable): A callable or a list of callables that build the GUI components. + """ multiprocessing.freeze_support() with gr.Blocks( diff --git a/birdnet_analyzer/localization.py b/birdnet_analyzer/localization.py index 56ae57c4..2e890457 100644 --- a/birdnet_analyzer/localization.py +++ b/birdnet_analyzer/localization.py @@ -13,6 +13,12 @@ def ensure_settings_file(): + """ + Ensures that the settings file exists at the specified path. If the file does not exist, + it creates a new settings file with default settings. + + If the file creation fails, the error is logged. + """ if not os.path.exists(GUI_SETTINGS_PATH): try: with open(GUI_SETTINGS_PATH, "w") as f: @@ -23,6 +29,15 @@ def ensure_settings_file(): def get_state_dict() -> dict: + """ + Retrieves the state dictionary from a JSON file specified by STATE_SETTINGS_PATH. + + If the file does not exist, it creates an empty JSON file and returns an empty dictionary. + If any other exception occurs during file operations, it logs the error and returns an empty dictionary. + + Returns: + dict: The state dictionary loaded from the JSON file, or an empty dictionary if the file does not exist or an error occurs. + """ try: with open(STATE_SETTINGS_PATH, "r", encoding="utf-8") as f: return json.load(f) @@ -37,10 +52,27 @@ def get_state_dict() -> dict: def get_state(key: str, default=None) -> str: + """ + Retrieves the value associated with the given key from the state dictionary. + + Args: + key (str): The key to look up in the state dictionary. + default: The value to return if the key is not found. Defaults to None. + + Returns: + str: The value associated with the key if found, otherwise the default value. + """ return get_state_dict().get(key, default) def set_state(key: str, value: str): + """ + Updates the state dictionary with the given key-value pair and writes it to a JSON file. + + Args: + key (str): The key to update in the state dictionary. + value (str): The value to associate with the key in the state dictionary. + """ state = get_state_dict() state[key] = value try: @@ -51,6 +83,10 @@ def set_state(key: str, value: str): def load_local_state(): + """ + Loads the local language settings and populates the LANGUAGE_LOOKUP dictionary with the appropriate translations. + This function performs the following steps: + """ global LANGUAGE_LOOKUP global TARGET_LANGUAGE @@ -79,10 +115,28 @@ def load_local_state(): def localize(key: str) -> str: + """ + Translates a given key into its corresponding localized string. + + Args: + key (str): The key to be localized. + + Returns: + str: The localized string corresponding to the given key. If the key is not found in the localization lookup, the original key is returned. + """ return LANGUAGE_LOOKUP.get(key, key) def set_language(language: str): + """ + Sets the language for the application by updating the GUI settings file. + This function ensures that the settings file exists, reads the current settings, + updates the "language-id" field with the provided language, and writes the updated + settings back to the file. + + Args: + language (str): The language identifier to set in the settings file. + """ if language: ensure_settings_file() settings = {} diff --git a/birdnet_analyzer/model.py b/birdnet_analyzer/model.py index e603ac61..432f7d5d 100644 --- a/birdnet_analyzer/model.py +++ b/birdnet_analyzer/model.py @@ -35,6 +35,11 @@ def resetCustomClassifier(): + """ + Resets the custom classifier by setting the global variables C_INTERPRETER and C_PBMODEL to None. + This function is used to clear any existing custom classifier models and interpreters, effectively + resetting the state of the custom classifier. + """ global C_INTERPRETER global C_PBMODEL @@ -43,10 +48,15 @@ def resetCustomClassifier(): def loadModel(class_output=True): - """Initializes the BirdNET Model. + """ + Loads the machine learning model based on the configuration provided. + This function loads either a TensorFlow Lite (TFLite) model or a protobuf model + depending on the file extension of the model path specified in the configuration. + It sets up the global variables for the model interpreter and input/output layer indices. Args: - class_output: Omits the last layer when False. + class_output (bool): If True, sets the output layer index to the classification output. + If False, sets the output layer index to the feature embeddings. """ global PBMODEL global INTERPRETER @@ -82,7 +92,12 @@ def loadModel(class_output=True): def loadCustomClassifier(): - """Loads the custom classifier.""" + """ + Loads a custom classifier model based on the file extension of the provided model path. + If the model file ends with ".tflite", it loads a TensorFlow Lite model and sets up the interpreter, + input layer index, output layer index, and input size. + If the model file does not end with ".tflite", it loads a TensorFlow SavedModel. + """ global C_INTERPRETER global C_INPUT_LAYER_INDEX global C_OUTPUT_LAYER_INDEX @@ -144,6 +159,7 @@ def buildLinearClassifier(num_labels, input_size, hidden_units=0, dropout=0.0): num_labels: Output size. input_size: Size of the input. hidden_units: If > 0, creates another hidden layer with the given number of units. + dropout: Dropout rate. Returns: A new classifier. @@ -202,6 +218,11 @@ def trainLinearClassifier( epochs: Number of epochs to train. batch_size: Batch size. learning_rate: The learning rate during training. + val_split: Validation split ratio. + upsampling_ratio: Upsampling ratio. + upsampling_mode: Upsampling mode. + train_with_mixup: If True, applies mixup to the training data. + train_with_label_smoothing: If True, applies label smoothing to the training data. on_epoch_end: Optional callback `function(epoch, logs)`. Returns: @@ -300,9 +321,7 @@ def on_epoch_end(self, epoch, logs=None): def saveLinearClassifier(classifier, model_path: str, labels: list[str], mode="replace"): - """Saves a custom classifier on the hard drive. - - Saves the classifier as a tflite model, as well as the used labels in a .txt. + """Saves the classifier as a tflite model, as well as the used labels in a .txt. Args: classifier: The custom classifier. @@ -358,6 +377,22 @@ def saveLinearClassifier(classifier, model_path: str, labels: list[str], mode="r def save_raven_model(classifier, model_path, labels: list[str], mode="replace"): + """ + Save a TensorFlow model with a custom classifier and associated metadata for use with BirdNET. + + Args: + classifier (tf.keras.Model): The custom classifier model to be saved. + model_path (str): The path where the model will be saved. + labels (list[str]): A list of labels associated with the classifier. + mode (str, optional): The mode for saving the model. Can be either "replace" or "append". + Defaults to "replace". + + Raises: + ValueError: If the mode is not "replace" or "append". + + Returns: + None + """ import csv import json @@ -545,6 +580,19 @@ def custom_loss(y_true, y_pred, epsilon=1e-7): def flat_sigmoid(x, sensitivity=-1): + """ + Applies a flat sigmoid function to the input array. + + The flat sigmoid function is defined as: + f(x) = 1 / (1 + exp(sensitivity * clip(x, -15, 15))) + + Args: + x (array-like): Input data. + sensitivity (float, optional): Sensitivity parameter for the sigmoid function. Default is -1. + + Returns: + numpy.ndarray: Transformed data after applying the flat sigmoid function. + """ return 1 / (1.0 + np.exp(sensitivity * np.clip(x, -15, 15))) diff --git a/birdnet_analyzer/segments.py b/birdnet_analyzer/segments.py index 59c99460..311cd85a 100644 --- a/birdnet_analyzer/segments.py +++ b/birdnet_analyzer/segments.py @@ -41,6 +41,15 @@ def detectRType(line: str): def getHeaderMapping(line: str) -> dict: + """ + Parses a header line and returns a mapping of column names to their indices. + + Args: + line (str): A string representing the header line of a file. + + Returns: + dict: A dictionary where the keys are column names and the values are their respective indices. + """ rtype = detectRType(line) if rtype == "table" or rtype == "audacity": sep = "\t" @@ -62,12 +71,12 @@ def parseFolders(apath: str, rpath: str, allowed_result_filetypes: list[str] = [ Reads all audio files and BirdNET output inside directory recursively. Args: - apath: Path to search for audio files. - rpath: Path to search for result files. - allowed_result_filetypes: List of extensions for the result files. + apath (str): Path to search for audio files. + rpath (str): Path to search for result files. + allowed_result_filetypes (list[str]): List of extensions for the result files. Returns: - A list of {"audio": path_to_audio, "result": path_to_result }. + list[dict]: A list of {"audio": path_to_audio, "result": path_to_result }. """ data = {} apath = apath.replace("/", os.sep).replace("\\", os.sep) @@ -112,14 +121,25 @@ def parseFolders(apath: str, rpath: str, allowed_result_filetypes: list[str] = [ def parseFiles(flist: list[dict], max_segments=100): - """Extracts the segments for all files. + """ + Parses a list of files to extract and organize bird call segments by species. Args: - flist: List of dict with {"audio": path_to_audio, "result": path_to_result }. - max_segments: Number of segments per species. - + flist (list[dict]): A list of dictionaries, each containing 'audio' and 'result' file paths. + Optionally, a dictionary can have 'isCombinedFile' set to True to indicate + that it is a combined result file. + max_segments (int, optional): The maximum number of segments to retain per species. Defaults to 100. Returns: - TODO @kahst + list[tuple]: A list of tuples where each tuple contains an audio file path and a list of segments + associated with that audio file. + Raises: + KeyError: If the dictionaries in flist do not contain the required keys ('audio' and 'result'). + Example: + flist = [ + {"audio": "path/to/audio1.wav", "result": "path/to/result1.csv"}, + {"audio": "path/to/audio2.wav", "result": "path/to/result2.csv"} + ] + segments = parseFiles(flist, max_segments=50) """ species_segments: dict[str, list] = {} @@ -176,14 +196,14 @@ def parseFiles(flist: list[dict], max_segments=100): return flist -def findSegmentsFromCombined(rfile: str): +def findSegmentsFromCombined(rfile: str) -> list[dict]: """Extracts the segments from a combined results file Args: - rfile: Path to the result file. + rfile (str): Path to the result file. Returns: - A list of dicts in the form of + list[dict]: A list of dicts in the form of {"audio": afile, "start": start, "end": end, "species": species, "confidence": confidence} """ segments: list[dict] = [] @@ -321,12 +341,19 @@ def findSegments(afile: str, rfile: str): def extractSegments(item: tuple[tuple[str, list[dict]], float, dict[str]]): - """Saves each segment separately. - - Creates an audio file for each species segment. - + """ + Extracts audio segments from a given audio file based on provided segment information. Args: - item: A tuple that contains ((audio file path, segments), segment length, config) + item (tuple): A tuple containing: + - A tuple with: + - A string representing the path to the audio file. + - A list of dictionaries, each containing segment information with keys "start", "end", "species", "confidence", and "audio". + - A float representing the segment length. + - A dictionary containing configuration settings. + Returns: + bool: True if segments were successfully extracted, False otherwise. + Raises: + Exception: If there is an error opening the audio file or extracting segments. """ # Paths and config afile = item[0][0] diff --git a/birdnet_analyzer/species.py b/birdnet_analyzer/species.py index b6cfca6d..f8ba04fc 100644 --- a/birdnet_analyzer/species.py +++ b/birdnet_analyzer/species.py @@ -38,6 +38,18 @@ def getSpeciesList(lat: float, lon: float, week: int, threshold=0.05, sort=False def run(output_path, lat, lon, week, threshold, sortby): + """ + Generates a species list for a given location and time, and saves it to the specified output path. + Args: + output_path (str): The path where the species list will be saved. If it's a directory, the list will be saved as "species_list.txt" inside it. + lat (float): Latitude of the location. + lon (float): Longitude of the location. + week (int): Week of the year (1-52) for which the species list is generated. + threshold (float): Threshold for location filtering. + sortby (str): Sorting criteria for the species list. Can be "freq" for frequency or any other value for alphabetical sorting. + Returns: + None + """ # Set paths relative to script path (requested in #3) cfg.LABELS_FILE = os.path.join(SCRIPT_DIR, cfg.LABELS_FILE) cfg.MDATA_MODEL_PATH = os.path.join(SCRIPT_DIR, cfg.MDATA_MODEL_PATH) diff --git a/birdnet_analyzer/train.py b/birdnet_analyzer/train.py index b2c28256..13ce4aac 100644 --- a/birdnet_analyzer/train.py +++ b/birdnet_analyzer/train.py @@ -22,6 +22,16 @@ def save_sample_counts(labels, y_train): + """ + Saves the count of samples per label combination to a CSV file. + + The function creates a dictionary where the keys are label combinations (joined by '+') and the values are the counts of samples for each combination. + It then writes this information to a CSV file named "_sample_counts.csv" with two columns: "Label" and "Count". + + Args: + labels (list of str): List of label names corresponding to the columns in y_train. + y_train (numpy.ndarray): 2D array where each row is a binary vector indicating the presence (1) or absence (0) of each label. + """ samples_per_label = {} label_combinations = np.unique(y_train, axis=0)