Skip to content

Commit

Permalink
Merge pull request #1073 from mglisse/tomato-kde-rescale
Browse files Browse the repository at this point in the history
Rescale values for KDE to avoir underflow
  • Loading branch information
VincentRouvreau authored Jun 17, 2024
2 parents 318c354 + d857587 commit b15c964
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/python/gudhi/clustering/tomato.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
graph_type (str): 'manual', 'knn' or 'radius'. Default is 'knn'.
density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'. When you have many points,
'KDE' and 'logKDE' tend to be slower. Default is 'logDTM'.
The values computed for 'DTM' or 'KDE' are not normalized (this does not affect the clustering).
metric (str|Callable): metric used when calculating the distance between instances in a feature array.
Defaults to Minkowski of parameter p.
kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to
Expand Down Expand Up @@ -224,6 +225,8 @@ def fit(self, X, y=None, weights=None):

weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_)
if self.density_type_ == "KDE":
# First rescale to avoid computing exp(-1000)
weights -= numpy.max(weights)
weights = numpy.exp(weights)

# TODO: do it at the C++ level and/or in parallel if this is too slow?
Expand Down
10 changes: 10 additions & 0 deletions src/python/test/test_tomato.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,13 @@ def test_tomato_1():
assert t.diagram_.size == 0
assert t.max_weight_per_cc_.size == 1
t.plot_diagram()


def test_tomato_kde_underflow():
# 1D construction with 2 Gaussians, embedded in high dimension
X = np.zeros((200, 1000))
X[:100, 0] = np.random.default_rng().normal(-2, 1, 100)
X[100:, 0] = np.random.default_rng().normal(2, 1, 100)
# X[:,0].sort()
t = Tomato(density_type="KDE").fit(X)
assert (t.weights_ != 0).all()

0 comments on commit b15c964

Please sign in to comment.