1#ifndef KMEANS_HARTIGAN_WONG_HPP
2#define KMEANS_HARTIGAN_WONG_HPP
9#include "sanisizer/sanisizer.hpp"
13#include "QuickSearch.hpp"
15#include "compute_centroids.hpp"
16#include "is_edge_case.hpp"
61namespace RefineHartiganWong_internal {
135template<
typename Index_>
138 Index_ my_last_observation = 0;
140 static constexpr int current_optimal_transfer = -1;
141 static constexpr int previous_optimal_transfer = -2;
142 static constexpr int ancient_history = -3;
146 int my_last_iteration = ancient_history;
149 void reset(
const Index_ total_obs) {
150 if (my_last_iteration > current_optimal_transfer) {
151 my_last_observation = total_obs;
152 my_last_iteration = previous_optimal_transfer;
153 }
else if (my_last_iteration == current_optimal_transfer) {
155 my_last_iteration = previous_optimal_transfer;
157 my_last_iteration = ancient_history;
161 void set_optimal(
const Index_ obs) {
162 my_last_observation = obs;
163 my_last_iteration = current_optimal_transfer;
167 void set_quick(
const int iter,
const Index_ obs) {
168 my_last_iteration = iter;
169 my_last_observation = obs;
173 bool changed_after(
const int iter,
const Index_ obs)
const {
174 if (my_last_iteration == iter) {
175 return my_last_observation > obs;
177 return my_last_iteration > iter;
181 bool changed_after_or_at(
const int iter,
const Index_ obs)
const {
182 if (my_last_iteration == iter) {
183 return my_last_observation >= obs;
185 return my_last_iteration > iter;
189 bool is_live(
const Index_ obs)
const {
190 return changed_after(previous_optimal_transfer, obs);
194template<
typename Float_,
typename Index_,
typename Cluster_>
197 std::vector<Cluster_> best_destination_cluster;
198 std::vector<Index_> cluster_sizes;
200 std::vector<Float_> loss_multiplier;
201 std::vector<Float_> gain_multiplier;
202 std::vector<Float_> wcss_loss;
204 std::vector<UpdateHistory<Index_> > update_history;
206 Index_ optra_steps_since_last_transfer = 0;
209 Workspace(Index_ nobs, Cluster_ ncenters) :
211 best_destination_cluster(sanisizer::cast<I<decltype(best_destination_cluster.size())> >(nobs)),
212 cluster_sizes(sanisizer::cast<I<decltype(cluster_sizes.size())> >(ncenters)),
213 loss_multiplier(sanisizer::cast<I<decltype(loss_multiplier.size())> >(ncenters)),
214 gain_multiplier(sanisizer::cast<I<decltype(gain_multiplier.size())> >(ncenters)),
215 wcss_loss(sanisizer::cast<I<decltype(wcss_loss.size())> >(nobs)),
216 update_history(sanisizer::cast<I<decltype(update_history.size())> >(ncenters))
220template<
typename Data_,
typename Float_>
221Float_ squared_distance_from_cluster(
const Data_*
const data,
const Float_*
const center,
const std::size_t ndim) {
223 for (I<
decltype(ndim)> d = 0; d < ndim; ++d) {
224 const Float_ delta =
static_cast<Float_
>(data[d]) - center[d];
225 output += delta * delta;
230template<
class Matrix_,
typename Cluster_,
typename Float_>
231void find_closest_two_centers(
233 const Cluster_ ncenters,
234 const Float_*
const centers,
235 Cluster_*
const best_cluster,
236 std::vector<Cluster_>& best_destination_cluster,
239 const auto ndim = data.num_dimensions();
243 internal::QuickSearch<Float_, Cluster_> index(ndim, ncenters);
244 index.reset(centers);
246 const auto nobs = data.num_observations();
247 parallelize(nthreads, nobs, [&](
const int,
const I<
decltype(nobs)> start,
const I<
decltype(nobs)> length) ->
void {
248 auto matwork = data.new_known_extractor(start, length);
249 auto qswork = index.new_workspace();
250 for (I<
decltype(start)> obs = start, end = start + length; obs < end; ++obs) {
251 const auto optr = matwork->get_observation();
252 const auto res2 = index.find2(optr, qswork);
253 best_cluster[obs] = res2[0];
254 best_destination_cluster[obs] = res2[1];
259template<
typename Float_>
260constexpr Float_ big_number() {
264template<
typename Data_,
typename Index_,
typename Cluster_,
typename Float_>
266 const std::size_t ndim,
267 const Data_*
const obs_ptr,
271 Float_*
const centers,
272 Cluster_*
const best_cluster,
273 Workspace<Float_, Index_, Cluster_>& work)
277 const Float_ al1 = work.cluster_sizes[l1], alw = al1 - 1;
278 const Float_ al2 = work.cluster_sizes[l2], alt = al2 + 1;
280 for (I<
decltype(ndim)> d = 0; d < ndim; ++d) {
281 const Float_ oval = obs_ptr[d];
282 auto& c1 = centers[sanisizer::nd_offset<std::size_t>(d, ndim, l1)];
283 c1 = (c1 * al1 - oval) / alw;
284 auto& c2 = centers[sanisizer::nd_offset<std::size_t>(d, ndim, l2)];
285 c2 = (c2 * al2 + oval) / alt;
288 --work.cluster_sizes[l1];
289 ++work.cluster_sizes[l2];
291 work.gain_multiplier[l1] = alw / al1;
292 work.loss_multiplier[l1] = (alw > 1 ? alw / (alw - 1) : big_number<Float_>());
293 work.loss_multiplier[l2] = alt / al2;
294 work.gain_multiplier[l2] = alt / (alt + 1);
296 best_cluster[obs_id] = l2;
297 work.best_destination_cluster[obs_id] = l1;
307template<
class Matrix_,
typename Cluster_,
typename Float_>
308bool optimal_transfer(
309 const Matrix_& data, Workspace<Float_, Index<Matrix_>, Cluster_>& work,
310 const Cluster_ ncenters,
311 Float_*
const centers,
312 Cluster_*
const best_cluster,
315 const auto nobs = data.num_observations();
316 const auto ndim = data.num_dimensions();
317 auto matwork = data.new_known_extractor();
319 for (I<
decltype(nobs)> obs = 0; obs < nobs; ++obs) {
320 ++work.optra_steps_since_last_transfer;
322 const auto l1 = best_cluster[obs];
323 if (work.cluster_sizes[l1] != 1) {
324 const auto obs_ptr = matwork->get_observation(obs);
336 auto& wcss_loss = work.wcss_loss[obs];
337 const auto l1_ptr = centers + sanisizer::product_unsafe<std::size_t>(ndim, l1);
338 wcss_loss = squared_distance_from_cluster(obs_ptr, l1_ptr, ndim) * work.loss_multiplier[l1];
341 auto l2 = work.best_destination_cluster[obs];
342 const auto original_l2 = l2;
343 const auto l2_ptr = centers + sanisizer::product_unsafe<std::size_t>(ndim, l2);
344 auto wcss_gain = squared_distance_from_cluster(obs_ptr, l2_ptr, ndim) * work.gain_multiplier[l2];
346 const auto update_destination_cluster = [&](
const Cluster_ cen) ->
void {
347 auto cen_ptr = centers + sanisizer::product_unsafe<std::size_t>(ndim, cen);
348 auto candidate = squared_distance_from_cluster(obs_ptr, cen_ptr, ndim) * work.gain_multiplier[cen];
349 if (candidate < wcss_gain) {
350 wcss_gain = candidate;
368 if (all_live || work.update_history[l1].is_live(obs)) {
369 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
370 if (cen != l1 && cen != original_l2) {
371 update_destination_cluster(cen);
375 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
376 if (cen != l1 && cen != original_l2 && work.update_history[cen].is_live(obs)) {
377 update_destination_cluster(cen);
383 if (wcss_gain >= wcss_loss) {
384 work.best_destination_cluster[obs] = l2;
386 work.optra_steps_since_last_transfer = 0;
387 work.update_history[l1].set_optimal(obs);
388 work.update_history[l2].set_optimal(obs);
389 transfer_point(ndim, obs_ptr, obs, l1, l2, centers, best_cluster, work);
395 if (work.optra_steps_since_last_transfer == nobs) {
414template<
class Matrix_,
typename Cluster_,
typename Float_>
415std::pair<bool, bool> quick_transfer(
417 Workspace<Float_, Index<Matrix_>, Cluster_>& work,
418 Float_*
const centers,
419 Cluster_*
const best_cluster,
420 const int quick_iterations)
422 bool had_transfer =
false;
424 const auto nobs = data.num_observations();
425 const auto ndim = data.num_dimensions();
426 auto matwork = data.new_known_extractor();
428 I<
decltype(nobs)> steps_since_last_quick_transfer = 0;
430 for (
int it = 0; it < quick_iterations; ++it) {
431 const int prev_it = it - 1;
433 for (I<
decltype(nobs)> obs = 0; obs < nobs; ++obs) {
434 ++steps_since_last_quick_transfer;
435 const auto l1 = best_cluster[obs];
437 if (work.cluster_sizes[l1] != 1) {
438 I<
decltype(matwork->get_observation(obs))> obs_ptr = NULL;
447 auto& history1 = work.update_history[l1];
448 if (history1.changed_after_or_at(prev_it, obs)) {
449 const auto l1_ptr = centers + sanisizer::product_unsafe<std::size_t>(l1, ndim);
450 obs_ptr = matwork->get_observation(obs);
451 work.wcss_loss[obs] = squared_distance_from_cluster(obs_ptr, l1_ptr, ndim) * work.loss_multiplier[l1];
458 const auto l2 = work.best_destination_cluster[obs];
459 auto& history2 = work.update_history[l2];
460 if (history1.changed_after(prev_it, obs) || history2.changed_after(prev_it, obs)) {
461 if (obs_ptr == NULL) {
462 obs_ptr = matwork->get_observation(obs);
464 const auto l2_ptr = centers + sanisizer::product_unsafe<std::size_t>(l2, ndim);
465 const auto wcss_gain = squared_distance_from_cluster(obs_ptr, l2_ptr, ndim) * work.gain_multiplier[l2];
467 if (wcss_gain < work.wcss_loss[obs]) {
469 steps_since_last_quick_transfer = 0;
470 history1.set_quick(it, obs);
471 history2.set_quick(it, obs);
472 transfer_point(ndim, obs_ptr, obs, l1, l2, centers, best_cluster, work);
477 if (steps_since_last_quick_transfer == nobs) {
480 return std::make_pair(had_transfer,
false);
485 return std::make_pair(had_transfer,
true);
528template<
typename Index_,
typename Data_,
typename Cluster_,
typename Float_,
class Matrix_ = Matrix<Index_, Data_> >
557 Details<Index_> run(
const Matrix_& data,
const Cluster_ ncenters, Float_*
const centers, Cluster_*
const clusters)
const {
558 const auto nobs = data.num_observations();
559 if (internal::is_edge_case(nobs, ncenters)) {
560 return internal::process_edge_case(data, ncenters, centers, clusters);
563 RefineHartiganWong_internal::Workspace<Float_, Index_, Cluster_> work(nobs, ncenters);
565 RefineHartiganWong_internal::find_closest_two_centers(data, ncenters, centers, clusters, work.best_destination_cluster, my_options.
num_threads);
566 for (Index_ obs = 0; obs < nobs; ++obs) {
567 ++work.cluster_sizes[clusters[obs]];
569 internal::compute_centroids(data, ncenters, centers, clusters, work.cluster_sizes);
571 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
572 const Float_ num = work.cluster_sizes[cen];
573 work.gain_multiplier[cen] = num / (num + 1);
574 work.loss_multiplier[cen] = (num > 1 ? num / (num - 1) : RefineHartiganWong_internal::big_number<Float_>());
580 const bool finished = RefineHartiganWong_internal::optimal_transfer(data, work, ncenters, centers, clusters, (iter == 0));
585 const auto quick_status = RefineHartiganWong_internal::quick_transfer(
599 internal::compute_centroids(data, ncenters, centers, clusters, work.cluster_sizes);
601 if (quick_status.second) {
614 if (quick_status.first) {
615 work.optra_steps_since_last_transfer = 0;
618 for (Cluster_ c = 0; c < ncenters; ++c) {
619 work.update_history[c].reset(nobs);
629 return Details(std::move(work.cluster_sizes), iter, ifault);
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the Hartigan-Wong algorithm for k-means clustering.
Definition RefineHartiganWong.hpp:529
RefineHartiganWongOptions & get_options()
Definition RefineHartiganWong.hpp:549
RefineHartiganWong(RefineHartiganWongOptions options)
Definition RefineHartiganWong.hpp:534
RefineHartiganWong()=default
Interface for k-means refinement algorithms.
Definition Refine.hpp:30
virtual Details< Index_ > run(const Matrix_ &data, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters) const =0
Perform k-means clustering.
Definition compute_wcss.hpp:16
void parallelize(const int num_workers, const Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Additional statistics from the k-means algorithm.
Definition Details.hpp:20
Options for RefineHartiganWong.
Definition RefineHartiganWong.hpp:30
bool quit_on_quick_transfer_convergence_failure
Definition RefineHartiganWong.hpp:49
int max_iterations
Definition RefineHartiganWong.hpp:35
int num_threads
Definition RefineHartiganWong.hpp:55
int max_quick_transfer_iterations
Definition RefineHartiganWong.hpp:41