Skip to content

Commit

Permalink
Convert Entity Relationship Extraction in DSPy to using CoT (#44)
Browse files Browse the repository at this point in the history
* Converted TypedPredictor to CoT and removed pydantic models using experimental DSPy in notebook

* Fix entity extraction unittests after removing pydantic models and changing to CoT

* Add working random search fine tuning with better metrics

* Still cannot get MIPROv2 to work

* Working MIPROv2 with TypedChainOfThought

* Updated metrics to compute all relationships at once, updated prompt instructions that works for qwen2-7b

* Add updated notebooks with fine tuning using MIPROv2 and qwen2-7b as task model

* Add compiled model for generate dataset with updated unittests

---------

Co-authored-by: terence-gpt <[email protected]>
  • Loading branch information
NumberChiffre and NumberChiffre authored Sep 23, 2024
1 parent 9e027a6 commit 5adf21f
Show file tree
Hide file tree
Showing 10 changed files with 16,693 additions and 807 deletions.
11 changes: 6 additions & 5 deletions examples/benchmarks/dspy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,20 @@ async def run_benchmark(text: str):
system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
lm = dspy.OpenAI(
model="deepseek-chat",
model_type="chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt_dspy,
system_prompt=system_prompt,
temperature=1.0,
top_p=1,
max_tokens=4096
max_tokens=8192
)
dspy.settings.configure(lm=lm)
dspy.settings.configure(lm=lm, experimental=True)
graph_storage_with_dspy, time_with_dspy = await benchmark_entity_extraction(text, system_prompt_dspy, use_dspy=True)
print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
print_extraction_results(graph_storage_with_dspy)

import pdb; pdb.set_trace()
print("Running benchmark without DSPy-AI:")
system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)
Expand Down
14,525 changes: 14,146 additions & 379 deletions examples/finetune_entity_relationship_dspy.ipynb

Large diffs are not rendered by default.

