mnncorrect
Batch correction with mutual nearest neighbors
Loading...
Searching...
No Matches
utils.hpp
Go to the documentation of this file.
1#ifndef MNNCORRECT_UTILS_HPP
2#define MNNCORRECT_UTILS_HPP
3
4#include <vector>
5#include <algorithm>
6#include <memory>
7#include <cstddef>
8#include <type_traits>
9
10#include "knncolle/knncolle.hpp"
11
12#ifndef MNNCORRECT_CUSTOM_PARALLEL
13#include "subpar/subpar.hpp"
14#endif
15
21namespace mnncorrect {
22
26typedef std::size_t BatchIndex;
27
43enum class MergePolicy : char { INPUT, SIZE, VARIANCE, RSS };
44
57template<typename Task_, class Run_>
58void parallelize(const int num_workers, const Task_ num_tasks, Run_ run_task_range) {
59#ifndef MNNCORRECT_CUSTOM_PARALLEL
60 // Methods could allocate or throw, so nothrow_ = false is safest.
61 subpar::parallelize_range<false>(num_workers, num_tasks, std::move(run_task_range));
62#else
63 MNNCORRECT_CUSTOM_PARALLEL(num_workers, num_tasks, run_task_range);
64#endif
65}
66
70namespace internal {
71
72template<typename Index_, typename Distance_>
73using NeighborSet = std::vector<std::vector<std::pair<Index_, Distance_> > >;
74
75template<typename Index_, typename Float_>
76struct Corrected {
77 Corrected() = default;
78 Corrected(std::unique_ptr<knncolle::Prebuilt<Index_, Float_, Float_> > index, std::vector<Index_> ids) : index(std::move(index)), ids(std::move(ids)) {}
79 std::unique_ptr<knncolle::Prebuilt<Index_, Float_, Float_> > index;
80 std::vector<Index_> ids;
81};
82
83template<typename Index_, typename Float_>
84struct BatchInfo {
85 Index_ offset, num_obs;
86 std::unique_ptr<knncolle::Prebuilt<Index_, Float_, Float_> > index;
87 std::vector<Corrected<Index_, Float_> > extras;
88};
89
90}
91
92template<typename Input_>
93std::remove_cv_t<std::remove_reference_t<Input_> > I(const Input_ x) {
94 return x;
95}
100}
101
102#endif
Batch correction with mutual nearest neighbors.
Definition utils.hpp:21
MergePolicy
Definition utils.hpp:43
std::size_t BatchIndex
Definition utils.hpp:26
void parallelize(const int num_workers, const Task_ num_tasks, Run_ run_task_range)
Definition utils.hpp:58
void parallelize_range(int num_workers, Task_ num_tasks, Run_ run_task_range)