Skip to content

Commit

Permalink
Improve the cluster titling and fix a few client-side bugs (#1058)
Browse files Browse the repository at this point in the history
- use sort_by cluster_id so we can avoid any thread locks and get better
resumability
- Improve the prompts
- Fix a few client-side and server-side bugs related to sorting
- Title-ing is now 2x faster -- takes 10mins to title HN 1M comments,
and 13min to title OpenHermes-2.5
  • Loading branch information
dsmilkov authored Jan 12, 2024
1 parent 6b315be commit 54f608b
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 330 deletions.
170 changes: 90 additions & 80 deletions lilac/data/cluster_test.py

Large diffs are not rendered by default.

550 changes: 315 additions & 235 deletions lilac/data/clustering.py

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
open_file,
)
from . import dataset
from .clustering import cluster, summarize_request
from .clustering import cluster_impl, summarize_request
from .dataset import (
BINARY_OPS,
DELETED_LABEL_NAME,
Expand Down Expand Up @@ -1939,9 +1939,11 @@ def select_rows(

order_query = ''
if sort_sql_before_udf:
order_query = (
f'ORDER BY {", ".join(sort_sql_before_udf)} ' f'{cast(SortOrder, sort_order).value}'
)
# TODO(smilkov): Make the sort order also a list to align with the sort_by list.
sort_with_order = [
f'{sql} {cast(SortOrder, sort_order).value}' for sql in sort_sql_before_udf
]
order_query = f'ORDER BY {", ".join(sort_with_order)}'

limit_query = ''
if limit:
Expand Down Expand Up @@ -2882,7 +2884,7 @@ def map(
with open_file(map_manifest_filepath, 'w') as f:
f.write(map_manifest.model_dump_json(exclude_none=True, indent=2))

log(f'Wrote map output to {parquet_dir}')
log(f'Wrote map output to {parquet_filename}')

# Promote any new string columns as media fields if the length is above a threshold.
for path, field in map_schema.leafs.items():
Expand All @@ -2905,7 +2907,7 @@ def cluster(
task_id: Optional[TaskId] = None,
) -> None:
topic_fn = topic_fn or summarize_request
return cluster(
return cluster_impl(
self, input, output_path, min_cluster_size, topic_fn, overwrite, remote, task_id=task_id
)

Expand Down
2 changes: 1 addition & 1 deletion lilac/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def flatten_keys(


def sparse_to_dense_compute(
sparse_input: Iterable[Optional[Tin]], func: Callable[[Iterable[Tin]], Iterator[Tout]]
sparse_input: Iterator[Optional[Tin]], func: Callable[[Iterator[Tin]], Iterator[Tout]]
) -> Iterator[Optional[Tout]]:
"""Densifies the input before calling the provided `func` and sparsifies the output."""
sparse_input = iter(sparse_input)
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,7 @@ follow_imports = skip
[mypy-modal.*]
ignore_missing_imports = True
follow_imports = skip

[mypy-hdbscan.*]
ignore_missing_imports = True
follow_imports = skip
80 changes: 74 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ google-generativeai = { version = "^0.1.0", optional = true } # PaLM / MakerSuit
google-cloud-aiplatform = { version = "^1.35.0", optional = true, extras = [
"grpcio",
] } # PaLM via GCP
openai = { version = "^1.3.7", optional = true }
openai = { version = "^1.7.1", optional = true }
sentence-transformers = { version = "^2.2.2", optional = true } # SBERT on-device embeddings.

# Gmail source.
Expand Down Expand Up @@ -85,6 +85,7 @@ llama-hub = { version = "^0.0.50", optional = true, python = ">=3.9,<3.12" }

# For HDBScan to reduce dimensionality before running clustering.
umap-learn = { version = "^0.5.4", optional = true }
hdbscan = { version = "^0.8.33", optional = true }

[tool.poetry.extras]
all = [
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ filterwarnings =
ignore::DeprecationWarning:huggingface_hub.*:
ignore::DeprecationWarning:pydantic_core.*:
ignore:PydanticDeprecatedSince20:DeprecationWarning
ignore::DeprecationWarning:hdbscan.*:
markers =
largedownload: Marks a test as having a large download. Wont run on github. (deselect with '-m "not largedownload"')
asyncio_mode = auto
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
$: rowsQuery = querySelectRows(
$store.namespace,
$store.datasetName,
{...selectOptions, columns: [ROWID], limit},
{
...selectOptions,
columns: [ROWID],
limit,
// Sort by ROWID on top of any other sort_by option to ensure that the result order is stable.
sort_by: [...(selectOptions.sort_by || []), ROWID]
},
$selectRowsSchema.data?.schema
);
Expand Down
1 change: 1 addition & 0 deletions web/blueprint/src/lib/stores/datasetViewStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ export function getSelectRowsOptions(
fieldName == CLUSTER_TITLE_FIELD ? CLUSTER_MEMBERSHIP_PROB : CATEGORY_MEMBERSHIP_PROB
);
options.sort_by = [membershipProbPath];
options.sort_order = 'DESC';
}
options.searches = options.searches || [];
options.searches.push({
Expand Down

0 comments on commit 54f608b

Please sign in to comment.