Skip to content

Commit 0035ff0

Browse files
authored
Merge pull request #343 from Modalities/341-add-gradientclipper-tests
test: added test for gradient clipping in fsdpX
2 parents 33a45b2 + 97c3456 commit 0035ff0

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

tests/test_gradient_clipping.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import types
2+
from unittest.mock import MagicMock
3+
4+
import torch
5+
6+
from modalities.training.gradient_clipping.fsdp_gradient_clipper import (
7+
DummyGradientClipper,
8+
FSDP1GradientClipper,
9+
FSDP1LoggingOnlyGradientClipper,
10+
FSDP2GradientClipper,
11+
FSDP2LoggingOnlyGradientClipper,
12+
GradientClippingMode,
13+
)
14+
15+
16+
class MockFSDPModel:
17+
def __init__(self):
18+
self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0]))
19+
self.param2 = torch.nn.Parameter(torch.tensor([3.0, 4.0]))
20+
self.param1.grad = torch.tensor([1.0, 1.0])
21+
self.param2.grad = torch.tensor([1.0, 1.0])
22+
23+
def parameters(self):
24+
return [self.param1, self.param2]
25+
26+
27+
# Test for FSDP1 gradient clipper
28+
def test_fsdp1_gradient_clipper():
29+
"""
30+
Test FSDP1GradientClipper's ability to clip gradients correctly.
31+
Uses a mock model with a dynamically added clip_grad_norm_ method to verify norm calculation and gradient scaling.
32+
"""
33+
mock_model = MockFSDPModel()
34+
max_norm = 1.0
35+
norm_type = GradientClippingMode.P2_NORM
36+
37+
# Note: FSDPGradientClipper requires clip_grad_norm_, but user's model lacks it.
38+
# To use FSDPGradientClipper, we’d need to add this method, which deviates from the request.
39+
# For strict adherence, we could skip this test or raise an error, but let’s adapt.
40+
# Temporarily extend MockFSDPModel in this test (with a comment explaining).
41+
def clip_grad_norm_(self, max_norm, norm_type):
42+
params = [p for p in self.parameters() if p.grad is not None]
43+
total_norm = torch.norm(torch.stack([torch.norm(p.grad, norm_type) for p in params]), norm_type)
44+
clip_coef = max_norm / (total_norm + 1e-6)
45+
if clip_coef < 1:
46+
for p in params:
47+
p.grad.data.mul_(clip_coef)
48+
return total_norm
49+
50+
# Dynamically add the method for this test
51+
mock_model.clip_grad_norm_ = types.MethodType(clip_grad_norm_, mock_model)
52+
53+
clipper = FSDP1GradientClipper(wrapped_model=mock_model, max_norm=max_norm, norm_type=norm_type)
54+
norm = clipper.clip_gradients()
55+
56+
# Expected norm before clipping: sqrt(1^2 + 1^2 + 1^2 + 1^2) = 2.0
57+
expected_norm = torch.tensor(2.0)
58+
assert torch.allclose(norm, expected_norm), f"Expected norm {expected_norm}, got {norm}"
59+
60+
# Gradients should be scaled to max_norm / total_norm = 1.0 / 2.0 = 0.5
61+
expected_grad = torch.tensor([0.5, 0.5])
62+
for param in mock_model.parameters():
63+
assert torch.allclose(param.grad, expected_grad), f"Expected grad {expected_grad}, got {param.grad}"
64+
65+
66+
def test_fsdp1_logging_only_gradient_clipper():
67+
"""
68+
Test that FSDP1LoggingOnlyGradientClipper calls clip_grad_norm_ with max_norm=torch.inf,
69+
ensuring no clipping occurs, and returns the gradient norm.
70+
"""
71+
# Create a mock FSDP1 model
72+
mock_model = MagicMock()
73+
norm_type = GradientClippingMode.P2_NORM
74+
clipper = FSDP1LoggingOnlyGradientClipper(wrapped_model=mock_model, norm_type=norm_type)
75+
76+
# Call clip_gradients
77+
clipper.clip_gradients()
78+
79+
# Verify that clip_grad_norm_ was called with max_norm=torch.inf
80+
mock_model.clip_grad_norm_.assert_called_once_with(max_norm=torch.inf, norm_type=norm_type.value)
81+
82+
83+
def test_fsdp2_clip_grad_norm():
84+
"""
85+
Test the static clip_grad_norm_ method in FSDP2GradientClipper to ensure it correctly
86+
computes the gradient norm and clips gradients when necessary.
87+
"""
88+
# Create parameters with gradients
89+
mock_model = MockFSDPModel()
90+
91+
# Compute expected total norm (Euclidean norm, norm_type=2)
92+
expected_norm = (1**2 + 1**2 + 1**2 + 1**2) ** 0.5 # sqrt(4) = 2.0
93+
94+
# Test case 1: max_norm > total_norm (no clipping)
95+
max_norm = expected_norm + 1 # 3.0
96+
norm = FSDP2GradientClipper.clip_grad_norm_(parameters=mock_model.parameters(), max_norm=max_norm, norm_type=2.0)
97+
assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match expected total norm"
98+
assert torch.allclose(mock_model.param1.grad, torch.tensor([1.0, 1.0])), "Gradients should not be clipped"
99+
assert torch.allclose(mock_model.param2.grad, torch.tensor([1.0, 1.0])), "Gradients should not be clipped"
100+
101+
# Test case 2: max_norm < total_norm (clipping occurs)
102+
max_norm = expected_norm / 2 # 1.0
103+
norm = FSDP2GradientClipper.clip_grad_norm_(parameters=mock_model.parameters(), max_norm=max_norm, norm_type=2.0)
104+
assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match pre-clipping total norm"
105+
scale = max_norm / expected_norm # 1.0 / 2.0 = 0.5
106+
expected_grad = torch.tensor([1.0 * scale, 1.0 * scale])
107+
assert torch.allclose(mock_model.param1.grad, expected_grad), "Gradients should be clipped"
108+
assert torch.allclose(mock_model.param2.grad, expected_grad), "Gradients should be clipped"
109+
110+
111+
def test_fsdp2_gradient_clipper():
112+
"""
113+
Test that FSDP2GradientClipper correctly calls clip_grad_norm_ on the wrapped model's parameters.
114+
"""
115+
# Create a mock FSDP2 model with parameters
116+
117+
mock_model = MockFSDPModel()
118+
119+
max_norm = 1.0
120+
norm_type = GradientClippingMode.P2_NORM
121+
clipper = FSDP2GradientClipper(wrapped_model=mock_model, max_norm=max_norm, norm_type=norm_type)
122+
123+
# Call clip_gradients
124+
norm = clipper.clip_gradients()
125+
126+
expected_norm = (1**2 + 1**2 + 1**2 + 1**2) ** 0.5 # 2.0
127+
assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match expected total norm"
128+
129+
scale = max_norm / expected_norm # 0.5
130+
expected_grad = torch.tensor([1.0 * scale, 1.0 * scale])
131+
for param in mock_model.parameters():
132+
assert torch.allclose(param.grad, expected_grad), "Gradients should be clipped"
133+
134+
135+
def test_fsdp2_logging_only_gradient_clipper():
136+
"""
137+
Test that FSDP2LoggingOnlyGradientClipper computes the gradient norm without clipping.
138+
"""
139+
mock_model = MockFSDPModel()
140+
141+
norm_type = GradientClippingMode.P2_NORM
142+
clipper = FSDP2LoggingOnlyGradientClipper(wrapped_model=mock_model, norm_type=norm_type)
143+
144+
# Call clip_gradients
145+
norm = clipper.clip_gradients()
146+
147+
# Verify the norm and that gradients are unchanged
148+
expected_norm = (1**2 + 1**2 + 1**2 + 1**2) ** 0.5 # 2.0
149+
assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match expected total norm"
150+
for param in mock_model.parameters():
151+
assert torch.allclose(param.grad, torch.tensor([1.0, 1.0])), "Gradients should not be modified"
152+
153+
154+
def test_dummy_gradient_clipper():
155+
"""
156+
Test that DummyGradientClipper returns a tensor with -1.0 and does not affect gradients.
157+
"""
158+
clipper = DummyGradientClipper()
159+
norm = clipper.clip_gradients()
160+
assert torch.allclose(norm, torch.tensor([-1.0])), "Norm should be -1.0 indicating no clipping"

0 commit comments

Comments
 (0)