Skip to content

Commit

Permalink
Add strict type checking (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 5, 2023
1 parent 8f2dfbc commit 1aeff6d
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 84 deletions.
13 changes: 8 additions & 5 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
"source.organizeImports": true
},
"editor.formatOnSave": true,
"evenBetterToml.formatter.alignComments": true,
"evenBetterToml.formatter.alignEntries": true,
"evenBetterToml.formatter.allowedBlankLines": 2,
"evenBetterToml.formatter.allowedBlankLines": 1,
"evenBetterToml.formatter.arrayAutoCollapse": true,
"evenBetterToml.formatter.arrayAutoExpand": true,
"evenBetterToml.formatter.arrayTrailingComma": true,
Expand All @@ -24,11 +22,15 @@
"evenBetterToml.formatter.compactEntries": true,
"evenBetterToml.formatter.compactInlineTables": true,
"evenBetterToml.formatter.indentEntries": true,
"evenBetterToml.formatter.indentString": " ",
"evenBetterToml.formatter.indentTables": true,
"evenBetterToml.formatter.inlineTableExpand": false,
"evenBetterToml.formatter.inlineTableExpand": true,
"evenBetterToml.formatter.reorderArrays": true,
"evenBetterToml.formatter.reorderKeys": true,
"evenBetterToml.formatter.trailingNewline": true,
"evenBetterToml.schema.enabled": true,
"evenBetterToml.schema.links": true,
"evenBetterToml.syntax.semanticTokens": false,
"notebook.formatOnCellExecution": true,
"notebook.formatOnSave.enabled": true,
"python.analysis.autoFormatStrings": true,
Expand All @@ -40,5 +42,6 @@
"python.testing.pytestEnabled": true,
"rewrap.autoWrap.enabled": true,
"rewrap.reformat": false,
"rewrap.wrappingColumn": 100
"rewrap.wrappingColumn": 100,
"python.analysis.diagnosticMode": "workspace"
}
156 changes: 104 additions & 52 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,74 +1,59 @@
[tool.poetry]
authors =["Alan Cooney <[email protected]>"]
authors=["Alan Cooney <[email protected]>"]
description="Sparse Autoencoder for Mechanistic Interpretability"
include =["sparse_autoencoder"]
license ="MIT"
name ="sparse_autoencoder"
readme ="README.md"
version ="0.0.0"
include=["sparse_autoencoder"]
license="MIT"
name="sparse_autoencoder"
readme="README.md"
version="0.0.0"

[tool.poetry.dependencies]
einops=">=0.6"
python=">=3.10, <3.13"
torch =">=2.1"
wandb =">=0.15.12"
torch=">=2.1"
wandb=">=0.15.12"

[tool.poetry.group.dev.dependencies]
jupyter =">=1"
plotly =">=5"
poethepoet=">=0.24.2"
pre-commit=">=3.5.0"
pyright =">=1.1.334"
pytest =">=7"
pytest-cov=">=4"
ruff =">=0.1.4"
[tool.poetry.group]
[tool.poetry.group.dev.dependencies]
jupyter=">=1"
plotly=">=5"
poethepoet=">=0.24.2"
pre-commit=">=3.5.0"
pyright=">=1.1.334"
pytest=">=7"
pytest-cov=">=4"
ruff=">=0.1.4"

[tool.poetry.group.demos.dependencies]
jupyterlab =">=3"
pandas =">=2.1.2"
transformer-lens=">=1.9.0"
[tool.poetry.group.demos.dependencies]
jupyterlab=">=3"
pandas=">=2.1.2"
transformer-lens=">=1.9.0"

[tool.poe.tasks]
check =["format", "lint", "test", "typecheck"]
format ="ruff format sparse_autoencoder"
lint ="ruff check sparse_autoencoder --fix"
check=["format", "lint", "test", "typecheck"]
format="ruff format sparse_autoencoder"
lint="ruff check sparse_autoencoder --fix"
precommit="pre-commit run --all-files"
test ="pytest"
test="pytest"
typecheck="pyright"

[build-system]
build-backend="poetry.core.masonry.api"
requires =["poetry-core"]
requires=["poetry-core"]

[tool.pytest]

[tool.pytest.ini_options]
addopts="""--jaxtyping-packages=sparse_autoencoder,beartype.beartype \
-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning \
--doctest-modules"""
filterwarnings=[
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
# Ignore numpy.distutils deprecation warning caused by pandas
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning",
]

[tool.pyright]
include =["sparse_autoencoder"]
reportIncompatibleMethodOverride=true
addopts="""--jaxtyping-packages=sparse_autoencoder,beartype.beartype --doctest-modules"""

