Skip to content

Commit 9af947f

Browse files
author
Andrey Oskin
committed
Fixed elkan and yinyang
1 parent 8049b0d commit 9af947f

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

src/elkan.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
1818
"""
1919
struct Elkan <: AbstractKMeansAlg end
2020

21-
function kmeans!(alg::Elkan, containers, X, k;
21+
function kmeans!(alg::Elkan, containers, X, k, weights;
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
2424
tol = eltype(X)(1e-6), verbose = false, init = nothing)
2525
nrow, ncol = size(X)
2626
centroids = init == nothing ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
2727

2828
update_containers(alg, containers, centroids, n_threads)
29-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
29+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
3030

3131
T = eltype(X)
3232
converged = false
@@ -37,7 +37,7 @@ function kmeans!(alg::Elkan, containers, X, k;
3737
while niters < max_iters
3838
niters += 1
3939
# Core iteration
40-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
40+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
4141

4242
# Collect distributed containers (such as centroids_new, centroids_cnt)
4343
# in paper it is step 4
@@ -70,7 +70,7 @@ function kmeans!(alg::Elkan, containers, X, k;
7070
J_previous = J
7171
end
7272

73-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
73+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
7474
totalcost = sum(containers.sum_of_squares)
7575

7676
# Terminate algorithm with the assumption that K-means has converged
@@ -127,7 +127,7 @@ function create_containers(alg::Elkan, X, k, nrow, ncol, n_threads)
127127
)
128128
end
129129

130-
function chunk_initialize(::Elkan, containers, centroids, X, r, idx)
130+
function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
131131
ub = containers.ub
132132
lb = containers.lb
133133
centroids_dist = containers.centroids_dist
@@ -153,9 +153,9 @@ function chunk_initialize(::Elkan, containers, centroids, X, r, idx)
153153
end
154154
ub[i] = min_dist
155155
labels[i] = label
156-
centroids_cnt[label] += one(T)
156+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
157157
for j in axes(X, 1)
158-
centroids_new[j, label] += X[j, i]
158+
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
159159
end
160160
end
161161
end
@@ -188,7 +188,7 @@ function update_containers(::Elkan, containers, centroids, n_threads)
188188
return centroids_dist
189189
end
190190

191-
function chunk_update_centroids(::Elkan, containers, centroids, X, r, idx)
191+
function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, idx)
192192
# unpack
193193
ub = containers.ub
194194
lb = containers.lb
@@ -231,11 +231,11 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, r, idx)
231231

232232
if label != label_old
233233
labels[i] = label
234-
centroids_cnt[label_old] -= one(T)
235-
centroids_cnt[label] += one(T)
234+
centroids_cnt[label_old] -= isnothing(weights) ? one(T) : weights[i]
235+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
236236
for j in axes(X, 1)
237-
centroids_new[j, label_old] -= X[j, i]
238-
centroids_new[j, label] += X[j, i]
237+
centroids_new[j, label_old] -= isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
238+
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
239239
end
240240
end
241241
end

src/yinyang.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Yinyang() = Yinyang(true, 7)
3030
Yinyang(auto::Bool) = Yinyang(auto, 7)
3131
Yinyang(group_size::Int) = Yinyang(true, group_size)
3232

33-
function kmeans!(alg::Yinyang, containers, X, k;
33+
function kmeans!(alg::Yinyang, containers, X, k, weights;
3434
n_threads = Threads.nthreads(),
3535
k_init = "k-means++", max_iters = 300,
3636
tol = 1e-6, verbose = false, init = nothing)
@@ -40,7 +40,7 @@ function kmeans!(alg::Yinyang, containers, X, k;
4040
# create initial groups of centers, step 1 in original paper
4141
initialize(alg, containers, centroids, n_threads)
4242
# construct initial bounds, step 2
43-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
43+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
4444
collect_containers(alg, containers, n_threads)
4545

4646
# update centers and calculate drifts. Step 3.1 of the original paper.
@@ -69,14 +69,14 @@ function kmeans!(alg::Yinyang, containers, X, k;
6969

7070
# push!(containers.debug, [0, 0, 0])
7171
# Core calculation of the Yinyang, 3.2-3.3 steps of the original paper
72-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
72+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
7373
collect_containers(alg, containers, n_threads)
7474

7575
# update centers and calculate drifts. Step 3.1 of the original paper.
7676
calculate_centroids_movement(alg, containers, centroids)
7777
end
7878

79-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
79+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
8080
totalcost = sum(containers.sum_of_squares)
8181

8282
# Terminate algorithm with the assumption that K-means has converged
@@ -166,16 +166,16 @@ function initialize(alg::Yinyang, containers, centroids, n_threads)
166166
end
167167
end
168168

169-
function chunk_initialize(alg::Yinyang, containers, centroids, X, r, idx)
169+
function chunk_initialize(alg::Yinyang, containers, centroids, X, weights, r, idx)
170170
T = eltype(X)
171171
centroids_cnt = containers.centroids_cnt[idx]
172172
centroids_new = containers.centroids_new[idx]
173173

174174
@inbounds for i in r
175175
label = point_all_centers!(alg, containers, centroids, X, i)
176-
centroids_cnt[label] += one(T)
176+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
177177
for j in axes(X, 1)
178-
centroids_new[j, label] += X[j, i]
178+
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
179179
end
180180
end
181181
end
@@ -202,7 +202,7 @@ function calculate_centroids_movement(alg::Yinyang, containers, centroids)
202202
end
203203
end
204204

205-
function chunk_update_centroids(alg, containers, centroids, X, r, idx)
205+
function chunk_update_centroids(alg, containers, centroids, X, weights, r, idx)
206206
# unpack containers for easier manipulations
207207
centroids_new = containers.centroids_new[idx]
208208
centroids_cnt = containers.centroids_cnt[idx]
@@ -330,11 +330,11 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
330330
ub[i] = ubx
331331
if old_label != label
332332
labels[i] = label
333-
centroids_cnt[label] += one(T)
334-
centroids_cnt[old_label] -= one(T)
333+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
334+
centroids_cnt[old_label] -= isnothing(weights) ? one(T) : weights[i]
335335
for j in axes(X, 1)
336-
centroids_new[j, label] += X[j, i]
337-
centroids_new[j, old_label] -= X[j, i]
336+
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
337+
centroids_new[j, old_label] -= isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
338338
end
339339
end
340340
end

0 commit comments

Comments
 (0)