-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
76 lines (64 loc) · 2.7 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch, json, os
# ------------ Model config --------------------- #
NUM_CLASSES = 10 #10 or 100 for CIFAR10 or CIFAR100 respectively
NON_IID = False # True to load non-IID data from TFF, False to load IID data from torchvision
assert not(NUM_CLASSES == 10 and NON_IID), "Non-IID is only supported for CIFAR100"
DEVICE: str = torch.device("cpu")
TFF_DATA_DIR = lambda x: f'data/tff_dataloaders_10clients/{x}.pth'
# Hugging face models:
HF_MODELS = {
"ViT": "google/vit-base-patch16-224",
"DeiT": "facebook/deit-base-distilled-patch16-224",
"DeiT-T": "facebook/deit-tiny-distilled-patch16-224",
"DeiT-S": "facebook/deit-small-distilled-patch16-224",
"BiT": "google/bit-50",
"ConvNeXt": "facebook/convnext-tiny-224"
}
# Chosen model:
MODEL_NAME = HF_MODELS['DeiT-S']
PRE_TRAINED = True
# ------------ Training config ------------------ #
TRAIN_SIZE = 1000 # for non-IID this doesnt do anything since all clients by default are given 100 curated points
VAL_PORTION = 0.1 # 10% of the training set is for validation
TEST_SIZE = 100
BATCH_SIZE = 32
LEARNING_RATE = 0.0001 # 0.00001 for all others except ConVNeXt (0.0001)
EPOCHS = 1 # EPOCHS PER CLIENT in each round
# ------------ FL config ------------------------ #
SERVER_ADDRESS = "JCY-PC:8080" # LAN setup for actual FL env
NUM_CLIENTS = 5
NUM_ROUNDS = 10
DOUBLE_TRAIN = True # Double the training size for each client in each round (for non-IID only)
FRAC_FIT = 1 # Sample X% of available clients for training
FRAC_EVAL = 1 # Sample X% of available clients for evaluation
MIN_FIT = 2 # Never sample less than this for training
MIN_EVAL = 2 # Never sample less than this for evaluation
MIN_AVAIL = 2 # Wait until all these # of clients are available
FIT_CONFIG_FN = lambda srvr_rnd: {
"server_round": srvr_rnd,
"local_epochs": EPOCHS,
"learning_rate": LEARNING_RATE,
}
# CPU and GPU resources for a single client.
# Supported keys are num_cpus and num_gpus.
# SEE Ray documentation for more details. https://docs.ray.io/en/latest/ray-core/tasks/using-ray-with-gpus.html)
CLIENT_RESOURCES = None
if DEVICE.type == "cuda":
CLIENT_RESOURCES = {"num_gpus": 1}
curr_dir_path = os.path.dirname(os.path.realpath(__file__)) + "\\data\\ray_spill"
if not os.path.exists(curr_dir_path):
os.makedirs(curr_dir_path)
RAY_ARGS = dict(
_system_config={
"object_spilling_config": json.dumps(
{
"type": "filesystem",
"params": {
# Multiple directories can be specified to distribute
# IO across multiple mounted physical devices.
"directory_path": [curr_dir_path]
},
}
)
},
)