Skip to content

Commit afdb010

Browse files
committed
refactor codes and fix bug of c_dim
1 parent 1fdb04b commit afdb010

File tree

4 files changed

+31
-31
lines changed

4 files changed

+31
-31
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ First, download dataset with:
3333

3434
To train a model with downloaded dataset:
3535

36-
$ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 --is_train
37-
$ python main.py --dataset celebA --input_height=108 --is_train --is_crop True
36+
$ python main.py --dataset mnist --input_height=28 --output_height=28 --train
37+
$ python main.py --dataset celebA --input_height=108 --train --crop
3838

3939
To test with an existing model:
4040

41-
$ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1
42-
$ python main.py --dataset celebA --input_height=108 --is_crop True
41+
$ python main.py --dataset mnist --input_height=28 --output_height=28
42+
$ python main.py --dataset celebA --input_height=108 --crop
4343

4444
Or, you can use your own dataset (without central crop) by:
4545

4646
$ mkdir data/DATASET_NAME
4747
... add images to data/DATASET_NAME ...
48-
$ python main.py --dataset DATASET_NAME --is_train
48+
$ python main.py --dataset DATASET_NAME --train
4949
$ python main.py --dataset DATASET_NAME
5050
$ # example
51-
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --c_dim=1 --is_train
51+
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train
5252

5353
## Results
5454

main.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
2222
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
2323
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
24-
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]")
25-
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]")
24+
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
25+
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
2626
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
2727
FLAGS = flags.FLAGS
2828

@@ -56,7 +56,7 @@ def main(_):
5656
y_dim=10,
5757
dataset_name=FLAGS.dataset,
5858
input_fname_pattern=FLAGS.input_fname_pattern,
59-
is_crop=FLAGS.is_crop,
59+
crop=FLAGS.crop,
6060
checkpoint_dir=FLAGS.checkpoint_dir,
6161
sample_dir=FLAGS.sample_dir)
6262
else:
@@ -70,15 +70,16 @@ def main(_):
7070
sample_num=FLAGS.batch_size,
7171
dataset_name=FLAGS.dataset,
7272
input_fname_pattern=FLAGS.input_fname_pattern,
73-
is_crop=FLAGS.is_crop,
73+
crop=FLAGS.crop,
7474
checkpoint_dir=FLAGS.checkpoint_dir,
7575
sample_dir=FLAGS.sample_dir)
7676

7777
show_all_variables()
78-
if FLAGS.is_train:
78+
79+
if FLAGS.train:
7980
dcgan.train(FLAGS)
8081
else:
81-
if not dcgan.load(FLAGS.checkpoint_dir):
82+
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
8283
raise Exception("[!] Train a model first, then run test mode")
8384

8485

model.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def conv_out_size_same(size, stride):
1414
return int(math.ceil(float(size) / float(stride)))
1515

1616
class DCGAN(object):
17-
def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
17+
def __init__(self, sess, input_height=108, input_width=108, crop=True,
1818
batch_size=64, sample_num = 64, output_height=64, output_width=64,
1919
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
2020
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
@@ -33,7 +33,7 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
3333
c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
3434
"""
3535
self.sess = sess
36-
self.is_crop = is_crop
36+
self.crop = crop
3737

3838
self.batch_size = batch_size
3939
self.sample_num = sample_num
@@ -52,7 +52,6 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
5252
self.gfc_dim = gfc_dim
5353
self.dfc_dim = dfc_dim
5454

55-
5655
# batch normalization : deals with poor initialization helps gradient flow
5756
self.d_bn1 = batch_norm(name='d_bn1')
5857
self.d_bn2 = batch_norm(name='d_bn2')
@@ -76,17 +75,17 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
7675
self.c_dim = self.data_X[0].shape[-1]
7776
else:
7877
self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern))
79-
self.c_dim = self.data[0].shape[-1]
78+
self.c_dim = imread(self.data[0]).shape[-1]
8079

81-
self.is_grayscale = (self.c_dim == 1)
80+
self.grayscale = (self.c_dim == 1)
8281

8382
self.build_model()
8483

8584
def build_model(self):
8685
if self.y_dim:
8786
self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')
8887

89-
if self.is_crop:
88+
if self.crop:
9089
image_dims = [self.output_height, self.output_width, self.c_dim]
9190
else:
9291
image_dims = [self.input_height, self.input_width, self.c_dim]
@@ -179,9 +178,9 @@ def train(self, config):
179178
input_width=self.input_width,
180179
resize_height=self.output_height,
181180
resize_width=self.output_width,
182-
is_crop=self.is_crop,
183-
is_grayscale=self.is_grayscale) for sample_file in sample_files]
184-
if (self.is_grayscale):
181+
crop=self.crop,
182+
grayscale=self.grayscale) for sample_file in sample_files]
183+
if (self.grayscale):
185184
sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
186185
else:
187186
sample_inputs = np.array(sample).astype(np.float32)
@@ -215,9 +214,9 @@ def train(self, config):
215214
input_width=self.input_width,
216215
resize_height=self.output_height,
217216
resize_width=self.output_width,
218-
is_crop=self.is_crop,
219-
is_grayscale=self.is_grayscale) for batch_file in batch_files]
220-
if (self.is_grayscale):
217+
crop=self.crop,
218+
grayscale=self.grayscale) for batch_file in batch_files]
219+
if self.grayscale:
221220
batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
222221
else:
223222
batch_images = np.array(batch).astype(np.float32)

utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ def show_all_variables():
2424

2525
def get_image(image_path, input_height, input_width,
2626
resize_height=64, resize_width=64,
27-
is_crop=True, is_grayscale=False):
28-
image = imread(image_path, is_grayscale)
27+
crop=True, grayscale=False):
28+
image = imread(image_path, grayscale)
2929
return transform(image, input_height, input_width,
30-
resize_height, resize_width, is_crop)
30+
resize_height, resize_width, crop)
3131

3232
def save_images(images, size, image_path):
3333
return imsave(inverse_transform(images), size, image_path)
3434

35-
def imread(path, is_grayscale = False):
36-
if (is_grayscale):
35+
def imread(path, grayscale = False):
36+
if (grayscale):
3737
return scipy.misc.imread(path, flatten = True).astype(np.float)
3838
else:
3939
return scipy.misc.imread(path).astype(np.float)
@@ -77,8 +77,8 @@ def center_crop(x, crop_h, crop_w,
7777
x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
7878

7979
def transform(image, input_height, input_width,
80-
resize_height=64, resize_width=64, is_crop=True):
81-
if is_crop:
80+
resize_height=64, resize_width=64, crop=True):
81+
if crop:
8282
cropped_image = center_crop(
8383
image, input_height, input_width,
8484
resize_height, resize_width)

0 commit comments

Comments
 (0)