-
Notifications
You must be signed in to change notification settings - Fork 161
/
Copy path_apply_udf.py
53 lines (44 loc) · 1.62 KB
/
_apply_udf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing.synchronize as python_mp_synchronize
import queue
import threading
from typing import Callable, Union
import torch
import torch.multiprocessing as mp
from torch._utils import ExceptionWrapper
from .constants import QUEUE_TIMEOUT
def _apply_udf(
worker_id: int,
in_q: Union[queue.Queue, mp.Queue],
out_q: Union[queue.Queue, mp.Queue],
udf: Callable,
stop_event: Union[threading.Event, python_mp_synchronize.Event],
):
"""_apply_udf assumes in_q emits tuples of (x, idx) where x is the
payload, idx is the index of the result, potentially used for maintaining
ordered outputs. For every input it pulls, a tuple (y, idx) is put on the out_q
where the output of udf(x), an ExceptionWrapper, or StopIteration (if it pulled
StopIteration from in_q).
"""
torch.set_num_threads(1)
while True:
if stop_event.is_set() and in_q.empty():
break
try:
item, idx = in_q.get(block=True, timeout=QUEUE_TIMEOUT)
except queue.Empty:
continue
if isinstance(item, ExceptionWrapper):
out_q.put((item, idx), block=False)
elif isinstance(item, StopIteration):
out_q.put((item, idx), block=False)
else:
try:
y = udf(item)
except Exception:
y = ExceptionWrapper(where="in _apply_udf")
out_q.put((y, idx), block=False)