diff --git a/apps/query-ui/queryui/Sycamore_Query.py b/apps/query-ui/queryui/Sycamore_Query.py
index 1d48b3a05..a9573291d 100644
--- a/apps/query-ui/queryui/Sycamore_Query.py
+++ b/apps/query-ui/queryui/Sycamore_Query.py
@@ -10,6 +10,7 @@
import streamlit as st
from streamlit_ace import st_ace
+from sycamore import ExecMode
from sycamore.executor import sycamore_ray_init
from sycamore.query.client import SycamoreQueryClient
from sycamore.query.logical_plan import LogicalPlan
@@ -62,6 +63,7 @@ def run_query():
s3_cache_path=st.session_state.llm_cache_dir,
trace_dir=st.session_state.trace_dir,
cache_dir=st.session_state.cache_dir,
+ exec_mode=ExecMode.LOCAL if st.session_state.local_mode else ExecMode.RAY,
)
with st.spinner("Generating plan..."):
t1 = time.time()
@@ -93,7 +95,7 @@ def run_query():
def main():
argparser = argparse.ArgumentParser()
- argparser.add_argument("--external-ray", action="store_true", help="Use external Ray process.")
+ argparser.add_argument("--local-mode", action="store_true", help="Enable Sycamore local execution mode.")
argparser.add_argument(
"--index", help="OpenSearch index name to use. If specified, only this index will be queried."
)
@@ -111,11 +113,15 @@ def main():
if "llm_cache_dir" not in st.session_state:
st.session_state.llm_cache_dir = args.llm_cache_dir
+ if "local_mode" not in st.session_state:
+ st.session_state.local_mode = args.local_mode
+
if "trace_dir" not in st.session_state:
st.session_state.trace_dir = args.trace_dir
- sycamore_ray_init(address="auto")
- client = get_sycamore_query_client()
+ if not args.local_mode:
+ sycamore_ray_init(address="auto")
+ client = get_sycamore_query_client(exec_mode=ExecMode.LOCAL if args.local_mode else ExecMode.RAY)
st.title("Sycamore Query")
st.write(f"Query cache dir: `{st.session_state.cache_dir}`")
@@ -141,11 +147,13 @@ def main():
show_schema(client, st.session_state.index)
with st.form("query_form"):
st.text_input("Query", key="query")
- col1, col2 = st.columns(2)
+ col1, col2, col3 = st.columns(3)
with col1:
submitted = st.form_submit_button("Run query")
with col2:
st.toggle("Plan only", key="plan_only", value=False)
+ with col3:
+ st.toggle("Use Ray", key="use_ray", value=True)
if submitted:
run_query()
diff --git a/apps/query-ui/queryui/configuration.py b/apps/query-ui/queryui/configuration.py
index e86a08175..e85432d2a 100644
--- a/apps/query-ui/queryui/configuration.py
+++ b/apps/query-ui/queryui/configuration.py
@@ -1,5 +1,6 @@
from typing import Optional
+from sycamore import ExecMode
from sycamore.query.client import SycamoreQueryClient
"""
@@ -24,6 +25,11 @@
def get_sycamore_query_client(
- s3_cache_path: Optional[str] = None, trace_dir: Optional[str] = None, cache_dir: Optional[str] = None
+ s3_cache_path: Optional[str] = None,
+ trace_dir: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ exec_mode: ExecMode = ExecMode.RAY,
) -> SycamoreQueryClient:
- return SycamoreQueryClient(s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir)
+ return SycamoreQueryClient(
+ s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir, sycamore_exec_mode=exec_mode
+ )
diff --git a/apps/query-ui/queryui/main.py b/apps/query-ui/queryui/main.py
index 9997f8723..8844f9048 100755
--- a/apps/query-ui/queryui/main.py
+++ b/apps/query-ui/queryui/main.py
@@ -35,6 +35,9 @@ def ray_init(**ray_args):
def main():
argparser = argparse.ArgumentParser()
+ argparser.add_argument(
+ "--exec-mode", type=str, choices=["ray", "local"], default="ray", help="Configure Sycamore execution mode."
+ )
argparser.add_argument("--chat", action="store_true", help="Only show the chat demo pane.")
argparser.add_argument(
"--index", help="OpenSearch index name to use. If specified, only this index will be queried."
@@ -83,8 +86,10 @@ def main():
trace_dir = args.trace_dir
cmdline_args.extend(["--trace-dir", trace_dir])
- ray_init()
-
+ if args.exec_mode == "ray":
+ ray_init()
+ elif args.exec_mode == "local":
+ cmdline_args.extend(["--local-mode"])
while True:
print("Starting streamlit process...", flush=True)
# Streamlit requires the -- separator to separate streamlit arguments from script arguments.
diff --git a/apps/query-ui/queryui/pages/Chat.py b/apps/query-ui/queryui/pages/Chat.py
index bbf1f2bbe..29b189959 100644
--- a/apps/query-ui/queryui/pages/Chat.py
+++ b/apps/query-ui/queryui/pages/Chat.py
@@ -16,6 +16,7 @@
import streamlit as st
import sycamore
from sycamore.data import OpenSearchQuery
+from sycamore import ExecMode
from sycamore.executor import sycamore_ray_init
from sycamore.transforms.query import OpenSearchQueryExecutor
from sycamore.query.client import SycamoreQueryClient
@@ -177,6 +178,7 @@ def query_data_source(query: str, index: str) -> Tuple[Any, Optional[Any], Optio
s3_cache_path=st.session_state.llm_cache_dir,
trace_dir=st.session_state.trace_dir,
cache_dir=st.session_state.cache_dir,
+ sycamore_exec_mode=ExecMode.LOCAL if st.session_state.local_mode else ExecMode.RAY,
)
with st.spinner("Generating plan..."):
plan = util.generate_plan(sqclient, query, index, examples=PLANNER_EXAMPLES)
@@ -310,6 +312,7 @@ def main():
argparser.add_argument(
"--index", help="OpenSearch index name to use. If specified, only this index will be queried."
)
+ argparser.add_argument("--local-mode", action="store_true", help="Enable Sycamore local execution mode.")
argparser.add_argument("--title", type=str, help="Title text.")
argparser.add_argument("--cache-dir", type=str, help="Query execution cache dir.")
argparser.add_argument("--llm-cache-dir", type=str, help="LLM query cache dir.")
@@ -331,6 +334,9 @@ def main():
if "use_cache" not in st.session_state:
st.session_state.use_cache = True
+ if "local_mode" not in st.session_state:
+ st.session_state.local_mode = args.local_mode
+
if "next_message_id" not in st.session_state:
st.session_state.next_message_id = 0
@@ -343,7 +349,8 @@ def main():
if "trace_dir" not in st.session_state:
st.session_state.trace_dir = os.path.join(os.getcwd(), "traces")
- sycamore_ray_init(address="auto")
+ if not args.local_mode:
+ sycamore_ray_init(address="auto")
st.title("Sycamore Query Chat")
st.toggle("Use RAG only", key="rag_only")
diff --git a/apps/query-ui/queryui/test_configuration.py b/apps/query-ui/queryui/test_configuration.py
index 998b19b6b..a3b955b43 100644
--- a/apps/query-ui/queryui/test_configuration.py
+++ b/apps/query-ui/queryui/test_configuration.py
@@ -7,4 +7,4 @@ def test_configuration_sha():
sha = sha256(bytes).hexdigest()
# If the change was intentional, update the hash
# Think about whether you have to make the change since everyone with a custom config will need to update it
- assert sha == "0d5a24a1edfb9e523814dc847fa34cdcde2d2e78aff03e3f4489755e12be2c54", f"hash mismatch got {sha}"
+ assert sha == "cf89894116604d4b002f2c5b6c9acf25982bf764310a9a50827608dcdc6b1b2c", f"hash mismatch got {sha}"
diff --git a/apps/query-ui/queryui/util.py b/apps/query-ui/queryui/util.py
index 5e82609ad..060f157a7 100644
--- a/apps/query-ui/queryui/util.py
+++ b/apps/query-ui/queryui/util.py
@@ -17,6 +17,8 @@
from queryui.configuration import get_sycamore_query_client
+from sycamore.data.document import DocumentSource
+
def get_schema(_client: SycamoreQueryClient, index: str) -> Dict[str, Tuple[str, Set[str]]]:
"""Return the OpenSearch schema for the given index."""
@@ -81,7 +83,7 @@ def docset_to_string(docset: DocSet, html: bool = True) -> str:
if isinstance(doc, MetadataDocument):
continue
if html:
- retval += f"**{doc.properties.get('path')}** page: {doc.properties.get('page_number', 'meta')} \n"
+ retval += f"**{doc.properties.get('path')}** \n"
retval += "| Property | Value |\n"
retval += "|----------|-------|\n"
@@ -96,13 +98,16 @@ def docset_to_string(docset: DocSet, html: bool = True) -> str:
text_content = (
doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None
)
- retval += f'*..."{text_content}"...*
'
+ if text_content:
+ retval += f'*..."{text_content}"...*
'
else:
props_dict = doc.properties.get("entity", {})
props_dict.update({p: doc.properties[p] for p in set(doc.properties) - set(BASE_PROPS)})
props_dict["text_representation"] = (
doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None
)
+ props_dict.get("_doc_source", {}).apply(lambda x: x.value if isinstance(x, DocumentSource) else x)
+
retval += json.dumps(props_dict, indent=2) + "\n"
return retval
@@ -242,17 +247,19 @@ def readdata(self):
data[col].append(row.get(col))
self.df = pd.DataFrame(data)
- def show(self):
+ def show(self, node):
"""Render the trace data."""
+ st.subheader(f"Node {self.node_id}")
+ st.markdown(f"*Description: {node.description if node else 'n/a'}*")
if self.df is None or not len(self.df):
- st.write(f"Result of node {self.node_id} — :red[0] documents")
+ st.write(":red[0] documents")
st.write("No data.")
return
all_columns = list(self.df.columns)
column_order = [c for c in self.COLUMNS if c in all_columns]
column_order += [c for c in all_columns if c not in column_order]
- st.write(f"Result of node {self.node_id} — **{len(self.df)}** documents")
+ st.write(f"**{len(self.df)}** documents")
st.dataframe(self.df, column_order=column_order)
@@ -261,11 +268,29 @@ class QueryTrace:
def __init__(self, trace_dir: str):
self.trace_dir = trace_dir
- self.node_traces = [QueryNodeTrace(trace_dir, node_id) for node_id in sorted(os.listdir(self.trace_dir))]
+ self.node_traces = []
+ self.query_plan = self._get_query_plan(self.trace_dir)
+ for dir in sorted(os.listdir(self.trace_dir)):
+ if "metadata" not in dir:
+ self.node_traces += [QueryNodeTrace(trace_dir, dir)]
+
+ def _get_query_plan(self, trace_dir: str):
+ metadata_dir = os.path.join(trace_dir, "metadata")
+ if os.path.isfile(os.path.join(trace_dir, "metadata", "query_plan.json")):
+ return LogicalPlan.parse_file(os.path.join(metadata_dir, "query_plan.json"))
+ return None
def show(self):
- for node_trace in self.node_traces:
- node_trace.show()
+ tab1, tab2 = st.tabs(["Node data", "Query plan"])
+ with tab1:
+ for node_trace in self.node_traces:
+ node_trace.show(self.query_plan.nodes.get(int(node_trace.node_id), None))
+ with tab2:
+ if self.query_plan is not None:
+ st.write(f"Query: {self.query_plan.query}")
+ st.write(self.query_plan)
+ else:
+ st.write("No query plan found")
@st.fragment
diff --git a/lib/sycamore/sycamore/data/document.py b/lib/sycamore/sycamore/data/document.py
index 8081c958d..04de7ae8e 100644
--- a/lib/sycamore/sycamore/data/document.py
+++ b/lib/sycamore/sycamore/data/document.py
@@ -9,9 +9,9 @@
class DocumentSource(Enum):
- UNKNOWN: str = "UNKNOWN"
- DB_QUERY: str = "DB_QUERY"
- DOCUMENT_RECONSTRUCTION_RETRIEVAL: str = "DOCUMENT_RECONSTRUCTION_RETRIEVAL"
+ UNKNOWN = "UNKNOWN"
+ DB_QUERY = "DB_QUERY"
+ DOCUMENT_RECONSTRUCTION_RETRIEVAL = "DOCUMENT_RECONSTRUCTION_RETRIEVAL"
class DocumentPropertyTypes:
diff --git a/lib/sycamore/sycamore/data/element.py b/lib/sycamore/sycamore/data/element.py
index e6cd12c80..d99996846 100644
--- a/lib/sycamore/sycamore/data/element.py
+++ b/lib/sycamore/sycamore/data/element.py
@@ -89,6 +89,21 @@ def __str__(self) -> str:
}
return json.dumps(d, indent=2)
+ def field_to_value(self, field: str) -> Any:
+ """
+ Extracts the value for a particular element field.
+
+ Args:
+ field: The field in dotted notation to indicate nesting, e.g. properties.schema
+
+ Returns:
+ The value associated with the document field.
+ Returns None if field does not exist in document.
+ """
+ from sycamore.utils.nested import dotted_lookup
+
+ return dotted_lookup(self, field)
+
class ImageElement(Element):
def __init__(
diff --git a/lib/sycamore/sycamore/query/client.py b/lib/sycamore/sycamore/query/client.py
index c3e98f177..016378f6c 100755
--- a/lib/sycamore/sycamore/query/client.py
+++ b/lib/sycamore/sycamore/query/client.py
@@ -19,10 +19,12 @@
import structlog
import sycamore
-from sycamore import Context
+from sycamore import Context, ExecMode
+from sycamore.context import OperationTypes
from sycamore.llms.openai import OpenAI, OpenAIModels
from sycamore.transforms.embed import SentenceTransformerEmbedder
from sycamore.transforms.query import OpenSearchQueryExecutor
+from sycamore.transforms.similarity import HuggingFaceTransformersSimilarityScorer
from sycamore.utils.cache import cache_from_path
from sycamore.utils.import_utils import requires_modules
@@ -124,6 +126,7 @@ def __init__(
os_client_args: Optional[dict] = None,
trace_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
+ sycamore_exec_mode: ExecMode = ExecMode.RAY,
):
from opensearchpy import OpenSearch
@@ -131,6 +134,7 @@ def __init__(
self.os_config = os_config
self.trace_dir = trace_dir
self.cache_dir = cache_dir
+ self.sycamore_exec_mode = sycamore_exec_mode
# TODO: remove these assertions and simplify the code to get all customization via the
# context.
@@ -141,7 +145,7 @@ def __init__(
raise AssertionError("setting s3_cache_path requires context==None. See Notes in class documentation.")
os_client_args = os_client_args or DEFAULT_OS_CLIENT_ARGS
- self.context = context or self._get_default_context(s3_cache_path, os_client_args)
+ self.context = context or self._get_default_context(s3_cache_path, os_client_args, sycamore_exec_mode)
assert self.context.params, "Could not find required params in Context"
self.os_client_args = self.context.params.get("opensearch", {}).get("os_client_args", os_client_args)
@@ -242,17 +246,25 @@ def default_text_embedder():
return SentenceTransformerEmbedder(batch_size=100, model_name="sentence-transformers/all-MiniLM-L6-v2")
@staticmethod
- def _get_default_context(s3_cache_path, os_client_args) -> Context:
+ def _get_default_context(s3_cache_path, os_client_args, sycamore_exec_mode) -> Context:
context_params = {
"default": {"llm": OpenAI(OpenAIModels.GPT_4O.value, cache=cache_from_path(s3_cache_path))},
"opensearch": {
"os_client_args": os_client_args,
"text_embedder": SycamoreQueryClient.default_text_embedder(),
},
+ OperationTypes.BINARY_CLASSIFIER: {
+ "llm": OpenAI(OpenAIModels.GPT_4O_MINI.value, cache=cache_from_path(s3_cache_path))
+ },
+ OperationTypes.INFORMATION_EXTRACTOR: {
+ "llm": OpenAI(OpenAIModels.GPT_4O_MINI.value, cache=cache_from_path(s3_cache_path))
+ },
+ OperationTypes.TEXT_SIMILARITY: {"similarity_scorer": HuggingFaceTransformersSimilarityScorer()},
}
- return sycamore.init(params=context_params)
+ return sycamore.init(params=context_params, exec_mode=sycamore_exec_mode)
- def result_to_str(self, result: Any, max_docs: int = 100, max_chars_per_doc: int = 2500) -> str:
+ @staticmethod
+ def result_to_str(result: Any, max_docs: int = 100, max_chars_per_doc: int = 2500) -> str:
"""Convert a query result to a string.
Args:
diff --git a/lib/sycamore/sycamore/query/execution/physical_operator.py b/lib/sycamore/sycamore/query/execution/physical_operator.py
index 411ed65bd..d86b5f3d5 100644
--- a/lib/sycamore/sycamore/query/execution/physical_operator.py
+++ b/lib/sycamore/sycamore/query/execution/physical_operator.py
@@ -82,12 +82,12 @@ def execute(self) -> Any:
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
assert isinstance(self.logical_node, Math)
- assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 2
+ assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 2
operator = self.logical_node.operation
result = f"""
{output_var or get_var_name(self.logical_node)} = math_operation(
- val1={input_var or get_var_name(self.logical_node.dependencies[0])},
- val2={input_var or get_var_name(self.logical_node.dependencies[1])},
+ val1={input_var or get_var_name(self.logical_node.get_dependencies()[0])},
+ val2={input_var or get_var_name(self.logical_node.get_dependencies()[1])},
operator='{operator}'
)
"""
diff --git a/lib/sycamore/sycamore/query/execution/sycamore_executor.py b/lib/sycamore/sycamore/query/execution/sycamore_executor.py
index 1cb8452f0..71f55a9c8 100644
--- a/lib/sycamore/sycamore/query/execution/sycamore_executor.py
+++ b/lib/sycamore/sycamore/query/execution/sycamore_executor.py
@@ -110,8 +110,8 @@ def process_node(self, logical_node: LogicalOperator, query_id: str) -> Any:
cache_dir = None
# Process dependencies
- if logical_node.dependencies:
- for dependency in logical_node.dependencies:
+ if logical_node.get_dependencies():
+ for dependency in logical_node.get_dependencies():
assert isinstance(dependency, LogicalOperator)
inputs += [self.process_node(dependency, query_id)]
@@ -212,7 +212,7 @@ def process_node(self, logical_node: LogicalOperator, query_id: str) -> Any:
raise ValueError(f"Unsupported node type: {str(logical_node)}")
code, imports = operation.script(
- output_var=(self.OUTPUT_VAR_NAME if not logical_node.downstream_nodes else None)
+ output_var=(self.OUTPUT_VAR_NAME if not logical_node.get_downstream_nodes() else None)
)
self.imports += imports
self.node_id_to_code[logical_node.node_id] = code
@@ -257,12 +257,24 @@ def get_code_string(self):
"""
return result
+ def _write_query_plan_to_trace_dir(self, plan: LogicalPlan, query_id: str):
+ assert self.trace_dir is not None, "Writing query_plan requires trace_dir to be set"
+ path = os.path.join(self.trace_dir, query_id, "metadata")
+ os.makedirs(path, exist_ok=True)
+ with open(os.path.join(path, "query_plan.json"), "w") as f:
+ f.write(plan.model_dump_json())
+
def execute(self, plan: LogicalPlan, query_id: Optional[str] = None) -> Any:
try:
"""Execute a logical plan using Sycamore."""
if not query_id:
query_id = str(uuid.uuid4())
bind_contextvars(query_id=query_id)
+
+ log.info("Writing query plan to trace dir")
+ if self.trace_dir:
+ self._write_query_plan_to_trace_dir(plan, query_id)
+
log.info("Executing query")
assert isinstance(plan.result_node, LogicalOperator)
result = self.process_node(plan.result_node, query_id)
diff --git a/lib/sycamore/sycamore/query/execution/sycamore_operator.py b/lib/sycamore/sycamore/query/execution/sycamore_operator.py
index d2367b8fa..61729205a 100644
--- a/lib/sycamore/sycamore/query/execution/sycamore_operator.py
+++ b/lib/sycamore/sycamore/query/execution/sycamore_operator.py
@@ -96,7 +96,9 @@ def execute(self) -> Any:
os_query = {"query": self.logical_node.query}
else:
os_query = {}
- result = self.context.read.opensearch(index_name=self.logical_node.index, query=os_query)
+ result = self.context.read.opensearch(
+ index_name=self.logical_node.index, query=os_query, reconstruct_document=True
+ )
return result
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
@@ -108,7 +110,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
return (
f"""
{output_var or get_var_name(self.logical_node)} = context.read.opensearch(
- index_name='{self.logical_node.index}', query={os_query}
+ index_name='{self.logical_node.index}', query={os_query}, reconstruct_document=True
)
""",
[],
@@ -141,7 +143,9 @@ def execute(self) -> Any:
os_query = get_knn_query(query_phrase=self.logical_node.query_phrase, context=self.context)
if self.logical_node.opensearch_filter:
os_query["query"]["knn"]["embedding"]["filter"] = self.logical_node.opensearch_filter
- result = self.context.read.opensearch(index_name=self.logical_node.index, query=os_query)
+ result = self.context.read.opensearch(
+ index_name=self.logical_node.index, query=os_query, reconstruct_document=True
+ ).rerank(query=self.logical_node.query_phrase)
return result
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
@@ -153,8 +157,9 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
result += f"""
{output_var or get_var_name(self.logical_node)} = context.read.opensearch(
index_name='{self.logical_node.index}',
- query=os_query
-)
+ query=os_query,
+ reconstruct_document=True
+).rerank(query={self.logical_node.query_phrase})
"""
return (
result,
@@ -198,12 +203,12 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
assert isinstance(self.logical_node, SummarizeData)
question = self.logical_node.question
description = self.logical_node.description
- assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) >= 1
+ assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) >= 1
logical_deps_str = ""
- for i, inp in enumerate(self.logical_node.dependencies):
+ for i, inp in enumerate(self.logical_node.get_dependencies()):
logical_deps_str += input_var or get_var_name(inp)
- if i != len(self.logical_node.dependencies) - 1:
+ if i != len(self.logical_node.get_dependencies()) - 1:
logical_deps_str += ", "
result = f"""
@@ -223,7 +228,8 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
class SycamoreLlmFilter(SycamoreOperator):
"""
Use an LLM to filter records on a Docset.
- Args:
+ If field == text_representation, the filter is run
+ on the elements of the document (i.e. use_elements = True)
"""
def __init__(
@@ -245,7 +251,7 @@ def execute(self) -> Any:
context=self.context, val_key="llm", param_names=[OperationTypes.BINARY_CLASSIFIER.value]
),
LLM,
- ), "LLMFilter requies an 'llm' configured on the Context"
+ ), "SyamoreLlmFilter requires an 'llm' configured on the Context"
question = self.logical_node.question
field = self.logical_node.field
@@ -257,14 +263,15 @@ def execute(self) -> Any:
new_field="_autogen_LLMFilterOutput",
prompt=prompt,
field=field,
+ use_elements=(field == "text_representation"),
**self.get_node_args(),
)
return result
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
- assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 1
+ assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 1
assert isinstance(self.logical_node, LlmFilter)
- input_str = input_var or get_var_name(self.logical_node.dependencies[0])
+ input_str = input_var or get_var_name(self.logical_node.get_dependencies()[0])
output_str = output_var or get_var_name(self.logical_node)
result = f"""
prompt = LlmFilterMessagesPrompt(filter_question='{self.logical_node.question}').as_messages()
@@ -272,6 +279,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
new_field='_autogen_LLMFilterOutput',
prompt=prompt,
field='{self.logical_node.field}',
+ use_elements={(self.logical_node.field == "text_representation")},
**{self.get_node_args()},
)
"""
@@ -321,10 +329,10 @@ def execute(self) -> Any:
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
assert isinstance(self.logical_node, BasicFilter)
- assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 1
+ assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 1
imports: list[str] = []
- input_str = input_var or get_var_name(self.logical_node.dependencies[0])
+ input_str = input_var or get_var_name(self.logical_node.get_dependencies()[0])
output_str = output_var or get_var_name(self.logical_node)
if self.logical_node.range_filter:
field = self.logical_node.field
@@ -391,15 +399,15 @@ def execute(self) -> Any:
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
assert isinstance(self.logical_node, Count)
- assert self.logical_node.dependencies is not None and len(self.logical_node.dependencies) == 1
+ assert self.logical_node.get_dependencies() is not None and len(self.logical_node.get_dependencies()) == 1
distinct_field = self.logical_node.distinct_field
imports: list[str] = []
script = f"""{output_var or get_var_name(self.logical_node)} ="""
if distinct_field is None:
- script += f"""{input_var or get_var_name(self.logical_node.dependencies[0])}.count("""
+ script += f"""{input_var or get_var_name(self.logical_node.get_dependencies()[0])}.count("""
else:
- script += f"""{input_var or get_var_name(self.logical_node.dependencies[0])}.count_distinct("""
+ script += f"""{input_var or get_var_name(self.logical_node.get_dependencies()[0])}.count_distinct("""
script += f"""field='{distinct_field}', """
script += f"""**{get_str_for_dict(self.get_execute_args())})"""
return script, imports
@@ -432,7 +440,7 @@ def execute(self) -> Any:
context=self.context, val_key="llm", param_names=[OperationTypes.INFORMATION_EXTRACTOR.value]
),
LLM,
- ), "LLMExtractEntity requies an 'llm' configured on the Context"
+ ), "LLMExtractEntity requires an 'llm' configured on the Context"
question = logical_node.question
new_field = logical_node.new_field
@@ -446,7 +454,7 @@ def execute(self) -> Any:
entity_extractor = OpenAIEntityExtractor(
entity_name=new_field,
- use_elements=False,
+ use_elements=True,
prompt=prompt,
field=field,
)
@@ -461,9 +469,9 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
field = logical_node.field
fmt = logical_node.new_field_type
discrete = logical_node.discrete
- assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1
+ assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1
- input_str = input_var or get_var_name(logical_node.dependencies[0])
+ input_str = input_var or get_var_name(logical_node.get_dependencies()[0])
output_str = output_var or get_var_name(logical_node)
result = f"""
@@ -473,7 +481,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
entity_extractor = OpenAIEntityExtractor(
entity_name='{new_field}',
- use_elements=False,
+ use_elements=True,
prompt=prompt,
field='{field}',
)
@@ -524,10 +532,10 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
descending = logical_node.descending
field = logical_node.field
default_value = logical_node.default_value
- assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1
+ assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1
result = f"""
-{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.sort(
+{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.sort(
descending={descending},
field='{field}'
default_val={default_value}
@@ -538,6 +546,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
class SycamoreTopK(SycamoreOperator):
"""
+ Note: top_k clustering only operators on properties, it will not cluster on text_representation currently.
Return the Top-K values from a DocSet
"""
@@ -550,6 +559,9 @@ def __init__(
trace_dir: Optional[str] = None,
) -> None:
super().__init__(context, logical_node, query_id, inputs, trace_dir=trace_dir)
+ assert (
+ self.logical_node.primary_field != "text_representation" # type: ignore[attr-defined]
+ ), "TopK can only operate on properties"
def execute(self) -> Any:
assert self.inputs and len(self.inputs) == 1, "TopK requires 1 input node"
@@ -580,10 +592,10 @@ def execute(self) -> Any:
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
logical_node = self.logical_node
assert isinstance(logical_node, TopK)
- assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1
+ assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1
result = f"""
-{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.top_k(
+{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.top_k(
field='{logical_node.field}',
k={logical_node.K},
descending={logical_node.descending},
@@ -636,15 +648,15 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
assert isinstance(logical_node, FieldIn)
field1 = logical_node.field_one
field2 = logical_node.field_two
- assert logical_node.dependencies is not None and len(logical_node.dependencies) == 2
+ assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 2
result = f"""
-{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.field_in(
- docset2={input_var or get_var_name(logical_node.dependencies[2])},
+{output_var or get_var_name(self.logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.field_in(
+ docset2={input_var or get_var_name(logical_node.get_dependencies()[2])},
field1='{field1}',
field2='{field2}'
)
-"""
+""" # noqa: E501
return result, []
@@ -676,10 +688,10 @@ def execute(self) -> Any:
def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
logical_node = self.logical_node
assert isinstance(logical_node, Limit)
- assert logical_node.dependencies is not None and len(logical_node.dependencies) == 1
+ assert logical_node.get_dependencies() is not None and len(logical_node.get_dependencies()) == 1
result = f"""
-{output_var or get_var_name(logical_node)} = {input_var or get_var_name(logical_node.dependencies[0])}.limit(
+{output_var or get_var_name(logical_node)} = {input_var or get_var_name(logical_node.get_dependencies()[0])}.limit(
{logical_node.num_records},
**{get_str_for_dict(self.get_execute_args())},
)
diff --git a/lib/sycamore/sycamore/query/logical_plan.py b/lib/sycamore/sycamore/query/logical_plan.py
index eac8f5b07..35dea21a5 100644
--- a/lib/sycamore/sycamore/query/logical_plan.py
+++ b/lib/sycamore/sycamore/query/logical_plan.py
@@ -5,7 +5,7 @@
from hashlib import sha256
-from pydantic import BaseModel, ConfigDict, SerializeAsAny
+from pydantic import BaseModel, ConfigDict, SerializeAsAny, computed_field
def exclude_from_comparison(func):
@@ -33,6 +33,9 @@ class Node(BaseModel):
node_id: int
"""A unique integer ID representing this node."""
+ description: Optional[str] = None
+ """A detailed description of why this operator was chosen for this query plan."""
+
# These are underscored here to prevent them from leaking out to the
# input_schema used by the planner.
@@ -40,16 +43,24 @@ class Node(BaseModel):
_downstream_nodes: List["Node"] = []
_cache_key: Optional[str] = None
- @property
- def dependencies(self) -> Optional[List["Node"]]:
+ def get_dependencies(self) -> List["Node"]:
"""The nodes that this node depends on."""
return self._dependencies
- @property
- def downstream_nodes(self) -> Optional[List["Node"]]:
+ def get_downstream_nodes(self) -> List["Node"]:
"""The nodes that depend on this node."""
return self._downstream_nodes
+ @property
+ @computed_field
+ def dependencies(self) -> List[int]:
+ return [dep.node_id for dep in self._dependencies]
+
+ @property
+ @computed_field
+ def downstream_nodes(self) -> List[int]:
+ return [dep.node_id for dep in self._downstream_nodes]
+
def __str__(self) -> str:
return f"Id: {self.node_id} Op: {type(self).__name__}"
diff --git a/lib/sycamore/sycamore/query/operators/logical_operator.py b/lib/sycamore/sycamore/query/operators/logical_operator.py
index ce0517684..c5bd214c7 100644
--- a/lib/sycamore/sycamore/query/operators/logical_operator.py
+++ b/lib/sycamore/sycamore/query/operators/logical_operator.py
@@ -16,9 +16,6 @@ class LogicalOperator(Node):
Logical operator class for LLM prompting.
"""
- description: Optional[str] = None
- """A detailed description of why this operator was chosen for this query plan."""
-
input: Optional[List[int]] = None
"""A list of node IDs that this operation depends on."""
diff --git a/lib/sycamore/sycamore/query/visualize.py b/lib/sycamore/sycamore/query/visualize.py
index 256cc8247..da7300388 100644
--- a/lib/sycamore/sycamore/query/visualize.py
+++ b/lib/sycamore/sycamore/query/visualize.py
@@ -12,8 +12,8 @@ def build_graph(plan: LogicalPlan):
else:
description = None
graph.add_node(node.node_id, description=f"{type(node).__name__}\n{description}")
- if node.dependencies:
- for dep in node.dependencies:
+ if node.get_dependencies():
+ for dep in node.get_dependencies():
graph.add_edge(dep.node_id, node.node_id)
return graph
diff --git a/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py b/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py
index 57794e043..7b913f84d 100644
--- a/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py
+++ b/lib/sycamore/sycamore/tests/integration/query/execution/test_sycamore_query.py
@@ -76,7 +76,7 @@ def test_vector_search(self, query_integration_test_index: str, codegen_mode: bo
"were there any environmentally caused incidents?",
query_integration_test_index,
schema,
- natural_language_response=True,
+ natural_language_response=False,
)
assert len(plan.nodes) == 2
assert isinstance(plan.nodes[0], QueryVectorDatabase)
diff --git a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py
index 42020aaf1..5ff8b2689 100644
--- a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py
+++ b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py
@@ -144,6 +144,7 @@ def test_vector_query_database():
mock_docset_reader_impl.opensearch.assert_called_once_with(
index_name=context.params["opensearch"]["index_name"],
query={"query": {"knn": {"embedding": {"vector": embedding, "k": 500, "filter": os_filter}}}},
+ reconstruct_document=True,
)
@@ -188,6 +189,7 @@ def test_llm_filter():
prompt=ANY,
field=logical_node.field,
name=str(logical_node.node_id),
+ use_elements=False,
)
assert result == return_doc_set
@@ -314,7 +316,7 @@ def test_llm_extract_entity():
# assert OpenAIEntityExtractor called with expected arguments
MockOpenAIEntityExtractor.assert_called_once_with(
entity_name=logical_node.new_field,
- use_elements=False,
+ use_elements=True,
prompt=ANY,
field=logical_node.field,
)
diff --git a/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py b/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py
index 8f5fbbabc..d2d421199 100644
--- a/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py
+++ b/lib/sycamore/sycamore/tests/unit/query/test_logical_operator.py
@@ -33,7 +33,12 @@ def test_node_cache_dict():
assert node4.cache_dict() == {
"operator_type": "Count",
"dependencies": [
- {"operator_type": "QueryDatabase", "dependencies": [], "index": "ntsb", "query": {"match_all": {}}}
+ {
+ "operator_type": "QueryDatabase",
+ "dependencies": [],
+ "index": "ntsb",
+ "query": {"match_all": {}},
+ }
],
"distinct_field": "temperature",
}
diff --git a/lib/sycamore/sycamore/tests/unit/query/test_plan.py b/lib/sycamore/sycamore/tests/unit/query/test_plan.py
index 86d24c58d..408a6f932 100644
--- a/lib/sycamore/sycamore/tests/unit/query/test_plan.py
+++ b/lib/sycamore/sycamore/tests/unit/query/test_plan.py
@@ -212,5 +212,6 @@ def test_compare_plans_structure_changed(llm_filter_plan):
assert diff[0].diff_type == LogicalNodeDiffType.PLAN_STRUCTURE
assert isinstance(diff[0].node_a, LlmFilter)
assert isinstance(diff[0].node_b, LlmFilter)
- assert len(diff[0].node_a.downstream_nodes) == 1
- assert len(diff[0].node_b.downstream_nodes) == 0
+ assert len(diff[0].node_a.get_downstream_nodes()) == 1
+ assert diff[0].node_a.get_downstream_nodes()[0].node_id == 2
+ assert len(diff[0].node_b.get_downstream_nodes()) == 0
diff --git a/lib/sycamore/sycamore/tests/unit/query/test_planner.py b/lib/sycamore/sycamore/tests/unit/query/test_planner.py
index 5e2be7dec..bc683f04e 100644
--- a/lib/sycamore/sycamore/tests/unit/query/test_planner.py
+++ b/lib/sycamore/sycamore/tests/unit/query/test_planner.py
@@ -108,18 +108,21 @@ def mock_generate_from_llm(self, query):
plan = planner.plan("Dummy query")
assert plan.result_node.node_id == 3
assert plan.result_node.description == "Generate an English response to the question"
- assert len(plan.result_node.dependencies) == 1
- assert plan.result_node.dependencies[0].node_id == 2
- assert plan.result_node.dependencies[0].description == "Determine how many incidents occurred in Piper aircrafts"
- assert len(plan.result_node.dependencies[0].dependencies) == 1
- assert plan.result_node.dependencies[0].dependencies[0].node_id == 1
+ assert len(plan.result_node.get_dependencies()) == 1
+ assert plan.result_node.get_dependencies()[0].node_id == 2
assert (
- plan.result_node.dependencies[0].dependencies[0].description
+ plan.result_node.get_dependencies()[0].description == "Determine how many incidents occurred in Piper aircrafts"
+ )
+ assert len(plan.result_node.get_dependencies()[0].get_dependencies()) == 1
+ assert plan.result_node.get_dependencies()[0].get_dependencies()[0].node_id == 1
+ assert (
+ plan.result_node.get_dependencies()[0].get_dependencies()[0].description
== "Filter to only include Piper aircraft incidents"
)
- assert len(plan.result_node.dependencies[0].dependencies[0].dependencies) == 1
- assert plan.result_node.dependencies[0].dependencies[0].dependencies[0].node_id == 0
- assert plan.result_node.dependencies[0].dependencies[0].dependencies[0].query == {"match_all": {}}
+ assert len(plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()) == 1
+ assert plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()[0].node_id == 0
+ assert plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()[0].query == {"match_all": {}}
assert (
- plan.result_node.dependencies[0].dependencies[0].dependencies[0].description == "Get all the airplane incidents"
+ plan.result_node.get_dependencies()[0].get_dependencies()[0].get_dependencies()[0].description
+ == "Get all the airplane incidents"
)
diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py b/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py
index 50aec433f..e865df266 100644
--- a/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py
+++ b/lib/sycamore/sycamore/tests/unit/transforms/test_entity_extraction.py
@@ -19,6 +19,11 @@ def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None):
return "alt_title"
if prompt_kwargs == {"prompt": "s3://path"} and llm_kwargs == {}:
return "alt_title"
+ if (
+ prompt_kwargs == {"messages": [{"role": "user", "content": "ELEMENT 1: Jack Black\nELEMENT 2: None\n"}]}
+ and llm_kwargs == {}
+ ):
+ return "Jack Black"
return "title"
def is_chat_mode(self):
@@ -38,7 +43,7 @@ class TestEntityExtraction:
{
"type": "title",
"content": {"binary": None, "text": "text1"},
- "properties": {"coordinates": [(1, 2)], "page_number": 1},
+ "properties": {"coordinates": [(1, 2)], "page_number": 1, "entity": {"author": "Jack Black"}},
},
{
"type": "table",
@@ -59,6 +64,18 @@ def test_extract_entity_zero_shot(self, mocker):
output_dataset = extract_entity.execute()
assert Document.from_row(output_dataset.take(1)[0]).properties.get("title") == "title"
+ def test_extract_entity_zero_shot_custom_field(self, mocker):
+ node = mocker.Mock(spec=Node)
+ llm = MockLLM()
+ extract_entity = ExtractEntity(
+ node, entity_extractor=OpenAIEntityExtractor("title", llm=llm, field="properties.entity.author")
+ )
+ input_dataset = ray.data.from_items([{"doc": self.doc.serialize()}])
+ execute = mocker.patch.object(node, "execute")
+ execute.return_value = input_dataset
+ output_dataset = extract_entity.execute()
+ assert Document.from_row(output_dataset.take(1)[0]).properties.get("title") == "Jack Black"
+
def test_extract_entity_w_context_llm(self, mocker):
node = mocker.Mock(spec=Node)
llm = MockLLM()
diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py
index d5402ad67..bdb95a053 100644
--- a/lib/sycamore/sycamore/transforms/extract_entity.py
+++ b/lib/sycamore/sycamore/transforms/extract_entity.py
@@ -13,10 +13,11 @@
from sycamore.utils.time_trace import timetrace
-def element_list_formatter(elements: list[Element]) -> str:
+def element_list_formatter(elements: list[Element], field: str = "text_representation") -> str:
query = ""
for i in range(len(elements)):
- query += f"ELEMENT {i + 1}: {elements[i].text_representation}\n"
+ value = str(elements[i].field_to_value(field))
+ query += f"ELEMENT {i + 1}: {value}\n"
return query
@@ -68,10 +69,10 @@ def __init__(
llm: Optional[LLM] = None,
prompt_template: Optional[str] = None,
num_of_elements: int = 10,
- prompt_formatter: Callable[[list[Element]], str] = element_list_formatter,
+ prompt_formatter: Callable[[list[Element], str], str] = element_list_formatter,
use_elements: Optional[bool] = True,
prompt: Optional[Union[list[dict], str]] = [],
- field: Optional[str] = None,
+ field: str = "text_representation",
):
super().__init__(entity_name)
self._llm = llm
@@ -89,10 +90,7 @@ def extract_entity(
) -> Document:
self._llm = llm or self._llm
if self._use_elements:
- if self._prompt_template:
- entities = self._handle_few_shot_prompting(document)
- else:
- entities = self._handle_zero_shot_prompting(document)
+ entities = self._handle_element_prompting(document)
else:
if self._prompt is None:
raise Exception("prompt must be specified if use_elements is False")
@@ -102,33 +100,27 @@ def extract_entity(
return document
- def _handle_few_shot_prompting(self, document: Document) -> Any:
- assert self._llm is not None
- sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))]
-
- prompt = EntityExtractorFewShotGuidancePrompt()
-
- entities = self._llm.generate(
- prompt_kwargs={
- "prompt": prompt,
- "entity": self._entity_name,
- "examples": self._prompt_template,
- "query": self._prompt_formatter(sub_elements),
- }
- )
- return entities
-
- def _handle_zero_shot_prompting(self, document: Document) -> Any:
+ def _handle_element_prompting(self, document: Document) -> Any:
assert self._llm is not None
sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))]
-
- prompt = EntityExtractorZeroShotGuidancePrompt()
-
- entities = self._llm.generate(
- prompt_kwargs={"prompt": prompt, "entity": self._entity_name, "query": self._prompt_formatter(sub_elements)}
- )
-
- return entities
+ content = self._prompt_formatter(sub_elements, self._field)
+ if self._prompt is None:
+ prompt: Any = None
+ if self._prompt_template:
+ prompt = EntityExtractorFewShotGuidancePrompt()
+ else:
+ prompt = EntityExtractorZeroShotGuidancePrompt()
+ entities = self._llm.generate(
+ prompt_kwargs={
+ "prompt": prompt,
+ "entity": self._entity_name,
+ "query": content,
+ "examples": self._prompt_template,
+ }
+ )
+ return entities
+ else:
+ return self._get_entities(content)
def _handle_document_field_prompting(self, document: Document) -> Any:
assert self._llm is not None
@@ -137,13 +129,18 @@ def _handle_document_field_prompting(self, document: Document) -> Any:
value = str(document.field_to_value(self._field))
+ return self._get_entities(value)
+
+ def _get_entities(self, content: str, prompt: Optional[Union[list[dict], str]] = None):
+ assert self._llm is not None
+ prompt = prompt or self._prompt
+ assert prompt is not None, "No prompt found for entity extraction"
if isinstance(self._prompt, str):
- prompt = self._prompt + value
+ prompt = self._prompt + content
response = self._llm.generate(prompt_kwargs={"prompt": prompt}, llm_kwargs={})
else:
- messages = (self._prompt or []) + [{"role": "user", "content": value}]
+ messages = (self._prompt or []) + [{"role": "user", "content": content}]
response = self._llm.generate(prompt_kwargs={"messages": messages}, llm_kwargs={})
-
return response