forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_utils.py
559 lines (467 loc) · 22.2 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
import torch
from typing import Optional, List, DefaultDict, Any
import warnings
from collections import defaultdict
import sys
import traceback
def _type(self, dtype=None, non_blocking=False, **kwargs):
"""Returns the type if `dtype` is not provided, else casts this object to
the specified type.
If this is already of the correct type, no copy is performed and the
original object is returned.
Args:
dtype (type or string): The desired type
non_blocking (bool): If ``True``, and the source is in pinned memory
and destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument
has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument. The ``async`` arg is deprecated.
"""
non_blocking = _get_async_or_non_blocking('type', non_blocking, kwargs)
if dtype is None:
return self.__module__ + '.' + self.__class__.__name__
if isinstance(dtype, str):
dtype = _import_dotted_name(dtype)
if dtype == type(self):
return self
if self.is_sparse:
if not dtype.is_sparse:
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
new_module_name = dtype.__module__.replace('.sparse', '')
new_values_type_name = new_module_name + '.' + dtype.__name__
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
new_indices_type_name = new_module_name + '.LongTensor'
new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking)
return dtype(new_indices, new_values, self.size())
if dtype.is_sparse:
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
return dtype(self.size()).copy_(self, non_blocking)
def _cuda(self, device=None, non_blocking=False, **kwargs):
"""Returns a copy of this object in CUDA memory.
If this object is already in CUDA memory and on the correct device, then
no copy is performed and the original object is returned.
Args:
device (int): The destination GPU id. Defaults to the current device.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument.
"""
non_blocking = _get_async_or_non_blocking('cuda', non_blocking, kwargs)
if self.is_cuda:
if device is None:
device = torch.cuda.current_device()
if self.get_device() == device:
return self
else:
if device is None:
device = -1
with torch.cuda.device(device):
if self.is_sparse:
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
indices = torch.Tensor._indices(self).cuda(device, non_blocking)
values = torch.Tensor._values(self).cuda(device, non_blocking)
return new_type(indices, values, self.size())
else:
new_type = getattr(torch.cuda, self.__class__.__name__)
return new_type(self.size()).copy_(self, non_blocking)
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
if not kwargs:
return non_blocking
if len(kwargs) != 1 or 'async' not in kwargs:
message = "{}() got an unexpected keyword argument '{}'"
argument = list(kwargs.keys()).pop()
raise TypeError(message.format(function_name, argument))
warnings.warn("'async' is deprecated; use 'non_blocking'")
return kwargs['async']
# Note [Don't serialize hooks]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Since time immemorial, we have serialized the backward hooks associated with
# variables. This kind of half-worked--Python can pickle global functions
# (but not closures!)--but there were problems.
#
# - It's fragile. If you serialize a backward hook into a saved
# model, and then you rename the function associated with the hook,
# now your saved model is broken and you can't load it anymore.
#
# - It's not actually used. The standard recommendation is to
# serialize the *state_dict* of a model, not the model itself
# (since this is more stable to code changes affecting the model
# serialization), and the state dict saves "data" only, thus
# stripping the the backward hooks. In some cases, hooks are
# essential to the well-functioning of a model (e.g., DDP),
# but DDP already manages readding the hooks!
#
# - We didn't serialize them in many cases. Prior to #10220, we
# were dropping backward hooks in ForkingPickler. We "fixed" this
# to be convenient with other serialization sites, but lack of
# serializing backward hooks wasn't actually the root cause of
# the bug.
#
# With these cases in mind, we have decided that a better strategy
# is to just NOT serialize hooks at all.
#
# Since this is a BC-breaking change, we should warn when we previously
# serialized a hook, but no longer do so. This will be done by adding a special
# sentinel property to hooks will be used to suppress this warning. If a hook
# has the property _torch_serialize_ignore, we will not emit a warning if we
# attempt to serialize a Tensor with this hook attached to it.
#
# By the way, when _backward_hooks is skipped, we must give an EMPTY
# OrderedDict(), if you pass a None you'll run afoul #12219.
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
# be a TypedStorage
def _rebuild_tensor(storage, storage_offset, size, stride):
# first construct a tensor with the correct dtype/device
t = torch.tensor([], dtype=storage.dtype, device=storage._untyped().device)
return t.set_(storage._untyped(), storage_offset, size, stride)
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
tensor._backward_hooks = backward_hooks
return tensor
_sparse_tensors_to_validate: List["torch.Tensor"] = []
# In _legacy_load() in serialization.py we unpickle storages after the sparse
# tensors have been already unpickled. Those storages contain data necessary for
# validating sparse tensors: indices and values. That's why sparse tensors are
# first unpickled without any validation, and then this function is called just
# before _legacy_load() returns, so that all the sparse tensors can be validated
# in bulk.
#
# The same procedure must be followed by _load() in serialization.py because due
# to Pickler semantics, we have to use the same (non-validating) function for
# unpickling sparse tensors, regardless of the caller.
def _validate_loaded_sparse_tensors():
try:
for t in _sparse_tensors_to_validate:
torch._validate_sparse_coo_tensor_args(t._indices(), t._values(),
t.size())
finally:
_sparse_tensors_to_validate.clear()
def _rebuild_sparse_tensor(layout, data):
if layout == torch.sparse_coo:
indices, values, size = data
result = torch._sparse_coo_tensor_unsafe(indices, values, size)
_sparse_tensors_to_validate.append(result)
return result
raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
_rebuild_mlc_tensor = _rebuild_device_tensor_from_numpy
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
return torch.empty_strided(size, stride, dtype=dtype, device='meta', requires_grad=requires_grad)
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
# be a TypedStorage
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
qscheme = quantizer_params[0]
if qscheme == torch.per_tensor_affine:
_, scale, zero_point = quantizer_params
tensor = torch._empty_affine_quantized(size, scale=scale, zero_point=zero_point, dtype=storage.dtype)
elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
_, scales, zero_points, axis = quantizer_params
if type(scales) is list and type(zero_points) is list:
if qscheme == torch.per_channel_affine:
scales = torch.tensor(scales, dtype=torch.double)
zero_points = torch.tensor(zero_points, dtype=torch.long)
else:
scales = torch.tensor(scales, dtype=torch.float)
zero_points = torch.tensor(zero_points, dtype=torch.float)
tensor = torch._empty_per_channel_affine_quantized(
size, scales=scales, zero_points=zero_points, axis=axis, dtype=storage.dtype)
else:
raise RuntimeError("Can't deserialize quantized tensor with qscheme {}".format(qscheme))
tensor.set_(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
tensor._backward_hooks = backward_hooks
return tensor
def _rebuild_parameter(data, requires_grad, backward_hooks):
param = torch.nn.Parameter(data, requires_grad)
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
param._backward_hooks = backward_hooks
return param
def _import_dotted_name(name):
components = name.split('.')
obj = __import__(components[0])
for component in components[1:]:
obj = getattr(obj, component)
return obj
# Taken from python 3.5 docs
def _accumulate(iterable, fn=lambda x, y: x + y):
'Return running totals'
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total
def _flatten_dense_tensors(tensors):
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
same dense type.
Since inputs are dense, the resulting tensor will be a concatenated 1D
buffer. Element-wise operation on this buffer will be equivalent to
operating individually.
Args:
tensors (Iterable[Tensor]): dense tensors to flatten.
Returns:
A contiguous 1D buffer containing input tensors.
"""
return torch._C._nn.flatten_dense_tensors(tensors)
def _flatten_sparse_tensors(tensors):
"""Flatten sparse tensors into two contiguous 1D buffers, one of indices and
one of values. Assume tensors are of same sparse type.
Args:
tensors (Iterable[Tensor]): sparse tensors to flatten.
Returns:
A tuple of two contiguous 1D buffers, one containing input tensors'
indices and the other containing the values.
"""
flat_indices = torch._C._nn.flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors])
flat_values = torch._C._nn.flatten_dense_tensors([torch.Tensor._values(t) for t in tensors])
return flat_indices, flat_values
def _unflatten_dense_tensors(flat, tensors):
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
same dense type, and that flat is given by _flatten_dense_tensors.
Args:
flat (Tensor): flattened dense tensors to unflatten.
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
unflatten flat.
Returns:
Unflattened dense tensors with sizes same as tensors and values from
flat.
"""
return torch._C._nn.unflatten_dense_tensors(flat, tensors)
def _unflatten_sparse_tensors(flat, tensors):
"""View flat buffer (containing indices and values) using the sizes of
tensors. Assume that tensors are of same sparse type, and that flat is given
by _flatten_sparse_tensors.
Args:
flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
tensors to unflatten.
tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
unflatten flat.
Returns:
Unflattened sparse tensors with sizes same as tensors and values from
flat.
"""
flat_indices, flat_values = flat
indices = torch._C._nn.unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors])
values = torch._C._nn.unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors])
outputs = []
for t, i, v in zip(tensors, indices, values):
outputs.append(t.new(i, v, t.size()))
return tuple(outputs)
def _reorder_tensors_as(tensors, ordered_tensors):
"""Assume that tensors are of same order as ordered_tensors within their
types, e.g., from _take_tensors. Reorder them to be of same order as
ordered_tensors.
Args:
tensors (Iterable[Tensor]): tensors to be reordered. They should be of
the same order as ordered_tensors within their own types.
ordered_tensors (Iterable[Tensor]): tensors whose order will be the
reference.
Returns:
Ordered tuple of tensors with contents from tensors and order of
ordered_tensors.
"""
type_dict = defaultdict(list)
for tensor in tensors:
type_dict[tensor.type()].append(tensor)
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
def _take_tensors(tensors, size_limit):
"""Group tensors into chunks. This generator yields a chunk at each time,
each containing tensors of same type up to certain byte limit in total size.
Args:
tensors (Sequence): A sequence of tensors to be separated into chunks.
size_limit (int): The limit of each chunk in bytes.
Yields:
Blocks of tensors of same type and within size_limit. The yielded
tensors are only ordered as the original sequence within its types.
"""
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
for tensor in tensors:
t = tensor.type()
if tensor.is_sparse:
indices = torch.Tensor._indices(tensor)
values = torch.Tensor._values(tensor)
size = indices.numel() * indices.element_size() + values.numel() * values.element_size()
else:
size = tensor.numel() * tensor.element_size()
buf_and_size = buf_dict[t]
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
yield buf_and_size[0]
buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append(tensor)
buf_and_size[1] += size
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf
# annotation decorator to get annotations in a way that is compatible
# with both Python 2 and 3
def annotate(ret, **kwargs):
def dec(fun):
fun.__annotations__ = dict(kwargs)
fun.__annotations__['return'] = ret
return fun
return dec
# NOTE [ Python Traceback Reference Cycle Problem ]
#
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
# which is the traceback, because otherwise you will run into the traceback
# reference cycle problem, i.e., the traceback holding reference to the frame,
# and the frame (which holds reference to all the object in its temporary scope)
# holding reference the traceback.
class KeyErrorMessage(str):
r"""str subclass that returns itself in repr"""
def __repr__(self):
return self
class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info=None, where="in background"):
# It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ]
if exc_info is None:
exc_info = sys.exc_info()
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.where = where
def reraise(self):
r"""Reraises the wrapped exception in the current thread"""
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = "Caught {} {}.\nOriginal {}".format(
self.exc_type.__name__, self.where, self.exc_msg)
if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
# (https://bugs.python.org/issue2651), so we work around it.
msg = KeyErrorMessage(msg)
elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly
# have message field
raise self.exc_type(message=msg)
try:
exception = self.exc_type(msg)
except TypeError:
# If the exception takes multiple arguments, don't try to
# instantiate since we don't know how to
raise RuntimeError(msg) from None
raise exception
def _get_available_device_type():
if torch.cuda.is_available():
return "cuda"
# add more available device types here
return None
def _get_device_attr(get_member):
device_type = _get_available_device_type()
if device_type and device_type.lower() == "cuda":
return get_member(torch.cuda)
# add more available device types here
return None
def _get_current_device_index():
# current device index
return _get_device_attr(lambda m: m.current_device())
def _get_all_device_indices():
# all device index
return _get_device_attr(lambda m: list(range(m.device_count())))
def _get_devices_properties(device_ids):
# all device properties
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
def get_current_device_index() -> int:
r"""Checks if there are CUDA devices available and
returns the device index of the current default CUDA device.
Returns -1 in case there are no CUDA devices available.
Arguments: ``None``
"""
if torch.cuda.device_count() > 0:
return torch.cuda.current_device()
return -1
def _get_device_index(device: Any, optional: bool = False, allow_cpu: bool = False) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
has index. Note that for a device without a specified index,
i.e., ``torch.device('xxx')``, this will return the current default
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default
device of the supported runtime platform if :attr:`optional` is ``True``.
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
"""
if isinstance(device, str):
device = torch.device(device)
device_idx: Optional[int] = None
if isinstance(device, torch.device):
if not allow_cpu and device.type == 'cpu':
raise ValueError('Expected a non cpu device, but got: {}'.format(device))
device_idx = -1 if device.type == 'cpu' else device.index
if isinstance(device, int):
device_idx = device
if device_idx is None:
if optional:
# The eager API _get_current_device_index uses `lambda` functions which are
# not supported in JIT and hence not scriptable. The JIT equivalent API to get
# the current device index is `get_current_device_index()` which can
# be scripted. We use is_scripting to check the mode we are in and call the
# appropriate API.
if torch.jit.is_scripting():
device_idx = get_current_device_index()
else:
device_idx = _get_current_device_index()
else:
raise ValueError('Expected a torch.device with a specified index '
'or an integer, but got:{}'.format(device))
return device_idx
def _handle_complex(tensor):
"""
Returns a real view of a tensor if complex dtype else just the tensor
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
"""
return torch.view_as_real(tensor) if not isinstance(tensor,
torch.nn.UninitializedParameter) and tensor.is_complex() else tensor
def _element_size(dtype):
"""
Returns the element size for a dtype, in bytes
"""
if not isinstance(dtype, torch.dtype):
raise RuntimeError(f'expected torch.dtype, but got {type(dtype)}')
if dtype.is_complex:
return torch.finfo(dtype).bits >> 2
elif dtype.is_floating_point:
return torch.finfo(dtype).bits >> 3
elif dtype == torch.bool:
# NOTE: torch.bool is not supported in torch.iinfo()
return 1
else:
return torch.iinfo(dtype).bits >> 3
class _ClassPropertyDescriptor:
def __init__(self, fget, fset=None):
self.fget = fget
def __get__(self, instance, owner=None):
if owner is None:
owner = type(instance)
return self.fget.__get__(instance, owner)()
def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
return _ClassPropertyDescriptor(func)