[tool.ruff]
exclude=["*/snapshots/", "/.venv"]
ignore=[
"ANN101", # self type annotation (it's inferred)
"ANN204", # __init__() return type (it's inferred)
"E731", # No lambdas (can be useful)
"F722", # Forward annotations check (conflicts with jaxtyping)
"FA102", # Annotations support (Python >= 3.9 is fine)
"FIX002", # TODO issue link (overkill)
"INP001", # __init__.py for all packages (Python >= 3.3 is fine)
"PGH003", # No general type: ignore (too strict)
"S101", # Use of assert detected (it's needed for tests)
"PGH003", # No general type: ignore (not supported with pyright)
"TCH002", # Type checking imports (conflicts with beartype)
"TD00", # TODO banned (we're in alpha)
# Rules that conflict with ruff format
Expand All @@ -77,15 +62,82 @@
]
ignore-init-module-imports=true
line-length=100
required-version="0.1.4"
select=["ALL"]

[tool.ruff.lint.isort]
force-sort-within-sections=true
lines-after-imports =2
[tool.ruff.lint]
[tool.ruff.lint.flake8-annotations]
mypy-init-return=true

[tool.ruff.lint.isort]
force-sort-within-sections=true
lines-after-imports=2

[tool.ruff.lint.pydocstyle]
convention="google"
[tool.ruff.lint.per-file-ignores]
"**/tests/*"=["S101"] # Assert is needed in PyTest

[tool.ruff.pylint]
max-args=10
[tool.ruff.lint.pydocstyle]
convention="google"

[tool.ruff.lint.pylint]
max-args=10

[tool.pyright]
# Includes all rules in strict mode, with some set to warning
deprecateTypingAliases=true
disableBytesTypePromotions=true
include=["sparse_autoencoder"]
reportAssertAlwaysTrue=true
reportConstantRedefinition=true
reportDeprecated=true
reportDuplicateImport=true
reportFunctionMemberAccess=true
reportGeneralTypeIssues=true
reportIncompatibleMethodOverride=true
reportIncompatibleVariableOverride=true
reportIncompleteStub=true
reportInconsistentConstructor=true
reportInvalidStringEscapeSequence=true
reportInvalidStubStatement=true
reportInvalidTypeVarUse=true
reportMatchNotExhaustive=true
reportMissingParameterType=true
reportMissingTypeArgument="warning"
reportMissingTypeStubs="warning"
reportOptionalCall=true
reportOptionalContextManager=true
reportOptionalIterable=true
reportOptionalMemberAccess=true
reportOptionalOperand=true
reportOptionalSubscript=true
reportOverlappingOverload=true
reportPrivateImportUsage=true
reportPrivateUsage=true
reportSelfClsParameterName=true
reportTypeCommentUsage=true
reportTypedDictNotRequiredAccess=true
reportUnboundVariable=true
reportUnknownArgumentType="warning"
reportUnknownLambdaType=true
reportUnknownMemberType="warning"
reportUnknownParameterType="warning"
reportUnknownVariableType="warning"
reportUnnecessaryCast=true
reportUnnecessaryComparison=true
reportUnnecessaryContains=true
reportUnnecessaryIsInstance=true
reportUnsupportedDunderAll=true
reportUntypedBaseClass=true
reportUntypedClassDecorator=true
reportUntypedFunctionDecorator=true
reportUntypedNamedTuple=true
reportUnusedClass=true
reportUnusedCoroutine=true
reportUnusedExpression=true
reportUnusedFunction=true
reportUnusedImport=true
reportUnusedVariable=true
reportWildcardImportFromLibrary=true
strictDictionaryInference=true
strictListInference=true
strictParameterNoneValue=true
strictSetInference=true
2 changes: 1 addition & 1 deletion sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""


class ActivationStore(Dataset, ABC):
class ActivationStore(Dataset[ActivationStoreItem], ABC):
"""Activation Store Abstract Class.
Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/activation_store/disk_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,15 @@ def __len__(self) -> int:
0
"""
# Calculate the length if not cached
if self._disk_n_activation_vectors.value is None:
if self._disk_n_activation_vectors.value == -1:
cache_size: int = 0
for file in self._all_filenames:
cache_size += len(torch.load(file))
self._disk_n_activation_vectors.value = cache_size

return self._disk_n_activation_vectors.value

def __del__(self):
def __del__(self) -> None:
"""Delete Dunder Method."""
# Shutdown the thread pool after everything is complete
self._thread_pool.shutdown(wait=True, cancel_futures=False)
Expand Down
6 changes: 3 additions & 3 deletions sparse_autoencoder/activation_store/list_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ListActivationStore(ActivationStore):
_pool: ProcessPoolExecutor | None = None
"""Multiprocessing Pool."""

