Skip to content

Commit

Permalink
refactor: dspy extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
gusye1234 committed Sep 24, 2024
1 parent 9f7efdc commit 9245a4f
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 105 deletions.
79 changes: 23 additions & 56 deletions nano_graphrag/entity_extraction/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,35 @@


async def generate_dataset(
<<<<<<< HEAD
chunks: dict[str, TextChunkSchema],
filepath: str,
save_dataset: bool = True,
global_config: dict = {}
=======
chunks: dict[str, TextChunkSchema], filepath: str, save_dataset: bool = True
>>>>>>> 0a9d8a9 (refactor: make _storage a folder)
global_config: dict = {},
) -> list[dspy.Example]:
entity_extractor = TypedEntityRelationshipExtractor()

if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])

ordered_chunks = list(chunks.items())
already_processed = 0
already_entities = 0
already_relations = 0

<<<<<<< HEAD
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]) -> dspy.Example:
async def _process_single_content(
chunk_key_dp: tuple[str, TextChunkSchema]
) -> dspy.Example:
nonlocal already_processed, already_entities, already_relations
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
try:
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []
example = dspy.Example(
input_text=content,
entities=entities,
relationships=relationships
=======
async def _process_single_content(
chunk_key_dp: tuple[str, TextChunkSchema]
) -> dspy.Example:
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
example = dspy.Example(
input_text=content,
entities=prediction.entities,
relationships=prediction.relationships,
>>>>>>> 0a9d8a9 (refactor: make _storage a folder)
input_text=content, entities=entities, relationships=relationships
).with_inputs("input_text")
already_entities += len(entities)
already_relations += len(relationships)
Expand All @@ -81,12 +62,18 @@ async def _process_single_content(
examples = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
filtered_examples = [example for example in examples if len(example.entities) > 0 and len(example.relationships) > 0]
filtered_examples = [
example
for example in examples
if len(example.entities) > 0 and len(example.relationships) > 0
]
num_filtered_examples = len(examples) - len(filtered_examples)
if save_dataset:
with open(filepath, 'wb') as f:
with open(filepath, "wb") as f:
pickle.dump(filtered_examples, f)
logger.info(f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples")
logger.info(
f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples"
)

return filtered_examples

Expand All @@ -112,46 +99,26 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
<<<<<<< HEAD
try:
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []

maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)

for entity in entities:
entity["source_id"] = chunk_key
maybe_nodes[entity['entity_name']].append(entity)
maybe_nodes[entity["entity_name"]].append(entity)
already_entities += 1

for relationship in relationships:
relationship["source_id"] = chunk_key
maybe_edges[(relationship['src_id'], relationship['tgt_id'])].append(relationship)
=======
prediction = await asyncio.to_thread(entity_extractor, input_text=content)

maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)

for entity in prediction.entities.context:
entity_dict = entity.dict()
entity_dict["source_id"] = chunk_key
maybe_nodes[entity_dict["entity_name"]].append(entity_dict)
already_entities += 1

for relationship in prediction.relationships.context:
relationship_dict = relationship.dict()
relationship_dict["source_id"] = chunk_key
maybe_edges[
(relationship_dict["src_id"], relationship_dict["tgt_id"])
].append(relationship_dict)
>>>>>>> 0a9d8a9 (refactor: make _storage a folder)
maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(
relationship
)
already_relations += 1

already_processed += 1
Expand Down
45 changes: 32 additions & 13 deletions nano_graphrag/entity_extraction/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,42 @@ class AssessRelationships(dspy.Signature):
- Balance the impact of matched and unmatched relationships in the final score.
"""

gold_relationships: list[Relationship] = dspy.InputField(desc="The gold-standard relationships to compare against.")
predicted_relationships: list[Relationship] = dspy.InputField(desc="The predicted relationships to compare against the gold-standard relationships.")
similarity_score: float = dspy.OutputField(desc="Similarity score between 0 and 1, with 1 being the highest similarity.")


def relationships_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
gold_relationships: list[Relationship] = dspy.InputField(
desc="The gold-standard relationships to compare against."
)
predicted_relationships: list[Relationship] = dspy.InputField(
desc="The predicted relationships to compare against the gold-standard relationships."
)
similarity_score: float = dspy.OutputField(
desc="Similarity score between 0 and 1, with 1 being the highest similarity."
)


def relationships_similarity_metric(
gold: dspy.Example, pred: dspy.Prediction, trace=None
) -> float:
model = dspy.TypedChainOfThought(AssessRelationships)
gold_relationships = [Relationship(**item) for item in gold['relationships']]
predicted_relationships = [Relationship(**item) for item in pred['relationships']]
similarity_score = float(model(gold_relationships=gold_relationships, predicted_relationships=predicted_relationships).similarity_score)
gold_relationships = [Relationship(**item) for item in gold["relationships"]]
predicted_relationships = [Relationship(**item) for item in pred["relationships"]]
similarity_score = float(
model(
gold_relationships=gold_relationships,
predicted_relationships=predicted_relationships,
).similarity_score
)
return similarity_score


def entity_recall_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
true_set = set(item['entity_name'] for item in gold['entities'])
pred_set = set(item['entity_name'] for item in pred['entities'])
def entity_recall_metric(
gold: dspy.Example, pred: dspy.Prediction, trace=None
) -> float:
true_set = set(item["entity_name"] for item in gold["entities"])
pred_set = set(item["entity_name"] for item in pred["entities"])
true_positives = len(pred_set.intersection(true_set))
false_negatives = len(true_set - pred_set)
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
recall = (
true_positives / (true_positives + false_negatives)
if (true_positives + false_negatives) > 0
else 0
)
return recall
Loading

0 comments on commit 9245a4f

Please sign in to comment.