From 40761c40ab14a6f3df2310dac2214ad4c586435f Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Thu, 22 Feb 2024 13:29:22 -0800 Subject: [PATCH] fixed tests with new file_path function & unit test --- mdagent/utils/__init__.py | 3 +- mdagent/utils/general_utils.py | 22 +++++++++++++++ tests/test_sims_and_clean.py | 51 +++++++++++++++++++--------------- tests/test_utils.py | 15 +++++++++- 4 files changed, 66 insertions(+), 25 deletions(-) create mode 100644 mdagent/utils/general_utils.py diff --git a/mdagent/utils/__init__.py b/mdagent/utils/__init__.py index ef0fa47b..ad59b1e4 100644 --- a/mdagent/utils/__init__.py +++ b/mdagent/utils/__init__.py @@ -1,4 +1,5 @@ +from .general_utils import find_file_path from .makellm import _make_llm from .path_registry import FileType, PathRegistry -__all__ = ["_make_llm", "PathRegistry", "FileType"] +__all__ = ["_make_llm", "PathRegistry", "FileType", "find_file_path"] diff --git a/mdagent/utils/general_utils.py b/mdagent/utils/general_utils.py new file mode 100644 index 00000000..e22764c0 --- /dev/null +++ b/mdagent/utils/general_utils.py @@ -0,0 +1,22 @@ +import os + + +def find_file_path(file_name: str, exact_match: bool = True): + """get the path of a file, if it exists in repo""" + setup_dir = None + for dirpath, dirnames, filenames in os.walk("."): + if "setup.py" in filenames: + setup_dir = dirpath + break + + if setup_dir is None: + raise FileNotFoundError("Unable to find root directory.") + + for dirpath, dirnames, filenames in os.walk(setup_dir): + for filename in filenames: + if (exact_match and filename == file_name) or ( + not exact_match and file_name in filename + ): + return os.path.join(dirpath, filename) + + return None diff --git a/tests/test_sims_and_clean.py b/tests/test_sims_and_clean.py index 1f466677..15d5c29c 100644 --- a/tests/test_sims_and_clean.py +++ b/tests/test_sims_and_clean.py @@ -7,7 +7,7 @@ from mdagent.tools.base_tools import CleaningTools, SimulationFunctions from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool -from mdagent.utils import PathRegistry +from mdagent.utils import PathRegistry, find_file_path warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -99,13 +99,13 @@ def test_setup_simulation_from_json(mock_json_load, mock_file_open, sim_fxns): def test_small_molecule_pdb(molpdb, get_registry): # Test with a valid SMILES string valid_smiles = "C1=CC=CC=C1" # Benzene - expected_output = ( - "PDB file for C1=CC=CC=C1 successfully created and saved to " - "files/pdb/benzene.pdb." + expected_output_success = "successfully created and saved to " + assert expected_output_success in molpdb.small_molecule_pdb( + valid_smiles, get_registry ) - assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output - assert os.path.exists("files/pdb/benzene.pdb") - os.remove("files/pdb/benzene.pdb") # Clean up + file_path = find_file_path("benzene", exact_match=False) + assert file_path is not None # assert file was found + os.remove(file_path) # Clean up # test with invalid SMILES string and invalid molecule name invalid_smiles = "C1=CC=CC=C1X" @@ -113,15 +113,17 @@ def test_small_molecule_pdb(molpdb, get_registry): expected_output = ( "There was an error getting pdb. Please input a single molecule name." ) - assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output - assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output + assert expected_output in molpdb.small_molecule_pdb(invalid_smiles, get_registry) + assert expected_output in molpdb.small_molecule_pdb(invalid_name, get_registry) # test with valid molecule name valid_name = "water" - assert "successfully" in molpdb.small_molecule_pdb(valid_name, get_registry) - # assert os.path.exists("files/pdb/water.pdb") - if os.path.exists("files/pdb/water.pdb"): - os.remove("files/pdb/water.pdb") + assert expected_output_success in molpdb.small_molecule_pdb( + valid_name, get_registry + ) + file_path = find_file_path("water", exact_match=False) + assert file_path is not None # assert file was found + os.remove(file_path) # Clean up def test_packmol_sm_download_called(packmol): @@ -156,12 +158,14 @@ def test_packmol_download_only(packmol): path_registry._remove_path_from_json("benzene") small_molecules = ["water", "benzene"] packmol._get_sm_pdbs(small_molecules) - # assert os.path.exists("files/pdb/water.pdb") - # assert os.path.exists("files/pdb/benzene.pdb") - if os.path.exists("files/pdb/water.pdb"): - os.remove("files/pdb/water.pdb") - if os.path.exists("files/pdb/benzene.pdb"): - os.remove("files/pdb/benzene.pdb") + + water_path = find_file_path("water", exact_match=False) + assert water_path is not None + os.remove(water_path) + + benzene_path = find_file_path("benzene", exact_match=False) + assert benzene_path is not None + os.remove(benzene_path) @pytest.mark.skip(reason="Resume this test when ckpt is implemented") @@ -170,14 +174,15 @@ def test_packmol_download_only_once(packmol): path_registry._remove_path_from_json("water") small_molecules = ["water"] packmol._get_sm_pdbs(small_molecules) - assert os.path.exists("files/pdb/water.pdb") - water_time = os.path.getmtime("files/pdb/water.pdb") + path_name = find_file_path("water", exact_match=False) + assert path_name is not None + water_time = os.path.getmtime(path_name) time.sleep(5) # Call the function again with the same molecule packmol._get_sm_pdbs(small_molecules) - water_time_after = os.path.getmtime("files/pdb/water.pdb") + water_time_after = os.path.getmtime(path_name) assert water_time == water_time_after # Clean up - os.remove("files/pdb/water.pdb") + os.remove(path_name) diff --git a/tests/test_utils.py b/tests/test_utils.py index 400d64fd..c193b207 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,11 @@ import json +import os import warnings from unittest.mock import mock_open, patch import pytest -from mdagent.utils import FileType, PathRegistry +from mdagent.utils import FileType, PathRegistry, find_file_path warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -154,3 +155,15 @@ def test_init_path_registry(path_registry_with_mocked_fs): # you may need to check the internal state or the contents of the JSON file. # For example: assert "water_000000" in path_registry_with_mocked_fs.list_path_names() + + +def test_find_file_path(): + file_name = "test_utils.py" + file_path_current = os.path.abspath(file_name) + file_path_test = find_file_path(file_name, exact_match=True) + assert file_path_current == file_path_test + + file_name_short = file_name[-4] + file_path_current_short = os.path.abspath(file_name_short) + file_path_test_short = find_file_path(file_name_short, exact_match=False) + assert file_path_current_short == file_path_test_short