diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 479b7f9a..25ec9503 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,9 +38,6 @@ jobs: uv pip install -e .[test] --system - # TODO remove next line installing ase from main branch when FrechetCellFilter is released - uv pip install --upgrade 'ase@git+https://gitlab.com/ase/ase' --system - - name: Run Tests run: pytest --capture=no --cov --cov-report=xml env: diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index b41aba6a..31a3ad33 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -242,13 +242,13 @@ def relax( Default = True ase_filter (str | ase.filters.Filter): The filter to apply to the atoms object for relaxation. Default = FrechetCellFilter - Used to default to ExpCellFilter but was removed due to bug reported in - https://gitlab.com/ase/ase/-/issues/1321 and fixed in + Default used to be ExpCellFilter which was removed due to bug reported + in https://gitlab.com/ase/ase/-/issues/1321 and fixed in https://gitlab.com/ase/ase/-/merge_requests/3024. save_path (str | None): The path to save the trajectory. Default = None - loginterval (int | None): Interval for logging trajectory and crystal feas - Default = 1 + loginterval (int | None): Interval for logging trajectory and crystal + features. Default = 1 crystal_feas_save_path (str | None): Path to save crystal feature vectors which are logged at a loginterval rage Default = None @@ -262,30 +262,18 @@ def relax( dict[str, Structure | TrajectoryObserver]: A dictionary with 'final_structure' and 'trajectory'. """ - try: - import ase.filters as filter_classes - from ase.filters import Filter - - except ImportWarning: - import ase.constraints as filter_classes - from ase.constraints import Filter - - if ase_filter == "FrechetCellFilter": - ase_filter = "ExpCellFilter" - print( - "Failed to import ase.filters. Default filter to ExpCellFilter. " - "For better relaxation accuracy with the new FrechetCellFilter, " - "run pip install git+https://gitlab.com/ase/ase" - ) + import ase.filters as filters + from ase.filters import Filter + valid_filter_names = [ name - for name, cls in inspect.getmembers(filter_classes, inspect.isclass) + for name, cls in inspect.getmembers(filters, inspect.isclass) if issubclass(cls, Filter) ] if isinstance(ase_filter, str): if ase_filter in valid_filter_names: - ase_filter = getattr(filter_classes, ase_filter) + ase_filter = getattr(filters, ase_filter) else: raise ValueError( f"Invalid {ase_filter=}, must be one of {valid_filter_names}. " diff --git a/examples/basics.ipynb b/examples/basics.ipynb index 7b0101d3..814c299c 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -19,8 +19,7 @@ " from chgnet.model import CHGNet\n", "except ImportError:\n", " # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n", - " !pip install chgnet\n", - " !pip install git+https://gitlab.com/ase/ase" + " !pip install chgnet" ] }, { diff --git a/pyproject.toml b/pyproject.toml index ab1e324c..7764f980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,3 @@ -[build-system] -requires = ["Cython", "setuptools>=65.0", "wheel"] -build-backend = "setuptools.build_meta" - [project] name = "chgnet" version = "0.3.6" @@ -11,7 +7,7 @@ requires-python = ">=3.9" readme = "README.md" license = { text = "Modified BSD" } dependencies = [ - "ase", + "ase>=3.23.0", "cython>=0.29.26", "numpy>=1.21.6", "nvidia-ml-py3>=7.352.0", @@ -48,9 +44,14 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } "chgnet" = ["*.json"] "chgnet.pretrained" = ["*", "**/*"] +[build-system] +requires = ["Cython", "setuptools>=65.0", "wheel"] +build-backend = "setuptools.build_meta" + [tool.ruff] target-version = "py39" -include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] +extend-include = ["*.ipynb"] + [tool.ruff.lint] select = ["ALL"] ignore = [ @@ -90,6 +91,9 @@ pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"] isort.split-on-trailing-comma = false +[tool.ruff.format] +docstring-code-format = true + [tool.ruff.lint.per-file-ignores] "site/*" = ["INP001", "S602"] "tests/*" = ["ANN201", "D100", "D103", "FBT001", "FBT002", "INP001", "S101"] @@ -98,7 +102,6 @@ isort.split-on-trailing-comma = false "chgnet/**/*" = ["T201"] "__init__.py" = ["F401"] - [tool.coverage.run] source = ["chgnet"]