Skip to content

Commit

Permalink
Refactor the input argument of BlockInfo.allocate_zeros_tensor() (#69)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #69

1. Change the first argument from `shape` to `size` so the default value of `BlockInfo.allocate_zeros_tensor` could be as simple as `partial(torch.zeros)`. This is also reported in OSS (#25).
2. Comments and typing fixes and improvements.

Reviewed By: chuanhaozhuge

Differential Revision: D67550399

fbshipit-source-id: d9a01fc6e361d7a006e5c5cb8522156f22cd27a9
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 22, 2024
1 parent 50d0649 commit 9a43622
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 40 deletions.
4 changes: 2 additions & 2 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _instantiate_momentum(self) -> None:
block_state = self.state[block_info.param][block_index]

block_state[MOMENTUM] = block_info.allocate_zeros_tensor(
shape=block.size(),
size=block.size(),
dtype=block.dtype,
device=block.device,
)
Expand Down Expand Up @@ -647,7 +647,7 @@ def _instantiate_filtered_grads(self) -> None:
block_state = self.state[block_info.param][block_index]

block_state[FILTERED_GRAD] = block_info.allocate_zeros_tensor(
shape=block.size(),
size=block.size(),
dtype=block.dtype,
device=block.device,
)
Expand Down
15 changes: 7 additions & 8 deletions distributed_shampoo/utils/shampoo_block_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from collections.abc import Callable
from dataclasses import dataclass
from functools import partial

import torch
from torch import Tensor
Expand All @@ -33,21 +34,19 @@ class BlockInfo:
parameter p1 on rank 0 will have the composable_block_ids being (0, "rank_0-block_0"), while block 0 of parameter p1
on rank 1 will have composable_block_ids being (0, "rank_1-block_0").
allocate_zeros_tensor (Callable): A function that returns a zero-initialized tensor.
allocate_zeros_tensor (Callable[..., Tensor]): A function that returns a zero-initialized tensor.
This tensor must be saved in the state dictionary for checkpointing.
This tensor might be DTensor. get_tensor() must be used to access the value.
Its function signature is (shape, dtype, device) -> Tensor.
(Default: lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device))
get_tensor (Callable): A function that takes a tensor allocated by allocator and returns its local tensor.
This tensor might be DTensor. get_tensor() must be used to access the value.
Its function signature is (size, dtype, device) -> Tensor.
(Default: lambda size, dtype, device: torch.zeros(size, dtype=dtype, device=device))
get_tensor (Callable[..., Tensor]): A function that takes a tensor allocated by allocator and returns its local tensor.
Its function signature is (tensor: Tensor) -> Tensor.
(Default: lambda tensor: tensor)
"""

param: Tensor
composable_block_ids: tuple[int, str]
allocate_zeros_tensor: Callable[..., Tensor] = (
lambda shape, dtype, device: torch.zeros(size=shape, dtype=dtype, device=device)
)
allocate_zeros_tensor: Callable[..., Tensor] = partial(torch.zeros)
get_tensor: Callable[..., Tensor] = lambda tensor_obj: tensor_obj


Expand Down
13 changes: 5 additions & 8 deletions distributed_shampoo/utils/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,20 +460,17 @@ def merge_and_block_gradients(

def _allocate_zeros_distributed_tensor(
self,
shape: tuple[int, ...],
size: tuple[int, ...],
dtype: torch.dtype,
device: torch.device,
group_source_rank: int,
) -> torch.Tensor:
"""Instantiates distributed tensor using DTensor.
Args:
shape (shape type accepted by torch.zeros() including tuple[int, ...]):
Shape of desired tensor.
dtype (dtype type accepted by torch.zeros() including torch.dtype):
DType of desired tensor.
device (device type accepted by torch.zeros() including torch.device):
Device of desired tensor.
size (tuple[int, ...]): Shape of desired tensor.
dtype (torch.dtype): DType of desired tensor.
device (torch.device): Device of desired tensor.
group_source_rank (int): Desired source rank of allocated zeros tensor within the process group.
Returns:
Expand All @@ -489,7 +486,7 @@ def _allocate_zeros_distributed_tensor(
)

return dtensor_zeros(
shape,
size,
dtype=dtype,
device_mesh=get_device_mesh(
device_type=device.type, mesh=device_mesh_ranks
Expand Down
13 changes: 5 additions & 8 deletions distributed_shampoo/utils/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,20 +913,17 @@ def block_within_tensor_shard_recovery(

def _allocate_zeros_distributed_tensor(
self,
shape: tuple[int, ...],
size: tuple[int, ...],
dtype: torch.dtype,
device: torch.device,
group_source_rank: int,
) -> torch.Tensor:
"""Instantiates distributed tensor using DTensor.
Args:
shape (shape type accepted by torch.zeros() including tuple[int, ...]):
Shape of desired tensor.
dtype (dtype type accepted by torch.zeros() including torch.dtype):
DType of desired tensor.
device (device type accepted by torch.zeros() including torch.device):
Device of desired tensor.
size (tuple[int, ...]): Shape of desired tensor.
dtype (torch.dtype): DType of desired tensor.
device (torch.device): Device of desired tensor.
group_source_rank (int): Group rank (with respect to the sharded group of
the 2D submesh) that determines which ranks the DTensor is allocated on.
Expand Down Expand Up @@ -958,7 +955,7 @@ def _allocate_zeros_distributed_tensor(
)[group_source_rank]

return dtensor_zeros(
shape,
size,
dtype=dtype,
device_mesh=replicate_submesh,
placements=[dtensor.Replicate()],
Expand Down
13 changes: 5 additions & 8 deletions distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,20 +580,17 @@ def merge_and_block_gradients(

def _allocate_zeros_distributed_tensor(
self,
shape: tuple[int, ...],
size: tuple[int, ...],
dtype: torch.dtype,
device: torch.device,
group_source_rank: int,
) -> torch.Tensor:
"""Instantiates distributed tensor using DTensor.
Args:
shape (shape type accepted by torch.zeros() including tuple[int, ...]):
Shape of desired tensor.
dtype (dtype type accepted by torch.zeros() including torch.dtype):
DType of desired tensor.
device (device type accepted by torch.zeros() including torch.device):
Device of desired tensor.
size (tuple[int, ...]): Shape of desired tensor.
dtype (torch.dtype): DType of desired tensor.
device (torch.device): Device of desired tensor.
group_source_rank (int): Group rank (with respect to the sharded group of
the 2D submesh) that determines which ranks the DTensor is allocated on.
Expand Down Expand Up @@ -625,7 +622,7 @@ def _allocate_zeros_distributed_tensor(
)[group_source_rank]

return dtensor_zeros(
shape,
size,
dtype=dtype,
device_mesh=replicate_submesh,
placements=[dtensor.Replicate()],
Expand Down
10 changes: 5 additions & 5 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
# Instantiate AdaGrad optimizer state for this block.
preconditioner_index = str(param_index) + "." + str(block_index)
block_state[ADAGRAD] = block_info.allocate_zeros_tensor(
shape=block.size(),
size=block.size(),
dtype=block.dtype,
device=block.device,
)
Expand Down Expand Up @@ -436,7 +436,7 @@ def _create_base_kronecker_factors(
"""
factor_matrices = tuple(
block_info.allocate_zeros_tensor(
shape=(dim, dim),
size=(dim, dim),
dtype=self._factor_matrix_dtype,
device=block_info.param.device,
)
Expand Down Expand Up @@ -773,7 +773,7 @@ def _create_kronecker_factors_state_for_block(
) -> ShampooKroneckerFactorsState:
inv_factor_matrices = tuple(
block_info.allocate_zeros_tensor(
shape=(dim, dim),
size=(dim, dim),
dtype=block.dtype,
device=block_info.param.device,
)
Expand Down Expand Up @@ -930,14 +930,14 @@ def _create_kronecker_factors_state_for_block(
) -> EigenvalueCorrectedShampooKroneckerFactorsState:
factor_matrices_eigenvectors = tuple(
block_info.allocate_zeros_tensor(
shape=(dim, dim),
size=(dim, dim),
dtype=block.dtype,
device=block_info.param.device,
)
for dim in dims
)
corrected_eigenvalues = block_info.allocate_zeros_tensor(
shape=tuple(dims),
size=tuple(dims),
dtype=block.dtype,
device=block_info.param.device,
)
Expand Down
2 changes: 1 addition & 1 deletion distributed_shampoo/utils/shampoo_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def init_from_dequantized_tensor(
block_info: BlockInfo,
) -> "QuantizedTensor":
quantized_values = block_info.allocate_zeros_tensor(
shape=dequantized_values.shape,
size=dequantized_values.shape,
dtype=quantized_dtype,
device=dequantized_values.device,
)
Expand Down

0 comments on commit 9a43622

Please sign in to comment.