Skip to content

Commit

Permalink
Flatten create_base_text_units config
Browse files Browse the repository at this point in the history
  • Loading branch information
natoverse committed Dec 24, 2024
1 parent 19649cc commit 75cce3f
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 113 deletions.
7 changes: 5 additions & 2 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput
from graphrag.config.input_models.llm_config_input import LLMConfigInput
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
Expand Down Expand Up @@ -412,12 +412,15 @@ def hydrate_parallelization_params(
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

strategy = reader.str("strategy")
chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=encoding_model,
strategy=ChunkStrategyType(strategy)
if strategy
else ChunkStrategyType.tokens,
)
with (
reader.envvar_prefix(Section.snapshot),
Expand Down
34 changes: 17 additions & 17 deletions graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,24 @@

"""Parameterization settings for the default configuration."""

from enum import Enum

from pydantic import BaseModel, Field

import graphrag.config.defaults as defs


class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""

tokens = "tokens"
sentence = "sentence"

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


class ChunkingConfig(BaseModel):
"""Configuration section for chunking."""

Expand All @@ -19,22 +32,9 @@ class ChunkingConfig(BaseModel):
description="The chunk by columns to use.",
default=defs.CHUNK_GROUP_BY_COLUMNS,
)
strategy: dict | None = Field(
description="The chunk strategy to use, overriding the default tokenization strategy",
default=None,
strategy: ChunkStrategyType = Field(
description="The chunking strategy to use.", default=ChunkStrategyType.tokens
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
encoding_model: str = Field(
description="The encoding model to use.", default=defs.ENCODING_MODEL
)

def resolved_strategy(self, encoding_model: str | None) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.operations.chunk_text import ChunkStrategyType

return self.strategy or {
"type": ChunkStrategyType.tokens,
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
"encoding_name": encoding_model or self.encoding_model,
}
7 changes: 1 addition & 6 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,8 @@ def _text_unit_workflows(
PipelineWorkflowReference(
name=create_base_text_units,
config={
"chunks": settings.chunks,
"snapshot_transient": settings.snapshots.transient,
"chunk_by": settings.chunks.group_by_columns,
"text_chunk": {
"strategy": settings.chunks.resolved_strategy(
settings.encoding_model
)
},
},
),
PipelineWorkflowReference(
Expand Down
40 changes: 25 additions & 15 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
aggregate_operation_mapping,
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.config.models.chunking_config import ChunkStrategyType
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
from graphrag.index.utils.hashing import gen_sha512_hash


def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
group_by_columns: list[str],
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand All @@ -35,7 +39,7 @@ def create_base_text_units(

aggregated = _aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
groupby=[*group_by_columns] if len(group_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
Expand All @@ -47,30 +51,36 @@ def create_base_text_units(

callbacks.progress(Progress(percent=1))

chunked = chunk_text(
aggregated["chunks"] = chunk_text(
aggregated,
column="texts",
to="chunks",
size=size,
overlap=overlap,
encoding_model=encoding_model,
strategy=strategy,
callbacks=callbacks,
strategy=chunk_strategy,
)

chunked = cast("pd.DataFrame", chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
aggregated = aggregated.explode("chunks")
aggregated.rename(
columns={
"chunks": "chunk",
},
inplace=True,
)
chunked["id"] = chunked.apply(lambda row: gen_sha512_hash(row, ["chunk"]), axis=1)
chunked[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame(
chunked["chunk"].tolist(), index=chunked.index
aggregated["id"] = aggregated.apply(
lambda row: gen_sha512_hash(row, ["chunk"]), axis=1
)
aggregated[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame(
aggregated["chunk"].tolist(), index=aggregated.index
)
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)
aggregated.rename(columns={"chunk": "text"}, inplace=True)

return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
return cast(
"pd.DataFrame", aggregated[aggregated["text"].notna()].reset_index(drop=True)
)


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
8 changes: 0 additions & 8 deletions graphrag/index/operations/chunk_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,3 @@
# Licensed under the MIT License

"""The Indexing Engine text chunk package root."""

from graphrag.index.operations.chunk_text.chunk_text import (
ChunkStrategy,
ChunkStrategyType,
chunk_text,
)

__all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk_text"]
46 changes: 23 additions & 23 deletions graphrag/index/operations/chunk_text/chunk_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
progress_ticker,
)

from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.index.operations.chunk_text.typing import (
ChunkInput,
ChunkStrategy,
ChunkStrategyType,
)


def chunk_text(
input: pd.DataFrame,
column: str,
to: str,
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
) -> pd.Series:
"""
Chunk a piece of text into smaller pieces.
Expand Down Expand Up @@ -60,35 +62,33 @@ def chunk_text(
type: sentence
```
"""
output = input
if strategy is None:
strategy = {}
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
strategy_config = {**strategy}
strategy_exec = load_strategy(strategy_name)

num_total = _get_num_total(output, column)
tick = progress_ticker(callbacks.progress, num_total)
strategy_exec = load_strategy(strategy)

output[to] = output.apply(
cast(
"Any",
lambda x: run_strategy(strategy_exec, x[column], strategy_config, tick),
num_total = _get_num_total(input, column)
tick = progress_ticker(callbacks.progress, num_total)
# collapse the config back to a single object to support "polymorphic" function call
config = ChunkingConfig(size=size, overlap=overlap, encoding_model=encoding_model)
return cast(
"pd.Series",
input.apply(
cast(
"Any",
lambda x: run_strategy(strategy_exec, x[column], config, tick),
),
axis=1,
),
axis=1,
)
return output


def run_strategy(
strategy: ChunkStrategy,
strategy_exec: ChunkStrategy,
input: ChunkInput,
strategy_args: dict[str, Any],
config: ChunkingConfig,
tick: ProgressTicker,
) -> list[str | tuple[list[str] | None, str, int]]:
"""Run strategy method definition."""
if isinstance(input, str):
return [item.text_chunk for item in strategy([input], {**strategy_args}, tick)]
return [item.text_chunk for item in strategy_exec([input], config, tick)]

# We can work with both just a list of text content
# or a list of tuples of (document_id, text content)
Expand All @@ -100,7 +100,7 @@ def run_strategy(
else:
texts.append(item[1])

strategy_results = strategy(texts, {**strategy_args}, tick)
strategy_results = strategy_exec(texts, config, tick)

results = []
for strategy_result in strategy_results:
Expand Down
13 changes: 6 additions & 7 deletions graphrag/index/operations/chunk_text/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
"""A module containing chunk strategies."""

from collections.abc import Iterable
from typing import Any

import nltk
import tiktoken
from datashaper import ProgressTicker

import graphrag.config.defaults as defs
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import Tokenizer


def run_tokens(
input: list[str], args: dict[str, Any], tick: ProgressTicker
input: list[str], config: ChunkingConfig, tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into chunks based on encoding tokens."""
tokens_per_chunk = args.get("chunk_size", defs.CHUNK_SIZE)
chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP)
encoding_name = args.get("encoding_name", defs.ENCODING_MODEL)
tokens_per_chunk = config.size
chunk_overlap = config.overlap
encoding_name = config.encoding_model
enc = tiktoken.get_encoding(encoding_name)

def encode(text: str) -> list[int]:
Expand Down Expand Up @@ -83,7 +82,7 @@ def _split_text_on_tokens(


def run_sentences(
input: list[str], _args: dict[str, Any], tick: ProgressTicker
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into multiple parts by sentence."""
for doc_idx, text in enumerate(input):
Expand Down
17 changes: 3 additions & 14 deletions graphrag/index/operations/chunk_text/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any

from datashaper import ProgressTicker

from graphrag.config.models.chunking_config import ChunkingConfig


@dataclass
class TextChunk:
Expand All @@ -24,16 +24,5 @@ class TextChunk:
"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text)."""

ChunkStrategy = Callable[
[list[str], dict[str, Any], ProgressTicker], Iterable[TextChunk]
[list[str], ChunkingConfig, ProgressTicker], Iterable[TextChunk]
]


class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""

tokens = "tokens"
sentence = "sentence"

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
Loading

0 comments on commit 75cce3f

Please sign in to comment.