Skip to content

Commit

Permalink
Update cpn inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Sep 30, 2024
1 parent 6e20336 commit f3bf775
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions celldetection_scripts/cpn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def on_oom():

v_ = None
oom = cd.OomCatcher(2, callback=on_oom)
while oom:
for attempt in oom:
if target_device is not None:
v = v.to(target_device)
with oom:
Expand Down Expand Up @@ -261,34 +261,26 @@ def oom_safe_gather_dict(local_dict: Dict[str, torch.Tensor], dst=0, fallback_de
if rank == dst:
vs = []
ds = tuple(v.shape[1:])

def on_oom():
warn(f'Not enough memory on {device}. Moving data to {fallback_device} in order to continue.')
nonlocal target_device, result, vs
target_device = fallback_device
result = cd.to_device(result, target_device)
vs = cd.to_device(vs, target_device)

oom = cd.OomCatcher(2, callback=on_oom)
for src in range(ranks):
recv_size = tuple(sizes[src].cpu().data.numpy()) + ds
if src == dst:
recv_tensor = v.to(device)
if target_device is None: # if not target device, send to where everything else is
def on_oom():
nonlocal recv_tensor, target_device
target_device = fallback_device
recv_tensor = v.to(fallback_device)

oom = cd.OomCatcher(2, callback=on_oom)
while oom:
with oom:
recv_tensor = recv_tensor.to(device)
else:
if src == dst: # data from own rank
for attempt in oom:
with oom:
recv_tensor = v.to(target_device or device)
else: # data from other ranks
# Create OOM safe recv Tensor
def on_oom():
nonlocal target_device, result, vs
warn(f'Not enough memory on {device}. Moving data to {fallback_device} in order to continue.')
target_device = fallback_device
result = cd.to_device(result, target_device)
vs = cd.to_device(vs, target_device)

oom = cd.OomCatcher(2, callback=on_oom)
while oom:
for attempt in oom:
with oom:
recv_tensor = torch.empty(recv_size, dtype=v.dtype, device=device)

torch.distributed.recv(recv_tensor, src=src) # todo: receive unordered
if target_device is not None: # move to other device right away
recv_tensor = recv_tensor.to(target_device)
Expand Down

0 comments on commit f3bf775

Please sign in to comment.