Skip to content

Commit 1755afb

Browse files
Collapse create base text units (#1178)
* Collapse non-attribute verbs * Include document_column_attributes in collapse * Remove merge_override verb * Semver * Setup initial test and config * Collapse create_base_text_units * Semver * Spelling * Fix smoke tests * Addres PR comments --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent be7d3eb commit 1755afb

File tree

11 files changed

+180
-96
lines changed

11 files changed

+180
-96
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Collapse create_base_text_units."
4+
}

dictionary.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ aembed
100100
dedupe
101101
dropna
102102
dtypes
103+
notna
103104

104105
# LLM Terms
105106
AOAI

graphrag/index/verbs/genid.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def genid(
1616
input: VerbInput,
1717
to: str,
1818
method: str = "md5_hash",
19-
hash: list[str] = [], # noqa A002
19+
hash: list[str] | None = None, # noqa A002
2020
**_kwargs: dict,
2121
) -> TableContainer:
2222
"""
@@ -52,15 +52,29 @@ def genid(
5252
"""
5353
data = cast(pd.DataFrame, input.source.table)
5454

55-
if method == "md5_hash":
56-
if len(hash) == 0:
57-
msg = 'Must specify the "hash" columns to use md5_hash method'
55+
output = genid_df(data, to, method, hash)
56+
57+
return TableContainer(table=output)
58+
59+
60+
def genid_df(
61+
input: pd.DataFrame,
62+
to: str,
63+
method: str = "md5_hash",
64+
hash: list[str] | None = None, # noqa A002
65+
):
66+
"""Generate a unique id for each row in the tabular data."""
67+
data = input
68+
match method:
69+
case "md5_hash":
70+
if not hash:
71+
msg = 'Must specify the "hash" columns to use md5_hash method'
72+
raise ValueError(msg)
73+
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
74+
case "increment":
75+
data[to] = data.index + 1
76+
case _:
77+
msg = f"Unknown method {method}"
5878
raise ValueError(msg)
5979

60-
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
61-
elif method == "increment":
62-
data[to] = data.index + 1
63-
else:
64-
msg = f"Unknown method {method}"
65-
raise ValueError(msg)
66-
return TableContainer(table=data)
80+
return data

