kmeans
k-means clustering in C++
Loading...
Searching...
No Matches
RefineMiniBatch.hpp
Go to the documentation of this file.
1#ifndef KMEANS_REFINE_MINIBATCH_HPP
2#define KMEANS_REFINE_MINIBATCH_HPP
3
4#include <vector>
5#include <algorithm>
6#include <cstddef>
7#include <random>
8
9#include "sanisizer/sanisizer.hpp"
10#include "aarand/aarand.hpp"
11
12#include "Refine.hpp"
13#include "Details.hpp"
14#include "QuickSearch.hpp"
15#include "is_edge_case.hpp"
16#include "parallelize.hpp"
17
24namespace kmeans {
25
29typedef std::mt19937_64 RefineMiniBatchRng;
30
39 int max_iterations = 100;
40
45 int batch_size = 500;
46
51 double max_change_proportion = 0.01;
52
58
62 typename RefineMiniBatchRng::result_type seed = sanisizer::cap<typename RefineMiniBatchRng::result_type>(1234567890);
63
68 int num_threads = 1;
69};
70
99template<typename Index_, typename Data_, typename Cluster_, typename Float_, typename Matrix_ = Matrix<Index_, Data_> >
100class RefineMiniBatch : public Refine<Index_, Data_, Cluster_, Float_, Matrix_> {
101public:
105 RefineMiniBatch(RefineMiniBatchOptions options) : my_options(std::move(options)) {}
106
110 RefineMiniBatch() = default;
111
112public:
118 return my_options;
119 }
120
121private:
122 RefineMiniBatchOptions my_options;
123
124public:
128 Details<Index_> run(const Matrix_& data, const Cluster_ ncenters, Float_* const centers, Cluster_* const clusters) const {
129 const auto nobs = data.num_observations();
130 if (internal::is_edge_case(nobs, ncenters)) {
131 return internal::process_edge_case(data, ncenters, centers, clusters);
132 }
133
134 auto total_sampled = sanisizer::create<std::vector<unsigned long long> >(ncenters); // holds the number of sampled observations across iterations, so we need a large integer.
135 auto last_changed = sanisizer::create<std::vector<unsigned long long> >(ncenters); // holds the number of sampled/changed observation for the last few iterations.
136 auto last_sampled = sanisizer::create<std::vector<unsigned long long> >(ncenters);
137 auto previous = sanisizer::create<std::vector<Cluster_> >(nobs);
138
139 const I<decltype(nobs)> actual_batch_size = sanisizer::min(nobs, my_options.batch_size);
140 sanisizer::cast<std::size_t>(actual_batch_size); // check that static_cast for new_known_extractor() calls will be safe.
141 auto chosen = sanisizer::create<std::vector<Index_> >(actual_batch_size);
142 RefineMiniBatchRng eng(my_options.seed);
143
144 const auto ndim = data.num_dimensions();
145 internal::QuickSearch<Float_, Cluster_> index(ndim, ncenters);
146
147 I<decltype(my_options.max_iterations)> iter = 0;
148 for (; iter < my_options.max_iterations; ++iter) {
149 aarand::sample(nobs, actual_batch_size, chosen.data(), eng);
150 if (iter > 0) {
151 for (const auto o : chosen) {
152 previous[o] = clusters[o];
153 }
154 }
155
156 index.reset(centers);
157 parallelize(my_options.num_threads, actual_batch_size, [&](const int, const Index_ start, const Index_ length) -> void {
158 auto matwork = data.new_known_extractor(chosen.data() + start, static_cast<std::size_t>(length));
159 auto qswork = index.new_workspace();
160 for (Index_ s = start, end = start + length; s < end; ++s) {
161 const auto ptr = matwork->get_observation();
162 clusters[chosen[s]] = index.find(ptr, qswork);
163 }
164 });
165
166 // Updating the means for each cluster.
167 auto work = data.new_known_extractor(chosen.data(), static_cast<std::size_t>(chosen.size()));
168 for (const auto o : chosen) {
169 const auto c = clusters[o];
170 auto& n = total_sampled[c];
171 ++n;
172
173 const auto ocopy = work->get_observation();
174 for (I<decltype(ndim)> d = 0; d < ndim; ++d) {
175 auto& curcenter = centers[sanisizer::nd_offset<std::size_t>(d, ndim, c)];
176 curcenter += (static_cast<Float_>(ocopy[d]) - curcenter) / n; // cast to ensure consistent precision regardless of Matrix_::data_type.
177 }
178 }
179
180 // Checking for updates.
181 if (iter != 0) {
182 for (const auto o : chosen) {
183 const auto p = previous[o];
184 ++(last_sampled[p]);
185 const auto c = clusters[o];
186 if (p != c) {
187 ++(last_sampled[c]);
188 ++(last_changed[p]);
189 ++(last_changed[c]);
190 }
191 }
192
193 if (iter % my_options.convergence_history == 0) {
194 bool too_many_changes = false;
195 for (Cluster_ c = 0; c < ncenters; ++c) {
196 if (static_cast<double>(last_changed[c]) >= static_cast<double>(last_sampled[c]) * my_options.max_change_proportion) {
197 too_many_changes = true;
198 break;
199 }
200 }
201
202 if (!too_many_changes) {
203 break;
204 }
205 std::fill(last_sampled.begin(), last_sampled.end(), 0);
206 std::fill(last_changed.begin(), last_changed.end(), 0);
207 }
208 }
209 }
210
211 // Run through all observations to make sure they have the latest cluster assignments.
212 index.reset(centers);
213 parallelize(my_options.num_threads, nobs, [&](const int, const Index_ start, const Index_ length) -> void {
214 auto matwork = data.new_known_extractor(start, length);
215 auto qswork = index.new_workspace();
216 for (Index_ s = start, end = start + length; s < end; ++s) {
217 const auto ptr = matwork->get_observation();
218 clusters[s] = index.find(ptr, qswork);
219 }
220 });
221
222 auto cluster_sizes = sanisizer::create<std::vector<Index_> >(ncenters);
223 for (Index_ o = 0; o < nobs; ++o) {
224 ++cluster_sizes[clusters[o]];
225 }
226 internal::compute_centroids(data, ncenters, centers, clusters, cluster_sizes);
227
228 int status = 0;
229 if (iter == my_options.max_iterations) {
230 status = 2;
231 } else {
232 ++iter; // make it 1-based.
233 }
234 return Details<Index_>(std::move(cluster_sizes), iter, status);
235 }
239};
240
241}
242
243#endif
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the mini-batch algorithm for k-means clustering.
Definition RefineMiniBatch.hpp:100
RefineMiniBatchOptions & get_options()
Definition RefineMiniBatch.hpp:117
RefineMiniBatch(RefineMiniBatchOptions options)
Definition RefineMiniBatch.hpp:105
Interface for k-means refinement algorithms.
Definition Refine.hpp:30
virtual Details< Index_ > run(const Matrix_ &data, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters) const =0
Perform k-means clustering.
Definition compute_wcss.hpp:16
std::mt19937_64 RefineMiniBatchRng
Definition RefineMiniBatch.hpp:29
void parallelize(const int num_workers, const Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Additional statistics from the k-means algorithm.
Definition Details.hpp:20
Options for RefineMiniBatch.
Definition RefineMiniBatch.hpp:34
RefineMiniBatchRng::result_type seed
Definition RefineMiniBatch.hpp:62
int max_iterations
Definition RefineMiniBatch.hpp:39
double max_change_proportion
Definition RefineMiniBatch.hpp:51
int convergence_history
Definition RefineMiniBatch.hpp:57
int num_threads
Definition RefineMiniBatch.hpp:68
int batch_size
Definition RefineMiniBatch.hpp:45