Skip to content

Commit

Permalink
Enable int alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Feb 28, 2024
1 parent 0bb7d4f commit a4967b6
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__( # pylint: disable=R0913
self,
partition_sizes: Union[List[int], NDArrayInt],
partition_by: str,
alpha: Union[float, List[float], NDArrayFloat],
alpha: Union[int, float, List[float], NDArrayFloat],
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
Expand Down Expand Up @@ -109,7 +109,7 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
return self.dataset.select(self._node_id_to_indices[node_id])

def _initialize_alpha_if_needed(
self, alpha: Union[float, List[float], NDArrayFloat]
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
"""Convert alpha to the used format in the code a NDArrayFloat.
Expand All @@ -130,7 +130,10 @@ def _initialize_alpha_if_needed(
if self._initialized_alpha:
assert self._alpha is not None
return self._alpha
if isinstance(alpha, float):
if isinstance(alpha, int):
assert self._num_unique_classes is not None
alpha = np.array([float(alpha)], dtype=float).repeat(self._num_unique_classes)
elif isinstance(alpha, float):
assert self._num_unique_classes is not None
alpha = np.array([alpha], dtype=float).repeat(self._num_unique_classes)
elif isinstance(alpha, List):
Expand Down

0 comments on commit a4967b6

Please sign in to comment.