-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcolor_generator.py
executable file
·92 lines (62 loc) · 2.1 KB
/
color_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import sys
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from itertools import permutations
from keras.models import load_model
import warnings
warnings.filterwarnings('ignore')
def generator_palette(generator, colors):
length = len(colors)
indexs = permutations(range(5), length)
fullPalettes = []
img = Image.new("RGB",(256,1),(0,0,0))
batch = int(256 / 5)
for k,index in enumerate(indexs):
im_fake = np.array(img)
for count,j in enumerate(index):
if j == 4:
im_fake[:,j*batch:,:] = colors[count]
else:
im_fake[:,j*batch:(j+1)*batch,:] = colors[count]
result = generator.predict(np.expand_dims(np.array(im_fake,dtype='float32') / 255 * 2 - 1,0))[0]
TempPalette = []
for i in range(5):
# if i not in index:
mean = np.mean(result[:,i*batch:(i+1)*batch,:],1)[0]
TempPalette.append(tuple(((mean+1)/2*255).clip(0,255).astype('uint8')))
# print(TempPalette)
fullPalettes.append(TempPalette)
return fullPalettes
def palette2img(palette):
Img = Image.new("RGB",(100,20),(0,0,0))
batch = int(100 / len(palette))
for i, color in enumerate(palette):
img = Image.new("RGB",(batch,20),color)
Img.paste(img, (batch*i, 0))
return Img
if __name__ == '__main__':
weight_path = sys.argv[1]
generator = load_model(weight_path)
nb = random.randint(1,2)
colors = []
for i in range(nb):
r = random.randint(0,255)
g = random.randint(0,255)
b = random.randint(0,255)
colors.append((r,g,b))
fullPalettes = generator_palette(generator, colors)
# print(fullPalettes)
lens = len(fullPalettes) + 1
plt.figure(figsize=(5, 10))
plt.subplot(lens, 1, 1)
plt.imshow(palette2img(colors))
plt.axis('off')
for j in range(1, lens):
plt.subplot(lens, 1, 1+j)
plt.imshow(palette2img(fullPalettes[j-1]))
plt.axis('off')
plt.show()