Skip to content

Commit

Permalink
Added unittests for entity extraction, added field for relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
NumberChiffre committed Sep 12, 2024
1 parent f83057d commit 1268afa
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 93 deletions.
2 changes: 1 addition & 1 deletion examples/benchmarks/dspy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def run_benchmark(text: str, system_prompt: str):
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt,
temperature=0.3,
top_p=1.0,
top_p=1,
max_tokens=4096
)
dspy.settings.configure(lm=lm)
Expand Down
210 changes: 195 additions & 15 deletions examples/finetune_entity_relationship_dspy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@
"outputs": [],
"source": [
"import dspy\n",
"from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2\n",
"from dspy.evaluate import Evaluate\n",
"import asyncio\n",
"import os\n",
"import numpy as np\n",
"from dotenv import load_dotenv\n",
"from datasets import load_dataset\n",
"import logging\n",
"\n",
"from nano_graphrag._utils import compute_mdhash_id\n",
"from nano_graphrag.entity_extraction.extract import generate_dataset, compile_model\n",
"from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor"
"from nano_graphrag.entity_extraction.extract import generate_dataset\n",
"from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor\n",
"from nano_graphrag.entity_extraction.metric import relationship_recall_metric, relationship_similarity_metric, entity_recall_metric"
]
},
{
Expand All @@ -39,7 +43,9 @@
"load_dotenv()\n",
"\n",
"logging.basicConfig(level=logging.WARNING)\n",
"logging.getLogger(\"nano-graphrag\").setLevel(logging.DEBUG)"
"logging.getLogger(\"nano-graphrag\").setLevel(logging.DEBUG)\n",
"\n",
"np.random.seed(1337)"
]
},
{
Expand All @@ -64,6 +70,12 @@
" top_p=1.0,\n",
" max_tokens=4096\n",
")\n",
"llama_lm = dspy.OllamaLocal(\n",
" model=\"llama3.1\", \n",
" model_type=\"chat\",\n",
" system=system_prompt,\n",
" max_tokens=4096\n",
")\n",
"dspy.settings.configure(lm=lm)"
]
},
Expand All @@ -75,10 +87,47 @@
"source": [
"os.makedirs(WORKING_DIR, exist_ok=True)\n",
"train_len = 20\n",
"entity_relationship_dataset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news.pkl\")\n",
"val_len = 2\n",
"dev_len = 3\n",
"entity_relationship_trainset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news_trainset.pkl\")\n",
"entity_relationship_valset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news_valset.pkl\")\n",
"entity_relationship_devset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news_devset.pkl\")\n",
"entity_relationship_module_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news.json\")\n",
"ds = load_dataset(\"ashraq/financial-news-articles\")\n",
"train_data = ds['train'][-train_len:]"
"fin_news = load_dataset(\"ashraq/financial-news-articles\")\n",
"cnn_news = load_dataset(\"AyoubChLin/CNN_News_Articles_2011-2022\")\n",
"fin_shuffled_indices = np.random.permutation(len(fin_news['train']))\n",
"cnn_train_shuffled_indices = np.random.permutation(len(cnn_news['train']))\n",
"cnn_test_shuffled_indices = np.random.permutation(len(cnn_news['test']))\n",
"train_data = cnn_news['train'].select(cnn_train_shuffled_indices[:train_len])\n",
"val_data = cnn_news['test'].select(cnn_test_shuffled_indices[:val_len])\n",
"dev_data = fin_news['train'].select(fin_shuffled_indices[:dev_len])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data['text'][:2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_data['text']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dev_data['text'][:2]"
]
},
{
Expand All @@ -87,7 +136,8 @@
"metadata": {},
"outputs": [],
"source": [
"train_data['text'][-1]"
"train_chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in train_data[\"text\"]}\n",
"trainset = asyncio.run(generate_dataset(chunks=train_chunks, filepath=entity_relationship_trainset_path))"
]
},
{
Expand All @@ -96,8 +146,10 @@
"metadata": {},
"outputs": [],
"source": [
"chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in train_data[\"text\"]}\n",
"dataset = asyncio.run(generate_dataset(chunks=chunks, filepath=entity_relationship_dataset_path))"
"for example in trainset:\n",
" for relationship in example.relationships.context:\n",
" if relationship.order == 2:\n",
" print(relationship)"
]
},
{
Expand All @@ -106,7 +158,10 @@
"metadata": {},
"outputs": [],
"source": [
"dataset[0]"
"for example in trainset:\n",
" for relationship in example.relationships.context:\n",
" if relationship.order == 3:\n",
" print(relationship)"
]
},
{
Expand All @@ -115,7 +170,93 @@
"metadata": {},
"outputs": [],
"source": [
"dataset[0].relationships.context"
"trainset[0].relationships.context[:2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in val_data[\"text\"]}\n",
"valset = asyncio.run(generate_dataset(chunks=val_chunks, filepath=entity_relationship_valset_path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"valset[0].relationships.context[:2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for example in valset:\n",
" for relationship in example.relationships.context:\n",
" if relationship.order == 2:\n",
" print(relationship)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for example in valset:\n",
" for relationship in example.relationships.context:\n",
" if relationship.order == 3:\n",
" print(relationship)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dev_chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in dev_data[\"text\"]}\n",
"devset = asyncio.run(generate_dataset(chunks=dev_chunks, filepath=entity_relationship_devset_path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"devset[0].relationships.context[:2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for example in devset:\n",
" for relationship in example.relationships.context:\n",
" if relationship.order == 2:\n",
" print(relationship)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for example in devset:\n",
" for relationship in example.relationships.context:\n",
" if relationship.order == 3:\n",
" print(relationship)"
]
},
{
Expand All @@ -134,13 +275,52 @@
"metadata": {},
"outputs": [],
"source": [
"optimized_model = compile_model(\n",
" model=model,\n",
" dataset_path=entity_relationship_dataset_path,\n",
" module_path=entity_relationship_module_path\n",
"metrics = [relationship_recall_metric, entity_recall_metric, relationship_similarity_metric]\n",
"for metric in metrics:\n",
" evaluate = Evaluate(\n",
" devset=devset, \n",
" metric=metric, \n",
" num_threads=os.cpu_count(), \n",
" display_progress=True,\n",
" display_table=5,\n",
" )\n",
" evaluate(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimizer = MIPROv2(\n",
" prompt_model=lm,\n",
" task_model=llama_lm,\n",
" metric=relationship_recall_metric,\n",
" init_temperature=0.7,\n",
" num_candidates=4\n",
")\n",
"optimized_model = optimizer.compile(model, trainset=trainset, valset=valset, num_batches=5, max_labeled_demos=5, max_bootstrapped_demos=3)\n",
"optimized_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metrics = [relationship_recall_metric, entity_recall_metric, relationship_similarity_metric]\n",
"for metric in metrics:\n",
" evaluate = Evaluate(\n",
" devset=devset, \n",
" metric=metric, \n",
" num_threads=os.cpu_count(), \n",
" display_progress=True,\n",
" display_table=5,\n",
" )\n",
" evaluate(optimized_model)"
]
}
],
"metadata": {
Expand Down
4 changes: 4 additions & 0 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,17 @@ async def _merge_edges_then_upsert(
already_weights = []
already_source_ids = []
already_description = []
already_order = []
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
already_order.append(already_edge.get("order", 1))

order = min([dp["order"] for dp in edges_data] + already_order)
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
Expand Down Expand Up @@ -212,6 +215,7 @@ async def _merge_edges_then_upsert(
weight=weight,
description=description,
source_id=source_id,
order=order
),
)

Expand Down
Loading

0 comments on commit 1268afa

Please sign in to comment.