Skip to content

Commit

Permalink
added some streamlit logging
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Jan 29, 2024
1 parent 4fc06b0 commit e140360
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 0 deletions.
11 changes: 11 additions & 0 deletions mdagent/subagents/subagent_fxns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
from typing import Optional

import streamlit as st

from .subagent_setup import SubAgentInitializer, SubAgentSettings


Expand Down Expand Up @@ -76,6 +78,7 @@ def _run_loop(self, task, full_history, skills):
"""
critique = None
print("\n\033[46m action agent is running, writing code\033[0m")
st.markdown("action agent is running, writing code", unsafe_allow_html=True)
success, code, fxn_name, code_output = self.action._run_code(
full_history, task, skills
)
Expand Down Expand Up @@ -126,12 +129,20 @@ def _run_iterations(self, run, task):

# give successful code to tool/skill manager
print("\n\033[46mThe new code is complete, running skill agent\033[0m")
st.markdown(
"The new code is complete, running skill agent",
unsafe_allow_html=True,
)
tool_name = self.skill.add_new_tool(fxn_name, code)
return success, tool_name
iter += 1

# if max iterations reached without success, save failures to file
print("\n\033[46m Max iterations reached, saving failed history to file\033[0m")
st.markdown(
"Max iterations reached, saving failed history to file",
unsafe_allow_html=True,
)
tool_name = None
full_failed = self._add_to_history(
full_history,
Expand Down
13 changes: 13 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/rmsd_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
import MDAnalysis as mda
import numpy as np
import streamlit as st
from langchain.tools import BaseTool
from MDAnalysis.analysis import align, diffusionmap, rms
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -44,15 +45,27 @@ def calculate_rmsd(
if rmsd_type == "rmsd":
if self.ref_file:
print("Calculating 1-D RMSD between two sets of coordinates...")
st.markdown(
"Calculating 1-D RMSD between two sets of coordinates...",
unsafe_allow_html=True,
)
return self.compute_rmsd_2sets(selection=selection)
else:
print("Calculating time-dependent RMSD...")
st.markdown(
"Calculating time-dependent RMSD...", unsafe_allow_html=True
)
return self.compute_rmsd(selection=selection, plot=plot)
elif rmsd_type == "pairwise_rmsd":
print("Calculating pairwise RMSD...")
st.markdown("Calculating pairwise RMSD...", unsafe_allow_html=True)
return self.compute_2d_rmsd(selection=selection, plot_heatmap=plot)
elif rmsd_type == "rmsf":
print("Calculating root mean square fluctuation (RMSF)...")
st.markdown(
"Calculating root mean square fluctuation (RMSF)...",
unsafe_allow_html=True,
)
return self.compute_rmsf(selection=selection, plot=plot)
else:
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions mdagent/tools/base_tools/preprocess_tools/pdb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, List, Optional, Type, Union

import requests
import streamlit as st
from langchain.tools import BaseTool
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, ValidationError, root_validator
Expand Down Expand Up @@ -39,6 +40,7 @@ def get_pdb(query_string, path_registry=None):
if "result_set" in r.json() and len(r.json()["result_set"]) > 0:
pdbid = r.json()["result_set"][0]["identifier"]
print(f"PDB file found with this ID: {pdbid}")
st.markdown(f"PDB file found with this ID: {pdbid}", unsafe_allow_html=True)
url = f"https://files.rcsb.org/download/{pdbid}.{filetype}"
pdb = requests.get(url)
filename = path_registry.write_file_name(
Expand Down
15 changes: 15 additions & 0 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional, Type

import langchain
import streamlit as st
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
Expand Down Expand Up @@ -317,6 +318,7 @@ def _setup_and_run_simulation(self, query, PathRegistry):
Forcefield = Forcefield_files[0]
Water_model = Forcefield_files[1]
print("Setting up forcields :", Forcefield, Water_model)
st.markdown("Setting up forcields", unsafe_allow_html=True)
# check if forcefields end in .xml
if Forcefield.endswith(".xml") and Water_model.endswith(".xml"):
forcefield = ForceField(Forcefield, Water_model)
Expand Down Expand Up @@ -355,6 +357,7 @@ def _setup_and_run_simulation(self, query, PathRegistry):
_timestep,
"fs",
)
st.markdown("Setting up Langevin integrator", unsafe_allow_html=True)
if params["Ensemble"] == "NPT":
_pressure = params["Pressure"].split(" ")[0].strip()
system.addForce(MonteCarloBarostat(_pressure * bar, _temp * kelvin))
Expand All @@ -378,6 +381,7 @@ def _setup_and_run_simulation(self, query, PathRegistry):
"bar",
)
print("Setting up Verlet integrator with Parameters:", _timestep, "fs")
st.markdown("Setting up Verlet integrator", unsafe_allow_html=True)
integrator = VerletIntegrator(float(_timestep) * picoseconds)

simulation = Simulation(modeller.topology, system, integrator)
Expand Down Expand Up @@ -682,6 +686,7 @@ def __init__(

def setup_system(self):
print("Building system...")
st.markdown("Building system", unsafe_allow_html=True)
self.pdb_id = self.params["pdb_id"]
self.pdb_path = self.path_registry.get_mapped_path(name=self.pdb_id)
self.pdb = PDBFile(self.pdb_path)
Expand All @@ -703,6 +708,7 @@ def setup_system(self):

def setup_integrator(self):
print("Setting up integrator...")
st.markdown("Setting up integrator", unsafe_allow_html=True)
int_params = self.int_params
integrator_type = int_params.get("integrator_type", "LangevinMiddle")

Expand All @@ -727,6 +733,7 @@ def setup_integrator(self):

def create_simulation(self):
print("Creating simulation...")
st.markdown("Creating simulation", unsafe_allow_html=True)
self.simulation = Simulation(
self.pdb.topology,
self.system,
Expand Down Expand Up @@ -1049,23 +1056,28 @@ def remove_leading_spaces(text):
file.write(script_content)

print(f"Standalone simulation script written to {directory}/{filename}")
st.markdown("Standalone simulation script written", unsafe_allow_html=True)

def run(self):
# Minimize and Equilibrate
print("Performing energy minimization...")
st.markdown("Performing energy minimization", unsafe_allow_html=True)

self.simulation.minimizeEnergy()
print("Minimization complete!")
st.markdown("Minimization complete! Equilibrating...", unsafe_allow_html=True)
print("Equilibrating...")
_temp = self.int_params["Temperature"]
self.simulation.context.setVelocitiesToTemperature(_temp)
_eq_steps = self.sim_params.get("equilibrationSteps", 1000)
self.simulation.step(_eq_steps)
# Simulate
print("Simulating...")
st.markdown("Simulating...", unsafe_allow_html=True)
self.simulation.currentStep = 0
self.simulation.step(self.sim_params["Number of Steps"])
print("Done!")
st.markdown("Done!", unsafe_allow_html=True)
if not self.save:
if os.path.exists("temp_trajectory.dcd"):
os.remove("temp_trajectory.dcd")
Expand Down Expand Up @@ -1134,6 +1146,7 @@ def _run(self, **input_args):
input, self.path_registry, save, sim_id, pdb_id
)
print("simulation set!")
st.markdown("simulation set!", unsafe_allow_html=True)
except ValueError as e:
return str(e) + f"This were the inputs {input_args}"
except FileNotFoundError:
Expand Down Expand Up @@ -1594,9 +1607,11 @@ def check_system_params(cls, values):
forcefield_files = values.get("forcefield_files")
if forcefield_files is None or forcefield_files is []:
print("Setting default forcefields")
st.markdown("Setting default forcefields", unsafe_allow_html=True)
forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"]
elif len(forcefield_files) == 0:
print("Setting default forcefields v2")
st.markdown("Setting default forcefields", unsafe_allow_html=True)
forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"]
else:
for file in forcefield_files:
Expand Down
7 changes: 7 additions & 0 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Optional, Type

import streamlit as st
from dotenv import load_dotenv
from langchain import agents
from langchain.base_language import BaseLanguageModel
Expand Down Expand Up @@ -179,6 +180,10 @@ def get_tools(
print(f"Invalid index {index}.")
print("Some tools may be duplicated.")
print(f"Try to delete vector DB at {ckpt_dir}/all_tools_vectordb.")
st.markdown(
"Invalid index. Some tools may be duplicated Try to delete VDB.",
unsafe_allow_html=True,
)
return retrieved_tools


Expand Down Expand Up @@ -232,6 +237,7 @@ def _run(self, task, orig_prompt, curr_tools, execute=True, args=None):
current_tools=curr_tools,
)
print("running iterator to draft a new tool")
st.markdown("Running iterator to draft a new tool", unsafe_allow_html=True)
tool_name = newcode_iterator.run(task, orig_prompt)
if not tool_name:
return "The 'CreateNewTool' tool failed to build a new tool."
Expand All @@ -242,6 +248,7 @@ def _run(self, task, orig_prompt, curr_tools, execute=True, args=None):
if execute:
try:
print("\nexecuting tool")
st.markdown("Executing tool", unsafe_allow_html=True)
agent_initializer = SubAgentInitializer(self.subagent_settings)
skill = agent_initializer.create_skill_manager(resume=True)
if skill is None:
Expand Down

0 comments on commit e140360

Please sign in to comment.