Skip to content

Commit

Permalink
fix: fix mutation types in delete DAG (#28110)
Browse files Browse the repository at this point in the history
  • Loading branch information
fuziontech authored Jan 30, 2025
1 parent 05709b1 commit 937fa79
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
10 changes: 4 additions & 6 deletions dags/deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,15 @@ def count_pending_deletes(client: Client) -> int:

@op
def wait_for_delete_mutations(
context: OpExecutionContext,
cluster: ResourceParam[ClickhouseCluster],
delete_person_events: tuple[PendingPersonEventDeletesTable, Mutation],
delete_person_events: tuple[PendingPersonEventDeletesTable, ShardMutations],
) -> PendingPersonEventDeletesTable:
pending_person_deletions, shard_mutations = delete_person_events

if not shard_mutations:
return pending_person_deletions
cluster.map_all_hosts_in_shards({shard: mutation.wait for shard, mutation in shard_mutations.items()}).result()

[table, mutations] = delete_person_events
cluster.map_all_hosts(mutations.wait).result()
return table
return pending_person_deletions


@op
Expand Down
35 changes: 32 additions & 3 deletions posthog/clickhouse/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,25 @@ def __init__(
self.__logger = logger
self.__client_settings = client_settings

@property
def shards(self) -> list[int]:
return list({host.shard_num for host in self.__hosts if host.shard_num is not None})

def __get_task_function(self, host: HostInfo, fn: Callable[[Client], T]) -> Callable[[], T]:
pool = self.__pools.get(host)
if pool is None:
pool = self.__pools[host] = host.connection_info.make_pool(self.__client_settings)

def task():
with pool.get_client() as client:
self.__logger.debug("Executing %r on %r...", fn, host)
self.__logger.info("Executing %r on %r...", fn, host)
try:
result = fn(client)
except Exception:
self.__logger.debug("Failed to execute %r on %r!", fn, host, exc_info=True)
self.__logger.warn("Failed to execute %r on %r!", fn, host, exc_info=True)
raise
else:
self.__logger.debug("Successfully executed %r on %r.", fn, host)
self.__logger.info("Successfully executed %r on %r.", fn, host)
return result

return task
Expand Down Expand Up @@ -159,6 +163,31 @@ def map_all_hosts_in_shard(
}
)

def map_all_hosts_in_shards(
self, shard_fns: dict[int, Callable[[Client], T]], concurrency: int | None = None
) -> FuturesMap[HostInfo, T]:
"""
Execute the callable once for each host in the specified shards.
The number of concurrent queries can limited with the ``concurrency`` parameter, or set to ``None`` to use the
default limit of the executor.
Wait for all to return before returning upon ``.values()``
"""

shard_host_fn = {}
for shard, fn in shard_fns.items():
if shard not in self.shards:
raise ValueError(f"Shard {shard} not found in cluster")
for host in self.__hosts:
if host.shard_num == shard:
shard_host_fn[host] = fn

with ThreadPoolExecutor(max_workers=concurrency) as executor:
return FuturesMap(
{host: executor.submit(self.__get_task_function(host, fn)) for host, fn in shard_host_fn.items()}
)

def map_one_host_per_shard(
self, fn: Callable[[Client], T], concurrency: int | None = None
) -> FuturesMap[HostInfo, T]:
Expand Down

0 comments on commit 937fa79

Please sign in to comment.