Skip to content

Commit a41f00e

Browse files
janeyx99pytorchmergebot
authored andcommitted
[optim][sgd] group tensors in foreach to maximize perf (pytorch#92338)
Make foreach faster for SGD Pull Request resolved: pytorch#92338 Approved by: https://github.com/albanD
1 parent 98b78aa commit a41f00e

File tree

2 files changed

+56
-43
lines changed

2 files changed

+56
-43
lines changed

Diff for: torch/optim/sgd.py

+39-36
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch import Tensor
33
from .optimizer import Optimizer, required, _use_grad_for_differentiable
44
from typing import List, Optional
5+
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
56

67
__all__ = ['SGD', 'sgd']
78

@@ -271,48 +272,50 @@ def _multi_tensor_sgd(params: List[Tensor],
271272
if len(params) == 0:
272273
return
273274

274-
if has_sparse_grad is None:
275-
has_sparse_grad = any(grad.is_sparse for grad in grads)
275+
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
276+
for device_params, device_grads, device_momentum_buffer_list, indices in grouped_tensors.values():
277+
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
276278

277-
if maximize:
278-
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
279+
if maximize:
280+
device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]
279281

280-
if weight_decay != 0:
281-
grads = torch._foreach_add(grads, params, alpha=weight_decay)
282+
if weight_decay != 0:
283+
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
284+
285+
if momentum != 0:
286+
bufs = []
282287

283-
if momentum != 0:
284-
bufs = []
288+
all_states_with_momentum_buffer = True
289+
for i in range(len(device_momentum_buffer_list)):
290+
if device_momentum_buffer_list[i] is None:
291+
all_states_with_momentum_buffer = False
292+
break
293+
else:
294+
bufs.append(device_momentum_buffer_list[i])
285295

286-
all_states_with_momentum_buffer = True
287-
for i in range(len(momentum_buffer_list)):
288-
if momentum_buffer_list[i] is None:
289-
all_states_with_momentum_buffer = False
290-
break
296+
if all_states_with_momentum_buffer:
297+
torch._foreach_mul_(bufs, momentum)
298+
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
291299
else:
292-
bufs.append(momentum_buffer_list[i])
300+
bufs = []
301+
for i in range(len(device_momentum_buffer_list)):
302+
if device_momentum_buffer_list[i] is None:
303+
buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
304+
torch.clone(device_grads[i]).detach()
305+
else:
306+
buf = device_momentum_buffer_list[i]
307+
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
293308

294-
if all_states_with_momentum_buffer:
295-
torch._foreach_mul_(bufs, momentum)
296-
torch._foreach_add_(bufs, grads, alpha=1 - dampening)
297-
else:
298-
bufs = []
299-
for i in range(len(momentum_buffer_list)):
300-
if momentum_buffer_list[i] is None:
301-
buf = momentum_buffer_list[i] = torch.clone(grads[i]).detach()
302-
else:
303-
buf = momentum_buffer_list[i]
304-
buf.mul_(momentum).add_(grads[i], alpha=1 - dampening)
309+
bufs.append(buf)
305310

306-
bufs.append(buf)
311+
if nesterov:
312+
torch._foreach_add_(device_grads, bufs, alpha=momentum)
313+
else:
314+
device_grads = bufs
307315

308-
if nesterov:
309-
torch._foreach_add_(grads, bufs, alpha=momentum)
316+
if not device_has_sparse_grad:
317+
torch._foreach_add_(device_params, device_grads, alpha=-lr)
310318
else:
311-
grads = bufs
312-
313-
if not has_sparse_grad:
314-
torch._foreach_add_(params, grads, alpha=-lr)
315-
else:
316-
# foreach APIs dont support sparse
317-
for i in range(len(params)):
318-
params[i].add_(grads[i], alpha=-lr)
319+
# foreach APIs don't support sparse
320+
for i in range(len(device_params)):
321+
device_params[i].add_(device_grads[i], alpha=-lr)

Diff for: torch/utils/_foreach_utils.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,37 @@
11
from collections import defaultdict
2-
from typing import List, Dict, Tuple
2+
from typing import List, Dict, Tuple, Optional, Union
33

44
import torch
55
from torch import Tensor
66

77

8-
# This util function splits tensors into groups by device and dtype, which is useful before sending
9-
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
10-
# Currently, this function is only used in torch.optim.
8+
# _group_tensors_by_device_and_dtype is a util function that splits tensors into groups by device and dtype,
9+
# which is useful before sending tensors off to a foreach implementation, which requires tensors to be on
10+
# one device and dtype.
1111
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
1212
# - tensorlists CAN be None
1313
# - all tensors in the first specified list cannot be None
1414
# - given an index i, all specified tensorlist[i]s match in dtype and device
15+
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
16+
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
17+
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
18+
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
19+
# may be necessary. Check out torch/optim/sgd.py for an example.
1520
@torch.no_grad()
16-
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]]) -> Dict[Tuple[str, torch.dtype], List[List[Tensor]]]:
21+
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
22+
with_indices: Optional[bool] = False) -> \
23+
Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]]:
1724
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
1825
"all specified tensorlists must match in length")
19-
per_device_and_dtype_tensors: Dict[Tuple[str, torch.dtype], List[List[Tensor]]] = defaultdict(
20-
lambda: [[] for _ in range(len(tensorlistlist))])
26+
per_device_and_dtype_tensors: Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
27+
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
2128
for i, t in enumerate(tensorlistlist[0]):
2229
key = (str(t.device), t.dtype)
2330
for j in range(len(tensorlistlist)):
2431
# a tensorlist may be empty/None
2532
if tensorlistlist[j]:
2633
per_device_and_dtype_tensors[key][j].append(tensorlistlist[j][i])
34+
if with_indices:
35+
# tack on previous index
36+
per_device_and_dtype_tensors[key][j + 1].append(i)
2737
return per_device_and_dtype_tensors

0 commit comments

Comments
 (0)