Skip to content

Commit

Permalink
Fast table queries with interpolation search (huggingface#2122)
Browse files Browse the repository at this point in the history
* add interpolation search

* update dataset and formatting

* update test_formatting

* test interpolation search

* docstrings

* add benchmark

* update benchmarks

* add indexed table test
  • Loading branch information
lhoestq authored Apr 6, 2021
1 parent 76c2a61 commit ae8b940
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 106 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/benchmarks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
pip install setuptools wheel
pip install -e .[benchmarks]
# pyarrow==0.17.1
pip install pyarrow==0.17.1
# pyarrow==1.0.0
pip install pyarrow==1.0.0
dvc repro --force
Expand All @@ -26,7 +26,7 @@ jobs:
python ./benchmarks/format.py report.json report.md
echo "<details>\n<summary>Show benchmarks</summary>\n\nPyArrow==0.17.1\n" > final_report.md
echo "<details>\n<summary>Show benchmarks</summary>\n\nPyArrow==1.0.0\n" > final_report.md
cat report.md >> final_report.md
# pyarrow
Expand All @@ -39,7 +39,7 @@ jobs:
python ./benchmarks/format.py report.json report.md
echo "\nPyArrow==1.0\n" >> final_report.md
echo "\nPyArrow==latest\n" >> final_report.md
cat report.md >> final_report.md
echo "\n</details>" >> final_report.md
Expand Down
78 changes: 78 additions & 0 deletions benchmarks/benchmark_getitem_100B.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
import os
from dataclasses import dataclass

import numpy as np
import pyarrow as pa
from utils import get_duration

import datasets


SPEED_TEST_N_EXAMPLES = 100_000_000_000
SPEED_TEST_CHUNK_SIZE = 10_000

RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))


