diff --git a/src/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py index 19a40e0449..6bd6992a5a 100644 --- a/src/flash/core/data/utilities/classification.py +++ b/src/flash/core/data/utilities/classification.py @@ -42,8 +42,11 @@ def _as_list(x: Union[List, Tensor, np.ndarray]) -> List: return x -def _strip(x: str) -> str: - return x.strip(", ") +def _strip(x: Union[str, int]) -> str: + """Replace both ` ` and `,` from str.""" + if isinstance(x, str): + return x.strip(", ") + return str(x) @dataclass