Skip to content

Commit c2fca33

Browse files
authored
Support None for max_shard_size (#2261)
* Support None for max_shard_size * Add unit test
1 parent 275acef commit c2fca33

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

keras_hub/src/utils/preset_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,11 @@ def save_backbone(self, backbone, max_shard_size=10):
772772
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
773773
# If the size of the backbone is larger than `max_shard_size`, save
774774
# sharded weights.
775-
if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
775+
if (
776+
sharded_weights_available()
777+
and max_shard_size is not None
778+
and backbone_size_in_gb > max_shard_size
779+
):
776780
backbone_sharded_weights_config_path = os.path.join(
777781
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778782
)

keras_hub/src/utils/preset_utils_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,47 @@ def test_sharded_weights(self):
5959
):
6060
self.assertAllClose(v1, v2)
6161

62+
@pytest.mark.large
63+
def test_disabled_sharding(self):
64+
init_kwargs = {
65+
"vocabulary_size": 1024,
66+
"num_layers": 12,
67+
"num_query_heads": 8,
68+
"num_key_value_heads": 4,
69+
"hidden_dim": 32,
70+
"intermediate_dim": 64,
71+
"head_dim": 4,
72+
"sliding_window_size": 5,
73+
"attention_logit_soft_cap": 50,
74+
"final_logit_soft_cap": 30,
75+
"layer_norm_epsilon": 1e-6,
76+
"query_head_dim_normalize": False,
77+
"use_post_ffw_norm": True,
78+
"use_post_attention_norm": True,
79+
"use_sliding_window_attention": True,
80+
}
81+
backbone = GemmaBackbone(**init_kwargs)
82+
83+
# Save the weights with `max_shard_size=None`
84+
preset_dir = self.get_temp_dir()
85+
backbone.save_to_preset(preset_dir, max_shard_size=None)
86+
self.assertTrue(
87+
os.path.exists(os.path.join(preset_dir, "model.weights.h5"))
88+
)
89+
self.assertFalse(
90+
os.path.exists(os.path.join(preset_dir, "model.weights.json"))
91+
)
92+
self.assertFalse(
93+
os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5"))
94+
)
95+
96+
# Load the weights.
97+
revived_backbone = GemmaBackbone.from_preset(preset_dir)
98+
for v1, v2 in zip(
99+
backbone.trainable_variables, revived_backbone.trainable_variables
100+
):
101+
self.assertAllClose(v1, v2)
102+
62103
@pytest.mark.large
63104
def test_preset_errors(self):
64105
with self.assertRaisesRegex(ValueError, "must be a string"):

0 commit comments

Comments
 (0)