Skip to content

Commit f4f3e5b

Browse files
committed
feat: 方法单独提取
1 parent 1e4880c commit f4f3e5b

File tree

15 files changed

+325
-620
lines changed

15 files changed

+325
-620
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ experimental/data/**/*.txt
5454
experimental/pre_trained
5555
experimental/scripts/**/checkpoints
5656
experimental/scripts/**/example*.json
57-
experimental/scripts/**/seqeval
57+
experimental/metrics/**
5858
experimental/scripts/test*.py
5959
experimental/scripts/test
6060
test*.py

Diff for: examples/get_knowledge_graph_pdf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
set_logger(console=True, file=False)
1414

1515
parser = argparse.ArgumentParser()
16-
parser.add_argument('-u', '--url', default='http://localhost:7474')
16+
parser.add_argument('-u', '--url', default='bolt://localhost:7687')
1717
parser.add_argument('-n', '--user', default='neo4j')
1818
parser.add_argument('-p', '--password', default='neo4j')
1919
parser.add_argument('-f', '--file')

Diff for: examples/knowledge_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
parser = argparse.ArgumentParser()
13-
parser.add_argument('-u', '--url', default='http://localhost:7474')
13+
parser.add_argument('-u', '--url', default='bolt://localhost:7687')
1414
parser.add_argument('-n', '--user', default='neo4j')
1515
parser.add_argument('-p', '--password', default='neo4j')
1616
args = parser.parse_args()

Diff for: experimental/README.md

+9-5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ experimental/
1313
│ ├── pipeline.py # 原始数据处理脚本
1414
│ └── txt/
1515
│ └── *.txt # 原始纯文本数据
16-
├── README.md
17-
└── scripts/
18-
├── ner/ # 实体识别模型
19-
├── overview.py # 数据概览
20-
└── pre_trained/ # 预训练模型
16+
├── scripts/
17+
│ ├── ner/ # 实体识别模型
18+
│ ├── overview.py # 数据概览
19+
│ └── pre_trained/ # 预训练模型
20+
├── results/ # 结果
21+
└── README.md
22+
2123
```
2224

2325
### 实验准备
@@ -29,6 +31,8 @@ experimental/
2931
export SWANLAB_API_KEY=
3032
```
3133

34+
#### 评估指标
35+
3236
#### 数据
3337

3438
### 数据格式说明

Diff for: experimental/scripts/ke/model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch.nn import functional as F
1212

13+
1314
class BertBiLSTMCRF(PreTrainedModel):
1415
def __init__(self, config):
1516
super().__init__(config)
@@ -55,10 +56,10 @@ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, **kwar
5556
"loss": loss,
5657
"pred_label_ids": pred_label_ids
5758
}
58-
else:
59-
return {
60-
"pred_label_ids": pred_label_ids
61-
}
59+
60+
return {
61+
"pred_label_ids": pred_label_ids
62+
}
6263

6364

6465
class BertForRE(PreTrainedModel):

Diff for: experimental/scripts/ke/train_with_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def main(args):
207207
parser.add_argument("--checkpoint", type=str, default="experimental/scripts/ke/checkpoints", help="检查点存储路径")
208208
parser.add_argument("--log", type=str, default="experimental/scripts/ke/logs", help="日志存储路径")
209209
parser.add_argument("--use_local_metric", type=bool, default=True, help="是否使用本地评估指标")
210-
parser.add_argument("--seqeval_path", type=str, default="experimental/scripts/ke/seqeval", help="本地seqeval指标代码路径")
210+
parser.add_argument("--seqeval_path", type=str, default="experimental/metrics/seqeval", help="本地seqeval指标代码路径")
211211
parser.add_argument("--lr", type=float, default=5e-5)
212212
parser.add_argument("--epochs", type=int, default=5)
213213
parser.add_argument("--batch_size", type=int, default=2)

Diff for: pyproject.toml

-12
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,23 @@ dependencies = [
2626
"tqdm==4.66.4",
2727
"transformers>=4.43.3",
2828
"vllm>=0.7.3",
29-
"pymongo==4.8.0",
3029
"paddleocr==2.8.1",
3130
"paddlepaddle==2.6.1",
32-
"ollama==0.3.3",
33-
"json5==0.9.25",
3431
"ray>=2.35.0",
35-
"verovio==4.3.1",
3632
"maturin==1.7.4",
37-
"patchelf==0.17.2.1",
3833
"docstring-parser==0.16",
3934
"doclayout-yolo==0.0.3",
4035
"shortuuid>=1.0.13",
4136
"jupyter>=1.1.1",
42-
"dashscope>=1.20.13",
43-
"tenacity>=9.0.0",
4437
"fastapi>=0.115.6",
45-
"shuangchentools>=0.0.6",
46-
"qwen-vl-utils>=0.0.8",
4738
"trl>=0.14.0",
48-
"deprecated>=1.2.18",
4939
"singleton-decorator>=1.0.0",
5040
"tabulate>=0.9.0",
5141
"torchcrf>=1.1.0",
52-
"seqeval>=1.2.2",
5342
"scikit-learn>=1.6.1",
5443
"swanlab>=0.5.2",
5544
"evaluate>=0.4.3",
5645
"pymilvus>=2.5.6",
57-
"openai-agents>=0.0.7",
5846
"neo4j>=5.28.1",
5947
]
6048

Diff for: src/course_graph/agent/agent.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,21 @@ def __init__(
6868
self.tool_choice = tool_choice
6969

7070
self.messages: list[ChatCompletionMessageParam] = []
71-
72-
for server in mcp_server:
73-
tools = server.tools
74-
for tool in tools:
75-
self.tools.append({
76-
'type': 'function',
77-
'function': {
78-
'name': tool.name,
79-
'description': tool.description,
80-
'parameters': tool.inputSchema
81-
}
82-
}) # 注意不能使用 add_tools 方法
83-
self.mcp_functions[tool.name] = server
84-
85-
def chat(self, message: str = None) -> ChatCompletionMessage:
71+
if mcp_server:
72+
for server in mcp_server:
73+
tools = server.tools
74+
for tool in tools:
75+
self.tools.append({
76+
'type': 'function',
77+
'function': {
78+
'name': tool.name,
79+
'description': tool.description,
80+
'parameters': tool.inputSchema
81+
}
82+
}) # 注意不能使用 add_tools 方法
83+
self.mcp_functions[tool.name] = server
84+
85+
def chat_completion(self, message: str = None) -> ChatCompletionMessage:
8686
""" Agent 多轮对话
8787
8888
Args:
@@ -111,6 +111,18 @@ def chat(self, message: str = None) -> ChatCompletionMessage:
111111
self.messages.append(resp) # 比 add_assistant_message 信息更详细
112112

113113
return response
114+
115+
def chat(self, message: str = None) -> str:
116+
""" Agent 多轮对话
117+
118+
Args:
119+
message (str): 用户输入
120+
121+
Returns:
122+
ChatCompletionMessage: 模型输出
123+
"""
124+
response = self.chat_completion(message)
125+
return response.content
114126

115127
def add_user_message(self, message: str) -> None:
116128
""" 添加用户记录

Diff for: src/course_graph/agent/controller.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
9292
data={'message': message}
9393
))
9494

95-
assistant_output = agent.chat(message)
95+
assistant_output = agent.chat_completion(message)
9696

9797
self._add_trace_event(TraceEvent(
9898
timestamp=datetime.now(),
@@ -188,7 +188,7 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
188188

189189
self.set_agent_instruction(agent)
190190

191-
assistant_output = agent.chat()
191+
assistant_output = agent.chat_completion()
192192
turn += 1
193193
if turn > self.max_turns:
194194
raise MaxTurnsException

Diff for: src/course_graph/database/neo4j_.py

+80-18
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# File: course_graph/database/neo4j.py
55
# Description: 定义图数据库连接
66

7-
from neo4j import GraphDatabase, Driver
7+
from neo4j import GraphDatabase
88
from neo4j.graph import Node, Relationship
99
from singleton_decorator import singleton
10-
from functools import lru_cache
10+
from functools import cache
1111

1212

1313
@singleton
@@ -39,8 +39,8 @@ def close(self):
3939
self.session.close()
4040
self.driver.close()
4141

42-
@lru_cache
43-
def match_nodes(self, skip: int = None, limit: int = None) -> list[Node]:
42+
@cache
43+
def get_nodes(self, skip: int = None, limit: int = None) -> list[Node]:
4444
""" 获取所有 Node
4545
4646
Args:
@@ -50,15 +50,16 @@ def match_nodes(self, skip: int = None, limit: int = None) -> list[Node]:
5050
Returns:
5151
list: 所有 Node
5252
"""
53-
records, _, _ = self.driver.execute_query(
54-
"MATCH (n) RETURN n skip $skip limit $limit",
55-
limit=limit,
56-
skip=skip
57-
)
53+
query = "MATCH (n) RETURN n"
54+
if skip is not None:
55+
query += f" skip $skip"
56+
if limit is not None:
57+
query += f" limit $limit"
58+
records, _, _ = self.driver.execute_query(query, limit=limit, skip=skip)
5859
return [record['n'] for record in records]
5960

60-
@lru_cache
61-
def match_relations(self, skip: int = None, limit: int = None) -> list[Relationship]:
61+
@cache
62+
def get_relations(self, skip: int = None, limit: int = None) -> list[Relationship]:
6263
""" 获取所有 Relation
6364
6465
Args:
@@ -68,14 +69,15 @@ def match_relations(self, skip: int = None, limit: int = None) -> list[Relations
6869
Returns:
6970
list: 所有 Relation
7071
"""
71-
records, _, _ = self.driver.execute_query(
72-
"MATCH ()-[r]->() RETURN r skip $skip limit $limit",
73-
limit=limit,
74-
skip=skip
75-
)
72+
query = "MATCH (m)-[r]->(n) RETURN m, r, n"
73+
if skip is not None:
74+
query += f" skip $skip"
75+
if limit is not None:
76+
query += f" limit $limit"
77+
records, _, _ = self.driver.execute_query(query, limit=limit, skip=skip)
7678
return [record['r'] for record in records]
7779

78-
@lru_cache
80+
@cache
7981
def get_nodes_count(self) -> int:
8082
""" 获取所有 Node 的数量
8183
@@ -87,7 +89,7 @@ def get_nodes_count(self) -> int:
8789
)
8890
return records[0]['count(n)']
8991

90-
@lru_cache
92+
@cache
9193
def get_relations_count(self) -> int:
9294
""" 获取所有 Relation 的数量
9395
@@ -99,5 +101,65 @@ def get_relations_count(self) -> int:
99101
)
100102
return records[0]['count(r)']
101103

104+
@cache
105+
def get_max_relation_count(self) -> int:
106+
""" 获取所有 Relation 的最大 ID
107+
108+
Returns:
109+
int: 所有 Relation 的最大 ID
110+
"""
111+
records, _, _ = self.driver.execute_query("""
112+
MATCH (n)-[r]-()
113+
RETURN n, count(r) AS relation_count
114+
ORDER BY relation_count DESC
115+
LIMIT 1""")
116+
return records[0]['relation_count']
117+
118+
@cache
119+
def get_nodes_with_relation_count(self, skip: int = None, limit: int = None) -> list[tuple[Node, int]]:
120+
""" 获取所有 Node 及其关系数量
121+
122+
Args:
123+
skip (int, optional): 跳过. Defaults to None.
124+
limit (int, optional): 限制. Defaults to None.
125+
126+
Returns:
127+
list[tuple[Node, int]]: 所有 Node 及其关系数量
128+
"""
129+
query = "MATCH (n)-[r]-() RETURN n, count(r)"
130+
if skip is not None:
131+
query += f" skip $skip"
132+
if limit is not None:
133+
query += f" limit $limit"
134+
records, _, _ = self.driver.execute_query(query, limit=limit, skip=skip)
135+
return [(record['n'], record['count(r)']) for record in records]
136+
137+
@cache
138+
def get_node_by_id(self, id: int) -> Node:
139+
""" 获取指定 id 的 Node
140+
141+
Args:
142+
id (int): 指定 id
143+
144+
Returns:
145+
Node: 相应 Node
146+
"""
147+
records, _, _ = self.driver.execute_query("MATCH (n) WHERE n.id = $id RETURN n", id=id)
148+
return records[0]['n']
149+
150+
@cache
151+
def get_relations_by_node_id(self, id: int) -> list[Relationship]:
152+
""" 获取指定 id 的 Node 的所有 Relation
153+
154+
Args:
155+
id (int): 指定 id
156+
157+
Returns:
158+
list[Relationship]: 相应 Node 的所有 Relation
159+
"""
160+
records, _, _ = self.driver.execute_query("MATCH (n)-[r]->() WHERE n.id = $id RETURN r", id=id)
161+
return [record['r'] for record in records]
162+
102163
def __hash__(self):
103164
return hash(self.url)
165+

0 commit comments

Comments
 (0)