Skip to content

Commit

Permalink
Merge branch 'main' into mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Nov 6, 2024
2 parents 431f0f1 + 397ad17 commit 5172a14
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 382 deletions.
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/ddp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,15 @@
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# checks for checkpointing
Expand Down
5 changes: 4 additions & 1 deletion distributed_shampoo/examples/default_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def train_default_model(

# instantiate data loader. Note that this is a single GPU training example,
# so we do not need to instantiate a sampler.
data_loader, _ = get_data_loader_and_sampler(args.data_path, 1, 1, args.batch_size)
data_loader, _ = get_data_loader_and_sampler(args.data_path, 1, 0, args.batch_size)

# instantiate optimizer (SGD, Adam, DistributedShampoo)
optimizer = instantiate_optimizer(
Expand Down Expand Up @@ -139,12 +139,15 @@ def train_default_model(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# train model
Expand Down
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/fsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,15 @@
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# train model
Expand Down
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/fully_shard_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,15 @@ def create_model_and_optimizer_and_loss_fn(args, device):
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)
return model, optimizer, loss_function

Expand Down
3 changes: 3 additions & 0 deletions distributed_shampoo/examples/hsdp_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,15 @@
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)

# train model
Expand Down
62 changes: 62 additions & 0 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
RMSpropGraftingConfig,
SGDGraftingConfig,
)
from matrix_functions_types import (
CoupledHigherOrderConfig,
CoupledNewtonConfig,
EigenConfig,
EighEigenvalueCorrectionConfig,
PreconditionerComputationConfig,
)
from torch import nn
from torchvision import datasets, transforms # type: ignore[import-untyped]

Expand Down Expand Up @@ -62,6 +69,13 @@ class GraftingType(enum.Enum):
ADAM = 4


class PreconditionerComputationType(enum.Enum):
EIGEN_ROOT_INV = 0
COUPLED_NEWTON_ROOT_INV = 1
COUPLED_HIGHER_ORDER_ROOT_INV = 2
EIGH_EIGENVALUE_CORRECTION = 3


###### ARGPARSER ######
def enum_type_parse(s: str, enum_type: enum.Enum):
try:
Expand Down Expand Up @@ -195,6 +209,12 @@ def get_args():
action="store_true",
help="Use debug mode for examining root inverse residuals.",
)
parser.add_argument(
"--preconditioner-computation-type",
type=lambda t: enum_type_parse(t, PreconditionerComputationType),
default=PreconditionerComputationType.EIGEN_ROOT_INV,
help="Preconditioner computation method for Shampoo.",
)

# Arguments for grafting.
parser.add_argument(
Expand Down Expand Up @@ -235,6 +255,18 @@ def get_args():
default=DType.FP32,
help="Data type for storing Shampoo inverse factor matrices.",
)
parser.add_argument(
"--corrected-eigenvalues-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing corrected eigenvalues of Shampoo preconditioner.",
)
parser.add_argument(
"--factor-matrix-eigenvectors-dtype",
type=lambda t: enum_type_parse(t, DType),
default=DType.FP32,
help="Data type for storing Shampoo factor matrices eigenvectors.",
)
parser.add_argument(
"--filtered-grad-dtype",
type=lambda t: enum_type_parse(t, DType),
Expand Down Expand Up @@ -410,6 +442,7 @@ def instantiate_optimizer(
precision_config: Optional[PrecisionConfig],
use_protected_eigh: bool,
track_root_inv_residuals: bool,
preconditioner_computation_type: PreconditionerComputationType,
) -> torch.optim.Optimizer:
if optimizer_type == OptimizerType.SGD:
optimizer = torch.optim.SGD(
Expand Down Expand Up @@ -464,6 +497,9 @@ def instantiate_optimizer(
precision_config=precision_config,
use_protected_eigh=use_protected_eigh,
track_root_inv_residuals=track_root_inv_residuals,
preconditioner_computation_config=instantiate_preconditioner_computation_config(
preconditioner_computation_type
),
) # type: ignore[assignment]
else:
raise ValueError(f"Invalid OptimizerType {optimizer_type}!")
Expand Down Expand Up @@ -501,6 +537,32 @@ def instantiate_grafting_config(
raise ValueError(f"Invalid GraftingType {grafting_type}!")


def instantiate_preconditioner_computation_config(
preconditioner_computation_type: PreconditionerComputationType,
) -> PreconditionerComputationConfig:
if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV:
return EigenConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV
):
return CoupledNewtonConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV
):
return CoupledHigherOrderConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION
):
return EighEigenvalueCorrectionConfig()
else:
raise ValueError(
f"Invalid PreconditionerComputationType {preconditioner_computation_type}!"
)