graphrag/index/verbs/text/chunk/text_chunk.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,24 @@ def chunk(
8585
type: sentence
8686
```
8787
"""
88+
input_table = cast(pd.DataFrame, input.get_input())
89+
90+
output = chunk_df(input_table, column, to, callbacks, strategy)
91+
92+
return TableContainer(table=output)
93+
94+
95+
def chunk_df(
96+
input: pd.DataFrame,
97+
column: str,
98+
to: str,
99+
callbacks: VerbCallbacks,
100+
strategy: dict[str, Any] | None = None,
101+
) -> pd.DataFrame:
102+
"""Chunk a piece of text into smaller pieces."""
103+
output = input
88104
if strategy is None:
89105
strategy = {}
90-
output = cast(pd.DataFrame, input.get_input())
91106
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
92107
strategy_config = {**strategy}
93108
strategy_exec = load_strategy(strategy_name)
@@ -102,7 +117,7 @@ def chunk(
102117
),
103118
axis=1,
104119
)
105-
return TableContainer(table=output)
120+
return output
106121

107122

108123
def run_strategy(

graphrag/index/workflows/v1/create_base_text_units.py

Lines changed: 6 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -22,91 +22,16 @@ def build_steps(
2222
chunk_column_name = config.get("chunk_column", "chunk")
2323
chunk_by_columns = config.get("chunk_by", []) or []
2424
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
25+
text_chunk = config.get("text_chunk", {})
2526
return [
2627
{
27-
"verb": "orderby",
28+
"verb": "create_base_text_units",
2829
"args": {
29-
"orders": [
30-
# sort for reproducibility
31-
{"column": "id", "direction": "asc"},
32-
]
30+
"chunk_column_name": chunk_column_name,
31+
"n_tokens_column_name": n_tokens_column_name,
32+
"chunk_by_columns": chunk_by_columns,
33+
**text_chunk,
3334
},
3435
"input": {"source": DEFAULT_INPUT_NAME},
3536
},
36-
{
37-
"verb": "zip",
38-
"args": {
39-
# Pack the document ids with the text
40-
# So when we unpack the chunks, we can restore the document id
41-
"columns": ["id", "text"],
42-
"to": "text_with_ids",
43-
},
44-
},
45-
{
46-
"verb": "aggregate_override",
47-
"args": {
48-
"groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
49-
"aggregations": [
50-
{
51-
"column": "text_with_ids",
52-
"operation": "array_agg",
53-
"to": "texts",
54-
}
55-
],
56-
},
57-
},
58-
{
59-
"verb": "chunk",
60-
"args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})},
61-
},
62-
{
63-
"verb": "select",
64-
"args": {
65-
"columns": [*chunk_by_columns, "chunks"],
66-
},
67-
},
68-
{
69-
"verb": "unroll",
70-
"args": {
71-
"column": "chunks",
72-
},
73-
},
74-
{
75-
"verb": "rename",
76-
"args": {
77-
"columns": {
78-
"chunks": chunk_column_name,
79-
}
80-
},
81-
},
82-
{
83-
"verb": "genid",
84-
"args": {
85-
# Generate a unique id for each chunk
86-
"to": "chunk_id",
87-
"method": "md5_hash",
88-
"hash": [chunk_column_name],
89-
},
90-
},
91-
{
92-
"verb": "unzip",
93-
"args": {
94-
"column": chunk_column_name,
95-
"to": ["document_ids", chunk_column_name, n_tokens_column_name],
96-
},
97-
},
98-
{"verb": "copy", "args": {"column": "chunk_id", "to": "id"}},
99-
{
100-
# ELIMINATE EMPTY CHUNKS
101-
"verb": "filter",
102-
"args": {
103-
"column": chunk_column_name,
104-
"criteria": [
105-
{
106-
"type": "value",
107-
"operator": "is not empty",
108-
}
109-
],
110-
},
111-
},
11237
]

graphrag/index/workflows/v1/subflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""The Indexing Engine workflows -> subflows package root."""
55

66
from .create_base_documents import create_base_documents
7+
from .create_base_text_units import create_base_text_units
78
from .create_final_communities import create_final_communities
89
from .create_final_nodes import create_final_nodes
910
from .create_final_relationships_post_embedding import (
@@ -16,6 +17,7 @@
1617

1718
__all__ = [
1819
"create_base_documents",
20+
"create_base_text_units",
1921
"create_final_communities",
2022
"create_final_nodes",
2123
"create_final_relationships_post_embedding",
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""All the steps to transform base text_units."""
5+
6+
from typing import Any, cast
7+
8+
import pandas as pd
9+
from datashaper import (
10+
Table,
11+
VerbCallbacks,
12+
VerbInput,
13+
verb,
14+
)
15+
from datashaper.table_store.types import VerbResult, create_verb_result
16+
17+
from graphrag.index.verbs.genid import genid_df
18+
from graphrag.index.verbs.overrides.aggregate import aggregate_df
19+
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df
20+
21+
22+
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
23+
def create_base_text_units(
24+
input: VerbInput,
25+
callbacks: VerbCallbacks,
26+
chunk_column_name: str,
27+
n_tokens_column_name: str,
28+
chunk_by_columns: list[str],
29+
strategy: dict[str, Any] | None = None,
30+
**_kwargs: dict,
31+
) -> VerbResult:
32+
"""All the steps to transform base text_units."""
33+
table = cast(pd.DataFrame, input.get_input())
34+
35+
sort = table.sort_values(by=["id"], ascending=[True])
36+
37+
sort["text_with_ids"] = list(
38+
zip(*[sort[col] for col in ["id", "text"]], strict=True)
39+
)
40+
41+
aggregated = aggregate_df(
42+
sort,
43+
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
44+
aggregations=[
45+
{
46+
"column": "text_with_ids",
47+
"operation": "array_agg",
48+
"to": "texts",
49+
}
50+
],
51+
)
52+
53+
chunked = chunk_df(
54+
aggregated,
55+
column="texts",
56+
to="chunks",
57+
callbacks=callbacks,
58+
strategy=strategy,
59+
)
60+
61+
chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
62+
chunked = chunked.explode("chunks")
63+
chunked.rename(
64+
columns={
65+
"chunks": chunk_column_name,
66+
},
67+
inplace=True,
68+
)
69+
70+
chunked = genid_df(
71+
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
72+
)
73+
74+
chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
75+
chunked[chunk_column_name].tolist(), index=chunked.index
76+
)
77+
chunked["id"] = chunked["chunk_id"]
78+
79+
filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
80+
81+
return create_verb_result(
82+
cast(
83+
Table,
84+
filtered,
85+
)
86+
)

tests/fixtures/min-csv/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
1,
88
2000
99
],
10-
"subworkflows": 11,
10+
"subworkflows": 1,
1111
"max_runtime": 10
1212
},
1313
"create_base_extracted_entities": {

tests/fixtures/text/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
1,
88
2000
99
],
10-
"subworkflows": 11,
10+
"subworkflows": 1,
1111
"max_runtime": 10
1212
},
1313
"create_base_extracted_entities": {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
from graphrag.index.workflows.v1.create_base_text_units import (
5+
build_steps,
6+
workflow_name,
7+
)
8+
9+
from .util import (
10+
compare_outputs,
11+
get_config_for_workflow,
12+
get_workflow_output,
13+
load_expected,
14+
load_input_tables,
15+
)
16+
17+
18+
async def test_create_base_text_units():
19+
input_tables = load_input_tables(inputs=[])
20+
expected = load_expected(workflow_name)
21+
22+
config = get_config_for_workflow(workflow_name)
23+
# test data was created with 4o, so we need to match the encoding for chunks to be identical
24+
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
25+
26+
steps = build_steps(config)
27+
28+
actual = await get_workflow_output(
29+
input_tables,
30+
{
31+
"steps": steps,
32+
},
33+
)
34+
35+
compare_outputs(actual, expected)

0 commit comments

Comments
 (0)