diff --git a/baselines/fedrep/fedrep/client.py b/baselines/fedrep/fedrep/client.py index 06447524c18c..a130b5f51bc7 100644 --- a/baselines/fedrep/fedrep/client.py +++ b/baselines/fedrep/fedrep/client.py @@ -238,7 +238,15 @@ def get_client_fn_simulation( if config.dataset.name.lower() == "cifar100": use_fine_label = True - partitioner = PathologicalPartitioner( + partitioner_train = PathologicalPartitioner( + num_partitions=config.num_clients, + partition_by="fine_label" if use_fine_label else "label", + num_classes_per_partition=config.dataset.num_classes, + class_assignment_mode="random", + shuffle=True, + seed=config.dataset.seed, + ) + partitioner_test = PathologicalPartitioner( num_partitions=config.num_clients, partition_by="fine_label" if use_fine_label else "label", num_classes_per_partition=config.dataset.num_classes, @@ -251,7 +259,7 @@ def get_client_fn_simulation( if FEDERATED_DATASET is None: FEDERATED_DATASET = FederatedDataset( dataset=config.dataset.name.lower(), - partitioners={"train": partitioner, "test": partitioner}, + partitioners={"train": partitioner_train, "test": partitioner_test}, ) def apply_train_transforms(batch):