Skip to content

Commit

Permalink
fixed tests with new file_path function & unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Feb 22, 2024
1 parent 3e5dacc commit 40761c4
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 25 deletions.
3 changes: 2 additions & 1 deletion mdagent/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .general_utils import find_file_path
from .makellm import _make_llm
from .path_registry import FileType, PathRegistry

__all__ = ["_make_llm", "PathRegistry", "FileType"]
__all__ = ["_make_llm", "PathRegistry", "FileType", "find_file_path"]
22 changes: 22 additions & 0 deletions mdagent/utils/general_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os


def find_file_path(file_name: str, exact_match: bool = True):
"""get the path of a file, if it exists in repo"""
setup_dir = None
for dirpath, dirnames, filenames in os.walk("."):
if "setup.py" in filenames:
setup_dir = dirpath
break

if setup_dir is None:
raise FileNotFoundError("Unable to find root directory.")

for dirpath, dirnames, filenames in os.walk(setup_dir):
for filename in filenames:
if (exact_match and filename == file_name) or (
not exact_match and file_name in filename
):
return os.path.join(dirpath, filename)

return None
51 changes: 28 additions & 23 deletions tests/test_sims_and_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mdagent.tools.base_tools import CleaningTools, SimulationFunctions
from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool
from mdagent.utils import PathRegistry
from mdagent.utils import PathRegistry, find_file_path

warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")

Expand Down Expand Up @@ -99,29 +99,31 @@ def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns):
def test_small_molecule_pdb(molpdb, get_registry):
# Test with a valid SMILES string
valid_smiles = "C1=CC=CC=C1" # Benzene
expected_output = (
"PDB file for C1=CC=CC=C1 successfully created and saved to "
"files/pdb/benzene.pdb."
expected_output_success = "successfully created and saved to "
assert expected_output_success in molpdb.small_molecule_pdb(
valid_smiles, get_registry
)
assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output
assert os.path.exists("files/pdb/benzene.pdb")
os.remove("files/pdb/benzene.pdb") # Clean up
file_path = find_file_path("benzene", exact_match=False)
assert file_path is not None # assert file was found
os.remove(file_path) # Clean up

# test with invalid SMILES string and invalid molecule name
invalid_smiles = "C1=CC=CC=C1X"
invalid_name = "NotAMolecule"
expected_output = (
"There was an error getting pdb. Please input a single molecule name."
)
assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output
assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output
assert expected_output in molpdb.small_molecule_pdb(invalid_smiles, get_registry)
assert expected_output in molpdb.small_molecule_pdb(invalid_name, get_registry)

# test with valid molecule name
valid_name = "water"
assert "successfully" in molpdb.small_molecule_pdb(valid_name, get_registry)
# assert os.path.exists("files/pdb/water.pdb")
if os.path.exists("files/pdb/water.pdb"):
os.remove("files/pdb/water.pdb")
assert expected_output_success in molpdb.small_molecule_pdb(
valid_name, get_registry
)
file_path = find_file_path("water", exact_match=False)
assert file_path is not None # assert file was found
os.remove(file_path) # Clean up


def test_packmol_sm_download_called(packmol):
Expand Down Expand Up @@ -156,12 +158,14 @@ def test_packmol_download_only(packmol):
path_registry._remove_path_from_json("benzene")
small_molecules = ["water", "benzene"]
packmol._get_sm_pdbs(small_molecules)
# assert os.path.exists("files/pdb/water.pdb")
# assert os.path.exists("files/pdb/benzene.pdb")
if os.path.exists("files/pdb/water.pdb"):
os.remove("files/pdb/water.pdb")
if os.path.exists("files/pdb/benzene.pdb"):
os.remove("files/pdb/benzene.pdb")

water_path = find_file_path("water", exact_match=False)
assert water_path is not None
os.remove(water_path)

benzene_path = find_file_path("benzene", exact_match=False)
assert benzene_path is not None
os.remove(benzene_path)


@pytest.mark.skip(reason="Resume this test when ckpt is implemented")
Expand All @@ -170,14 +174,15 @@ def test_packmol_download_only_once(packmol):
path_registry._remove_path_from_json("water")
small_molecules = ["water"]
packmol._get_sm_pdbs(small_molecules)
assert os.path.exists("files/pdb/water.pdb")
water_time = os.path.getmtime("files/pdb/water.pdb")
path_name = find_file_path("water", exact_match=False)
assert path_name is not None
water_time = os.path.getmtime(path_name)
time.sleep(5)

# Call the function again with the same molecule
packmol._get_sm_pdbs(small_molecules)
water_time_after = os.path.getmtime("files/pdb/water.pdb")
water_time_after = os.path.getmtime(path_name)

assert water_time == water_time_after
# Clean up
os.remove("files/pdb/water.pdb")
os.remove(path_name)
15 changes: 14 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import os
import warnings
from unittest.mock import mock_open, patch

import pytest

from mdagent.utils import FileType, PathRegistry
from mdagent.utils import FileType, PathRegistry, find_file_path

warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")

Expand Down Expand Up @@ -154,3 +155,15 @@ def test_init_path_registry(path_registry_with_mocked_fs):
# you may need to check the internal state or the contents of the JSON file.
# For example:
assert "water_000000" in path_registry_with_mocked_fs.list_path_names()


def test_find_file_path():
file_name = "test_utils.py"
file_path_current = os.path.abspath(file_name)
file_path_test = find_file_path(file_name, exact_match=True)
assert file_path_current == file_path_test

file_name_short = file_name[-4]
file_path_current_short = os.path.abspath(file_name_short)
file_path_test_short = find_file_path(file_name_short, exact_match=False)
assert file_path_current_short == file_path_test_short

0 comments on commit 40761c4

Please sign in to comment.