Skip to content

Commit 26d3536

Browse files
committed
Add assertFileExists test method
1 parent a6cc559 commit 26d3536

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

integration_tests/model_visualization_test.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
import re
2-
from pathlib import Path
32

43
import keras
54
from keras.src import testing
65
from keras.src.utils import model_to_dot
76
from keras.src.utils import plot_model
87

98

10-
def assert_file_exists(path):
11-
assert Path(path).is_file(), "File does not exist"
12-
13-
149
def parse_text_from_html(html):
1510
pattern = r"<font[^>]*>(.*?)</font>"
1611
matches = re.findall(pattern, html)
@@ -61,11 +56,11 @@ def test_plot_sequential_model(self):
6156

6257
file_name = "sequential.png"
6358
plot_model(model, file_name)
64-
assert_file_exists(file_name)
59+
self.assertFileExists(file_name)
6560

6661
file_name = "sequential-show_shapes.png"
6762
plot_model(model, file_name, show_shapes=True)
68-
assert_file_exists(file_name)
63+
self.assertFileExists(file_name)
6964

7065
file_name = "sequential-show_shapes-show_dtype.png"
7166
plot_model(
@@ -74,7 +69,7 @@ def test_plot_sequential_model(self):
7469
show_shapes=True,
7570
show_dtype=True,
7671
)
77-
assert_file_exists(file_name)
72+
self.assertFileExists(file_name)
7873

7974
file_name = "sequential-show_shapes-show_dtype-show_layer_names.png"
8075
plot_model(
@@ -84,7 +79,7 @@ def test_plot_sequential_model(self):
8479
show_dtype=True,
8580
show_layer_names=True,
8681
)
87-
assert_file_exists(file_name)
82+
self.assertFileExists(file_name)
8883

8984
file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501
9085
plot_model(
@@ -95,7 +90,7 @@ def test_plot_sequential_model(self):
9590
show_layer_names=True,
9691
show_layer_activations=True,
9792
)
98-
assert_file_exists(file_name)
93+
self.assertFileExists(file_name)
9994

10095
file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
10196
plot_model(
@@ -107,7 +102,7 @@ def test_plot_sequential_model(self):
107102
show_layer_activations=True,
108103
show_trainable=True,
109104
)
110-
assert_file_exists(file_name)
105+
self.assertFileExists(file_name)
111106

112107
file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
113108
plot_model(
@@ -120,7 +115,7 @@ def test_plot_sequential_model(self):
120115
show_trainable=True,
121116
rankdir="LR",
122117
)
123-
assert_file_exists(file_name)
118+
self.assertFileExists(file_name)
124119

125120
file_name = "sequential-show_layer_activations-show_trainable.png"
126121
plot_model(
@@ -129,7 +124,7 @@ def test_plot_sequential_model(self):
129124
show_layer_activations=True,
130125
show_trainable=True,
131126
)
132-
assert_file_exists(file_name)
127+
self.assertFileExists(file_name)
133128

134129
def test_plot_functional_model(self):
135130
inputs = keras.Input((3,), name="input")
@@ -167,11 +162,11 @@ def test_plot_functional_model(self):
167162

168163
file_name = "functional.png"
169164
plot_model(model, file_name)
170-
assert_file_exists(file_name)
165+
self.assertFileExists(file_name)
171166

172167
file_name = "functional-show_shapes.png"
173168
plot_model(model, file_name, show_shapes=True)
174-
assert_file_exists(file_name)
169+
self.assertFileExists(file_name)
175170

176171
file_name = "functional-show_shapes-show_dtype.png"
177172
plot_model(
@@ -180,7 +175,7 @@ def test_plot_functional_model(self):
180175
show_shapes=True,
181176
show_dtype=True,
182177
)
183-
assert_file_exists(file_name)
178+
self.assertFileExists(file_name)
184179

185180
file_name = "functional-show_shapes-show_dtype-show_layer_names.png"
186181
plot_model(
@@ -190,7 +185,7 @@ def test_plot_functional_model(self):
190185
show_dtype=True,
191186
show_layer_names=True,
192187
)
193-
assert_file_exists(file_name)
188+
self.assertFileExists(file_name)
194189

195190
file_name = (
196191
"functional-show_shapes-show_dtype-show_layer_activations.png"
@@ -203,7 +198,7 @@ def test_plot_functional_model(self):
203198
show_layer_names=True,
204199
show_layer_activations=True,
205200
)
206-
assert_file_exists(file_name)
201+
self.assertFileExists(file_name)
207202

