-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfigs.py
94 lines (69 loc) · 2.02 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
'''
==========
Date: July 5, 2022
Maintainer:
Xinyi Zhong ([email protected])
Xinchen Du ([email protected])
Zhiyuan Long ([email protected])
==========
Config class and CLI to set config
'''
from dataclasses import dataclass
#from omegaconf import OmegaConf, MISSING
from csv import DictReader
import math
import argparse
def parse_argv():
parser = argparse.ArgumentParser(prog='SC')
parser.add_argument("--row", type = int)
parser.add_argument("--processor", type = str)
return parser.parse_args()
args = parse_argv()
params = []
with open('parameter.csv', 'r') as read_obj:
# pass the file object to DictReader() to get the DictReader object
csv_dict_reader = DictReader(read_obj)
# iterate over each line as a ordered dictionary
for row in csv_dict_reader:
# row variable is a dictionary that represents a row in csv
params.append(row)
i = args.row
param = params[i-1]
def get_argsrow():
return i
def get_argsprocessor():
return args.processor
def get_neuron_shape(x):
y = int(math.sqrt(x))
return (y, y)
@dataclass
class KernelConfig:
ri: int = int(param['ri'])
re: int = int(param['re'])
wi: int = int(param['wi'])
we: int = int(param['we'])
leaky: int = wi + we
@dataclass
class NDMConfig:
pass
@dataclass
class ExperimentConfig:
loader_name: str = "unigram97"
ndm_name: str = "l1ActDoubleDecker"
dl_name: str = "gradientDescent"
input_dim: int = 97
neuron_shape: tuple = get_neuron_shape(int(param['neuron_shape']))
gradient_steps: int = int(param['gradient_steps'])
batch_size: int = int(param['batch_size'])
lr_act: float = float(param['lr_act'])
lr_codebook: float = float(param['lr_codebook'])
l0_target: float = float(param['l0_target'])
threshold: float = float(param['threshold'])
@dataclass
class Config:
kernel: KernelConfig = KernelConfig()
ndm: NDMConfig = NDMConfig()
exp: ExperimentConfig = ExperimentConfig()
srepr: str = None
srepr = "test"
cfg = Config(srepr=srepr)