From ba4fc9c6239f400105fa71ac0264764cb37f49b7 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Mon, 5 Feb 2024 12:51:17 -0500 Subject: [PATCH 1/2] added streamlit app with ability to upload files (#74) --- README.md | 7 ++ mdagent/mainagent/agent.py | 5 ++ mdagent/subagents/subagent_fxns.py | 11 +++ .../base_tools/analysis_tools/rmsd_tools.py | 13 +++ .../base_tools/preprocess_tools/pdb_tools.py | 2 + .../simulation_tools/setup_and_run.py | 17 +++- mdagent/tools/maketools.py | 7 ++ setup.py | 1 + st_app.py | 80 +++++++++++++++++++ 9 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 st_app.py diff --git a/README.md b/README.md index b38b46d8..e87e97e1 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,13 @@ Other tools require API keys, such as paper-qa for literature searches. We recom 1. Copy the `.env.example` file and rename it to `.env`: `cp .env.example .env` 2. Replace the placeholder values in `.env` with your actual keys +## Using Streamlit Interface +If you'd like to use MDAgent via the streamlit app, make sure you have completed the steps above. Then, in your terminal, run `streamlit run st_app.py` in the project root directory. + +From there you may upload files to use during the run. Note: the app is currently limited to uploading .pdb and .cif files, and the max size is defaulted at 200MB. +- To upload larger files, instead run `streamlit run st_app.py --server.maxUploadSize=some_large_number` +- To add different file types, you can add your desired file type to the list in the [streamlit app file](https://github.com/ur-whitelab/md-agent/blob/main/st_app.py). + ## Contributing diff --git a/mdagent/mainagent/agent.py b/mdagent/mainagent/agent.py index 6f8d7f3b..61bff83d 100644 --- a/mdagent/mainagent/agent.py +++ b/mdagent/mainagent/agent.py @@ -47,9 +47,14 @@ def __init__( resume=False, top_k_tools=20, # set "all" if you want to use all tools (& skills if resume) use_human_tool=False, + uploaded_files=[], # user input files to add to path registry ): if path_registry is None: path_registry = PathRegistry.get_instance() + self.uploaded_files = uploaded_files + for file in uploaded_files: # todo -> allow users to add descriptions? + path_registry.map_path(file, file, description="User uploaded file") + self.agent_type = agent_type self.user_tools = tools self.tools_llm = _make_llm(tools_model, temp, verbose) diff --git a/mdagent/subagents/subagent_fxns.py b/mdagent/subagents/subagent_fxns.py index 8b1c84ce..b010ba60 100644 --- a/mdagent/subagents/subagent_fxns.py +++ b/mdagent/subagents/subagent_fxns.py @@ -2,6 +2,8 @@ import os from typing import Optional +import streamlit as st + from .subagent_setup import SubAgentInitializer, SubAgentSettings @@ -76,6 +78,7 @@ def _run_loop(self, task, full_history, skills): """ critique = None print("\n\033[46m action agent is running, writing code\033[0m") + st.markdown("action agent is running, writing code", unsafe_allow_html=True) success, code, fxn_name, code_output = self.action._run_code( full_history, task, skills ) @@ -126,12 +129,20 @@ def _run_iterations(self, run, task): # give successful code to tool/skill manager print("\n\033[46mThe new code is complete, running skill agent\033[0m") + st.markdown( + "The new code is complete, running skill agent", + unsafe_allow_html=True, + ) tool_name = self.skill.add_new_tool(fxn_name, code) return success, tool_name iter += 1 # if max iterations reached without success, save failures to file print("\n\033[46m Max iterations reached, saving failed history to file\033[0m") + st.markdown( + "Max iterations reached, saving failed history to file", + unsafe_allow_html=True, + ) tool_name = None full_failed = self._add_to_history( full_history, diff --git a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py index b5433f82..684d5f37 100644 --- a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import MDAnalysis as mda import numpy as np +import streamlit as st from langchain.tools import BaseTool from MDAnalysis.analysis import align, diffusionmap, rms from pydantic import BaseModel, Field @@ -44,15 +45,27 @@ def calculate_rmsd( if rmsd_type == "rmsd": if self.ref_file: print("Calculating 1-D RMSD between two sets of coordinates...") + st.markdown( + "Calculating 1-D RMSD between two sets of coordinates...", + unsafe_allow_html=True, + ) return self.compute_rmsd_2sets(selection=selection) else: print("Calculating time-dependent RMSD...") + st.markdown( + "Calculating time-dependent RMSD...", unsafe_allow_html=True + ) return self.compute_rmsd(selection=selection, plot=plot) elif rmsd_type == "pairwise_rmsd": print("Calculating pairwise RMSD...") + st.markdown("Calculating pairwise RMSD...", unsafe_allow_html=True) return self.compute_2d_rmsd(selection=selection, plot_heatmap=plot) elif rmsd_type == "rmsf": print("Calculating root mean square fluctuation (RMSF)...") + st.markdown( + "Calculating root mean square fluctuation (RMSF)...", + unsafe_allow_html=True, + ) return self.compute_rmsf(selection=selection, plot=plot) else: raise ValueError( diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index fbad3a6f..45dacde9 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Type, Union import requests +import streamlit as st from langchain.tools import BaseTool from pdbfixer import PDBFixer from pydantic import BaseModel, Field, ValidationError, root_validator @@ -39,6 +40,7 @@ def get_pdb(query_string, path_registry=None): if "result_set" in r.json() and len(r.json()["result_set"]) > 0: pdbid = r.json()["result_set"][0]["identifier"] print(f"PDB file found with this ID: {pdbid}") + st.markdown(f"PDB file found with this ID: {pdbid}", unsafe_allow_html=True) url = f"https://files.rcsb.org/download/{pdbid}.{filetype}" pdb = requests.get(url) filename = path_registry.write_file_name( diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py index 62176806..19ffaac6 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Type import langchain +import streamlit as st from langchain.base_language import BaseLanguageModel from langchain.chains import LLMChain from langchain.prompts import PromptTemplate @@ -316,7 +317,8 @@ def _setup_and_run_simulation(self, query, PathRegistry): ] Forcefield = Forcefield_files[0] Water_model = Forcefield_files[1] - print("Setting up forcields :", Forcefield, Water_model) + print("Setting up forcefields :", Forcefield, Water_model) + st.markdown("Setting up forcefields", unsafe_allow_html=True) # check if forcefields end in .xml if Forcefield.endswith(".xml") and Water_model.endswith(".xml"): forcefield = ForceField(Forcefield, Water_model) @@ -355,6 +357,7 @@ def _setup_and_run_simulation(self, query, PathRegistry): _timestep, "fs", ) + st.markdown("Setting up Langevin integrator", unsafe_allow_html=True) if params["Ensemble"] == "NPT": _pressure = params["Pressure"].split(" ")[0].strip() system.addForce(MonteCarloBarostat(_pressure * bar, _temp * kelvin)) @@ -378,6 +381,7 @@ def _setup_and_run_simulation(self, query, PathRegistry): "bar", ) print("Setting up Verlet integrator with Parameters:", _timestep, "fs") + st.markdown("Setting up Verlet integrator", unsafe_allow_html=True) integrator = VerletIntegrator(float(_timestep) * picoseconds) simulation = Simulation(modeller.topology, system, integrator) @@ -682,6 +686,7 @@ def __init__( def setup_system(self): print("Building system...") + st.markdown("Building system", unsafe_allow_html=True) self.pdb_id = self.params["pdb_id"] self.pdb_path = self.path_registry.get_mapped_path(name=self.pdb_id) self.pdb = PDBFile(self.pdb_path) @@ -703,6 +708,7 @@ def setup_system(self): def setup_integrator(self): print("Setting up integrator...") + st.markdown("Setting up integrator", unsafe_allow_html=True) int_params = self.int_params integrator_type = int_params.get("integrator_type", "LangevinMiddle") @@ -727,6 +733,7 @@ def setup_integrator(self): def create_simulation(self): print("Creating simulation...") + st.markdown("Creating simulation", unsafe_allow_html=True) self.simulation = Simulation( self.pdb.topology, self.system, @@ -1049,13 +1056,16 @@ def remove_leading_spaces(text): file.write(script_content) print(f"Standalone simulation script written to {directory}/{filename}") + st.markdown("Standalone simulation script written", unsafe_allow_html=True) def run(self): # Minimize and Equilibrate print("Performing energy minimization...") + st.markdown("Performing energy minimization", unsafe_allow_html=True) self.simulation.minimizeEnergy() print("Minimization complete!") + st.markdown("Minimization complete! Equilibrating...", unsafe_allow_html=True) print("Equilibrating...") _temp = self.int_params["Temperature"] self.simulation.context.setVelocitiesToTemperature(_temp) @@ -1063,9 +1073,11 @@ def run(self): self.simulation.step(_eq_steps) # Simulate print("Simulating...") + st.markdown("Simulating...", unsafe_allow_html=True) self.simulation.currentStep = 0 self.simulation.step(self.sim_params["Number of Steps"]) print("Done!") + st.markdown("Done!", unsafe_allow_html=True) if not self.save: if os.path.exists("temp_trajectory.dcd"): os.remove("temp_trajectory.dcd") @@ -1134,6 +1146,7 @@ def _run(self, **input_args): input, self.path_registry, save, sim_id, pdb_id ) print("simulation set!") + st.markdown("simulation set!", unsafe_allow_html=True) except ValueError as e: return str(e) + f"This were the inputs {input_args}" except FileNotFoundError: @@ -1594,9 +1607,11 @@ def check_system_params(cls, values): forcefield_files = values.get("forcefield_files") if forcefield_files is None or forcefield_files is []: print("Setting default forcefields") + st.markdown("Setting default forcefields", unsafe_allow_html=True) forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"] elif len(forcefield_files) == 0: print("Setting default forcefields v2") + st.markdown("Setting default forcefields", unsafe_allow_html=True) forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"] else: for file in forcefield_files: diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index 5fca1faf..d64f4284 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -2,6 +2,7 @@ import os from typing import Optional, Type +import streamlit as st from dotenv import load_dotenv from langchain import agents from langchain.base_language import BaseLanguageModel @@ -179,6 +180,10 @@ def get_tools( print(f"Invalid index {index}.") print("Some tools may be duplicated.") print(f"Try to delete vector DB at {ckpt_dir}/all_tools_vectordb.") + st.markdown( + "Invalid index. Some tools may be duplicated Try to delete VDB.", + unsafe_allow_html=True, + ) return retrieved_tools @@ -232,6 +237,7 @@ def _run(self, task, orig_prompt, curr_tools, execute=True, args=None): current_tools=curr_tools, ) print("running iterator to draft a new tool") + st.markdown("Running iterator to draft a new tool", unsafe_allow_html=True) tool_name = newcode_iterator.run(task, orig_prompt) if not tool_name: return "The 'CreateNewTool' tool failed to build a new tool." @@ -242,6 +248,7 @@ def _run(self, task, orig_prompt, curr_tools, execute=True, args=None): if execute: try: print("\nexecuting tool") + st.markdown("Executing tool", unsafe_allow_html=True) agent_initializer = SubAgentInitializer(self.subagent_settings) skill = agent_initializer.create_skill_manager(resume=True) if skill is None: diff --git a/setup.py b/setup.py index 531ed583..a25ff0db 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "requests", "rmrkl", "tiktoken", + "streamlit", ], test_suite="tests", long_description=long_description, diff --git a/st_app.py b/st_app.py new file mode 100644 index 00000000..ab21a360 --- /dev/null +++ b/st_app.py @@ -0,0 +1,80 @@ +import os +from typing import List + +import streamlit as st +from dotenv import load_dotenv +from langchain.callbacks import StreamlitCallbackHandler +from langchain.callbacks.base import BaseCallbackHandler + +from mdagent import MDAgent + +load_dotenv() + + +st_callback = StreamlitCallbackHandler(st.container()) + + +# Streamlit app +st.title("MDAgent") + +# option = st.selectbox("Choose an option:", ["Explore & Learn", "Use Learned Skills"]) +# if option == "Explore & Learn": +# explore = True +# else: +# explore = False + +resume_op = st.selectbox("Resume:", ["False", "True"]) +if resume_op == "True": + resume = True +else: + resume = False + +# for now I'm just going to allow pdb and cif files - we can add more later +uploaded_files = st.file_uploader( + "Upload a .pdb or .cif file", type=["pdb", "cif"], accept_multiple_files=True +) +files: List[str] = [] +# write file to disk +if uploaded_files: + for file in uploaded_files: + with open(file.name, "wb") as f: + f.write(file.getbuffer()) + + st.write("Files successfully uploaded!") + uploaded_file = [os.path.join(os.getcwd(), file.name) for file in uploaded_files] +else: + uploaded_file = [] + +mdagent = MDAgent(resume=resume, uploaded_files=uploaded_file) + + +def generate_response(prompt): + result = mdagent.run(prompt) + return result + + +# make new container to store scratch +scratch = st.empty() +scratch.write( + """Hi! I am MDAgent, your MD automation assistant. + How can I help you today?""" +) + + +# This allows streaming of llm tokens +class TokenStreamlitCallbackHandler(BaseCallbackHandler): + def __init__(self, container): + self.container = container + + def on_llm_new_token(self, token, **kwargs): + self.container.write("".join(token)) + + +token_st_callback = TokenStreamlitCallbackHandler(scratch) + +if prompt := st.chat_input(): + st.chat_message("user").write(prompt) + with st.chat_message("assistant"): + st_callback = StreamlitCallbackHandler(st.container()) + response = mdagent.run(prompt, callbacks=[st_callback, token_st_callback]) + st.write(response) From a56cfe8602674003a881ad91a7657230e7bfdc96 Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Mon, 5 Feb 2024 13:12:41 -0500 Subject: [PATCH 2/2] moved element list to its own file (#77) --- .../base_tools/preprocess_tools/elements.py | 224 ++++++++++++ .../base_tools/preprocess_tools/pdb_tools.py | 345 +----------------- 2 files changed, 229 insertions(+), 340 deletions(-) create mode 100644 mdagent/tools/base_tools/preprocess_tools/elements.py diff --git a/mdagent/tools/base_tools/preprocess_tools/elements.py b/mdagent/tools/base_tools/preprocess_tools/elements.py new file mode 100644 index 00000000..445e6eeb --- /dev/null +++ b/mdagent/tools/base_tools/preprocess_tools/elements.py @@ -0,0 +1,224 @@ +list_of_elements = [ + " H", + "He", + "Li", + "Be", + " B", + " C", + " N", + " O", + " F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + " P", + " S", + "Cl", + "Ar", + " K", + "Ca", + "Sc", + "Ti", + " V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + " Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + " I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + " W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + " U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og", + "HE", + "LI", + "BE", + "NE", + "NA", + "MG", + "AL", + "SI", + "CL", + "AR", + "CA", + "SC", + "TI", + "CR", + "MN", + "FE", + "CO", + "NI", + "CU", + "ZN", + "GA", + "GE", + "AS", + "SE", + "BR", + "KR", + "RB", + "SR", + " Y", + "ZR", + "NB", + "MO", + "TC", + "RU", + "RH", + "PD", + "AG", + "CD", + "IN", + "SN", + "SB", + "TE", + "XE", + "CS", + "BA", + "LA", + "CE", + "PR", + "ND", + "PM", + "SM", + "EU", + "GD", + "TB", + "DY", + "HO", + "ER", + "TM", + "YB", + "LU", + "HF", + "TA", + "RE", + "OS", + "IR", + "PT", + "AU", + "HG", + "TL", + "PB", + "BI", + "PO", + "AT", + "RN", + "FR", + "RA", + "AC", + "TH", + "PA", + "NP", + "PU", + "AM", + "CM", + "BK", + "CF", + "ES", + "FM", + "MD", + "NO", + "LR", + "RF", + "DB", + "SG", + "BH", + "HS", + "MT", + "DS", + "RG", + "CN", + "NH", + "FL", + "MC", + "LV", + "TS", +] diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index 45dacde9..24171be2 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -13,6 +13,8 @@ from mdagent.utils import FileType, PathRegistry +from .elements import list_of_elements + def get_pdb(query_string, path_registry=None): """ @@ -504,230 +506,8 @@ async def _arun(self, values: str) -> str: class PDBsummarizerfxns: - list_of_elements = [ - " H", - "He", - "Li", - "Be", - " B", - " C", - " N", - " O", - " F", - "Ne", - "Na", - "Mg", - "Al", - "Si", - " P", - " S", - "Cl", - "Ar", - " K", - "Ca", - "Sc", - "Ti", - " V", - "Cr", - "Mn", - "Fe", - "Co", - "Ni", - "Cu", - "Zn", - "Ga", - "Ge", - "As", - "Se", - "Br", - "Kr", - "Rb", - "Sr", - " Y", - "Zr", - "Nb", - "Mo", - "Tc", - "Ru", - "Rh", - "Pd", - "Ag", - "Cd", - "In", - "Sn", - "Sb", - "Te", - " I", - "Xe", - "Cs", - "Ba", - "La", - "Ce", - "Pr", - "Nd", - "Pm", - "Sm", - "Eu", - "Gd", - "Tb", - "Dy", - "Ho", - "Er", - "Tm", - "Yb", - "Lu", - "Hf", - "Ta", - " W", - "Re", - "Os", - "Ir", - "Pt", - "Au", - "Hg", - "Tl", - "Pb", - "Bi", - "Po", - "At", - "Rn", - "Fr", - "Ra", - "Ac", - "Th", - "Pa", - " U", - "Np", - "Pu", - "Am", - "Cm", - "Bk", - "Cf", - "Es", - "Fm", - "Md", - "No", - "Lr", - "Rf", - "Db", - "Sg", - "Bh", - "Hs", - "Mt", - "Ds", - "Rg", - "Cn", - "Nh", - "Fl", - "Mc", - "Lv", - "Ts", - "Og", - "HE", - "LI", - "BE", - "NE", - "NA", - "MG", - "AL", - "SI", - "CL", - "AR", - "CA", - "SC", - "TI", - "CR", - "MN", - "FE", - "CO", - "NI", - "CU", - "ZN", - "GA", - "GE", - "AS", - "SE", - "BR", - "KR", - "RB", - "SR", - " Y", - "ZR", - "NB", - "MO", - "TC", - "RU", - "RH", - "PD", - "AG", - "CD", - "IN", - "SN", - "SB", - "TE", - "XE", - "CS", - "BA", - "LA", - "CE", - "PR", - "ND", - "PM", - "SM", - "EU", - "GD", - "TB", - "DY", - "HO", - "ER", - "TM", - "YB", - "LU", - "HF", - "TA", - "RE", - "OS", - "IR", - "PT", - "AU", - "HG", - "TL", - "PB", - "BI", - "PO", - "AT", - "RN", - "FR", - "RA", - "AC", - "TH", - "PA", - "NP", - "PU", - "AM", - "CM", - "BK", - "CF", - "ES", - "FM", - "MD", - "NO", - "LR", - "RF", - "DB", - "SG", - "BH", - "HS", - "MT", - "DS", - "RG", - "CN", - "NH", - "FL", - "MC", - "LV", - "TS", - ] + def __init__(self): + self.list_of_elements = list_of_elements def _record_inf(self, pdbfile): with open(pdbfile, "r") as f: @@ -897,121 +677,6 @@ def pdb_summarizer(pdb_file): def _fix_element_column(pdb_file, custom_element_dict=None): - elements = set( - ( - "H", - "D", - "HE", - "LI", - "BE", - "B", - "C", - "N", - "O", - "F", - "NE", - "NA", - "MG", - "AL", - "SI", - "P", - "S", - "CL", - "AR", - "K", - "CA", - "SC", - "TI", - "V", - "CR", - "MN", - "FE", - "CO", - "NI", - "CU", - "ZN", - "GA", - "GE", - "AS", - "SE", - "BR", - "KR", - "RB", - "SR", - "Y", - "ZR", - "NB", - "MO", - "TC", - "RU", - "RH", - "PD", - "AG", - "CD", - "IN", - "SN", - "SB", - "TE", - "I", - "XE", - "CS", - "BA", - "LA", - "CE", - "PR", - "ND", - "PM", - "SM", - "EU", - "GD", - "TB", - "DY", - "HO", - "ER", - "TM", - "YB", - "LU", - "HF", - "TA", - "W", - "RE", - "OS", - "IR", - "PT", - "AU", - "HG", - "TL", - "PB", - "BI", - "PO", - "AT", - "RN", - "FR", - "RA", - "AC", - "TH", - "PA", - "U", - "NP", - "PU", - "AM", - "CM", - "BK", - "CF", - "ES", - "FM", - "MD", - "NO", - "LR", - "RF", - "DB", - "SG", - "BH", - "HS", - "MT", - ) - ) - records = ("ATOM", "HETATM", "ANISOU") corrected_lines = [] for line in pdb_file: @@ -1027,7 +692,7 @@ def _fix_element_column(pdb_file, custom_element_dict=None): else: element = atom_name[0] - if element not in elements: + if element not in set(list_of_elements): element = " " # empty element in case we cannot assign line = line[:76] + element.rjust(2) + line[78:]