Skip to content

Commit

Permalink
added ckpt_dir to mdagent agent
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Feb 22, 2024
1 parent f9a1439 commit 6c17f5b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 22 deletions.
25 changes: 18 additions & 7 deletions mdagent/mainagent/agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from dotenv import load_dotenv
import os

from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.agents.structured_chat.base import StructuredChatAgent
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI

from mdagent.subagents import SubAgentSettings
from mdagent.utils import PathRegistry, _make_llm
from mdagent.utils import (
PathRegistry,
_make_llm,
make_ckpt_path,
move_files_to_ckpt_path,
)

from ..tools import get_tools, make_all_tools
from .prompt import openaifxn_prompt, structured_prompt

load_dotenv()


class AgentType:
valid_models = {
Expand Down Expand Up @@ -51,11 +55,18 @@ def __init__(
curriculum=True,
uploaded_files=[], # user input files to add to path registry
):
self.ckpt = ckpt_dir
if self.ckpt == "ckpt":
ckpt_path = make_ckpt_path(self.ckpt)
if path_registry is None:
path_registry = PathRegistry.get_instance()
path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_path)
self.uploaded_files = uploaded_files
for file in uploaded_files: # todo -> allow users to add descriptions?
path_registry.map_path(file, file, description="User uploaded file")

for file in uploaded_files: # move files to ckpt path & map
new_path = move_files_to_ckpt_path(file, ckpt_path)
path_registry.map_path(
os.path.basename(new_path), new_path, description="User uploaded file"
)

self.agent_type = agent_type
self.user_tools = tools
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/preprocess_tools/pdb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
}

# Further validation, e.g., checking if files exist
registry = PathRegistry()
registry = PathRegistry.get_instance()
file_ids = registry.list_path_names()

for pdbfile_id in pdbfiles:
Expand Down
10 changes: 9 additions & 1 deletion mdagent/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from .ckpt_handler import get_ckpt_folder_path, make_ckpt_path, move_files_to_ckpt_path
from .makellm import _make_llm
from .path_registry import FileType, PathRegistry

__all__ = ["_make_llm", "PathRegistry", "FileType"]
__all__ = [
"_make_llm",
"PathRegistry",
"FileType",
"make_ckpt_path",
"get_ckpt_folder_path",
"move_files_to_ckpt_path",
]
21 changes: 14 additions & 7 deletions mdagent/utils/ckpt_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import shutil


def find_repo_root(start_path="current"):
def find_repo_root(start_path: str = "current") -> str:
"""
Finds the folder containing setup.py
"""
Expand All @@ -17,21 +18,27 @@ def find_repo_root(start_path="current"):
)


def get_ckpt_folder_path(ckpt_dir="ckpt", start_path="current"):
def get_ckpt_folder_path(ckpt_dir: str = "ckpt", start_path: str = "current") -> str:
"""Returns the path to the ckpt folder in the repository."""
repo_root = find_repo_root(start_path)
return os.path.join(repo_root, ckpt_dir)


def make_ckpt_path():
def make_ckpt_path(ckpt: str = "ckpt") -> str:
root_repo = find_repo_root()
for i in range(10):
ckpt_path = os.path.join(root_repo, "ckpt", f"ckpt_{i}")
for i in range(100):
ckpt_path = os.path.join(root_repo, "ckpt", f"{ckpt}_{i}")
if not os.path.exists(ckpt_path):
# make that folder
os.makedirs(ckpt_path)
return ckpt_path
return "Could not make a new ckpt folder"


if __name__ == "__main__":
print(make_ckpt_path())
def move_files_to_ckpt_path(file: str, ckpt_dir: str) -> str:
"""This function moves files to the current ckpt
directory. It should only be used for user-uploaded files."""
dest_path = os.path.join(ckpt_dir, os.path.basename(file))
shutil.move(file, dest_path)
print(f"Moved {file} to {dest_path}")
return dest_path
14 changes: 8 additions & 6 deletions mdagent/utils/path_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ class FileType(Enum):

class PathRegistry:
instance = None
ckpt_dir = "ckpt"

@classmethod
def get_instance(cls):
def get_instance(cls, ckpt_dir=None):
if not cls.instance:
cls.instance = cls()
cls.instance = cls(ckpt_dir)
return cls.instance

def __init__(self):
def __init__(self, ckpt_dir=None):
self.ckpt_dir = ckpt_dir
self.json_file_path = "paths_registry.json"
self._init_path_registry()
self._init_path_registry(ckpt_dir)

def _init_path_registry(self):
base_directory = "files"
def _init_path_registry(self, ckpt_dir):
base_directory = ckpt_dir + "files"
subdirectories = ["pdb", "records", "simulations", "solvents"]
existing_registry = self._load_existing_registry()
file_names_in_registry = []
Expand Down

0 comments on commit 6c17f5b

Please sign in to comment.