def generate_100B_dataset(num_examples: int, chunk_size: int) -> datasets.Dataset:
table = pa.Table.from_pydict({"col": [0] * chunk_size})
table = pa.concat_tables([table] * (num_examples // chunk_size))
return datasets.Dataset(table, fingerprint="table_100B")


@dataclass
class RandIter:
low: int
high: int
size: int
seed: int

def __post_init__(self):
rng = np.random.default_rng(self.seed)
self._sampled_values = rng.integers(low=self.low, high=self.high, size=self.size).tolist()

def __iter__(self):
return iter(self._sampled_values)

def __len__(self):
return self.size


@get_duration
def get_first_row(dataset: datasets.Dataset):
_ = dataset[0]


@get_duration
def get_last_row(dataset: datasets.Dataset):
_ = dataset[-1]


@get_duration
def get_batch_of_1024_rows(dataset: datasets.Dataset):
_ = dataset[range(len(dataset) // 2, len(dataset) // 2 + 1024)]


@get_duration
def get_batch_of_1024_random_rows(dataset: datasets.Dataset):
_ = dataset[RandIter(0, len(dataset), 1024, seed=42)]


def benchmark_table_100B():
times = {"num examples": SPEED_TEST_N_EXAMPLES}
functions = (get_first_row, get_last_row, get_batch_of_1024_rows, get_batch_of_1024_random_rows)
print("generating dataset")
dataset = generate_100B_dataset(num_examples=SPEED_TEST_N_EXAMPLES, chunk_size=SPEED_TEST_CHUNK_SIZE)
print("Functions")
for func in functions:
print(func.__name__)
times[func.__name__] = func(dataset)

with open(RESULTS_FILE_PATH, "wb") as f:
f.write(json.dumps(times).encode("utf-8"))


if __name__ == "__main__": # useful to run the profiler
benchmark_table_100B()
1 change: 1 addition & 0 deletions benchmarks/results/benchmark_getitem_100B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"num examples": 100000000000, "get_first_row": 0.00019991099999927542, "get_last_row": 5.4411000000698095e-05, "get_batch_of_1024_rows": 0.0004897069999998394, "get_batch_of_1024_random_rows": 0.01800621099999944}
8 changes: 8 additions & 0 deletions dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ stages:
metrics:
- ./benchmarks/results/benchmark_iterating.json:
cache: false

benchmark_getitem_100B:
cmd: python ./benchmarks/benchmark_getitem_100B.py
deps:
- ./benchmarks/benchmark_getitem_100B.py
metrics:
- ./benchmarks/results/benchmark_getitem_100B.json:
cache: false
30 changes: 20 additions & 10 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,18 @@ def wrapper(*args, **kwargs):
return wrapper


def _check_table(table) -> Table:
"""We check the table type to make sure it's an instance of :class:`datasets.table.Table`"""
if isinstance(table, pa.Table):
# for a pyarrow table, we can just consider it as a in-memory table
# this is here for backward compatibility
return InMemoryTable(table)
elif isinstance(table, Table):
return table
else:
raise TypeError(f"Expected a pyarrow.Table or a datasets.table.Table object, but got {table}.")


class Dataset(DatasetInfoMixin, IndexableMixin):
"""A Dataset backed by an Arrow table."""

Expand All @@ -206,8 +218,8 @@ def __init__(
DatasetInfoMixin.__init__(self, info=info, split=split)
IndexableMixin.__init__(self)

self._data: Table = arrow_table
self._indices: Optional[Table] = indices_table
self._data: Table = _check_table(arrow_table)
self._indices: Optional[Table] = _check_table(indices_table) if indices_table is not None else None

self._format_type: Optional[str] = None
self._format_kwargs: dict = {}
Expand Down Expand Up @@ -1157,9 +1169,7 @@ def _getitem(
"""
format_kwargs = format_kwargs if format_kwargs is not None else {}
formatter = get_formatter(format_type, **format_kwargs)
pa_subtable = query_table(
self._data, key, indices=self._indices.column(0) if self._indices is not None else None
)
pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
formatted_output = format_table(
pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
)
Expand Down Expand Up @@ -1870,7 +1880,7 @@ def select(
if self._indices is not None:
indices_array = self._indices.column(0).take(indices_array)

indices_table = InMemoryTable.from_arrays([indices_array], names=["indices"])
indices_table = pa.Table.from_arrays([indices_array], names=["indices"])

with writer:
try:
Expand Down Expand Up @@ -2427,15 +2437,15 @@ def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Un
return query_table(
table=self._data,
key=slice(0, len(self)),
indices=self._indices.column(0) if self._indices is not None else None,
indices=self._indices if self._indices is not None else None,
).to_pydict()
else:
batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
return (
query_table(
table=self._data,
key=slice(offset, offset + batch_size),
indices=self._indices.column(0) if self._indices is not None else None,
indices=self._indices if self._indices is not None else None,
).to_pydict()
for offset in range(0, len(self), batch_size)
)
Expand All @@ -2458,15 +2468,15 @@ def to_pandas(
return query_table(
table=self._data,
key=slice(0, len(self)),
indices=self._indices.column(0) if self._indices is not None else None,
indices=self._indices if self._indices is not None else None,
).to_pandas()
else:
batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
return (
query_table(
table=self._data,
key=slice(offset, offset + batch_size),
indices=self._indices.column(0) if self._indices is not None else None,
indices=self._indices if self._indices is not None else None,
).to_pandas()
for offset in range(0, len(self), batch_size)
)
Expand Down
51 changes: 26 additions & 25 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,51 +42,55 @@ def _raise_bad_key_type(key: Any):


def _query_table_with_indices_mapping(
pa_table: pa.Table, key: Union[int, slice, range, str, Iterable], indices: pa.lib.UInt64Array
table: Table, key: Union[int, slice, range, str, Iterable], indices: Table
) -> pa.Table:
"""
Query a pyarrow Table to extract the subtable that correspond to the given key.
The :obj:`indices` parameter corresponds to the indices mapping in case we cant to take into
account a shuffling or an indices selection for example.
The indices table must contain one column named "indices" of type uint64.
"""
if isinstance(key, int):
return _query_table(pa_table, indices[key].as_py())
key = indices.fast_slice(key % indices.num_rows, 1).column(0)[0].as_py()
return _query_table(table, key)
if isinstance(key, slice):
key = range(*key.indices(pa_table.num_rows))
key = range(*key.indices(table.num_rows))
if isinstance(key, range):
if _is_range_contiguous(key):
return _query_table(pa_table, [i.as_py() for i in indices.slice(key.start, key.stop - key.start)])
if _is_range_contiguous(key) and key.start >= 0:
return _query_table(
table, [i.as_py() for i in indices.fast_slice(key.start, key.stop - key.start).column(0)]
)
else:
pass # treat as an iterable
if isinstance(key, str):
pa_table = _query_table(pa_table, key)
return _query_table(pa_table, indices.to_pylist())
table = table.drop([column for column in table.column_names if column != key])
return _query_table(table, indices.column(0).to_pylist())
if isinstance(key, Iterable):
return _query_table(pa_table, [indices[i].as_py() for i in key])
return _query_table(table, [indices.fast_slice(i, 1).column(0)[0].as_py() for i in key])

_raise_bad_key_type(key)


def _query_table(pa_table: pa.Table, key: Union[int, slice, range, str, Iterable]) -> pa.Table:
def _query_table(table: Table, key: Union[int, slice, range, str, Iterable]) -> pa.Table:
"""
Query a pyarrow Table to extract the subtable that correspond to the given key.
"""
if isinstance(key, int):
return pa_table.slice(key % pa_table.num_rows, 1)
return table.fast_slice(key % table.num_rows, 1)
if isinstance(key, slice):
key = range(*key.indices(pa_table.num_rows))
key = range(*key.indices(table.num_rows))
if isinstance(key, range):
if _is_range_contiguous(key) and key.start >= 0:
return pa_table.slice(key.start, key.stop - key.start)
return table.fast_slice(key.start, key.stop - key.start)
else:
pass # treat as an iterable
if isinstance(key, str):
return pa_table.drop([column for column in pa_table.column_names if column != key])
return table.table.drop([column for column in table.column_names if column != key])
if isinstance(key, Iterable):
if len(key) == 0:
return pa_table.slice(0, 0)
return table.table.slice(0, 0)
# don't use pyarrow.Table.take even for pyarrow >=1.0 (see https://issues.apache.org/jira/browse/ARROW-9773)
return pa.concat_tables(pa_table.slice(int(i) % pa_table.num_rows, 1) for i in key)
return pa.concat_tables(table.fast_slice(int(i) % table.num_rows, 1) for i in key)

_raise_bad_key_type(key)

Expand Down Expand Up @@ -306,7 +310,7 @@ def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
def query_table(
table: Table,
key: Union[int, slice, range, str, Iterable],
indices: Optional[pa.lib.UInt64Array] = None,
indices: Optional[Table] = None,
) -> pa.Table:
"""
Query a Table to extract the subtable that correspond to the given key.
Expand All @@ -319,30 +323,27 @@ def query_table(
- a range(i, j, k): the subtable containing the rows that correspond to this range
- a string c: the subtable containing all the rows but only the column c
- an iterable l: the subtable that is the concatenation of all the i-th rows for all i in the iterable
indices (Optional ``pyarrow.lib.UInt64Array``): If not None, it is used to re-map the given key to the table rows.
indices (Optional ``datasets.table.Table``): If not None, it is used to re-map the given key to the table rows.
The indices table must contain one column named "indices" of type uint64.
This is used in case of shuffling or rows selection.
Returns:
``pyarrow.Table``: the result of the query on the input table
"""
if isinstance(table, Table):
pa_table = table.table
else:
pa_table = table
# Check if key is valid
if not isinstance(key, (int, slice, range, str, Iterable)):
_raise_bad_key_type(key)
if isinstance(key, str):
_check_valid_column_key(key, pa_table.column_names)
_check_valid_column_key(key, table.column_names)
else:
size = len(indices) if indices is not None else pa_table.num_rows
size = indices.num_rows if indices is not None else table.num_rows
_check_valid_index_key(key, size)
# Query the main table
if indices is None:
pa_subtable = _query_table(pa_table, key)
pa_subtable = _query_table(table, key)
else:
pa_subtable = _query_table_with_indices_mapping(pa_table, key, indices=indices)
pa_subtable = _query_table_with_indices_mapping(table, key, indices=indices)
return pa_subtable


Expand Down
2 changes: 1 addition & 1 deletion src/datasets/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _write(
batch = query_table(
table=self.dataset._data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices.column(0) if self.dataset._indices is not None else None,
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
csv_str = batch.to_pandas().to_csv(
path_or_buf=None, header=header if (offset == 0) else False, encoding=encoding, **to_csv_kwargs
Expand Down
Loading

0 comments on commit ae8b940

Please sign in to comment.