Skip to content

Commit

Permalink
Try hacky fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Nov 4, 2023
1 parent bd1a82b commit aba5824
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
44 changes: 22 additions & 22 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,44 +1,44 @@
[tool.poetry]
authors=["Alan Cooney <[email protected]>"]
authors =["Alan Cooney <[email protected]>"]
description="Sparse Autoencoder for Mechanistic Interpretability"
include=["sparse_autoencoder"]
license="MIT"
name="sparse_autoencoder"
readme="README.md"
version="0.0.0"
include =["sparse_autoencoder"]
license ="MIT"
name ="sparse_autoencoder"
readme ="README.md"
version ="0.0.0"

[tool.poetry.dependencies]
einops=">=0.6"
python=">=3.10, <3.13"
torch=">=2.1"
wandb=">=0.15.12"
torch =">=2.1"
wandb =">=0.15.12"

[tool.poetry.group.dev.dependencies]
jupyter=">=1"
plotly=">=5"
jupyter =">=1"
plotly =">=5"
poethepoet=">=0.24.2"
pre-commit=">=3.5.0"
pyright=">=1.1.334"
pytest=">=7"
pyright =">=1.1.334"
pytest =">=7"
pytest-cov=">=4"
ruff=">=0.1.4"
ruff =">=0.1.4"

[tool.poetry.group.demos.dependencies]
jupyterlab=">=3"
pandas=">=2.1.2"
jupyterlab =">=3"
pandas =">=2.1.2"
transformer-lens=">=1.9.0"

[tool.poe.tasks]
check=["format", "lint", "test", "typecheck"]
format="ruff format sparse_autoencoder"
lint="ruff check sparse_autoencoder --fix"
check =["format", "lint", "test", "typecheck"]
format ="ruff format sparse_autoencoder"
lint ="ruff check sparse_autoencoder --fix"
precommit="pre-commit run --all-files"
test="pytest"
test ="pytest"
typecheck="pyright"

[build-system]
build-backend="poetry.core.masonry.api"
requires=["poetry-core"]
requires =["poetry-core"]

[tool.pytest]

Expand All @@ -54,7 +54,7 @@
]

[tool.pyright]
include=["sparse_autoencoder"]
include =["sparse_autoencoder"]
reportIncompatibleMethodOverride=true

[tool.ruff]
Expand Down Expand Up @@ -82,7 +82,7 @@

[tool.ruff.lint.isort]
force-sort-within-sections=true
lines-after-imports=2
lines-after-imports =2

[tool.ruff.lint.pydocstyle]
convention="google"
Expand Down
6 changes: 3 additions & 3 deletions sparse_autoencoder/train/train_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Training Pipeline."""
import torch
from torch import device, set_grad_enabled
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
Expand All @@ -20,7 +20,7 @@ def train_autoencoder(
optimizer: Optimizer,
sweep_parameters: SweepParametersRuntime,
log_interval: int = 10,
device: torch.device | None = None,
device: device | None = None,
) -> None:
"""Sparse Autoencoder Training Loop.
Expand All @@ -35,7 +35,7 @@ def train_autoencoder(
n_dataset_items: int = len(activations_dataloader.dataset) # type: ignore
batch_size: int = activations_dataloader.batch_size # type: ignore

with torch.set_grad_enabled(True), tqdm( # noqa: FBT003
with set_grad_enabled(True), tqdm( # noqa: FBT003
desc="Train Autoencoder",
total=n_dataset_items,
colour="green",
Expand Down

0 comments on commit aba5824

Please sign in to comment.