forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_jit_internal.py
106 lines (83 loc) · 2.81 KB
/
_jit_internal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
The weak_script annotation needs to be here instead of inside torch/jit/ so it
can be used in other places in torch/ (namely torch.nn) without running into
circular dependency problems
"""
import weakref
import inspect
try:
import builtins # PY3
except Exception:
import __builtin__ as builtins # PY2
# Tracks standalone weak script functions
_compiled_weak_fns = weakref.WeakKeyDictionary()
# Tracks which methods should be converted to strong methods
_weak_script_methods = weakref.WeakKeyDictionary()
# Converted modules and their corresponding WeakScriptModuleProxy objects
_weak_modules = weakref.WeakKeyDictionary()
# Types that have been declared as weak modules
_weak_types = weakref.WeakKeyDictionary()
COMPILATION_PENDING = object()
COMPILED = object()
def createResolutionCallback(frames_up=0):
"""
Creates a function which, given a string variable name,
returns the value of the variable in the scope of the caller of
the function which called createResolutionCallback (by default).
This is used to enable access in-scope Python variables inside
TorchScript fragments.
frames_up is number of additional frames to go up on the stack.
The default value is 0, which correspond to the frame of the caller
of createResolutionCallback. Also for example, if frames_up is set
to 1, then the frame of the caller's caller of createResolutionCallback
will be taken.
For example, the following program prints 2::
def bar():
cb = createResolutionCallback(1)
print(cb("foo"))
def baz():
foo = 2
bar()
baz()
"""
frame = inspect.currentframe()
i = 0
while i < frames_up + 1:
frame = frame.f_back
i += 1
f_locals = frame.f_locals
f_globals = frame.f_globals
def env(key):
if key in f_locals:
return f_locals[key]
elif key in f_globals:
return f_globals[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
else:
return None
return env
def weak_script(fn, _frames_up=0):
"""
Marks a function as a weak script function. When used in a script function
or ScriptModule, the weak script function will be lazily compiled and
inlined in the graph. When not used in a script function, the weak script
annotation has no effect.
"""
_compiled_weak_fns[fn] = {
"status": COMPILATION_PENDING,
"compiled_fn": None,
"rcb": createResolutionCallback(_frames_up + 1)
}
return fn
def weak_module(cls):
_weak_types[cls] = {
"method_stubs": None
}
return cls
def weak_script_method(fn):
_weak_script_methods[fn] = {
"rcb": createResolutionCallback(frames_up=2),
"original_method": fn
}
return fn