diff --git a/integration_tests/model_visualization_test.py b/integration_tests/model_visualization_test.py index cec3299e67d..14597d70ebb 100644 --- a/integration_tests/model_visualization_test.py +++ b/integration_tests/model_visualization_test.py @@ -1,5 +1,4 @@ import re -from pathlib import Path import keras from keras.src import testing @@ -7,10 +6,6 @@ from keras.src.utils import plot_model -def assert_file_exists(path): - assert Path(path).is_file(), "File does not exist" - - def parse_text_from_html(html): pattern = r"]*>(.*?)" matches = re.findall(pattern, html) @@ -61,11 +56,11 @@ def test_plot_sequential_model(self): file_name = "sequential.png" plot_model(model, file_name) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_shapes.png" plot_model(model, file_name, show_shapes=True) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_shapes-show_dtype.png" plot_model( @@ -74,7 +69,7 @@ def test_plot_sequential_model(self): show_shapes=True, show_dtype=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_shapes-show_dtype-show_layer_names.png" plot_model( @@ -84,7 +79,7 @@ def test_plot_sequential_model(self): show_dtype=True, show_layer_names=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501 plot_model( @@ -95,7 +90,7 @@ def test_plot_sequential_model(self): show_layer_names=True, show_layer_activations=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 plot_model( @@ -107,7 +102,7 @@ def test_plot_sequential_model(self): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 plot_model( @@ -120,7 +115,7 @@ def test_plot_sequential_model(self): show_trainable=True, rankdir="LR", ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "sequential-show_layer_activations-show_trainable.png" plot_model( @@ -129,7 +124,7 @@ def test_plot_sequential_model(self): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) def test_plot_functional_model(self): inputs = keras.Input((3,), name="input") @@ -167,11 +162,11 @@ def test_plot_functional_model(self): file_name = "functional.png" plot_model(model, file_name) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "functional-show_shapes.png" plot_model(model, file_name, show_shapes=True) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "functional-show_shapes-show_dtype.png" plot_model( @@ -180,7 +175,7 @@ def test_plot_functional_model(self): show_shapes=True, show_dtype=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "functional-show_shapes-show_dtype-show_layer_names.png" plot_model( @@ -190,7 +185,7 @@ def test_plot_functional_model(self): show_dtype=True, show_layer_names=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = ( "functional-show_shapes-show_dtype-show_layer_activations.png" @@ -203,7 +198,7 @@ def test_plot_functional_model(self): show_layer_names=True, show_layer_activations=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "functional-show_shapes-show_dtype-show_layer_activations-show_trainable.png" # noqa: E501 plot_model( @@ -215,7 +210,7 @@ def test_plot_functional_model(self): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 plot_model( @@ -228,7 +223,7 @@ def test_plot_functional_model(self): show_trainable=True, rankdir="LR", ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "functional-show_layer_activations-show_trainable.png" plot_model( @@ -237,7 +232,7 @@ def test_plot_functional_model(self): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = ( "functional-show_shapes-show_layer_activations-show_trainable.png" @@ -249,7 +244,7 @@ def test_plot_functional_model(self): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) def test_plot_subclassed_model(self): class MyModel(keras.Model): @@ -266,11 +261,11 @@ def call(self, x): file_name = "subclassed.png" plot_model(model, file_name) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "subclassed-show_shapes.png" plot_model(model, file_name, show_shapes=True) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "subclassed-show_shapes-show_dtype.png" plot_model( @@ -279,7 +274,7 @@ def call(self, x): show_shapes=True, show_dtype=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "subclassed-show_shapes-show_dtype-show_layer_names.png" plot_model( @@ -289,7 +284,7 @@ def call(self, x): show_dtype=True, show_layer_names=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = ( "subclassed-show_shapes-show_dtype-show_layer_activations.png" @@ -302,7 +297,7 @@ def call(self, x): show_layer_names=True, show_layer_activations=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 plot_model( @@ -314,7 +309,7 @@ def call(self, x): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 plot_model( @@ -327,7 +322,7 @@ def call(self, x): show_trainable=True, rankdir="LR", ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "subclassed-show_layer_activations-show_trainable.png" plot_model( @@ -336,7 +331,7 @@ def call(self, x): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = ( "subclassed-show_shapes-show_layer_activations-show_trainable.png" @@ -348,7 +343,7 @@ def call(self, x): show_layer_activations=True, show_trainable=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) def test_plot_nested_functional_model(self): inputs = keras.Input((3,), name="input") @@ -387,7 +382,7 @@ def test_plot_nested_functional_model(self): file_name = "nested-functional.png" plot_model(model, file_name, expand_nested=True) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "nested-functional-show_shapes.png" plot_model( @@ -396,7 +391,7 @@ def test_plot_nested_functional_model(self): show_shapes=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "nested-functional-show_shapes-show_dtype.png" plot_model( @@ -406,7 +401,7 @@ def test_plot_nested_functional_model(self): show_dtype=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = ( "nested-functional-show_shapes-show_dtype-show_layer_names.png" @@ -419,7 +414,7 @@ def test_plot_nested_functional_model(self): show_layer_names=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501 plot_model( @@ -431,7 +426,7 @@ def test_plot_nested_functional_model(self): show_layer_activations=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 plot_model( @@ -444,7 +439,7 @@ def test_plot_nested_functional_model(self): show_trainable=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 plot_model( @@ -458,7 +453,7 @@ def test_plot_nested_functional_model(self): rankdir="LR", expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = ( "nested-functional-show_layer_activations-show_trainable.png" @@ -470,7 +465,7 @@ def test_plot_nested_functional_model(self): show_trainable=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "nested-functional-show_shapes-show_layer_activations-show_trainable.png" # noqa: E501 plot_model( @@ -481,7 +476,7 @@ def test_plot_nested_functional_model(self): show_trainable=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) def test_plot_functional_model_with_splits_and_merges(self): class SplitLayer(keras.Layer): @@ -518,7 +513,7 @@ def call(self, xs): file_name = "split-functional.png" plot_model(model, file_name, expand_nested=True) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "split-functional-show_shapes.png" plot_model( @@ -527,7 +522,7 @@ def call(self, xs): show_shapes=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) file_name = "split-functional-show_shapes-show_dtype.png" plot_model( @@ -537,4 +532,4 @@ def call(self, xs): show_dtype=True, expand_nested=True, ) - assert_file_exists(file_name) + self.assertFileExists(file_name) diff --git a/keras/src/testing/test_case.py b/keras/src/testing/test_case.py index c276414bb5d..d5a8f7d779f 100644 --- a/keras/src/testing/test_case.py +++ b/keras/src/testing/test_case.py @@ -2,6 +2,7 @@ import shutil import tempfile import unittest +from pathlib import Path import numpy as np @@ -113,6 +114,10 @@ def assertDType(self, x, dtype, msg=None): msg = msg or default_msg self.assertEqual(x_dtype, standardized_dtype, msg=msg) + def assertFileExists(self, path): + if not Path(path).is_file(): + raise AssertionError(f"File {path} does not exist") + def run_class_serialization_test(self, instance, custom_objects=None): from keras.src.saving import custom_object_scope from keras.src.saving import deserialize_keras_object