-
-
Notifications
You must be signed in to change notification settings - Fork 23
/
mm_color_cluster.py
89 lines (64 loc) · 1.98 KB
/
mm_color_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Auto-clustering, suggested by Matt Terry
from skimage import io, color, exposure
from sklearn import cluster, preprocessing
import numpy as np
import matplotlib.pyplot as plt
url = 'http://blogs.mathworks.com/images/steve/2010/mms.jpg'
import os
if not os.path.exists('mm.jpg'):
print("Downloading M&M's...")
from urllib.request import urlretrieve
urlretrieve(url, 'mm.jpg')
print("Image I/O...")
mm = io.imread('mm.jpg')
mm_lab = color.rgb2lab(mm)
ab = mm_lab[..., 1:]
print("Mini-batch K-means...")
X = ab.reshape(-1, 2)
kmeans = cluster.MiniBatchKMeans(n_clusters=6)
y = kmeans.fit(X).labels_
labels = y.reshape(mm.shape[:2])
N = labels.max()
def no_ticks(ax):
ax.set_xticks([])
ax.set_yticks([])
# Display all clusters
for i in range(N):
mask = (labels == i)
mm_cluster = mm_lab.copy()
mm_cluster[..., 1:][~mask] = 0
ax = plt.subplot2grid((2, N), (1, i))
ax.imshow(color.lab2rgb(mm_cluster))
no_ticks(ax)
ax = plt.subplot2grid((2, N), (0, 0), colspan=2)
ax.imshow(mm)
no_ticks(ax)
# Display histogram
L, a, b = mm_lab.T
left, right = -100, 100
bins = np.arange(left, right)
H, x_edges, y_edges = np.histogram2d(a.flatten(), b.flatten(), bins,
normed=True)
ax = plt.subplot2grid((2, N), (0, 2))
H_bright = exposure.rescale_intensity(H, in_range=(0, 5e-4))
ax.imshow(H_bright,
extent=[left, right, right, left], cmap=plt.cm.gray)
ax.set_title('Histogram')
ax.set_xlabel('b')
ax.set_ylabel('a')
# Voronoi diagram
mid_bins = bins[:-1] + 0.5
L = len(mid_bins)
yy, xx = np.meshgrid(mid_bins, mid_bins)
Z = kmeans.predict(np.column_stack([xx.ravel(), yy.ravel()]))
Z = Z.reshape((L, L))
ax = plt.subplot2grid((2, N), (0, 3))
ax.imshow(Z, interpolation='nearest',
extent=[left, right, right, left],
cmap=plt.cm.Spectral, alpha=0.8)
ax.imshow(H_bright, alpha=0.2,
extent=[left, right, right, left],
cmap=plt.cm.gray)
ax.set_title('Clustered histogram')
no_ticks(ax)
plt.show()