diff --git a/top2vec/Top2Vec.py b/top2vec/Top2Vec.py index 0cdc437..c14b053 100644 --- a/top2vec/Top2Vec.py +++ b/top2vec/Top2Vec.py @@ -26,6 +26,13 @@ except ImportError: _HAVE_CUMAP = False +try: + from cuml.cluster import HDBSCAN as cuHDBSCAN + + _HAVE_CUHDBSCAN = True +except ImportError: + _HAVE_CUHDBSCAN = False + try: import hnswlib @@ -1369,13 +1376,19 @@ def compute_topics(self, 'metric': 'euclidean', 'cluster_selection_method': 'eom'} - cluster = hdbscan.HDBSCAN(**hdbscan_args).fit(umap_embedding) + if gpu_hdbscan and _HAVE_CUHDBSCAN: + cluster = cuHDBSCAN(**hdbscan_args) + labels = cluster.fit_predict(umap_embedding) + + else: + cluster = hdbscan.HDBSCAN(**hdbscan_args).fit(umap_embedding) + labels = cluster.labels_ # calculate topic vectors from dense areas of documents logger.info('Finding topics') # create topic vectors - self._create_topic_vectors(cluster.labels_) + self._create_topic_vectors(labels) # deduplicate topics self._deduplicate_topics(topic_merge_delta)