diff --git a/.semversioner/next-release/patch-20241130004740004072.json b/.semversioner/next-release/patch-20241130004740004072.json new file mode 100644 index 0000000000..b361849a0d --- /dev/null +++ b/.semversioner/next-release/patch-20241130004740004072.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Optimize Final Community Reports calculation and stabilize cache" +} diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index 681cab7c90..e49a01aee5 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -19,6 +19,7 @@ CLAIM_STATUS, CLAIM_SUBJECT, CLAIM_TYPE, + COMMUNITY_ID, EDGE_DEGREE, EDGE_DESCRIPTION, EDGE_DETAILS, @@ -83,9 +84,7 @@ async def create_final_community_reports( community_reports["community"] = community_reports["community"].astype(int) community_reports["human_readable_id"] = community_reports["community"] - community_reports["id"] = community_reports["community"].apply( - lambda _x: str(uuid4()) - ) + community_reports["id"] = [uuid4().hex for _ in range(len(community_reports))] # Merge with communities to add size and period merged = community_reports.merge( @@ -115,45 +114,42 @@ async def create_final_community_reports( def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame: - input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) - # merge values of four columns into a map column - input[NODE_DETAILS] = input.apply( - lambda x: { - NODE_ID: x[NODE_ID], - NODE_NAME: x[NODE_NAME], - NODE_DESCRIPTION: x[NODE_DESCRIPTION], - NODE_DEGREE: x[NODE_DEGREE], - }, - axis=1, + """Prepare nodes by filtering, filling missing descriptions, and creating NODE_DETAILS.""" + # Filter rows where community is not -1 + input = input.loc[input.loc[:, COMMUNITY_ID] != -1] + + # Fill missing values in NODE_DESCRIPTION + input.loc[:, NODE_DESCRIPTION] = input.loc[:, NODE_DESCRIPTION].fillna( + "No Description" ) + + # Create NODE_DETAILS column + input[NODE_DETAILS] = input.loc[ + :, [NODE_ID, NODE_NAME, NODE_DESCRIPTION, NODE_DEGREE] + ].to_dict(orient="records") + return input def _prep_edges(input: pd.DataFrame) -> pd.DataFrame: - input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) - input[EDGE_DETAILS] = input.apply( - lambda x: { - EDGE_ID: x[EDGE_ID], - EDGE_SOURCE: x[EDGE_SOURCE], - EDGE_TARGET: x[EDGE_TARGET], - EDGE_DESCRIPTION: x[EDGE_DESCRIPTION], - EDGE_DEGREE: x[EDGE_DEGREE], - }, - axis=1, - ) + # Fill missing NODE_DESCRIPTION + input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True) + + # Create EDGE_DETAILS column + input[EDGE_DETAILS] = input.loc[ + :, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE] + ].to_dict(orient="records") + return input def _prep_claims(input: pd.DataFrame) -> pd.DataFrame: - input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) - input[CLAIM_DETAILS] = input.apply( - lambda x: { - CLAIM_ID: x[CLAIM_ID], - CLAIM_SUBJECT: x[CLAIM_SUBJECT], - CLAIM_TYPE: x[CLAIM_TYPE], - CLAIM_STATUS: x[CLAIM_STATUS], - CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION], - }, - axis=1, - ) + # Fill missing NODE_DESCRIPTION + input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True) + + # Create CLAIM_DETAILS column + input[CLAIM_DETAILS] = input.loc[ + :, [CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION] + ].to_dict(orient="records") + return input diff --git a/graphrag/index/graph/extractors/community_reports/__init__.py b/graphrag/index/graph/extractors/community_reports/__init__.py index bac91674c2..63d75de87e 100644 --- a/graphrag/index/graph/extractors/community_reports/__init__.py +++ b/graphrag/index/graph/extractors/community_reports/__init__.py @@ -14,25 +14,11 @@ prep_community_report_context, ) from graphrag.index.graph.extractors.community_reports.sort_context import sort_context -from graphrag.index.graph.extractors.community_reports.utils import ( - filter_claims_to_nodes, - filter_edges_to_nodes, - filter_nodes_to_level, - get_levels, - set_context_exceeds_flag, - set_context_size, -) __all__ = [ "CommunityReportsExtractor", "build_mixed_context", - "filter_claims_to_nodes", - "filter_edges_to_nodes", - "filter_nodes_to_level", - "get_levels", "prep_community_report_context", "schemas", - "set_context_exceeds_flag", - "set_context_size", "sort_context", ] diff --git a/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py b/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py index a4df7d533a..22e707c67f 100644 --- a/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py +++ b/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py @@ -13,7 +13,6 @@ build_mixed_context, ) from graphrag.index.graph.extractors.community_reports.sort_context import sort_context -from graphrag.index.graph.extractors.community_reports.utils import set_context_size from graphrag.index.utils.dataframes import ( antijoin, drop_columns, @@ -23,6 +22,7 @@ union, where_column_equals, ) +from graphrag.query.llm.text_utils import num_tokens log = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def prep_community_report_context( report_df: pd.DataFrame | None, community_hierarchy_df: pd.DataFrame, local_context_df: pd.DataFrame, - level: int | str, + level: int, max_tokens: int, ) -> pd.DataFrame: """ @@ -44,10 +44,18 @@ def prep_community_report_context( if report_df is None: report_df = pd.DataFrame() - level = int(level) - level_context_df = _at_level(level, local_context_df) - valid_context_df = _within_context(level_context_df) - invalid_context_df = _exceeding_context(level_context_df) + # Filter by community level + level_context_df = local_context_df.loc[ + local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level + ] + + # Filter valid and invalid contexts using boolean logic + valid_context_df = level_context_df.loc[ + ~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] + ] + invalid_context_df = level_context_df.loc[ + level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] + ] # there is no report to substitute with, so we just trim the local context of the invalid context records # this case should only happen at the bottom level of the community hierarchy where there are no sub-communities @@ -55,11 +63,13 @@ def prep_community_report_context( return valid_context_df if report_df.empty: - invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context( + invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context( invalid_context_df, max_tokens ) - set_context_size(invalid_context_df) - invalid_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = 0 + invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[ + :, schemas.CONTEXT_STRING + ].map(num_tokens) + invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0 return union(valid_context_df, invalid_context_df) level_context_df = _antijoin_reports(level_context_df, report_df) @@ -74,12 +84,13 @@ def prep_community_report_context( # handle any remaining invalid records that can't be subsituted with sub-community reports # this should be rare, but if it happens, we will just trim the local context to fit the limit remaining_df = _antijoin_reports(invalid_context_df, community_df) - remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context( + remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context( remaining_df, max_tokens ) result = union(valid_context_df, community_df, remaining_df) - set_context_size(result) + result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens) + result[schemas.CONTEXT_EXCEED_FLAG] = 0 return result @@ -94,16 +105,6 @@ def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame: return where_column_equals(df, schemas.COMMUNITY_LEVEL, level) -def _exceeding_context(df: pd.DataFrame) -> pd.DataFrame: - """Return records where the context exceeds the limit.""" - return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 1) - - -def _within_context(df: pd.DataFrame) -> pd.DataFrame: - """Return records where the context is within the limit.""" - return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 0) - - def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame: """Return records in df that are not in reports.""" return antijoin(df, reports, schemas.NODE_COMMUNITY) diff --git a/graphrag/index/graph/extractors/community_reports/sort_context.py b/graphrag/index/graph/extractors/community_reports/sort_context.py index c62710e1c8..ab56083438 100644 --- a/graphrag/index/graph/extractors/community_reports/sort_context.py +++ b/graphrag/index/graph/extractors/community_reports/sort_context.py @@ -12,7 +12,6 @@ def sort_context( local_context: list[dict], sub_community_reports: list[dict] | None = None, max_tokens: int | None = None, - node_id_column: str = schemas.NODE_ID, node_name_column: str = schemas.NODE_NAME, node_details_column: str = schemas.NODE_DETAILS, edge_id_column: str = schemas.EDGE_ID, @@ -20,14 +19,9 @@ def sort_context( edge_degree_column: str = schemas.EDGE_DEGREE, edge_source_column: str = schemas.EDGE_SOURCE, edge_target_column: str = schemas.EDGE_TARGET, - claim_id_column: str = schemas.CLAIM_ID, claim_details_column: str = schemas.CLAIM_DETAILS, - community_id_column: str = schemas.COMMUNITY_ID, ) -> str: - """Sort context by degree in descending order. - - If max tokens is provided, we will return the context string that fits within the token limit. - """ + """Sort context by degree in descending order, optimizing for performance.""" def _get_context_string( entities: list[dict], @@ -38,119 +32,123 @@ def _get_context_string( """Concatenate structured data into a context string.""" contexts = [] if sub_community_reports: - sub_community_reports = [ - report - for report in sub_community_reports - if community_id_column in report - and report[community_id_column] - and str(report[community_id_column]).strip() != "" - ] - report_df = pd.DataFrame(sub_community_reports).drop_duplicates() + report_df = pd.DataFrame(sub_community_reports) if not report_df.empty: - if report_df[community_id_column].dtype == float: - report_df[community_id_column] = report_df[ - community_id_column - ].astype(int) - report_string = ( + contexts.append( f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}" ) - contexts.append(report_string) - - entities = [ - entity - for entity in entities - if node_id_column in entity - and entity[node_id_column] - and str(entity[node_id_column]).strip() != "" - ] - entity_df = pd.DataFrame(entities).drop_duplicates() - if not entity_df.empty: - if entity_df[node_id_column].dtype == float: - entity_df[node_id_column] = entity_df[node_id_column].astype(int) - entity_string = ( - f"-----Entities-----\n{entity_df.to_csv(index=False, sep=',')}" - ) - contexts.append(entity_string) - - if claims and len(claims) > 0: - claims = [ - claim - for claim in claims - if claim_id_column in claim - and claim[claim_id_column] - and str(claim[claim_id_column]).strip() != "" - ] - claim_df = pd.DataFrame(claims).drop_duplicates() - if not claim_df.empty: - if claim_df[claim_id_column].dtype == float: - claim_df[claim_id_column] = claim_df[claim_id_column].astype(int) - claim_string = ( - f"-----Claims-----\n{claim_df.to_csv(index=False, sep=',')}" - ) - contexts.append(claim_string) - - edges = [ - edge - for edge in edges - if edge_id_column in edge - and edge[edge_id_column] - and str(edge[edge_id_column]).strip() != "" - ] - edge_df = pd.DataFrame(edges).drop_duplicates() - if not edge_df.empty: - if edge_df[edge_id_column].dtype == float: - edge_df[edge_id_column] = edge_df[edge_id_column].astype(int) - edge_string = ( - f"-----Relationships-----\n{edge_df.to_csv(index=False, sep=',')}" - ) - contexts.append(edge_string) + + for label, data in [ + ("Entities", entities), + ("Claims", claims), + ("Relationships", edges), + ]: + if data: + data_df = pd.DataFrame(data) + if not data_df.empty: + contexts.append( + f"-----{label}-----\n{data_df.to_csv(index=False, sep=',')}" + ) return "\n\n".join(contexts) - # sort node details by degree in descending order - edges = [] - node_details = {} - claim_details = {} - - for record in local_context: - node_name = record[node_name_column] - record_edges = record.get(edge_details_column, []) - record_edges = [e for e in record_edges if not pd.isna(e)] - record_node_details = record[node_details_column] - record_claims = record.get(claim_details_column, []) - record_claims = [c for c in record_claims if not pd.isna(c)] - - edges.extend(record_edges) - node_details[node_name] = record_node_details - claim_details[node_name] = record_claims - - edges = [edge for edge in edges if isinstance(edge, dict)] - edges = sorted(edges, key=lambda x: x[edge_degree_column], reverse=True) - - sorted_edges = [] - sorted_nodes = [] - sorted_claims = [] + # Preprocess local context + edges = [ + {**e, schemas.EDGE_ID: int(e[schemas.EDGE_ID])} + for record in local_context + for e in record.get(edge_details_column, []) + if isinstance(e, dict) + ] + + node_details = { + record[node_name_column]: { + **record[node_details_column], + schemas.NODE_ID: int(record[node_details_column][schemas.NODE_ID]), + } + for record in local_context + } + + claim_details = { + record[node_name_column]: [ + {**c, schemas.CLAIM_ID: int(c[schemas.CLAIM_ID])} + for c in record.get(claim_details_column, []) + if isinstance(c, dict) and c.get(schemas.CLAIM_ID) is not None + ] + for record in local_context + if isinstance(record.get(claim_details_column), list) + } + + # Sort edges by degree (desc) and ID (asc) + edges.sort(key=lambda x: (-x.get(edge_degree_column, 0), x.get(edge_id_column, ""))) + + # Deduplicate and build context incrementally + edge_ids, nodes_ids, claims_ids = set(), set(), set() + sorted_edges, sorted_nodes, sorted_claims = [], [], [] context_string = "" + for edge in edges: - source_details = node_details.get(edge[edge_source_column], {}) - target_details = node_details.get(edge[edge_target_column], {}) - sorted_nodes.extend([source_details, target_details]) - sorted_edges.append(edge) - source_claims = claim_details.get(edge[edge_source_column], []) - target_claims = claim_details.get(edge[edge_target_column], []) - sorted_claims.extend(source_claims if source_claims else []) - sorted_claims.extend(target_claims if source_claims else []) - if max_tokens: - new_context_string = _get_context_string( - sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + source, target = edge[edge_source_column], edge[edge_target_column] + + # Add source and target node details + for node in [node_details.get(source), node_details.get(target)]: + if node and node[schemas.NODE_ID] not in nodes_ids: + nodes_ids.add(node[schemas.NODE_ID]) + sorted_nodes.append(node) + + # Add claims related to source and target + for claims in [claim_details.get(source), claim_details.get(target)]: + if claims: + for claim in claims: + if claim[schemas.CLAIM_ID] not in claims_ids: + claims_ids.add(claim[schemas.CLAIM_ID]) + sorted_claims.append(claim) + + # Add the edge + if edge[schemas.EDGE_ID] not in edge_ids: + edge_ids.add(edge[schemas.EDGE_ID]) + sorted_edges.append(edge) + + # Generate new context string + new_context_string = _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + if max_tokens and num_tokens(new_context_string) > max_tokens: + break + context_string = new_context_string + + # Return the final context string + return context_string or _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + + +def parallel_sort_context_batch(community_df, max_tokens, parallel=False): + """Calculate context using parallelization if enabled.""" + if parallel: + # Use ThreadPoolExecutor for parallel execution + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=None) as executor: + context_strings = list( + executor.map( + lambda x: sort_context(x, max_tokens=max_tokens), + community_df[schemas.ALL_CONTEXT], + ) ) - if num_tokens(new_context_string) > max_tokens: - break - context_string = new_context_string + community_df[schemas.CONTEXT_STRING] = context_strings - if context_string == "": - return _get_context_string( - sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + else: + # Assign context strings directly to the DataFrame + community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( + lambda context_list: sort_context(context_list, max_tokens=max_tokens) ) - return context_string + # Calculate other columns + community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply( + num_tokens + ) + community_df[schemas.CONTEXT_EXCEED_FLAG] = ( + community_df[schemas.CONTEXT_SIZE] > max_tokens + ) + + return community_df diff --git a/graphrag/index/graph/extractors/community_reports/utils.py b/graphrag/index/graph/extractors/community_reports/utils.py index 32d3196f6c..48afd671b9 100644 --- a/graphrag/index/graph/extractors/community_reports/utils.py +++ b/graphrag/index/graph/extractors/community_reports/utils.py @@ -3,53 +3,13 @@ """A module containing community report generation utilities.""" -from typing import cast - import pandas as pd import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.query.llm.text_utils import num_tokens - - -def set_context_size(df: pd.DataFrame) -> None: - """Measure the number of tokens in the context.""" - df.loc[:, schemas.CONTEXT_SIZE] = df.loc[:, schemas.CONTEXT_STRING].apply( - lambda x: num_tokens(x) - ) - - -def set_context_exceeds_flag(df: pd.DataFrame, max_tokens: int) -> None: - """Set a flag to indicate if the context exceeds the limit.""" - df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = df.loc[:, schemas.CONTEXT_SIZE].apply( - lambda x: x > max_tokens - ) def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]: """Get the levels of the communities.""" - result = sorted(df[level_column].fillna(-1).unique().tolist(), reverse=True) - return [r for r in result if r != -1] - - -def filter_nodes_to_level(node_df: pd.DataFrame, level: int) -> pd.DataFrame: - """Filter nodes to level.""" - return cast(pd.DataFrame, node_df[node_df[schemas.NODE_LEVEL] == level]) - - -def filter_edges_to_nodes(edge_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame: - """Filter edges to nodes.""" - return cast( - pd.DataFrame, - edge_df[ - edge_df[schemas.EDGE_SOURCE].isin(nodes) - & edge_df[schemas.EDGE_TARGET].isin(nodes) - ], - ) - - -def filter_claims_to_nodes(claims_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame: - """Filter edges to nodes.""" - return cast( - pd.DataFrame, - claims_df[claims_df[schemas.CLAIM_SUBJECT].isin(nodes)], - ) + levels = df[level_column].dropna().unique() + levels = [int(lvl) for lvl in levels if lvl != -1] + return sorted(levels, reverse=True) diff --git a/graphrag/index/operations/summarize_communities/prepare_community_reports.py b/graphrag/index/operations/summarize_communities/prepare_community_reports.py index b402e6afc8..c2156edd69 100644 --- a/graphrag/index/operations/summarize_communities/prepare_community_reports.py +++ b/graphrag/index/operations/summarize_communities/prepare_community_reports.py @@ -4,7 +4,6 @@ """A module containing create_community_reports and load_strategy methods definition.""" import logging -from typing import cast import pandas as pd from datashaper import ( @@ -13,15 +12,10 @@ ) import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.index.graph.extractors.community_reports import ( - filter_claims_to_nodes, - filter_edges_to_nodes, - filter_nodes_to_level, - get_levels, - set_context_exceeds_flag, - set_context_size, - sort_context, +from graphrag.index.graph.extractors.community_reports.sort_context import ( + parallel_sort_context_batch, ) +from graphrag.index.graph.extractors.community_reports.utils import get_levels log = logging.getLogger(__name__) @@ -35,12 +29,15 @@ def prepare_community_reports( ): """Prep communities for report generation.""" levels = get_levels(nodes, schemas.NODE_LEVEL) + dfs = [] for level in progress_iterable(levels, callbacks.progress, len(levels)): communities_at_level_df = _prepare_reports_at_level( nodes, edges, claims, level, max_tokens ) + + communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level dfs.append(communities_at_level_df) # build initial local context for all communities @@ -53,127 +50,121 @@ def _prepare_reports_at_level( claim_df: pd.DataFrame | None, level: int, max_tokens: int = 16_000, - community_id_column: str = schemas.COMMUNITY_ID, - node_id_column: str = schemas.NODE_ID, - node_name_column: str = schemas.NODE_NAME, - node_details_column: str = schemas.NODE_DETAILS, - node_level_column: str = schemas.NODE_LEVEL, - node_degree_column: str = schemas.NODE_DEGREE, - node_community_column: str = schemas.NODE_COMMUNITY, - edge_id_column: str = schemas.EDGE_ID, - edge_source_column: str = schemas.EDGE_SOURCE, - edge_target_column: str = schemas.EDGE_TARGET, - edge_degree_column: str = schemas.EDGE_DEGREE, - edge_details_column: str = schemas.EDGE_DETAILS, - claim_id_column: str = schemas.CLAIM_ID, - claim_subject_column: str = schemas.CLAIM_SUBJECT, - claim_details_column: str = schemas.CLAIM_DETAILS, -): - def get_edge_details(node_df: pd.DataFrame, edge_df: pd.DataFrame, name_col: str): - return node_df.merge( - cast( - pd.DataFrame, - edge_df[[name_col, schemas.EDGE_DETAILS]], - ).rename(columns={name_col: schemas.NODE_NAME}), - on=schemas.NODE_NAME, - how="left", - ) - - level_node_df = filter_nodes_to_level(node_df, level) +) -> pd.DataFrame: + """Prepare reports at a given level.""" + # Filter and prepare node details + level_node_df = node_df[node_df[schemas.NODE_LEVEL] == level] log.info("Number of nodes at level=%s => %s", level, len(level_node_df)) - nodes = level_node_df[node_name_column].tolist() - - # Filter edges & claims to those containing the target nodes - level_edge_df = filter_edges_to_nodes(edge_df, nodes) - level_claim_df = ( - filter_claims_to_nodes(claim_df, nodes) if claim_df is not None else None - ) - - # concat all edge details per node - merged_node_df = pd.concat( + nodes_set = set(level_node_df[schemas.NODE_NAME]) + + # Filter and prepare edge details + level_edge_df = edge_df[ + edge_df.loc[:, schemas.EDGE_SOURCE].isin(nodes_set) + & edge_df.loc[:, schemas.EDGE_TARGET].isin(nodes_set) + ] + level_edge_df.loc[:, schemas.EDGE_DETAILS] = level_edge_df.loc[ + :, [ - get_edge_details(level_node_df, level_edge_df, edge_source_column), - get_edge_details(level_node_df, level_edge_df, edge_target_column), + schemas.EDGE_ID, + schemas.EDGE_SOURCE, + schemas.EDGE_TARGET, + schemas.EDGE_DESCRIPTION, + schemas.EDGE_DEGREE, ], - axis=0, + ].to_dict(orient="records") + + level_claim_df = pd.DataFrame() + if claim_df is not None: + level_claim_df = claim_df[ + claim_df.loc[:, schemas.CLAIM_SUBJECT].isin(nodes_set) + ] + + # Merge node and edge details + # Group edge details by node and aggregate into lists + source_edges = ( + level_edge_df.groupby(schemas.EDGE_SOURCE) + .agg({schemas.EDGE_DETAILS: "first"}) + .reset_index() + .rename(columns={schemas.EDGE_SOURCE: schemas.NODE_NAME}) ) - merged_node_df = ( - merged_node_df.groupby([ - node_name_column, - node_community_column, - node_degree_column, - node_level_column, - ]) - .agg({node_details_column: "first", edge_details_column: list}) + + target_edges = ( + level_edge_df.groupby(schemas.EDGE_TARGET) + .agg({schemas.EDGE_DETAILS: "first"}) .reset_index() + .rename(columns={schemas.EDGE_TARGET: schemas.NODE_NAME}) ) - # concat claim details per node - if level_claim_df is not None: - merged_node_df = merged_node_df.merge( - cast( - pd.DataFrame, - level_claim_df[[claim_subject_column, claim_details_column]], - ).rename(columns={claim_subject_column: node_name_column}), - on=node_name_column, - how="left", - ) + # Merge aggregated edges into the node DataFrame + merged_node_df = level_node_df.merge( + source_edges, on=schemas.NODE_NAME, how="left" + ).merge(target_edges, on=schemas.NODE_NAME, how="left") + + # Combine source and target edge details into a single column + merged_node_df.loc[:, schemas.EDGE_DETAILS] = merged_node_df.loc[ + :, f"{schemas.EDGE_DETAILS}_x" + ].combine_first(merged_node_df.loc[:, f"{schemas.EDGE_DETAILS}_y"]) + + # Drop intermediate columns + merged_node_df.drop( + columns=[f"{schemas.EDGE_DETAILS}_x", f"{schemas.EDGE_DETAILS}_y"], inplace=True + ) + + # Aggregate node and edge details merged_node_df = ( merged_node_df.groupby([ - node_name_column, - node_community_column, - node_level_column, - node_degree_column, + schemas.NODE_NAME, + schemas.NODE_COMMUNITY, + schemas.NODE_LEVEL, + schemas.NODE_DEGREE, ]) .agg({ - node_details_column: "first", - edge_details_column: "first", - **({claim_details_column: list} if level_claim_df is not None else {}), + schemas.NODE_DETAILS: "first", + schemas.EDGE_DETAILS: lambda x: list(x.dropna()), }) .reset_index() ) - # concat all node details, including name, degree, node_details, edge_details, and claim_details - merged_node_df[schemas.ALL_CONTEXT] = merged_node_df.apply( - lambda x: { - node_name_column: x[node_name_column], - node_degree_column: x[node_degree_column], - node_details_column: x[node_details_column], - edge_details_column: x[edge_details_column], - claim_details_column: x[claim_details_column] - if level_claim_df is not None - else [], - }, - axis=1, + # Add ALL_CONTEXT column + # Ensure schemas.CLAIM_DETAILS exists with the correct length + # Merge claim details if available + if claim_df is not None: + merged_node_df = merged_node_df.merge( + level_claim_df.loc[ + :, [schemas.CLAIM_SUBJECT, schemas.CLAIM_DETAILS] + ].rename(columns={schemas.CLAIM_SUBJECT: schemas.NODE_NAME}), + on=schemas.NODE_NAME, + how="left", + ) + + # Create the ALL_CONTEXT column + merged_node_df[schemas.ALL_CONTEXT] = ( + merged_node_df.loc[ + :, + [ + schemas.NODE_NAME, + schemas.NODE_DEGREE, + schemas.NODE_DETAILS, + schemas.EDGE_DETAILS, + ], + ] + .assign( + **{schemas.CLAIM_DETAILS: merged_node_df[schemas.CLAIM_DETAILS]} + if claim_df is not None + else {} + ) + .to_dict(orient="records") ) # group all node details by community community_df = ( - merged_node_df.groupby(node_community_column) + merged_node_df.groupby(schemas.NODE_COMMUNITY) .agg({schemas.ALL_CONTEXT: list}) .reset_index() ) - community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( - lambda x: sort_context( - x, - node_id_column=node_id_column, - node_name_column=node_name_column, - node_details_column=node_details_column, - edge_id_column=edge_id_column, - edge_details_column=edge_details_column, - edge_degree_column=edge_degree_column, - edge_source_column=edge_source_column, - edge_target_column=edge_target_column, - claim_id_column=claim_id_column, - claim_details_column=claim_details_column, - community_id_column=community_id_column, - ) - ) - set_context_size(community_df) - set_context_exceeds_flag(community_df, max_tokens) - community_df[schemas.COMMUNITY_LEVEL] = level - community_df[node_community_column] = community_df[node_community_column].astype( - int + # Generate community-level context strings using vectorized batch processing + return parallel_sort_context_batch( + community_df, + max_tokens=max_tokens, ) - return community_df diff --git a/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py b/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py index 2430472abd..2512db484f 100644 --- a/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py +++ b/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py @@ -4,6 +4,7 @@ """A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" import logging +from itertools import pairwise import pandas as pd @@ -19,50 +20,39 @@ def restore_community_hierarchy( level_column: str = schemas.NODE_LEVEL, ) -> pd.DataFrame: """Restore the community hierarchy from the node data.""" + # Group by community and level, aggregate names as lists community_df = ( - input.groupby([community_column, level_column]) - .agg({name_column: list}) + input.groupby([community_column, level_column])[name_column] + .apply(set) .reset_index() ) - community_levels = {} - for _, row in community_df.iterrows(): - level = row[level_column] - name = row[name_column] - community = row[community_column] - if community_levels.get(level) is None: - community_levels[level] = {} - community_levels[level][community] = name + # Build dictionary with levels as integers + community_levels = { + level: group.set_index(community_column)[name_column].to_dict() + for level, group in community_df.groupby(level_column) + } # get unique levels, sorted in ascending order - levels = sorted(community_levels.keys()) + levels = sorted(community_levels.keys()) # type: ignore community_hierarchy = [] - for idx in range(len(levels) - 1): - level = levels[idx] - next_level = levels[idx + 1] - current_level_communities = community_levels[level] - next_level_communities = community_levels[next_level] + # Iterate through adjacent levels + for current_level, next_level in pairwise(levels): + current_communities = community_levels[current_level] + next_communities = community_levels[next_level] - for current_community in current_level_communities: - current_entities = current_level_communities[current_community] - - # loop through next level's communities to find all the subcommunities - entities_found = 0 - for next_level_community in next_level_communities: - next_entities = next_level_communities[next_level_community] - if set(next_entities).issubset(set(current_entities)): + # Find sub-communities + for curr_comm, curr_entities in current_communities.items(): + for next_comm, next_entities in next_communities.items(): + if next_entities.issubset(curr_entities): community_hierarchy.append({ - community_column: current_community, - schemas.COMMUNITY_LEVEL: level, - schemas.SUB_COMMUNITY: next_level_community, + community_column: curr_comm, + schemas.COMMUNITY_LEVEL: current_level, + schemas.SUB_COMMUNITY: next_comm, schemas.SUB_COMMUNITY_SIZE: len(next_entities), }) - entities_found += len(next_entities) - if entities_found == len(current_entities): - break - return pd.DataFrame( community_hierarchy, ) diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index bab02ef5e1..35a0192b82 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -18,9 +18,9 @@ import graphrag.index.graph.extractors.community_reports.schemas as schemas from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.graph.extractors.community_reports import ( - get_levels, prep_community_report_context, ) +from graphrag.index.graph.extractors.community_reports.utils import get_levels from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, CommunityReportsStrategy, diff --git a/graphrag/index/utils/dataframes.py b/graphrag/index/utils/dataframes.py index ea65d71d7a..b5a71bf397 100644 --- a/graphrag/index/utils/dataframes.py +++ b/graphrag/index/utils/dataframes.py @@ -28,15 +28,7 @@ def antijoin(df: pd.DataFrame, exclude: pd.DataFrame, column: str) -> pd.DataFra * exclude: The DataFrame containing rows to remove. * column: The join-on column. """ - result = df.merge( - exclude[[column]], - on=column, - how="outer", - indicator=True, - ) - if "_merge" in result.columns: - result = result[result["_merge"] == "left_only"].drop("_merge", axis=1) - return cast(pd.DataFrame, result) + return df.loc[~df.loc[:, column].isin(exclude.loc[:, column])] def transform_series(series: pd.Series, fn: Callable[[Any], Any]) -> pd.Series: diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 1dbf38e66b..513b9e65b6 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -72,7 +72,7 @@ def __init__( text_units = [] self.entities = {entity.id: entity for entity in entities} self.community_reports = { - community.id: community for community in community_reports + community.community_id: community for community in community_reports } self.text_units = {unit.id: unit for unit in text_units} self.relationships = { @@ -254,7 +254,7 @@ def _build_community_context( for community in selected_communities: if community.attributes is None: community.attributes = {} - community.attributes["matches"] = community_matches[community.id] + community.attributes["matches"] = community_matches[community.community_id] selected_communities.sort( key=lambda x: (x.attributes["matches"], x.rank), # type: ignore reverse=True, # type: ignore diff --git a/tests/verbs/data/base_communities.parquet b/tests/verbs/data/base_communities.parquet index e06c927ba9..232923a844 100644 Binary files a/tests/verbs/data/base_communities.parquet and b/tests/verbs/data/base_communities.parquet differ diff --git a/tests/verbs/data/base_entity_nodes.parquet b/tests/verbs/data/base_entity_nodes.parquet index 58a6d5ddf2..8db3f1ee29 100644 Binary files a/tests/verbs/data/base_entity_nodes.parquet and b/tests/verbs/data/base_entity_nodes.parquet differ diff --git a/tests/verbs/data/base_relationship_edges.parquet b/tests/verbs/data/base_relationship_edges.parquet index 3e66fd229e..deb64dd476 100644 Binary files a/tests/verbs/data/base_relationship_edges.parquet and b/tests/verbs/data/base_relationship_edges.parquet differ diff --git a/tests/verbs/data/create_final_communities.parquet b/tests/verbs/data/create_final_communities.parquet index add74505bf..2d23a0306d 100644 Binary files a/tests/verbs/data/create_final_communities.parquet and b/tests/verbs/data/create_final_communities.parquet differ diff --git a/tests/verbs/data/create_final_community_reports.parquet b/tests/verbs/data/create_final_community_reports.parquet index 3c72a762c2..1e27b83fce 100644 Binary files a/tests/verbs/data/create_final_community_reports.parquet and b/tests/verbs/data/create_final_community_reports.parquet differ diff --git a/tests/verbs/data/create_final_covariates.parquet b/tests/verbs/data/create_final_covariates.parquet index 7969f19565..2ee7160d19 100644 Binary files a/tests/verbs/data/create_final_covariates.parquet and b/tests/verbs/data/create_final_covariates.parquet differ diff --git a/tests/verbs/data/create_final_entities.parquet b/tests/verbs/data/create_final_entities.parquet index 8d3ed4451a..2eb48e6a5e 100644 Binary files a/tests/verbs/data/create_final_entities.parquet and b/tests/verbs/data/create_final_entities.parquet differ diff --git a/tests/verbs/data/create_final_nodes.parquet b/tests/verbs/data/create_final_nodes.parquet index 93432fea07..22c9db9c39 100644 Binary files a/tests/verbs/data/create_final_nodes.parquet and b/tests/verbs/data/create_final_nodes.parquet differ diff --git a/tests/verbs/data/create_final_relationships.parquet b/tests/verbs/data/create_final_relationships.parquet index e3c822fbe0..e104cd507f 100644 Binary files a/tests/verbs/data/create_final_relationships.parquet and b/tests/verbs/data/create_final_relationships.parquet differ diff --git a/tests/verbs/data/create_final_text_units.parquet b/tests/verbs/data/create_final_text_units.parquet index be6da45fa3..39cc20d415 100644 Binary files a/tests/verbs/data/create_final_text_units.parquet and b/tests/verbs/data/create_final_text_units.parquet differ