208203
file_name = "functional-show_shapes-show_dtype-show_layer_activations-show_trainable.png" # noqa: E501
209204
plot_model(
@@ -215,7 +210,7 @@ def test_plot_functional_model(self):
215210
show_layer_activations=True,
216211
show_trainable=True,
217212
)
218-
assert_file_exists(file_name)
213+
self.assertFileExists(file_name)
219214

220215
file_name = "functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
221216
plot_model(
@@ -228,7 +223,7 @@ def test_plot_functional_model(self):
228223
show_trainable=True,
229224
rankdir="LR",
230225
)
231-
assert_file_exists(file_name)
226+
self.assertFileExists(file_name)
232227

233228
file_name = "functional-show_layer_activations-show_trainable.png"
234229
plot_model(
@@ -237,7 +232,7 @@ def test_plot_functional_model(self):
237232
show_layer_activations=True,
238233
show_trainable=True,
239234
)
240-
assert_file_exists(file_name)
235+
self.assertFileExists(file_name)
241236

242237
file_name = (
243238
"functional-show_shapes-show_layer_activations-show_trainable.png"
@@ -249,7 +244,7 @@ def test_plot_functional_model(self):
249244
show_layer_activations=True,
250245
show_trainable=True,
251246
)
252-
assert_file_exists(file_name)
247+
self.assertFileExists(file_name)
253248

254249
def test_plot_subclassed_model(self):
255250
class MyModel(keras.Model):
@@ -266,11 +261,11 @@ def call(self, x):
266261

267262
file_name = "subclassed.png"
268263
plot_model(model, file_name)
269-
assert_file_exists(file_name)
264+
self.assertFileExists(file_name)
270265

271266
file_name = "subclassed-show_shapes.png"
272267
plot_model(model, file_name, show_shapes=True)
273-
assert_file_exists(file_name)
268+
self.assertFileExists(file_name)
274269

275270
file_name = "subclassed-show_shapes-show_dtype.png"
276271
plot_model(
@@ -279,7 +274,7 @@ def call(self, x):
279274
show_shapes=True,
280275
show_dtype=True,
281276
)
282-
assert_file_exists(file_name)
277+
self.assertFileExists(file_name)
283278

284279
file_name = "subclassed-show_shapes-show_dtype-show_layer_names.png"
285280
plot_model(
@@ -289,7 +284,7 @@ def call(self, x):
289284
show_dtype=True,
290285
show_layer_names=True,
291286
)
292-
assert_file_exists(file_name)
287+
self.assertFileExists(file_name)
293288

294289
file_name = (
295290
"subclassed-show_shapes-show_dtype-show_layer_activations.png"
@@ -302,7 +297,7 @@ def call(self, x):
302297
show_layer_names=True,
303298
show_layer_activations=True,
304299
)
305-
assert_file_exists(file_name)
300+
self.assertFileExists(file_name)
306301

307302
file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
308303
plot_model(
@@ -314,7 +309,7 @@ def call(self, x):
314309
show_layer_activations=True,
315310
show_trainable=True,
316311
)
317-
assert_file_exists(file_name)
312+
self.assertFileExists(file_name)
318313

319314
file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
320315
plot_model(
@@ -327,7 +322,7 @@ def call(self, x):
327322
show_trainable=True,
328323
rankdir="LR",
329324
)
330-
assert_file_exists(file_name)
325+
self.assertFileExists(file_name)
331326

332327
file_name = "subclassed-show_layer_activations-show_trainable.png"
333328
plot_model(
@@ -336,7 +331,7 @@ def call(self, x):
336331
show_layer_activations=True,
337332
show_trainable=True,
338333
)
339-
assert_file_exists(file_name)
334+
self.assertFileExists(file_name)
340335

341336
file_name = (
342337
"subclassed-show_shapes-show_layer_activations-show_trainable.png"
@@ -348,7 +343,7 @@ def call(self, x):
348343
show_layer_activations=True,
349344
show_trainable=True,
350345
)
351-
assert_file_exists(file_name)
346+
self.assertFileExists(file_name)
352347

353348
def test_plot_nested_functional_model(self):
354349
inputs = keras.Input((3,), name="input")
@@ -387,7 +382,7 @@ def test_plot_nested_functional_model(self):
387382

