-
Notifications
You must be signed in to change notification settings - Fork 0
/
Gan.py
57 lines (46 loc) · 2.34 KB
/
Gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 7 16:02:33 2022
@author: Admin
"""
import tensorflow as tf
from tensorflow.keras.models import Model
class Gan(Model):
def __init__(self, discriminator, generator, latent_dim):
super(Gan, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(Gan, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
for _ in range(2):
## Train the discriminator
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
generated_images = self.generator(random_latent_vectors)
generated_labels = tf.zeros((batch_size, 1))
with tf.GradientTape() as ftape:
predictions = self.discriminator(generated_images)
d1_loss = self.loss_fn(generated_labels, predictions)
grads = ftape.gradient(d1_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
## Train the discriminator
labels = tf.ones((batch_size, 1))
with tf.GradientTape() as rtape:
predictions = self.discriminator(real_images)
d2_loss = self.loss_fn(labels, predictions)
grads = rtape.gradient(d2_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
## Train the generator
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
misleading_labels = tf.ones((batch_size, 1))
with tf.GradientTape() as gtape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = gtape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
return {"d1_loss": d1_loss, "d2_loss": d2_loss, "g_loss": g_loss}