mnncorrect
Batch correction with mutual nearest neighbors
Loading...
Searching...
No Matches
compute.hpp
Go to the documentation of this file.
1#ifndef MNNCORRECT_COMPUTE_HPP
2#define MNNCORRECT_COMPUTE_HPP
3
4#include <algorithm>
5#include <vector>
6#include <numeric>
7#include <stdexcept>
8#include <cstdint>
9
10#include "knncolle/knncolle.hpp"
11
12#include "AutomaticOrder.hpp"
13#include "CustomOrder.hpp"
14#include "Options.hpp"
15#include "restore_order.hpp"
16
23namespace mnncorrect {
24
28struct Details {
32 Details() = default;
33
34 Details(std::vector<size_t> merge_order, std::vector<size_t> num_pairs) : merge_order(std::move(merge_order)), num_pairs(std::move(num_pairs)) {}
44 std::vector<size_t> merge_order;
45
50 std::vector<size_t> num_pairs;
51};
52
56namespace internal {
57
58template<typename Dim_, typename Index_, typename Float_>
59Details compute(size_t num_dim, const std::vector<size_t>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Dim_, Index_, Float_>& options) {
60 auto builder = options.builder;
61 if (!builder) {
63 }
64
65 if (!options.order.empty()) {
66 CustomOrder<Dim_, Index_, Float_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, options.order, options.mass_cap, options.num_threads);
67 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
68 return Details(runner.get_order(), runner.get_num_pairs());
69
70 } else if (options.automatic_order) {
71 AutomaticOrder<Dim_, Index_, Float_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, options.reference_policy, options.mass_cap, options.num_threads);
72 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
73 return Details(runner.get_order(), runner.get_num_pairs());
74
75 } else {
76 std::vector<size_t> trivial_order(num_obs.size());
77 std::iota(trivial_order.begin(), trivial_order.end(), 0);
78 CustomOrder<Dim_, Index_, Float_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, trivial_order, options.mass_cap, options.num_threads);
79 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
80 return Details(std::move(trivial_order), runner.get_num_pairs());
81 }
82}
83
84}
129template<typename Dim_, typename Index_, typename Float_>
130Details compute(size_t num_dim, const std::vector<size_t>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Dim_, Index_, Float_>& options) {
131 auto stats = internal::compute(num_dim, num_obs, batches, output, options);
132 internal::restore_order(num_dim, stats.merge_order, num_obs, output);
133 return stats;
134}
135
157template<typename Dim_, typename Index_, typename Float_>
158Details compute(size_t num_dim, const std::vector<size_t>& num_obs, const Float_* input, Float_* output, const Options<Dim_, Index_, Float_>& options) {
159 std::vector<const Float_*> batches;
160 batches.reserve(num_obs.size());
161 for (auto n : num_obs) {
162 batches.push_back(input);
163 input += n * num_dim; // already size_t's, so no need to worry about overflow.
164 }
165 return compute(num_dim, num_obs, batches, output, options);
166}
167
189template<typename Dim_, typename Index_, typename Float_, typename Batch_>
190Details compute(size_t num_dim, size_t num_obs, const Float_* input, const Batch_* batch, Float_* output, const Options<Dim_, Index_, Float_>& options) {
191 const size_t nbatches = (num_obs ? static_cast<size_t>(*std::max_element(batch, batch + num_obs)) + 1 : 0);
192 std::vector<size_t> sizes(nbatches);
193 for (size_t o = 0; o < num_obs; ++o) {
194 ++sizes[batch[o]];
195 }
196
197 // Avoiding the need to allocate a temporary buffer
198 // if we're already dealing with contiguous batches.
199 bool already_sorted = true;
200 for (size_t o = 1; o < num_obs; ++o) {
201 if (batch[o] < batch[o-1]) {
202 already_sorted = false;
203 break;
204 }
205 }
206 if (already_sorted) {
207 return compute(num_dim, sizes, input, output, options);
208 }
209
210 size_t accumulated = 0;
211 std::vector<size_t> offsets(nbatches);
212 for (size_t b = 0; b < nbatches; ++b) {
213 offsets[b] = accumulated;
214 accumulated += sizes[b];
215 }
216
217 // Dumping everything by order into another vector.
218 std::vector<Float_> tmp(num_dim * num_obs);
219 std::vector<const Float_*> ptrs(nbatches);
220 for (size_t b = 0; b < nbatches; ++b) {
221 ptrs[b] = tmp.data() + offsets[b] * num_dim;
222 }
223
224 for (size_t o = 0; o < num_obs; ++o) {
225 auto current = input + o * num_dim;
226 auto& offset = offsets[batch[o]];
227 auto destination = tmp.data() + num_dim * offset; // already size_t's, so no need to cast to avoid overflow.
228 std::copy_n(current, num_dim, destination);
229 ++offset;
230 }
231
232 auto stats = internal::compute(num_dim, sizes, ptrs, output, options);
233 internal::restore_order(num_dim, stats.merge_order, sizes, batch, output);
234 return stats;
235}
236
237}
238
239#endif
Options for MNN correction.
Batch correction with mutual nearest neighbors.
Definition compute.hpp:23
Details compute(size_t num_dim, const std::vector< size_t > &num_obs, const std::vector< const Float_ * > &batches, Float_ *output, const Options< Dim_, Index_, Float_ > &options)
Definition compute.hpp:130
Correction details from compute().
Definition compute.hpp:28
std::vector< size_t > num_pairs
Definition compute.hpp:50
std::vector< size_t > merge_order
Definition compute.hpp:44
Options for compute().
Definition Options.hpp:21
int num_threads
Definition Options.hpp:100
std::vector< size_t > order
Definition Options.hpp:56
double robust_trim
Definition Options.hpp:81
double num_mads
Definition Options.hpp:36
ReferencePolicy reference_policy
Definition Options.hpp:86
int num_neighbors
Definition Options.hpp:30
std::shared_ptr< knncolle::Builder< knncolle::SimpleMatrix< Dim_, Index_, Float_ >, Float_ > > builder
Definition Options.hpp:42
size_t mass_cap
Definition Options.hpp:94
bool automatic_order
Definition Options.hpp:69
int robust_iterations
Definition Options.hpp:75