Skip to content

Commit 9c92ba4

Browse files
committed
Add Keras version check.
1 parent 10a6368 commit 9c92ba4

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

keras_hub/src/utils/preset_utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from keras_hub.src.api_export import keras_hub_export
1414
from keras_hub.src.utils.keras_utils import print_msg
15+
from keras_hub.src.utils.keras_utils import sharded_weights_available
1516

1617
try:
1718
import kagglehub
@@ -743,6 +744,12 @@ def _load_backbone_weights(self, backbone):
743744
if has_single_file_weights:
744745
filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
745746
else:
747+
if not sharded_weights_available():
748+
raise RuntimeError(
749+
"Sharded weights loading is not supported in the current "
750+
f"Keras version {keras.__version__}. "
751+
"Please update to a newer version."
752+
)
746753
filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
747754
sharded_filenames = self._get_sharded_filenames(filepath)
748755
for sharded_filename in sharded_filenames:
@@ -768,7 +775,7 @@ def save_backbone(self, backbone, max_shard_size=None):
768775
# If the size of the backbone is larger than `MAX_SHARD_SIZE`, save
769776
# sharded weights.
770777
max_shard_size = max_shard_size or MAX_SHARD_SIZE
771-
if backbone_size_in_gb > max_shard_size:
778+
if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
772779
backbone_sharded_weights_config_path = os.path.join(
773780
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
774781
)

keras_hub/src/utils/preset_utils_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
1313
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
1414
from keras_hub.src.tests.test_case import TestCase
15+
from keras_hub.src.utils.keras_utils import sharded_weights_available
1516
from keras_hub.src.utils.preset_utils import CONFIG_FILE
1617
from keras_hub.src.utils.preset_utils import get_preset_saver
1718
from keras_hub.src.utils.preset_utils import upload_preset
@@ -20,6 +21,9 @@
2021
class PresetUtilsTest(TestCase):
2122
@pytest.mark.large
2223
def test_sharded_weights(self):
24+
if not sharded_weights_available():
25+
self.skipTest("Sharded weights are not available.")
26+
2327
# Gemma2 config.
2428
init_kwargs = {
2529
"vocabulary_size": 4096, # 256128

0 commit comments

Comments
 (0)