###### DATA LOADER ######
def get_data_loader_and_sampler(
data_path: str, world_size: int, rank: int, local_batch_size: int
Expand Down
11 changes: 6 additions & 5 deletions distributed_shampoo/gpu_tests/shampoo_grafting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#!/usr/bin/env python3

import math
import unittest
from functools import partial
from itertools import product
Expand Down Expand Up @@ -118,7 +119,7 @@ def test_adagrad_grafting_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=1000,
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=False,
grafting_config=AdaGradGraftingConfig(
epsilon=1e-10,
Expand Down Expand Up @@ -153,7 +154,7 @@ def test_adam_grafting_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=1000,
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=False,
grafting_config=AdamGraftingConfig(
beta2=0.999,
Expand Down Expand Up @@ -189,7 +190,7 @@ def test_adamw_grafting_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=1000,
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
beta2=0.999,
Expand Down Expand Up @@ -228,7 +229,7 @@ def test_rmsprop_grafting_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=1000,
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=False,
grafting_config=RMSpropGraftingConfig(
beta2=0.99,
Expand Down Expand Up @@ -269,7 +270,7 @@ def test_sgd_grafting_on_quadratic(self) -> None:
epsilon=1e-10,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=1000,
start_preconditioning_step=math.inf,
use_nesterov=use_nesterov,
use_decoupled_weight_decay=False,
grafting_config=SGDGraftingConfig(),
Expand Down
16 changes: 7 additions & 9 deletions distributed_shampoo/tests/shampoo_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import itertools
from typing import Optional
from functools import reduce

import torch
from torch import nn
Expand All @@ -18,7 +18,7 @@ class _ModelWithLinearAndDeadLayers(nn.Module):
def __init__(
self,
model_linear_layers_dims: tuple[int, ...],
model_dead_layer_dims: Optional[tuple[int, ...]],
model_dead_layer_dims: tuple[int, ...] | None,
bias: bool = False,
) -> None:
super().__init__()
Expand All @@ -38,15 +38,13 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for linear_layer in self.linear_layers:
x = linear_layer(x)
return x
return reduce(lambda x, layer: layer(x), self.linear_layers, x)


def construct_training_problem(
model_linear_layers_dims: tuple[int, ...],
model_dead_layer_dims: Optional[tuple[int, ...]] = (10, 10),
device: Optional[torch.device] = None,
model_dead_layer_dims: tuple[int, ...] | None = (10, 10),
device: torch.device | None = None,
bias: bool = False,
fill: float | tuple[float, ...] = 0.0,
) -> tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]:
Expand All @@ -55,8 +53,8 @@ def construct_training_problem(
Args:
model_linear_layers_dims (tuple[int, ...]): The dimensions of the model linear layers.
model_dead_layer_dims (Optional[tuple[int, ...]]): The dimensions of the model dead linear layers. (Default: (10, 10))
device (Optional[torch.device]): The device to use. (Default: None)
model_dead_layer_dims (tuple[int, ...] | None): The dimensions of the model dead linear layers. (Default: (10, 10))
device (torch.device | None): The device to use. (Default: None)
bias (bool): Whether to use bias in the linear (non-dead) layers. (Default: False)
fill (float | tuple[float, ...]): The value(s) to fill the model parameters. If a tuple, each element should correspond to one layer. (Default: 0.0)
Expand Down
Loading

0 comments on commit 5172a14

Please sign in to comment.