Skip to content

Commit

Permalink
Docstrings added to nearly all methods and classes where appropriprat…
Browse files Browse the repository at this point in the history
…e. Few unit tests added focused on loading and saving models
  • Loading branch information
jsschreck committed Aug 12, 2024
1 parent 4038e93 commit 2029ab1
Show file tree
Hide file tree
Showing 10 changed files with 843 additions and 93 deletions.
118 changes: 118 additions & 0 deletions mlguess/tests/test_torch_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import unittest
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import os
from mlguess.torch.checkpoint import TorchFSDPModel, FSDPOptimizerWrapper, TorchFSDPCheckpointIO


class TestFSDPCheckpointing(unittest.TestCase):
@classmethod
def setUpClass(cls):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'

if not torch.cuda.is_available():
raise unittest.SkipTest("GPU is not available. Skipping torch.distributed tests.")

dist.init_process_group("nccl", rank=0, world_size=1)

def setUp(self):
self.model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to("cuda")
self.fsdp_model = TorchFSDPModel(self.model)
self.optimizer = optim.SGD(self.fsdp_model.parameters(), lr=0.01)
self.fsdp_optimizer = FSDPOptimizerWrapper(self.optimizer, self.fsdp_model)
self.checkpoint_io = TorchFSDPCheckpointIO()
self.rank = dist.get_rank()

def test_save_and_load_model(self):
# Save model
self.checkpoint_io.save_unsharded_model(self.fsdp_model, "test_model_checkpoint.pth", rank=self.rank, gather_dtensor=True, use_safetensors=False)

# Load model
loaded_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to("cuda")
fsdp_loaded_model = TorchFSDPModel(loaded_model)
self.checkpoint_io.load_unsharded_model(fsdp_loaded_model, "test_model_checkpoint.pth")

# Compare models
original_state_dict = self.fsdp_model.state_dict()
loaded_state_dict = fsdp_loaded_model.state_dict()
for key in original_state_dict:
if key != '_flat_param':
self.assertTrue(torch.allclose(original_state_dict[key], loaded_state_dict[key]), f"Loaded model parameter {key} does not match original")

def test_save_and_load_optimizer(self):
# Save optimizer
self.checkpoint_io.save_unsharded_optimizer(self.fsdp_optimizer, "test_optimizer_checkpoint.pth", rank=self.rank, gather_dtensor=True)

# Load optimizer
loaded_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to("cuda")
fsdp_loaded_model = TorchFSDPModel(loaded_model)
loaded_optimizer = optim.SGD(fsdp_loaded_model.parameters(), lr=0.01)
fsdp_loaded_optimizer = FSDPOptimizerWrapper(loaded_optimizer, fsdp_loaded_model)
self.checkpoint_io.load_unsharded_optimizer(fsdp_loaded_optimizer, "test_optimizer_checkpoint.pth")

# Compare optimizers
original_state_dict = self.fsdp_optimizer.optim.state_dict()
loaded_state_dict = fsdp_loaded_optimizer.optim.state_dict()

self.assertEqual(original_state_dict['param_groups'], loaded_state_dict['param_groups'], "Optimizer param_groups do not match")

for key in original_state_dict['state']:
self.assertTrue(key in loaded_state_dict['state'], f"Key {key} not found in loaded optimizer state")
for param_key, param_value in original_state_dict['state'][key].items():
if isinstance(param_value, torch.Tensor):
self.assertTrue(torch.allclose(param_value, loaded_state_dict['state'][key][param_key]),
f"Optimizer state for key {key}, param {param_key} does not match")
else:
self.assertEqual(param_value, loaded_state_dict['state'][key][param_key],
f"Optimizer state for key {key}, param {param_key} does not match")

def test_load_model_outside_fsdp(self):
# Save model
self.checkpoint_io.save_unsharded_model(self.fsdp_model, "test_model_checkpoint.pth", rank=self.rank, gather_dtensor=True, use_safetensors=False)

# Load model outside FSDP context
indy_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to("cuda")
checkpoint = torch.load("test_model_checkpoint.pth")
indy_model.load_state_dict(checkpoint)

# Compare models
fsdp_state_dict = self.fsdp_model.state_dict()
indy_state_dict = indy_model.state_dict()

for (fsdp_key, fsdp_value), (indy_key, indy_value) in zip(fsdp_state_dict.items(), indy_state_dict.items()):
if '_flat_param' not in fsdp_key and '_flat_param' not in indy_key:
self.assertTrue(torch.allclose(fsdp_value, indy_value), f"Parameter {fsdp_key} does not match {indy_key}")

def tearDown(self):
# Clean up checkpoint files
if os.path.exists("test_model_checkpoint.pth"):
os.remove("test_model_checkpoint.pth")
if os.path.exists("test_optimizer_checkpoint.pth"):
os.remove("test_optimizer_checkpoint.pth")

@classmethod
def tearDownClass(cls):
dist.destroy_process_group()


if __name__ == '__main__':
unittest.main()
106 changes: 106 additions & 0 deletions mlguess/tests/test_torch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import unittest
import torch
import numpy as np
from mlguess.torch.models import DNN # Make sure to import your DNN class


class TestDNNModel(unittest.TestCase):

def test_evidential_regression_model(self):
for n_outputs in range(1, 3):
input_size = 10
x_train = torch.rand(10000, input_size)
y_train = torch.rand(10000, n_outputs)

