Skip to content

Commit

Permalink
Merge pull request #5575 from pathwaycom/berke/vs-metadata-proc
Browse files Browse the repository at this point in the history
Berke/vs metadata proc

GitOrigin-RevId: 8c45097bd7046d04620ef9822f8cf94e6bfa191c
  • Loading branch information
berkecanrizai authored and Manul from Pathway committed Feb 4, 2024
1 parent ce37acd commit 02cfbbc
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions python/pathway/xpacks/llm/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,23 @@ def __init__(
embedder: Callable[[str], list[float]],
parser: Callable[[bytes], list[tuple[str, dict]]] | None = None,
splitter: Callable[[str], list[tuple[str, dict]]] | None = None,
doc_post_processors: list[Callable[[str, dict], tuple[str, dict]]]
| None = None,
):
self.docs = docs

self.parser: Callable[[bytes], list[tuple[str, dict]]] = _unwrap_udf(
parser if parser is not None else pathway.xpacks.llm.parsers.ParseUtf8()
)
self.doc_post_processors = []

if doc_post_processors:
self.doc_post_processors = [
_unwrap_udf(processor)
for processor in doc_post_processors
if processor is not None
]

self.splitter = _unwrap_udf(
splitter
if splitter is not None
Expand Down Expand Up @@ -131,11 +142,25 @@ def parse_doc(data: bytes, metadata) -> list[pw.Json]:
pw.this.data
)

@pw.udf
def post_proc_docs(data_json: pw.Json) -> pw.Json:
data: dict = data_json.value # type:ignore
text = data["text"]
metadata = data["metadata"]

for processor in self.doc_post_processors:
text, metadata = processor(text, metadata)

return dict(text=text, metadata=metadata) # type: ignore

parsed_docs = parsed_docs.select(data=post_proc_docs(pw.this.data))

@pw.udf
def split_doc(data_json: pw.Json) -> list[pw.Json]:
data: dict = data_json.value # type:ignore
text = data["text"]
metadata = data["metadata"]

rets = self.splitter(text)
return [
dict(text=ret[0], metadata={**metadata, **ret[1]}) # type:ignore
Expand Down Expand Up @@ -269,6 +294,7 @@ def format_inputs(
metadata_filter, m.value, options=_knn_lsh._glob_options
)
]

return metadatas

input_results = input_queries.join_left(all_metas, id=input_queries.id).select(
Expand Down Expand Up @@ -336,9 +362,8 @@ def run_server(
port,
threaded: bool = False,
with_cache: bool = True,
cache_backend: (
pw.persistence.Backend | None
) = pw.persistence.Backend.filesystem("./Cache"),
cache_backend: pw.persistence.Backend
| None = pw.persistence.Backend.filesystem("./Cache"),
):
"""
Builds the document processing pipeline and runs it.
Expand Down

0 comments on commit 02cfbbc

Please sign in to comment.