Skip to content

Commit

Permalink
Merge pull request #233 from nel-lab/fix-get-params-diff
Browse files Browse the repository at this point in the history
fix `get_params_diff()`, return as dataframe instead of series of dicts
  • Loading branch information
kushalkolar authored Oct 12, 2023
2 parents 7c4cdc2 + ce33691 commit 14abc7a
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions mesmerize_core/caiman_extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def remove_item(self, index: Union[int, str, UUID], remove_data: bool = True, sa
self._df.to_pickle(self._df.paths.get_batch_path())

@warning_experimental("This feature is new and the might improve in the future")
def get_params_diffs(self, algo: str, item_name: str) -> pd.Series:
def get_params_diffs(self, algo: str, item_name: str) -> pd.DataFrame:
"""
Get the parameters that differ for a given `item_name` run with a given `algo`
Expand All @@ -260,8 +260,8 @@ def get_params_diffs(self, algo: str, item_name: str) -> pd.Series:
Returns
-------
pd.Series
pandas Series (rows) with dicts containing only the
pd.DataFrame
pandas DataFrame) with dicts containing only the
parameters that vary between batch items for the given
`item_name`. The returned index corresponds to the
index of the original DataFrame
Expand All @@ -286,9 +286,16 @@ def get_params_diffs(self, algo: str, item_name: str) -> pd.Series:
counts = Counter([av[0] for av in all_variants])
variants_exist = [param[0] for param in counts.items() if param[1] > 1]

diffs = sub_df["params"].apply(lambda p: {k: p["main"][k] for k in variants_exist})
# gives a series where each item is a dict that has the unique params that correspond to a row
# the indices of this series correspond to the index of the row in the parent dataframe
diffs: pd.Series = sub_df["params"].apply(
lambda p: {k: p["main"][k] for k in variants_exist if k in p["main"].keys()}
)

# return as a nicely formatted dataframe
diffs_df = pd.DataFrame.from_dict(diffs.tolist(), dtype=object).set_index(diffs.index)

return diffs
return diffs_df

@warning_experimental("This feature will change in the future and directly return the "
" a DataFrame of children (rows, ie. child batch items row) "
Expand Down

0 comments on commit 14abc7a

Please sign in to comment.