Skip to content

Commit

Permalink
Merge pull request #283 from AzamatB/ab-kmeans-broadcasting-bugfix
Browse files Browse the repository at this point in the history
Fix broadcasting bug in `repick_unused_centers()` for K-means
  • Loading branch information
alyst authored Jan 6, 2025
2 parents 24a30ae + 5f280eb commit a3768b6
Showing 1 changed file with 26 additions and 34 deletions.
60 changes: 26 additions & 34 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
update_centers!(X, weights, assignments, to_update, centers, wcounts)

if !isempty(unused)
repick_unused_centers(X, costs, centers, unused, distance, rng)
repick_unused_centers!(centers, unused, X, costs, distance, rng)
to_update[unused] .= true
end

Expand Down Expand Up @@ -211,18 +211,16 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
wcounts, objv, t, converged)
end

#
# Updates assignments, costs, and counts based on
# an updated (squared) distance matrix
#
# Update point assignments, costs, and cluster counts based on
# an updated (squared) distance matrix
function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k x n)
is_init::Bool, # in: whether it is the initial run
assignments::Vector{Int}, # out: assignment vector (n)
costs::Vector{<:Real}, # out: costs of the resultant assignment (n)
counts::Vector{Int}, # out: # of points assigned to each cluster (k)
to_update::Vector{Bool}, # out: whether a center needs update (k)
unused::Vector{Int} # out: list of centers with no points assigned
)
unused::Vector{Int}, # out: list of centers with no points assigned
)
k, n = size(dmat)

# re-initialize the counting vector
Expand Down Expand Up @@ -272,17 +270,15 @@ function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k
end
end

#
# Update centers based on updated assignments
#
# (specific to the case where points are not weighted)
#
# Update cluster centers and weights to match updated assignments
# (non-weighted points case)
function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
weights::Nothing, # in: point weights
assignments::Vector{Int}, # in: assignments (n)
to_update::Vector{Bool}, # in: whether a center needs update (k)
centers::AbstractMatrix{<:AbstractFloat}, # out: updated centers (d x k)
wcounts::Vector{Int}) # out: updated cluster weights (k)
wcounts::Vector{Int}, # out: updated cluster weights (k)
)
d, n = size(X)
k = size(centers, 2)

Expand Down Expand Up @@ -318,18 +314,15 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d
end
end

#
# Update centers based on updated assignments
#
# (specific to the case where points are weighted)
#
# Update cluster centers and weights to match updated assignments
# (weighted points case)
function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
weights::Vector{W}, # in: point weights (n)
assignments::Vector{Int}, # in: assignments (n)
to_update::Vector{Bool}, # in: whether a center needs update (k)
centers::AbstractMatrix{<:Real}, # out: updated centers (d x k)
wcounts::Vector{W} # out: updated cluster weights (k)
) where W<:Real
wcounts::Vector{W}, # out: updated cluster weights (k)
) where W<:Real
d, n = size(X)
k = size(centers, 2)

Expand Down Expand Up @@ -368,26 +361,25 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
end


#
# Re-picks centers that have no points assigned to them.
#
function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix (d x n)
costs::Vector{<:Real}, # in: the current assignment costs (n)
centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
unused::Vector{Int}, # in: indices of centers to be updated
distance::SemiMetric, # in: function to calculate the distance with
rng::AbstractRNG) # in: RNG object
# Re-pick centers that have no points assigned to them.
function repick_unused_centers!(centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
unused::Vector{Int}, # in: indices of centers to be updated (k)
X::AbstractMatrix{<:Real}, # in: the data matrix (d x n)
costs::Vector{<:Real}, # in: the current assignment costs (n)
distance::SemiMetric, # in: function to calculate the distance with
rng::AbstractRNG,
)
# pick new centers using a scheme like kmeans++
ds = similar(costs)
tcosts = copy(costs)
tcosts = copy(costs) # temporary costs used as sampling weights
n = size(X, 2)

for i in unused
# select a random point as a new center
j = wsample(rng, 1:n, tcosts)
tcosts[j] = 0
v = view(X, :, j)
centers[:, i] = v
centers[:, i] = v = view(X, :, j)
colwise!(distance, ds, v, X)
tcosts = min(tcosts, ds)
ds[j] = 0 # calculated distance might be not exactly zero
tcosts .= min.(tcosts, ds)
end
end

0 comments on commit a3768b6

Please sign in to comment.