Skip to content

Commit

Permalink
create skeleton of derive_annotations function in AM
Browse files Browse the repository at this point in the history
  • Loading branch information
akoziol98 committed Feb 20, 2024
1 parent a8eb13a commit a623c99
Showing 1 changed file with 123 additions and 1 deletion.
124 changes: 123 additions & 1 deletion ChildProject/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ def _check_for_outdated_merged_sets(self, sets: set = None):
warnings.append("set {} is outdated because the {} set it is merged from was modified. Consider updating or rerunning the creation of the {} set.".format(i,j,i))

return warnings

def _import_annotation(
self, import_function: Callable[[str], pd.DataFrame],
params: dict,
Expand Down Expand Up @@ -711,6 +710,129 @@ def import_annotations(
logger_annotations.warning("warning: %s", warning)

return (imported, errors)
def derive_annotations(self,
input_set: str,
output_set: str,
derivation_function: Union[str, Callable],
threads: int = -1,
overwrite_existing: bool = False,
) -> (pd.DataFrame, pd.DataFrame):
"""Derive annotations.
:param input_set: name of the set of annotations to be derived
:rtype: str
:param output_set: name of the new set of derived annotations
:rtype: str
:param derivation_function: name of the derivation type to be performed
:rtype: Union[str, Callable]
:param threads: If > 1, conversions will be run on ``threads`` threads, defaults to -1
:type threads: int, optional
:param overwrite_existing: choose if lines with the same set and annotation_filename should be overwritten
:type overwrite_existing: bool, optional
:return: tuple of dataframe of derived annotations, as in :ref:`format-annotations` and dataframe of errors
:rtype: tuple (pd.DataFrame, pd.DataFrame)
"""
input_processed = input.copy().reset_index()

required_columns = {
c.name
for c in AnnotationManager.INDEX_COLUMNS
if c.required and not c.generated
}

assert_dataframe("input", input_processed)
assert_columns_presence("input", input_processed, required_columns)

input_processed["range_onset"] = input_processed["range_onset"].astype(np.int64)
input_processed["range_offset"] = input_processed["range_offset"].astype(np.int64)

assert (input_processed["range_offset"] > input_processed[
"range_onset"]).all(), "range_offset must be greater than range_onset"
assert (input_processed["range_onset"] >= 0).all(), "range_onset must be greater or equal to 0"
if "duration" in self.project.recordings.columns:
assert (input_processed["range_offset"] <= input_processed.merge(self.project.recordings,
how='left',
on='recording_filename',
validate='m:1'
).reset_index()["duration"]
).all(), "range_offset must be smaller than the duration of the recording"

missing_recordings = input_processed[
~input_processed["recording_filename"].isin(
self.project.recordings["recording_filename"]
)
]
missing_recordings = missing_recordings["recording_filename"]

if len(missing_recordings) > 0:
raise ValueError(
"cannot import annotations, because the following recordings are not referenced in the metadata:\n{}".format(
"\n".join(missing_recordings)
)
)

builtin = input_processed[input_processed["format"].isin(converters.keys())]
if not builtin["format"].map(lambda f: converters[f].THREAD_SAFE).all():
logger_annotations.warning(
"warning: some of the converters do not support multithread importation; running on 1 thread")
threads = 1

# if the input to import has overlaps in it, raise an error immediately, nothing will be imported
ovl_ranges = find_lines_involved_in_overlap(input_processed, labels=['recording_filename', 'set'])
if ovl_ranges[ovl_ranges == True].shape[0] > 0:
ovl_ranges = ovl_ranges[ovl_ranges].index.values.tolist()
raise ValueError(f"the ranges given to import have overlaps on indexes : {ovl_ranges}")

if threads == 1:
imported = input_processed.apply(
partial(self._import_annotation, import_function,
{"new_tiers": new_tiers},
overwrite_existing=overwrite_existing
), axis=1
).to_dict(orient="records")
else:
with mp.Pool(processes=threads if threads > 0 else mp.cpu_count()) as pool:
imported = pool.map(
partial(self._import_annotation, import_function,
{"new_tiers": new_tiers},
overwrite_existing=overwrite_existing
),
input_processed.to_dict(orient="records"),
)

imported = pd.DataFrame(imported)
imported.drop(
list(set(imported.columns) - {c.name for c in self.INDEX_COLUMNS}),
axis=1,
inplace=True,
)

if 'error' in imported.columns:
errors = imported[~imported["error"].isnull()]
imported = imported[imported["error"].isnull()]
# when errors occur, separate them in a different csv in extra
if errors.shape[0] > 0:
output = os.path.join(self.project.path, "extra",
"errors_import_{}.csv".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
errors.to_csv(output, index=False)
logger_annotations.info("Errors summary exported to %s", output)
else:
errors = None

self.read()
self.annotations = pd.concat([self.annotations, imported], sort=False)
# at this point, 2 lines with same set and annotation_filename can happen if specified overwrite,
# dropping duplicates remove the first importation and keeps the more recent one
self.annotations = self.annotations.sort_values('imported_at').drop_duplicates(
subset=["set", "recording_filename", "range_onset", "range_offset"], keep='last')
self.write()

sets = set(input_processed['set'].unique())
outdated_sets = self._check_for_outdated_merged_sets(sets=sets)
for warning in outdated_sets:
logger_annotations.warning("warning: %s", warning)

return imported, errors

def get_subsets(self, annotation_set: str, recursive: bool = False) -> List[str]:
"""Retrieve the list of subsets belonging to a given set of annotations.
Expand Down

0 comments on commit a623c99

Please sign in to comment.