Skip to content

Commit

Permalink
added helper functions for matching under rel2graph.neo4j
Browse files Browse the repository at this point in the history
  • Loading branch information
jkminder committed Mar 12, 2024
1 parent d7012ca commit adccd82
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 4 deletions.
6 changes: 5 additions & 1 deletion docs/source/api/neo4j.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
Neo4j Integration
-----------------

These functions abstract complexity of interacting with Neo4j. Instead of writing Cypher queries, you can use Python objects to create and merge nodes and relationships.
These functions abstract complexity of interacting with Neo4j. Instead of writing Cypher queries, you can use Python objects to create, merge and match nodes and relationships.

.. autofunction:: rel2graph.neo4j.create

.. autofunction:: rel2graph.neo4j.merge

.. autofunction:: rel2graph.neo4j.match_nodes

.. autofunction:: rel2graph.neo4j.match_relationships

Subgraph
~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion docs/source/neo4j.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Neo4j Integration
The rel2graph library comes with a set of abstract classes that simplify the interaction with neo4j in python. They are derived from the now EOL library py2neo.
This includes python objects to represent |Node| and |Relationship| objects as well as a |Subgraph| object that can be used to represent a set of nodes and relationships.
|Node| and |Relationship| objects are themself a |Subgraph|. The two functions :py:func:`create <rel2graph.neo4j.create>` and :py:func:`merge <rel2graph.neo4j.create>` can be used to create or merge a |Subgraph| into a neo4j database given a neo4j session.

Further use the functions :py:func:`match_nodes <rel2graph.neo4j.match_nodes>` and :py:func:`match_relationships <rel2graph.neo4j.match_relationships>` to match elements in the graph and return a list of |Node| or |Relationship|.
We refer to the :doc:`neo4j documentation <api/neo4j>` for more information.

.. |Subgraph| replace:: :py:class:`Subgraph <rel2graph.neo4j.Subgraph>`
Expand Down
82 changes: 81 additions & 1 deletion rel2graph/neo4j/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from neo4j import Session
from typing import List, Union

from .graph_elements import Node, Relationship, Subgraph, Attribute
from .cypher import cypher_join, _match_clause, encode_value, encode_key

def create(graph: Subgraph, session: Session):
"""
Expand All @@ -22,4 +24,82 @@ def merge(graph: Subgraph, session: Session, primary_label=None, primary_key=Non
primary_label (str): The primary label to merge on. Has to be provided if the nodes themselves don't have a primary label (Default: None)
primary_key (str): The primary key to merge on. Has to be provided if the graph elements themselves don't have a primary label (Default: None)
"""
session.execute_write(graph.__db_merge__, primary_label=primary_label, primary_key=primary_key)
session.execute_write(graph.__db_merge__, primary_label=primary_label, primary_key=primary_key)



def match_nodes(session: Session, *labels: List[str], **properties: dict):
"""
Matches nodes in the database.
Args:
labels (List[str]): The labels to match.
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
properties (dict): The properties to match.
"""
flat_params = [tuple(labels),]
data = []
for k, v in properties.items():
data.append(v)
flat_params.append(k)

unwind = "UNWIND $data as r" if len(data) > 0 else ""
clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data)
records = session.run(*clause).data()
print(clause)
# Convert to Node
out = []
for record in records:
node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)'])
out.append(node)
return out


def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, **properties: dict):
"""
Matches relationships in the database.
Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
from_node (Node): The node to match the relationship from (Default: None)
to_node (Node): The node to match the relationship to (Default: None)
rel_type (str): The type of the relationship to match (Default: None)
properties (dict): The properties to match.
"""
if from_node is not None:
assert from_node.identity is not None, "from_node must have an identity"

if to_node is not None:
assert to_node.identity is not None, "to_node must have an identity"

params = ""
for k, v in properties.items():
if params != "":
params += ", "
params += f"{encode_key(k)}: {encode_value(v)}"

clauses = []
if from_node is not None:
clauses.append(f"ID(from_node) = {from_node.identity}")
if to_node is not None:
clauses.append(f"ID(to_node) = {to_node.identity}")
if rel_type is not None:
clauses.append(f"type(r) = {encode_value(rel_type)}")

clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)"
)

print(clause)
records = session.run(*clause).data()
out = []
for record in records:
print(record)
fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)']) if from_node is None else from_node
tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)']) if to_node is None else to_node
rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)'])
out.append(rel)
return out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
setup(
name = "rel2graph",
packages = find_packages(),
version = "1.0.2",
version = "1.1.0",
description = "Library for converting relational data into graph data (neo4j)",
author = "Julian Minder",
author_email = "[email protected]",
Expand Down
100 changes: 100 additions & 0 deletions tests/unit/neo4j/test_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Tests for create and merge functionality.
authors: Julian Minder
"""
import os
import pytest
import datetime
from neo4j import GraphDatabase, time

from rel2graph.neo4j import Node, Relationship, Subgraph, create, merge, match_nodes, match_relationships
from rel2graph.common_modules import MERGE_RELATIONSHIPS

@pytest.fixture
def session():
# Check if custom port
try:
port = os.environ["NEO4J PORT"]
except KeyError:
port = 7687
# Initialise graph

driver = GraphDatabase.driver("bolt://localhost:{}".format(port), auth=("neo4j", "password"))
with driver.session() as session:
try:
delete_all(session)
# generate test data
n1 = Node("test", "second", id=1, name="test1")
n2 = Node("test", id=2, name="test2")
n3 = Node("anotherlabel", id=3, name="test3")

r1 = Relationship(n1, "to", n2, id=1)
r2 = Relationship(n1, "to", n3, id=2)

graph = n1 | n2 | n3 | r1 | r2
create(graph, session)
yield session
finally:
session.close()
driver.close()
return

def check_node(nodes, id):
return len([node for node in nodes if node["id"] == id]) == 1

def check_rel(rels, id):
return len([rel for rel in rels if rel["id"] == id]) == 1

def delete_all(session):
session.run("MATCH (n) DETACH DELETE n")

def test_match_nodes(session):
# match by single label
nodes = match_nodes(session, "test")
assert(len(nodes) == 2)
assert(check_node(nodes, 1))
assert(check_node(nodes, 2))

# match by multiple labels
nodes = match_nodes(session, "test", "second")
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

# match by properties with no label
nodes = match_nodes(session, name="test3")
assert(len(nodes) == 1)
assert(check_node(nodes, 3))

# match by properties with label
nodes = match_nodes(session, "test", name="test1")
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

def test_match_relationships(session):
# match by type
rels = match_relationships(session, rel_type="to")
assert(len(rels) == 2)
assert(check_rel(rels, 1))
assert(check_rel(rels, 2))

# match by properties
rels = match_relationships(session, rel_type="to", id=1)
assert(len(rels) == 1)
assert(check_rel(rels, 1))

# match by from node
n1 = match_nodes(session, "test", id=1)[0]
rels = match_relationships(session, from_node=n1)
assert(len(rels) == 2)
assert(check_rel(rels, 1))
assert(check_rel(rels, 2))

# match by to node
n2 = match_nodes(session, "test", id=2)[0]
rels = match_relationships(session, to_node=n2)
assert(len(rels) == 1)
assert(check_rel(rels, 1))

0 comments on commit adccd82

Please sign in to comment.