Skip to content

Commit

Permalink
Generalize common graph store tests (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo authored Jul 31, 2024
1 parent 7f90b49 commit 552fae4
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 204 deletions.
9 changes: 9 additions & 0 deletions motleycrew/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,13 @@
"Defaults",
"MotleySupportedTool",
"MotleyAgentFactory",
"logger",
"configure_logging",
"AsyncBackend",
"GraphStoreType",
"LLMFamily",
"LLMFramework",
"LunaryEventName",
"LunaryRunType",
"TaskUnitStatus",
]
14 changes: 14 additions & 0 deletions motleycrew/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,29 @@ class LLMFamily:
OPENAI = "openai"
ANTHROPIC = "anthropic"

ALL = {OPENAI, ANTHROPIC}


class LLMFramework:
LANGCHAIN = "langchain"
LLAMA_INDEX = "llama_index"

ALL = {LANGCHAIN, LLAMA_INDEX}


class GraphStoreType:
KUZU = "kuzu"

ALL = {KUZU}


class TaskUnitStatus:
PENDING = "pending"
RUNNING = "running"
DONE = "done"

ALL = {PENDING, RUNNING, DONE}


class LunaryRunType:
LLM = "llm"
Expand All @@ -28,13 +36,17 @@ class LunaryRunType:
CHAIN = "chain"
EMBED = "embed"

ALL = {LLM, AGENT, TOOL, CHAIN, EMBED}


class LunaryEventName:
START = "start"
END = "end"
UPDATE = "update"
ERROR = "error"

ALL = {START, END, UPDATE, ERROR}


