Skip to content

Commit

Permalink
Add assertFileExists test method
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 5, 2024
1 parent a6cc559 commit 26d3536
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 43 deletions.
81 changes: 38 additions & 43 deletions integration_tests/model_visualization_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import re
from pathlib import Path

import keras
from keras.src import testing
from keras.src.utils import model_to_dot
from keras.src.utils import plot_model


def assert_file_exists(path):
assert Path(path).is_file(), "File does not exist"


def parse_text_from_html(html):
pattern = r"<font[^>]*>(.*?)</font>"
matches = re.findall(pattern, html)
Expand Down Expand Up @@ -61,11 +56,11 @@ def test_plot_sequential_model(self):

file_name = "sequential.png"
plot_model(model, file_name)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_shapes.png"
plot_model(model, file_name, show_shapes=True)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_shapes-show_dtype.png"
plot_model(
Expand All @@ -74,7 +69,7 @@ def test_plot_sequential_model(self):
show_shapes=True,
show_dtype=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names.png"
plot_model(
Expand All @@ -84,7 +79,7 @@ def test_plot_sequential_model(self):
show_dtype=True,
show_layer_names=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501
plot_model(
Expand All @@ -95,7 +90,7 @@ def test_plot_sequential_model(self):
show_layer_names=True,
show_layer_activations=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
Expand All @@ -107,7 +102,7 @@ def test_plot_sequential_model(self):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
Expand All @@ -120,7 +115,7 @@ def test_plot_sequential_model(self):
show_trainable=True,
rankdir="LR",
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "sequential-show_layer_activations-show_trainable.png"
plot_model(
Expand All @@ -129,7 +124,7 @@ def test_plot_sequential_model(self):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

def test_plot_functional_model(self):
inputs = keras.Input((3,), name="input")
Expand Down Expand Up @@ -167,11 +162,11 @@ def test_plot_functional_model(self):

file_name = "functional.png"
plot_model(model, file_name)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "functional-show_shapes.png"
plot_model(model, file_name, show_shapes=True)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "functional-show_shapes-show_dtype.png"
plot_model(
Expand All @@ -180,7 +175,7 @@ def test_plot_functional_model(self):
show_shapes=True,
show_dtype=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_names.png"
plot_model(
Expand All @@ -190,7 +185,7 @@ def test_plot_functional_model(self):
show_dtype=True,
show_layer_names=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = (
"functional-show_shapes-show_dtype-show_layer_activations.png"
Expand All @@ -203,7 +198,7 @@ def test_plot_functional_model(self):
show_layer_names=True,
show_layer_activations=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
Expand All @@ -215,7 +210,7 @@ def test_plot_functional_model(self):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
Expand All @@ -228,7 +223,7 @@ def test_plot_functional_model(self):
show_trainable=True,
rankdir="LR",
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "functional-show_layer_activations-show_trainable.png"
plot_model(
Expand All @@ -237,7 +232,7 @@ def test_plot_functional_model(self):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = (
"functional-show_shapes-show_layer_activations-show_trainable.png"
Expand All @@ -249,7 +244,7 @@ def test_plot_functional_model(self):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

def test_plot_subclassed_model(self):
class MyModel(keras.Model):
Expand All @@ -266,11 +261,11 @@ def call(self, x):

file_name = "subclassed.png"
plot_model(model, file_name)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "subclassed-show_shapes.png"
plot_model(model, file_name, show_shapes=True)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "subclassed-show_shapes-show_dtype.png"
plot_model(
Expand All @@ -279,7 +274,7 @@ def call(self, x):
show_shapes=True,
show_dtype=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_names.png"
plot_model(
Expand All @@ -289,7 +284,7 @@ def call(self, x):
show_dtype=True,
show_layer_names=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = (
"subclassed-show_shapes-show_dtype-show_layer_activations.png"
Expand All @@ -302,7 +297,7 @@ def call(self, x):
show_layer_names=True,
show_layer_activations=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
Expand All @@ -314,7 +309,7 @@ def call(self, x):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
Expand All @@ -327,7 +322,7 @@ def call(self, x):
show_trainable=True,
rankdir="LR",
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "subclassed-show_layer_activations-show_trainable.png"
plot_model(
Expand All @@ -336,7 +331,7 @@ def call(self, x):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = (
"subclassed-show_shapes-show_layer_activations-show_trainable.png"
Expand All @@ -348,7 +343,7 @@ def call(self, x):
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

def test_plot_nested_functional_model(self):
inputs = keras.Input((3,), name="input")
Expand Down Expand Up @@ -387,7 +382,7 @@ def test_plot_nested_functional_model(self):

file_name = "nested-functional.png"
plot_model(model, file_name, expand_nested=True)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "nested-functional-show_shapes.png"
plot_model(
Expand All @@ -396,7 +391,7 @@ def test_plot_nested_functional_model(self):
show_shapes=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "nested-functional-show_shapes-show_dtype.png"
plot_model(
Expand All @@ -406,7 +401,7 @@ def test_plot_nested_functional_model(self):
show_dtype=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = (
"nested-functional-show_shapes-show_dtype-show_layer_names.png"
Expand All @@ -419,7 +414,7 @@ def test_plot_nested_functional_model(self):
show_layer_names=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501
plot_model(
Expand All @@ -431,7 +426,7 @@ def test_plot_nested_functional_model(self):
show_layer_activations=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
Expand All @@ -444,7 +439,7 @@ def test_plot_nested_functional_model(self):
show_trainable=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
Expand All @@ -458,7 +453,7 @@ def test_plot_nested_functional_model(self):
rankdir="LR",
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = (
"nested-functional-show_layer_activations-show_trainable.png"
Expand All @@ -470,7 +465,7 @@ def test_plot_nested_functional_model(self):
show_trainable=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "nested-functional-show_shapes-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
Expand All @@ -481,7 +476,7 @@ def test_plot_nested_functional_model(self):
show_trainable=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

def test_plot_functional_model_with_splits_and_merges(self):
class SplitLayer(keras.Layer):
Expand Down Expand Up @@ -518,7 +513,7 @@ def call(self, xs):

file_name = "split-functional.png"
plot_model(model, file_name, expand_nested=True)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "split-functional-show_shapes.png"
plot_model(
Expand All @@ -527,7 +522,7 @@ def call(self, xs):
show_shapes=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)

file_name = "split-functional-show_shapes-show_dtype.png"
plot_model(
Expand All @@ -537,4 +532,4 @@ def call(self, xs):
show_dtype=True,
expand_nested=True,
)
assert_file_exists(file_name)
self.assertFileExists(file_name)
5 changes: 5 additions & 0 deletions keras/src/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import tempfile
import unittest
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -113,6 +114,10 @@ def assertDType(self, x, dtype, msg=None):
msg = msg or default_msg
self.assertEqual(x_dtype, standardized_dtype, msg=msg)

def assertFileExists(self, path):
if not Path(path).is_file():
raise AssertionError(f"File {path} does not exist")

def run_class_serialization_test(self, instance, custom_objects=None):
from keras.src.saving import custom_object_scope
from keras.src.saving import deserialize_keras_object
Expand Down

0 comments on commit 26d3536

Please sign in to comment.