Skip to content

Commit

Permalink
fix some parameters and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
LoannPeurey committed Jul 4, 2024
1 parent 4e2514f commit 1181f84
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ChildProject/pipelines/conversationFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def who_finished(segments: pd.DataFrame):
Required keyword arguments:
"""
return segments[segments['segment_offset'] == segments['segment_offset'].max()]['speaker_type']
return segments[segments['segment_offset'] == segments['segment_offset'].max()].iloc[0]['speaker_type']

@conversationFunction()
def participants(segments: pd.DataFrame):
Expand Down
19 changes: 8 additions & 11 deletions ChildProject/pipelines/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def check_callable(row):
)
self.set = setname
self.features_dict = features_list.to_dict(orient="index")
features_list['name'] = features_list.index
self.features_df = features_list

# necessary columns to construct the conversations
join_columns = {
Expand Down Expand Up @@ -260,21 +262,15 @@ def extract(self):
:rtype: pandas.DataFrame
"""
if self.threads == 1:
extractions = []
for rec in self.recordings:
segments = self.retrieve_segments(rec)

conversations = segments.groupby(grouper)

extractions += [self._process_conversation(block) for block in conversations]
self.conversations = pd.DataFrame(extractions) if len(extractions) else pd.DataFrame(columns=grouper)
results = list(itertools.chain.from_iterable(map(self._process_recording, self.recordings)))
else:
with mp.Pool(
processes=self.threads if self.threads >= 1 else mp.cpu_count()
) as pool:
results = list(itertools.chain.from_iterable(pool.map(self._process_recording, self.recordings)))

self.conversations = pd.DataFrame(results) if len(results) else pd.DataFrame(columns=grouper)
self.conversations = pd.DataFrame(results) if len(results) else pd.DataFrame(columns=grouper)

# now add the rec_cols and child_cols in the result
if self.rec_cols:
Expand Down Expand Up @@ -336,7 +332,7 @@ def retrieve_segments(self, recording: str):
# no annotations for that unit
return pd.DataFrame(columns=list(set([c.name for c in AnnotationManager.SEGMENTS_COLUMNS if c.required]
+ list(annotations.columns) + ['conv_count'])))
segments = segments.dropna(subset='conv_count')
segments = segments.dropna(subset=['conv_count'])
else:
# no annotations for that unit
return pd.DataFrame(columns=list(set([c.name for c in AnnotationManager.SEGMENTS_COLUMNS if c.required]
Expand Down Expand Up @@ -509,7 +505,7 @@ def run(self, path, destination, pipeline, func=None, **kwargs):
self.conversations.to_csv(self.destination, index=False)

# get the df of features used from the Conversations class
features_df = conversations.features_list
features_df = conversations.features_df
features_df['callable'] = features_df.apply(lambda row: row['callable'].__name__,
axis=1) # from the callables used, find their name back
parameters['features_list'] = [{k: v for k, v in m.items() if pd.notnull(v)} for m in
Expand Down Expand Up @@ -658,7 +654,8 @@ def run(self, parameters_input, func=None):
self.conversations.to_csv(self.destination, index=False)

# get the df of features used from the Conversations class
features_df = conversations.features_list
features_df = conversations.features_df
print(features_df)
features_df['callable'] = features_df.apply(lambda row: row['callable'].__name__,
axis=1) # from the callables used, find their name back
parameters['features_list'] = [{k: v for k, v in m.items() if pd.notnull(v)} for m in
Expand Down

0 comments on commit 1181f84

Please sign in to comment.