Skip to content

Commit

Permalink
Adding Pre-filter using outlines. 1. Changing versioning for Chroma, …
Browse files Browse the repository at this point in the history
…paperqa, and pydantic to solve dependencies issues, 2. Add a file pre_filter.py in the mainagent directory to contain the filter functionality
  • Loading branch information
Jgmedina95 committed Feb 28, 2024
1 parent 3639361 commit 14b2cc5
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 60 deletions.
7 changes: 5 additions & 2 deletions mdagent/mainagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..tools import get_tools, make_all_tools
from .prompt import openaifxn_prompt, structured_prompt
from .query_filter import create_filtered_query

load_dotenv()

Expand Down Expand Up @@ -124,5 +125,7 @@ def _initialize_tools_and_agent(self, user_input=None):
)

def run(self, user_input, callbacks=None):
self.agent = self._initialize_tools_and_agent(user_input)
return self.agent.run(self.prompt.format(input=user_input), callbacks=callbacks)
structured_query = create_filtered_query(user_input, model="gpt-3.5-turbo")
print(structured_query)
# self.agent=self._initialize_tools_and_agent(user_input)
# returnself.agent.run(self.prompt.format(input=user_input),callbacks=callbacks)
152 changes: 152 additions & 0 deletions mdagent/mainagent/query_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional

import outlines
from outlines import generate, models
from pydantic import BaseModel


@dataclass
class Parameters:
Temperature: Optional[str]
Pressure: Optional[str]
Time: Optional[str]
ForceField: Optional[str]
WaterModel: Optional[str]
SaltConcentration: Optional[str]
pH: Optional[str]
Solvate: Optional[bool]
Ensemble: Optional[str]
Other_Parameters: Optional[str]


class Task_type(str, Enum):
question = "question"
preprocessing = "preprocessing"
simulation = "simulation"
postnalysis = "postanalysis"


class FilteredQuery(BaseModel):
Main_Task: str
Subtask_types: List[Task_type] # conlist(Task_type, min_length=1)
ProteinS: str
Parameters: Parameters
UserProposedPlan: List[str] # conlist(str, min_length=0]


@dataclass
class Example:
Raw_query: str
Filtered_Query: FilteredQuery


@outlines.prompt
def query_filter(raw_query, examples: list[Example]):
"""You are about to organize an user query. User will
ask for a specific Molecular Dynamics related task, from wich you will
extract:
1. The main task of the query
2. A list of subtasks that are part of the main task
3. The protein of interest mentioned in the raw query (as a PDB ID,
UniProt ID, name, or sequence)
4. Parameters or conditions specified by the user for the simulation
5. The plan proposed by the user for the simulation (if any)
{% for example in examples %}
Raw Query: "{{ example.Raw_query }}"
RESULT: {
"Main_Task": "{{ example.Filtered_Query.Main_Task }}",
"Subtask_types": "{{ example.Filtered_Query.Subtask_types }}",
"ProteinS": "{{ example.Filtered_Query.ProteinS }}",
"Parameters": "{{ example.Filtered_Query.Parameters }}",
"UserProposedPlan": "{{ example.Filtered_Query.UserProposedPlan }}"}
{% endfor %}
Here is the new raw query that you need to filter:
Raw Query: {{raw_query}}
RESULT:
"""


examples = [
Example(
Raw_query="I want a simulation of 1A3N at 280K",
Filtered_Query=FilteredQuery(
Main_Task="Simulate 1A3N at 280K",
Subtask_types=["simulation"],
ProteinS="1A3N",
Parameters=Parameters(
Temperature="280K",
Pressure=None,
Time=None,
ForceField=None,
WaterModel=None,
SaltConcentration=None,
pH=None,
Solvate=None,
Ensemble=None,
Other_Parameters=None,
),
UserProposedPlan=[],
),
),
Example(
Raw_query="What is the best force field for 1A3N?",
Filtered_Query=FilteredQuery(
Main_Task="Answer the question: best force field for 1A3N?",
Subtask_types=["question"],
ProteinS="1A3N",
Parameters=Parameters(
Temperature=None,
Pressure=None,
Time=None,
ForceField=None,
WaterModel=None,
SaltConcentration=None,
pH=None,
Solvate=None,
Ensemble=None,
Other_Parameters=None,
),
UserProposedPlan=[],
),
),
Example(
Raw_query="""Calculate the Radial Distribution Function of 1A3N with
water. Youll have to download the PDB file, clean it, and solvate it
for the simulation. The trajectory and
topology files can be used to calculate the RDF.""",
Filtered_Query=FilteredQuery(
Main_Task="Calculate the Radial Distribution Function of 1A3N with water.",
Subtask_types=["preprocessing", "simulation", "postanalysis"],
ProteinS="1A3N",
Parameters=Parameters(
Temperature=None,
Pressure=None,
Time=None,
ForceField=None,
WaterModel=None,
SaltConcentration=None,
pH=None,
Solvate=True,
Ensemble=None,
Other_Parameters=None,
),
UserProposedPlan=[
"Downlaod PDB file for 1A3N",
"Clean/Pre-process the PDB file",
"Calculate the Radial Distribution Function with water.",
"With the trajectory and topology files, calculate the RDF.",
],
),
),
]


