From beeed4923581fa757f87293805c868bdf3fc75eb Mon Sep 17 00:00:00 2001 From: amolvdeshpande Date: Tue, 8 Oct 2024 16:09:19 -0400 Subject: [PATCH] fixing unit tests --- lib/sycamore/sycamore/query/planner.py | 34 +++++++++++-------- .../sycamore/tests/unit/query/test_plan.py | 2 +- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/lib/sycamore/sycamore/query/planner.py b/lib/sycamore/sycamore/query/planner.py index 0f7acf13b..bb34b47a6 100644 --- a/lib/sycamore/sycamore/query/planner.py +++ b/lib/sycamore/sycamore/query/planner.py @@ -62,7 +62,7 @@ Other than those, DO NOT USE ANY OTHER FIELD NAMES. 5. If an optional field does not have a value in the query plan, return null in its place. 6. If you cannot generate a plan to answer a question, return an empty list. - 7. The first step of each plan MUST be a **QueryDatabase** or **QueryVectorDatabase" operation that returns a + 7. The first step of each plan MUST be a **QueryDatabase** or **QueryVectorDatabase" operation that returns a database. Whenever possible, include all possible filtering operations in the QueryDatabase step. That is, you should strive to construct an OpenSearch query that filters the data as much as possible, reducing the need for further query operations. Use a QueryVectorDatabase step instead of @@ -78,7 +78,8 @@ def process_json_plan( - plan: typing.Union[str, Any], operators: Optional[List[Type[LogicalOperator]]] = None + plan: typing.Union[str, Any], operators: Optional[List[Type[LogicalOperator]]] = None, postProcess: bool = + False ) -> (Tuple)[LogicalOperator, Mapping[int, LogicalOperator]]: """Given the query plan provided by the LLM, return a tuple of (result_node, list of nodes).""" operators = operators or OPERATORS @@ -91,7 +92,10 @@ def process_json_plan( nodes: MutableMapping[int, LogicalOperator] = {} downstream_dependencies: Dict[int, List[int]] = {} - postprocessed_plan = postprocess_json_plan(parsed_plan) + if postProcess: + postprocessed_plan = postprocess_json_plan(parsed_plan) + else: + postprocessed_plan = parsed_plan # 1. Build nodes for step in postprocessed_plan: @@ -136,8 +140,8 @@ def postprocess_llm_helper(user_message: str) -> str: messages = [ { "role": "system", - "content": """You are a helpful agent that assists in small transformations - of input text as per the instructions. You should make minimal changes + "content": """You are a helpful agent that assists in small transformations + of input text as per the instructions. You should make minimal changes to the provided input and keep your response short""", }, {"role": "user", "content": user_message}, @@ -169,9 +173,9 @@ def postprocess_json_plan(parsed_plan: Any) -> Any: modified_description = postprocess_llm_helper( f""" - The following is the description of a Python function. I am modifying the function code - to remove any functionality that has to do with "{op['query_phrase']}". - Return only the modified description. + The following is the description of a Python function. I am modifying the function code + to remove any functionality that has to do with "{op['query_phrase']}". + Return only the modified description. {op['description']}""" ) @@ -204,13 +208,13 @@ def postprocess_json_plan(parsed_plan: Any) -> Any: ) llm_op_question = postprocess_llm_helper( f""" - Generate a one-line true/false question that is appropriate to check whether an input document - satisfies {op['query_phrase']}. Keep it as generic and short as possible. Do not make assumptions - about the intent of the question that are not explicitly specified. + Generate a one-line true/false question that is appropriate to check whether an input document + satisfies {op['query_phrase']}. Keep it as generic and short as possible. Do not make assumptions + about the intent of the question that are not explicitly specified. - Here are two examples: - (1) Was this incident caused by an environmental condition? - (2) Did this incident occur in Georgia? + Here are two examples: + (1) Was this incident caused by an environmental condition? + (2) Did this incident occur in Georgia? """ ) llm_op = { @@ -627,7 +631,7 @@ def plan(self, question: str) -> LogicalPlan: """Given a question from the user, generate a logical query plan.""" llm_prompt, llm_plan = self.generate_from_llm(question) try: - result_node, nodes = process_json_plan(llm_plan, self._operators) + result_node, nodes = process_json_plan(llm_plan, self._operators, True) except Exception as e: logging.error(f"Error processing LLM-generated query plan: {e}\nPlan is:\n{llm_plan}") raise diff --git a/lib/sycamore/sycamore/tests/unit/query/test_plan.py b/lib/sycamore/sycamore/tests/unit/query/test_plan.py index 86d24c58d..c1e9599b1 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_plan.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_plan.py @@ -159,7 +159,7 @@ def vector_search_filter_plan(): def get_logical_plan(plan): - result_node, nodes = process_json_plan(plan) + result_node, nodes = process_json_plan(plan, postProcess = False) plan = LogicalPlan(result_node=result_node, nodes=nodes, query="", llm_prompt="", llm_plan="") return plan