Skip to content

Commit

Permalink
clear mem
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Feb 21, 2024
1 parent 427624a commit 53cf36b
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ jobs:
run: |
pip install pre-commit
pre-commit run --all-files || ( git status --short ; git diff ; exit 1 )
- name: Clean up files
run: |
python mdagent/utils/clear_mem.py
- name: Run Test
shell: bash -l {0}
env:
Expand Down
3 changes: 2 additions & 1 deletion mdagent/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .clear_mem import clear_memory
from .makellm import _make_llm
from .path_registry import FileType, PathRegistry

__all__ = ["_make_llm", "PathRegistry", "FileType"]
__all__ = ["_make_llm", "PathRegistry", "FileType", "clear_memory"]
50 changes: 50 additions & 0 deletions mdagent/utils/clear_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import shutil
from pathlib import Path


def find_repo_root(start_path):
path = Path(start_path).resolve()
while not (path / "setup.py").exists():
if path.parent == path:
raise FileNotFoundError("Could not find the repository root with setup.py.")
path = path.parent
print("path: ", path)
return path


def clear_memory(
clear_skill=True, clear_files=True, ask_confirmation=False, repo_root=None
):
print(
"""This script will delete the following:
1. All files in files/pdb, files/simulation, and files/records directories
2. All files starting with temp_ in the current directory
3. The file path_registry.json"""
)
if repo_root is None:
repo_root = find_repo_root(__file__)
else:
repo_root = Path(repo_root)
if ask_confirmation:
confirmation = input("Are you sure you want to proceed? (y/n): ")
else:
confirmation = "y"
if confirmation.lower() == "y":
directories_to_clear = []
if clear_files:
directories_to_clear += [Path(repo_root) / "files"]
if clear_skill:
directories_to_clear += [Path(repo_root) / "ckpt"]
if not clear_files and not clear_skill:
return None
for directory in directories_to_clear:
shutil.rmtree(directory, ignore_errors=True)
directory.mkdir(parents=True, exist_ok=True)

return "Deletion complete."
else:
return "Deletion aborted."


if __name__ == "__main__":
clear_memory()
27 changes: 26 additions & 1 deletion tests/test_fxns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import tempfile
import time
import warnings
from pathlib import Path
from unittest.mock import MagicMock, mock_open, patch

import pytest
Expand All @@ -14,7 +16,7 @@
)
from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv
from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool
from mdagent.utils import FileType, PathRegistry
from mdagent.utils import FileType, PathRegistry, clear_memory

warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")

Expand Down Expand Up @@ -357,6 +359,7 @@ def test_packmol_sm_download_called(packmol):
mock_get_sm_pdbs.assert_called_with(["water", "benzene"])


@pytest.skip("Skipping temporarily", allow_module_level=True)
def test_packmol_download_only(packmol):
path_registry = PathRegistry()
path_registry._remove_path_from_json("water")
Expand All @@ -369,6 +372,7 @@ def test_packmol_download_only(packmol):
os.remove("files/pdb/benzene.pdb")


@pytest.skip("Skipping temporarily", allow_module_level=True)
def test_packmol_download_only_once(packmol):
path_registry = PathRegistry()
path_registry._remove_path_from_json("water")
Expand Down Expand Up @@ -414,3 +418,24 @@ def test_init_path_registry(path_registry_with_mocked_fs):
# you may need to check the internal state or the contents of the JSON file.
# For example:
assert "water_000000" in path_registry_with_mocked_fs.list_path_names()


def test_clear_mem(monkeypatch):
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = Path(tmpdir)
(repo_root / "setup.py").touch()
directories_to_create = ["files/pdb", "files/simulations", "files/records"]
for directory in directories_to_create:
(repo_root / "dir2" / directory).mkdir(parents=True)
# Create a dummy file in each directory
(repo_root / "dir2" / directory / "dummy_file.txt").touch()

(repo_root / "temp_file.txt").touch()
(repo_root / "path_registry.json").touch()
monkeypatch.setattr("builtins.input", lambda _: "y")
clear_memory()

for directory in directories_to_create:
assert not list((repo_root / "dir2" / directory).iterdir())
assert not (repo_root / "temp_file.txt").exists()
assert not (repo_root / "path_registry.json").exists()

0 comments on commit 53cf36b

Please sign in to comment.