Skip to content

Commit c3d28fe

Browse files
committed
dbscan
1 parent 8cdd48c commit c3d28fe

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

Clustering/DBSCAN.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
from scipy.spatial.distance import cdist
3+
import matplotlib.pyplot as plt
4+
from sklearn.cluster import DBSCAN
5+
6+
7+
class cDBSCAN(object):
8+
def __init__(self, min_pts=5, epsilon=0.5, metric='euclidean'):
9+
self.min_pts = min_pts
10+
self.epsilon = epsilon
11+
self.metric = metric
12+
13+
def fit(self, X, y=None):
14+
self.fit_predict(X, y)
15+
return self
16+
17+
def predict(self, X):
18+
pass
19+
20+
def fit_predict(self, X, y=None):
21+
n_samples, _ = X.shape
22+
nearin = cdist(X, X, metric=self.metric) <= self.epsilon
23+
near_num = np.sum(nearin, axis=1)
24+
core_ind = set(np.arange(n_samples)[near_num >= self.min_pts])
25+
print(core_ind)
26+
27+
n_clusters = 0
28+
this_set = set(range(n_samples))
29+
clusters = []
30+
31+
while core_ind:
32+
old_set = this_set.copy()
33+
ele = core_ind.pop()
34+
queue = [ele]
35+
this_set.remove(ele)
36+
while queue:
37+
q = queue.pop(0)
38+
if near_num[q] >= self.min_pts:
39+
dlt = this_set.intersection(np.arange(n_samples)[nearin[q]])
40+
queue.extend(dlt)
41+
this_set.difference_update(dlt)
42+
n_clusters += 1
43+
C = old_set.difference(this_set)
44+
clusters.append(C)
45+
core_ind.difference_update(C)
46+
labels = -1 * np.ones(n_samples, dtype=int)
47+
48+
for l, g in enumerate(clusters):
49+
labels[list(g)] = l
50+
self.labels = labels
51+
return labels
52+
53+
if __name__ == "__main__":
54+
np.random.seed(23)
55+
X = np.random.random((40, 2))
56+
cdb = cDBSCAN(epsilon=0.2)
57+
res = cdb.fit_predict(X)
58+
print(res)
59+
db = DBSCAN(eps=0.2)
60+
l = db.fit_predict(X)
61+
print(l)
62+
63+
fig, (a1, a2) = plt.subplots(1, 2)
64+
65+
a1.scatter(X[:, 0], X[:, 1], c=res)
66+
a2.scatter(X[:, 0], X[:, 1], c=l)
67+
plt.show()

0 commit comments

Comments
 (0)