diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 394f049f6..a2b97cc88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,4 +38,4 @@ repos: hooks: - id: mypy args: [--ignore-missing-imports] - additional_dependencies: [wandb, types-PyYAML] + additional_dependencies: [wandb==0.17.8, types-PyYAML] diff --git a/pyproject.toml b/pyproject.toml index 7ba0b4c32..f936e5afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "tokenizers>=0.15.2", "transformers>=4.41.2", "optax>=0.1.9", - "wandb>=0.16.6,<0.17.6", + "wandb>=0.17.8", "scipy<=1.12.0", "draccus>=0.8.0", "pyarrow>=11.0.0", @@ -105,13 +105,6 @@ test = [ "pytest-asyncio", ] -#[tool.setuptools.packages.find] -#where = ["src"] -#include = ["levanter", "levanter.*"] - - -[tool.setuptools] -packages = ["levanter"] - -[tool.setuptools.package-dir] -levanter = "src/levanter" +[tool.setuptools.packages.find] +where = ["src"] +include = ["levanter", "levanter.*"] diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index c98c0727c..1b0254261 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -182,9 +182,9 @@ def init(self, run_id: Optional[str]) -> WandbTracker: if wandb.run is not None: wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() + wandb.summary["num_devices"] = jax.device_count() # type: ignore + wandb.summary["num_hosts"] = jax.process_count() # type: ignore + wandb.summary["backend"] = jax.default_backend() # type: ignore return WandbTracker(r)