From b525b026a250ae416a0d8299f9d8dc65720f2609 Mon Sep 17 00:00:00 2001 From: Peter Chang <40067028+JyChang012@users.noreply.github.com> Date: Sat, 16 Jul 2022 22:26:53 -0400 Subject: [PATCH] Fix bug in flattening list of lists. (#32) --- src/rsdiv/evaluation/diversity_metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rsdiv/evaluation/diversity_metrics.py b/src/rsdiv/evaluation/diversity_metrics.py index 68dc1cb..91f2440 100644 --- a/src/rsdiv/evaluation/diversity_metrics.py +++ b/src/rsdiv/evaluation/diversity_metrics.py @@ -15,7 +15,8 @@ class DiversityMetrics: def _get_histogram( items: Union[Iterable[Hashable], Iterable[Sequence[Hashable]]], ) -> np.ndarray: - if isinstance(next(iter(items)), Sequence): + first_element = next(iter(items)) + if isinstance(first_element, Sequence) and not isinstance(first_element, str): items = chain(*items) flatten_items = list(items) return np.asarray(pd.Series(flatten_items).value_counts()) @@ -80,7 +81,8 @@ def get_lorenz_curve( def get_distribution( cls, items: Union[Iterable[Hashable], Iterable[Sequence[Hashable]]] ) -> pd.DataFrame: - if isinstance(next(iter(items)), Sequence): + first_element = next(iter(items)) + if isinstance(first_element, Sequence) and not isinstance(first_element, str): items = chain(*items) counter: pd.DataFrame = pd.DataFrame(Counter(items).most_common()) counter.columns = pd.Index(["category", "percentage"])