Skip to content

Commit

Permalink
Fix: distance of kmeans bad-defined
Browse files Browse the repository at this point in the history
  • Loading branch information
qianmingxue-msft committed Apr 8, 2021
1 parent f85c2d0 commit c170b06
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions 14.Cluster/KMeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fit(self, X):
while (pre_centers != self.centers).any():
pre_centers = self.centers.copy()
dis = euc_dis(X[:, None, :], self.centers[None, :, :])
cluster = dis.argmax(axis=-1)
cluster = dis.argmin(axis=-1)
for i in range(self.k):
self.centers[i] = X[cluster == i].mean(axis=0)
step += 1
Expand All @@ -32,7 +32,7 @@ def fit(self, X):

def predict(self, X):
dis = euc_dis(X[:, None, :], self.centers[None, :, :])
return dis.argmax(axis=-1)
return dis.argmin(axis=-1)

if __name__ == "__main__":
def demonstrate(X, k, desc):
Expand All @@ -46,15 +46,17 @@ def demonstrate(X, k, desc):
plt.show()

# -------------------------- Example 1 ----------------------------------------
X = np.array([[0, 0], [0, 1], [1, 0], [2, 2], [2, 1], [1, 2]])
# generate grid-shaped test data
X = np.array([[0, 0], [0, 1], [1, 0], [2, 2], [2, 1], [1, 2]]).astype(float)
demonstrate(X, 2, "Example 1")

# -------------------------- Example 2 ----------------------------------------
X = np.concatenate([
np.random.normal([0, 0], [.3, .3], [100, 2]),
np.random.normal([0, 1], [.3, .3], [100, 2]),
np.random.normal([1, 0], [.3, .3], [100, 2]),
])
# generate grid-shaped test data
demonstrate(X, 3, "Example 2: it is very sensitive to noise")
]).astype(float)
demonstrate(X, 3, "Example 2")

# -------------------------- Example 3 ----------------------------------------
X = np.array([[0, 0], [0, 1], [0, 3]]).astype(float)
demonstrate(X, 2, "Example 3: K-Means doesn't always return the best answer. (try to run multiple times!)")

0 comments on commit c170b06

Please sign in to comment.