def create_filtered_query(raw_query, model="gpt-3.5-turbo", examples=examples):
filter_model = models.openai(model)
generator = generate.text(filter_model)
return generator(query_filter(raw_query, examples=examples))
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Dict, Optional, Type
from typing import Optional, Type

from langchain.tools import BaseTool
from openmm.app import PDBFile, PDBxFile
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field

from mdagent.utils import FileType, PathRegistry

Expand Down Expand Up @@ -227,12 +227,6 @@ class CleaningToolFunctionInput(BaseModel):
)
add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.")

@root_validator
def validate_query(cls, values) -> Dict:
"""Check that the input is valid."""

return values


class CleaningToolFunction(BaseTool):
name = "CleaningToolFunction"
Expand Down
96 changes: 49 additions & 47 deletions mdagent/tools/base_tools/preprocess_tools/pdb_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import re
import sys
import typing
from typing import Any, Dict, Optional, Type, Union
from typing import Dict, Optional, Type

from langchain.tools import BaseTool
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, ValidationError, root_validator
from pydantic import BaseModel, Field, ValidationError

from mdagent.utils import PathRegistry

Expand Down Expand Up @@ -660,51 +660,53 @@ class PDBFilesFixInp(BaseModel):
),
)

@root_validator
def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
if isinstance(values, str):
print("values is a string", values)
raise ValidationError("Input must be a dictionary")

pdbfile = values.get("pdbfiles", "")
occupancy = values.get("occupancy")
tempFactor = values.get("tempFactor")
ElemColum = values.get("ElemColum")

if occupancy is None and tempFactor is None and ElemColum is None:
if pdbfile == "":
return {"error": "No inputs given, failed use of tool."}
else:
return values
else:
if occupancy:
if len(occupancy) != 2:
return {
"error": (
"if you want to fix the occupancy"
"column argument must be a tuple of (bool, float)"
)
}
if not isinstance(occupancy[0], float):
return {"error": "occupancy first arg must be a float"}
if not isinstance(occupancy[1], bool):
return {"error": "occupancy second arg must be a bool"}
if tempFactor:
if len(tempFactor != 2):
return {
"error": (
"if you want to fix the tempFactor"
"column argument must be a tuple of (float, bool)"
)
}
if not isinstance(tempFactor[0], bool):
return {"error": "occupancy first arg must be a float"}
if not isinstance(tempFactor[1], float):
return {"error": "tempFactor second arg must be a float"}
if ElemColum is not None:
if not isinstance(ElemColum[1], bool):
return {"error": "ElemColum must be a bool"}
return values
# @model_validator(mode='before')
# def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
# if isinstance(values, str):
# print("values is a string", values)
# raise ValidationError("Input must be a dictionary")


#
# pdbfile = values.get("pdbfiles", "")
# occupancy = values.get("occupancy")
# tempFactor = values.get("tempFactor")
# ElemColum = values.get("ElemColum")
#
# if occupancy is None and tempFactor is None and ElemColum is None:
# if pdbfile == "":
# return {"error": "No inputs given, failed use of tool."}
# else:
# return values
# else:
# if occupancy:
# if len(occupancy) != 2:
# return {
# "error": (
# "if you want to fix the occupancy"
# "column argument must be a tuple of (bool, float)"
# )
# }
# if not isinstance(occupancy[0], float):
# return {"error": "occupancy first arg must be a float"}
# if not isinstance(occupancy[1], bool):
# return {"error": "occupancy second arg must be a bool"}
# if tempFactor:
# if len(tempFactor != 2):
# return {
# "error": (
# "if you want to fix the tempFactor"
# "column argument must be a tuple of (float, bool)"
# )
# }
# if not isinstance(tempFactor[0], bool):
# return {"error": "occupancy first arg must be a float"}
# if not isinstance(tempFactor[1], float):
# return {"error": "tempFactor second arg must be a float"}
# if ElemColum is not None:
# if not isinstance(ElemColum[1], bool):
# return {"error": "ElemColum must be a bool"}
# return values


class FixPDBFile(BaseTool):
Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
license="MIT",
packages=find_packages(),
install_requires=[
"chromadb==0.3.29",
"chromadb==0.4.24",
"google-search-results",
"langchain==0.0.336",
"langchain_experimental",
Expand All @@ -28,9 +28,11 @@
"tiktoken",
"rdkit",
"streamlit",
"paper-qa",
"openmm",
"paper-qa==4.0.0rc8 ",
"MDAnalysis",
"pydantic>=2.6",
"outlines",
"mdtraj",
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
],
test_suite="tests",
Expand Down

0 comments on commit 14b2cc5

Please sign in to comment.