From 72415df2300b8bcb3e0f68cdf6f96d7d92490fd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agata=20Kozio=C5=82?= Date: Tue, 27 Feb 2024 16:40:52 +0100 Subject: [PATCH] Created _summarise_conversations --- ChildProject/annotations.py | 204 +++++++++++++++++++++++++++++++++++- 1 file changed, 203 insertions(+), 1 deletion(-) diff --git a/ChildProject/annotations.py b/ChildProject/annotations.py index c730c700..0118f68a 100644 --- a/ChildProject/annotations.py +++ b/ChildProject/annotations.py @@ -11,7 +11,7 @@ import logging from . import __version__ -from .pipelines.derivations import DERIVATIONS +from .pipelines.derivations import DERIVATIONS, conversations from .projects import ChildProject from .converters import * from .tables import IndexTable, IndexColumn, assert_dataframe, assert_columns_presence @@ -923,6 +923,208 @@ def derive_annotations(self, return imported, errors + def _summarise_conversations( + self, + annotation: dict, + overwrite_existing: bool = False, + ): + """import and convert ``annotation``. This function should not be called outside of this class. + + :param import_function: If callable, ``import_function`` will be called to convert the input annotation into a dataframe. Otherwise, the conversion will be performed by a built-in function. + :type import_function: Callable[[str], pd.DataFrame] + :param output_set: name of the new set of derived annotations + :type output_set: str + :param params: Optional parameters. With ```new_tiers```, the corresponding EAF tiers will be imported + :type params: dict + :param annotation: input annotation dictionary (attributes defined according to :ref:`ChildProject.annotations.AnnotationManager.SEGMENTS_COLUMNS`) + :type annotation: dict + :param overwrite_existing: choose if lines with the same set and annotation_filename should be overwritten + :type overwrite_existing: bool + :return: output annotation dictionary (attributes defined according to :ref:`ChildProject.annotations.AnnotationManager.SEGMENTS_COLUMNS`) + :rtype: dict + """ + + source_recording = os.path.splitext(annotation["recording_filename"])[0] + annotation_filename = "{}_{}_{}.csv".format( + source_recording, annotation["range_onset"], annotation["range_offset"] + ) + output_filename = os.path.join( + "extra", annotation_filename + ) + + # # check if the annotation file already exists in dataset (same filename and same set) + # if self.annotations[(self.annotations['set'] == output_set) & + # (self.annotations['annotation_filename'] == annotation_filename)].shape[0] > 0: + # if overwrite_existing: + # logger_annotations.warning("Derived file %s will be overwritten", output_filename) + # + # else: + # logger_annotations.warning("File %s already exists. To overwrite, specify parameter ''overwrite_existing''", output_filename) + # return annotation_result + + # find if there are annotation indexes in the same set that overlap the new annotation + # as it is not possible to annotate multiple times the same audio stretch in the same set + # ovl_annots = self.annotations[(self.annotations['set'] == output_set) & + # (self.annotations[ + # 'annotation_filename'] != annotation_filename) & # this condition avoid matching a line that should be overwritten (so has the same annotation_filename), it is dependent on the previous block!!! + # (self.annotations['recording_filename'] == annotation['recording_filename']) & + # (self.annotations['range_onset'] < annotation['range_offset']) & + # (self.annotations['range_offset'] > annotation['range_onset']) + # ] + # if ovl_annots.shape[0] > 0: + # array_tup = list( + # ovl_annots[['set', 'recording_filename', 'range_onset', 'range_offset']].itertuples(index=False, + # name=None)) + # annotation_result[ + # "error"] = f"derivation for set <{output_set}> recording <{annotation['recording_filename']}> from {annotation['range_onset']} to {annotation['range_offset']} cannot continue because it overlaps with these existing annotation lines: {array_tup}" + # logger_annotations.error("Error: %s", annotation['error']) + # # (f"Error: {annotation['error']}") + # annotation_result["imported_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # return annotation_result + + path = os.path.join( + self.project.path, + "annotations", + annotation["set"], + "converted", #EXPAND + annotation["annotation_filename"], + ) + + #TODO CHECK FOR DTYPES + df_input = pd.read_csv(path) + df = None + + try: + df = conversations(df_input) + # if callable(import_function): + # df = import_function(df_input) + # elif import_function in DERIVATIONS.keys(): + # df = DERIVATIONS[import_function](df_input) + # else: + # raise ValueError( + # "derivation value '{}' unknown, use one of {}".format(import_function, DERIVATIONS.keys()) + # ) + except: + annotation["error"] = traceback.format_exc() + logger_annotations.error("An error occurred while processing '%s'", path, exc_info=True) + + if df is None or not isinstance(df, pd.DataFrame): + annotation_result["imported_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + return annotation_result + + if not df.shape[1]: + df = pd.DataFrame(columns=[c.name for c in self.SEGMENTS_COLUMNS]) + + df["raw_filename"] = annotation["raw_filename"] + + df["segment_onset"] += np.int64(annotation["time_seek"]) + df["segment_offset"] += np.int64(annotation["time_seek"]) + df["segment_onset"] = df["segment_onset"].astype(np.int64) + df["segment_offset"] = df["segment_offset"].astype(np.int64) + + annotation_result["time_seek"] = np.int64(annotation["time_seek"]) + annotation_result["range_onset"] = np.int64(annotation["range_onset"]) + annotation_result["range_offset"] = np.int64(annotation["range_offset"]) + + df = AnnotationManager.clip_segments( + df, annotation_result["range_onset"], annotation_result["range_offset"] + ) + + sort_columns = ["segment_onset", "segment_offset"] + if "speaker_type" in df.columns: + sort_columns.append("speaker_type") + + df.sort_values(sort_columns, inplace=True) + + os.makedirs( + os.path.dirname(os.path.join(self.project.path, output_filename)), + exist_ok=True, + ) + df.to_csv(os.path.join(self.project.path, output_filename), index=False) + + annotation_result["annotation_filename"] = annotation_filename + annotation_result["imported_at"] = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + annotation_result["package_version"] = __version__ + + return annotation_result + + def summarise_conversations(self, + input_set: str, + output_set: str, + derivation_function: Union[str, Callable], + threads: int = -1, + overwrite_existing: bool = False, + ) -> (pd.DataFrame, pd.DataFrame): + """Derive annotations. + + :param input_set: name of the set of annotations to be derived + :rtype: str + :param output_set: name of the new set of derived annotations + :rtype: str + :param derivation_function: name of the derivation type to be performed + :rtype: Union[str, Callable] + :param threads: If > 1, conversions will be run on ``threads`` threads, defaults to -1 + :type threads: int, optional + :param overwrite_existing: choice if lines with the same set and annotation_filename should be overwritten + :type overwrite_existing: bool, optional + :return: tuple of dataframe of derived annotations, as in :ref:`format-annotations` and dataframe of errors + :rtype: tuple (pd.DataFrame, pd.DataFrame) + """ + input_processed = self.annotations[self.annotations['set'] == input_set].copy() + assert not input_processed.empty, "Input set {0} does not exist".format(input_set) + + if threads == 1: + imported = input_processed.apply( + partial(self._summarise_conversations, + overwrite_existing=overwrite_existing + ), axis=1 + ).to_dict(orient="records") + else: + + with mp.Pool(processes=threads if threads > 0 else mp.cpu_count()) as pool: + imported = pool.map( + partial(self._summarise_conversations, + overwrite_existing=overwrite_existing + ), + input_processed.to_dict(orient="records"), + ) + + imported = pd.DataFrame(imported) + imported.drop( + list(set(imported.columns) - {c.name for c in self.INDEX_COLUMNS}), + axis=1, + inplace=True, + ) + + if 'error' in imported.columns: + errors = imported[~imported["error"].isnull()] + imported = imported[imported["error"].isnull()] + # when errors occur, separate them in a different csv in extra + if errors.shape[0] > 0: + output = os.path.join(self.project.path, "extra", + "errors_conv_summary_{}.csv".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))) + errors.to_csv(output, index=False) + logger_annotations.info("Errors summary exported to %s", output) + else: + errors = None + + self.read() + self.annotations = pd.concat([self.annotations, imported], sort=False) + # at this point, 2 lines with same set and annotation_filename can happen if specified overwrite, + # dropping duplicates remove the first importation and keeps the more recent one + self.annotations = self.annotations.sort_values('imported_at').drop_duplicates( + subset=["set", "recording_filename", "range_onset", "range_offset"], keep='last') + self.write() + + sets = set(input_processed['set'].unique()) + outdated_sets = self._check_for_outdated_merged_sets(sets=sets) + for warning in outdated_sets: + logger_annotations.warning("warning: %s", warning) + + return imported, errors + def get_subsets(self, annotation_set: str, recursive: bool = False) -> List[str]: """Retrieve the list of subsets belonging to a given set of annotations.