forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cwrap_common.py
213 lines (186 loc) · 7.93 KB
/
cwrap_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap
import copy
from typing import Any, Dict, Iterable, List, Union
Arg = Dict[str, Any]
def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
new_args = []
for arg in args:
# Simple arg declaration of form "<type> <name>"
if isinstance(arg, str):
t, _, name = arg.partition(' ')
new_args.append({'type': t, 'name': name})
elif isinstance(arg, dict):
if 'arg' in arg:
arg['type'], _, arg['name'] = arg['arg'].partition(' ')
del arg['arg']
new_args.append(arg)
else:
raise AssertionError()
return new_args
Declaration = Dict[str, Any]
def set_declaration_defaults(declaration: Declaration) -> None:
if 'schema_string' not in declaration:
# This happens for legacy TH bindings like
# _thnn_conv_depthwise2d_backward
declaration['schema_string'] = ''
declaration.setdefault('arguments', [])
declaration.setdefault('return', 'void')
if 'cname' not in declaration:
declaration['cname'] = declaration['name']
if 'backends' not in declaration:
declaration['backends'] = ['CPU', 'CUDA']
assert 'api_name' not in declaration
declaration['api_name'] = declaration['name']
# NB: keep this in sync with gen_autograd.py
if declaration.get('overload_name'):
declaration['type_wrapper_name'] = "{}_{}".format(
declaration['name'], declaration['overload_name'])
else:
declaration['type_wrapper_name'] = declaration['name']
# TODO: Uggggh, parsing the schema string here, really???
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
if declaration['schema_string']:
declaration['unqual_schema_string'] = declaration['schema_string'].split('::')[1]
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
else:
declaration['unqual_schema_string'] = ''
declaration['unqual_operator_name_with_overload'] = ''
# Simulate multiple dispatch, even if it's not necessary
if 'options' not in declaration:
declaration['options'] = [{
'arguments': copy.deepcopy(declaration['arguments']),
'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
}]
del declaration['arguments']
del declaration['schema_order_arguments']
# Parse arguments (some of them can be strings)
for option in declaration['options']:
option['arguments'] = parse_arguments(option['arguments'])
option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
# Propagate defaults from declaration to options
for option in declaration['options']:
for k, v in declaration.items():
# TODO(zach): why does cwrap not propagate 'name'? I need it
# propagaged for ATen
if k != 'options':
option.setdefault(k, v)
# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.
Option = Dict[str, Any]
def filter_unique_options(
options: Iterable[Option],
allow_kwarg: bool,
type_to_signature: Dict[str, str],
remove_self: bool,
) -> List[Option]:
def exclude_arg(arg: Arg) -> bool:
return arg['type'] == 'CONSTANT' # type: ignore[no-any-return]
def exclude_arg_with_self_check(arg: Arg) -> bool:
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
def signature(option: Option, num_kwarg_only: int) -> str:
if num_kwarg_only == 0:
kwarg_only_count = None
else:
kwarg_only_count = -num_kwarg_only
arg_signature = '#'.join(
type_to_signature.get(arg['type'], arg['type'])
for arg in option['arguments'][:kwarg_only_count]
if not exclude_arg_with_self_check(arg))
if kwarg_only_count is None:
return arg_signature
kwarg_only_signature = '#'.join(
arg['name'] + '#' + arg['type']
for arg in option['arguments'][kwarg_only_count:]
if not exclude_arg(arg))
return arg_signature + "#-#" + kwarg_only_signature
seen_signatures = set()
unique = []
for option in options:
# if only check num_kwarg_only == 0 if allow_kwarg == False
limit = len(option['arguments']) if allow_kwarg else 0
for num_kwarg_only in range(0, limit + 1):
sig = signature(option, num_kwarg_only)
if sig not in seen_signatures:
if num_kwarg_only > 0:
for arg in option['arguments'][-num_kwarg_only:]:
arg['kwarg_only'] = True
unique.append(option)
seen_signatures.add(sig)
break
return unique
def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
def num_args(option: Option) -> int:
return len(option['arguments'])
declaration['options'].sort(key=num_args, reverse=reverse)
class Function(object):
def __init__(self, name: str) -> None:
self.name = name
self.arguments: List['Argument'] = []
def add_argument(self, arg: 'Argument') -> None:
assert isinstance(arg, Argument)
self.arguments.append(arg)
def __repr__(self) -> str:
return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')'
class Argument(object):
def __init__(self, _type: str, name: str, is_optional: bool):
self.type = _type
self.name = name
self.is_optional = is_optional
def __repr__(self) -> str:
return self.type + ' ' + self.name
def parse_header(path: str) -> List[Function]:
with open(path, 'r') as f:
lines: Iterable[Any] = f.read().split('\n')
# Remove empty lines and prebackend directives
lines = filter(lambda l: l and not l.startswith('#'), lines)
# Remove line comments
lines = (l.partition('//') for l in lines)
# Select line and comment part
lines = ((l[0].strip(), l[2].strip()) for l in lines)
# Remove trailing special signs
lines = ((l[0].rstrip(');').rstrip(','), l[1]) for l in lines)
# Split arguments
lines = ((l[0].split(','), l[1]) for l in lines)
# Flatten lines
new_lines = []
for l, c in lines:
for split in l:
new_lines.append((split, c))
lines = new_lines
del new_lines
# Remove unnecessary whitespace
lines = ((l[0].strip(), l[1]) for l in lines)
# Remove empty lines
lines = filter(lambda l: l[0], lines)
generic_functions = []
for l, c in lines:
if l.startswith('TH_API void THNN_'):
fn_name = l[len('TH_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l.startswith('TORCH_CUDA_CPP_API void THNN_'):
fn_name = l[len('TORCH_CUDA_CPP_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l.startswith('TORCH_CUDA_CU_API void THNN_'):
fn_name = l[len('TORCH_CUDA_CU_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l:
t, name = l.split()
if '*' in name:
t = t + '*'
name = name[1:]
generic_functions[-1].add_argument(
Argument(t, name, '[OPTIONAL]' in c))
return generic_functions