Skip to content

Commit

Permalink
Unify default node type in llm graph transformer (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored Oct 1, 2024
1 parent 900c6df commit 6fdfdbe
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions libs/experimental/langchain_experimental/graph_transformers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field, create_model

DEFAULT_NODE_TYPE = "Node"

examples = [
{
"text": (
Expand Down Expand Up @@ -479,7 +481,7 @@ def _parse_and_clean_json(
nodes.append(
Node(
id=node["id"],
type=node.get("type", "Node"),
type=node.get("type", DEFAULT_NODE_TYPE),
properties=node_properties,
)
)
Expand All @@ -502,7 +504,7 @@ def _parse_and_clean_json(
if el["id"] == rel["source_node_id"]
][0]
except IndexError:
rel["source_node_type"] = None
rel["source_node_type"] = DEFAULT_NODE_TYPE
if not rel.get("target_node_type"):
try:
rel["target_node_type"] = [
Expand All @@ -511,7 +513,7 @@ def _parse_and_clean_json(
if el["id"] == rel["target_node_id"]
][0]
except IndexError:
rel["target_node_type"] = None
rel["target_node_type"] = DEFAULT_NODE_TYPE

rel_properties = {}
if "properties" in rel and rel["properties"]:
Expand Down Expand Up @@ -543,7 +545,7 @@ def _format_nodes(nodes: List[Node]) -> List[Node]:
id=el.id.title() if isinstance(el.id, str) else el.id,
type=el.type.capitalize() # type: ignore[arg-type]
if el.type
else None, # handle empty strings # type: ignore[arg-type]
else DEFAULT_NODE_TYPE, # handle empty strings # type: ignore[arg-type]
properties=el.properties,
)
for el in nodes
Expand Down Expand Up @@ -758,11 +760,15 @@ def process_response(
continue
# Nodes need to be deduplicated using a set
# Use default Node label for nodes if missing
nodes_set.add((rel["head"], rel.get("head_type", "Node")))
nodes_set.add((rel["tail"], rel.get("tail_type", "Node")))
nodes_set.add((rel["head"], rel.get("head_type", DEFAULT_NODE_TYPE)))
nodes_set.add((rel["tail"], rel.get("tail_type", DEFAULT_NODE_TYPE)))

source_node = Node(id=rel["head"], type=rel.get("head_type", "Node"))
target_node = Node(id=rel["tail"], type=rel.get("tail_type", "Node"))
source_node = Node(
id=rel["head"], type=rel.get("head_type", DEFAULT_NODE_TYPE)
)
target_node = Node(
id=rel["tail"], type=rel.get("tail_type", DEFAULT_NODE_TYPE)
)
relationships.append(
Relationship(
source=source_node, target=target_node, type=rel["relation"]
Expand Down Expand Up @@ -838,11 +844,15 @@ async def aprocess_response(
continue
# Nodes need to be deduplicated using a set
# Use default Node label for nodes if missing
nodes_set.add((rel["head"], rel.get("head_type", "Node")))
nodes_set.add((rel["tail"], rel.get("tail_type", "Node")))
nodes_set.add((rel["head"], rel.get("head_type", DEFAULT_NODE_TYPE)))
nodes_set.add((rel["tail"], rel.get("tail_type", DEFAULT_NODE_TYPE)))

source_node = Node(id=rel["head"], type=rel.get("head_type", "Node"))
target_node = Node(id=rel["tail"], type=rel.get("tail_type", "Node"))
source_node = Node(
id=rel["head"], type=rel.get("head_type", DEFAULT_NODE_TYPE)
)
target_node = Node(
id=rel["tail"], type=rel.get("tail_type", DEFAULT_NODE_TYPE)
)
relationships.append(
Relationship(
source=source_node, target=target_node, type=rel["relation"]
Expand Down

0 comments on commit 6fdfdbe

Please sign in to comment.