2,062 changes: 2,062 additions & 0 deletions examples/generate_entity_relationship_dspy.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions examples/using_dspy_entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def query():
"""
lm = dspy.OpenAI(
model="deepseek-chat",
model_type="chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt,
temperature=1.0,
top_p=1,
max_tokens=4096
max_tokens=8192
)
dspy.settings.configure(lm=lm)
dspy.settings.configure(lm=lm, experimental=True)
insert()
query()
79 changes: 55 additions & 24 deletions nano_graphrag/entity_extraction/extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union
import pickle
import asyncio
from openai import BadRequestError
from collections import defaultdict
import dspy
from nano_graphrag._storage import BaseGraphStorage
Expand All @@ -11,40 +12,67 @@
)
from nano_graphrag.prompt import PROMPTS
from nano_graphrag._utils import logger, compute_mdhash_id
from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor
from nano_graphrag.entity_extraction.module import TypedEntityRelationshipExtractor
from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert


async def generate_dataset(
chunks: dict[str, TextChunkSchema],
filepath: str,
save_dataset: bool = True
save_dataset: bool = True,
global_config: dict = {}
) -> list[dspy.Example]:
entity_extractor = EntityRelationshipExtractor()
entity_extractor = TypedEntityRelationshipExtractor()

if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])

ordered_chunks = list(chunks.items())
already_processed = 0
already_entities = 0
already_relations = 0

async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]) -> dspy.Example:
nonlocal already_processed, already_entities, already_relations
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
try:
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []
example = dspy.Example(
input_text=content,
entities=prediction.entities,
relationships=prediction.relationships
entities=entities,
relationships=relationships
).with_inputs("input_text")
already_entities += len(entities)
already_relations += len(relationships)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return example

examples = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
filtered_examples = [example for example in examples if len(example.entities) > 0 and len(example.relationships) > 0]
num_filtered_examples = len(examples) - len(filtered_examples)
if save_dataset:
with open(filepath, 'wb') as f:
pickle.dump(examples, f)
logger.info(f"Saved {len(examples)} examples with keys: {examples[0].keys()}")
pickle.dump(filtered_examples, f)
logger.info(f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples")

return examples
return filtered_examples


async def extract_entities_dspy(
Expand All @@ -53,7 +81,7 @@ async def extract_entities_dspy(
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
entity_extractor = EntityRelationshipExtractor()
entity_extractor = TypedEntityRelationshipExtractor()

if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])
Expand All @@ -68,23 +96,26 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)

try:
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []

maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)

for entity in prediction.entities.context:
entity_dict = entity.dict()
entity_dict["source_id"] = chunk_key
maybe_nodes[entity_dict['entity_name']].append(entity_dict)
for entity in entities:
entity["source_id"] = chunk_key
maybe_nodes[entity['entity_name']].append(entity)
already_entities += 1

for relationship in prediction.relationships.context:
relationship_dict = relationship.dict()
relationship_dict["source_id"] = chunk_key
maybe_edges[(relationship_dict['src_id'], relationship_dict['tgt_id'])].append(relationship_dict)
for relationship in relationships:
relationship["source_id"] = chunk_key
maybe_edges[(relationship['src_id'], relationship['tgt_id'])].append(relationship)
already_relations += 1

already_processed += 1
Expand Down
63 changes: 35 additions & 28 deletions nano_graphrag/entity_extraction/metric.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
import dspy
import numpy as np


class AssessRelationship(dspy.Signature):
"""Assess the similarity of two relationships."""
gold_relationship = dspy.InputField()
predicted_relationship = dspy.InputField()
similarity_score = dspy.OutputField(desc="Similarity score between 0 and 1")


def relationship_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
similarity_scores = []

for gold_rel, pred_rel in zip(gold.relationships.context, pred.relationships.context):
assessment = dspy.Predict(AssessRelationship)(
gold_relationship=gold_rel,
predicted_relationship=pred_rel
)

try:
score = float(assessment.similarity_score)
similarity_scores.append(score)
except ValueError:
similarity_scores.append(0.0)

return np.mean(similarity_scores) if similarity_scores else 0.0
from nano_graphrag.entity_extraction.module import Relationship


class AssessRelationships(dspy.Signature):
"""
Assess the similarity between gold and predicted relationships:
1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
2. For matched pairs, compare:
a) Description similarity (semantic meaning)
b) Weight similarity
c) Order similarity
3. Consider unmatched relationships as penalties.
4. Aggregate scores, accounting for precision and recall.
5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
Key considerations:
- Prioritize matching based on entity pairs over exact string matches.
- Use semantic similarity for descriptions rather than exact matches.
- Weight the importance of different aspects (e.g., entity matching, description, weight, order).
- Balance the impact of matched and unmatched relationships in the final score.
"""

gold_relationships: list[Relationship] = dspy.InputField(desc="The gold-standard relationships to compare against.")
predicted_relationships: list[Relationship] = dspy.InputField(desc="The predicted relationships to compare against the gold-standard relationships.")
similarity_score: float = dspy.OutputField(desc="Similarity score between 0 and 1, with 1 being the highest similarity.")


def relationships_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
model = dspy.TypedChainOfThought(AssessRelationships)
gold_relationships = [Relationship(**item) for item in gold['relationships']]
predicted_relationships = [Relationship(**item) for item in pred['relationships']]
similarity_score = float(model(gold_relationships=gold_relationships, predicted_relationships=predicted_relationships).similarity_score)
return similarity_score


def entity_recall_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
true_set = set(item.entity_name for item in gold.entities.context)
pred_set = set(item.entity_name for item in pred.entities.context)
true_set = set(item['entity_name'] for item in gold['entities'])
pred_set = set(item['entity_name'] for item in pred['entities'])
true_positives = len(pred_set.intersection(true_set))
false_negatives = len(true_set - pred_set)
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
Expand Down
Loading

0 comments on commit 5adf21f

Please sign in to comment.