Skip to content

Commit e159304

Browse files
authored
Add a parameter sweep results type (#4)
1 parent 46ca9c8 commit e159304

File tree

4 files changed

+98
-13
lines changed

4 files changed

+98
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ exclude = '''
7070

7171
[tool.pylint.DESIGN]
7272
max-args = 10
73+
max-attributes = 20
7374

7475
[tool.pylint."MESSAGES CONTROL"]
7576
disable = "redefined-builtin" # Disable redefined builtin functions
Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,57 @@
11
"""Sweep Config."""
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33

4+
from sparse_autoencoder.train.utils.results_dataclass import (
5+
convert_parameters_to_results_type,
6+
)
47
from sparse_autoencoder.train.utils.wandb_sweep_types import (
58
Method,
69
Parameter,
10+
Parameters,
711
WandbSweepConfig,
812
)
913

1014

1115
@dataclass
12-
class SweepParameterConfig(dict):
16+
class SweepParameterConfig(Parameters):
1317
"""Sweep Parameter Config."""
1418

15-
lr: Parameter[float] = Parameter(value=0.001)
19+
lr: Parameter[float] = field(default_factory=lambda: Parameter(value=0.001))
1620
"""Adam Learning Rate."""
1721

18-
adam_beta_1: Parameter[float] = Parameter(value=0.9)
22+
adam_beta_1: Parameter[float] = field(default_factory=lambda: Parameter(value=0.9))
1923
"""Adam Beta 1.
2024
2125
The exponential decay rate for the first moment estimates (mean) of the gradient.
2226
"""
2327

24-
adam_beta_2: Parameter[float] = Parameter(value=0.999)
28+
adam_beta_2: Parameter[float] = field(
29+
default_factory=lambda: Parameter(value=0.999)
30+
)
2531
"""Adam Beta 2.
2632
2733
The exponential decay rate for the second moment estimates (variance) of the gradient.
2834
"""
2935

30-
adam_epsilon: Parameter[float] = Parameter(value=1e-8)
36+
adam_epsilon: Parameter[float] = field(
37+
default_factory=lambda: Parameter(value=1e-8)
38+
)
3139
"""Adam Epsilon.
3240
3341
A small constant for numerical stability.
3442
"""
3543

36-
adam_weight_decay: Parameter[float] = Parameter(value=0)
44+
adam_weight_decay: Parameter[float] = field(
45+
default_factory=lambda: Parameter(value=0)
46+
)
3747
"""Adam Weight Decay.
3848
3949
Weight decay (L2 penalty).
4050
"""
4151

42-
l1_coefficient: Parameter[float] = Parameter(value=[0.001, 0.004, 0.006, 0.008, 1])
52+
l1_coefficient: Parameter[float] = field(
53+
default_factory=lambda: Parameter(value=[0.001, 0.004, 0.006, 0.008, 1])
54+
)
4355
"""L1 Penalty Coefficient.
4456
4557
The L1 penalty is the absolute sum of learned (hidden) activations, multiplied by this constant.
@@ -50,14 +62,19 @@ class SweepParameterConfig(dict):
5062
paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html).
5163
"""
5264

53-
width_multiplier: Parameter[int] = Parameter(value=8, min=1, max=256)
65+
width_multiplier: Parameter[int] = field(
66+
default_factory=lambda: Parameter(value=8, min=1, max=256)
67+
)
5468
"""Source-to-Trained Activations Width Multiplier."""
5569

5670

71+
SweepRunParameters = convert_parameters_to_results_type(SweepParameterConfig)
72+
73+
5774
@dataclass
5875
class SweepConfig(WandbSweepConfig):
5976
"""Sweep Config."""
6077

61-
method: Method = Method.bayes
62-
6378
parameters: SweepParameterConfig
79+
80+
method: Method = Method.bayes
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Results dataclass utilities."""
2+
from dataclasses import MISSING, field, fields, make_dataclass
3+
from typing import Any, Type, get_args, get_origin
4+
5+
from sparse_autoencoder.train.utils.wandb_sweep_types import Parameter
6+
7+
8+
def convert_parameters_to_results_type(config_dataclass: Type) -> Type:
9+
"""Convert Parameters to a Results Dataclass Type
10+
11+
Converts a :class:`sparse_autoencoder.train.utils.wandb_sweep_types.Parameters` dataclass type
12+
signature into a parameter results dataclass type signature.
13+
14+
Example:
15+
16+
>>> from dataclasses import dataclass, field
17+
>>> @dataclass
18+
... class SweepParameterConfig:
19+
...
20+
... lr: Parameter[float] = field(default_factory=lambda: Parameter(value=0.001))
21+
...
22+
... lr_list: Parameter[float] = field(
23+
... default_factory=lambda: Parameter(value=[0.002, 0.004])
24+
... )
25+
...
26+
>>> SweepParameterResults = convert_parameters_to_results_type(SweepParameterConfig)
27+
>>> SweepParameterResults.__annotations__['lr']
28+
<class 'float'>
29+
30+
>>> SweepParameterResults.__annotations__['lr_list']
31+
<class 'float'>
32+
33+
Args:
34+
config_dataclass: The config dataclass to convert.
35+
"""
36+
new_fields: list[tuple[str, Any, Any]] = []
37+
38+
for f in fields(config_dataclass):
39+
# Determine if the default should come from a default or a factory
40+
if f.default is not MISSING:
41+
default_value = f.default
42+
elif f.default_factory is not MISSING: # Use the default factory if provided
43+
default_value = field( # pylint: disable=invalid-field-call
44+
default_factory=f.default_factory
45+
)
46+
else:
47+
default_value = MISSING
48+
49+
# If the field is a Parameter, replace it with the contained type
50+
if get_origin(f.type) == Parameter:
51+
contained_type = get_args(f.type)[0]
52+
# If the contained type is a list, go one level deeper
53+
if get_origin(contained_type) == list:
54+
list_contained_type = get_args(contained_type)[0]
55+
new_fields.append(
56+
(f.name, list[list_contained_type], default_value) # type: ignore
57+
)
58+
else:
59+
new_fields.append((f.name, contained_type, default_value))
60+
else:
61+
new_fields.append((f.name, f.type, default_value))
62+
63+
# Create a new dataclass with the new fields
64+
return make_dataclass(config_dataclass.__name__ + "Results", new_fields)

sparse_autoencoder/train/utils/wandb_sweep_types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,16 @@ class Parameter(Generic[ParamType]):
183183
parameters: dict[str, "Parameter"] | None = None
184184

185185

186+
Parameters = dict[str, Parameter]
187+
188+
186189
@dataclass
187190
class WandbSweepConfig:
188191
"""Weights & Biases Sweep Configuration."""
189192

190-
method: Method
193+
parameters: Parameters
191194

192-
parameters: dict[str, Parameter]
195+
method: Method
193196

194197
apiVersion: str | None = None
195198

0 commit comments

Comments
 (0)