Skip to content

Commit

Permalink
Refactor Create Final Community reports to simplify code (#1456)
Browse files Browse the repository at this point in the history
* Optimize prep claims

* Optimize community hierarchy restore

* Partial optimization of prepare_community_reports

* More optimization code

* Fix context string generation

* Filter community -1

* Fix cache, add more optimization fixes

* Fix local search community ids

* Cleanup

* Format

* Semver

* Remove perf counter

* Unused import

* Format

* Fix edge addition to reports

* Add edge by edge context creation

* Re-org of the optimization code

* Format

* Ruff

* Some Ruff fixes

* More pyright

* More pyright

* Pyright

* Pyright

* Update tests
  • Loading branch information
AlonsoGuevara authored Dec 5, 2024
1 parent b001422 commit d43124e
Show file tree
Hide file tree
Showing 21 changed files with 294 additions and 376 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241130004740004072.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Optimize Final Community Reports calculation and stabilize cache"
}
66 changes: 31 additions & 35 deletions graphrag/index/flows/create_final_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CLAIM_STATUS,
CLAIM_SUBJECT,
CLAIM_TYPE,
COMMUNITY_ID,
EDGE_DEGREE,
EDGE_DESCRIPTION,
EDGE_DETAILS,
Expand Down Expand Up @@ -83,9 +84,7 @@ async def create_final_community_reports(

community_reports["community"] = community_reports["community"].astype(int)
community_reports["human_readable_id"] = community_reports["community"]
community_reports["id"] = community_reports["community"].apply(
lambda _x: str(uuid4())
)
community_reports["id"] = [uuid4().hex for _ in range(len(community_reports))]

# Merge with communities to add size and period
merged = community_reports.merge(
Expand Down Expand Up @@ -115,45 +114,42 @@ async def create_final_community_reports(


def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
# merge values of four columns into a map column
input[NODE_DETAILS] = input.apply(
lambda x: {
NODE_ID: x[NODE_ID],
NODE_NAME: x[NODE_NAME],
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
NODE_DEGREE: x[NODE_DEGREE],
},
axis=1,
"""Prepare nodes by filtering, filling missing descriptions, and creating NODE_DETAILS."""
# Filter rows where community is not -1
input = input.loc[input.loc[:, COMMUNITY_ID] != -1]

# Fill missing values in NODE_DESCRIPTION
input.loc[:, NODE_DESCRIPTION] = input.loc[:, NODE_DESCRIPTION].fillna(
"No Description"
)

# Create NODE_DETAILS column
input[NODE_DETAILS] = input.loc[
:, [NODE_ID, NODE_NAME, NODE_DESCRIPTION, NODE_DEGREE]
].to_dict(orient="records")

return input


def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[EDGE_DETAILS] = input.apply(
lambda x: {
EDGE_ID: x[EDGE_ID],
EDGE_SOURCE: x[EDGE_SOURCE],
EDGE_TARGET: x[EDGE_TARGET],
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
EDGE_DEGREE: x[EDGE_DEGREE],
},
axis=1,
)
# Fill missing NODE_DESCRIPTION
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)

# Create EDGE_DETAILS column
input[EDGE_DETAILS] = input.loc[
:, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
].to_dict(orient="records")

return input


def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[CLAIM_DETAILS] = input.apply(
lambda x: {
CLAIM_ID: x[CLAIM_ID],
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
CLAIM_TYPE: x[CLAIM_TYPE],
CLAIM_STATUS: x[CLAIM_STATUS],
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
},
axis=1,
)
# Fill missing NODE_DESCRIPTION
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)

# Create CLAIM_DETAILS column
input[CLAIM_DETAILS] = input.loc[
:, [CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION]
].to_dict(orient="records")

return input
14 changes: 0 additions & 14 deletions graphrag/index/graph/extractors/community_reports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,11 @@
prep_community_report_context,
)
from graphrag.index.graph.extractors.community_reports.sort_context import sort_context
from graphrag.index.graph.extractors.community_reports.utils import (
filter_claims_to_nodes,
filter_edges_to_nodes,
filter_nodes_to_level,
get_levels,
set_context_exceeds_flag,
set_context_size,
)

__all__ = [
"CommunityReportsExtractor",
"build_mixed_context",
"filter_claims_to_nodes",
"filter_edges_to_nodes",
"filter_nodes_to_level",
"get_levels",
"prep_community_report_context",
"schemas",
"set_context_exceeds_flag",
"set_context_size",
"sort_context",
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
build_mixed_context,
)
from graphrag.index.graph.extractors.community_reports.sort_context import sort_context
from graphrag.index.graph.extractors.community_reports.utils import set_context_size
from graphrag.index.utils.dataframes import (
antijoin,
drop_columns,
Expand All @@ -23,6 +22,7 @@
union,
where_column_equals,
)
from graphrag.query.llm.text_utils import num_tokens

log = logging.getLogger(__name__)

Expand All @@ -31,7 +31,7 @@ def prep_community_report_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int | str,
level: int,
max_tokens: int,
) -> pd.DataFrame:
"""
Expand All @@ -44,22 +44,32 @@ def prep_community_report_context(
if report_df is None:
report_df = pd.DataFrame()

level = int(level)
level_context_df = _at_level(level, local_context_df)
valid_context_df = _within_context(level_context_df)
invalid_context_df = _exceeding_context(level_context_df)
# Filter by community level
level_context_df = local_context_df.loc[
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
]

# Filter valid and invalid contexts using boolean logic
valid_context_df = level_context_df.loc[
~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
]
invalid_context_df = level_context_df.loc[
level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
]

# there is no report to substitute with, so we just trim the local context of the invalid context records
# this case should only happen at the bottom level of the community hierarchy where there are no sub-communities
if invalid_context_df.empty:
return valid_context_df

if report_df.empty:
invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_tokens
)
set_context_size(invalid_context_df)
invalid_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = 0
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
:, schemas.CONTEXT_STRING
].map(num_tokens)
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0
return union(valid_context_df, invalid_context_df)

level_context_df = _antijoin_reports(level_context_df, report_df)
Expand All @@ -74,12 +84,13 @@ def prep_community_report_context(
# handle any remaining invalid records that can't be subsituted with sub-community reports
# this should be rare, but if it happens, we will just trim the local context to fit the limit
remaining_df = _antijoin_reports(invalid_context_df, community_df)
remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df, max_tokens
)

result = union(valid_context_df, community_df, remaining_df)
set_context_size(result)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)

result[schemas.CONTEXT_EXCEED_FLAG] = 0
return result

Expand All @@ -94,16 +105,6 @@ def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame:
return where_column_equals(df, schemas.COMMUNITY_LEVEL, level)


def _exceeding_context(df: pd.DataFrame) -> pd.DataFrame:
"""Return records where the context exceeds the limit."""
return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 1)


def _within_context(df: pd.DataFrame) -> pd.DataFrame:
"""Return records where the context is within the limit."""
return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 0)


def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
"""Return records in df that are not in reports."""
return antijoin(df, reports, schemas.NODE_COMMUNITY)
Expand Down
Loading

0 comments on commit d43124e

Please sign in to comment.