mnncorrect
Batch correction with mutual nearest neighbors
Loading...
Searching...
No Matches
mnncorrect.hpp
Go to the documentation of this file.
1#ifndef MNNCORRECT_HPP
2#define MNNCORRECT_HPP
3
4#include <algorithm>
5#include <vector>
6#include <numeric>
7#include <stdexcept>
8#include <cstddef>
9
10#include "knncolle/knncolle.hpp"
11#include "sanisizer/sanisizer.hpp"
12
13#include "AutomaticOrder.hpp"
14#include "restore_input_order.hpp"
15#include "utils.hpp"
16
26namespace mnncorrect {
27
36template<typename Index_, typename Float_, class Matrix_ = knncolle::Matrix<Index_, Float_> >
37struct Options {
45 int num_neighbors = 15;
46
51 int num_steps = 1;
52
57 std::shared_ptr<knncolle::Builder<Index_, Float_, Float_, Matrix_> > builder;
58
62 MergePolicy merge_policy = MergePolicy::RSS;
63
68 int num_threads = 1;
69};
70
74namespace internal {
75
76template<typename Index_, typename Float_, class Matrix_>
77void compute(const std::size_t num_dim, const std::vector<Index_>& num_obs, const std::vector<const Float_*>& batches, Float_* const output, const Options<Index_, Float_, Matrix_>& options) {
78 auto builder = options.builder;
79 if (!builder) {
81 builder.reset(new knncolle::VptreeBuilder<Index_, Float_, Float_, Matrix_, Euclidean>(std::make_shared<Euclidean>()));
82 }
83
84 AutomaticOrder<Index_, Float_, Matrix_> runner(
85 num_dim,
86 num_obs,
87 batches,
88 output,
89 *builder,
90 options.num_neighbors,
91 options.num_steps,
92 options.merge_policy,
93 options.num_threads
94 );
95
96 runner.merge();
97}
98
99}
147template<typename Index_, typename Float_, class Matrix_>
148void compute(const std::size_t num_dim, const std::vector<Index_>& num_obs, const std::vector<const Float_*>& batches, Float_* const output, const Options<Index_, Float_, Matrix_>& options) {
149 internal::compute(num_dim, num_obs, batches, output, options);
150}
151
173template<typename Index_, typename Float_, class Matrix_>
174void compute(const std::size_t num_dim, const std::vector<Index_>& num_obs, const Float_* const input, Float_* const output, const Options<Index_, Float_, Matrix_>& options) {
175 std::vector<const Float_*> batches;
176 batches.reserve(num_obs.size());
177
178 Index_ accumulated = 0;
179 for (const auto n : num_obs) {
180 batches.push_back(input + sanisizer::product_unsafe<std::size_t>(accumulated, num_dim));
181
182 // After this check, all internal functions may assume that the total number of observations fits in an Index_.
183 accumulated = sanisizer::sum<decltype(I(accumulated))>(accumulated, n);
184 }
185
186 compute(num_dim, num_obs, batches, output, options);
187}
188
210template<typename Index_, typename Float_, typename Batch_, class Matrix_>
211void compute(const std::size_t num_dim, const Index_ num_obs, const Float_* const input, const Batch_* const batch, Float_* const output, const Options<Index_, Float_, Matrix_>& options) {
212 const BatchIndex nbatches = (num_obs ? sanisizer::sum<BatchIndex>(*std::max_element(batch, batch + num_obs), 1) : static_cast<BatchIndex>(0));
213 auto sizes = sanisizer::create<std::vector<Index_> >(nbatches);
214 for (Index_ o = 0; o < num_obs; ++o) {
215 ++sizes[batch[o]];
216 }
217
218 // Avoiding the need to allocate a temporary buffer
219 // if we're already dealing with contiguous batches.
220 bool already_sorted = true;
221 for (Index_ o = 1; o < num_obs; ++o) {
222 if (batch[o] < batch[o-1]) {
223 already_sorted = false;
224 break;
225 }
226 }
227 if (already_sorted) {
228 compute(num_dim, sizes, input, output, options);
229 return;
230 }
231
232 Index_ accumulated = 0;
233 auto offsets = sanisizer::create<std::vector<Index_> >(nbatches);
234 std::vector<Float_> tmp(sanisizer::product<typename std::vector<Float_>::size_type>(num_dim, num_obs));
235 auto ptrs = sanisizer::create<std::vector<const Float_*> >(nbatches);
236 for (BatchIndex b = 0; b < nbatches; ++b) {
237 ptrs[b] = tmp.data() + sanisizer::product_unsafe<std::size_t>(accumulated, num_dim);
238 offsets[b] = accumulated;
239 accumulated += sizes[b]; // this won't overflow as know that num_obs fits in an Index_.
240 }
241
242 for (Index_ o = 0; o < num_obs; ++o) {
243 auto& offset = offsets[batch[o]];
244 std::copy_n(
245 input + sanisizer::product_unsafe<std::size_t>(o, num_dim),
246 num_dim,
247 tmp.data() + sanisizer::product_unsafe<std::size_t>(offset, num_dim)
248 );
249 ++offset;
250 }
251
252 internal::compute(num_dim, sizes, ptrs, output, options);
253 internal::restore_input_order(num_dim, sizes, batch, output);
254}
255
256}
257
258#endif
Batch correction with mutual nearest neighbors.
Definition utils.hpp:21
MergePolicy
Definition utils.hpp:43
std::size_t BatchIndex
Definition utils.hpp:26
void compute(const std::size_t num_dim, const std::vector< Index_ > &num_obs, const std::vector< const Float_ * > &batches, Float_ *const output, const Options< Index_, Float_, Matrix_ > &options)
Definition mnncorrect.hpp:148
Options for compute().
Definition mnncorrect.hpp:37
int num_steps
Definition mnncorrect.hpp:51
std::shared_ptr< knncolle::Builder< Index_, Float_, Float_, Matrix_ > > builder
Definition mnncorrect.hpp:57
int num_threads
Definition mnncorrect.hpp:68
int num_neighbors
Definition mnncorrect.hpp:45
MergePolicy merge_policy
Definition mnncorrect.hpp:62
Utilities for MNN correction.