Skip to content

Commit

Permalink
Fix bug in flattening list of lists. (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
JyChang012 authored Jul 17, 2022
1 parent be5e723 commit b525b02
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/rsdiv/evaluation/diversity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class DiversityMetrics:
def _get_histogram(
items: Union[Iterable[Hashable], Iterable[Sequence[Hashable]]],
) -> np.ndarray:
if isinstance(next(iter(items)), Sequence):
first_element = next(iter(items))
if isinstance(first_element, Sequence) and not isinstance(first_element, str):
items = chain(*items)
flatten_items = list(items)
return np.asarray(pd.Series(flatten_items).value_counts())
Expand Down Expand Up @@ -80,7 +81,8 @@ def get_lorenz_curve(
def get_distribution(
cls, items: Union[Iterable[Hashable], Iterable[Sequence[Hashable]]]
) -> pd.DataFrame:
if isinstance(next(iter(items)), Sequence):
first_element = next(iter(items))
if isinstance(first_element, Sequence) and not isinstance(first_element, str):
items = chain(*items)
counter: pd.DataFrame = pd.DataFrame(Counter(items).most_common())
counter.columns = pd.Index(["category", "percentage"])
Expand Down

0 comments on commit b525b02

Please sign in to comment.