diff --git a/apps/query-ui/queryui/Sycamore_Query.py b/apps/query-ui/queryui/Sycamore_Query.py index 1d48b3a05..a9573291d 100644 --- a/apps/query-ui/queryui/Sycamore_Query.py +++ b/apps/query-ui/queryui/Sycamore_Query.py @@ -10,6 +10,7 @@ import streamlit as st from streamlit_ace import st_ace +from sycamore import ExecMode from sycamore.executor import sycamore_ray_init from sycamore.query.client import SycamoreQueryClient from sycamore.query.logical_plan import LogicalPlan @@ -62,6 +63,7 @@ def run_query(): s3_cache_path=st.session_state.llm_cache_dir, trace_dir=st.session_state.trace_dir, cache_dir=st.session_state.cache_dir, + exec_mode=ExecMode.LOCAL if st.session_state.local_mode else ExecMode.RAY, ) with st.spinner("Generating plan..."): t1 = time.time() @@ -93,7 +95,7 @@ def run_query(): def main(): argparser = argparse.ArgumentParser() - argparser.add_argument("--external-ray", action="store_true", help="Use external Ray process.") + argparser.add_argument("--local-mode", action="store_true", help="Enable Sycamore local execution mode.") argparser.add_argument( "--index", help="OpenSearch index name to use. If specified, only this index will be queried." ) @@ -111,11 +113,15 @@ def main(): if "llm_cache_dir" not in st.session_state: st.session_state.llm_cache_dir = args.llm_cache_dir + if "local_mode" not in st.session_state: + st.session_state.local_mode = args.local_mode + if "trace_dir" not in st.session_state: st.session_state.trace_dir = args.trace_dir - sycamore_ray_init(address="auto") - client = get_sycamore_query_client() + if not args.local_mode: + sycamore_ray_init(address="auto") + client = get_sycamore_query_client(exec_mode=ExecMode.LOCAL if args.local_mode else ExecMode.RAY) st.title("Sycamore Query") st.write(f"Query cache dir: `{st.session_state.cache_dir}`") @@ -141,11 +147,13 @@ def main(): show_schema(client, st.session_state.index) with st.form("query_form"): st.text_input("Query", key="query") - col1, col2 = st.columns(2) + col1, col2, col3 = st.columns(3) with col1: submitted = st.form_submit_button("Run query") with col2: st.toggle("Plan only", key="plan_only", value=False) + with col3: + st.toggle("Use Ray", key="use_ray", value=True) if submitted: run_query() diff --git a/apps/query-ui/queryui/configuration.py b/apps/query-ui/queryui/configuration.py index e86a08175..e85432d2a 100644 --- a/apps/query-ui/queryui/configuration.py +++ b/apps/query-ui/queryui/configuration.py @@ -1,5 +1,6 @@ from typing import Optional +from sycamore import ExecMode from sycamore.query.client import SycamoreQueryClient """ @@ -24,6 +25,11 @@ def get_sycamore_query_client( - s3_cache_path: Optional[str] = None, trace_dir: Optional[str] = None, cache_dir: Optional[str] = None + s3_cache_path: Optional[str] = None, + trace_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + exec_mode: ExecMode = ExecMode.RAY, ) -> SycamoreQueryClient: - return SycamoreQueryClient(s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir) + return SycamoreQueryClient( + s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir, sycamore_exec_mode=exec_mode + ) diff --git a/apps/query-ui/queryui/main.py b/apps/query-ui/queryui/main.py index 9997f8723..8844f9048 100755 --- a/apps/query-ui/queryui/main.py +++ b/apps/query-ui/queryui/main.py @@ -35,6 +35,9 @@ def ray_init(**ray_args): def main(): argparser = argparse.ArgumentParser() + argparser.add_argument( + "--exec-mode", type=str, choices=["ray", "local"], default="ray", help="Configure Sycamore execution mode." + ) argparser.add_argument("--chat", action="store_true", help="Only show the chat demo pane.") argparser.add_argument( "--index", help="OpenSearch index name to use. If specified, only this index will be queried." @@ -83,8 +86,10 @@ def main(): trace_dir = args.trace_dir cmdline_args.extend(["--trace-dir", trace_dir]) - ray_init() - + if args.exec_mode == "ray": + ray_init() + elif args.exec_mode == "local": + cmdline_args.extend(["--local-mode"]) while True: print("Starting streamlit process...", flush=True) # Streamlit requires the -- separator to separate streamlit arguments from script arguments. diff --git a/apps/query-ui/queryui/pages/Chat.py b/apps/query-ui/queryui/pages/Chat.py index bbf1f2bbe..29b189959 100644 --- a/apps/query-ui/queryui/pages/Chat.py +++ b/apps/query-ui/queryui/pages/Chat.py @@ -16,6 +16,7 @@ import streamlit as st import sycamore from sycamore.data import OpenSearchQuery +from sycamore import ExecMode from sycamore.executor import sycamore_ray_init from sycamore.transforms.query import OpenSearchQueryExecutor from sycamore.query.client import SycamoreQueryClient @@ -177,6 +178,7 @@ def query_data_source(query: str, index: str) -> Tuple[Any, Optional[Any], Optio s3_cache_path=st.session_state.llm_cache_dir, trace_dir=st.session_state.trace_dir, cache_dir=st.session_state.cache_dir, + sycamore_exec_mode=ExecMode.LOCAL if st.session_state.local_mode else ExecMode.RAY, ) with st.spinner("Generating plan..."): plan = util.generate_plan(sqclient, query, index, examples=PLANNER_EXAMPLES) @@ -310,6 +312,7 @@ def main(): argparser.add_argument( "--index", help="OpenSearch index name to use. If specified, only this index will be queried." ) + argparser.add_argument("--local-mode", action="store_true", help="Enable Sycamore local execution mode.") argparser.add_argument("--title", type=str, help="Title text.") argparser.add_argument("--cache-dir", type=str, help="Query execution cache dir.") argparser.add_argument("--llm-cache-dir", type=str, help="LLM query cache dir.") @@ -331,6 +334,9 @@ def main(): if "use_cache" not in st.session_state: st.session_state.use_cache = True + if "local_mode" not in st.session_state: + st.session_state.local_mode = args.local_mode + if "next_message_id" not in st.session_state: st.session_state.next_message_id = 0 @@ -343,7 +349,8 @@ def main(): if "trace_dir" not in st.session_state: st.session_state.trace_dir = os.path.join(os.getcwd(), "traces") - sycamore_ray_init(address="auto") + if not args.local_mode: + sycamore_ray_init(address="auto") st.title("Sycamore Query Chat") st.toggle("Use RAG only", key="rag_only") diff --git a/apps/query-ui/queryui/test_configuration.py b/apps/query-ui/queryui/test_configuration.py index 998b19b6b..a3b955b43 100644 --- a/apps/query-ui/queryui/test_configuration.py +++ b/apps/query-ui/queryui/test_configuration.py @@ -7,4 +7,4 @@ def test_configuration_sha(): sha = sha256(bytes).hexdigest() # If the change was intentional, update the hash # Think about whether you have to make the change since everyone with a custom config will need to update it - assert sha == "0d5a24a1edfb9e523814dc847fa34cdcde2d2e78aff03e3f4489755e12be2c54", f"hash mismatch got {sha}" + assert sha == "cf89894116604d4b002f2c5b6c9acf25982bf764310a9a50827608dcdc6b1b2c", f"hash mismatch got {sha}" diff --git a/apps/query-ui/queryui/util.py b/apps/query-ui/queryui/util.py index 5e82609ad..060f157a7 100644 --- a/apps/query-ui/queryui/util.py +++ b/apps/query-ui/queryui/util.py @@ -17,6 +17,8 @@ from queryui.configuration import get_sycamore_query_client +from sycamore.data.document import DocumentSource + def get_schema(_client: SycamoreQueryClient, index: str) -> Dict[str, Tuple[str, Set[str]]]: """Return the OpenSearch schema for the given index.""" @@ -81,7 +83,7 @@ def docset_to_string(docset: DocSet, html: bool = True) -> str: if isinstance(doc, MetadataDocument): continue if html: - retval += f"**{doc.properties.get('path')}** page: {doc.properties.get('page_number', 'meta')} \n" + retval += f"**{doc.properties.get('path')}** \n" retval += "| Property | Value |\n" retval += "|----------|-------|\n" @@ -96,13 +98,16 @@ def docset_to_string(docset: DocSet, html: bool = True) -> str: text_content = ( doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None ) - retval += f'*..."{text_content}"...*

' + if text_content: + retval += f'*..."{text_content}"...*

' else: props_dict = doc.properties.get("entity", {}) props_dict.update({p: doc.properties[p] for p in set(doc.properties) - set(BASE_PROPS)}) props_dict["text_representation"] = ( doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None ) + props_dict.get("_doc_source", {}).apply(lambda x: x.value if isinstance(x, DocumentSource) else x) + retval += json.dumps(props_dict, indent=2) + "\n" return retval @@ -242,17 +247,19 @@ def readdata(self): data[col].append(row.get(col)) self.df = pd.DataFrame(data) - def show(self): + def show(self, node): """Render the trace data.""" + st.subheader(f"Node {self.node_id}") + st.markdown(f"*Description: {node.description if node else 'n/a'}*") if self.df is None or not len(self.df): - st.write(f"Result of node {self.node_id} — :red[0] documents") + st.write(":red[0] documents") st.write("No data.") return all_columns = list(self.df.columns) column_order = [c for c in self.COLUMNS if c in all_columns] column_order += [c for c in all_columns if c not in column_order] - st.write(f"Result of node {self.node_id} — **{len(self.df)}** documents") + st.write(f"**{len(self.df)}** documents") st.dataframe(self.df, column_order=column_order) @@ -261,11 +268,29 @@ class QueryTrace: def __init__(self, trace_dir: str): self.trace_dir = trace_dir - self.node_traces = [QueryNodeTrace(trace_dir, node_id) for node_id in sorted(os.listdir(self.trace_dir))] + self.node_traces = [] + self.query_plan = self._get_query_plan(self.trace_dir) + for dir in sorted(os.listdir(self.trace_dir)): + if "metadata" not in dir: + self.node_traces += [QueryNodeTrace(trace_dir, dir)] + + def _get_query_plan(self, trace_dir: str): + metadata_dir = os.path.join(trace_dir, "metadata") + if os.path.isfile(os.path.join(trace_dir, "metadata", "query_plan.json")): + return LogicalPlan.parse_file(os.path.join(metadata_dir, "query_plan.json")) + return None def show(self): - for node_trace in self.node_traces: - node_trace.show() + tab1, tab2 = st.tabs(["Node data", "Query plan"]) + with tab1: + for node_trace in self.node_traces: + node_trace.show(self.query_plan.nodes.get(int(node_trace.node_id), None)) + with tab2: + if self.query_plan is not None: + st.write(f"Query: {self.query_plan.query}") + st.write(self.query_plan) + else: + st.write("No query plan found") @st.fragment diff --git a/lib/sycamore/sycamore/data/document.py b/lib/sycamore/sycamore/data/document.py index 8081c958d..04de7ae8e 100644 --- a/lib/sycamore/sycamore/data/document.py +++ b/lib/sycamore/sycamore/data/document.py @@ -9,9 +9,9 @@ class DocumentSource(Enum): - UNKNOWN: str = "UNKNOWN" - DB_QUERY: str = "DB_QUERY" - DOCUMENT_RECONSTRUCTION_RETRIEVAL: str = "DOCUMENT_RECONSTRUCTION_RETRIEVAL" + UNKNOWN = "UNKNOWN" + DB_QUERY = "DB_QUERY" + DOCUMENT_RECONSTRUCTION_RETRIEVAL = "DOCUMENT_RECONSTRUCTION_RETRIEVAL" class DocumentPropertyTypes: diff --git a/lib/sycamore/sycamore/data/element.py b/lib/sycamore/sycamore/data/element.py index e6cd12c80..d99996846 100644 --- a/lib/sycamore/sycamore/data/element.py +++ b/lib/sycamore/sycamore/data/element.py @@ -89,6 +89,21 @@ def __str__(self) -> str: } return json.dumps(d, indent=2) + def field_to_value(self, field: str) -> Any: + """ + Extracts the value for a particular element field. + + Args: + field: The field in dotted notation to indicate nesting, e.g. properties.schema + + Returns: + The value associated with the document field. + Returns None if field does not exist in document. + """ + from sycamore.utils.nested import dotted_lookup + + return dotted_lookup(self, field) + class ImageElement(Element): def __init__( diff --git a/lib/sycamore/sycamore/query/client.py b/lib/sycamore/sycamore/query/client.py index c3e98f177..016378f6c 100755 --- a/lib/sycamore/sycamore/query/client.py +++ b/lib/sycamore/sycamore/query/client.py @@ -19,10 +19,12 @@ import structlog import sycamore -from sycamore import Context +from sycamore import Context, ExecMode +from sycamore.context import OperationTypes from sycamore.llms.openai import OpenAI, OpenAIModels from sycamore.transforms.embed import SentenceTransformerEmbedder from sycamore.transforms.query import OpenSearchQueryExecutor +from sycamore.transforms.similarity import HuggingFaceTransformersSimilarityScorer from sycamore.utils.cache import cache_from_path from sycamore.utils.import_utils import requires_modules @@ -124,6 +126,7 @@ def __init__( os_client_args: Optional[dict] = None, trace_dir: Optional[str] = None, cache_dir: Optional[str] = None, + sycamore_exec_mode: ExecMode = ExecMode.RAY, ): from opensearchpy import OpenSearch @@ -131,6 +134,7 @@ def __init__( self.os_config = os_config self.trace_dir = trace_dir self.cache_dir = cache_dir + self.sycamore_exec_mode = sycamore_exec_mode # TODO: remove these assertions and simplify the code to get all customization via the # context. @@ -141,7 +145,7 @@ def __init__( raise AssertionError("setting s3_cache_path requires context==None. See Notes in class documentation.") os_client_args = os_client_args or DEFAULT_OS_CLIENT_ARGS - self.context = context or self._get_default_context(s3_cache_path, os_client_args) + self.context = context or self._get_default_context(s3_cache_path, os_client_args, sycamore_exec_mode) assert self.context.params, "Could not find required params in Context" self.os_client_args = self.context.params.get("opensearch", {}).get("os_client_args", os_client_args) @@ -242,17 +246,25 @@ def default_text_embedder(): return SentenceTransformerEmbedder(batch_size=100, model_name="sentence-transformers/all-MiniLM-L6-v2") @staticmethod - def _get_default_context(s3_cache_path, os_client_args) -> Context: + def _get_default_context(s3_cache_path, os_client_args, sycamore_exec_mode) -> Context: context_params = { "default": {"llm": OpenAI(OpenAIModels.GPT_4O.value, cache=cache_from_path(s3_cache_path))}, "opensearch": { "os_client_args": os_client_args, "text_embedder": SycamoreQueryClient.default_text_embedder(), }, + OperationTypes.BINARY_CLASSIFIER: { + "llm": OpenAI(OpenAIModels.GPT_4O_MINI.value, cache=cache_from_path(s3_cache_path)) + }, + OperationTypes.INFORMATION_EXTRACTOR: { + "llm": OpenAI(OpenAIModels.GPT_4O_MINI.value, cache=cache_from_path(s3_cache_path)) + }, + OperationTypes.TEXT_SIMILARITY: {"similarity_scorer": HuggingFaceTransformersSimilarityScorer()}, } - return sycamore.init(params=context_params) + return sycamore.init(params=context_params, exec_mode=sycamore_exec_mode) - def result_to_str(self, result: Any, max_docs: int = 100, max_chars_per_doc: int = 2500) -> str: + @staticmethod + def result_to_str(result: Any, max_docs: int = 100, max_chars_per_doc: int = 2500) -> str: """Convert a query result to a string. Args: diff --git a/lib/sycamore/sycamore/query/execution/physical_operator.py b/lib/sycamore/sycamore/query/execution/physical_operator.py index 411ed65bd..d86b5f3d5 100644 --- a/lib/sycamore/sycamore/query/execution/physical_operator.py +++ b/lib/sycamore/sycamore/query/execution/physical_operator.py @@ -82,12 +82,12 @@ def execute(self) -> Any: def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: assert isinstance(self.logical_node, Math) - assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 2 + assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 2 operator = self.logical_node.operation result = f""" {output_var or get_var_name(self.logical_node)} = math_operation( - val1={input_var or get_var_name(self.logical_node.dependencies[0])}, - val2={input_var or get_var_name(self.logical_node.dependencies[1])}, + val1={input_var or get_var_name(self.logical_node.get_dependencies()[0])}, + val2={input_var or get_var_name(self.logical_node.get_dependencies()[1])}, operator='{operator}' ) """ diff --git a/lib/sycamore/sycamore/query/execution/sycamore_executor.py b/lib/sycamore/sycamore/query/execution/sycamore_executor.py index 1cb8452f0..71f55a9c8 100644 --- a/lib/sycamore/sycamore/query/execution/sycamore_executor.py +++ b/lib/sycamore/sycamore/query/execution/sycamore_executor.py @@ -110,8 +110,8 @@ def process_node(self, logical_node: LogicalOperator, query_id: str) -> Any: cache_dir = None # Process dependencies - if logical_node.dependencies: - for dependency in logical_node.dependencies: + if logical_node.get_dependencies(): + for dependency in logical_node.get_dependencies(): assert isinstance(dependency, LogicalOperator) inputs += [self.process_node(dependency, query_id)] @@ -212,7 +212,7 @@ def process_node(self, logical_node: LogicalOperator, query_id: str) -> Any: raise ValueError(f"Unsupported node type: {str(logical_node)}") code, imports = operation.script( - output_var=(self.OUTPUT_VAR_NAME if not logical_node.downstream_nodes else None) + output_var=(self.OUTPUT_VAR_NAME if not logical_node.get_downstream_nodes() else None) ) self.imports += imports self.node_id_to_code[logical_node.node_id] = code @@ -257,12 +257,24 @@ def get_code_string(self): """ return result + def _write_query_plan_to_trace_dir(self, plan: LogicalPlan, query_id: str): + assert self.trace_dir is not None, "Writing query_plan requires trace_dir to be set" + path = os.path.join(self.trace_dir, query_id, "metadata") + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, "query_plan.json"), "w") as f: + f.write(plan.model_dump_json()) + def execute(self, plan: LogicalPlan, query_id: Optional[str] = None) -> Any: try: """Execute a logical plan using Sycamore.""" if not query_id: query_id = str(uuid.uuid4()) bind_contextvars(query_id=query_id) + + log.info("Writing query plan to trace dir") + if self.trace_dir: + self._write_query_plan_to_trace_dir(plan, query_id) + log.info("Executing query") assert isinstance(plan.result_node, LogicalOperator) result = self.process_node(plan.result_node, query_id) diff --git a/lib/sycamore/sycamore/query/execution/sycamore_operator.py b/lib/sycamore/sycamore/query/execution/sycamore_operator.py index d2367b8fa..61729205a 100644 --- a/lib/sycamore/sycamore/query/execution/sycamore_operator.py +++ b/lib/sycamore/sycamore/query/execution/sycamore_operator.py @@ -96,7 +96,9 @@ def execute(self) -> Any: os_query = {"query": self.logical_node.query} else: os_query = {} - result = self.context.read.opensearch(index_name=self.logical_node.index, query=os_query) + result = self.context.read.opensearch( + index_name=self.logical_node.index, query=os_query, reconstruct_document=True + ) return result def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: @@ -108,7 +110,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No return ( f""" {output_var or get_var_name(self.logical_node)} = context.read.opensearch( - index_name='{self.logical_node.index}', query={os_query} + index_name='{self.logical_node.index}', query={os_query}, reconstruct_document=True ) """, [], @@ -141,7 +143,9 @@ def execute(self) -> Any: os_query = get_knn_query(query_phrase=self.logical_node.query_phrase, context=self.context) if self.logical_node.opensearch_filter: os_query["query"]["knn"]["embedding"]["filter"] = self.logical_node.opensearch_filter - result = self.context.read.opensearch(index_name=self.logical_node.index, query=os_query) + result = self.context.read.opensearch( + index_name=self.logical_node.index, query=os_query, reconstruct_document=True + ).rerank(query=self.logical_node.query_phrase) return result def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: @@ -153,8 +157,9 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No result += f""" {output_var or get_var_name(self.logical_node)} = context.read.opensearch( index_name='{self.logical_node.index}', - query=os_query -) + query=os_query, + reconstruct_document=True +).rerank(query={self.logical_node.query_phrase}) """ return ( result, @@ -198,12 +203,12 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No assert isinstance(self.logical_node, SummarizeData) question = self.logical_node.question description = self.logical_node.description - assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) >= 1 + assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) >= 1 logical_deps_str = "" - for i, inp in enumerate(self.logical_node.dependencies): + for i, inp in enumerate(self.logical_node.get_dependencies()): logical_deps_str += input_var or get_var_name(inp) - if i != len(self.logical_node.dependencies) - 1: + if i != len(self.logical_node.get_dependencies()) - 1: logical_deps_str += ", " result = f""" @@ -223,7 +228,8 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No class SycamoreLlmFilter(SycamoreOperator): """ Use an LLM to filter records on a Docset. - Args: + If field == text_representation, the filter is run + on the elements of the document (i.e. use_elements = True) """ def __init__( @@ -245,7 +251,7 @@ def execute(self) -> Any: context=self.context, val_key="llm", param_names=[OperationTypes.BINARY_CLASSIFIER.value] ), LLM, - ), "LLMFilter requies an 'llm' configured on the Context" + ), "SyamoreLlmFilter requires an 'llm' configured on the Context" question = self.logical_node.question field = self.logical_node.field @@ -257,14 +263,15 @@ def execute(self) -> Any: new_field="_autogen_LLMFilterOutput", prompt=prompt, field=field, + use_elements=(field == "text_representation"), **self.get_node_args(), ) return result def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: - assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 1 + assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 1 assert isinstance(self.logical_node, LlmFilter) - input_str = input_var or get_var_name(self.logical_node.dependencies[0]) + input_str = input_var or get_var_name(self.logical_node.get_dependencies()[0]) output_str = output_var or get_var_name(self.logical_node) result = f""" prompt = LlmFilterMessagesPrompt(filter_question='{self.logical_node.question}').as_messages() @@ -272,6 +279,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No new_field='_autogen_LLMFilterOutput', prompt=prompt, field='{self.logical_node.field}', + use_elements={(self.logical_node.field == "text_representation")}, **{self.get_node_args()}, ) """ @@ -321,10 +329,10 @@ def execute(self) -> Any: def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: assert isinstance(self.logical_node, BasicFilter) - assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 1 + assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 1 imports: list[str] = [] - input_str = input_var or get_var_name(self.logical_node.dependencies[0]) + input_str = input_var or get_var_name(self.logical_node.get_dependencies()[0]) output_str = output_var or get_var_name(self.logical_node) if self.logical_node.range_filter: field = self.logical_node.field @@ -391,15 +399,15 @@ def execute(self) -> Any: def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: assert isinstance(self.logical_node, Count) - assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 1 + assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 1 distinct_field = self.logical_node.distinct_field imports: list[str] = [] script = f"""{output_var or get_var_name(self.logical_node)} =""" if distinct_field is None: - script += f"""{input_var or get_var_name(self.logical_node.dependencies[0])}.count(""" + script += f"""{input_var or get_var_name(self.logical_node.get_dependencies()[0])}.count(""" else: - script += f"""{input_var or get_var_name(self.logical_node.dependencies[0])}.count_distinct(""" + script += f"""{input_var or get_var_name(self.logical_node.get_dependencies()[0])}.count_distinct(""" script += f"""field='{distinct_field}', """ script += f"""**{get_str_for_dict(self.get_execute_args())})""" return script, imports @@ -432,7 +440,7 @@ def execute(self) -> Any: context=self.context, val_key="llm", param_names=[OperationTypes.INFORMATION_EXTRACTOR.value] ), LLM, - ), "LLMExtractEntity requies an 'llm' configured on the Context" + ), "LLMExtractEntity requires an 'llm' configured on the Context" question = logical_node.question new_field = logical_node.new_field @@ -446,7 +454,7 @@ def execute(self) -> Any: entity_extractor = OpenAIEntityExtractor( entity_name=new_field, - use_elements=False, + use_elements=True, prompt=prompt, field=field, ) @@ -461,9 +469,9 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No field = logical_node.field fmt = logical_node.new_field_type discrete = logical_node.discrete - assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1 + assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1 - input_str = input_var or get_var_name(logical_node.dependencies[0]) + input_str = input_var or get_var_name(logical_node.get_dependencies()[0]) output_str = output_var or get_var_name(logical_node) result = f""" @@ -473,7 +481,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No entity_extractor = OpenAIEntityExtractor( entity_name='{new_field}', - use_elements=False, + use_elements=True, prompt=prompt, field='{field}', ) @@ -524,10 +532,10 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No descending = logical_node.descending field = logical_node.field default_value = logical_node.default_value - assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1 + assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1 result = f""" -{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.sort( +{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.sort( descending={descending}, field='{field}' default_val={default_value} @@ -538,6 +546,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No class SycamoreTopK(SycamoreOperator): """ + Note: top_k clustering only operators on properties, it will not cluster on text_representation currently. Return the Top-K values from a DocSet """ @@ -550,6 +559,9 @@ def __init__( trace_dir: Optional[str] = None, ) -> None: super().__init__(context, logical_node, query_id, inputs, trace_dir=trace_dir) + assert ( + self.logical_node.primary_field != "text_representation" # type: ignore[attr-defined] + ), "TopK can only operate on properties" def execute(self) -> Any: assert self.inputs and len(self.inputs) == 1, "TopK requires 1 input node" @@ -580,10 +592,10 @@ def execute(self) -> Any: def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: logical_node = self.logical_node assert isinstance(logical_node, TopK) - assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1 + assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1 result = f""" -{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.top_k( +{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.top_k( field='{logical_node.field}', k={logical_node.K}, descending={logical_node.descending}, @@ -636,15 +648,15 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No assert isinstance(logical_node, FieldIn) field1 = logical_node.field_one field2 = logical_node.field_two - assert logical_node.dependencies is not None and len(logical_node.dependencies) == 2 + assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 2 result = f""" -{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.field_in( - docset2={input_var or get_var_name(logical_node.dependencies[2])}, +{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.field_in( + docset2={input_var or get_var_name(logical_node.get_dependencies()[2])}, field1='{field1}', field2='{field2}' ) -""" +""" # noqa: E501 return result, [] @@ -676,10 +688,10 @@ def execute(self) -> Any: def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]: logical_node = self.logical_node assert isinstance(logical_node, Limit) - assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1 + assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1 result = f""" -{output_var or get_var_name(logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.limit( +{output_var or get_var_name(logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.limit( {logical_node.num_records}, **{get_str_for_dict(self.get_execute_args())}, ) diff --git a/lib/sycamore/sycamore/query/logical_plan.py b/lib/sycamore/sycamore/query/logical_plan.py index eac8f5b07..35dea21a5 100644 --- a/lib/sycamore/sycamore/query/logical_plan.py +++ b/lib/sycamore/sycamore/query/logical_plan.py @@ -5,7 +5,7 @@ from hashlib import sha256 -from pydantic import BaseModel, ConfigDict, SerializeAsAny +from pydantic import BaseModel, ConfigDict, SerializeAsAny, computed_field def exclude_from_comparison(func): @@ -33,6 +33,9 @@ class Node(BaseModel): node_id: int """A unique integer ID representing this node.""" + description: Optional[str] = None + """A detailed description of why this operator was chosen for this query plan.""" + # These are underscored here to prevent them from leaking out to the # input_schema used by the planner. @@ -40,16 +43,24 @@ class Node(BaseModel): _downstream_nodes: List["Node"] = [] _cache_key: Optional[str] = None - @property - def dependencies(self) -> Optional[List["Node"]]: + def get_dependencies(self) -> List["Node"]: """The nodes that this node depends on.""" return self._dependencies - @property - def downstream_nodes(self) -> Optional[List["Node"]]: + def get_downstream_nodes(self) -> List["Node"]: """The nodes that depend on this node.""" return self._downstream_nodes + @property + @computed_field + def dependencies(self) -> List[int]: + return [dep.node_id for dep in self._dependencies] + + @property + @computed_field + def downstream_nodes(self) -> List[int]: + return [dep.node_id for dep in self._downstream_nodes] + def __str__(self) -> str: return f"Id: {self.node_id} Op: {type(self).__name__}" diff --git a/lib/sycamore/sycamore/query/operators/logical_operator.py b/lib/sycamore/sycamore/query/operators/logical_operator.py index ce0517684..c5bd214c7 100644 --- a/lib/sycamore/sycamore/query/operators/logical_operator.py +++ b/lib/sycamore/sycamore/query/operators/logical_operator.py @@ -16,9 +16,6 @@ class LogicalOperator(Node): Logical operator class for LLM prompting. """ - description: Optional[str] = None - """A detailed description of why this operator was chosen for this query plan.""" - input: Optional[List[int]] = None """A list of node IDs that this operation depends on.""" diff --git a/lib/sycamore/sycamore/query/visualize.py b/lib/sycamore/sycamore/query/visualize.py index 256cc8247..da7300388 100644 --- a/lib/sycamore/sycamore/query/visualize.py +++ b/lib/sycamore/sycamore/query/visualize.py @@ -12,8 +12,8 @@ def build_graph(plan: LogicalPlan): else: description = None graph.add_node(node.node_id, description=f"{type(node).__name__}\n{description}") - if node.dependencies: - for dep in node.dependencies: + if node.get_dependencies(): + for dep in node.get_dependencies(): graph.add_edge(dep.node_id, node.node_id) return graph diff --git a/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py b/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py index 57794e043..7b913f84d 100644 --- a/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py +++ b/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py @@ -76,7 +76,7 @@ def test_vector_search(self, query_integration_test_index: str, codegen_mode: bo "were there any environmentally caused incidents?", query_integration_test_index, schema, - natural_language_response=True, + natural_language_response=False, ) assert len(plan.nodes) == 2 assert isinstance(plan.nodes[0], QueryVectorDatabase) diff --git a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py index 42020aaf1..5ff8b2689 100644 --- a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py +++ b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py @@ -144,6 +144,7 @@ def test_vector_query_database(): mock_docset_reader_impl.opensearch.assert_called_once_with( index_name=context.params["opensearch"]["index_name"], query={"query": {"knn": {"embedding": {"vector": embedding, "k": 500, "filter": os_filter}}}}, + reconstruct_document=True, ) @@ -188,6 +189,7 @@ def test_llm_filter(): prompt=ANY, field=logical_node.field, name=str(logical_node.node_id), + use_elements=False, ) assert result == return_doc_set @@ -314,7 +316,7 @@ def test_llm_extract_entity(): # assert OpenAIEntityExtractor called with expected arguments MockOpenAIEntityExtractor.assert_called_once_with( entity_name=logical_node.new_field, - use_elements=False, + use_elements=True, prompt=ANY, field=logical_node.field, ) diff --git a/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py b/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py index 8f5fbbabc..d2d421199 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py @@ -33,7 +33,12 @@ def test_node_cache_dict(): assert node4.cache_dict() == { "operator_type": "Count", "dependencies": [ - {"operator_type": "QueryDatabase", "dependencies": [], "index": "ntsb", "query": {"match_all": {}}} + { + "operator_type": "QueryDatabase", + "dependencies": [], + "index": "ntsb", + "query": {"match_all": {}}, + } ], "distinct_field": "temperature", } diff --git a/lib/sycamore/sycamore/tests/unit/query/test_plan.py b/lib/sycamore/sycamore/tests/unit/query/test_plan.py index 86d24c58d..408a6f932 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_plan.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_plan.py @@ -212,5 +212,6 @@ def test_compare_plans_structure_changed(llm_filter_plan): assert diff[0].diff_type == LogicalNodeDiffType.PLAN_STRUCTURE assert isinstance(diff[0].node_a, LlmFilter) assert isinstance(diff[0].node_b, LlmFilter) - assert len(diff[0].node_a.downstream_nodes) == 1 - assert len(diff[0].node_b.downstream_nodes) == 0 + assert len(diff[0].node_a.get_downstream_nodes()) == 1 + assert diff[0].node_a.get_downstream_nodes()[0].node_id == 2 + assert len(diff[0].node_b.get_downstream_nodes()) == 0 diff --git a/lib/sycamore/sycamore/tests/unit/query/test_planner.py b/lib/sycamore/sycamore/tests/unit/query/test_planner.py index 5e2be7dec..bc683f04e 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_planner.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_planner.py @@ -108,18 +108,21 @@ def mock_generate_from_llm(self, query): plan = planner.plan("Dummy query") assert plan.result_node.node_id == 3 assert plan.result_node.description == "Generate an English response to the question" - assert len(plan.result_node.dependencies) == 1 - assert plan.result_node.dependencies[0].node_id == 2 - assert plan.result_node.dependencies[0].description == "Determine how many incidents occurred in Piper aircrafts" - assert len(plan.result_node.dependencies[0].dependencies) == 1 - assert plan.result_node.dependencies[0].dependencies[0].node_id == 1 + assert len(plan.result_node.get_dependencies()) == 1 + assert plan.result_node.get_dependencies()[0].node_id == 2 assert ( - plan.result_node.dependencies[0].dependencies[0].description + plan.result_node.get_dependencies()[0].description == "Determine how many incidents occurred in Piper aircrafts" + ) + assert len(plan.result_node.get_dependencies()[0].get_dependencies()) == 1 + assert plan.result_node.get_dependencies()[0].get_dependencies()[0].node_id == 1 + assert ( + plan.result_node.get_dependencies()[0].get_dependencies()[0].description == "Filter to only include Piper aircraft incidents" ) - assert len(plan.result_node.dependencies[0].dependencies[0].dependencies) == 1 - assert plan.result_node.dependencies[0].dependencies[0].dependencies[0].node_id == 0 - assert plan.result_node.dependencies[0].dependencies[0].dependencies[0].query == {"match_all": {}} + assert len(plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()) == 1 + assert plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()[0].node_id == 0 + assert plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()[0].query == {"match_all": {}} assert ( - plan.result_node.dependencies[0].dependencies[0].dependencies[0].description == "Get all the airplane incidents" + plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()[0].description + == "Get all the airplane incidents" ) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py b/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py index 50aec433f..e865df266 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py @@ -19,6 +19,11 @@ def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): return "alt_title" if prompt_kwargs == {"prompt": "s3://path"} and llm_kwargs == {}: return "alt_title" + if ( + prompt_kwargs == {"messages": [{"role": "user", "content": "ELEMENT 1: Jack Black\nELEMENT 2: None\n"}]} + and llm_kwargs == {} + ): + return "Jack Black" return "title" def is_chat_mode(self): @@ -38,7 +43,7 @@ class TestEntityExtraction: { "type": "title", "content": {"binary": None, "text": "text1"}, - "properties": {"coordinates": [(1, 2)], "page_number": 1}, + "properties": {"coordinates": [(1, 2)], "page_number": 1, "entity": {"author": "Jack Black"}}, }, { "type": "table", @@ -59,6 +64,18 @@ def test_extract_entity_zero_shot(self, mocker): output_dataset = extract_entity.execute() assert Document.from_row(output_dataset.take(1)[0]).properties.get("title") == "title" + def test_extract_entity_zero_shot_custom_field(self, mocker): + node = mocker.Mock(spec=Node) + llm = MockLLM() + extract_entity = ExtractEntity( + node, entity_extractor=OpenAIEntityExtractor("title", llm=llm, field="properties.entity.author") + ) + input_dataset = ray.data.from_items([{"doc": self.doc.serialize()}]) + execute = mocker.patch.object(node, "execute") + execute.return_value = input_dataset + output_dataset = extract_entity.execute() + assert Document.from_row(output_dataset.take(1)[0]).properties.get("title") == "Jack Black" + def test_extract_entity_w_context_llm(self, mocker): node = mocker.Mock(spec=Node) llm = MockLLM() diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index d5402ad67..bdb95a053 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -13,10 +13,11 @@ from sycamore.utils.time_trace import timetrace -def element_list_formatter(elements: list[Element]) -> str: +def element_list_formatter(elements: list[Element], field: str = "text_representation") -> str: query = "" for i in range(len(elements)): - query += f"ELEMENT {i + 1}: {elements[i].text_representation}\n" + value = str(elements[i].field_to_value(field)) + query += f"ELEMENT {i + 1}: {value}\n" return query @@ -68,10 +69,10 @@ def __init__( llm: Optional[LLM] = None, prompt_template: Optional[str] = None, num_of_elements: int = 10, - prompt_formatter: Callable[[list[Element]], str] = element_list_formatter, + prompt_formatter: Callable[[list[Element], str], str] = element_list_formatter, use_elements: Optional[bool] = True, prompt: Optional[Union[list[dict], str]] = [], - field: Optional[str] = None, + field: str = "text_representation", ): super().__init__(entity_name) self._llm = llm @@ -89,10 +90,7 @@ def extract_entity( ) -> Document: self._llm = llm or self._llm if self._use_elements: - if self._prompt_template: - entities = self._handle_few_shot_prompting(document) - else: - entities = self._handle_zero_shot_prompting(document) + entities = self._handle_element_prompting(document) else: if self._prompt is None: raise Exception("prompt must be specified if use_elements is False") @@ -102,33 +100,27 @@ def extract_entity( return document - def _handle_few_shot_prompting(self, document: Document) -> Any: - assert self._llm is not None - sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))] - - prompt = EntityExtractorFewShotGuidancePrompt() - - entities = self._llm.generate( - prompt_kwargs={ - "prompt": prompt, - "entity": self._entity_name, - "examples": self._prompt_template, - "query": self._prompt_formatter(sub_elements), - } - ) - return entities - - def _handle_zero_shot_prompting(self, document: Document) -> Any: + def _handle_element_prompting(self, document: Document) -> Any: assert self._llm is not None sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))] - - prompt = EntityExtractorZeroShotGuidancePrompt() - - entities = self._llm.generate( - prompt_kwargs={"prompt": prompt, "entity": self._entity_name, "query": self._prompt_formatter(sub_elements)} - ) - - return entities + content = self._prompt_formatter(sub_elements, self._field) + if self._prompt is None: + prompt: Any = None + if self._prompt_template: + prompt = EntityExtractorFewShotGuidancePrompt() + else: + prompt = EntityExtractorZeroShotGuidancePrompt() + entities = self._llm.generate( + prompt_kwargs={ + "prompt": prompt, + "entity": self._entity_name, + "query": content, + "examples": self._prompt_template, + } + ) + return entities + else: + return self._get_entities(content) def _handle_document_field_prompting(self, document: Document) -> Any: assert self._llm is not None @@ -137,13 +129,18 @@ def _handle_document_field_prompting(self, document: Document) -> Any: value = str(document.field_to_value(self._field)) + return self._get_entities(value) + + def _get_entities(self, content: str, prompt: Optional[Union[list[dict], str]] = None): + assert self._llm is not None + prompt = prompt or self._prompt + assert prompt is not None, "No prompt found for entity extraction" if isinstance(self._prompt, str): - prompt = self._prompt + value + prompt = self._prompt + content response = self._llm.generate(prompt_kwargs={"prompt": prompt}, llm_kwargs={}) else: - messages = (self._prompt or []) + [{"role": "user", "content": value}] + messages = (self._prompt or []) + [{"role": "user", "content": content}] response = self._llm.generate(prompt_kwargs={"messages": messages}, llm_kwargs={}) - return response