diff --git a/apps/query-ui/queryui/configuration.py b/apps/query-ui/queryui/configuration.py index 8c73abe99..e85432d2a 100644 --- a/apps/query-ui/queryui/configuration.py +++ b/apps/query-ui/queryui/configuration.py @@ -25,6 +25,11 @@ def get_sycamore_query_client( - s3_cache_path: Optional[str] = None, trace_dir: Optional[str] = None, cache_dir: Optional[str] = None, exec_mode: ExecMode = ExecMode.RAY + s3_cache_path: Optional[str] = None, + trace_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + exec_mode: ExecMode = ExecMode.RAY, ) -> SycamoreQueryClient: - return SycamoreQueryClient(s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir, sycamore_exec_mode=exec_mode) + return SycamoreQueryClient( + s3_cache_path=s3_cache_path, trace_dir=trace_dir, cache_dir=cache_dir, sycamore_exec_mode=exec_mode + ) diff --git a/apps/query-ui/queryui/pages/Chat.py b/apps/query-ui/queryui/pages/Chat.py index 098df0ce0..9a4c4d530 100644 --- a/apps/query-ui/queryui/pages/Chat.py +++ b/apps/query-ui/queryui/pages/Chat.py @@ -350,7 +350,6 @@ def main(): if "trace_dir" not in st.session_state: st.session_state.trace_dir = os.path.join(os.getcwd(), "traces") - if not args.local_mode and args.external_ray: sycamore_ray_init(address="auto") st.title("Sycamore Query Chat") diff --git a/apps/query-ui/queryui/util.py b/apps/query-ui/queryui/util.py index f78e9a9ce..a564029a6 100644 --- a/apps/query-ui/queryui/util.py +++ b/apps/query-ui/queryui/util.py @@ -243,12 +243,12 @@ def readdata(self): data[col].append(row.get(col)) self.df = pd.DataFrame(data) - def show(self, node_descriptions: dict[str, str]): + def show(self, node_descriptions: dict[str, str]): """Render the trace data.""" st.subheader(f"Node {self.node_id}") st.markdown(f"*Description: {node_descriptions.get(self.node_id) or 'n/a'}*") if self.df is None or not len(self.df): - st.write(f":red[0] documents") + st.write(":red[0] documents") st.write("No data.") return @@ -258,6 +258,7 @@ def show(self, node_descriptions: dict[str, str]): st.write(f"**{len(self.df)}** documents") st.dataframe(self.df, column_order=column_order) + class QueryMetadataTrace: """Helper class to read and display metadata about a query.""" @@ -271,13 +272,13 @@ def readdata(self): if os.path.isfile(f): with open(f, "rb") as file: self.query_plan = pickle.load(file) - + def get_node_to_description(self) -> dict[str, str]: if self.query_plan is None: return {} result = dict() for node_id, node in self.query_plan.nodes.items(): - result[str(node_id)] = node.description + result[str(node_id)] = node.description return result def show(self):