Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable document model in sycamore.query + query-ui improvements #884

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions apps/query-ui/queryui/Sycamore_Query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import streamlit as st
from streamlit_ace import st_ace

from sycamore import ExecMode
from sycamore.executor import sycamore_ray_init
from sycamore.query.client import SycamoreQueryClient
from sycamore.query.logical_plan import LogicalPlan
Expand Down Expand Up @@ -62,6 +63,7 @@ def run_query():
s3_cache_path=st.session_state.llm_cache_dir,
trace_dir=st.session_state.trace_dir,
cache_dir=st.session_state.cache_dir,
exec_mode=ExecMode.LOCAL if st.session_state.local_mode else ExecMode.RAY,
)
with st.spinner("Generating plan..."):
t1 = time.time()
Expand Down Expand Up @@ -93,7 +95,7 @@ def run_query():

def main():
argparser = argparse.ArgumentParser()
argparser.add_argument("--external-ray", action="store_true", help="Use external Ray process.")
argparser.add_argument("--local-mode", action="store_true", help="Enable Sycamore local execution mode.")
argparser.add_argument(
"--index", help="OpenSearch index name to use. If specified, only this index will be queried."
)
Expand All @@ -111,16 +113,21 @@ def main():
if "llm_cache_dir" not in st.session_state:
st.session_state.llm_cache_dir = args.llm_cache_dir

if "local_mode" not in st.session_state:
st.session_state.local_mode = args.local_mode

if "trace_dir" not in st.session_state:
st.session_state.trace_dir = args.trace_dir

sycamore_ray_init(address="auto")
client = get_sycamore_query_client()
if not args.local_mode:
sycamore_ray_init(address="auto")
client = get_sycamore_query_client(exec_mode=ExecMode.LOCAL if args.local_mode else ExecMode.RAY)

st.title("Sycamore Query")
st.write(f"Query cache dir: `{st.session_state.cache_dir}`")
st.write(f"LLM cache dir: `{st.session_state.llm_cache_dir}`")
st.write(f"Trace dir: `{st.session_state.trace_dir}`")
st.write(f"Local mode: `{st.session_state.local_mode}`")

if not args.index:
with st.spinner("Loading indices..."):
Expand All @@ -141,11 +148,13 @@ def main():
show_schema(client, st.session_state.index)
with st.form("query_form"):
st.text_input("Query", key="query")
col1, col2 = st.columns(2)
col1, col2, col3 = st.columns(3)
with col1:
submitted = st.form_submit_button("Run query")
with col2:
st.toggle("Plan only", key="plan_only", value=False)
with col3:
st.toggle("Local mode", key="local_mode", value=False)
baitsguy marked this conversation as resolved.
Show resolved Hide resolved

if submitted:
run_query()
Expand Down
10 changes: 8 additions & 2 deletions apps/query-ui/queryui/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from sycamore import ExecMode
from sycamore.query.client import SycamoreQueryClient

