mnncorrect
Batch correction with mutual nearest neighbors
Loading...
Searching...
No Matches
compute.hpp
Go to the documentation of this file.
1#ifndef MNNCORRECT_COMPUTE_HPP
2#define MNNCORRECT_COMPUTE_HPP
3
4#include <algorithm>
5#include <vector>
6#include <numeric>
7#include <stdexcept>
8#include <cstdint>
9
10#include "knncolle/knncolle.hpp"
11
12#include "AutomaticOrder.hpp"
13#include "CustomOrder.hpp"
14#include "Options.hpp"
15#include "restore_order.hpp"
16
23namespace mnncorrect {
24
28struct Details {
32 Details() = default;
33
34 Details(std::vector<size_t> merge_order, std::vector<size_t> num_pairs) : merge_order(std::move(merge_order)), num_pairs(std::move(num_pairs)) {}
44 std::vector<size_t> merge_order;
45
50 std::vector<size_t> num_pairs;
51};
52
56namespace internal {
57
58template<typename Dim_, typename Index_, typename Float_>
59Details compute(size_t num_dim, const std::vector<size_t>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Dim_, Index_, Float_>& options) {
60 auto builder = options.builder;
61 if (!builder) {
63 }
64
65 if (!options.order.empty()) {
66 { // Running some checks on the 'order' vector.
67 size_t nbatches = num_obs.size();
68 if (options.order.size() != nbatches) {
69 throw std::runtime_error("'order' should have the same length as the number of batches");
70 }
71
72 std::vector<uint8_t> found(nbatches);
73 for (auto o : options.order) {
74 if (o >= nbatches) {
75 throw std::runtime_error("out-of-range batch indices in 'order'");
76 }
77 if (found[o]) {
78 throw std::runtime_error("duplicated batch indices in 'order'");
79 }
80 found[o] = 1;
81 }
82 }
83
84 CustomOrder<Dim_, Index_, Float_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, options.order.data(), options.mass_cap, options.num_threads);
85 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
86 return Details(runner.get_order(), runner.get_num_pairs());
87
88 } else if (options.automatic_order) {
89 AutomaticOrder<Dim_, Index_, Float_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, options.reference_policy, options.mass_cap, options.num_threads);
90 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
91 return Details(runner.get_order(), runner.get_num_pairs());
92
93 } else {
94 std::vector<size_t> trivial_order(num_obs.size());
95 std::iota(trivial_order.begin(), trivial_order.end(), 0);
96 CustomOrder<Dim_, Index_, Float_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, trivial_order.data(), options.mass_cap, options.num_threads);
97 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
98 return Details(std::move(trivial_order), runner.get_num_pairs());
99 }
100}
101
102}
147template<typename Dim_, typename Index_, typename Float_>
148Details compute(size_t num_dim, const std::vector<size_t>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Dim_, Index_, Float_>& options) {
149 auto stats = internal::compute(num_dim, num_obs, batches, output, options);
150 internal::restore_order(num_dim, stats.merge_order, num_obs, output);
151 return stats;
152}
153
175template<typename Dim_, typename Index_, typename Float_>
176Details compute(size_t num_dim, const std::vector<size_t>& num_obs, const Float_* input, Float_* output, const Options<Dim_, Index_, Float_>& options) {
177 std::vector<const Float_*> batches;
178 batches.reserve(num_obs.size());
179 for (auto n : num_obs) {
180 batches.push_back(input);
181 input += n * num_dim; // already size_t's, so no need to worry about overflow.
182 }
183 return compute(num_dim, num_obs, batches, output, options);
184}
185
207template<typename Dim_, typename Index_, typename Float_, typename Batch_>
208Details compute(size_t num_dim, size_t num_obs, const Float_* input, const Batch_* batch, Float_* output, const Options<Dim_, Index_, Float_>& options) {
209 const size_t nbatches = (num_obs ? static_cast<size_t>(*std::max_element(batch, batch + num_obs)) + 1 : 0);
210 std::vector<size_t> sizes(nbatches);
211 for (size_t o = 0; o < num_obs; ++o) {
212 ++sizes[batch[o]];
213 }
214
215 // Avoiding the need to allocate a temporary buffer
216 // if we're already dealing with contiguous batches.
217 bool already_sorted = true;
218 for (size_t o = 1; o < num_obs; ++o) {
219 if (batch[o] < batch[o-1]) {
220 already_sorted = false;
221 break;
222 }
223 }
224 if (already_sorted) {
225 return compute(num_dim, sizes, input, output, options);
226 }
227
228 size_t accumulated = 0;
229 std::vector<size_t> offsets(nbatches);
230 for (size_t b = 0; b < nbatches; ++b) {
231 offsets[b] = accumulated;
232 accumulated += sizes[b];
233 }
234
235 // Dumping everything by order into another vector.
236 std::vector<Float_> tmp(num_dim * num_obs);
237 std::vector<const Float_*> ptrs(nbatches);
238 for (size_t b = 0; b < nbatches; ++b) {
239 ptrs[b] = tmp.data() + offsets[b] * num_dim;
240 }
241
242 for (size_t o = 0; o < num_obs; ++o) {
243 auto current = input + o * num_dim;
244 auto& offset = offsets[batch[o]];
245 auto destination = tmp.data() + num_dim * offset; // already size_t's, so no need to cast to avoid overflow.
246 std::copy_n(current, num_dim, destination);
247 ++offset;
248 }
249
250 auto stats = internal::compute(num_dim, sizes, ptrs, output, options);
251 internal::restore_order(num_dim, stats.merge_order, sizes, batch, output);
252 return stats;
253}
254
255}
256
257#endif
Options for MNN correction.
Batch correction with mutual nearest neighbors.
Definition compute.hpp:23
Details compute(size_t num_dim, const std::vector< size_t > &num_obs, const std::vector< const Float_ * > &batches, Float_ *output, const Options< Dim_, Index_, Float_ > &options)
Definition compute.hpp:148
Correction details from compute().
Definition compute.hpp:28
std::vector< size_t > num_pairs
Definition compute.hpp:50
std::vector< size_t > merge_order
Definition compute.hpp:44
Options for compute().
Definition Options.hpp:21
int num_threads
Definition Options.hpp:100
std::vector< size_t > order
Definition Options.hpp:56
double robust_trim
Definition Options.hpp:81
double num_mads
Definition Options.hpp:36
ReferencePolicy reference_policy
Definition Options.hpp:86
int num_neighbors
Definition Options.hpp:30
std::shared_ptr< knncolle::Builder< knncolle::SimpleMatrix< Dim_, Index_, Float_ >, Float_ > > builder
Definition Options.hpp:42
size_t mass_cap
Definition Options.hpp:94
bool automatic_order
Definition Options.hpp:69
int robust_iterations
Definition Options.hpp:75