forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPythonTorchFunctionTLS.h
41 lines (32 loc) · 1.25 KB
/
PythonTorchFunctionTLS.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Macros.h>
namespace at {
namespace impl {
struct TORCH_API PythonTorchFunctionTLS {
static void set_disabled(bool);
static bool is_disabled();
static void set_mode(std::shared_ptr<c10::SafePyObject>);
static const std::shared_ptr<c10::SafePyObject>& get_mode();
static void swap_mode(std::shared_ptr<c10::SafePyObject>&);
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();
static const PythonTorchFunctionTLS& get_state();
static void set_state(const PythonTorchFunctionTLS& state);
private:
// The mode TLS is split into
// - disabled_, which says whether or not to disable all torch function
// modes
// - mode_, which is the C++ mode, that can only be the mode handling mode
// or null
// - stack_, which is a vector of modes representing the stack of user
// defined modes
bool disabled_;
std::shared_ptr<c10::SafePyObject> mode_ = nullptr;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};
TORCH_API bool function_mode_enabled();
} // namespace impl
} // namespace at