388383
file_name = "nested-functional.png"
389384
plot_model(model, file_name, expand_nested=True)
390-
assert_file_exists(file_name)
385+
self.assertFileExists(file_name)
391386

392387
file_name = "nested-functional-show_shapes.png"
393388
plot_model(
@@ -396,7 +391,7 @@ def test_plot_nested_functional_model(self):
396391
show_shapes=True,
397392
expand_nested=True,
398393
)
399-
assert_file_exists(file_name)
394+
self.assertFileExists(file_name)
400395

401396
file_name = "nested-functional-show_shapes-show_dtype.png"
402397
plot_model(
@@ -406,7 +401,7 @@ def test_plot_nested_functional_model(self):
406401
show_dtype=True,
407402
expand_nested=True,
408403
)
409-
assert_file_exists(file_name)
404+
self.assertFileExists(file_name)
410405

411406
file_name = (
412407
"nested-functional-show_shapes-show_dtype-show_layer_names.png"
@@ -419,7 +414,7 @@ def test_plot_nested_functional_model(self):
419414
show_layer_names=True,
420415
expand_nested=True,
421416
)
422-
assert_file_exists(file_name)
417+
self.assertFileExists(file_name)
423418

424419
file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501
425420
plot_model(
@@ -431,7 +426,7 @@ def test_plot_nested_functional_model(self):
431426
show_layer_activations=True,
432427
expand_nested=True,
433428
)
434-
assert_file_exists(file_name)
429+
self.assertFileExists(file_name)
435430

436431
file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
437432
plot_model(
@@ -444,7 +439,7 @@ def test_plot_nested_functional_model(self):
444439
show_trainable=True,
445440
expand_nested=True,
446441
)
447-
assert_file_exists(file_name)
442+
self.assertFileExists(file_name)
448443

449444
file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
450445
plot_model(
@@ -458,7 +453,7 @@ def test_plot_nested_functional_model(self):
458453
rankdir="LR",
459454
expand_nested=True,
460455
)
461-
assert_file_exists(file_name)
456+
self.assertFileExists(file_name)
462457

463458
file_name = (
464459
"nested-functional-show_layer_activations-show_trainable.png"
@@ -470,7 +465,7 @@ def test_plot_nested_functional_model(self):
470465
show_trainable=True,
471466
expand_nested=True,
472467
)
473-
assert_file_exists(file_name)
468+
self.assertFileExists(file_name)
474469

475470
file_name = "nested-functional-show_shapes-show_layer_activations-show_trainable.png" # noqa: E501
476471
plot_model(
@@ -481,7 +476,7 @@ def test_plot_nested_functional_model(self):
481476
show_trainable=True,
482477
expand_nested=True,
483478
)
484-
assert_file_exists(file_name)
479+
self.assertFileExists(file_name)
485480

486481
def test_plot_functional_model_with_splits_and_merges(self):
487482
class SplitLayer(keras.Layer):
@@ -518,7 +513,7 @@ def call(self, xs):
518513

519514
file_name = "split-functional.png"
520515
plot_model(model, file_name, expand_nested=True)
521-
assert_file_exists(file_name)
516+
self.assertFileExists(file_name)
522517

523518
file_name = "split-functional-show_shapes.png"
524519
plot_model(
@@ -527,7 +522,7 @@ def call(self, xs):
527522
show_shapes=True,
528523
expand_nested=True,
529524
)
530-
assert_file_exists(file_name)
525+
self.assertFileExists(file_name)
531526

532527
file_name = "split-functional-show_shapes-show_dtype.png"
533528
plot_model(
@@ -537,4 +532,4 @@ def call(self, xs):
537532
show_dtype=True,
538533
expand_nested=True,
539534
)
540-
assert_file_exists(file_name)
535+
self.assertFileExists(file_name)

keras/src/testing/test_case.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import shutil
33
import tempfile
44
import unittest
5+
from pathlib import Path
56

67
import numpy as np
78

@@ -113,6 +114,10 @@ def assertDType(self, x, dtype, msg=None):
113114
msg = msg or default_msg
114115
self.assertEqual(x_dtype, standardized_dtype, msg=msg)
115116

117+
def assertFileExists(self, path):
118+
if not Path(path).is_file():
119+
raise AssertionError(f"File {path} does not exist")
120+
116121
def run_class_serialization_test(self, instance, custom_objects=None):
117122
from keras.src.saving import custom_object_scope
118123
from keras.src.saving import deserialize_keras_object

0 commit comments

Comments
 (0)