Skip to content

Commit

Permalink
Add a parameter sweep results type (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 2, 2023
1 parent 46ca9c8 commit e159304
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ exclude = '''

[tool.pylint.DESIGN]
max-args = 10
max-attributes = 20

[tool.pylint."MESSAGES CONTROL"]
disable = "redefined-builtin" # Disable redefined builtin functions
Expand Down
39 changes: 28 additions & 11 deletions sparse_autoencoder/train/sweep_config.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,57 @@
"""Sweep Config."""
from dataclasses import dataclass
from dataclasses import dataclass, field

from sparse_autoencoder.train.utils.results_dataclass import (
convert_parameters_to_results_type,
)
from sparse_autoencoder.train.utils.wandb_sweep_types import (
Method,
Parameter,
Parameters,
WandbSweepConfig,
)


@dataclass
class SweepParameterConfig(dict):
class SweepParameterConfig(Parameters):
"""Sweep Parameter Config."""

lr: Parameter[float] = Parameter(value=0.001)
lr: Parameter[float] = field(default_factory=lambda: Parameter(value=0.001))
"""Adam Learning Rate."""

adam_beta_1: Parameter[float] = Parameter(value=0.9)
adam_beta_1: Parameter[float] = field(default_factory=lambda: Parameter(value=0.9))
"""Adam Beta 1.
The exponential decay rate for the first moment estimates (mean) of the gradient.
"""

adam_beta_2: Parameter[float] = Parameter(value=0.999)
adam_beta_2: Parameter[float] = field(
default_factory=lambda: Parameter(value=0.999)
)
"""Adam Beta 2.
The exponential decay rate for the second moment estimates (variance) of the gradient.
"""

adam_epsilon: Parameter[float] = Parameter(value=1e-8)
adam_epsilon: Parameter[float] = field(
default_factory=lambda: Parameter(value=1e-8)
)
"""Adam Epsilon.
A small constant for numerical stability.
"""

adam_weight_decay: Parameter[float] = Parameter(value=0)
adam_weight_decay: Parameter[float] = field(
default_factory=lambda: Parameter(value=0)
)
"""Adam Weight Decay.
Weight decay (L2 penalty).
"""

l1_coefficient: Parameter[float] = Parameter(value=[0.001, 0.004, 0.006, 0.008, 1])
l1_coefficient: Parameter[float] = field(
default_factory=lambda: Parameter(value=[0.001, 0.004, 0.006, 0.008, 1])
)
"""L1 Penalty Coefficient.
The L1 penalty is the absolute sum of learned (hidden) activations, multiplied by this constant.
Expand All @@ -50,14 +62,19 @@ class SweepParameterConfig(dict):
paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html).
"""

width_multiplier: Parameter[int] = Parameter(value=8, min=1, max=256)
width_multiplier: Parameter[int] = field(
default_factory=lambda: Parameter(value=8, min=1, max=256)
)
"""Source-to-Trained Activations Width Multiplier."""


SweepRunParameters = convert_parameters_to_results_type(SweepParameterConfig)


@dataclass
class SweepConfig(WandbSweepConfig):
"""Sweep Config."""

method: Method = Method.bayes

parameters: SweepParameterConfig

method: Method = Method.bayes
64 changes: 64 additions & 0 deletions sparse_autoencoder/train/utils/results_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Results dataclass utilities."""
from dataclasses import MISSING, field, fields, make_dataclass
from typing import Any, Type, get_args, get_origin

from sparse_autoencoder.train.utils.wandb_sweep_types import Parameter


def convert_parameters_to_results_type(config_dataclass: Type) -> Type:
"""Convert Parameters to a Results Dataclass Type
Converts a :class:`sparse_autoencoder.train.utils.wandb_sweep_types.Parameters` dataclass type
signature into a parameter results dataclass type signature.
Example:
>>> from dataclasses import dataclass, field
>>> @dataclass
... class SweepParameterConfig:
...
... lr: Parameter[float] = field(default_factory=lambda: Parameter(value=0.001))
...
... lr_list: Parameter[float] = field(
... default_factory=lambda: Parameter(value=[0.002, 0.004])
... )
...
>>> SweepParameterResults = convert_parameters_to_results_type(SweepParameterConfig)
>>> SweepParameterResults.__annotations__['lr']
<class 'float'>
>>> SweepParameterResults.__annotations__['lr_list']
<class 'float'>
Args:
config_dataclass: The config dataclass to convert.
"""
new_fields: list[tuple[str, Any, Any]] = []

for f in fields(config_dataclass):
# Determine if the default should come from a default or a factory
if f.default is not MISSING:
default_value = f.default
elif f.default_factory is not MISSING: # Use the default factory if provided
default_value = field( # pylint: disable=invalid-field-call
default_factory=f.default_factory
)
else:
default_value = MISSING

# If the field is a Parameter, replace it with the contained type
if get_origin(f.type) == Parameter:
contained_type = get_args(f.type)[0]
# If the contained type is a list, go one level deeper
if get_origin(contained_type) == list:
list_contained_type = get_args(contained_type)[0]
new_fields.append(
(f.name, list[list_contained_type], default_value) # type: ignore
)
else:
new_fields.append((f.name, contained_type, default_value))
else:
new_fields.append((f.name, f.type, default_value))

# Create a new dataclass with the new fields
return make_dataclass(config_dataclass.__name__ + "Results", new_fields)
7 changes: 5 additions & 2 deletions sparse_autoencoder/train/utils/wandb_sweep_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,16 @@ class Parameter(Generic[ParamType]):
parameters: dict[str, "Parameter"] | None = None


Parameters = dict[str, Parameter]


@dataclass
class WandbSweepConfig:
"""Weights & Biases Sweep Configuration."""

method: Method
parameters: Parameters

parameters: dict[str, Parameter]
method: Method

apiVersion: str | None = None

Expand Down

0 comments on commit e159304

Please sign in to comment.