@@ -21,15 +21,16 @@ struct Hamerly <: AbstractKMeansAlg end
2121function kmeans! (alg:: Hamerly , containers, X, k;
2222 n_threads = Threads. nthreads (),
2323 k_init = " k-means++" , max_iters = 300 ,
24- tol = 1e-6 , verbose = false , init = nothing )
24+ tol = eltype (X)( 1e-6 ) , verbose = false , init = nothing )
2525 nrow, ncol = size (X)
2626 centroids = init == nothing ? smart_init (X, k, n_threads, init= k_init). centroids : deepcopy (init)
2727
2828 @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X)
2929
30+ T = eltype (X)
3031 converged = false
3132 niters = 0
32- J_previous = 0.0
33+ J_previous = zero (T)
3334 p = containers. p
3435
3536 # Update centroids & labels with closest members until convergence
@@ -70,35 +71,36 @@ function kmeans!(alg::Hamerly, containers, X, k;
7071 # TODO empty placeholder vectors should be calculated
7172 # TODO Float64 type definitions is too restrictive, should be relaxed
7273 # especially during GPU related development
73- return KmeansResult (centroids, containers. labels, Float64 [], Int[], Float64 [], totalcost, niters, converged)
74+ return KmeansResult (centroids, containers. labels, T [], Int[], T [], totalcost, niters, converged)
7475end
7576
76- function create_containers (alg:: Hamerly , k, nrow, ncol, n_threads)
77+ function create_containers (alg:: Hamerly , X, k, nrow, ncol, n_threads)
78+ T = eltype (X)
7779 lng = n_threads + 1
78- centroids_new = Vector {Array{Float64,2 }} (undef, lng)
79- centroids_cnt = Vector {Vector{Int }} (undef, lng)
80+ centroids_new = Vector {Matrix{T }} (undef, lng)
81+ centroids_cnt = Vector {Vector{T }} (undef, lng)
8082
8183 for i = 1 : lng
82- centroids_new[i] = zeros (nrow, k)
83- centroids_cnt[i] = zeros (k)
84+ centroids_new[i] = zeros (T, nrow, k)
85+ centroids_cnt[i] = zeros (T, k)
8486 end
8587
8688 # Upper bound to the closest center
87- ub = Vector {Float64 } (undef, ncol)
89+ ub = Vector {T } (undef, ncol)
8890
8991 # lower bound to the second closest center
90- lb = Vector {Float64 } (undef, ncol)
92+ lb = Vector {T } (undef, ncol)
9193
9294 labels = zeros (Int, ncol)
9395
9496 # distance that centroid has moved
95- p = Vector {Float64 } (undef, k)
97+ p = Vector {T } (undef, k)
9698
9799 # distance from the center to the closest other center
98- s = Vector {Float64 } (undef, k)
100+ s = Vector {T } (undef, k)
99101
100102 # total_sum_calculation
101- sum_of_squares = Vector {Float64 } (undef, n_threads)
103+ sum_of_squares = Vector {T } (undef, n_threads)
102104
103105 return (
104106 centroids_new = centroids_new,
@@ -118,12 +120,13 @@ end
118120Initial calulation of all bounds and points labeling.
119121"""
120122function chunk_initialize (alg:: Hamerly , containers, centroids, X, r, idx)
123+ T = eltype (X)
121124 centroids_cnt = containers. centroids_cnt[idx]
122125 centroids_new = containers. centroids_new[idx]
123126
124127 @inbounds for i in r
125128 label = point_all_centers! (containers, centroids, X, i)
126- centroids_cnt[label] += 1
129+ centroids_cnt[label] += one (T)
127130 for j in axes (X, 1 )
128131 centroids_new[j, label] += X[j, i]
129132 end
@@ -136,12 +139,13 @@ end
136139Calculates minimum distances from centers to each other.
137140"""
138141function update_containers (:: Hamerly , containers, centroids, n_threads)
142+ T = eltype (centroids)
139143 s = containers. s
140- s .= Inf
144+ s .= T ( Inf )
141145 @inbounds for i in axes (centroids, 2 )
142146 for j in i+ 1 : size (centroids, 2 )
143147 d = distance (centroids, centroids, i, j)
144- d = 0.25 * d
148+ d = T ( 0.25 ) * d
145149 s[i] = s[i] > d ? d : s[i]
146150 s[j] = s[j] > d ? d : s[j]
147151 end
@@ -164,6 +168,7 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
164168 s = containers. s
165169 lb = containers. lb
166170 ub = containers. ub
171+ T = eltype (X)
167172
168173 @inbounds for i in r
169174 # m ← max(s(a(i))/2, l(i))
@@ -178,8 +183,8 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
178183 label_new = point_all_centers! (containers, centroids, X, i)
179184 if label != label_new
180185 labels[i] = label_new
181- centroids_cnt[label_new] += 1
182- centroids_cnt[label] -= 1
186+ centroids_cnt[label_new] += one (T)
187+ centroids_cnt[label] -= one (T)
183188 for j in axes (X, 1 )
184189 centroids_new[j, label_new] += X[j, i]
185190 centroids_new[j, label] -= X[j, i]
@@ -199,9 +204,10 @@ function point_all_centers!(containers, centroids, X, i)
199204 ub = containers. ub
200205 lb = containers. lb
201206 labels = containers. labels
207+ T = eltype (X)
202208
203- min_distance = Inf
204- min_distance2 = Inf
209+ min_distance = T ( Inf )
210+ min_distance2 = T ( Inf )
205211 label = 1
206212 @inbounds for k in axes (centroids, 2 )
207213 dist = distance (X, centroids, i, k)
@@ -230,9 +236,10 @@ in `centroids` and `p` respectively.
230236function move_centers (:: Hamerly , containers, centroids)
231237 centroids_new = containers. centroids_new[end ]
232238 p = containers. p
239+ T = eltype (centroids)
233240
234241 @inbounds for i in axes (centroids, 2 )
235- d = 0.0
242+ d = zero (T)
236243 for j in axes (centroids, 1 )
237244 d += (centroids[j, i] - centroids_new[j, i])^ 2
238245 centroids[j, i] = centroids_new[j, i]
@@ -251,6 +258,7 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
251258 ub = containers. ub
252259 lb = containers. lb
253260 labels = containers. labels
261+ T = eltype (containers. ub)
254262
255263 # Since bounds are squred distance, `sqrt` is used to make corresponding estimation, unlike
256264 # the original paper, where usual metric is used.
@@ -270,11 +278,11 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
270278 # The same applies to the lower bounds.
271279 @inbounds for i in r
272280 label = labels[i]
273- ub[i] += 2 * sqrt (abs (ub[i] * p[label])) + p[label]
281+ ub[i] += T ( 2 ) * sqrt (abs (ub[i] * p[label])) + p[label]
274282 if r1 == label
275- lb[i] = lb[i] <= pr2 ? 0.0 : lb[i] + pr2 - 2 * sqrt (abs (pr2* lb[i]))
283+ lb[i] = lb[i] <= pr2 ? zero (T) : lb[i] + pr2 - T ( 2 ) * sqrt (abs (pr2* lb[i]))
276284 else
277- lb[i] = lb[i] <= pr1 ? 0.0 : lb[i] + pr1 - 2 * sqrt (abs (pr1* lb[i]))
285+ lb[i] = lb[i] <= pr1 ? zero (T) : lb[i] + pr1 - T ( 2 ) * sqrt (abs (pr1* lb[i]))
278286 end
279287 end
280288end
@@ -284,10 +292,10 @@ end
284292
285293Finds maximum and next after maximum arguments.
286294"""
287- function double_argmax (p)
295+ function double_argmax (p:: AbstractVector{T} ) where T
288296 r1, r2 = 1 , 1
289297 d1 = p[1 ]
290- d2 = - 1.0
298+ d2 = T ( - Inf )
291299 for i in 2 : length (p)
292300 if p[i] > d1
293301 r2 = r1
0 commit comments