diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index f07c45ea..88080174 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -10,7 +10,7 @@ ARG USER_GID=$USER_UID # https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user RUN groupadd --gid $USER_GID $USERNAME \ && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ - && usermod -a -G video user \ + && usermod -a -G video user \ && apt-get update \ && apt-get install -y sudo \ && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index f3d5961e..88607ecd 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,3 @@ -// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: -// https://github.com/microsoft/vscode-dev-containers/tree/v0.238.1/containers/python-3 { "name": "Python 3", "build": { @@ -12,46 +10,6 @@ "--gpus", "all" ], - // Configure tool-specific properties. - "customizations": { - // Configure properties specific to VS Code. - "vscode": { - // Set *default* container specific settings.json values on container create. - "settings": { - "mypy.dmypyExecutable": "dmypy" - }, - // Add the IDs of extensions you want installed when the container is created. - "extensions": [ - "christian-kohler.path-intellisense", - "davidanson.vscode-markdownlint", - "donjayamanne.githistory", - "donjayamanne.python-extension-pack", - "github.copilot", - "github.vscode-github-actions", - "github.vscode-pull-request-github", - "ionutvmi.path-autocomplete", - "matangover.mypy", - "mikoz.autoflake-extension", - "ms-python.black-formatter", - "ms-python.isort", - "ms-python.pylint", - "ms-python.python", - "ms-python.vscode-pylance", - "ms-toolsai.jupyter-keymap", - "ms-toolsai.jupyter-renderers", - "ms-toolsai.jupyter", - "ms-vsliveshare.vsliveshare-pack", - "njpwerner.autodocstring", - "richie5um2.vscode-sort-json", - "stkb.rewrap", - "streetsidesoftware.code-spell-checker-british-english", - "streetsidesoftware.code-spell-checker", - "tushortz.python-extended-snippets", - "yzhang.markdown-all-in-one" - ] - } - }, "containerUser": "user", - // Install any dependencies "postCreateCommand": "poetry env use 3.11 && poetry install --with dev,jupyter" -} \ No newline at end of file +} diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 99893a21..60b4d6d4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -2,7 +2,7 @@ name: Checks on: push: - branches: + branches: - main paths-ignore: - '.devcontainer/**' @@ -10,7 +10,7 @@ on: - '.gitignore' - 'README.md' pull_request: - branches: + branches: - main paths-ignore: - '.devcontainer/**' @@ -35,7 +35,7 @@ jobs: - "3.10" - "3.11" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Poetry uses: snok/install-poetry@v1 - name: Set up Python @@ -43,13 +43,16 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: 'poetry' + allow-prereleases: true - name: Install dependencies run: poetry install --with dev - - name: Pytest - run: poetry run pytest - name: Pyright type check run: poetry run pyright - - name: Ruff (lint & format) - run: poetry run ruff --check sparse_autoencoder + - name: Ruff lint + run: poetry run ruff check sparse_autoencoder --output-format=github + - name: Ruff format + run: poetry run ruff format sparse_autoencoder --check + - name: Pytest + run: poetry run pytest - name: Build check - run: poetry build \ No newline at end of file + run: poetry build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9bd999e8..0be2934c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Parse semver string - id: semver_parser + id: semver_parser uses: booxmedialtd/ws-action-parse-semver@v1.4.7 with: input_string: ${{ github.event.release.tag_name }} @@ -51,4 +51,4 @@ jobs: - name: Publish run: poetry publish env: - POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN }} \ No newline at end of file + POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..330cc3d1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,36 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-json + - id: check-added-large-files + - id: check-merge-conflict + - id: check-symlinks + - id: destroyed-symlinks + - id: detect-private-key + - repo: local + hooks: + - id: ruff_lint + name: Ruff Lint + entry: poetry run ruff check sparse_autoencoder + language: system + types: [python] + require_serial: true + - id: ruff_format + name: Ruff Format + entry: poetry run ruff format sparse_autoencoder --check + language: system + types: [python] + require_serial: true + - id: typecheck + name: Pyright Type Check + entry: poetry run pyright + language: system + types: [python] + require_serial: true diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 69b6e0c3..7f6dd48d 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -2,23 +2,28 @@ "language": "en,en-GB", "words": [ "allclose", + "astroid", "autocast", - "Autoencoder", + "autoencoder", "categoricalwprobabilities", "circuitsvis", "colab", "cuda", "cudnn", + "davidanson", "devcontainer", "devel", + "dmypy", "docstrings", + "donjayamanne", "dtype", "dunder", "earlyterminate", "einops", "endoftext", "gelu", - "Hobbhahn", + "githistory", + "hobbhahn", "hyperband", "imageuri", "imputewhilerunning", @@ -41,6 +46,7 @@ "nelement", "numel", "optim", + "penality", "polysemantic", "polysemantically", "pyproject", @@ -59,6 +65,7 @@ "sharded", "tqdm", "transformer_lens", + "typecheck", "uncopyrighted", "unsqueeze", "venv", diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 917ff3b9..60684bab 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -21,4 +21,4 @@ "kevinrose.vsc-python-indent", "donjayamanne.python-environment-manager" ] -} \ No newline at end of file +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 07416e1c..047ff8ec 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,14 +1,44 @@ { - "rewrap.autoWrap.enabled": true, + "[jsonc]": { + "editor.defaultFormatter": "vscode.json-language-features" + }, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "[toml]": { + "editor.defaultFormatter": "tamasfe.even-better-toml" + }, "editor.codeActionsOnSave": { "source.fixAll.eslint": true, "source.organizeImports": true }, - "rewrap.reformat": false, "editor.formatOnSave": true, - "python.testing.pytestEnabled": true, - "rewrap.wrappingColumn": 100, + "evenBetterToml.formatter.alignComments": true, + "evenBetterToml.formatter.alignEntries": true, + "evenBetterToml.formatter.allowedBlankLines": 2, + "evenBetterToml.formatter.arrayAutoCollapse": true, + "evenBetterToml.formatter.arrayAutoExpand": true, + "evenBetterToml.formatter.arrayTrailingComma": true, + "evenBetterToml.formatter.columnWidth": 100, + "evenBetterToml.formatter.compactArrays": true, + "evenBetterToml.formatter.compactEntries": true, + "evenBetterToml.formatter.compactInlineTables": true, + "evenBetterToml.formatter.indentEntries": true, + "evenBetterToml.formatter.indentTables": true, + "evenBetterToml.formatter.inlineTableExpand": false, + "evenBetterToml.formatter.reorderArrays": true, + "evenBetterToml.formatter.reorderKeys": true, + "evenBetterToml.formatter.trailingNewline": true, "notebook.formatOnCellExecution": true, "notebook.formatOnSave.enabled": true, - "editor.defaultFormatter": "charliermarsh.ruff", -} \ No newline at end of file + "python.analysis.autoFormatStrings": true, + "python.analysis.autoImportCompletions": true, + "python.analysis.inlayHints.functionReturnTypes": false, + "python.analysis.typeCheckingMode": "basic", + "python.languageServer": "Pylance", + "python.terminal.activateEnvInCurrentTerminal": true, + "python.testing.pytestEnabled": true, + "rewrap.autoWrap.enabled": true, + "rewrap.reformat": true, + "rewrap.wrappingColumn": 100 +} diff --git a/demo.ipynb b/demo.ipynb index 720b2e2c..ce50380b 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -16,13 +16,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Autoreload\n", - "# %load_ext autoreload\n", - "# %autoreload 2\n", + "%load_ext autoreload\n", + "%autoreload 2\n", "\n", "from sparse_autoencoder import (\n", " SparseAutoencoder,\n", @@ -31,6 +31,7 @@ " create_src_dataloader,\n", ")\n", "from transformer_lens import HookedTransformer\n", + "from transformer_lens.utils import get_device\n", "from sparse_autoencoder.src_data.datasets.neel_c4_tokenized import (\n", " collate_neel_c4_tokenized,\n", ")\n", @@ -39,6 +40,15 @@ "import os" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = get_device()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -48,15 +58,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "129121d84045498ea87783e8c7e32e5a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/28 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/alan/Documents/Repos/sparse_autoencoder/wandb/run-20231104_163710-v4rzv049" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run toasty-jazz-13 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/alan-cooney/sparse-autoencoder" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/alan-cooney/sparse-autoencoder/runs/v4rzv049" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error in callback > (for post_run_cell), with arguments args ( result=>,),kwargs {}:\n" + ] + }, + { + "ename": "TypeError", + "evalue": "_WandbInit._pause_backend() takes 1 positional argument but 2 were given", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;31mTypeError\u001b[0m: _WandbInit._pause_backend() takes 1 positional argument but 2 were given" + ] + } + ], "source": [ - "# Disable TOKENIZERS_PARALLELISM warning\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + "wandb.init(project=\"sparse-autoencoder\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "183144c8bf514fdf899ae1370411c41f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generate/Train Cycles: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "52050dd0401149a0b3fde019b5423e3b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generate Activations: 0%| | 0/1966080 [00:00 1\u001b[0m pipeline(\n\u001b[1;32m 2\u001b[0m src_model\u001b[39m=\u001b[39;49msrc_model,\n\u001b[1;32m 3\u001b[0m src_model_activation_hook_point\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mblocks.0.mlp.hook_post\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 4\u001b[0m src_model_activation_layer\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m,\n\u001b[1;32m 5\u001b[0m src_dataloader\u001b[39m=\u001b[39;49msrc_dataloader,\n\u001b[1;32m 6\u001b[0m activation_store\u001b[39m=\u001b[39;49mstore,\n\u001b[1;32m 7\u001b[0m num_activations_before_training\u001b[39m=\u001b[39;49mmax_items,\n\u001b[1;32m 8\u001b[0m autoencoder\u001b[39m=\u001b[39;49mautoencoder,\n\u001b[1;32m 9\u001b[0m device\u001b[39m=\u001b[39;49mdevice,\n\u001b[1;32m 10\u001b[0m )\n", + "File \u001b[0;32m~/Documents/Repos/sparse_autoencoder/sparse_autoencoder/train/pipeline.py:88\u001b[0m, in \u001b[0;36mpipeline\u001b[0;34m(src_model, src_model_activation_hook_point, src_model_activation_layer, src_dataloader, activation_store, num_activations_before_training, autoencoder, sweep_parameters, device)\u001b[0m\n\u001b[1;32m 82\u001b[0m dataloader \u001b[39m=\u001b[39m DataLoader(\n\u001b[1;32m 83\u001b[0m activation_store,\n\u001b[1;32m 84\u001b[0m batch_size\u001b[39m=\u001b[39m\u001b[39m8192\u001b[39m,\n\u001b[1;32m 85\u001b[0m )\n\u001b[1;32m 87\u001b[0m \u001b[39m# Train the autoencoder\u001b[39;00m\n\u001b[0;32m---> 88\u001b[0m train_autoencoder(\n\u001b[1;32m 89\u001b[0m activations_dataloader\u001b[39m=\u001b[39;49mdataloader,\n\u001b[1;32m 90\u001b[0m autoencoder\u001b[39m=\u001b[39;49mautoencoder,\n\u001b[1;32m 91\u001b[0m optimizer\u001b[39m=\u001b[39;49moptimizer,\n\u001b[1;32m 92\u001b[0m sweep_parameters\u001b[39m=\u001b[39;49msweep_parameters,\n\u001b[1;32m 93\u001b[0m device\u001b[39m=\u001b[39;49mdevice,\n\u001b[1;32m 94\u001b[0m )\n\u001b[1;32m 96\u001b[0m \u001b[39m# Empty the store so we can fill it up again\u001b[39;00m\n\u001b[1;32m 97\u001b[0m activation_store\u001b[39m.\u001b[39mempty()\n", + "File \u001b[0;32m~/Documents/Repos/sparse_autoencoder/sparse_autoencoder/train/train_autoencoder.py:73\u001b[0m, in \u001b[0;36mtrain_autoencoder\u001b[0;34m(activations_dataloader, autoencoder, optimizer, sweep_parameters, log_interval, device)\u001b[0m\n\u001b[1;32m 65\u001b[0m total_loss \u001b[39m=\u001b[39m sae_training_loss(\n\u001b[1;32m 66\u001b[0m reconstruction_loss_mse,\n\u001b[1;32m 67\u001b[0m l1_loss_learned_activations,\n\u001b[1;32m 68\u001b[0m sweep_parameters\u001b[39m.\u001b[39ml1_coefficient,\n\u001b[1;32m 69\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[39m# TODO: Log dead neurons metric (get_frequencies in Neel's code)\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \n\u001b[1;32m 72\u001b[0m \u001b[39m# Backwards pass\u001b[39;00m\n\u001b[0;32m---> 73\u001b[0m total_loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m 75\u001b[0m \u001b[39m# TODO: Make decoder weights and grad unit norm\u001b[39;00m\n\u001b[1;32m 77\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n", + "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 483\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 484\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 485\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 490\u001b[0m inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 492\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m 493\u001b[0m \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m 494\u001b[0m )\n", + "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 246\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 248\u001b[0m \u001b[39m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 252\u001b[0m tensors,\n\u001b[1;32m 253\u001b[0m grad_tensors_,\n\u001b[1;32m 254\u001b[0m retain_graph,\n\u001b[1;32m 255\u001b[0m create_graph,\n\u001b[1;32m 256\u001b[0m inputs,\n\u001b[1;32m 257\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 258\u001b[0m accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 259\u001b[0m )\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "pipeline(\n", " src_model=src_model,\n", @@ -141,6 +377,7 @@ " activation_store=store,\n", " num_activations_before_training=max_items,\n", " autoencoder=autoencoder,\n", + " device=device,\n", ")" ] } diff --git a/poetry.lock b/poetry.lock index eba09a7b..11596893 100644 --- a/poetry.lock +++ b/poetry.lock @@ -270,20 +270,6 @@ types-python-dateutil = ">=2.8.10" doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"] test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"] -[[package]] -name = "astroid" -version = "3.0.1" -description = "An abstract syntax tree for Python with inference support." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "astroid-3.0.1-py3-none-any.whl", hash = "sha256:7d5895c9825e18079c5aeac0572bc2e4c83205c95d416e0b4fee8bc361d2d9ca"}, - {file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"}, -] - -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} - [[package]] name = "asttokens" version = "2.4.1" @@ -491,6 +477,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -806,6 +803,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "distlib" +version = "0.3.7" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, + {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, +] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -1084,6 +1092,20 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jed torch = ["torch"] typing = ["pydantic (<2.0)", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +[[package]] +name = "identify" +version = "2.5.31" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.31-py2.py3-none-any.whl", hash = "sha256:90199cb9e7bd3c5407a9b7e81b4abec4bb9d249991c79439ec8af740afc6293d"}, + {file = "identify-2.5.31.tar.gz", hash = "sha256:7736b3c7a28233637e3c36550646fc6389bedd74ae84cb788200cc8e2dd60b75"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.4" @@ -1222,23 +1244,6 @@ files = [ [package.dependencies] arrow = ">=0.15.0" -[[package]] -name = "isort" -version = "5.12.0" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"}, - {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"}, -] - -[package.extras] -colors = ["colorama (>=0.4.3)"] -pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] -plugins = ["setuptools"] -requirements-deprecated-finder = ["pip-api", "pipreqs"] - [[package]] name = "jaxtyping" version = "0.2.23" @@ -1724,17 +1729,6 @@ files = [ [package.dependencies] traitlets = "*" -[[package]] -name = "mccabe" -version = "0.7.0" -description = "McCabe checker, plugin for flake8" -optional = false -python-versions = ">=3.6" -files = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, -] - [[package]] name = "mdurl" version = "0.1.2" @@ -2347,6 +2341,17 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "pastel" +version = "0.2.1" +description = "Bring colors to your terminal." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"}, + {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, +] + [[package]] name = "pathtools" version = "0.1.2" @@ -2416,6 +2421,42 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "poethepoet" +version = "0.24.2" +description = "A task runner that works well with poetry." +optional = false +python-versions = ">=3.8" +files = [ + {file = "poethepoet-0.24.2-py3-none-any.whl", hash = "sha256:affaf7669542f54df05ed1e2be3d24028c9f4bd2ea514813fae7f01c5ca6e686"}, + {file = "poethepoet-0.24.2.tar.gz", hash = "sha256:f600ecdbf58b474f7dba273060b194242566ffccb41d75f5b0d1cb8f5aa8bf2e"}, +] + +[package.dependencies] +pastel = ">=0.2.1,<0.3.0" +tomli = ">=1.2.2" + +[package.extras] +poetry-plugin = ["poetry (>=1.0,<2.0)"] + +[[package]] +name = "pre-commit" +version = "3.5.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, + {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prometheus-client" version = "0.18.0" @@ -2590,35 +2631,6 @@ files = [ [package.extras] plugins = ["importlib-metadata"] -[[package]] -name = "pylint" -version = "3.0.2" -description = "python code static checker" -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "pylint-3.0.2-py3-none-any.whl", hash = "sha256:60ed5f3a9ff8b61839ff0348b3624ceeb9e6c2a92c514d81c9cc273da3b6bcda"}, - {file = "pylint-3.0.2.tar.gz", hash = "sha256:0d4c286ef6d2f66c8bfb527a7f8a629009e42c99707dec821a03e1b51a4c1496"}, -] - -[package.dependencies] -astroid = ">=3.0.1,<=3.1.0-dev0" -colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, -] -isort = ">=4.2.5,<6" -mccabe = ">=0.6,<0.8" -platformdirs = ">=2.2.0" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -tomlkit = ">=0.10.1" - -[package.extras] -spelling = ["pyenchant (>=3.2,<4.0)"] -testutils = ["gitpython (>3)"] - [[package]] name = "pyright" version = "1.1.334" @@ -3819,17 +3831,6 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] -[[package]] -name = "tomlkit" -version = "0.12.2" -description = "Style preserving TOML library" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tomlkit-0.12.2-py3-none-any.whl", hash = "sha256:eeea7ac7563faeab0a1ed8fe12c2e5a51c61f933f2502f7e9db0241a65163ad0"}, - {file = "tomlkit-0.12.2.tar.gz", hash = "sha256:df32fab589a81f0d7dc525a4267b6d7a64ee99619cbd1eeb0fae32c1dd426977"}, -] - [[package]] name = "torch" version = "2.1.0" @@ -4139,6 +4140,26 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.24.6" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.24.6-py3-none-any.whl", hash = "sha256:520d056652454c5098a00c0f073611ccbea4c79089331f60bf9d7ba247bb7381"}, + {file = "virtualenv-20.24.6.tar.gz", hash = "sha256:02ece4f56fbf939dbbc33c0715159951d6bf14aaf5457b092e4548e1382455af"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<4" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "wandb" version = "0.15.12" @@ -4445,65 +4466,7 @@ files = [ idna = ">=2.0" multidict = ">=4.0" -[[package]] -name = "zstandard" -version = "0.21.0" -description = "Zstandard bindings for Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "zstandard-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:649a67643257e3b2cff1c0a73130609679a5673bf389564bc6d4b164d822a7ce"}, - {file = "zstandard-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:144a4fe4be2e747bf9c646deab212666e39048faa4372abb6a250dab0f347a29"}, - {file = "zstandard-0.21.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b72060402524ab91e075881f6b6b3f37ab715663313030d0ce983da44960a86f"}, - {file = "zstandard-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8257752b97134477fb4e413529edaa04fc0457361d304c1319573de00ba796b1"}, - {file = "zstandard-0.21.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c053b7c4cbf71cc26808ed67ae955836232f7638444d709bfc302d3e499364fa"}, - {file = "zstandard-0.21.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2769730c13638e08b7a983b32cb67775650024632cd0476bf1ba0e6360f5ac7d"}, - {file = "zstandard-0.21.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7d3bc4de588b987f3934ca79140e226785d7b5e47e31756761e48644a45a6766"}, - {file = "zstandard-0.21.0-cp310-cp310-win32.whl", hash = "sha256:67829fdb82e7393ca68e543894cd0581a79243cc4ec74a836c305c70a5943f07"}, - {file = "zstandard-0.21.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6048a287f8d2d6e8bc67f6b42a766c61923641dd4022b7fd3f7439e17ba5a4d"}, - {file = "zstandard-0.21.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7f2afab2c727b6a3d466faee6974a7dad0d9991241c498e7317e5ccf53dbc766"}, - {file = "zstandard-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff0852da2abe86326b20abae912d0367878dd0854b8931897d44cfeb18985472"}, - {file = "zstandard-0.21.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d12fa383e315b62630bd407477d750ec96a0f438447d0e6e496ab67b8b451d39"}, - {file = "zstandard-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1b9703fe2e6b6811886c44052647df7c37478af1b4a1a9078585806f42e5b15"}, - {file = "zstandard-0.21.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:df28aa5c241f59a7ab524f8ad8bb75d9a23f7ed9d501b0fed6d40ec3064784e8"}, - {file = "zstandard-0.21.0-cp311-cp311-win32.whl", hash = "sha256:0aad6090ac164a9d237d096c8af241b8dcd015524ac6dbec1330092dba151657"}, - {file = "zstandard-0.21.0-cp311-cp311-win_amd64.whl", hash = "sha256:48b6233b5c4cacb7afb0ee6b4f91820afbb6c0e3ae0fa10abbc20000acdf4f11"}, - {file = "zstandard-0.21.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e7d560ce14fd209db6adacce8908244503a009c6c39eee0c10f138996cd66d3e"}, - {file = "zstandard-0.21.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e6e131a4df2eb6f64961cea6f979cdff22d6e0d5516feb0d09492c8fd36f3bc"}, - {file = "zstandard-0.21.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e0c62a67ff425927898cf43da2cf6b852289ebcc2054514ea9bf121bec10a5"}, - {file = "zstandard-0.21.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1545fb9cb93e043351d0cb2ee73fa0ab32e61298968667bb924aac166278c3fc"}, - {file = "zstandard-0.21.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe6c821eb6870f81d73bf10e5deed80edcac1e63fbc40610e61f340723fd5f7c"}, - {file = "zstandard-0.21.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ddb086ea3b915e50f6604be93f4f64f168d3fc3cef3585bb9a375d5834392d4f"}, - {file = "zstandard-0.21.0-cp37-cp37m-win32.whl", hash = "sha256:57ac078ad7333c9db7a74804684099c4c77f98971c151cee18d17a12649bc25c"}, - {file = "zstandard-0.21.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1243b01fb7926a5a0417120c57d4c28b25a0200284af0525fddba812d575f605"}, - {file = "zstandard-0.21.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ea68b1ba4f9678ac3d3e370d96442a6332d431e5050223626bdce748692226ea"}, - {file = "zstandard-0.21.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8070c1cdb4587a8aa038638acda3bd97c43c59e1e31705f2766d5576b329e97c"}, - {file = "zstandard-0.21.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af612c96599b17e4930fe58bffd6514e6c25509d120f4eae6031b7595912f85"}, - {file = "zstandard-0.21.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cff891e37b167bc477f35562cda1248acc115dbafbea4f3af54ec70821090965"}, - {file = "zstandard-0.21.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9fec02ce2b38e8b2e86079ff0b912445495e8ab0b137f9c0505f88ad0d61296"}, - {file = "zstandard-0.21.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0bdbe350691dec3078b187b8304e6a9c4d9db3eb2d50ab5b1d748533e746d099"}, - {file = "zstandard-0.21.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b69cccd06a4a0a1d9fb3ec9a97600055cf03030ed7048d4bcb88c574f7895773"}, - {file = "zstandard-0.21.0-cp38-cp38-win32.whl", hash = "sha256:9980489f066a391c5572bc7dc471e903fb134e0b0001ea9b1d3eff85af0a6f1b"}, - {file = "zstandard-0.21.0-cp38-cp38-win_amd64.whl", hash = "sha256:0e1e94a9d9e35dc04bf90055e914077c80b1e0c15454cc5419e82529d3e70728"}, - {file = "zstandard-0.21.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2d61675b2a73edcef5e327e38eb62bdfc89009960f0e3991eae5cc3d54718de"}, - {file = "zstandard-0.21.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25fbfef672ad798afab12e8fd204d122fca3bc8e2dcb0a2ba73bf0a0ac0f5f07"}, - {file = "zstandard-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62957069a7c2626ae80023998757e27bd28d933b165c487ab6f83ad3337f773d"}, - {file = "zstandard-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14e10ed461e4807471075d4b7a2af51f5234c8f1e2a0c1d37d5ca49aaaad49e8"}, - {file = "zstandard-0.21.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9cff89a036c639a6a9299bf19e16bfb9ac7def9a7634c52c257166db09d950e7"}, - {file = "zstandard-0.21.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52b2b5e3e7670bd25835e0e0730a236f2b0df87672d99d3bf4bf87248aa659fb"}, - {file = "zstandard-0.21.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b1367da0dde8ae5040ef0413fb57b5baeac39d8931c70536d5f013b11d3fc3a5"}, - {file = "zstandard-0.21.0-cp39-cp39-win32.whl", hash = "sha256:db62cbe7a965e68ad2217a056107cc43d41764c66c895be05cf9c8b19578ce9c"}, - {file = "zstandard-0.21.0-cp39-cp39-win_amd64.whl", hash = "sha256:a8d200617d5c876221304b0e3fe43307adde291b4a897e7b0617a61611dfff6a"}, - {file = "zstandard-0.21.0.tar.gz", hash = "sha256:f08e3a10d01a247877e4cb61a82a319ea746c356a3786558bed2481e6c405546"}, -] - -[package.dependencies] -cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\""} - -[package.extras] -cffi = ["cffi (>=1.11)"] - [metadata] lock-version = "2.0" python-versions = ">=3.10, <3.13" -content-hash = "a927f27f5e762f635f437f607655cf51d80a849084485ccf936a9861f5480000" +content-hash = "299ec5e65d2968c30ec18ec90143406a1c9fd6e707d30c9bb2cdc8a490e9c791" diff --git a/pyproject.toml b/pyproject.toml index 4b25e5da..589daeb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,64 +1,91 @@ [tool.poetry] -name = "sparse_autoencoder" -version = "0.0.0" -description = "Sparse Autoencoder for Mechanistic Interpretability" -authors = ["Alan Cooney <41682961+alan-cooney@users.noreply.github.com>"] -license = "MIT" -readme = "README.md" -include = ["sparse_autoencoder"] + authors =["Alan Cooney <41682961+alan-cooney@users.noreply.github.com>"] + description="Sparse Autoencoder for Mechanistic Interpretability" + include =["sparse_autoencoder"] + license ="MIT" + name ="sparse_autoencoder" + readme ="README.md" + version ="0.0.0" -[tool.poetry.dependencies] -python = ">=3.10, <3.13" -einops = ">=0.6" -torch = ">=2.1" -zstandard = "^0.21.0" -wandb = "^0.15.12" + [tool.poetry.dependencies] + einops=">=0.6" + python=">=3.10, <3.13" + torch =">=2.1" + wandb =">=0.15.12" -[tool.poetry.group.dev.dependencies] -pytest = ">=7" -pytest-cov = ">=4" -jupyter = ">=1" -plotly = ">=5" -pylint = "^3.0.2" -ruff = "^0.1.4" -pyright = "^1.1.334" + [tool.poetry.group.dev.dependencies] + jupyter =">=1" + plotly =">=5" + poethepoet=">=0.24.2" + pre-commit=">=3.5.0" + pyright =">=1.1.334" + pytest =">=7" + pytest-cov=">=4" + ruff =">=0.1.4" -[tool.poetry.group.demos.dependencies] -jupyterlab = ">=3" -transformer-lens = "^1.9.0" -pandas = ">=2.1.2" + [tool.poetry.group.demos.dependencies] + jupyterlab =">=3" + pandas =">=2.1.2" + transformer-lens=">=1.9.0" + +[tool.poe.tasks] + check =["format", "lint", "test", "typecheck"] + format ="ruff format sparse_autoencoder" + lint ="ruff check sparse_autoencoder --fix" + precommit="pre-commit run --all-files" + test ="pytest" + typecheck="pyright" [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" + build-backend="poetry.core.masonry.api" + requires =["poetry-core"] + +[tool.pytest] -[tool.pytest.ini_options] -filterwarnings = [ - "ignore:pkg_resources is deprecated as an API:DeprecationWarning", - # Ignore numpy.distutils deprecation warning caused by pandas - # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils - "ignore:distutils Version classes are deprecated:DeprecationWarning" -] -addopts = """--jaxtyping-packages=sparse_autoencoder,beartype.beartype \ --W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning \ ---doctest-modules""" + [tool.pytest.ini_options] + addopts="""--jaxtyping-packages=sparse_autoencoder,beartype.beartype \ + -W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning \ + --doctest-modules""" + filterwarnings=[ + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", + # Ignore numpy.distutils deprecation warning caused by pandas + # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils + "ignore:distutils Version classes are deprecated:DeprecationWarning", + ] +[tool.pyright] + include =["sparse_autoencoder"] + reportIncompatibleMethodOverride=true [tool.ruff] -exclude = [ - "/.venv", - "*/snapshots/", -] -select = ["E", "F", "I001"] # errors, flake8, isort equivalents -ignore = ["E402", "E721", "E731", "E741", "F722"] -ignore-init-module-imports = true -fixable = ["I001", "F401"] # Auto-fix isort, unused variables -line-length = 100 + exclude=["*/snapshots/", "/.venv"] + ignore=[ + "ANN101", # self type annotation (it's inferred) + "ANN204", # __init__() return type (it's inferred) + "E731", # No lambdas (can be useful) + "F722", # Forward annotations check (conflicts with jaxtyping) + "FA102", # Annotations support (Python >= 3.9 is fine) + "FIX002", # TODO issue link (overkill) + "INP001", # __init__.py for all packages (Python >= 3.3 is fine) + "PGH003", # No general type: ignore (too strict) + "S101", # Use of assert detected (it's needed for tests) + "TCH002", # Type checking imports (conflicts with beartype) + "TD00", # TODO banned (we're in alpha) + # Rules that conflict with ruff format + "COM812", + "ISC001", + ] + ignore-init-module-imports=true + line-length=100 + required-version="0.1.4" + select=["ALL"] -[tool.ruff.isort] -combine-as-imports = true # Combine imports from the same module -lines-after-imports = 2 + [tool.ruff.lint.isort] + force-sort-within-sections=true + lines-after-imports =2 -[tool.pyright] -reportIncompatibleMethodOverride = true -include = ["sparse_autoencoder"] \ No newline at end of file + [tool.ruff.lint.pydocstyle] + convention="google" + +[tool.ruff.pylint] + max-args=10 diff --git a/sparse_autoencoder/__init__.py b/sparse_autoencoder/__init__.py index 9de10a13..b8137dfc 100644 --- a/sparse_autoencoder/__init__.py +++ b/sparse_autoencoder/__init__.py @@ -13,13 +13,13 @@ __all__ = [ - ActivationStore, - ActivationStoreBatch, - ActivationStoreItem, - DiskActivationStore, - ListActivationStore, - TensorActivationStore, - SparseAutoencoder, - create_src_dataloader, - pipeline, + "ActivationStore", + "ActivationStoreBatch", + "ActivationStoreItem", + "DiskActivationStore", + "ListActivationStore", + "TensorActivationStore", + "SparseAutoencoder", + "create_src_dataloader", + "pipeline", ] diff --git a/sparse_autoencoder/activation_store/base_store.py b/sparse_autoencoder/activation_store/base_store.py index 30c0005d..84d30090 100644 --- a/sparse_autoencoder/activation_store/base_store.py +++ b/sparse_autoencoder/activation_store/base_store.py @@ -1,5 +1,6 @@ """Activation Store Base Class.""" from abc import ABC, abstractmethod +from concurrent.futures import Future from jaxtyping import Float from torch import Tensor @@ -34,7 +35,6 @@ class ActivationStore(Dataset, ABC): `__getitem__` and `__len__` methods from the underlying `torch.utils.data.Dataset` class). Example: - >>> import torch >>> class MyActivationStore(ActivationStore): ... def __init__(self): @@ -63,17 +63,17 @@ class ActivationStore(Dataset, ABC): """ @abstractmethod - def append(self, item: ActivationStoreItem): + def append(self, item: ActivationStoreItem) -> Future | None: """Add a Single Item to the Store.""" raise NotImplementedError @abstractmethod - def extend(self, batch: ActivationStoreBatch): + def extend(self, batch: ActivationStoreBatch) -> Future | None: """Add a Batch to the Store.""" raise NotImplementedError @abstractmethod - def empty(self): + def empty(self) -> None: """Empty the Store.""" raise NotImplementedError @@ -91,5 +91,10 @@ def __getitem__(self, index: int) -> ActivationStoreItem: class StoreFullError(IndexError): """Exception raised when the activation store is full.""" - def __init__(self, message="Activation store is full"): + def __init__(self, message: str = "Activation store is full"): + """Initialise the exception. + + Args: + message: Override the default message. + """ super().__init__(message) diff --git a/sparse_autoencoder/activation_store/disk_store.py b/sparse_autoencoder/activation_store/disk_store.py index 4269aa2b..ef77776a 100644 --- a/sparse_autoencoder/activation_store/disk_store.py +++ b/sparse_autoencoder/activation_store/disk_store.py @@ -1,9 +1,9 @@ """Disk Activation Store.""" -import tempfile from concurrent.futures import Future, ThreadPoolExecutor from multiprocessing import Manager from multiprocessing.managers import ListProxy, ValueProxy from pathlib import Path +import tempfile from threading import Lock import torch @@ -30,25 +30,11 @@ class DiskActivationStore(ActivationStore): Multiprocess safe (supports writing from multiple GPU workers). Warning: - Unless you want to keep and use existing .pt files in the storage directory when initialized, set `empty_dir` to `True`. Note also that :meth:`close` must be called to ensure all activation vectors are written to disk after the last batch has been added to the store. - - Args: - storage_path: Path to the directory where the activation vectors will be stored. Defaults to - the OS temporary directory. - empty_dir: Whether to empty the directory before writing. Generally you want to set this to - `True` as otherwise the directory may contain stale activation vectors from previous - runs. - max_cache_size: The maximum number of activation vectors to cache in memory before writing - to disk. Note this is only followed approximately. - num_workers: Number of CPU workers to use for non-blocking writes to the file system (so - that the model can keep running whilst it writes the previous activations to disk). This - should be less than the number of CPU cores available. You don't need multiple GPUs to - take advantage of this feature. """ _storage_path: Path @@ -56,7 +42,7 @@ class DiskActivationStore(ActivationStore): _cache: ListProxy """Cache for Activation Vectors. - + Activation vectors are buffered in memory until the cache is full, at which point they are written to disk. """ @@ -72,17 +58,32 @@ class DiskActivationStore(ActivationStore): _disk_n_activation_vectors: ValueProxy[int] """Length of the Store (on disk). - + Minus 1 signifies not calculated yet. """ def __init__( self, storage_path: Path = DEFAULT_DISK_ACTIVATION_STORE_PATH, - empty_dir: bool = False, max_cache_size: int = 10_000, num_workers: int = 6, + *, + empty_dir: bool = False, ): + """Initialize the Disk Activation Store. + + Args: + storage_path: Path to the directory where the activation vectors will be stored. + max_cache_size: The maximum number of activation vectors to cache in memory before + writing to disk. Note this is only followed approximately. + num_workers: Number of CPU workers to use for non-blocking writes to the file system (so + that the model can keep running whilst it writes the previous activations to disk). + This should be less than the number of CPU cores available. You don't need multiple + GPUs to take advantage of this feature. + empty_dir: Whether to empty the directory before writing. Generally you want to set this + to `True` as otherwise the directory may contain stale activation vectors from + previous runs. + """ super().__init__() # Setup the storage directory @@ -103,7 +104,7 @@ def __init__( # Create a threadpool for non-blocking writes to the cache self._thread_pool = ThreadPoolExecutor(num_workers) - def _write_to_disk(self, wait_for_max: bool = False) -> None: + def _write_to_disk(self, *, wait_for_max: bool = False) -> None: """Write the contents of the queue to disk. Args: @@ -123,7 +124,7 @@ def _write_to_disk(self, wait_for_max: bool = False) -> None: del self._cache[0:size_to_get] # Update the length cache - if not self._disk_n_activation_vectors.value == -1: + if self._disk_n_activation_vectors.value != -1: self._disk_n_activation_vectors.value += len(activations) stacked_activations = torch.stack(activations) @@ -135,7 +136,6 @@ def append(self, item: ActivationStoreItem) -> Future | None: """Add a Single Item to the Store. Example: - >>> store = DiskActivationStore(max_cache_size=1, empty_dir=True) >>> future = store.append(torch.randn(100)) >>> future.result() @@ -162,7 +162,6 @@ def extend(self, batch: ActivationStoreBatch) -> Future | None: """Add a Batch to the Store. Example: - >>> store = DiskActivationStore(max_cache_size=10, empty_dir=True) >>> future = store.extend(torch.randn(10, 100)) >>> future.result() @@ -194,7 +193,6 @@ def wait_for_writes_to_complete(self) -> None: all activation vectors to be written to disk. Example: - >>> store = DiskActivationStore(max_cache_size=1, empty_dir=True) >>> future = store.append(torch.randn(100)) >>> store.wait_for_writes_to_complete() @@ -209,15 +207,13 @@ def _all_filenames(self) -> list[Path]: """Return a List of All Activation Vector Filenames.""" return list(self._storage_path.glob("*.pt")) - def empty(self): + def empty(self) -> None: """Empty the Store. Warning: - This will delete all .pt files in the top level of the storage directory. Example: - >>> store = DiskActivationStore(max_cache_size=1, empty_dir=True) >>> future = store.append(torch.randn(100)) >>> future.result() @@ -253,7 +249,6 @@ def __len__(self) -> int: """Length Dunder Method. Example: - >>> store = DiskActivationStore(max_cache_size=1, empty_dir=True) >>> print(len(store)) 0 diff --git a/sparse_autoencoder/activation_store/list_store.py b/sparse_autoencoder/activation_store/list_store.py index 229d901d..d194110c 100644 --- a/sparse_autoencoder/activation_store/list_store.py +++ b/sparse_autoencoder/activation_store/list_store.py @@ -1,9 +1,9 @@ """List Activation Store.""" -import random -import time from concurrent.futures import Future, ProcessPoolExecutor, as_completed from multiprocessing import Manager from multiprocessing.managers import ListProxy +import random +import time import torch @@ -39,7 +39,6 @@ class ListActivationStore(ActivationStore): dataset to the loader and then set the DataLoader `shuffle` argument to `False`. Examples: - Create an empty activation dataset: >>> import torch @@ -68,45 +67,48 @@ class ListActivationStore(ActivationStore): >>> next_item = next(iter(loader)) >>> next_item.shape torch.Size([2, 100]) - - Args: - data: Data to initialize the dataset with. - device: Device to store the activation vectors on. - multiprocessing_enabled: Support reading/writing to the dataset with multiple GPU workers. - This creates significant overhead, so you should only enable it if you have multiple - GPUs (and experiment with enabling/disabling it). - max_workers: Max CPU workers if multiprocessing is enabled, for writing to the list. - Default is the number of cores you have. """ _data: list[ActivationStoreItem] | ListProxy """Underlying List Data Store.""" - _device: torch.device + _device: torch.device | None """Device to Store the Activation Vectors On.""" _pool: ProcessPoolExecutor | None = None """Multiprocessing Pool.""" - _pool_exceptions: ListProxy | list = [] + _pool_exceptions: ListProxy | list """Pool Exceptions. - + Used to keep track of exceptions. """ - _pool_futures: list[Future] = [] + _pool_futures: list[Future] """Pool Futures. - + Used to keep track of processes running in the pool. """ def __init__( self, data: list[ActivationStoreItem] | None = None, - device: torch.device = torch.device("cpu"), - multiprocessing_enabled=False, + device: torch.device | None = None, max_workers: int | None = None, + *, + multiprocessing_enabled: bool = False, ) -> None: + """Initialize the List Activation Store. + + Args: + data: Data to initialize the dataset with. + device: Device to store the activation vectors on. + max_workers: Max CPU workers if multiprocessing is enabled, for writing to the list. + Default is the number of cores you have. + multiprocessing_enabled: Support reading/writing to the dataset with multiple GPU + workers. This creates significant overhead, so you should only enable it if you have + multiple GPUs (and experiment with enabling/disabling it). + """ # Default to empty if data is None: data = [] @@ -121,6 +123,9 @@ def __init__( self._pool_exceptions = manager.list() else: self._data = data + self._pool_exceptions = [] + + self._pool_futures = [] # Device for storing the activation vectors self._device = device @@ -131,7 +136,6 @@ def __len__(self) -> int: Returns the number of activation vectors in the dataset. Example: - >>> import torch >>> store = ListActivationStore() >>> store.append(torch.randn(100)) @@ -166,7 +170,6 @@ def __getitem__(self, index: int) -> ActivationStoreItem: """Get Item Dunder Method. Example: - >>> import torch >>> store = ListActivationStore() >>> store.append(torch.zeros(5)) @@ -182,13 +185,12 @@ def __getitem__(self, index: int) -> ActivationStoreItem: """ return self._data[index] - def shuffle(self): + def shuffle(self) -> None: """Shuffle the Data In-Place. This is much faster than using the shuffle argument on `torch.utils.data.DataLoader`. Example: - >>> import torch >>> _seed = torch.manual_seed(42) >>> store = ListActivationStore() @@ -203,13 +205,12 @@ def shuffle(self): self.wait_for_writes_to_complete() random.shuffle(self._data) - def append(self, item: ActivationStoreItem) -> None: + def append(self, item: ActivationStoreItem) -> Future | None: """Append a single item to the dataset. Note **append is blocking**. For better performance use extend instead with batches. Example: - >>> import torch >>> store = ListActivationStore() >>> store.append(torch.randn(100)) @@ -228,21 +229,20 @@ def _extend(self, batch: ActivationStoreBatch) -> None: To be called by :meth:`extend`. Args: - items: A list of items to add to the dataset. + batch: A batch of items to add to the dataset. """ try: # Unstack to a list of tensors items: list[ActivationStoreItem] = resize_to_list_vectors(batch) self._data.extend(items) - except Exception as e: # pylint: disable=broad-except + except Exception as e: # noqa: BLE001 self._pool_exceptions.append(e) - def extend(self, batch: ActivationStoreBatch) -> None: + def extend(self, batch: ActivationStoreBatch) -> Future | None: """Extend the dataset with multiple items (non-blocking). Example: - >>> import torch >>> store = ListActivationStore() >>> batch = torch.randn(10, 100) @@ -251,7 +251,7 @@ def extend(self, batch: ActivationStoreBatch) -> None: 10 Args: - items: A list of items to add to the dataset. + batch: A batch of items to add to the dataset. """ # Schedule _extend to run in a separate process if self._pool: @@ -262,12 +262,11 @@ def extend(self, batch: ActivationStoreBatch) -> None: self._extend(batch) def wait_for_writes_to_complete(self) -> None: - """Wait for Writes to Complete + """Wait for Writes to Complete. Wait for any non-blocking writes (e.g. calls to :meth:`append`) to complete. Example: - >>> import torch >>> store = ListActivationStore(multiprocessing_enabled=True) >>> store.extend(torch.randn(3, 100)) @@ -284,18 +283,14 @@ def wait_for_writes_to_complete(self) -> None: time.sleep(1) if self._pool_exceptions: - exceptions_report = "\n".join( - f"{e}\n{tb}" for e, tb in self._pool_exceptions - ) - raise RuntimeError( - f"Exceptions occurred in background workers:\n{exceptions_report}" - ) - - def empty(self): + exceptions_report = "\n".join(f"{e}\n{tb}" for e, tb in self._pool_exceptions) + msg = f"Exceptions occurred in background workers:\n{exceptions_report}" + raise RuntimeError(msg) + + def empty(self) -> None: """Empty the dataset. Example: - >>> import torch >>> store = ListActivationStore() >>> store.append(torch.randn(100)) diff --git a/sparse_autoencoder/activation_store/tensor_store.py b/sparse_autoencoder/activation_store/tensor_store.py index d8e8db90..09885d07 100644 --- a/sparse_autoencoder/activation_store/tensor_store.py +++ b/sparse_autoencoder/activation_store/tensor_store.py @@ -1,6 +1,6 @@ """Tensor Activation Store.""" -import torch from jaxtyping import Float +import torch from torch import Tensor from sparse_autoencoder.activation_store.base_store import ( @@ -28,7 +28,6 @@ class TensorActivationStore(ActivationStore): additional :meth:`append` and :meth:`extend` methods (the latter of which is non-blocking). Examples: - Create an empty activation dataset: >>> import torch @@ -58,11 +57,6 @@ class TensorActivationStore(ActivationStore): >>> next_item = next(iter(loader)) >>> next_item.shape torch.Size([2, 100]) - - Args: - max_items: Maximum number of items to store (individual activation vectors) - num_neurons: Number of neurons in each activation vector. - device: Device to store the activation vectors on. """ _data: TensorActivationStoreData @@ -78,9 +72,15 @@ def __init__( self, max_items: int, num_neurons: int, - device: torch.device = torch.device("cpu"), + device: torch.device | None = None, ) -> None: - # Initialise the datastore + """Initialise the Tensor Activation Store. + + Args: + max_items: Maximum number of items to store (individual activation vectors) + num_neurons: Number of neurons in each activation vector. + device: Device to store the activation vectors on. + """ self._data = torch.empty((max_items, num_neurons), device=device) self._max_items = max_items @@ -90,7 +90,6 @@ def __len__(self) -> int: Returns the number of activation vectors in the dataset. Example: - >>> import torch >>> store = TensorActivationStore(max_items=10_000_000, num_neurons=100) >>> store.append(torch.randn(100)) @@ -106,7 +105,6 @@ def __sizeof__(self) -> int: Returns the size of the underlying tensor in bytes. Example: - >>> import torch >>> store = TensorActivationStore(max_items=2, num_neurons=100) >>> store.__sizeof__() # Pre-allocated tensor of 2x100 @@ -118,7 +116,6 @@ def __getitem__(self, index: int) -> ActivationStoreItem: """Get Item Dunder Method. Example: - >>> import torch >>> store = TensorActivationStore(max_items=2, num_neurons=5) >>> store.append(torch.zeros(5)) @@ -137,19 +134,17 @@ def __getitem__(self, index: int) -> ActivationStoreItem: """ # Check in range if index >= self.items_stored: - raise IndexError( - f"Index {index} out of range (only {self.items_stored} items stored)" - ) + msg = f"Index {index} out of range (only {self.items_stored} items stored)" + raise IndexError(msg) return self._data[index] - def shuffle(self): + def shuffle(self) -> None: """Shuffle the Data In-Place. This is much faster than using the shuffle argument on `torch.utils.data.DataLoader`. Example: - >>> import torch >>> _seed = torch.manual_seed(42) >>> store = TensorActivationStore(max_items=10, num_neurons=1) @@ -170,7 +165,6 @@ def append(self, item: ActivationStoreItem) -> None: """Add a single item to the store. Example: - >>> import torch >>> store = TensorActivationStore(max_items=10, num_neurons=5) >>> store.append(torch.zeros(5)) @@ -186,7 +180,7 @@ def append(self, item: ActivationStoreItem) -> None: """ # Check we have space if self.items_stored + 1 > self._max_items: - raise StoreFullError() + raise StoreFullError self._data[self.items_stored] = item.to( self._data.device, @@ -197,7 +191,6 @@ def extend(self, batch: ActivationStoreBatch) -> None: """Add a batch to the store. Examples: - >>> import torch >>> store = TensorActivationStore(max_items=10, num_neurons=5) >>> store.extend(torch.zeros(2, 5)) @@ -216,30 +209,28 @@ def extend(self, batch: ActivationStoreBatch) -> None: IndexError: If there is no space remaining. """ reshaped: Float[Tensor, "subset_item neuron"] = resize_to_single_item_dimension( - batch + batch, ) # Check we have space num_activation_tensors: int = reshaped.shape[0] if self.items_stored + num_activation_tensors > self._max_items: if reshaped.shape[0] > self._max_items: - raise ValueError( - f"Single batch of {num_activation_tensors} activations is larger \ - than the total maximum in the store of {self._max_items}." - ) + msg = f"Single batch of {num_activation_tensors} activations is larger than the \ + total maximum in the store of {self._max_items}." + raise ValueError(msg) - raise StoreFullError() + raise StoreFullError - self._data[ - self.items_stored : self.items_stored + num_activation_tensors - ] = reshaped.to(self._data.device) + self._data[self.items_stored : self.items_stored + num_activation_tensors] = reshaped.to( + self._data.device + ) self.items_stored += num_activation_tensors def empty(self) -> None: """Empty the store. Example: - >>> import torch >>> store = TensorActivationStore(max_items=10, num_neurons=5) >>> store.extend(torch.zeros(2, 5)) diff --git a/sparse_autoencoder/activation_store/utils/extend_resize.py b/sparse_autoencoder/activation_store/utils/extend_resize.py index b30f3a9b..d079e845 100644 --- a/sparse_autoencoder/activation_store/utils/extend_resize.py +++ b/sparse_autoencoder/activation_store/utils/extend_resize.py @@ -10,7 +10,7 @@ def resize_to_list_vectors( - input: ActivationStoreBatch, + batched_tensor: ActivationStoreBatch, ) -> list[ActivationStoreItem]: """Resize Extend List Vectors. @@ -18,7 +18,6 @@ def resize_to_list_vectors( the neurons dimension), and returns a list of vectors each of size [neurons]. Examples: - With 2 axis (e.g. pos neuron): >>> import torch @@ -42,20 +41,21 @@ def resize_to_list_vectors( '27 items of shape torch.Size([100])' Args: - input: Input Activation Store Batch + batched_tensor: Input Activation Store Batch Returns: List of Activation Store Item Vectors """ rearranged: Float[Tensor, "batch neuron"] = rearrange( - input, "... neurons -> (...) neurons" + batched_tensor, + "... neurons -> (...) neurons", ) res = rearranged.unbind(0) return list(res) def resize_to_single_item_dimension( - input: ActivationStoreBatch, + batch_activations: ActivationStoreBatch, ) -> Float[Tensor, "item neuron"]: """Resize Extend Single Item Dimension. @@ -63,7 +63,6 @@ def resize_to_single_item_dimension( the neurons dimension), and returns a single tensor of size [item, neurons]. Examples: - With 2 axis (e.g. pos neuron): >>> import torch @@ -87,9 +86,9 @@ def resize_to_single_item_dimension( torch.Size([27, 100]) Args: - input: Input Activation Store Batch + batch_activations: Input Activation Store Batch Returns: Single Tensor of Activation Store Items """ - return rearrange(input, "... neurons -> (...) neurons") + return rearrange(batch_activations, "... neurons -> (...) neurons") diff --git a/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py b/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py index dc1d7f13..7bc2755a 100644 --- a/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py +++ b/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py @@ -13,14 +13,19 @@ class TestResizeListVectors: """Resize to List Vectors Tests.""" @pytest.mark.parametrize( - "input_shape, expected_len, expected_shape", + ("input_shape", "expected_len", "expected_shape"), [ ((3, 100), 3, torch.Size([100])), ((3, 3, 100), 9, torch.Size([100])), ((3, 3, 3, 100), 27, torch.Size([100])), ], ) - def test_resize_to_list_vectors(self, input_shape, expected_len, expected_shape): + def test_resize_to_list_vectors( + self, + input_shape: tuple[int], + expected_len: int, + expected_shape: torch.Tensor, + ) -> None: """Check each item's shape in the resulting list.""" input_tensor = torch.rand(input_shape) result = resize_to_list_vectors(ActivationStoreBatch(input_tensor)) @@ -30,7 +35,7 @@ def test_resize_to_list_vectors(self, input_shape, expected_len, expected_shape) item.shape == expected_shape for item in result ), f"All items should have shape {expected_shape}" - def test_resize_to_list_vectors_values(self): + def test_resize_to_list_vectors_values(self) -> None: """Check each item's values in the resulting list.""" input_tensor = torch.tensor([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) expected_output = [ @@ -41,9 +46,9 @@ def test_resize_to_list_vectors_values(self): ] result = resize_to_list_vectors(ActivationStoreBatch(input_tensor)) - for expected, output in zip(expected_output, result): + for expected, output in zip(expected_output, result, strict=True): assert torch.all( - torch.eq(expected, output) + torch.eq(expected, output), ), "Tensor values do not match expected" @@ -51,26 +56,30 @@ class TestResizeSingleItemDimension: """Resize to Single Item Dimension Tests.""" @pytest.mark.parametrize( - "input_shape, expected_shape", + ("input_shape", "expected_shape"), [ ((3, 100), (3, 100)), ((3, 3, 100), (9, 100)), ((3, 3, 3, 100), (27, 100)), ], ) - def test_resize_to_single_item_dimension(self, input_shape, expected_shape): + def test_resize_to_single_item_dimension( + self, + input_shape: tuple[int], + expected_shape: tuple[int], + ) -> None: """Check the resulting tensor shape.""" input_tensor = torch.randn(input_shape) result = resize_to_single_item_dimension(ActivationStoreBatch(input_tensor)) assert result.shape == expected_shape, f"Expected tensor shape {expected_shape}" - def test_resize_to_single_item_dimension_values(self): + def test_resize_to_single_item_dimension_values(self) -> None: """Check the resulting tensor values.""" input_tensor = torch.tensor([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) expected_output = torch.tensor([[1.0, 2], [3, 4], [5, 6], [7, 8]]) result = resize_to_single_item_dimension(ActivationStoreBatch(input_tensor)) assert torch.all( - torch.eq(expected_output, result) + torch.eq(expected_output, result), ), "Tensor values do not match expected" diff --git a/sparse_autoencoder/autoencoder/loss.py b/sparse_autoencoder/autoencoder/loss.py index 3b6f6c53..9a63bcf6 100644 --- a/sparse_autoencoder/autoencoder/loss.py +++ b/sparse_autoencoder/autoencoder/loss.py @@ -1,6 +1,6 @@ """Loss function for the Sparse Autoencoder.""" -import torch from jaxtyping import Float +import torch from torch import Tensor from torch.nn.functional import mse_loss @@ -18,7 +18,6 @@ def reconstruction_loss( polysemantic and monosemantic representations of true features. Examples: - >>> input_activations = torch.tensor([[3.0, 4]]) >>> output_activations = torch.tensor([[1.0, 5]]) >>> reconstruction_loss(input_activations, output_activations) @@ -35,13 +34,12 @@ def reconstruction_loss( def l1_loss(learned_activations: Float[Tensor, "*batch learned_activations"]) -> Tensor: - """L1 Loss on Learned Activations + """L1 Loss on Learned Activations. - L1 loss penalty is the absolute sum of the learned activations. The L1 penality is this + L1 loss penalty is the absolute sum of the learned activations. The L1 penalty is this multiplied by the l1_coefficient (designed to encourage sparsity). Examples: - >>> learned_activations = torch.tensor([[2.0, -3]]) >>> l1_loss(learned_activations) tensor([5.]) @@ -68,7 +66,6 @@ def sae_training_loss( https://transformer-circuits.pub/2023/monosemantic-features/index.html#setup-autoencoder-motivation Examples: - >>> reconstruction_loss_mse = torch.tensor([2.5000]) >>> l1_loss_learned_activations = torch.tensor([1.]) >>> l1_coefficient = 0.5 diff --git a/sparse_autoencoder/autoencoder/model.py b/sparse_autoencoder/autoencoder/model.py index a3e5358e..8fcb048a 100644 --- a/sparse_autoencoder/autoencoder/model.py +++ b/sparse_autoencoder/autoencoder/model.py @@ -1,6 +1,6 @@ """The Sparse Autoencoder Model.""" -import torch from jaxtyping import Float +import torch from torch import Tensor from torch.nn import Linear, Module, ReLU, Sequential from torch.nn.parameter import Parameter @@ -9,25 +9,17 @@ class SparseAutoencoder(Module): - """Sparse Autoencoder Model. - - Args: - n_input_features: Number of input features (e.g. `d_mlp` if training on MLP activations from - TransformerLens). - n_learned_features: Number of learned features. The initial paper experimented with 1× to - 256× the number of input features, and primarily used 8x. - geometric_median_dataset: Estimated geometric median of the dataset. - """ + """Sparse Autoencoder Model.""" geometric_median_dataset: Float[Tensor, " input_activations"] """Estimated Geometric Median of the Dataset. - + Used for initialising :attr:`tied_bias`. """ tied_bias: Float[Parameter, " input_activations"] """Tied Bias Parameter. - + The same bias is used pre-encoder and post-decoder. """ @@ -43,6 +35,15 @@ def __init__( n_learned_features: int, geometric_median_dataset: Float[Tensor, " input_activations"], ) -> None: + """Initialize the Sparse Autoencoder Model. + + Args: + n_input_features: Number of input features (e.g. `d_mlp` if training on MLP activations + from TransformerLens). + n_learned_features: Number of learned features. The initial paper experimented with 1 to + 256 times the number of input features, and primarily used a multiple of 8. + geometric_median_dataset: Estimated geometric median of the dataset. + """ super().__init__() self.n_input_features = n_input_features @@ -70,7 +71,8 @@ def __init__( ) def forward( - self, x: Float[Tensor, "batch input_activations"] + self, + x: Float[Tensor, "batch input_activations"], ) -> tuple[ Float[Tensor, "batch learned_activations"], Float[Tensor, "batch input_activations"], @@ -78,7 +80,7 @@ def forward( """Forward Pass. Args: - input: Input activations (e.g. activations from an MLP layer in a transformer model). + x: Input activations (e.g. activations from an MLP layer in a transformer model). Returns: Tuple of learned activations and decoded activations. diff --git a/sparse_autoencoder/autoencoder/tests/test_loss.py b/sparse_autoencoder/autoencoder/tests/test_loss.py index df1f74c3..34978e94 100644 --- a/sparse_autoencoder/autoencoder/tests/test_loss.py +++ b/sparse_autoencoder/autoencoder/tests/test_loss.py @@ -16,19 +16,20 @@ def test_loss() -> None: l1_coefficient = 0.5 squared_errors: float = 0.0 - for i, o in zip(input_activations, output_activations): + for i, o in zip(input_activations, output_activations, strict=True): squared_errors += (i - o) ** 2 mse = squared_errors / len(input_activations) l1_penalty: float = 0.0 - for l in learned_activations: - l1_penalty += abs(l) * l1_coefficient + for neuron in learned_activations: + l1_penalty += abs(neuron) * l1_coefficient expected: float = mse + l1_penalty # Compute the reconstruction_loss, l1_loss, and sae_training_loss mse_tensor = reconstruction_loss( - torch.tensor(input_activations), torch.tensor(output_activations) + torch.tensor(input_activations), + torch.tensor(output_activations), ) l1_tensor = l1_loss(torch.tensor(learned_activations)) result = sae_training_loss(mse_tensor, l1_tensor, l1_coefficient) diff --git a/sparse_autoencoder/autoencoder/tests/test_model.py b/sparse_autoencoder/autoencoder/tests/test_model.py index d8cc1e1b..feaba6f9 100644 --- a/sparse_autoencoder/autoencoder/tests/test_model.py +++ b/sparse_autoencoder/autoencoder/tests/test_model.py @@ -4,14 +4,14 @@ from sparse_autoencoder.autoencoder.model import SparseAutoencoder -def test_initialize_tied_bias(): +def test_initialize_tied_bias() -> None: """Check the tied bias is initialised correctly.""" geometric_median = torch.tensor([1.0, 2.0, 3.0]) model = SparseAutoencoder(3, 6, geometric_median) assert torch.allclose(model.tied_bias, geometric_median) -def test_encoded_decoded_shape_same(): +def test_encoded_decoded_shape_same() -> None: """Check the input and output are the same shape.""" geometric_median = torch.tensor([1.0, 2.0, 3.0]) model = SparseAutoencoder(3, 6, geometric_median) diff --git a/sparse_autoencoder/autoencoder/tied_bias.py b/sparse_autoencoder/autoencoder/tied_bias.py index 578ec4f3..59084e75 100644 --- a/sparse_autoencoder/autoencoder/tied_bias.py +++ b/sparse_autoencoder/autoencoder/tied_bias.py @@ -30,7 +30,8 @@ def __init__( self.bias = bias def forward( - self, x: Float[Tensor, "*batch input_activations"] + self, + x: Float[Tensor, "*batch input_activations"], ) -> Float[Tensor, "*batch input_activations"]: """Forward Pass.""" return x - self.bias @@ -61,7 +62,8 @@ def __init__( self.bias = bias def forward( - self, x: Float[Tensor, "*batch input_activations"] + self, + x: Float[Tensor, "*batch input_activations"], ) -> Float[Tensor, "*batch input_activations"]: """Forward Pass.""" return x + self.bias diff --git a/sparse_autoencoder/src_data/datasets/dummy.py b/sparse_autoencoder/src_data/datasets/dummy.py index fdd9e704..259e830b 100644 --- a/sparse_autoencoder/src_data/datasets/dummy.py +++ b/sparse_autoencoder/src_data/datasets/dummy.py @@ -1,4 +1,5 @@ """Dummy dataset for testing/examples.""" +from jaxtyping import Int import torch from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -9,18 +10,41 @@ class RandomIntDataset(Dataset): """Dummy dataset for testing/examples.""" - def __init__(self, num_samples, batch_size, pos, vocab_size=50000): + def __init__( + self, + num_samples: int, + batch_size: int, + pos: int, + vocab_size: int = 50000, + ): + """Initialise the dataset. + + Args: + num_samples: Number of items in the dataset. + batch_size: Batch size. + pos: Number of tokens in each item. + vocab_size: Size of the vocabulary. + """ self.num_samples = num_samples self.batch_size = batch_size self.pos = pos self.vocab_size = vocab_size - def __len__(self): + def __len__(self) -> int: + """Length Dunder Method.""" return self.num_samples - def __getitem__(self, idx): + def __getitem__(self, _idx: int) -> Int[Tensor, " pos"]: + """Get Item Dunder Method. + + Args: + idx: Index of the item to get. + """ return torch.randint( - low=0, high=self.vocab_size, size=(self.pos,), dtype=torch.long + low=0, + high=self.vocab_size, + size=(self.pos,), + dtype=torch.long, ) @@ -32,7 +56,10 @@ def dummy_collate_fn( def create_dummy_dataloader( - num_samples: int, batch_size: int, pos: int = 512, vocab_size: int = 50000 + num_samples: int, + batch_size: int, + pos: int = 512, + vocab_size: int = 50000, ) -> DataLoader: """Create dummy dataloader.""" dataset = RandomIntDataset(num_samples, batch_size, pos, vocab_size) diff --git a/sparse_autoencoder/src_data/datasets/tests/test_neel_c4_tokenized.py b/sparse_autoencoder/src_data/datasets/tests/test_neel_c4_tokenized.py index c8db6e32..1b583ba4 100644 --- a/sparse_autoencoder/src_data/datasets/tests/test_neel_c4_tokenized.py +++ b/sparse_autoencoder/src_data/datasets/tests/test_neel_c4_tokenized.py @@ -6,20 +6,19 @@ ) -def test_collate_neel_c4_tokenized(): +def test_collate_neel_c4_tokenized() -> None: """Test the collate result is shaped as expected.""" - dataset = load_dataset( "NeelNanda/c4-code-tokenized-2b", streaming=True, split="train", keep_in_memory=True, ) + # asdf a thingyss dataset_iter = iter(dataset) first_item = next(dataset_iter) - tokens = collate_neel_c4_tokenized([first_item]) + expected_tokens_per_batch = 1024 # The dataset is all 1024 tokens per batch item - # The dataset is all 1024 tokens per batch item - assert tokens.shape[1] == 1024 + assert tokens.shape[1] == expected_tokens_per_batch diff --git a/sparse_autoencoder/src_data/src_data.py b/sparse_autoencoder/src_data/src_data.py index 81216f12..e12c9799 100644 --- a/sparse_autoencoder/src_data/src_data.py +++ b/sparse_autoencoder/src_data/src_data.py @@ -3,10 +3,10 @@ Gets large amounts of text that can be used as prompts for the source model, to be used in getting activations. -Note that for shared types, we include the shape in the docstring, as code hints aren't supported +Note that for shared types, we include the shape in the docstring, as code hints aren't supported by jaxtyping. """ -from typing import Callable +from collections.abc import Callable from datasets import IterableDataset, load_dataset from jaxtyping import Int @@ -37,7 +37,6 @@ def create_src_dataloader( Supports distributed training across GPUs with `torch.nn.DataParallel`, but not across nodes. Examples: - You can create a dataloader with the GPT2 tokenizer and pile uncopyrighted dataset as follows: >>> from sparse_autoencoder.src_data.datasets.neel_c4_tokenized import collate_neel_c4_tokenized @@ -76,11 +75,12 @@ def create_src_dataloader( # This dataset fills a buffer with buffer_size elements, then randomly samples elements from # this buffer, replacing the selected elements with new elements. shuffled_dataset = dataset.shuffle( - seed=random_seed, buffer_size=shuffle_buffer_size + seed=random_seed, + buffer_size=shuffle_buffer_size, ) return DataLoader( - shuffled_dataset, # type: ignore # TODO: Consider using the dataset directly + shuffled_dataset, # type: ignore batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, diff --git a/sparse_autoencoder/src_model/store_activations_hook.py b/sparse_autoencoder/src_model/store_activations_hook.py index 6c4f5b98..9100408d 100644 --- a/sparse_autoencoder/src_model/store_activations_hook.py +++ b/sparse_autoencoder/src_model/store_activations_hook.py @@ -8,7 +8,7 @@ def store_activations_hook( value: Float[Tensor, "*any neuron"], - hook: HookPoint, # pylint: disable=unused-argument + hook: HookPoint, # noqa: ARG001 as needed by TransformerLens store: ActivationStore, ) -> Float[Tensor, "*any neuron"]: """Store Activations Hook. @@ -16,7 +16,6 @@ def store_activations_hook( Useful for getting just the specific activations wanted, rather than the full cache. Example: - First we'll need a source model from TransformerLens and an activation store. >>> from functools import partial diff --git a/sparse_autoencoder/src_model/tests/test_store_activations_hook.py b/sparse_autoencoder/src_model/tests/test_store_activations_hook.py index 8cfb1308..f9dddfa8 100644 --- a/sparse_autoencoder/src_model/tests/test_store_activations_hook.py +++ b/sparse_autoencoder/src_model/tests/test_store_activations_hook.py @@ -8,13 +8,14 @@ from sparse_autoencoder.src_model.store_activations_hook import store_activations_hook -def test_hook_stores_activations(): +def test_hook_stores_activations() -> None: """Test that the hook stores activations correctly.""" store = ListActivationStore() model = HookedTransformer.from_pretrained("tiny-stories-1M") model.add_hook( - "blocks.1.mlp.hook_post", partial(store_activations_hook, store=store) + "blocks.1.mlp.hook_post", + partial(store_activations_hook, store=store), ) tokens = model.to_tokens("Hello world") diff --git a/sparse_autoencoder/train/generate_activations.py b/sparse_autoencoder/train/generate_activations.py index 9ad27b42..c204e9ca 100644 --- a/sparse_autoencoder/train/generate_activations.py +++ b/sparse_autoencoder/train/generate_activations.py @@ -1,8 +1,8 @@ """Generate activations for training a Sparse Autoencoder.""" from functools import partial +from jaxtyping import Int import torch -from jaxtyping import Float from torch import Tensor from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -18,18 +18,17 @@ def generate_activations( model: HookedTransformer, layer: int, - hook_name: str, + cache_name: str, store: ActivationStore, dataloader: DataLoader, num_items: int, - device: torch.device = torch.device("cpu"), + device: torch.device | None = None, ) -> None: """Generate activations for training a Sparse Autoencoder. Generates activations and updates the activation store in place. Warning: - This function is a little confusing as it uses side effects. The way it works is to add a hook to the model, which will automatically store activations every time the model runs. When it has filled up the store to the desired size, it will return `None`. Your activations will then be @@ -54,42 +53,42 @@ def generate_activations( than strict limit. device: Device to run the model on. """ - model.to(device, print_details=False) + if isinstance(device, torch.device): + model.to(device, print_details=False) # Add the hook to the model (will automatically store the activations every time the model runs) model.remove_all_hook_fns() hook = partial(store_activations_hook, store=store) - model.add_hook(hook_name, hook) + model.add_hook(cache_name, hook) # Get the input dimensions for logging - first_item: Float[Tensor, "batch pos"] = next(iter(dataloader)) + first_item: Int[Tensor, "batch pos"] = next(iter(dataloader)) batch_size: int = first_item.shape[0] context_size: int = first_item.shape[1] activations_per_batch: int = context_size * batch_size total: int = num_items - num_items % activations_per_batch - with torch.no_grad(): - # Loop through the dataloader until the store reaches the desired size - with tqdm( - dataloader, - desc="Generate Activations", - total=total, - colour="green", - position=1, - leave=False, - dynamic_ncols=True, - ) as progress_bar: - for input_ids in dataloader: - try: - input_ids = input_ids.to(device) - model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore - progress_bar.update(activations_per_batch) + # Loop through the dataloader until the store reaches the desired size + with torch.no_grad(), tqdm( + dataloader, + desc="Generate Activations", + total=total, + colour="green", + position=1, + leave=False, + dynamic_ncols=True, + ) as progress_bar: + for input_ids in dataloader: + try: + input_ids = input_ids.to(device) # noqa: PLW2901 + model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore + progress_bar.update(activations_per_batch) - # Break the loop if the store is full - except StoreFullError: - break + # Break the loop if the store is full + except StoreFullError: + break - if len(store) >= total: - return + if len(store) >= total: + return - progress_bar.close() + progress_bar.close() diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index 9caec934..b99ba3e2 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -21,9 +21,9 @@ def pipeline( activation_store: ActivationStore, num_activations_before_training: int, autoencoder: SparseAutoencoder, - sweep_parameters: SweepParametersRuntime = SweepParametersRuntime(), - device: torch.device = torch.device("cpu"), -): + sweep_parameters: SweepParametersRuntime = SweepParametersRuntime(), # noqa: B008 + device: torch.device | None = None, +) -> None: """Full pipeline for training the sparse autoEncoder. The pipeline alternates between generating activations and training the autoencoder. @@ -55,45 +55,46 @@ def pipeline( ) # Run loop until source data is exhausted: - with logging_redirect_tqdm(): - with tqdm( - desc="Generate/Train Cycles", position=0, dynamic_ncols=True - ) as progress_bar: - while True: - # Add activations to the store - generate_activations( - src_model, - src_model_activation_layer, - src_model_activation_hook_point, - activation_store, - src_dataloader, - device=device, - num_items=num_activations_before_training, - ) - if len(activation_store) == 0: - break + with logging_redirect_tqdm(), tqdm( + desc="Generate/Train Cycles", + position=0, + dynamic_ncols=True, + ) as progress_bar: + while True: + # Add activations to the store + generate_activations( + src_model, + src_model_activation_layer, + src_model_activation_hook_point, + activation_store, + src_dataloader, + device=device, + num_items=num_activations_before_training, + ) + if len(activation_store) == 0: + break - # Shuffle the store if it has a shuffle method - it is often more efficient to - # create a shuffle method ourselves rather than get the DataLoader to shuffle - if hasattr(activation_store, "shuffle"): - activation_store.shuffle() # type: ignore + # Shuffle the store if it has a shuffle method - it is often more efficient to + # create a shuffle method ourselves rather than get the DataLoader to shuffle + if hasattr(activation_store, "shuffle"): + activation_store.shuffle() # type: ignore - # Create a dataloader from the store - dataloader = DataLoader( - activation_store, - batch_size=8192, - ) + # Create a dataloader from the store + dataloader = DataLoader( + activation_store, + batch_size=8192, + ) - # Train the autoencoder - train_autoencoder( - activations_dataloader=dataloader, - autoencoder=autoencoder, - optimizer=optimizer, - sweep_parameters=sweep_parameters, - device=device, - ) + # Train the autoencoder + train_autoencoder( + activations_dataloader=dataloader, + autoencoder=autoencoder, + optimizer=optimizer, + sweep_parameters=sweep_parameters, + device=device, + ) - # Empty the store so we can fill it up again - activation_store.empty() + # Empty the store so we can fill it up again + activation_store.empty() - progress_bar.update(1) + progress_bar.update(1) diff --git a/sparse_autoencoder/train/sweep_config.py b/sparse_autoencoder/train/sweep_config.py index 352ff6f5..565d3ddf 100644 --- a/sparse_autoencoder/train/sweep_config.py +++ b/sparse_autoencoder/train/sweep_config.py @@ -20,35 +20,35 @@ class SweepParameterConfig(Parameters): adam_beta_1: Parameter[float] | None """Adam Beta 1. - + The exponential decay rate for the first moment estimates (mean) of the gradient. """ adam_beta_2: Parameter[float] | None """Adam Beta 2. - + The exponential decay rate for the second moment estimates (variance) of the gradient. """ adam_epsilon: Parameter[float] | None """Adam Epsilon. - + A small constant for numerical stability. """ adam_weight_decay: Parameter[float] | None """Adam Weight Decay. - + Weight decay (L2 penalty). """ l1_coefficient: Parameter[float] | None """L1 Penalty Coefficient. - + The L1 penalty is the absolute sum of learned (hidden) activations, multiplied by this constant. The penalty encourages sparsity in the learned activations. This loss penalty can be reduced by using more features, or using a lower L1 coefficient. - + Default values from the [original paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html). """ @@ -71,7 +71,7 @@ class SweepParametersRuntime(dict): l1_coefficient: float = 0.01 - def to_dict(self): + def to_dict(self) -> dict: """Return dict representation of this object.""" return asdict(self) @@ -86,28 +86,11 @@ class SweepConfig(WandbSweepConfig): metric: Metric = field(default_factory=lambda: Metric(name="loss")) - def to_dict(self): + def to_dict(self) -> dict: """Return dict representation of this object.""" dict_representation = asdict(self) # Convert StrEnums to strings dict_representation["method"] = dict_representation["method"].value - def remove_none_values(d): - """Recursively remove all None values from the dictionary.""" - if isinstance(d, dict): - return { - k: remove_none_values(v) - for k, v in d.items() - if v is not None and remove_none_values(v) is not None - } - elif isinstance(d, list): - return [ - remove_none_values(v) - for v in d - if v is not None and remove_none_values(v) is not None - ] - else: - return d - - return remove_none_values(dict_representation) + return dict_representation diff --git a/sparse_autoencoder/train/tests/test_generate_activations.py b/sparse_autoencoder/train/tests/test_generate_activations.py index 7e71cfb6..6ca731fd 100644 --- a/sparse_autoencoder/train/tests/test_generate_activations.py +++ b/sparse_autoencoder/train/tests/test_generate_activations.py @@ -16,13 +16,15 @@ def test_activations_generated() -> None: batch_size = 2 dataloader = create_dummy_dataloader(num_samples, batch_size) + num_items = 2 + generate_activations( model=model, layer=1, - hook_name="blocks.1.mlp.hook_post", + cache_name="blocks.1.mlp.hook_post", store=store, dataloader=dataloader, - num_items=2, + num_items=num_items, ) - assert len(store) >= 2 + assert len(store) >= num_items diff --git a/sparse_autoencoder/train/train_autoencoder.py b/sparse_autoencoder/train/train_autoencoder.py index e24acc50..afc09129 100644 --- a/sparse_autoencoder/train/train_autoencoder.py +++ b/sparse_autoencoder/train/train_autoencoder.py @@ -1,10 +1,10 @@ """Training Pipeline.""" -import torch +from torch import device, set_grad_enabled from torch.optim import Optimizer from torch.utils.data import DataLoader from tqdm.auto import tqdm - import wandb + from sparse_autoencoder.autoencoder.loss import ( l1_loss, reconstruction_loss, @@ -20,8 +20,8 @@ def train_autoencoder( optimizer: Optimizer, sweep_parameters: SweepParametersRuntime, log_interval: int = 10, - device: torch.device | None = None, -): + device: device | None = None, +) -> None: """Sparse Autoencoder Training Loop. Args: @@ -30,60 +30,61 @@ def train_autoencoder( optimizer: The optimizer to use. sweep_parameters: The sweep parameters to use. log_interval: How often to log progress. + device: Decide to use. """ n_dataset_items: int = len(activations_dataloader.dataset) # type: ignore batch_size: int = activations_dataloader.batch_size # type: ignore - with torch.set_grad_enabled(True): - with tqdm( - desc="Train Autoencoder", - total=n_dataset_items, - colour="green", - position=1, - leave=False, - dynamic_ncols=True, - ) as progress_bar: - for step, batch in enumerate(activations_dataloader): - # Zero the gradients - optimizer.zero_grad() + with set_grad_enabled(True), tqdm( # noqa: FBT003 + desc="Train Autoencoder", + total=n_dataset_items, + colour="green", + position=1, + leave=False, + dynamic_ncols=True, + ) as progress_bar: + for step, batch in enumerate(activations_dataloader): + # Zero the gradients + optimizer.zero_grad() - # Move the batch to the device (in place) - batch = batch.to(device) + # Move the batch to the device (in place) + batch = batch.to(device) # noqa: PLW2901 - # Forward pass - learned_activations, reconstructed_activations = autoencoder(batch) + # Forward pass + learned_activations, reconstructed_activations = autoencoder(batch) - # Get metrics - reconstruction_loss_mse = reconstruction_loss( - batch, reconstructed_activations - ) - l1_loss_learned_activations = l1_loss(learned_activations) - total_loss = sae_training_loss( - reconstruction_loss_mse, - l1_loss_learned_activations, - sweep_parameters.l1_coefficient, - ) - # TODO: Log dead neurons metric (get_frequencies in Neel's code) + # Get metrics + reconstruction_loss_mse = reconstruction_loss( + batch, + reconstructed_activations, + ) + l1_loss_learned_activations = l1_loss(learned_activations) + total_loss = sae_training_loss( + reconstruction_loss_mse, + l1_loss_learned_activations, + sweep_parameters.l1_coefficient, + ) + # TODO: Log dead neurons metric (get_frequencies in Neel's code) - # Backwards pass - total_loss.backward() + # Backwards pass + total_loss.backward() - # TODO: Make decoder weights and grad unit norm + # TODO: Make decoder weights and grad unit norm - optimizer.step() + optimizer.step() - # TODO: Enable neuron resampling + # TODO: Enable neuron resampling here - # Log - if step % log_interval == 0 and wandb.run is not None: - wandb.log( - { - "reconstruction_loss": reconstruction_loss_mse, - "l1_loss": l1_loss_learned_activations, - "loss": total_loss, - } - ) + # Log + if step % log_interval == 0 and wandb.run is not None: + wandb.log( + { + "reconstruction_loss": reconstruction_loss_mse, + "l1_loss": l1_loss_learned_activations, + "loss": total_loss, + }, + ) - progress_bar.update(batch_size) + progress_bar.update(batch_size) - progress_bar.close() + progress_bar.close() diff --git a/sparse_autoencoder/train/utils/wandb_sweep_types.py b/sparse_autoencoder/train/utils/wandb_sweep_types.py index f59a1ffa..45a1e253 100644 --- a/sparse_autoencoder/train/utils/wandb_sweep_types.py +++ b/sparse_autoencoder/train/utils/wandb_sweep_types.py @@ -2,6 +2,7 @@ Weights & Biases just provide a JSON Schema, so we've converted here to dataclasses. """ +# ruff: noqa from dataclasses import dataclass from enum import Enum from typing import Any, Generic, TypeVar @@ -18,7 +19,7 @@ class ControllerType(Enum): class Controller: """Controller.""" - type: ControllerType + type: ControllerType # noqa: A003 class HyperbandStoppingType(Enum): @@ -35,25 +36,25 @@ class HyperbandStopping: than successful training runs. """ - type: HyperbandStoppingType + type: HyperbandStoppingType # noqa: A003 eta: float | None = None """ETA. - + At every eta^n steps, hyperband continues running the top 1/eta runs and stops all other runs. """ maxiter: int | None = None """Max Iterations. - + Set the last epoch to finish trimming runs, and hyperband will automatically calculate the prior epochs to trim runs. """ miniter: int | None = None """Min Iterations. - + Set the first epoch to start trimming runs, and hyperband will automatically calculate the subsequent epochs to trim runs. """ @@ -155,9 +156,9 @@ class Parameter(Generic[ParamType]): value: ParamType | list[ParamType] - max: ParamType | None = None + max: ParamType | None = None # noqa: A003 - min: ParamType | None = None + min: ParamType | None = None # noqa: A003 a: float | None = None @@ -231,6 +232,6 @@ class WandbSweepConfig: runcap: int | None = None """Run Cap. - + Sweep will run no more than this number of runs, across any number of agents. """