|
2 | 2 | from torch import Tensor
|
3 | 3 | from .optimizer import Optimizer, required, _use_grad_for_differentiable
|
4 | 4 | from typing import List, Optional
|
| 5 | +from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype |
5 | 6 |
|
6 | 7 | __all__ = ['SGD', 'sgd']
|
7 | 8 |
|
@@ -271,48 +272,50 @@ def _multi_tensor_sgd(params: List[Tensor],
|
271 | 272 | if len(params) == 0:
|
272 | 273 | return
|
273 | 274 |
|
274 |
| - if has_sparse_grad is None: |
275 |
| - has_sparse_grad = any(grad.is_sparse for grad in grads) |
| 275 | + grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True) |
| 276 | + for device_params, device_grads, device_momentum_buffer_list, indices in grouped_tensors.values(): |
| 277 | + device_has_sparse_grad = any(grad.is_sparse for grad in device_grads) |
276 | 278 |
|
277 |
| - if maximize: |
278 |
| - grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] |
| 279 | + if maximize: |
| 280 | + device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment] |
279 | 281 |
|
280 |
| - if weight_decay != 0: |
281 |
| - grads = torch._foreach_add(grads, params, alpha=weight_decay) |
| 282 | + if weight_decay != 0: |
| 283 | + device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay) |
| 284 | + |
| 285 | + if momentum != 0: |
| 286 | + bufs = [] |
282 | 287 |
|
283 |
| - if momentum != 0: |
284 |
| - bufs = [] |
| 288 | + all_states_with_momentum_buffer = True |
| 289 | + for i in range(len(device_momentum_buffer_list)): |
| 290 | + if device_momentum_buffer_list[i] is None: |
| 291 | + all_states_with_momentum_buffer = False |
| 292 | + break |
| 293 | + else: |
| 294 | + bufs.append(device_momentum_buffer_list[i]) |
285 | 295 |
|
286 |
| - all_states_with_momentum_buffer = True |
287 |
| - for i in range(len(momentum_buffer_list)): |
288 |
| - if momentum_buffer_list[i] is None: |
289 |
| - all_states_with_momentum_buffer = False |
290 |
| - break |
| 296 | + if all_states_with_momentum_buffer: |
| 297 | + torch._foreach_mul_(bufs, momentum) |
| 298 | + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) |
291 | 299 | else:
|
292 |
| - bufs.append(momentum_buffer_list[i]) |
| 300 | + bufs = [] |
| 301 | + for i in range(len(device_momentum_buffer_list)): |
| 302 | + if device_momentum_buffer_list[i] is None: |
| 303 | + buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \ |
| 304 | + torch.clone(device_grads[i]).detach() |
| 305 | + else: |
| 306 | + buf = device_momentum_buffer_list[i] |
| 307 | + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) |
293 | 308 |
|
294 |
| - if all_states_with_momentum_buffer: |
295 |
| - torch._foreach_mul_(bufs, momentum) |
296 |
| - torch._foreach_add_(bufs, grads, alpha=1 - dampening) |
297 |
| - else: |
298 |
| - bufs = [] |
299 |
| - for i in range(len(momentum_buffer_list)): |
300 |
| - if momentum_buffer_list[i] is None: |
301 |
| - buf = momentum_buffer_list[i] = torch.clone(grads[i]).detach() |
302 |
| - else: |
303 |
| - buf = momentum_buffer_list[i] |
304 |
| - buf.mul_(momentum).add_(grads[i], alpha=1 - dampening) |
| 309 | + bufs.append(buf) |
305 | 310 |
|
306 |
| - bufs.append(buf) |
| 311 | + if nesterov: |
| 312 | + torch._foreach_add_(device_grads, bufs, alpha=momentum) |
| 313 | + else: |
| 314 | + device_grads = bufs |
307 | 315 |
|
308 |
| - if nesterov: |
309 |
| - torch._foreach_add_(grads, bufs, alpha=momentum) |
| 316 | + if not device_has_sparse_grad: |
| 317 | + torch._foreach_add_(device_params, device_grads, alpha=-lr) |
310 | 318 | else:
|
311 |
| - grads = bufs |
312 |
| - |
313 |
| - if not has_sparse_grad: |
314 |
| - torch._foreach_add_(params, grads, alpha=-lr) |
315 |
| - else: |
316 |
| - # foreach APIs dont support sparse |
317 |
| - for i in range(len(params)): |
318 |
| - params[i].add_(grads[i], alpha=-lr) |
| 319 | + # foreach APIs don't support sparse |
| 320 | + for i in range(len(device_params)): |
| 321 | + device_params[i].add_(device_grads[i], alpha=-lr) |
0 commit comments