Skip to content

Commit

Permalink
Merge branch 'rapidsai:branch-24.10' into hdbscan-nndescent
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp authored Aug 23, 2024
2 parents 732a06e + 4587f6a commit 517f7ab
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
12 changes: 10 additions & 2 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ class UMAP(UniversalBase,
smaller than or equal to 50K. Otherwise, runs with nn descent.
build_kwds: dict (optional, default=None)
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True}
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True,
'nnd_n_clusters': 1}
Note that nnd_n_clusters > 1 will result in batch-building with NN Descent.
Notes
-----
Expand Down Expand Up @@ -440,7 +442,8 @@ class UMAP(UniversalBase,
# https://github.com/rapidsai/cuml/issues/5985
logger.info("build_algo set to brute_force_knn because random_state is given")
self.build_algo ="brute_force_knn"
self.build_algo = build_algo
else:
self.build_algo = build_algo
else:
raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo)

Expand Down Expand Up @@ -488,12 +491,17 @@ class UMAP(UniversalBase,
umap_params.nn_descent_params.max_iterations = <uint64_t> 20
umap_params.nn_descent_params.termination_threshold = <float> 0.0001
umap_params.nn_descent_params.return_distances = <bool> True
umap_params.nn_descent_params.n_clusters = <uint64_t> 1
else:
umap_params.nn_descent_params.graph_degree = <uint64_t> cls.build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> cls.build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> cls.build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> cls.build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> cls.build_kwds.get("nnd_return_distances", True)
if cls.build_kwds.get("nnd_n_clusters", 1) < 1:
logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1")
umap_params.nn_descent_params.n_clusters = <uint64_t> cls.build_kwds.get("nnd_n_clusters", 1)

umap_params.target_weight = <float> cls.target_weight
umap_params.random_state = <uint64_t> cls.random_state
umap_params.deterministic = <bool> cls.deterministic
Expand Down
9 changes: 5 additions & 4 deletions python/cuml/cuml/manifold/umap_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals":

cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbors::experimental::nn_descent":
cdef struct index_params:
int64_t graph_degree,
int64_t intermediate_graph_degree,
int64_t max_iterations,
uint64_t graph_degree,
uint64_t intermediate_graph_degree,
uint64_t max_iterations,
float termination_threshold,
bool return_distances
bool return_distances,
uint64_t n_clusters,

cdef extern from "cuml/manifold/umapparams.h" namespace "ML":

Expand Down
22 changes: 22 additions & 0 deletions python/cuml/cuml/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,25 @@ def test_umap_distance_metrics_fit_transform_trust_on_sparse_input(

if umap_learn_supported:
assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True)


@pytest.mark.parametrize("data_on_host", [True, False])
@pytest.mark.parametrize("num_clusters", [0, 3, 5])
def test_umap_trustworthiness_on_batch_nnd(data_on_host, num_clusters):

digits = datasets.load_digits()

cuml_model = cuUMAP(
n_neighbors=10,
min_dist=0.01,
build_algo="nn_descent",
build_kwds={"nnd_n_clusters": num_clusters},
)

cuml_embedding = cuml_model.fit_transform(
digits.data, convert_dtype=True, data_on_host=data_on_host
)

cuml_trust = trustworthiness(digits.data, cuml_embedding, n_neighbors=10)

assert cuml_trust > 0.9

0 comments on commit 517f7ab

Please sign in to comment.