Skip to content

Commit

Permalink
using test_samples_per_user to reduce test set
Browse files Browse the repository at this point in the history
  • Loading branch information
kathrynle20 committed Dec 5, 2024
1 parent f9ad68a commit e24a2f8
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 12 deletions.
14 changes: 7 additions & 7 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,13 +532,6 @@ def set_data_parameters(self, config: ConfigType) -> None:
if config.get("test_samples_per_class", None) is not None:
test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"])

#reduce test_dset size
if config.get("workflow_test", False):
print("Workflow testing: Reducing test size...")
reduced_test_size = 1000
indices = np.random.choice(len(test_dset), reduced_test_size, replace=False)
test_dset = Subset(test_dset, indices)

samples_per_user = config["samples_per_user"]
batch_size: int = config["batch_size"] # type: ignore
print(f"samples per user: {samples_per_user}, batch size: {batch_size}")
Expand Down Expand Up @@ -693,7 +686,14 @@ def is_same_dest(dset):
if self.dset.startswith("domainnet"):
test_dset = CacheDataset(test_dset)

# reduce test_dset size
if config.get("test_samples_per_user", 0) != 0:
print(f"Reducing test size to {config.get('test_samples_per_user', 0)}")
reduced_test_size = config.get("test_samples_per_user", 0)
indices = np.random.choice(len(test_dset), reduced_test_size, replace=False)
test_dset = Subset(test_dset, indices)
print(f"test_dset size: {len(test_dset)}")

self._test_loader = DataLoader(test_dset, batch_size=batch_size)
# TODO: fix print_data_summary
# self.print_data_summary(train_dset, test_dset, val_dset=val_dset)
Expand Down
1 change: 0 additions & 1 deletion src/configs/algo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
"model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
"workflow_test": False,
}

test_fl_inversion: ConfigType = {
Expand Down
1 change: 0 additions & 1 deletion src/configs/algo_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
"workflow_test": True,
}

# default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic]
1 change: 0 additions & 1 deletion src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"matlaber3": [0, 1, 2, 3],
"matlaber4": [0, 2, 3, 4, 5, 6, 7],
},
"workflow_test": False,
}

grpc_system_config_gia: ConfigType = {
Expand Down
3 changes: 1 addition & 2 deletions src/configs/sys_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_algo_configs(
"alpha_data": 1.0,
"exp_keys": [],
"dropout_dicts": dropout_dicts,
"test_samples_per_user": 100,
"test_samples_per_user": 1000,
"log_memory": True,
# "streaming_aggregation": True, # Make it true for fedstatic
"assign_based_on_host": True,
Expand All @@ -130,6 +130,5 @@ def get_algo_configs(
"matlaber3": [0, 1, 2, 3],
"matlaber4": [0, 2, 3, 4, 5, 6, 7],
},
"workflow_test": True,
}
current_config = grpc_system_config

0 comments on commit e24a2f8

Please sign in to comment.