class AsyncBackend:
"""Backends for parallel crew execution.
Expand All @@ -48,3 +60,5 @@ class AsyncBackend:
ASYNCIO = "asyncio"
THREADING = "threading"
NONE = "none"

ALL = {ASYNCIO, THREADING, NONE}
21 changes: 21 additions & 0 deletions tests/test_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import kuzu
import pytest

from motleycrew.storage import MotleyKuzuGraphStore


class GraphStoreFixtures:
@pytest.fixture
def kuzu_graph_store(self, tmpdir):
db_path = tmpdir / "test_db"
db = kuzu.Database(str(db_path))

graph_store = MotleyKuzuGraphStore(db)
return graph_store

@pytest.fixture
def graph_store(self, request, kuzu_graph_store):
graph_stores = {
"kuzu": kuzu_graph_store,
}
return graph_stores.get(request.param)
188 changes: 188 additions & 0 deletions tests/test_storage/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from typing import Optional

import pytest

from motleycrew.common import GraphStoreType
from motleycrew.storage import MotleyGraphNode
from motleycrew.storage import MotleyKuzuGraphStore
from tests.test_storage import GraphStoreFixtures


class Entity(MotleyGraphNode):
int_param: int
optional_str_param: Optional[str] = None
optional_list_str_param: Optional[list[str]] = None


class TestMotleyGraphStore(GraphStoreFixtures):
@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_insert_new_node(self, graph_store):
entity = Entity(int_param=1)
created_entity = graph_store.insert_node(entity)
assert created_entity.id is not None
assert entity.id is not None # mutated in place

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_insert_node_and_retrieve(self, graph_store):
entity = Entity(int_param=1, optional_str_param="test", optional_list_str_param=["a", "b"])
inserted_entity = graph_store.insert_node(entity)
assert inserted_entity.id is not None

retrieved_entity = graph_store.get_node_by_class_and_id(
node_class=Entity, node_id=inserted_entity.id
)
assert inserted_entity == retrieved_entity

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_check_node_exists_true(self, graph_store):
entity = Entity(int_param=1)

graph_store.insert_node(entity)
assert graph_store.check_node_exists(entity)

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_check_node_exists_false(self, graph_store):
entity = Entity(int_param=1)
assert graph_store.check_node_exists(entity) is False

MotleyKuzuGraphStore._set_node_id(node=entity, node_id=2)
assert graph_store.check_node_exists(entity) is False

graph_store.ensure_node_table(Entity)
assert graph_store.check_node_exists(entity) is False

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_create_relation(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2)
graph_store.insert_node(entity1)
graph_store.insert_node(entity2)

graph_store.create_relation(from_node=entity1, to_node=entity2, label="p")

assert graph_store.check_relation_exists(from_node=entity1, to_node=entity2)
assert graph_store.check_relation_exists(from_node=entity1, to_node=entity2, label="p")
assert not graph_store.check_relation_exists(from_node=entity1, to_node=entity2, label="q")
assert not graph_store.check_relation_exists(from_node=entity2, to_node=entity1)

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_upsert_triplet(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2)
graph_store.upsert_triplet(from_node=entity1, to_node=entity2, label="p")

assert graph_store.check_node_exists(entity1)
assert graph_store.check_node_exists(entity2)

assert graph_store.check_relation_exists(from_node=entity1, to_node=entity2)
assert graph_store.check_relation_exists(from_node=entity1, to_node=entity2, label="p")
assert not graph_store.check_relation_exists(from_node=entity1, to_node=entity2, label="q")
assert not graph_store.check_relation_exists(from_node=entity2, to_node=entity1)

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_nodes_do_not_exist(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2)

assert not graph_store.check_node_exists(entity1)
assert not graph_store.check_node_exists(entity2)

assert not graph_store.check_relation_exists(from_node=entity1, to_node=entity2)
assert not graph_store.check_relation_exists(from_node=entity2, to_node=entity1, label="p")

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_relation_does_not_exist(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2)

assert not graph_store.check_relation_exists(from_node=entity1, to_node=entity2)
assert not graph_store.check_relation_exists(from_node=entity2, to_node=entity1)

graph_store.insert_node(entity1)
graph_store.insert_node(entity2)

assert not graph_store.check_relation_exists(from_node=entity1, to_node=entity2)
assert not graph_store.check_relation_exists(from_node=entity2, to_node=entity1)

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_delete_node(self, graph_store):
entity = Entity(int_param=1)
graph_store.insert_node(entity)
assert graph_store.check_node_exists(entity) is True

graph_store.delete_node(entity)
assert graph_store.check_node_exists(entity) is False

entity.int_param = 2 # check that entity is not frozen

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_delete_entity_with_relations(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2)

graph_store.insert_node(entity1)
graph_store.insert_node(entity2)
graph_store.create_relation(from_node=entity1, to_node=entity2, label="p")
assert graph_store.check_relation_exists(from_node=entity1, to_node=entity2) is True

graph_store.delete_node(entity1)
assert graph_store.check_node_exists(entity1) is False
assert graph_store.check_node_exists(entity2) is True
assert graph_store.check_relation_exists(from_node=entity1, to_node=entity2) is False

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_set_property(self, graph_store):
entity = Entity(int_param=1)
graph_store.insert_node(entity)
assert entity.optional_str_param is None
assert graph_store.get_node_by_class_and_id(Entity, entity.id).optional_str_param is None

entity.optional_str_param = "test"
assert graph_store.get_node_by_class_and_id(Entity, entity.id).optional_str_param == "test"

entity.optional_list_str_param = ["a", "b"]
assert graph_store.get_node_by_class_and_id(Entity, entity.id).optional_list_str_param == [
"a",
"b",
]

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_run_cypher_query(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2)

graph_store.insert_node(entity1)
graph_store.insert_node(entity2)
graph_store.create_relation(from_node=entity1, to_node=entity2, label="p")

query = """
MATCH (a:Entity {int_param: 1})-[r]->(b:Entity {int_param: 2})
RETURN a, r, b
"""
result = graph_store.run_cypher_query(query)
assert len(result) == 1
assert len(result[0]) == 3

a, r, b = result[0]
assert a["int_param"] == 1
assert b["int_param"] == 2
assert r["_label"] == "p"

@pytest.mark.parametrize("graph_store", GraphStoreType.ALL, indirect=True)
def test_run_cypher_query_with_container(self, graph_store):
entity1 = Entity(int_param=1)
entity2 = Entity(int_param=2, optional_list_str_param=["a", "b"])

graph_store.insert_node(entity1)
graph_store.insert_node(entity2)

query = """
MATCH (a:Entity)
WHERE a.int_param = 2
RETURN a
"""
result = graph_store.run_cypher_query(query, container=Entity)
assert len(result) == 1
assert isinstance(result[0], Entity)
assert result[0].int_param == 2
assert result[0].optional_list_str_param == ["a", "b"]
Loading

0 comments on commit 552fae4

Please sign in to comment.