Skip to content

Commit 8f5ed91

Browse files
author
Michael Fromm
committed
formatting fix
1 parent 061a7bc commit 8f5ed91

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

tests/test_gradient_clipping.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
1-
from modalities.training.gradient_clipping.fsdp_gradient_clipper import DummyGradientClipper, FSDP1GradientClipper, FSDP1LoggingOnlyGradientClipper, FSDP2GradientClipper, FSDP2LoggingOnlyGradientClipper, GradientClippingMode
2-
import torch
31
from unittest.mock import MagicMock
42

3+
import torch
4+
5+
from modalities.training.gradient_clipping.fsdp_gradient_clipper import (
6+
DummyGradientClipper,
7+
FSDP1GradientClipper,
8+
FSDP1LoggingOnlyGradientClipper,
9+
FSDP2GradientClipper,
10+
FSDP2LoggingOnlyGradientClipper,
11+
GradientClippingMode,
12+
)
13+
14+
515
class MockFSDP2Model:
616
def __init__(self):
717
self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0]))
@@ -11,8 +21,11 @@ def __init__(self):
1121

1222
def parameters(self):
1323
return [self.param1, self.param2]
24+
25+
1426
# Note: Replace 'your_module' above with the correct module path where the gradient clipping classes are defined.
1527

28+
1629
def test_fsdp1_gradient_clipper():
1730
"""
1831
Test that FSDP1GradientClipper correctly calls the wrapped model's clip_grad_norm_ method
@@ -25,7 +38,7 @@ def test_fsdp1_gradient_clipper():
2538
clipper = FSDP1GradientClipper(wrapped_model=mock_model, max_norm=max_norm, norm_type=norm_type)
2639

2740
# Call clip_gradients
28-
norm = clipper.clip_gradients()
41+
clipper.clip_gradients()
2942

3043
# Verify that clip_grad_norm_ was called with the correct arguments
3144
mock_model.clip_grad_norm_.assert_called_once_with(max_norm=max_norm, norm_type=norm_type.value)
@@ -43,7 +56,7 @@ def test_fsdp1_logging_only_gradient_clipper():
4356
clipper = FSDP1LoggingOnlyGradientClipper(wrapped_model=mock_model, norm_type=norm_type)
4457

4558
# Call clip_gradients
46-
norm = clipper.clip_gradients()
59+
clipper.clip_gradients()
4760

4861
# Verify that clip_grad_norm_ was called with max_norm=torch.inf
4962
mock_model.clip_grad_norm_.assert_called_once_with(max_norm=torch.inf, norm_type=norm_type.value)
@@ -66,18 +79,14 @@ def test_fsdp2_clip_grad_norm():
6679

6780
# Test case 1: max_norm > total_norm (no clipping)
6881
max_norm = expected_norm + 1 # 3.0
69-
norm = FSDP2GradientClipper.clip_grad_norm_(
70-
parameters=parameters, max_norm=max_norm, norm_type=2.0
71-
)
82+
norm = FSDP2GradientClipper.clip_grad_norm_(parameters=parameters, max_norm=max_norm, norm_type=2.0)
7283
assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match expected total norm"
7384
assert torch.allclose(param1.grad, torch.tensor([1.0, 1.0])), "Gradients should not be clipped"
7485
assert torch.allclose(param2.grad, torch.tensor([1.0, 1.0])), "Gradients should not be clipped"
7586

7687
# Test case 2: max_norm < total_norm (clipping occurs)
7788
max_norm = expected_norm / 2 # 1.0
78-
norm = FSDP2GradientClipper.clip_grad_norm_(
79-
parameters=parameters, max_norm=max_norm, norm_type=2.0
80-
)
89+
norm = FSDP2GradientClipper.clip_grad_norm_(parameters=parameters, max_norm=max_norm, norm_type=2.0)
8190
assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match pre-clipping total norm"
8291
scale = max_norm / expected_norm # 1.0 / 2.0 = 0.5
8392
expected_grad = torch.tensor([1.0 * scale, 1.0 * scale])
@@ -91,7 +100,6 @@ def test_fsdp2_gradient_clipper():
91100
"""
92101
# Create a mock FSDP2 model with parameters
93102

94-
95103
mock_model = MockFSDP2Model()
96104
max_norm = 1.0
97105
norm_type = GradientClippingMode.P2_NORM
@@ -133,4 +141,4 @@ def test_dummy_gradient_clipper():
133141
"""
134142
clipper = DummyGradientClipper()
135143
norm = clipper.clip_gradients()
136-
assert torch.allclose(norm, torch.tensor([-1.0])), "Norm should be -1.0 indicating no clipping"
144+
assert torch.allclose(norm, torch.tensor([-1.0])), "Norm should be -1.0 indicating no clipping"

0 commit comments

Comments
 (0)