@@ -30,7 +30,7 @@ Yinyang() = Yinyang(true, 7)
3030Yinyang (auto:: Bool ) = Yinyang (auto, 7 )
3131Yinyang (group_size:: Int ) = Yinyang (true , group_size)
3232
33- function kmeans! (alg:: Yinyang , containers, X, k;
33+ function kmeans! (alg:: Yinyang , containers, X, k, weights ;
3434 n_threads = Threads. nthreads (),
3535 k_init = " k-means++" , max_iters = 300 ,
3636 tol = 1e-6 , verbose = false , init = nothing )
@@ -40,7 +40,7 @@ function kmeans!(alg::Yinyang, containers, X, k;
4040 # create initial groups of centers, step 1 in original paper
4141 initialize (alg, containers, centroids, n_threads)
4242 # construct initial bounds, step 2
43- @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X)
43+ @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights )
4444 collect_containers (alg, containers, n_threads)
4545
4646 # update centers and calculate drifts. Step 3.1 of the original paper.
@@ -69,14 +69,14 @@ function kmeans!(alg::Yinyang, containers, X, k;
6969
7070 # push!(containers.debug, [0, 0, 0])
7171 # Core calculation of the Yinyang, 3.2-3.3 steps of the original paper
72- @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X)
72+ @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights )
7373 collect_containers (alg, containers, n_threads)
7474
7575 # update centers and calculate drifts. Step 3.1 of the original paper.
7676 calculate_centroids_movement (alg, containers, centroids)
7777 end
7878
79- @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids)
79+ @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights )
8080 totalcost = sum (containers. sum_of_squares)
8181
8282 # Terminate algorithm with the assumption that K-means has converged
@@ -166,16 +166,16 @@ function initialize(alg::Yinyang, containers, centroids, n_threads)
166166 end
167167end
168168
169- function chunk_initialize (alg:: Yinyang , containers, centroids, X, r, idx)
169+ function chunk_initialize (alg:: Yinyang , containers, centroids, X, weights, r, idx)
170170 T = eltype (X)
171171 centroids_cnt = containers. centroids_cnt[idx]
172172 centroids_new = containers. centroids_new[idx]
173173
174174 @inbounds for i in r
175175 label = point_all_centers! (alg, containers, centroids, X, i)
176- centroids_cnt[label] += one (T)
176+ centroids_cnt[label] += isnothing (weights) ? one (T) : weights[i]
177177 for j in axes (X, 1 )
178- centroids_new[j, label] += X[j, i]
178+ centroids_new[j, label] += isnothing (weights) ? X[j, i] : weights[i] * X[j, i]
179179 end
180180 end
181181end
@@ -202,7 +202,7 @@ function calculate_centroids_movement(alg::Yinyang, containers, centroids)
202202 end
203203end
204204
205- function chunk_update_centroids (alg, containers, centroids, X, r, idx)
205+ function chunk_update_centroids (alg, containers, centroids, X, weights, r, idx)
206206 # unpack containers for easier manipulations
207207 centroids_new = containers. centroids_new[idx]
208208 centroids_cnt = containers. centroids_cnt[idx]
@@ -330,11 +330,11 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
330330 ub[i] = ubx
331331 if old_label != label
332332 labels[i] = label
333- centroids_cnt[label] += one (T)
334- centroids_cnt[old_label] -= one (T)
333+ centroids_cnt[label] += isnothing (weights) ? one (T) : weights[i]
334+ centroids_cnt[old_label] -= isnothing (weights) ? one (T) : weights[i]
335335 for j in axes (X, 1 )
336- centroids_new[j, label] += X[j, i]
337- centroids_new[j, old_label] -= X[j, i]
336+ centroids_new[j, label] += isnothing (weights) ? X[j, i] : weights[i] * X[j, i]
337+ centroids_new[j, old_label] -= isnothing (weights) ? X[j, i] : weights[i] * X[j, i]
338338 end
339339 end
340340 end
0 commit comments