diff --git a/backend/Pipfile b/backend/Pipfile index 778105e..bc9520f 100644 --- a/backend/Pipfile +++ b/backend/Pipfile @@ -23,6 +23,7 @@ httpx = "==0.27.0" pymupdf = "==1.24.2" pydantic = {extras = ["email"], version = "==2.7.1"} python-multipart = "==0.0.9" +instructor = "*" [dev-packages] diff --git a/backend/Pipfile.lock b/backend/Pipfile.lock index 26e665f..d59b4fd 100644 --- a/backend/Pipfile.lock +++ b/backend/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "65638357e5510f64031dc36b3f7368a06de6512a2f4459dfea5f802c371cc4e4" + "sha256": "b15b64af901673decb13323cf00df182362aa87ec84e5640a596dd09df68f6df" }, "pipfile-spec": 6, "requires": { @@ -317,6 +317,14 @@ "markers": "python_version >= '3.8'", "version": "==2.6.1" }, + "docstring-parser": { + "hashes": [ + "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682", + "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9" + ], + "markers": "python_version >= '3.6' and python_version < '4.0'", + "version": "==0.15" + }, "email-validator": { "hashes": [ "sha256:200a70680ba08904be6d1eef729205cc0d687634399a5924d842533efb824b84", @@ -578,6 +586,15 @@ "markers": "python_version >= '3.5'", "version": "==3.7" }, + "instructor": { + "hashes": [ + "sha256:9c9a85e236054f0723560b1a671f512e5fc8b6c639c2f29849e0b2867e65a030", + "sha256:de0993d4fa58cea9e6a17322184172971ca41f4aaa517e3ce71038aa2c0703da" + ], + "index": "pypi", + "markers": "python_version >= '3.10' and python_version < '4.0'", + "version": "==1.0.3" + }, "joblib": { "hashes": [ "sha256:1eb0dc091919cd384490de890cb5dfd538410a6d4b3b54eef09fb8c50b409b1c", @@ -667,11 +684,11 @@ }, "llama-index-core": { "hashes": [ - "sha256:0078c06d9143390e14c86a40e69716c88c7828533341559edd15e52249ede65a", - "sha256:215f7389dadb78f2df13c20312a3e1e03c41f23e3063907469c4bae67bfd458c" + "sha256:21b98b2c45e0c6b673aa505c7add1e8b730f472ad58d4572b909a34f4a22c36c", + "sha256:943114fb02dfe62fec5d882d749ad8adf113081aadcb0d4cb2c083b2c9052ed0" ], "markers": "python_version < '4.0' and python_full_version >= '3.8.1'", - "version": "==0.10.32" + "version": "==0.10.33" }, "llama-index-readers-file": { "hashes": [ @@ -845,6 +862,14 @@ "markers": "python_version >= '3.7'", "version": "==1.1.0" }, + "markdown-it-py": { + "hashes": [ + "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", + "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb" + ], + "markers": "python_version >= '3.8'", + "version": "==3.0.0" + }, "marshmallow": { "hashes": [ "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3", @@ -853,6 +878,14 @@ "markers": "python_version >= '3.8'", "version": "==3.21.1" }, + "mdurl": { + "hashes": [ + "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", + "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba" + ], + "markers": "python_version >= '3.7'", + "version": "==0.1.2" + }, "multidict": { "hashes": [ "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556", @@ -1311,6 +1344,14 @@ "markers": "python_version >= '3.8'", "version": "==2.18.2" }, + "pygments": { + "hashes": [ + "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c", + "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367" + ], + "markers": "python_version >= '3.7'", + "version": "==2.17.2" + }, "pymupdf": { "hashes": [ "sha256:007586883fbc8acb900d46aa95520aaeb8943d05a956b26c54053ddb58dbdd5f", @@ -1575,6 +1616,14 @@ "markers": "python_version >= '3.7'", "version": "==2.31.0" }, + "rich": { + "hashes": [ + "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222", + "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432" + ], + "markers": "python_full_version >= '3.7.0'", + "version": "==13.7.1" + }, "setuptools": { "hashes": [ "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987", @@ -1787,6 +1836,14 @@ "markers": "python_version >= '3.6'", "version": "==1.8.1" }, + "typer": { + "hashes": [ + "sha256:aa6c4a4e2329d868b80ecbaf16f807f2b54e192209d7ac9dd42691d63f7a54eb", + "sha256:f714c2d90afae3a7929fcd72a3abb08df305e1ff61719381384211c4070af57f" + ], + "markers": "python_version >= '3.6'", + "version": "==0.9.4" + }, "typing-extensions": { "hashes": [ "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0", diff --git a/backend/app/db.py b/backend/app/db.py index 0802674..06cf8e9 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -89,11 +89,19 @@ def update_urls(self, urls: list[URL]): def create_text_nodes(self, nodes: list[TextNode], user_id: str): text_nodes_to_persist = [] text_node_chunks_to_persist = [] + text_nodes_to_text_node_concepts_to_persist = [] for node in nodes: text_node, text_node_chunks = node.to_persistence() text_node["user_id"] = user_id + for chunk in text_node_chunks: chunk["user_id"] = user_id + + for concept_id in node.concept_ids: + text_nodes_to_text_node_concepts_to_persist.append( + {"text_node_id": node.id, "text_node_concept_id": concept_id} + ) + text_nodes_to_persist.append(text_node) text_node_chunks_to_persist.extend(text_node_chunks) @@ -101,6 +109,10 @@ def create_text_nodes(self, nodes: list[TextNode], user_id: str): self._client.table("text_node_chunks").insert( text_node_chunks_to_persist ).execute() + print(text_nodes_to_text_node_concepts_to_persist) + self._client.table("text_node_to_text_node_concepts").insert( + text_nodes_to_text_node_concepts_to_persist + ).execute() def get_urls_feed(self, user_id: str): result = ( @@ -123,3 +135,26 @@ def get_user_id_by_email_alias(self, app_email_alias: str): if len(result.data) != 1: return None return result.data[0]["id"] + + def get_text_node_concept_ids(self, concepts: list[str]) -> list[int]: + existing_concepts = ( + self._client.table("text_node_concepts") + .select("id, name") + .in_("name", concepts) + .execute() + .data + ) + + existing_concept_names = [c["name"] for c in existing_concepts] + new_concept_names = [ + {"name": c} for c in concepts if c not in existing_concept_names + ] + new_concepts = ( + self._client.table("text_node_concepts") + .insert(new_concept_names) + .execute() + .data + ) + concept_ids = [c["id"] for c in (existing_concepts + new_concepts)] + + return concept_ids diff --git a/backend/app/domain/node.py b/backend/app/domain/node.py index 276b606..a2aae24 100644 --- a/backend/app/domain/node.py +++ b/backend/app/domain/node.py @@ -18,7 +18,13 @@ def __init__(self, text: str, text_node_id: str) -> None: class TextNode: def __init__( - self, url_feed_id: str, url: str, title: str, text: str, summary: str + self, + url_feed_id: str, + url: str, + title: str, + text: str, + summary: str, + concept_ids: list[int], ) -> None: self.id = uuid7() self.url_feed_id = url_feed_id @@ -26,10 +32,15 @@ def __init__( self.title = title self.text = text self.summary = summary + self._concept_ids = concept_ids self.embedding = None self.chunks: list[TextNodeChunk] = [] self.create_title_if_missing() + @property + def concept_ids(self) -> list[int]: + return self._concept_ids + def create_chunks(self, chunker: NodeChunker) -> None: self.chunks = chunker.chunk(self.id, self.text) diff --git a/backend/app/llm.py b/backend/app/llm.py index 6ac19a3..ca0ec9d 100644 --- a/backend/app/llm.py +++ b/backend/app/llm.py @@ -1,10 +1,13 @@ import json from typing import List, Generator, Any + +import instructor from openai import OpenAI +from pydantic import BaseModel client = OpenAI() -MODEL_16K = "gpt-3.5-turbo-16k" +MODEL_16K = "gpt-3.5-turbo-0125" PROMPT_TEMPLATE = ( "A question and context documents are provided below." @@ -22,6 +25,23 @@ "{question}" ) +EXTRACT_CONCEPTS_PROMPT_TEMPLATE = ( + "Please extract ONLY THE MOST IMPORTANT concepts, entities & topics from the provided text." + "DO NOT provide more than 8 results per text article." + "MAKE SURE the oncepts, entities & topics you select are relevant to the overall article, and are not ads or examples." + "---------------------\n" + "TEXT:\n" + "{text}" +) + +EXTRACT_CONCEPTS_SYSTEM_PROMPT_TEMPLATE = ( + "You are an information extraction system. You respond to each message with a list of useful named entities." + "Each named entity appears as one entry in a list." + "Ignore unimportant entities, e.g., of type formatting, citations, and references." + "The types of entities that we are most interested in are human, artificial object, spatio-temporal entity, corporate body, concrete object, talk, geographical feature, natural object, product, system." + "IMPORTANT: you only include entities that appear in the text." +) + def format_chunks(chunks: List[dict]) -> str: result = "" @@ -74,3 +94,35 @@ def summarise_text(text: str) -> str: temperature=0, ) return result.choices[0].message.content + + +class NodeConcepts(BaseModel): + """ + Represents a list of key concepts and entities extracted from text. + """ + + concepts: list[str] + + +def extract_concepts(text: str) -> list[str]: + client = instructor.from_openai(OpenAI()) + + node_concepts = client.chat.completions.create( + model=MODEL_16K, + temperature=0, + response_model=NodeConcepts, + messages=[ + { + "role": "system", + "content": "EXTRACT_CONCEPTS_SYSTEM_PROMPT_TEMPLATE", + }, + { + "role": "user", + "content": EXTRACT_CONCEPTS_PROMPT_TEMPLATE.format(text=text), + }, + ], + ) + + concepts = [n.lower().replace(" ", "-") for n in node_concepts.concepts] + + return concepts diff --git a/backend/app/services/indexing.py b/backend/app/services/indexing.py index cc691b0..48cd821 100644 --- a/backend/app/services/indexing.py +++ b/backend/app/services/indexing.py @@ -1,5 +1,6 @@ from app.db import DB from app.llm import summarise_text +from app.llm import extract_concepts from app.utils import URLProcessor from app.utils import URLProcessingResult from app.utils import NodeChunker @@ -20,12 +21,16 @@ async def index(self, urls: list[URL], user_id: str): for idx, processed_url in enumerate(processed_urls): try: if isinstance(processed_url, URLProcessingResult): + concepts = extract_concepts(processed_url.text) + print(concepts) + concept_ids = db.get_text_node_concept_ids(concepts) text_node = TextNode( url=processed_url.url, url_feed_id=urls[idx].id, title=processed_url.title, text=processed_url.text, summary=summarise_text(processed_url.text), + concept_ids=concept_ids, ) text_node.create_chunks(NodeChunker) text_node.create_embeddings(NodeEmbedder) diff --git a/backend/supabase/migrations/20240428112033_add_text_node_concepts.sql b/backend/supabase/migrations/20240428112033_add_text_node_concepts.sql new file mode 100644 index 0000000..1a969f0 --- /dev/null +++ b/backend/supabase/migrations/20240428112033_add_text_node_concepts.sql @@ -0,0 +1,21 @@ +-- Create a table to store text node concepts. +create table + public.text_node_concepts ( + id serial primary key, + "name" varchar not null + ); + +alter table public.text_node_concepts enable row level security; + +-- Create a table to connect text nodes to text node concepts, with a many-to-many relationship. +create table + public.text_node_to_text_node_concepts ( + text_node_id uuid not null, + text_node_concept_id int not null, + primary key (text_node_id, text_node_concept_id), + foreign key (text_node_id) references public.text_nodes (id), + foreign key (text_node_concept_id) references public.text_node_concepts (id) +); + +alter table public.text_node_to_text_node_concepts enable row level security; +