"""
Expand All @@ -24,6 +25,11 @@


def get_sycamore_query_client(
s3_cache_path: Optional[str] = None, trace_dir: Optional[str] = None, cache_dir: Optional[str] = None
s3_cache_path: Optional[str] = None,
trace_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
exec_mode: ExecMode = ExecMode.RAY,
) -> SycamoreQueryClient:
return SycamoreQueryClient(s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir)
return SycamoreQueryClient(
s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir, sycamore_exec_mode=exec_mode
)
7 changes: 5 additions & 2 deletions apps/query-ui/queryui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def ray_init(**ray_args):

def main():
argparser = argparse.ArgumentParser()
argparser.add_argument("--exec-mode", type=str, default="ray", help="Configure Sycamore execution mode.")
baitsguy marked this conversation as resolved.
Show resolved Hide resolved
argparser.add_argument("--chat", action="store_true", help="Only show the chat demo pane.")
argparser.add_argument(
"--index", help="OpenSearch index name to use. If specified, only this index will be queried."
Expand Down Expand Up @@ -83,8 +84,10 @@ def main():
trace_dir = args.trace_dir
cmdline_args.extend(["--trace-dir", trace_dir])

ray_init()

if args.exec_mode == "ray":
baitsguy marked this conversation as resolved.
Show resolved Hide resolved
ray_init()
elif args.exec_mode == "local":
cmdline_args.extend(["--local-mode"])
while True:
print("Starting streamlit process...", flush=True)
# Streamlit requires the -- separator to separate streamlit arguments from script arguments.
Expand Down
8 changes: 7 additions & 1 deletion apps/query-ui/queryui/pages/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def query_data_source(query: str, index: str) -> Tuple[Any, Optional[Any], Optio
s3_cache_path=st.session_state.llm_cache_dir,
trace_dir=st.session_state.trace_dir,
cache_dir=st.session_state.cache_dir,
sycamore_exec_mode=st.session_state.exec_mode,
)
with st.spinner("Generating plan..."):
plan = util.generate_plan(sqclient, query, index, examples=PLANNER_EXAMPLES)
Expand Down Expand Up @@ -310,6 +311,7 @@ def main():
argparser.add_argument(
"--index", help="OpenSearch index name to use. If specified, only this index will be queried."
)
argparser.add_argument("--local-mode", action="store_true", help="Enable Sycamore local execution mode.")
argparser.add_argument("--title", type=str, help="Title text.")
argparser.add_argument("--cache-dir", type=str, help="Query execution cache dir.")
argparser.add_argument("--llm-cache-dir", type=str, help="LLM query cache dir.")
Expand All @@ -331,6 +333,9 @@ def main():
if "use_cache" not in st.session_state:
st.session_state.use_cache = True

if "local_mode" not in st.session_state:
st.session_state.local_mode = args.local_mode

if "next_message_id" not in st.session_state:
st.session_state.next_message_id = 0

Expand All @@ -343,7 +348,8 @@ def main():
if "trace_dir" not in st.session_state:
st.session_state.trace_dir = os.path.join(os.getcwd(), "traces")

sycamore_ray_init(address="auto")
if not args.local_mode:
sycamore_ray_init(address="auto")
st.title("Sycamore Query Chat")
st.toggle("Use RAG only", key="rag_only")

Expand Down
2 changes: 1 addition & 1 deletion apps/query-ui/queryui/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def test_configuration_sha():
sha = sha256(bytes).hexdigest()
# If the change was intentional, update the hash
# Think about whether you have to make the change since everyone with a custom config will need to update it
assert sha == "0d5a24a1edfb9e523814dc847fa34cdcde2d2e78aff03e3f4489755e12be2c54", f"hash mismatch got {sha}"
assert sha == "cf89894116604d4b002f2c5b6c9acf25982bf764310a9a50827608dcdc6b1b2c", f"hash mismatch got {sha}"
mdwelsh marked this conversation as resolved.
Show resolved Hide resolved
60 changes: 52 additions & 8 deletions apps/query-ui/queryui/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def docset_to_string(docset: DocSet, html: bool = True) -> str:
if isinstance(doc, MetadataDocument):
continue
if html:
retval += f"**{doc.properties.get('path')}** page: {doc.properties.get('page_number', 'meta')} \n"
retval += f"**{doc.properties.get('path')}** \n"

retval += "| Property | Value |\n"
retval += "|----------|-------|\n"
Expand All @@ -96,7 +96,8 @@ def docset_to_string(docset: DocSet, html: bool = True) -> str:
text_content = (
doc.text_representation[:NUM_TEXT_CHARS_GENERATE] if doc.text_representation is not None else None
)
retval += f'*..."{text_content}"...* <br><br>'
if text_content:
retval += f'*..."{text_content}"...* <br><br>'
else:
props_dict = doc.properties.get("entity", {})
props_dict.update({p: doc.properties[p] for p in set(doc.properties) - set(BASE_PROPS)})
Expand Down Expand Up @@ -242,30 +243,73 @@ def readdata(self):
data[col].append(row.get(col))
self.df = pd.DataFrame(data)

def show(self):
def show(self, node_descriptions: dict[str, str]):
"""Render the trace data."""
st.subheader(f"Node {self.node_id}")
st.markdown(f"*Description: {node_descriptions.get(self.node_id) or 'n/a'}*")
if self.df is None or not len(self.df):
st.write(f"Result of node {self.node_id} — :red[0] documents")
st.write(":red[0] documents")
st.write("No data.")
return

all_columns = list(self.df.columns)
column_order = [c for c in self.COLUMNS if c in all_columns]
column_order += [c for c in all_columns if c not in column_order]
st.write(f"Result of node {self.node_id} — **{len(self.df)}** documents")
st.write(f"**{len(self.df)}** documents")
st.dataframe(self.df, column_order=column_order)


class QueryMetadataTrace:
"""Helper class to read and display metadata about a query."""

def __init__(self, metadata_dir: str):
self.metadata_dir = metadata_dir
self.query_plan = None
self.readdata()

def readdata(self):
f = os.path.join(self.metadata_dir, "query_plan.json")
if os.path.isfile(f):
self.query_plan = LogicalPlan.parse_file(f)

def get_node_to_description(self) -> dict[str, str]:
baitsguy marked this conversation as resolved.
Show resolved Hide resolved
if self.query_plan is None:
return {}
result = dict()
for node_id, node in self.query_plan.nodes.items():
result[str(node_id)] = node.description
return result

def show(self):
if self.query_plan is not None:
st.write(f"Query: {self.query_plan.query}")
st.write(self.query_plan)
else:
st.write("No query plan found")


class QueryTrace:
"""Helper class used to read and display query traces."""

def __init__(self, trace_dir: str):
self.trace_dir = trace_dir
self.node_traces = [QueryNodeTrace(trace_dir, node_id) for node_id in sorted(os.listdir(self.trace_dir))]
self.node_traces = []
for dir in sorted(os.listdir(self.trace_dir)):
if "metadata" not in dir:
self.node_traces += [QueryNodeTrace(trace_dir, dir)]
self.metadata = QueryMetadataTrace(self.trace_dir + "/" + dir)
baitsguy marked this conversation as resolved.
Show resolved Hide resolved

def show(self):
for node_trace in self.node_traces:
node_trace.show()
node_descriptions = dict()
tab1, tab2 = st.tabs(["Node data", "Query plan"])
if self.metadata:
node_descriptions = self.metadata.get_node_to_description()
with tab1:
for node_trace in self.node_traces:
node_trace.show(node_descriptions)
with tab2:
if self.metadata:
self.metadata.show()


@st.fragment
Expand Down
15 changes: 15 additions & 0 deletions lib/sycamore/sycamore/data/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def __str__(self) -> str:
}
return json.dumps(d, indent=2)

def field_to_value(self, field: str) -> Any:
"""
Extracts the value for a particular element field.

