Skip to content

Commit

Permalink
feat: Support providing a configurable limit to squash job (#28054)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkaemming authored Jan 29, 2025
1 parent bfeabae commit 376f261
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 29 deletions.
41 changes: 22 additions & 19 deletions dags/person_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def create_resource(self, context: dagster.InitResourceContext) -> ClickhouseClu
@dataclass
class PersonOverridesSnapshotTable:
id: uuid.UUID
timestamp: str

@property
def name(self) -> str:
Expand Down Expand Up @@ -66,22 +65,25 @@ def exists(self, client: Client) -> None:
def drop(self, client: Client) -> None:
client.execute(f"DROP TABLE IF EXISTS {self.qualified_name} SYNC")

def populate(self, client: Client) -> None:
def populate(self, client: Client, timestamp: str, limit: int | None = None) -> None:
# NOTE: this is theoretically subject to replication lag and accuracy of this result is not a guarantee
# this could optionally support truncate as a config option if necessary to reset the table state, or
# force an optimize after insertion to compact the table before dictionary insertion (if that's even needed)
[[count]] = client.execute(f"SELECT count() FROM {self.qualified_name}")
assert count == 0

limit_clause = f"LIMIT {limit}" if limit else ""

client.execute(
f"""
INSERT INTO {self.qualified_name} (team_id, distinct_id, person_id, version)
SELECT team_id, distinct_id, argMax(person_id, version), max(version)
FROM {settings.CLICKHOUSE_DATABASE}.{PERSON_DISTINCT_ID_OVERRIDES_TABLE}
WHERE _timestamp < %(timestamp)s
GROUP BY team_id, distinct_id
{limit_clause}
""",
{"timestamp": self.timestamp},
{"timestamp": timestamp},
)

def sync(self, client: Client) -> None:
Expand Down Expand Up @@ -204,7 +206,18 @@ def overrides_delete_mutation_runner(self) -> MutationRunner:
# Snapshot Table Management


class SnapshotTableConfig(dagster.Config):
@dagster.op
def create_snapshot_table(
context: dagster.OpExecutionContext,
cluster: dagster.ResourceParam[ClickhouseCluster],
) -> PersonOverridesSnapshotTable:
"""Create the snapshot table on all hosts in the cluster."""
table = PersonOverridesSnapshotTable(id=uuid.UUID(context.run.run_id))
cluster.map_all_hosts(table.create).result()
return table


class PopulateSnapshotTableConfig(dagster.Config):
"""
Configuration for creating and populating the initial snapshot table.
"""
Expand All @@ -215,30 +228,20 @@ class SnapshotTableConfig(dagster.Config):
"the past that there is no reasonable likelihood that events or overrides prior to this time have not yet been "
"written to the database and replicated to all hosts in the cluster."
)


@dagster.op
def create_snapshot_table(
context: dagster.OpExecutionContext,
cluster: dagster.ResourceParam[ClickhouseCluster],
config: SnapshotTableConfig,
) -> PersonOverridesSnapshotTable:
"""Create the snapshot table on all hosts in the cluster."""
table = PersonOverridesSnapshotTable(
id=uuid.UUID(context.run.run_id),
timestamp=config.timestamp,
limit: int | None = pydantic.Field(
description="The number of rows to include in the snapshot. If provided, this can be used to limit the total "
"amount of memory consumed by the squash process during execution."
)
cluster.map_all_hosts(table.create).result()
return table


@dagster.op
def populate_snapshot_table(
cluster: dagster.ResourceParam[ClickhouseCluster],
table: PersonOverridesSnapshotTable,
config: PopulateSnapshotTableConfig,
) -> PersonOverridesSnapshotTable:
"""Fill the snapshot data with the selected overrides based on the configuration timestamp."""
cluster.any_host(table.populate).result()
cluster.any_host(partial(table.populate, timestamp=config.timestamp, limit=config.limit)).result()
return table


Expand Down
39 changes: 29 additions & 10 deletions dags_test/test_person_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from dags.person_overrides import (
PersonOverridesSnapshotDictionary,
PersonOverridesSnapshotTable,
SnapshotTableConfig,
create_snapshot_table,
PopulateSnapshotTableConfig,
populate_snapshot_table,
squash_person_overrides,
)
from posthog.clickhouse.cluster import ClickhouseCluster, get_cluster
Expand Down Expand Up @@ -76,23 +76,42 @@ def get_distinct_ids_with_overrides(client: Client) -> set[str]:
}
assert cluster.any_host(get_distinct_ids_with_overrides).result() == {"c", "d", "e", "z"}

result = squash_person_overrides.execute_in_process(
# run with limit
limited_run_result = squash_person_overrides.execute_in_process(
run_config=dagster.RunConfig(
{create_snapshot_table.name: SnapshotTableConfig(timestamp=timestamp.isoformat())}
{populate_snapshot_table.name: PopulateSnapshotTableConfig(timestamp=timestamp.isoformat(), limit=2)}
),
resources={"cluster": cluster},
)

# ensure we cleaned up after ourselves
table = PersonOverridesSnapshotTable(UUID(limited_run_result.dagster_run.run_id))
dictionary = PersonOverridesSnapshotDictionary(table)
assert not any(cluster.map_all_hosts(table.exists).result().values())
assert not any(cluster.map_all_hosts(dictionary.exists).result().values())

remaining_overrides = cluster.any_host(get_distinct_ids_with_overrides).result()
assert len(remaining_overrides) == 2 # one candidate discarded due to limit, one out of timestamp range
assert "z" in remaining_overrides # outside of timestamp range

# run without limit to handle the remaining item(s)
full_run_result = squash_person_overrides.execute_in_process(
run_config=dagster.RunConfig(
{populate_snapshot_table.name: PopulateSnapshotTableConfig(timestamp=timestamp.isoformat())}
),
resources={"cluster": cluster},
)

# ensure we cleaned up after ourselves again
table = PersonOverridesSnapshotTable(UUID(full_run_result.dagster_run.run_id))
dictionary = PersonOverridesSnapshotDictionary(table)
assert not any(cluster.map_all_hosts(table.exists).result().values())
assert not any(cluster.map_all_hosts(dictionary.exists).result().values())

# check postconditions
assert cluster.any_host(get_distinct_ids_on_events_by_person).result() == {
UUID(int=0): {"a", "c"},
UUID(int=1): {"b", "d", "e"},
UUID(int=100): {"z"},
}
assert cluster.any_host(get_distinct_ids_with_overrides).result() == {"z"}

# ensure we cleaned up after ourselves
table = PersonOverridesSnapshotTable(UUID(result.dagster_run.run_id), timestamp)
dictionary = PersonOverridesSnapshotDictionary(table)
assert not any(cluster.map_all_hosts(table.exists).result().values())
assert not any(cluster.map_all_hosts(dictionary.exists).result().values())

0 comments on commit 376f261

Please sign in to comment.