1
- from modalities .training .gradient_clipping .fsdp_gradient_clipper import DummyGradientClipper , FSDP1GradientClipper , FSDP1LoggingOnlyGradientClipper , FSDP2GradientClipper , FSDP2LoggingOnlyGradientClipper , GradientClippingMode
2
- import torch
3
1
from unittest .mock import MagicMock
4
2
3
+ import torch
4
+
5
+ from modalities .training .gradient_clipping .fsdp_gradient_clipper import (
6
+ DummyGradientClipper ,
7
+ FSDP1GradientClipper ,
8
+ FSDP1LoggingOnlyGradientClipper ,
9
+ FSDP2GradientClipper ,
10
+ FSDP2LoggingOnlyGradientClipper ,
11
+ GradientClippingMode ,
12
+ )
13
+
14
+
5
15
class MockFSDP2Model :
6
16
def __init__ (self ):
7
17
self .param1 = torch .nn .Parameter (torch .tensor ([1.0 , 2.0 ]))
@@ -11,8 +21,11 @@ def __init__(self):
11
21
12
22
def parameters (self ):
13
23
return [self .param1 , self .param2 ]
24
+
25
+
14
26
# Note: Replace 'your_module' above with the correct module path where the gradient clipping classes are defined.
15
27
28
+
16
29
def test_fsdp1_gradient_clipper ():
17
30
"""
18
31
Test that FSDP1GradientClipper correctly calls the wrapped model's clip_grad_norm_ method
@@ -25,7 +38,7 @@ def test_fsdp1_gradient_clipper():
25
38
clipper = FSDP1GradientClipper (wrapped_model = mock_model , max_norm = max_norm , norm_type = norm_type )
26
39
27
40
# Call clip_gradients
28
- norm = clipper .clip_gradients ()
41
+ clipper .clip_gradients ()
29
42
30
43
# Verify that clip_grad_norm_ was called with the correct arguments
31
44
mock_model .clip_grad_norm_ .assert_called_once_with (max_norm = max_norm , norm_type = norm_type .value )
@@ -43,7 +56,7 @@ def test_fsdp1_logging_only_gradient_clipper():
43
56
clipper = FSDP1LoggingOnlyGradientClipper (wrapped_model = mock_model , norm_type = norm_type )
44
57
45
58
# Call clip_gradients
46
- norm = clipper .clip_gradients ()
59
+ clipper .clip_gradients ()
47
60
48
61
# Verify that clip_grad_norm_ was called with max_norm=torch.inf
49
62
mock_model .clip_grad_norm_ .assert_called_once_with (max_norm = torch .inf , norm_type = norm_type .value )
@@ -66,18 +79,14 @@ def test_fsdp2_clip_grad_norm():
66
79
67
80
# Test case 1: max_norm > total_norm (no clipping)
68
81
max_norm = expected_norm + 1 # 3.0
69
- norm = FSDP2GradientClipper .clip_grad_norm_ (
70
- parameters = parameters , max_norm = max_norm , norm_type = 2.0
71
- )
82
+ norm = FSDP2GradientClipper .clip_grad_norm_ (parameters = parameters , max_norm = max_norm , norm_type = 2.0 )
72
83
assert torch .allclose (norm , torch .tensor (expected_norm )), "Norm should match expected total norm"
73
84
assert torch .allclose (param1 .grad , torch .tensor ([1.0 , 1.0 ])), "Gradients should not be clipped"
74
85
assert torch .allclose (param2 .grad , torch .tensor ([1.0 , 1.0 ])), "Gradients should not be clipped"
75
86
76
87
# Test case 2: max_norm < total_norm (clipping occurs)
77
88
max_norm = expected_norm / 2 # 1.0
78
- norm = FSDP2GradientClipper .clip_grad_norm_ (
79
- parameters = parameters , max_norm = max_norm , norm_type = 2.0
80
- )
89
+ norm = FSDP2GradientClipper .clip_grad_norm_ (parameters = parameters , max_norm = max_norm , norm_type = 2.0 )
81
90
assert torch .allclose (norm , torch .tensor (expected_norm )), "Norm should match pre-clipping total norm"
82
91
scale = max_norm / expected_norm # 1.0 / 2.0 = 0.5
83
92
expected_grad = torch .tensor ([1.0 * scale , 1.0 * scale ])
@@ -91,7 +100,6 @@ def test_fsdp2_gradient_clipper():
91
100
"""
92
101
# Create a mock FSDP2 model with parameters
93
102
94
-
95
103
mock_model = MockFSDP2Model ()
96
104
max_norm = 1.0
97
105
norm_type = GradientClippingMode .P2_NORM
@@ -133,4 +141,4 @@ def test_dummy_gradient_clipper():
133
141
"""
134
142
clipper = DummyGradientClipper ()
135
143
norm = clipper .clip_gradients ()
136
- assert torch .allclose (norm , torch .tensor ([- 1.0 ])), "Norm should be -1.0 indicating no clipping"
144
+ assert torch .allclose (norm , torch .tensor ([- 1.0 ])), "Norm should be -1.0 indicating no clipping"
0 commit comments