forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnet_async_task_future.h
76 lines (56 loc) · 1.88 KB
/
net_async_task_future.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H
#define CAFFE2_NET_ASYNC_TASK_FUTURE_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
namespace caffe2 {
// Represents the state of AsyncTask execution, that can be queried with
// IsCompleted/IsFailed. Callbacks are supported through SetCallback and
// are called upon future's completion.
class AsyncTaskFuture {
public:
AsyncTaskFuture();
// Creates a future completed when all given futures are completed
explicit AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures);
~AsyncTaskFuture();
AsyncTaskFuture(const AsyncTaskFuture&) = delete;
AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete;
bool IsCompleted() const;
bool IsFailed() const;
std::string ErrorMessage() const;
void Wait() const;
void SetCallback(std::function<void(const AsyncTaskFuture*)> callback);
void SetCompleted(const char* err_msg = nullptr);
void ResetState();
private:
mutable std::mutex mutex_;
mutable std::condition_variable cv_completed_;
std::atomic<bool> completed_;
std::atomic<bool> failed_;
std::string err_msg_;
std::vector<std::function<void(const AsyncTaskFuture*)>> callbacks_;
struct ParentCounter {
explicit ParentCounter(int init_parent_count)
: init_parent_count_(init_parent_count),
parent_count(init_parent_count),
parent_failed(false) {}
void Reset() {
std::unique_lock<std::mutex> lock(err_mutex);
parent_count = init_parent_count_;
parent_failed = false;
err_msg = "";
}
const int init_parent_count_;
std::atomic<int> parent_count;
std::mutex err_mutex;
std::atomic<bool> parent_failed;
std::string err_msg;
};
std::unique_ptr<ParentCounter> parent_counter_;
};
} // namespace caffe2
#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H