model1 = DNN(input_size=input_size,
output_size=n_outputs,
layer_size=[100, 50],
dr=[0.2, 0.2],
batch_norm=True,
lng=True) # Using LinearNormalGamma for evidential regression

model2 = DNN(input_size=input_size,
output_size=n_outputs,
layer_size=[100, 50],
dr=[0.2, 0.2],
batch_norm=True,
lng=False) # Using LinearNormalGamma for evidential regression

# Train the model (you might need to implement a training loop)
# For simplicity, we're skipping the training step here

with torch.no_grad():
p_with_uncertainty = model1.predict(x_train, return_uncertainties=True)
p_without_uncertainty = model2.predict(x_train, return_uncertainties=False)

self.assertEqual(len(p_with_uncertainty), 4) # mu, aleatoric, epistemic, total
self.assertEqual(len(p_with_uncertainty), 4)
self.assertEqual(p_without_uncertainty.shape[-1], n_outputs)

# Save and load model
torch.save(model1.state_dict(), "test_evi_regression.pt")
loaded_model = DNN(
input_size=input_size,
output_size=n_outputs,
layer_size=[100, 50],
dr=[0.2, 0.2],
batch_norm=True,
lng=True) # Using LinearNormalGamma for evidential regression
loaded_model.load_state_dict(torch.load("test_evi_regression.pt"))

def test_standard_nn_regression(self):
for n_outputs in range(1, 3):
input_size = 10
x_train = torch.rand(10000, input_size)
y_train = torch.rand(10000, n_outputs)

model = DNN(input_size=input_size,
output_size=n_outputs,
layer_size=[100, 50],
dr=[0.2, 0.2],
batch_norm=True,
lng=False) # Standard neural network

# Train the model (you might need to implement a training loop)
# For simplicity, we're skipping the training step here

with torch.no_grad():
p_without_uncertainty = model.predict(x_train, return_uncertainties=False)

self.assertEqual(p_without_uncertainty.shape[-1], n_outputs)

# Save and load model
torch.save(model.state_dict(), "test_standard_regression.pt")
loaded_model = DNN(
input_size=input_size,
output_size=n_outputs,
layer_size=[100, 50],
dr=[0.2, 0.2],
batch_norm=True,
lng=False
) # Standard neural network
loaded_model.load_state_dict(torch.load("test_standard_regression.pt"))

def test_monte_carlo_dropout(self):
input_size = 10
output_size = 2
x_train = torch.rand(1000, input_size)

model = DNN(input_size=input_size,
output_size=output_size,
layer_size=[100, 50],
dr=[0.2, 0.2],
batch_norm=True,
lng=False)

pred_probs, aleatoric, epistemic, entropy, mutual_info = model.predict_dropout(x_train)

self.assertEqual(pred_probs.shape, (1000, output_size))
self.assertEqual(aleatoric.shape, (1000, output_size))
self.assertEqual(epistemic.shape, (1000, output_size))
self.assertEqual(entropy.shape, (1000,))
self.assertEqual(mutual_info.shape, (1000,))


if __name__ == "__main__":
unittest.main()
41 changes: 40 additions & 1 deletion mlguess/torch/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@
# utils

def load_model_state(conf, model, device):
"""
Load the model state from a checkpoint file.
This function restores the model state from a saved checkpoint. It supports loading models from
different distributed training modes such as Fully Sharded Data Parallel (FSDP), Distributed Data Parallel (DDP),
or a standard non-distributed setup. Depending on the configuration, it either loads the unsharded model for FSDP
or directly updates the model's state dictionary for other modes.
Args:
conf (dict): Configuration dictionary containing paths and mode information.
- `save_loc` (str): Location where the checkpoint is saved.
- `trainer` (dict): Contains `mode` key which indicates the distributed training mode.
model (torch.nn.Module): The model to load the state into.
device (torch.device): The device to load the model state onto.
Returns:
torch.nn.Module: The model with its state loaded from the checkpoint.
Raises:
FileNotFoundError: If the checkpoint file does not exist at the specified location.
KeyError: If the checkpoint file does not contain the expected keys.
"""

save_loc = os.path.expandvars(conf['save_loc'])
# Load an optimizer, gradient scaler, and learning rate scheduler, the optimizer must come after wrapping model using FSDP
ckpt = os.path.join(save_loc, "checkpoint.pt")
Expand Down Expand Up @@ -142,6 +165,22 @@ def is_safetensors_available() -> bool:


class TorchFSDPCheckpointIO:
"""
Handles loading and saving of checkpoints for models and optimizers
using Fully Sharded Data Parallel (FSDP) in PyTorch.
This class provides methods to load unsharded models and optimizers from
checkpoints, with special handling for FSDP models and optimizers. It
also manages the unwrapping of distributed models and the sharding of
optimizer state dictionaries.
Methods:
load_unsharded_model(model, checkpoint):
Loads the state dictionary into an unsharded model.
load_unsharded_optimizer(optimizer, checkpoint):
Loads the optimizer state dictionary into an unsharded optimizer.
"""
def __init__(self) -> None:
super().__init__()

Expand Down Expand Up @@ -175,7 +214,7 @@ def save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor, rank):
Save optimizer to checkpoint but only on master process.
"""
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
full_optimizer_state = FSDP.optim_state_dict(fsdp_model, optim=optimizer)
if rank == 0:
save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

Expand Down
Loading

0 comments on commit 2029ab1

Please sign in to comment.