Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for simplicial_set_embedding #6043

Merged
21 changes: 21 additions & 0 deletions cpp/include/cuml/manifold/umap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ void refine(const raft::handle_t& handle,
UMAPParams* params,
float* embeddings);

/**
* Initializes embeddings and performs a UMAP fit on them, which enables
* iterative fitting without callbacks.
*
* @param[in] handle: raft::handle_t
* @param[in] X: pointer to input array
* @param[in] n: n_samples of input array
* @param[in] d: n_features of input array
* @param[in] graph: pointer to raft::sparse::COO object computed using ML::UMAP::get_graph
* @param[in] params: pointer to ML::UMAPParams object
* @param[out] embeddings: pointer to current embedding with shape n * n_components, stores updated
* embeddings on executing refine
*/
void init_and_refine(const raft::handle_t& handle,
float* X,
int n,
int d,
raft::sparse::COO<float, int>* graph,
UMAPParams* params,
float* embeddings);

/**
* Dense fit
*
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/umap/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,31 @@ void _refine(const raft::handle_t& handle,
value_t* embeddings)
{
cudaStream_t stream = handle.get_stream();
ML::Logger::get().setLevel(params->verbosity);

/**
* Run simplicial set embedding to approximate low-dimensional representation
*/
SimplSetEmbed::run<TPB_X, value_t>(inputs.n, inputs.d, graph, params, embeddings, stream);
}

template <typename value_idx, typename value_t, typename umap_inputs, int TPB_X>
void _init_and_refine(const raft::handle_t& handle,
const umap_inputs& inputs,
UMAPParams* params,
raft::sparse::COO<value_t>* graph,
value_t* embeddings)
{
cudaStream_t stream = handle.get_stream();
ML::Logger::get().setLevel(params->verbosity);

// Initialize embeddings
InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init);

// Run simplicial set embedding
SimplSetEmbed::run<TPB_X, value_t>(inputs.n, inputs.d, graph, params, embeddings, stream);
}

template <typename value_idx, typename value_t, typename umap_inputs, int TPB_X>
void _fit(const raft::handle_t& handle,
const umap_inputs& inputs,
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/umap/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ void refine(const raft::handle_t& handle,
handle, inputs, params, graph, embeddings);
}

void init_and_refine(const raft::handle_t& handle,
float* X,
int n,
int d,
raft::sparse::COO<float>* graph,
UMAPParams* params,
float* embeddings)
{
CUML_LOG_DEBUG("Calling UMAP::init_and_refine() with precomputed KNN");
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
UMAPAlgo::_init_and_refine<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, params, graph, embeddings);
}

void fit(const raft::handle_t& handle,
float* X,
float* y,
Expand Down
96 changes: 70 additions & 26 deletions python/cuml/cuml/manifold/simpl_set.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# distutils: language = c++

import warnings
from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
from cuml.internals.safe_imports import gpu_only_import
Expand All @@ -26,7 +27,7 @@ from cuml.manifold.umap_utils cimport *
from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \
metric_parsing

from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.input_utils import input_to_cuml_array, is_array_like
from cuml.internals.array import CumlArray

from pylibraft.common.handle cimport handle_t
Expand Down Expand Up @@ -56,6 +57,14 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP":
UMAPParams* params,
float* embeddings)

void init_and_refine(handle_t &handle,
float* X,
int n,
int d,
COO* cgraph_coo,
UMAPParams* params,
float* embeddings)


