Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: encapsulate process episode logic #250

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ RUN poetry build && pip install dist/*.whl

# Install server dependencies
WORKDIR /app/server
RUN poetry install --no-interaction --no-ansi --no-dev
RUN poetry install --no-interaction --no-ansi --only main --no-root

FROM python:3.12-slim

Expand Down
250 changes: 130 additions & 120 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ async def add_episode_endpoint(episode_data: EpisodeData):
return {"message": "Episode processing started"}
"""
try:
start = time()

entity_edges: list[EntityEdge] = []
now = utc_now()

previous_episodes = await self.retrieve_episodes(
Expand All @@ -331,155 +328,168 @@ async def add_episode_endpoint(episode_data: EpisodeData):
valid_at=reference_time,
)
)
episode, nodes, entity_edges = await self.process_episode(
group_id=group_id,
episode=episode,
previous_episodes=previous_episodes,
update_communities=update_communities,
)
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)

except Exception as e:
raise e

async def process_episode(
self,
group_id: str,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
update_communities: bool,
):
start = time()

entity_edges: list[EntityEdge] = []
now = utc_now()

# Extract entities as nodes

# Extract entities as nodes
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
# Calculate Embeddings

# Calculate Embeddings
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
)

# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
)
)

# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
)
)
# Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

# Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
resolve_extracted_nodes(
self.llm_client,
extracted_nodes,
existing_nodes_lists,
episode,
previous_episodes,
),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
),
)
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes = mentioned_nodes
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
resolve_extracted_nodes(
self.llm_client,
extracted_nodes,
existing_nodes_lists,
episode,
previous_episodes,
),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes, group_id),
)
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes = mentioned_nodes

extracted_edges_with_resolved_pointers = resolve_edge_pointers(
extracted_edges, uuid_map
)
extracted_edges_with_resolved_pointers = resolve_edge_pointers(extracted_edges, uuid_map)

# calculate embeddings
# calculate embeddings
await semaphore_gather(
*[
edge.generate_embedding(self.embedder)
for edge in extracted_edges_with_resolved_pointers
]
)

# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
edge.generate_embedding(self.embedder)
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
)
logger.debug(
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
)

# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
)
logger.debug(
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
existing_source_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
None,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)

existing_source_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
None,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
existing_target_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)

existing_target_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
existing_edges_list: list[list[EntityEdge]] = [
source_lst + target_lst
for source_lst, target_lst in zip(
existing_source_edges_list, existing_target_edges_list
)
]

existing_edges_list: list[list[EntityEdge]] = [
source_lst + target_lst
for source_lst, target_lst in zip(
existing_source_edges_list, existing_target_edges_list
)
]
resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.llm_client,
extracted_edges_with_resolved_pointers,
related_edges_list,
existing_edges_list,
episode,
previous_episodes,
)

resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.llm_client,
extracted_edges_with_resolved_pointers,
related_edges_list,
existing_edges_list,
episode,
previous_episodes,
)
entity_edges.extend(resolved_edges + invalidated_edges)

entity_edges.extend(resolved_edges + invalidated_edges)
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')

logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)

episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
logger.debug(f'Built episodic edges: {episodic_edges}')

logger.debug(f'Built episodic edges: {episodic_edges}')
episode.entity_edges = [edge.uuid for edge in entity_edges]

episode.entity_edges = [edge.uuid for edge in entity_edges]
if not self.store_raw_episode_content:
episode.content = ''

if not self.store_raw_episode_content:
episode.content = ''
await add_nodes_and_edges_bulk(self.driver, [episode], episodic_edges, nodes, entity_edges)

await add_nodes_and_edges_bulk(
self.driver, [episode], episodic_edges, nodes, entity_edges
# Update any communities
if update_communities:
await semaphore_gather(
*[
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
]
)
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')

# Update any communities
if update_communities:
await semaphore_gather(
*[
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
]
)
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')

return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)

except Exception as e:
raise e
return episode, nodes, entity_edges

async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
"""
Expand Down
Loading