mnncorrect
Batch correction with mutual nearest neighbors
Loading...
Searching...
No Matches
mnncorrect.hpp
Go to the documentation of this file.
1#ifndef MNNCORRECT_HPP
2#define MNNCORRECT_HPP
3
4#include <algorithm>
5#include <vector>
6#include <numeric>
7#include <stdexcept>
8#include <cstddef>
9
10#include "knncolle/knncolle.hpp"
11
12#include "AutomaticOrder.hpp"
13#include "restore_input_order.hpp"
14#include "utils.hpp"
15
25namespace mnncorrect {
26
35template<typename Index_, typename Float_, class Matrix_ = knncolle::Matrix<Index_, Float_> >
36struct Options {
42 int num_neighbors = 15;
43
48 int num_steps = 1;
49
54 std::shared_ptr<knncolle::Builder<Index_, Float_, Float_, Matrix_> > builder;
55
59 MergePolicy merge_policy = MergePolicy::RSS;
60
65 int num_threads = 1;
66};
67
71namespace internal {
72
73template<typename Index_, typename Float_, class Matrix_>
74void compute(std::size_t num_dim, const std::vector<Index_>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
75 auto builder = options.builder;
76 if (!builder) {
78 builder.reset(new knncolle::VptreeBuilder<Index_, Float_, Float_, Matrix_, Euclidean>(std::make_shared<Euclidean>()));
79 }
80
81 AutomaticOrder<Index_, Float_, Matrix_> runner(
82 num_dim,
83 num_obs,
84 batches,
85 output,
86 *builder,
87 options.num_neighbors,
88 options.num_steps,
89 options.merge_policy,
90 options.num_threads
91 );
92
93 runner.merge();
94}
95
96}
141template<typename Index_, typename Float_, class Matrix_>
142void compute(std::size_t num_dim, const std::vector<Index_>& num_obs, const std::vector<const Float_*>& batches, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
143 internal::compute(num_dim, num_obs, batches, output, options);
144}
145
167template<typename Index_, typename Float_, class Matrix_>
168void compute(std::size_t num_dim, const std::vector<Index_>& num_obs, const Float_* input, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
169 std::vector<const Float_*> batches;
170 batches.reserve(num_obs.size());
171 for (auto n : num_obs) {
172 batches.push_back(input);
173 input += static_cast<std::size_t>(n) * num_dim; // cast to size_t's to avoid overflow.
174 }
175 compute(num_dim, num_obs, batches, output, options);
176}
177
199template<typename Index_, typename Float_, typename Batch_, class Matrix_>
200void compute(std::size_t num_dim, Index_ num_obs, const Float_* input, const Batch_* batch, Float_* output, const Options<Index_, Float_, Matrix_>& options) {
201 const BatchIndex nbatches = (num_obs ? static_cast<BatchIndex>(*std::max_element(batch, batch + num_obs)) + 1 : 0);
202 std::vector<Index_> sizes(nbatches);
203 for (Index_ o = 0; o < num_obs; ++o) {
204 ++sizes[batch[o]];
205 }
206
207 // Avoiding the need to allocate a temporary buffer
208 // if we're already dealing with contiguous batches.
209 bool already_sorted = true;
210 for (Index_ o = 1; o < num_obs; ++o) {
211 if (batch[o] < batch[o-1]) {
212 already_sorted = false;
213 break;
214 }
215 }
216 if (already_sorted) {
217 compute(num_dim, sizes, input, output, options);
218 return;
219 }
220
221 std::size_t accumulated = 0; // use size_t to avoid overflow issues during later multiplication.
222 std::vector<std::size_t> offsets(nbatches);
223 for (BatchIndex b = 0; b < nbatches; ++b) {
224 offsets[b] = accumulated;
225 accumulated += sizes[b];
226 }
227
228 // Dumping everything by order into another vector.
229 std::vector<Float_> tmp(num_dim * static_cast<std::size_t>(num_obs)); // cast to size_t to avoid overflow.
230 std::vector<const Float_*> ptrs(nbatches);
231 for (BatchIndex b = 0; b < nbatches; ++b) {
232 ptrs[b] = tmp.data() + offsets[b] * num_dim; // already size_t's, so no need to cast to avoid overflow.
233 }
234
235 for (Index_ o = 0; o < num_obs; ++o) {
236 auto current = input + static_cast<std::size_t>(o) * num_dim; // cast to size_t to avoid overflow.
237 auto& offset = offsets[batch[o]];
238 auto destination = tmp.data() + num_dim * offset; // already size_t's, so no need to cast to avoid overflow.
239 std::copy_n(current, num_dim, destination);
240 ++offset;
241 }
242
243 internal::compute(num_dim, sizes, ptrs, output, options);
244 internal::restore_input_order(num_dim, sizes, batch, output);
245}
246
247}
248
249#endif
Batch correction with mutual nearest neighbors.
Definition utils.hpp:20
void compute(std::size_t num_dim, const std::vector< Index_ > &num_obs, const std::vector< const Float_ * > &batches, Float_ *output, const Options< Index_, Float_, Matrix_ > &options)
Definition mnncorrect.hpp:142
MergePolicy
Definition utils.hpp:42
std::size_t BatchIndex
Definition utils.hpp:25
Options for compute().
Definition mnncorrect.hpp:36
int num_steps
Definition mnncorrect.hpp:48
std::shared_ptr< knncolle::Builder< Index_, Float_, Float_, Matrix_ > > builder
Definition mnncorrect.hpp:54
int num_threads
Definition mnncorrect.hpp:65
int num_neighbors
Definition mnncorrect.hpp:42
MergePolicy merge_policy
Definition mnncorrect.hpp:59
Utilities for MNN correction.