def fuzzy_simplicial_set(X,
n_neighbors,
Expand All @@ -73,6 +82,7 @@ def fuzzy_simplicial_set(X,
locally approximating geodesic distance at each point, creating a fuzzy
simplicial set for each such point, and then combining all the local
fuzzy simplicial sets into a global one via a fuzzy union.

Parameters
----------
X: array of shape (n_samples, n_features)
Expand Down Expand Up @@ -212,7 +222,7 @@ def simplicial_set_embedding(
initial_alpha=1.0,
a=None,
b=None,
repulsion_strength=1.0,
gamma=1.0,
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
negative_sample_rate=5,
n_epochs=None,
init="spectral",
Expand All @@ -221,13 +231,15 @@ def simplicial_set_embedding(
metric_kwds=None,
output_metric="euclidean",
output_metric_kwds=None,
repulsion_strength=None,
convert_dtype=True,
verbose=False,
):
"""Perform a fuzzy simplicial set embedding, using a specified
initialisation method and then minimizing the fuzzy set cross entropy
between the 1-skeletons of the high and low dimensional fuzzy simplicial
sets.

Parameters
----------
data: array of shape (n_samples, n_features)
Expand All @@ -244,7 +256,7 @@ def simplicial_set_embedding(
Parameter of differentiable approximation of right adjoint functor
b: float
Parameter of differentiable approximation of right adjoint functor
repulsion_strength: float
gamma: float
Weight to apply to negative samples.
negative_sample_rate: int (optional, default 5)
The number of negative samples to select per positive sample
Expand All @@ -260,7 +272,7 @@ def simplicial_set_embedding(
How to initialize the low dimensional embedding. Options are:
* 'spectral': use a spectral embedding of the fuzzy 1-skeleton
* 'random': assign initial embedding positions at random.
* A numpy array of initial embedding positions.
* An array-like with initial embedding positions.
random_state: numpy RandomState or equivalent
A state capable being used as a numpy random state.
metric: string (default='euclidean').
Expand Down Expand Up @@ -294,9 +306,6 @@ def simplicial_set_embedding(
if output_metric_kwds is None:
output_metric_kwds = {}

if init not in ['spectral', 'random']:
raise Exception("Initialization strategy not supported: %d" % init)

if output_metric not in ['euclidean', 'categorical']:
raise Exception("Invalid output metric: {}" % output_metric)

Expand All @@ -320,17 +329,29 @@ def simplicial_set_embedding(
cdef UMAPParams* umap_params = new UMAPParams()
umap_params.n_components = <int> n_components
umap_params.initial_alpha = <int> initial_alpha
umap_params.a = <int> a
umap_params.b = <int> b
umap_params.repulsion_strength = <float> repulsion_strength
umap_params.a = <float> a
umap_params.b = <float> b

if repulsion_strength:
gamma = repulsion_strength
warnings.simplefilter(action="always", category=FutureWarning)
warnings.warn('Parameter "repulsion_strength" has been'
' deprecated. It will be removed in version 24.12.'
' Please use the "gamma" parameter instead.',
FutureWarning)

umap_params.repulsion_strength = <float> gamma
umap_params.negative_sample_rate = <int> negative_sample_rate
umap_params.n_epochs = <int> n_epochs
if init == 'spectral':
umap_params.init = <int> 1
else: # init == 'random'
umap_params.init = <int> 0
umap_params.random_state = <int> random_state
umap_params.deterministic = <bool> deterministic
if isinstance(init, str):
if init == "random":
umap_params.init = <int> 0
elif init == 'spectral':
umap_params.init = <int> 1
else:
raise ValueError("Invalid initialization strategy")
try:
umap_params.metric = metric_parsing[metric.lower()]
except KeyError:
Expand All @@ -344,7 +365,7 @@ def simplicial_set_embedding(
else: # output_metric == 'categorical'
umap_params.target_metric = MetricType.CATEGORICAL
umap_params.target_weight = <float> output_metric_kwds['p'] \
if 'p' in output_metric_kwds else 0
if 'p' in output_metric_kwds else 0.5
umap_params.verbosity = <int> verbose

X_m, _, _, _ = \
Expand All @@ -365,17 +386,40 @@ def simplicial_set_embedding(
handle,
graph)

embedding = CumlArray.zeros((X_m.shape[0], n_components),
order="C", dtype=np.float32,
index=X_m.index)

refine(handle_[0],
<float*><uintptr_t> X_m.ptr,
<int> X_m.shape[0],
<int> X_m.shape[1],
<COO*> fss_graph.get(),
<UMAPParams*> umap_params,
<float*><uintptr_t> embedding.ptr)
if isinstance(init, str):
if init in ['spectral', 'random']:
embedding = CumlArray.zeros((X_m.shape[0], n_components),
order="C", dtype=np.float32,
index=X_m.index)
init_and_refine(handle_[0],
<float*><uintptr_t> X_m.ptr,
<int> X_m.shape[0],
<int> X_m.shape[1],
<COO*> fss_graph.get(),
<UMAPParams*> umap_params,
<float*><uintptr_t> embedding.ptr)
else:
raise ValueError("Invalid initialization strategy")
elif is_array_like(init):
embedding, _, _, _ = \
input_to_cuml_array(init,
order='C',
convert_to_dtype=(np.float32 if convert_dtype
else None),
check_dtype=np.float32,
check_rows=X_m.shape[0],
check_cols=n_components)
refine(handle_[0],
<float*><uintptr_t> X_m.ptr,
<int> X_m.shape[0],
<int> X_m.shape[1],
<COO*> fss_graph.get(),
<UMAPParams*> umap_params,
<float*><uintptr_t> embedding.ptr)
else:
raise ValueError(
"Initialization not supported. Please provide a valid "
"initialization strategy or a pre-initialized embedding.")

free(umap_params)

Expand Down
12 changes: 6 additions & 6 deletions python/cuml/cuml/tests/test_simpl_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
from cuml.datasets import make_blobs
from cuml.internals.safe_imports import cpu_only_import
from cuml.metrics import trustworthiness

np = cpu_only_import("numpy")
cp = gpu_only_import("cupy")
Expand Down Expand Up @@ -133,7 +134,7 @@ def test_simplicial_set_embedding(
metric = "euclidean"
initial_alpha = 1.0
a, b = UMAP.find_ab_params(1.0, 0.1)
gamma = 0
gamma = 1.0
negative_sample_rate = 5
n_epochs = 500
init = "random"
Expand Down Expand Up @@ -180,7 +181,6 @@ def test_simplicial_set_embedding(
cu_fss_graph = cu_fuzzy_simplicial_set(
X, n_neighbors, random_state, metric
)

cu_embedding = cu_simplicial_set_embedding(
X,
cu_fss_graph,
Expand All @@ -199,7 +199,7 @@ def test_simplicial_set_embedding(
output_metric_kwds=output_metric_kwds,
)

ref_embedding = cp.array(ref_embedding)
assert correctness_dense(
ref_embedding, cu_embedding, rtol=0.1, threshold=0.95
)
ref_t_score = trustworthiness(X, ref_embedding, n_neighbors=n_neighbors)
t_score = trustworthiness(X, cu_embedding, n_neighbors=n_neighbors)
abs_tol = 0.05
assert t_score >= ref_t_score - abs_tol
Loading