@@ -59,6 +59,47 @@ def test_sharded_weights(self):
59
59
):
60
60
self .assertAllClose (v1 , v2 )
61
61
62
+ @pytest .mark .large
63
+ def test_disabled_sharding (self ):
64
+ init_kwargs = {
65
+ "vocabulary_size" : 1024 ,
66
+ "num_layers" : 12 ,
67
+ "num_query_heads" : 8 ,
68
+ "num_key_value_heads" : 4 ,
69
+ "hidden_dim" : 32 ,
70
+ "intermediate_dim" : 64 ,
71
+ "head_dim" : 4 ,
72
+ "sliding_window_size" : 5 ,
73
+ "attention_logit_soft_cap" : 50 ,
74
+ "final_logit_soft_cap" : 30 ,
75
+ "layer_norm_epsilon" : 1e-6 ,
76
+ "query_head_dim_normalize" : False ,
77
+ "use_post_ffw_norm" : True ,
78
+ "use_post_attention_norm" : True ,
79
+ "use_sliding_window_attention" : True ,
80
+ }
81
+ backbone = GemmaBackbone (** init_kwargs )
82
+
83
+ # Save the weights with `max_shard_size=None`
84
+ preset_dir = self .get_temp_dir ()
85
+ backbone .save_to_preset (preset_dir , max_shard_size = None )
86
+ self .assertTrue (
87
+ os .path .exists (os .path .join (preset_dir , "model.weights.h5" ))
88
+ )
89
+ self .assertFalse (
90
+ os .path .exists (os .path .join (preset_dir , "model.weights.json" ))
91
+ )
92
+ self .assertFalse (
93
+ os .path .exists (os .path .join (preset_dir , "model_00000.weights.h5" ))
94
+ )
95
+
96
+ # Load the weights.
97
+ revived_backbone = GemmaBackbone .from_preset (preset_dir )
98
+ for v1 , v2 in zip (
99
+ backbone .trainable_variables , revived_backbone .trainable_variables
100
+ ):
101
+ self .assertAllClose (v1 , v2 )
102
+
62
103
@pytest .mark .large
63
104
def test_preset_errors (self ):
64
105
with self .assertRaisesRegex (ValueError , "must be a string" ):
0 commit comments