diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 355a53f6..0f174a18 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -1,9 +1,9 @@ import hdbscan import numpy as np - + def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): - """ Function used to select the HDBSCAN-like model for generating + """ Function used to select the HDBSCAN-like model for generating predictions and probabilities. Arguments: @@ -42,7 +42,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): return cuml_hdbscan.all_points_membership_vectors(model) return None - + # membership_vector if func == "membership_vector": if isinstance(model, hdbscan.HDBSCAN): @@ -51,8 +51,16 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import approximate_predict - probabilities = approximate_predict(model, embeddings) + from cuml.cluster.hdbscan import prediction + try: + probabilities = prediction.membership_vector( + model, embeddings, + # bacth size cannot be larger than the number of docs + # this will be unnecessary in cuml 23.08 + batch_size=min(embeddings.shape[0], 4096)) + # membership_vector available in cuml 23.04 and up + except AttributeError: + _, probabilities = prediction.approximate_predict(model, embeddings) return probabilities return None diff --git a/tests/conftest.py b/tests/conftest.py index 0418ac5e..57ab7223 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,3 +126,19 @@ def online_topic_model(documents, document_embeddings, embedding_model): topics.extend(model.topics_) model.topics_ = topics return model + + +@pytest.fixture(scope="session") +def cuml_base_topic_model(documents, document_embeddings, embedding_model): + try: + from cuml import HDBSCAN as cuml_hdbscan, UMAP as cuml_umap + model = BERTopic(embedding_model=embedding_model, + calculate_probabilities=True, + umap_model=cuml_umap(random_state=42), + hdbscan_model=cuml_hdbscan( + min_cluster_size=3, + prediction_data=True)) + model.fit(documents, document_embeddings) + return model + except ModuleNotFoundError: + return None diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index be5904e7..8b5952eb 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -3,7 +3,7 @@ from bertopic import BERTopic -@pytest.mark.parametrize('model', [("base_topic_model"), ('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model')]) +@pytest.mark.parametrize('model', [("base_topic_model"), ('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model'), ('cuml_base_topic_model')]) def test_full_model(model, documents, request): """ Tests the entire pipeline in one go. This serves as a sanity check to see if the default settings result in a good separation of topics. @@ -11,6 +11,9 @@ def test_full_model(model, documents, request): NOTE: This does not cover all cases but merely combines it all together """ topic_model = copy.deepcopy(request.getfixturevalue(model)) + if model == 'cuml_base_topic_model' and topic_model is None: + # cuml not installed, can't run test + return if model == "base_topic_model": topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2") topic_model = BERTopic.load("model_dir") @@ -110,3 +113,15 @@ def test_full_model(model, documents, request): # if topic_model.topic_embeddings_ is not None: # topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True) # loaded_model = BERTopic.load("model_dir") + +def test_cuml(cuml_base_topic_model, documents, request, monkeypatch): + """Specific tests for cuml-based models.""" + + if cuml_base_topic_model is None: + # cuml not installed, can't run test + return + # make sure calculating probabilities does not fail if the cuml version + # does not yet support membership_vector (cuml 23.04 and higher) + with monkeypatch.context() as m: + m.delattr('cuml.cluster.hdbscan.prediction.membership_vector', raising=False) + predictions, probabilities = cuml_base_topic_model.transform(documents)