Args:
field: The field in dotted notation to indicate nesting, e.g. properties.schema

Returns:
The value associated with the document field.
Returns None if field does not exist in document.
"""
from sycamore.utils.nested import dotted_lookup

return dotted_lookup(self, field)


class ImageElement(Element):
def __init__(
Expand Down
22 changes: 17 additions & 5 deletions lib/sycamore/sycamore/query/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import structlog

import sycamore
from sycamore import Context
from sycamore import Context, ExecMode
from sycamore.context import OperationTypes
from sycamore.llms.openai import OpenAI, OpenAIModels
from sycamore.transforms.embed import SentenceTransformerEmbedder
from sycamore.transforms.query import OpenSearchQueryExecutor
from sycamore.transforms.similarity import HuggingFaceTransformersSimilarityScorer
from sycamore.utils.cache import cache_from_path
from sycamore.utils.import_utils import requires_modules

Expand Down Expand Up @@ -124,13 +126,15 @@ def __init__(
os_client_args: Optional[dict] = None,
trace_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
sycamore_exec_mode: ExecMode = ExecMode.RAY,
):
from opensearchpy import OpenSearch

self.s3_cache_path = s3_cache_path
self.os_config = os_config
self.trace_dir = trace_dir
self.cache_dir = cache_dir
self.sycamore_exec_mode = sycamore_exec_mode

# TODO: remove these assertions and simplify the code to get all customization via the
# context.
Expand All @@ -141,7 +145,7 @@ def __init__(
raise AssertionError("setting s3_cache_path requires context==None. See Notes in class documentation.")

os_client_args = os_client_args or DEFAULT_OS_CLIENT_ARGS
self.context = context or self._get_default_context(s3_cache_path, os_client_args)
self.context = context or self._get_default_context(s3_cache_path, os_client_args, sycamore_exec_mode)

assert self.context.params, "Could not find required params in Context"
self.os_client_args = self.context.params.get("opensearch", {}).get("os_client_args", os_client_args)
Expand Down Expand Up @@ -242,17 +246,25 @@ def default_text_embedder():
return SentenceTransformerEmbedder(batch_size=100, model_name="sentence-transformers/all-MiniLM-L6-v2")

@staticmethod
def _get_default_context(s3_cache_path, os_client_args) -> Context:
def _get_default_context(s3_cache_path, os_client_args, sycamore_exec_mode) -> Context:
context_params = {
"default": {"llm": OpenAI(OpenAIModels.GPT_4O.value, cache=cache_from_path(s3_cache_path))},
"opensearch": {
"os_client_args": os_client_args,
"text_embedder": SycamoreQueryClient.default_text_embedder(),
},
OperationTypes.BINARY_CLASSIFIER: {
"llm": OpenAI(OpenAIModels.GPT_4O_MINI.value, cache=cache_from_path(s3_cache_path))
},
OperationTypes.INFORMATION_EXTRACTOR: {
"llm": OpenAI(OpenAIModels.GPT_4O_MINI.value, cache=cache_from_path(s3_cache_path))
},
OperationTypes.TEXT_SIMILARITY: {"similarity_scorer": HuggingFaceTransformersSimilarityScorer()},
}
return sycamore.init(params=context_params)
return sycamore.init(params=context_params, exec_mode=sycamore_exec_mode)

def result_to_str(self, result: Any, max_docs: int = 100, max_chars_per_doc: int = 2500) -> str:
@staticmethod
def result_to_str(result: Any, max_docs: int = 100, max_chars_per_doc: int = 2500) -> str:
"""Convert a query result to a string.

Args:
Expand Down
12 changes: 12 additions & 0 deletions lib/sycamore/sycamore/query/execution/sycamore_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,24 @@ def get_code_string(self):
"""
return result

def _write_query_plan_to_trace_dir(self, plan: LogicalPlan, query_id: str):
assert self.trace_dir is not None, "Writing query_plan requires trace_dir to be set"
path = os.path.join(self.trace_dir, f"{query_id}/metadata/")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using os.path.join already please do

path = os.path.join(self.trace_dir, query_id, "metadata")

instead.

os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "query_plan.json"), "w") as f:
f.write(plan.json())
baitsguy marked this conversation as resolved.
Show resolved Hide resolved

def execute(self, plan: LogicalPlan, query_id: Optional[str] = None) -> Any:
try:
"""Execute a logical plan using Sycamore."""
if not query_id:
query_id = str(uuid.uuid4())
bind_contextvars(query_id=query_id)

log.info("Writing query plan to trace dir")
if self.trace_dir:
self._write_query_plan_to_trace_dir(plan, query_id)

log.info("Executing query")
assert isinstance(plan.result_node, LogicalOperator)
result = self.process_node(plan.result_node, query_id)
Expand Down
Loading
Loading