-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
165 lines (146 loc) · 5.07 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class TBVariant(Enum):
"""See algo.trajectory_balance.TrajectoryBalance for details."""
TB = 0
SubTB1 = 1
DB = 2
SubTBMC = 3
@dataclass
class TBConfig:
"""Trajectory Balance config.
Attributes
----------
bootstrap_own_reward : bool
Whether to bootstrap the reward with the own reward. (deprecated)
epsilon : Optional[float]
The epsilon parameter in log-flow smoothing (see paper)
reward_loss_multiplier : float
The multiplier for the reward loss when bootstrapping the reward. (deprecated)
variant : TBVariant
The loss variant. See algo.trajectory_balance.TrajectoryBalance for details.
do_correct_idempotent : bool
Whether to correct for idempotent actions
do_parameterize_p_b : bool
Whether to parameterize the P_B distribution (otherwise it is uniform)
do_length_normalize : bool
Whether to normalize the loss by the length of the trajectory
subtb_max_len : int
The maximum length trajectories, used to cache subTB computation indices
Z_learning_rate : float
The learning rate for the logZ parameter (only relevant when do_subtb is False)
Z_lr_decay : float
The learning rate decay for the logZ parameter (only relevant when do_subtb is False)
"""
bootstrap_own_reward: bool = False
epsilon: Optional[float] = None
min_entropy_alpha: Optional[float] = None # not used
softmax_temper: Optional[float] = None # not used
reward_loss_multiplier: float = 1.0
variant: TBVariant = TBVariant.TB
do_correct_idempotent: bool = True
do_parameterize_p_b: bool = False
do_length_normalize: bool = False
subtb_max_len: int = 128
do_length_normalize: bool = True
Z_learning_rate: float = 1e-4
Z_lr_decay: float = 50_000
cum_subtb: bool = False
@dataclass
class MOQLConfig:
gamma: float = 1
num_omega_samples: int = 32
num_objectives: int = 2
lambda_decay: int = 10_000
penalty: float = -10
@dataclass
class A2CConfig:
entropy: float = 0.01
gamma: float = 1
penalty: float = -10
@dataclass
class FMConfig:
epsilon: float = 1e-38
balanced_loss: bool = False
leaf_coef: float = 10
correct_idempotent: bool = True
@dataclass
class SQLConfig:
alpha: float = 0.01
gamma: float = 1
penalty: float = -10
@dataclass
class AlgoConfig:
"""Generic configuration for algorithms
Attributes
----------
method : str
The name of the algorithm to use (e.g. "TB")
global_batch_size : int
The batch size for training
max_len : int
The maximum length of a trajectory
max_nodes : int
The maximum number of nodes in a generated graph
max_edges : int
The maximum number of edges in a generated graph
illegal_action_logreward : float
The log reward an agent gets for illegal actions
offline_ratio: float
The ratio of samples drawn from `self.training_data` during training. The rest is drawn from
`self.sampling_model`
valid_offline_ratio: float
Idem but for validation, and `self.test_data`.
offline_sampling_g_distribution: str
In offline training, this select P(x) for sampling x ~ P(x).
Options = ["uniform", "log_rewards", "log_p", "loss_gfn", "error_gfn"]
true_log_Z: float
TODO
use_true_log_Z: bool
only use in offline setting to control for effects of learing log_Z
l2_reg_log_Z_lambda: float
TODO
l1_reg_log_Z_lambda: float
TODO
flow_reg: bool
TODO
dir_model_pretrain_for_sampling: str
TODO
alpha: float
TODO
train_random_action_prob : float
The probability of taking a random action during training
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Whether to sample conditioning information during validation (if False, expects a validation set of cond_info)
sampling_tau : float
The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta)
"""
method: str = "TB"
global_batch_size: int = 64
max_len: int = 128
max_nodes: int = 128
max_edges: int = 128
illegal_action_logreward: float = -100
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
offline_sampling_g_distribution: Optional[str] = None
use_true_log_Z: bool = False
true_log_Z: Optional[float] = None
l2_reg_log_Z_lambda: float = 0.0
l1_reg_log_Z_lambda: float = 0.0
flow_reg: bool = False
dir_model_pretrain_for_sampling: Optional[str] = None
supervised_reward_predictor: Optional[str] = None
alpha: float = 0.0
train_random_action_prob: float = 0.0
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
tb: TBConfig = TBConfig()
moql: MOQLConfig = MOQLConfig()
a2c: A2CConfig = A2CConfig()
fm: FMConfig = FMConfig()
sql: SQLConfig = SQLConfig()