Skip to content

Commit c170b06

Browse files
Fix: distance of kmeans bad-defined
1 parent f85c2d0 commit c170b06

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

14.Cluster/KMeans.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fit(self, X):
2323
while (pre_centers != self.centers).any():
2424
pre_centers = self.centers.copy()
2525
dis = euc_dis(X[:, None, :], self.centers[None, :, :])
26-
cluster = dis.argmax(axis=-1)
26+
cluster = dis.argmin(axis=-1)
2727
for i in range(self.k):
2828
self.centers[i] = X[cluster == i].mean(axis=0)
2929
step += 1
@@ -32,7 +32,7 @@ def fit(self, X):
3232

3333
def predict(self, X):
3434
dis = euc_dis(X[:, None, :], self.centers[None, :, :])
35-
return dis.argmax(axis=-1)
35+
return dis.argmin(axis=-1)
3636

3737
if __name__ == "__main__":
3838
def demonstrate(X, k, desc):
@@ -46,15 +46,17 @@ def demonstrate(X, k, desc):
4646
plt.show()
4747

4848
# -------------------------- Example 1 ----------------------------------------
49-
X = np.array([[0, 0], [0, 1], [1, 0], [2, 2], [2, 1], [1, 2]])
50-
# generate grid-shaped test data
49+
X = np.array([[0, 0], [0, 1], [1, 0], [2, 2], [2, 1], [1, 2]]).astype(float)
5150
demonstrate(X, 2, "Example 1")
5251

5352
# -------------------------- Example 2 ----------------------------------------
5453
X = np.concatenate([
5554
np.random.normal([0, 0], [.3, .3], [100, 2]),
5655
np.random.normal([0, 1], [.3, .3], [100, 2]),
5756
np.random.normal([1, 0], [.3, .3], [100, 2]),
58-
])
59-
# generate grid-shaped test data
60-
demonstrate(X, 3, "Example 2: it is very sensitive to noise")
57+
]).astype(float)
58+
demonstrate(X, 3, "Example 2")
59+
60+
# -------------------------- Example 3 ----------------------------------------
61+
X = np.array([[0, 0], [0, 1], [0, 3]]).astype(float)
62+
demonstrate(X, 2, "Example 3: K-Means doesn't always return the best answer. (try to run multiple times!)")

0 commit comments

Comments
 (0)