From 14b2cc5ce7eb0622223d84a725784cc6ca9f95f5 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 27 Feb 2024 23:13:51 -0500 Subject: [PATCH] Adding Pre-filter using outlines. 1. Changing versioning for Chroma, paperqa, and pydantic to solve dependencies issues, 2. Add a file pre_filter.py in the mainagent directory to contain the filter functionality --- mdagent/mainagent/agent.py | 7 +- mdagent/mainagent/query_filter.py | 152 ++++++++++++++++++ .../preprocess_tools/clean_tools.py | 10 +- .../base_tools/preprocess_tools/pdb_fix.py | 96 +++++------ setup.py | 8 +- 5 files changed, 213 insertions(+), 60 deletions(-) create mode 100644 mdagent/mainagent/query_filter.py diff --git a/mdagent/mainagent/agent.py b/mdagent/mainagent/agent.py index cb9c81bc..6025ad99 100644 --- a/mdagent/mainagent/agent.py +++ b/mdagent/mainagent/agent.py @@ -9,6 +9,7 @@ from ..tools import get_tools, make_all_tools from .prompt import openaifxn_prompt, structured_prompt +from .query_filter import create_filtered_query load_dotenv() @@ -124,5 +125,7 @@ def _initialize_tools_and_agent(self, user_input=None): ) def run(self, user_input, callbacks=None): - self.agent = self._initialize_tools_and_agent(user_input) - return self.agent.run(self.prompt.format(input=user_input), callbacks=callbacks) + structured_query = create_filtered_query(user_input, model="gpt-3.5-turbo") + print(structured_query) + # self.agent=self._initialize_tools_and_agent(user_input) + # returnself.agent.run(self.prompt.format(input=user_input),callbacks=callbacks) diff --git a/mdagent/mainagent/query_filter.py b/mdagent/mainagent/query_filter.py new file mode 100644 index 00000000..05d961ae --- /dev/null +++ b/mdagent/mainagent/query_filter.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import outlines +from outlines import generate, models +from pydantic import BaseModel + + +@dataclass +class Parameters: + Temperature: Optional[str] + Pressure: Optional[str] + Time: Optional[str] + ForceField: Optional[str] + WaterModel: Optional[str] + SaltConcentration: Optional[str] + pH: Optional[str] + Solvate: Optional[bool] + Ensemble: Optional[str] + Other_Parameters: Optional[str] + + +class Task_type(str, Enum): + question = "question" + preprocessing = "preprocessing" + simulation = "simulation" + postnalysis = "postanalysis" + + +class FilteredQuery(BaseModel): + Main_Task: str + Subtask_types: List[Task_type] # conlist(Task_type, min_length=1) + ProteinS: str + Parameters: Parameters + UserProposedPlan: List[str] # conlist(str, min_length=0] + + +@dataclass +class Example: + Raw_query: str + Filtered_Query: FilteredQuery + + +@outlines.prompt +def query_filter(raw_query, examples: list[Example]): + """You are about to organize an user query. User will + ask for a specific Molecular Dynamics related task, from wich you will + extract: + 1. The main task of the query + 2. A list of subtasks that are part of the main task + 3. The protein of interest mentioned in the raw query (as a PDB ID, + UniProt ID, name, or sequence) + 4. Parameters or conditions specified by the user for the simulation + 5. The plan proposed by the user for the simulation (if any) + + + {% for example in examples %} + Raw Query: "{{ example.Raw_query }}" + RESULT: { + "Main_Task": "{{ example.Filtered_Query.Main_Task }}", + "Subtask_types": "{{ example.Filtered_Query.Subtask_types }}", + "ProteinS": "{{ example.Filtered_Query.ProteinS }}", + "Parameters": "{{ example.Filtered_Query.Parameters }}", + "UserProposedPlan": "{{ example.Filtered_Query.UserProposedPlan }}"} + {% endfor %} + + Here is the new raw query that you need to filter: + Raw Query: {{raw_query}} + RESULT: + """ + + +examples = [ + Example( + Raw_query="I want a simulation of 1A3N at 280K", + Filtered_Query=FilteredQuery( + Main_Task="Simulate 1A3N at 280K", + Subtask_types=["simulation"], + ProteinS="1A3N", + Parameters=Parameters( + Temperature="280K", + Pressure=None, + Time=None, + ForceField=None, + WaterModel=None, + SaltConcentration=None, + pH=None, + Solvate=None, + Ensemble=None, + Other_Parameters=None, + ), + UserProposedPlan=[], + ), + ), + Example( + Raw_query="What is the best force field for 1A3N?", + Filtered_Query=FilteredQuery( + Main_Task="Answer the question: best force field for 1A3N?", + Subtask_types=["question"], + ProteinS="1A3N", + Parameters=Parameters( + Temperature=None, + Pressure=None, + Time=None, + ForceField=None, + WaterModel=None, + SaltConcentration=None, + pH=None, + Solvate=None, + Ensemble=None, + Other_Parameters=None, + ), + UserProposedPlan=[], + ), + ), + Example( + Raw_query="""Calculate the Radial Distribution Function of 1A3N with + water. Youll have to download the PDB file, clean it, and solvate it + for the simulation. The trajectory and + topology files can be used to calculate the RDF.""", + Filtered_Query=FilteredQuery( + Main_Task="Calculate the Radial Distribution Function of 1A3N with water.", + Subtask_types=["preprocessing", "simulation", "postanalysis"], + ProteinS="1A3N", + Parameters=Parameters( + Temperature=None, + Pressure=None, + Time=None, + ForceField=None, + WaterModel=None, + SaltConcentration=None, + pH=None, + Solvate=True, + Ensemble=None, + Other_Parameters=None, + ), + UserProposedPlan=[ + "Downlaod PDB file for 1A3N", + "Clean/Pre-process the PDB file", + "Calculate the Radial Distribution Function with water.", + "With the trajectory and topology files, calculate the RDF.", + ], + ), + ), +] + + +def create_filtered_query(raw_query, model="gpt-3.5-turbo", examples=examples): + filter_model = models.openai(model) + generator = generate.text(filter_model) + return generator(query_filter(raw_query, examples=examples)) diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py index 589a4294..ec7fbe36 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py @@ -1,10 +1,10 @@ import os -from typing import Dict, Optional, Type +from typing import Optional, Type from langchain.tools import BaseTool from openmm.app import PDBFile, PDBxFile from pdbfixer import PDBFixer -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field from mdagent.utils import FileType, PathRegistry @@ -227,12 +227,6 @@ class CleaningToolFunctionInput(BaseModel): ) add_hydrogens_ph: int = Field(7.0, description="pH at which hydrogens are added.") - @root_validator - def validate_query(cls, values) -> Dict: - """Check that the input is valid.""" - - return values - class CleaningToolFunction(BaseTool): name = "CleaningToolFunction" diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py index 4cef4ef0..d5bd9bab 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py @@ -2,11 +2,11 @@ import re import sys import typing -from typing import Any, Dict, Optional, Type, Union +from typing import Dict, Optional, Type from langchain.tools import BaseTool from pdbfixer import PDBFixer -from pydantic import BaseModel, Field, ValidationError, root_validator +from pydantic import BaseModel, Field, ValidationError from mdagent.utils import PathRegistry @@ -660,51 +660,53 @@ class PDBFilesFixInp(BaseModel): ), ) - @root_validator - def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: - if isinstance(values, str): - print("values is a string", values) - raise ValidationError("Input must be a dictionary") - - pdbfile = values.get("pdbfiles", "") - occupancy = values.get("occupancy") - tempFactor = values.get("tempFactor") - ElemColum = values.get("ElemColum") - - if occupancy is None and tempFactor is None and ElemColum is None: - if pdbfile == "": - return {"error": "No inputs given, failed use of tool."} - else: - return values - else: - if occupancy: - if len(occupancy) != 2: - return { - "error": ( - "if you want to fix the occupancy" - "column argument must be a tuple of (bool, float)" - ) - } - if not isinstance(occupancy[0], float): - return {"error": "occupancy first arg must be a float"} - if not isinstance(occupancy[1], bool): - return {"error": "occupancy second arg must be a bool"} - if tempFactor: - if len(tempFactor != 2): - return { - "error": ( - "if you want to fix the tempFactor" - "column argument must be a tuple of (float, bool)" - ) - } - if not isinstance(tempFactor[0], bool): - return {"error": "occupancy first arg must be a float"} - if not isinstance(tempFactor[1], float): - return {"error": "tempFactor second arg must be a float"} - if ElemColum is not None: - if not isinstance(ElemColum[1], bool): - return {"error": "ElemColum must be a bool"} - return values + # @model_validator(mode='before') + # def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: + # if isinstance(values, str): + # print("values is a string", values) + # raise ValidationError("Input must be a dictionary") + + +# +# pdbfile = values.get("pdbfiles", "") +# occupancy = values.get("occupancy") +# tempFactor = values.get("tempFactor") +# ElemColum = values.get("ElemColum") +# +# if occupancy is None and tempFactor is None and ElemColum is None: +# if pdbfile == "": +# return {"error": "No inputs given, failed use of tool."} +# else: +# return values +# else: +# if occupancy: +# if len(occupancy) != 2: +# return { +# "error": ( +# "if you want to fix the occupancy" +# "column argument must be a tuple of (bool, float)" +# ) +# } +# if not isinstance(occupancy[0], float): +# return {"error": "occupancy first arg must be a float"} +# if not isinstance(occupancy[1], bool): +# return {"error": "occupancy second arg must be a bool"} +# if tempFactor: +# if len(tempFactor != 2): +# return { +# "error": ( +# "if you want to fix the tempFactor" +# "column argument must be a tuple of (float, bool)" +# ) +# } +# if not isinstance(tempFactor[0], bool): +# return {"error": "occupancy first arg must be a float"} +# if not isinstance(tempFactor[1], float): +# return {"error": "tempFactor second arg must be a float"} +# if ElemColum is not None: +# if not isinstance(ElemColum[1], bool): +# return {"error": "ElemColum must be a bool"} +# return values class FixPDBFile(BaseTool): diff --git a/setup.py b/setup.py index 6564c056..da45656d 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ license="MIT", packages=find_packages(), install_requires=[ - "chromadb==0.3.29", + "chromadb==0.4.24", "google-search-results", "langchain==0.0.336", "langchain_experimental", @@ -28,9 +28,11 @@ "tiktoken", "rdkit", "streamlit", - "paper-qa", - "openmm", + "paper-qa==4.0.0rc8 ", "MDAnalysis", + "pydantic>=2.6", + "outlines", + "mdtraj", "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", ], test_suite="tests",