_pool_exceptions: ListProxy | list
_pool_exceptions: ListProxy | list[Exception]
"""Pool Exceptions.
Used to keep track of exceptions.
Expand Down Expand Up @@ -283,7 +283,7 @@ def wait_for_writes_to_complete(self) -> None:
time.sleep(1)

if self._pool_exceptions:
exceptions_report = "\n".join(f"{e}\n{tb}" for e, tb in self._pool_exceptions)
exceptions_report = "\n".join([str(e) for e in self._pool_exceptions])
msg = f"Exceptions occurred in background workers:\n{exceptions_report}"
raise RuntimeError(msg)

Expand All @@ -307,7 +307,7 @@ def empty(self) -> None:
# Clearing a list like this works for both standard and multiprocessing lists
self._data[:] = []

def __del__(self):
def __del__(self) -> None:
"""Delete Dunder Method."""
if self._pool:
self._pool.shutdown(wait=False, cancel_futures=True)
6 changes: 4 additions & 2 deletions sparse_autoencoder/autoencoder/tied_bias.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Tied Biases (Pre-Encoder and Post-Decoder)."""
from enum import StrEnum
from enum import Enum

from jaxtyping import Float
from torch import Tensor
from torch.nn import Module


class TiedBiasPosition(StrEnum):
class TiedBiasPosition(str, Enum):
"""Tied Bias Position."""

PRE_ENCODER = "pre_encoder"
Expand Down Expand Up @@ -44,6 +44,8 @@ def __init__(
super().__init__()

self._bias_reference = bias

# Support string literals as well as enums
self._bias_position = position

def forward(
Expand Down
7 changes: 4 additions & 3 deletions sparse_autoencoder/autoencoder/unit_norm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def __init__(
dtype: Data type to use.
"""
# Create the linear layer as per the standard PyTorch linear layer
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
self.weight = Parameter(
torch.empty((out_features, in_features), device=device, dtype=dtype)
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
self.bias = Parameter(torch.empty(out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
self.reset_parameters()
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/src_data/datasets/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sparse_autoencoder.src_data.src_data import CollateResponseTokens


class RandomIntDataset(Dataset):
class RandomIntDataset(Dataset[Int[Tensor, " pos"]]):
"""Dummy dataset for testing/examples."""

def __init__(
Expand Down Expand Up @@ -60,7 +60,7 @@ def create_dummy_dataloader(
batch_size: int,
pos: int = 512,
vocab_size: int = 50000,
) -> DataLoader:
) -> DataLoader[Int[Tensor, " pos"]]:
"""Create dummy dataloader."""
dataset = RandomIntDataset(num_samples, batch_size, pos, vocab_size)
return DataLoader(dataset, collate_fn=dummy_collate_fn)
5 changes: 3 additions & 2 deletions sparse_autoencoder/src_data/src_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
by jaxtyping.
"""
from collections.abc import Callable
from typing import Any

from datasets import IterableDataset, load_dataset
from jaxtyping import Int
Expand All @@ -23,13 +24,13 @@

def create_src_dataloader(
dataset_name: str,
collate_fn: Callable[[list], CollateResponseTokens],
collate_fn: Callable[[list[Any]], CollateResponseTokens],
dataset_split: str = "train",
batch_size: int = 512,
shuffle_buffer_size: int = 10_000,
random_seed: int = 0,
num_workers: int = 1,
) -> DataLoader:
) -> DataLoader[Int[Tensor, "batch pos"]]:
"""Create a DataLoader with tokenized data.
Creates a DataLoader with a [HuggingFace Dataset](https://huggingface.co/datasets).
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/train/generate_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def generate_activations(
layer: int,
cache_name: str,
store: ActivationStore,
dataloader: DataLoader,
dataloader: DataLoader[Int[Tensor, " pos"]],
num_items: int,
device: torch.device | None = None,
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Training Pipeline."""
from jaxtyping import Int
import torch
from torch import Tensor
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
Expand All @@ -17,7 +19,7 @@ def pipeline(
src_model: HookedTransformer,
src_model_activation_hook_point: str,
src_model_activation_layer: int,
src_dataloader: DataLoader,
src_dataloader: DataLoader[Int[Tensor, " pos"]],
activation_store: ActivationStore,
num_activations_before_training: int,
autoencoder: SparseAutoencoder,
Expand Down
Loading

0 comments on commit 1aeff6d

Please sign in to comment.