-
Notifications
You must be signed in to change notification settings - Fork 0
/
unused_plot_callback.py
30 lines (30 loc) · 1.34 KB
/
unused_plot_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# class DisplayCallback(k.callbacks.Callback):
# def __init__(self, epoch_interval=None):
# self.epoch_interval = epoch_interval
#
# def on_epoch_end(self, epoch, logs=None):
# if self.epoch_interval and epoch % self.epoch_interval == 0:
# pred_masks = self.model.predict(X_test)
# pred_masks = tf.math.argmax(pred_masks, axis=-1)
# pred_masks = pred_masks[..., tf.newaxis]
#
# fig = make_subplots(rows=3, cols=3)
#
# for i in range(3):
# # Randomly select an image from the test batch
# random_index = random.randint(0, batch_size - 1)
# random_image = X_test[random_index]
# random_pred_mask = pred_masks[random_index]
# random_true_mask = y_test[random_index]
#
# fig.add_trace(go.Image(z=random_image * 255), i+1, 1)
# fig.add_trace(go.Image(z=to_image(np.squeeze(random_true_mask)) * 255), i+1, 2)
# fig.add_trace(go.Image(z=to_image(np.squeeze(random_pred_mask)) * 255), i+1, 3)
#
# img_bytes = fig.to_image(format="png")
# fp = io.BytesIO(img_bytes)
# with fp:
# i = mpimg.imread(fp, format='png')
# plt.axis('off')
# plt.imshow(i, interpolation='nearest')
# plt.show()