mnncorrect
Batch correction with mutual nearest neighbors
Loading...
Searching...
No Matches
compute.hpp
Go to the documentation of this file.
1#ifndef MNNCORRECT_COMPUTE_HPP
2#define MNNCORRECT_COMPUTE_HPP
3
4#include <algorithm>
5#include <vector>
6#include <numeric>
7#include <stdexcept>
8#include <cstddef>
9
10#include "knncolle/knncolle.hpp"
11
12#include "AutomaticOrder.hpp"
13#include "CustomOrder.hpp"
14#include "Options.hpp"
15#include "restore_order.hpp"
16#include "utils.hpp"
17
24namespace mnncorrect {
25
29struct Details {
33 Details() = default;
34
35 Details(std::vector<BatchIndex> merge_order, std::vector<unsigned long long> num_pairs) : merge_order(std::move(merge_order)), num_pairs(std::move(num_pairs)) {}
45 std::vector<BatchIndex> merge_order;
46
51 std::vector<unsigned long long> num_pairs;
52};
53
57namespace internal {
58
59template<typename Index_, typename Float_, class Matrix_>
60Details compute(std::size_t num_dim, const std::vector<Index_>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
61 auto builder = options.builder;
62 if (!builder) {
64 builder.reset(new knncolle::VptreeBuilder<Index_, Float_, Float_, Matrix_, Euclidean>(std::make_shared<Euclidean>()));
65 }
66
67 if (!options.order.empty()) {
68 CustomOrder<Index_, Float_, Matrix_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, options.order, options.mass_cap, options.num_threads);
69 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
70 return Details(runner.get_order(), runner.get_num_pairs());
71
72 } else if (options.automatic_order) {
73 AutomaticOrder<Index_, Float_, Matrix_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, options.reference_policy, options.mass_cap, options.num_threads);
74 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
75 return Details(runner.get_order(), runner.get_num_pairs());
76
77 } else {
78 std::vector<BatchIndex> trivial_order(num_obs.size());
79 std::iota(trivial_order.begin(), trivial_order.end(), static_cast<BatchIndex>(0));
80 CustomOrder<Index_, Float_, Matrix_> runner(num_dim, num_obs, batches, output, *builder, options.num_neighbors, trivial_order, options.mass_cap, options.num_threads);
81 runner.run(options.num_mads, options.robust_iterations, options.robust_trim);
82 return Details(std::move(trivial_order), runner.get_num_pairs());
83 }
84}
85
86}
133template<typename Index_, typename Float_, class Matrix_>
134Details compute(std::size_t num_dim, const std::vector<Index_>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
135 auto stats = internal::compute(num_dim, num_obs, batches, output, options);
136 internal::restore_order(num_dim, stats.merge_order, num_obs, output);
137 return stats;
138}
139
163template<typename Index_, typename Float_, class Matrix_>
164Details compute(std::size_t num_dim, const std::vector<Index_>& num_obs, const Float_* input, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
165 std::vector<const Float_*> batches;
166 batches.reserve(num_obs.size());
167 for (auto n : num_obs) {
168 batches.push_back(input);
169 input += static_cast<std::size_t>(n) * num_dim; // cast to size_t's to avoid overflow.
170 }
171 return compute(num_dim, num_obs, batches, output, options);
172}
173
197template<typename Index_, typename Float_, typename Batch_, class Matrix_>
198Details compute(std::size_t num_dim, Index_ num_obs, const Float_* input, const Batch_* batch, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
199 const BatchIndex nbatches = (num_obs ? static_cast<BatchIndex>(*std::max_element(batch, batch + num_obs)) + 1 : 0);
200 std::vector<Index_> sizes(nbatches);
201 for (Index_ o = 0; o < num_obs; ++o) {
202 ++sizes[batch[o]];
203 }
204
205 // Avoiding the need to allocate a temporary buffer
206 // if we're already dealing with contiguous batches.
207 bool already_sorted = true;
208 for (Index_ o = 1; o < num_obs; ++o) {
209 if (batch[o] < batch[o-1]) {
210 already_sorted = false;
211 break;
212 }
213 }
214 if (already_sorted) {
215 return compute(num_dim, sizes, input, output, options);
216 }
217
218 std::size_t accumulated = 0; // use size_t to avoid overflow issues during later multiplication.
219 std::vector<std::size_t> offsets(nbatches);
220 for (BatchIndex b = 0; b < nbatches; ++b) {
221 offsets[b] = accumulated;
222 accumulated += sizes[b];
223 }
224
225 // Dumping everything by order into another vector.
226 std::vector<Float_> tmp(num_dim * static_cast<std::size_t>(num_obs)); // cast to size_t to avoid overflow.
227 std::vector<const Float_*> ptrs(nbatches);
228 for (BatchIndex b = 0; b < nbatches; ++b) {
229 ptrs[b] = tmp.data() + offsets[b] * num_dim; // already size_t's, so no need to cast to avoid overflow.
230 }
231
232 for (Index_ o = 0; o < num_obs; ++o) {
233 auto current = input + static_cast<std::size_t>(o) * num_dim; // cast to size_t to avoid overflow.
234 auto& offset = offsets[batch[o]];
235 auto destination = tmp.data() + num_dim * offset; // already size_t's, so no need to cast to avoid overflow.
236 std::copy_n(current, num_dim, destination);
237 ++offset;
238 }
239
240 auto stats = internal::compute(num_dim, sizes, ptrs, output, options);
241 internal::restore_order(num_dim, stats.merge_order, sizes, batch, output);
242 return stats;
243}
244
245}
246
247#endif
Options for MNN correction.
Batch correction with mutual nearest neighbors.
Definition compute.hpp:24
std::size_t BatchIndex
Definition utils.hpp:20
Details compute(std::size_t num_dim, const std::vector< Index_ > &num_obs, const std::vector< const Float_ * > &batches, Float_ *output, const Options< Index_, Float_, Matrix_ > &options)
Definition compute.hpp:134
Correction details from compute().
Definition compute.hpp:29
std::vector< unsigned long long > num_pairs
Definition compute.hpp:51
std::vector< BatchIndex > merge_order
Definition compute.hpp:45
Options for compute().
Definition Options.hpp:23
Index_ mass_cap
Definition Options.hpp:96
int robust_iterations
Definition Options.hpp:77
double robust_trim
Definition Options.hpp:83
std::shared_ptr< knncolle::Builder< Index_, Float_, Float_, Matrix_ > > builder
Definition Options.hpp:44
double num_mads
Definition Options.hpp:38
int num_threads
Definition Options.hpp:102
ReferencePolicy reference_policy
Definition Options.hpp:88
int num_neighbors
Definition Options.hpp:32
std::vector< BatchIndex > order
Definition Options.hpp:58
bool automatic_order
Definition Options.hpp:71
Utilities for MNN correction.