Skip to content

Commit

Permalink
Fix the neuron resampler approach (#141)
Browse files Browse the repository at this point in the history
This was normalising across an incorrect dimension and not firing at the correct point.
  • Loading branch information
alan-cooney authored Dec 9, 2023
1 parent 15378ab commit be08405
Show file tree
Hide file tree
Showing 30 changed files with 1,083 additions and 738 deletions.
51 changes: 32 additions & 19 deletions docs/content/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,18 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"# Check if we're in Colab\n",
"try:\n",
Expand All @@ -62,22 +71,23 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from sparse_autoencoder import (\n",
" sweep,\n",
" SweepConfig,\n",
" ActivationResamplerHyperparameters,\n",
" Hyperparameters,\n",
" SourceModelHyperparameters,\n",
" Parameter,\n",
" SourceDataHyperparameters,\n",
" Method,\n",
" LossHyperparameters,\n",
" Method,\n",
" OptimizerHyperparameters,\n",
" Parameter,\n",
" SourceDataHyperparameters,\n",
" SourceModelHyperparameters,\n",
" sweep,\n",
" SweepConfig,\n",
")\n",
"import wandb\n",
"\n",
Expand All @@ -103,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -112,28 +122,31 @@
"SweepConfig(parameters=Hyperparameters(\n",
" source_data=SourceDataHyperparameters(dataset_path=Parameter(value=NeelNanda/c4-code-tokenized-2b), context_size=Parameter(value=128))\n",
" source_model=SourceModelHyperparameters(name=Parameter(value=gelu-2l), hook_site=Parameter(value=mlp_out), hook_layer=Parameter(value=0), hook_dimension=Parameter(value=512), dtype=Parameter(value=float32))\n",
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_resamples=Parameter(value=4), n_steps_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=819200), dead_neuron_threshold=Parameter(value=0.0))\n",
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=4))\n",
" loss=LossHyperparameters(l1_coefficient=Parameter(values=[0.001, 0.0001, 1e-05]))\n",
" optimizer=OptimizerHyperparameters(lr=Parameter(values=[0.001, 0.0001, 1e-05]), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=12), train_batch_size=Parameter(value=4096), max_store_size=Parameter(value=3145728), max_activations=Parameter(value=2000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=314572800), validation_number_activations=Parameter(value=1024))\n",
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=197885952), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=98942976), resample_dataset_size=Parameter(value=819200), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
" loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
" optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=8192), max_store_size=Parameter(value=2998272), max_activations=Parameter(value=1999847424), checkpoint_frequency=Parameter(value=47972352), validation_frequency=Parameter(value=99999744), validation_number_activations=Parameter(value=8192))\n",
" random_seed=Parameter(value=49)\n",
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None, runcap=None)"
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
]
},
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sweep_config = SweepConfig(\n",
" parameters=Hyperparameters(\n",
" activation_resampler=ActivationResamplerHyperparameters(\n",
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
" ),\n",
" loss=LossHyperparameters(\n",
" l1_coefficient=Parameter(values=[1e-3, 1e-4, 1e-5]),\n",
" l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
" ),\n",
" optimizer=OptimizerHyperparameters(\n",
" lr=Parameter(values=[1e-3, 1e-4, 1e-5]),\n",
" lr=Parameter(max=1e-3, min=1e-5),\n",
" ),\n",
" source_model=SourceModelHyperparameters(\n",
" name=Parameter(\"gelu-2l\"),\n",
Expand Down
26 changes: 22 additions & 4 deletions docs/content/flexible_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@
"outputs": [],
"source": [
"activation_resampler = ActivationResampler(\n",
" resample_interval=10_000, n_steps_collate=10_000, max_resamples=5\n",
" resample_interval=10_000,\n",
" n_activations_activity_collate=10_000,\n",
" max_n_resamples=5,\n",
" n_learned_features=autoencoder.n_learned_features,\n",
")"
]
},
Expand All @@ -400,9 +403,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2fe4955deca9463dbed606c9452d518e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"source_data = PreTokenizedDataset(\n",
" dataset_path=\"NeelNanda/c4-code-tokenized-2b\", context_size=int(hyperparameters[\"context_size\"])\n",
Expand All @@ -429,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down
Loading

0 comments on commit be08405

Please sign in to comment.