mumosa
Multi-modal analyses of single-cell data
Loading...
Searching...
No Matches
mumosa.hpp
Go to the documentation of this file.
1#ifndef MUMOSA_HPP
2#define MUMOSA_HPP
3
4#include <vector>
5#include <stdexcept>
6#include <cmath>
7#include <algorithm>
8#include <limits>
9#include <cstddef>
10#include <type_traits>
11
12#include "knncolle/knncolle.hpp"
13#include "tatami_stats/tatami_stats.hpp"
14#include "sanisizer/sanisizer.hpp"
15
25namespace mumosa {
26
30struct Options {
35 int num_neighbors = 20;
36
41 int num_threads = 1;
42};
43
47template<typename Input_>
48std::remove_cv_t<std::remove_reference_t<Input_> > I(Input_ x) {
49 return x;
50}
69template<typename Index_, typename Distance_>
70std::pair<Distance_, Distance_> compute_distance(const Index_ num_cells, Distance_* const distances) {
71 const Distance_ med = tatami_stats::medians::direct(distances, num_cells, /* skip_nan = */ false);
72 Distance_ rmsd = 0;
73 for (Index_ i = 0; i < num_cells; ++i) {
74 const auto d = distances[i];
75 rmsd += d * d;
76 }
77 rmsd = std::sqrt(rmsd);
78 return std::make_pair(med, rmsd);
79}
80
94template<typename Index_, typename Input_, typename Distance_>
95std::pair<Distance_, Distance_> compute_distance(const knncolle::Prebuilt<Index_, Input_, Distance_>& prebuilt, const Options& options) {
96 const Index_ nobs = prebuilt.num_observations();
97 const auto capped_k = knncolle::cap_k(options.num_neighbors, nobs);
98 auto dist = sanisizer::create<std::vector<Distance_> >(nobs);
99
100 knncolle::parallelize(options.num_threads, nobs, [&](const int, const Index_ start, const Index_ length) -> void {
101 const auto searcher = prebuilt.initialize();
102 std::vector<Distance_> distances;
103 for (Index_ i = start, end = start + length; i < end; ++i) {
104 searcher->search(i, capped_k, NULL, &distances);
105 if (distances.size()) {
106 dist[i] = distances.back();
107 }
108 }
109 });
110
111 return compute_distance(nobs, dist.data());
112}
113
132template<typename Index_, typename Input_, typename Distance_, class Matrix_ = knncolle::Matrix<Index_, Input_> >
133std::pair<Distance_, Distance_> compute_distance(
134 const std::size_t num_dim,
135 const Index_ num_cells,
136 const Input_* const data,
138 const Options& options)
139{
140 const auto prebuilt = builder.build_unique(knncolle::SimpleMatrix(num_dim, num_cells, data));
141 return compute_distance(*prebuilt, options);
142}
143
164template<typename Distance_>
165Distance_ compute_scale(const std::pair<Distance_, Distance_>& ref, const std::pair<Distance_, Distance_>& target) {
166 if (target.first == 0 || ref.first == 0) {
167 if (target.second == 0) {
168 return std::numeric_limits<Distance_>::infinity();
169 } else if (ref.second == 0) {
170 return 0;
171 } else {
172 return ref.second / target.second;
173 }
174 } else {
175 return ref.first / target.first;
176 }
177}
178
192template<typename Distance_>
193std::vector<Distance_> compute_scale(const std::vector<std::pair<Distance_, Distance_> >& distances) {
194 const auto ndist = distances.size();
195 auto output = sanisizer::create<std::vector<Distance_> >(ndist);
196
197 // Use the first entry with a non-zero RMSD as the reference.
198 bool found_ref = false;
199 decltype(I(ndist)) ref = 0;
200 for (decltype(I(ndist)) e = 0; e < ndist; ++e) {
201 if (distances[e].second) {
202 found_ref = true;
203 ref = e;
204 break;
205 }
206 }
207
208 // If all of them have a zero RMSD, then all scalings are zero, because it doesn't matter.
209 if (found_ref) {
210 const auto& dref = distances[ref];
211 for (decltype(I(ndist)) e = 0; e < ndist; ++e) {
212 output[e] = (e == ref ? static_cast<Distance_>(1) : compute_scale(dref, distances[e]));
213 }
214 }
215
216 return output;
217}
218
240template<typename Index_, typename Input_, typename Scale_, typename Output_>
241void combine_scaled_embeddings(const std::vector<std::size_t>& num_dims, const Index_ num_cells, const std::vector<Input_*>& embeddings, const std::vector<Scale_>& scaling, Output_* const output) {
242 const auto nembed = num_dims.size();
243 if (embeddings.size() != nembed || scaling.size() != nembed) {
244 throw std::runtime_error("'num_dims', 'embeddings' and 'scale' should have the same length");
245 }
246
247 const std::size_t ntotal = std::accumulate(num_dims.begin(), num_dims.end(), static_cast<std::size_t>(0));
248 std::size_t starting_dim = 0;
249
250 for (decltype(I(nembed)) e = 0; e < nembed; ++e) {
251 const auto curdim = num_dims[e];
252 const auto inptr = embeddings[e];
253 const auto s = scaling[e];
254
255 if (std::isinf(s)) {
256 // If the scaling factor is infinite, it implies that the current
257 // embedding is all-zero, so we just fill with zeros, and move on.
258 for (Index_ c = 0; c < num_cells; ++c) {
259 const auto out_offset = sanisizer::nd_offset<std::size_t>(starting_dim, ntotal, c);
260 std::fill_n(output + out_offset, curdim, 0);
261 }
262 } else {
263 for (Index_ c = 0; c < num_cells; ++c) {
264 for (decltype(I(curdim)) d = 0; d < curdim; ++d) {
265 const auto out_offset = sanisizer::nd_offset<std::size_t>(starting_dim + d, ntotal, c);
266 const auto in_offset = sanisizer::nd_offset<std::size_t>(d, curdim, c);
267 output[out_offset] = inptr[in_offset] * s;
268 }
269 }
270 }
271
272 starting_dim += curdim;
273 }
274}
275
276}
277
278#endif
std::unique_ptr< Prebuilt< Index_, Data_, Distance_ > > build_unique(const Matrix_ &data) const
virtual Index_ num_observations() const=0
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
int cap_k(int k, Index_ num_observations)
Scale multi-modal embeddings to adjust for differences in variance.
Distance_ compute_scale(const std::pair< Distance_, Distance_ > &ref, const std::pair< Distance_, Distance_ > &target)
Definition mumosa.hpp:165
std::pair< Distance_, Distance_ > compute_distance(const Index_ num_cells, Distance_ *const distances)
Definition mumosa.hpp:70
void combine_scaled_embeddings(const std::vector< std::size_t > &num_dims, const Index_ num_cells, const std::vector< Input_ * > &embeddings, const std::vector< Scale_ > &scaling, Output_ *const output)
Definition mumosa.hpp:241
Options for compute_distance().
Definition mumosa.hpp:30
int num_threads
Definition mumosa.hpp:41
int num_neighbors
Definition mumosa.hpp:35