Skip to content

Commit

Permalink
Power-law histogram generations
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee committed Feb 5, 2024
1 parent 95e03c1 commit 65424cf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
9 changes: 5 additions & 4 deletions lilac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import enum
import pathlib
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import datetime, date
from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -112,9 +112,10 @@ class StatsResult(BaseModel):
# The approximate number of distinct leaf values.
approx_count_distinct: int

# Defined for ordinal features.
min_val: Optional[Union[float, datetime]] = None
max_val: Optional[Union[float, datetime]] = None
# Defined for numeric features.
min_val: Optional[Union[float, date, datetime]] = None
max_val: Optional[Union[float, date, datetime]] = None
value_samples: Optional[list[float]] = None # Used for approximating histogram bins

# Defined for text features.
avg_text_length: Optional[float] = None
Expand Down
53 changes: 34 additions & 19 deletions lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from datasets import Dataset as HuggingFaceDataset
from pandas.api.types import is_object_dtype
from pydantic import BaseModel, SerializeAsAny, field_validator
from sklearn.preprocessing import PowerTransformer
from typing_extensions import override

from ..auth import UserInfo
Expand Down Expand Up @@ -182,7 +183,7 @@

SQLITE_LABEL_COLNAME = 'label'
SQLITE_CREATED_COLNAME = 'created'
NUM_AUTO_BINS = 15
MAX_AUTO_BINS = 15

BINARY_OP_TO_SQL: dict[BinaryOp, str] = {
'equals': '=',
Expand Down Expand Up @@ -1820,12 +1821,19 @@ def stats(self, leaf_path: Path, include_deleted: bool = False) -> StatsResult:
# Compute min/max values for ordinal leafs, without sampling the data.
if is_ordinal(leaf.dtype):
min_max_query = f"""
SELECT MIN(val) AS minVal, MAX(val) AS maxVal
SELECT MIN(val) AS minVal, MAX(val) AS maxVal,
FROM (SELECT {inner_select} as val FROM t {where_clause})
{'WHERE NOT isnan(val)' if is_float(leaf.dtype) else ''}
"""
row = self._query(min_max_query)[0]
result.min_val, result.max_val = row
sample_query = f"""
SELECT ARRAY_AGG(val)
FROM (SELECT {inner_select} as val FROM t {where_clause})
{'WHERE NOT isnan(val)' if is_float(leaf.dtype) else ''}
USING SAMPLE 50;
"""
result.value_samples = list(self._query(sample_query)[0][0])

return result

Expand Down Expand Up @@ -1873,7 +1881,7 @@ def select_groups(
if not leaf.categorical and (leaf_is_float or leaf_is_integer):
if named_bins is None:
# Auto-bin.
named_bins = _auto_bins(stats, NUM_AUTO_BINS)
named_bins = _auto_bins(stats)

sql_bounds = []
for label, start, end in named_bins:
Expand Down Expand Up @@ -3914,7 +3922,7 @@ def _normalize_bins(bins: Optional[Union[Sequence[Bin], Sequence[float]]]) -> Op
return named_bins


def _auto_bins(stats: StatsResult, num_bins: int) -> list[Bin]:
def _auto_bins(stats: StatsResult) -> list[Bin]:
if stats.min_val is None or stats.max_val is None:
return [('0', None, None)]

Expand All @@ -3923,22 +3931,29 @@ def _auto_bins(stats: StatsResult, num_bins: int) -> list[Bin]:
const_val = cast(float, stats.min_val)
return [('0', const_val, None)]

min_val = cast(float, stats.min_val)
max_val = cast(float, stats.max_val)
value_range = max_val - min_val
# Select a round ndigits as a function of the value range. We offset it by 2 to allow for some
# decimal places as a function of the range.
round_ndigits = -1 * round(math.log10(value_range)) + 3
bin_width = (max_val - min_val) / num_bins
bins: list = []
last_end_val = None
for i in range(num_bins):
end = None if i == num_bins - 1 else min_val + (i + 1) * bin_width
if end:
end = round(end, round_ndigits)
start = last_end_val
last_end_val = end
is_integer = stats.value_samples and all(isinstance(val, int) for val in stats.value_samples)
def _round(value: float) -> float:
# Select a round ndigits as a function of the value range. We offset it by 2 to allow for some
# decimal places as a function of the range.
if not value:
return round(value)
round_ndigits = -1 * round(math.log10(abs(value))) + 1
if is_integer:
round_ndigits = min(round_ndigits, 0)
return round(value, round_ndigits)

num_bins = min(int(np.log2(stats.total_count)), MAX_AUTO_BINS)
transformer = PowerTransformer().fit(np.array(stats.value_samples).reshape(-1, 1))
# [-2, 2], assuming normal distribution, should cover central 95% of the data.
buckets = transformer.inverse_transform(np.linspace(-2, 2, num_bins).reshape(-1, 1)).ravel()
# Sometimes the autogenerated buckets round to the same value.
# Sometimes PowerTransformer returns NaN for some unusually shaped distributions.
buckets = sorted(set(_round(val) for val in buckets if not np.isnan(val)))
buckets = [None] + buckets + [None]
bins = []
for i, (start, end) in enumerate(zip(buckets[:-1], buckets[1:])):
bins.append((str(i), start, end))

return bins


Expand Down
1 change: 1 addition & 0 deletions web/lib/fastapi_client/models/StatsResult.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export type StatsResult = {
approx_count_distinct: number;
min_val?: (number | string | null);
max_val?: (number | string | null);
value_samples?: (Array<number> | null);
avg_text_length?: (number | null);
};

0 comments on commit 65424cf

Please sign in to comment.