scran_aggregate
Aggregate expression values across cells
Loading...
Searching...
No Matches
aggregate_across_cells.hpp
Go to the documentation of this file.
1#ifndef SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
2#define SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
3
4#include <algorithm>
5#include <vector>
6#include <cstddef>
7#include <type_traits>
8#include <cassert>
9
10#include "tatami/tatami.hpp"
11#include "tatami_stats/tatami_stats.hpp"
12#include "sanisizer/sanisizer.hpp"
13
14#include "utils.hpp"
15
21namespace scran_aggregate {
22
31 bool compute_sums = true;
32
37 bool compute_detected = true;
38
43 bool compute_medians = false; // false by default as we usually don't need this.
44
49 int num_threads = 1;
50};
51
60template <typename Sum_, typename Detected_, typename Float_>
69 std::vector<Sum_*> sums;
70
78 std::vector<Detected_*> detected;
79
87 std::vector<Float_*> medians;
88};
89
98template <typename Sum_, typename Detected_, typename Float_>
107 std::vector<std::vector<Sum_> > sums;
108
116 std::vector<std::vector<Detected_> > detected;
117
125 std::vector<std::vector<Float_> > medians;
126};
127
131template<bool sparse_, typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_, typename Float_>
132void aggregate_across_cells_by_row(
134 const Group_* const group,
136 const AggregateAcrossCellsOptions& options
137) {
138 tatami::Options opt;
139 opt.sparse_ordered_index = false;
140
141 std::optional<std::vector<Index_> > group_sizes;
142 const auto NC = p.ncol();
143 if (!buffers.medians.empty()) {
144 group_sizes = tatami_stats::tabulate_groups(group, NC);
145 }
146
147 tatami::parallelize([&](const int, const Index_ s, const Index_ l) -> void {
148 auto ext = tatami::consecutive_extractor<sparse_>(p, true, s, l, opt);
149
150 std::vector<Sum_> tmp_sums;
151 const auto nsums = buffers.sums.size();
152 if (nsums) {
153 sanisizer::resize(tmp_sums, nsums);
154 }
155
156 std::vector<Detected_> tmp_detected;
157 const auto ndetected = buffers.detected.size();
158 if (ndetected) {
159 sanisizer::resize(tmp_detected, ndetected);
160 }
161
162 std::vector<std::vector<Float_> > tmp_medians;
163 const auto nmedians = buffers.medians.size();
164 if (nmedians) {
165 sanisizer::resize(tmp_medians, nmedians);
166 for (I<decltype(nmedians)> l = 0; l < nmedians; ++l) {
167 sanisizer::reserve(tmp_medians[l], (*group_sizes)[l]);
168 }
169 }
170
171 const auto NC = p.ncol();
173 auto ibuffer = [&]{
174 if constexpr(sparse_) {
176 } else {
177 return false;
178 }
179 }();
180
181 for (Index_ x = s, end = s + l; x < end; ++x) {
182 const auto row = [&]{
183 if constexpr(sparse_) {
184 return ext->fetch(vbuffer.data(), ibuffer.data());
185 } else {
186 return ext->fetch(vbuffer.data());
187 }
188 }();
189
190 if (nsums) {
191 std::fill(tmp_sums.begin(), tmp_sums.end(), 0);
192
193 if constexpr(sparse_) {
194 for (Index_ j = 0; j < row.number; ++j) {
195 tmp_sums[group[row.index[j]]] += row.value[j];
196 }
197 } else {
198 for (Index_ j = 0; j < NC; ++j) {
199 tmp_sums[group[j]] += row[j];
200 }
201 }
202
203 // Computing before transferring for more cache-friendliness.
204 for (I<decltype(nsums)> l = 0; l < nsums; ++l) {
205 buffers.sums[l][x] = tmp_sums[l];
206 }
207 }
208
209 if (ndetected) {
210 std::fill(tmp_detected.begin(), tmp_detected.end(), 0);
211
212 if constexpr(sparse_) {
213 for (Index_ j = 0; j < row.number; ++j) {
214 tmp_detected[group[row.index[j]]] += (row.value[j] > 0);
215 }
216 } else {
217 for (Index_ j = 0; j < NC; ++j) {
218 tmp_detected[group[j]] += (row[j] > 0);
219 }
220 }
221
222 for (I<decltype(ndetected)> l = 0; l < ndetected; ++l) {
223 buffers.detected[l][x] = tmp_detected[l];
224 }
225 }
226
227 if (nmedians) {
228 if constexpr(sparse_) {
229 for (Index_ j = 0; j < row.number; ++j) {
230 tmp_medians[group[row.index[j]]].push_back(row.value[j]);
231 }
232 for (I<decltype(ndetected)> l = 0; l < nmedians; ++l) {
233 auto& current = tmp_medians[l];
234 buffers.medians[l][x] = tatami_stats::medians::direct<Float_>(current.data(), static_cast<Index_>(current.size()), (*group_sizes)[l], false);
235 current.clear();
236 }
237
238 } else {
239 for (Index_ j = 0; j < NC; ++j) {
240 tmp_medians[group[j]].push_back(row[j]);
241 }
242 for (I<decltype(ndetected)> l = 0; l < nmedians; ++l) {
243 auto& current = tmp_medians[l];
244 buffers.medians[l][x] = tatami_stats::medians::direct(current.data(), current.size(), false);
245 current.clear();
246 }
247 }
248 }
249 }
250 }, p.nrow(), options.num_threads);
251}
252
253template<bool sparse_, typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_, typename Float_>
254void aggregate_across_cells_by_column(
256 const Group_* const group,
257 const AggregateAcrossCellsBuffers<Sum_, Detected_, Float_>& buffers,
258 const AggregateAcrossCellsOptions& options
259) {
260 tatami::Options opt;
261 opt.sparse_ordered_index = false;
262 assert(buffers.medians.empty());
263
264 tatami::parallelize([&](const int t, const Index_ start, const Index_ length) -> void {
265 const auto NC = p.ncol();
266 auto ext = tatami::consecutive_extractor<sparse_>(p, false, static_cast<Index_>(0), NC, start, length, opt);
268 auto ibuffer = [&]{
269 if constexpr(sparse_) {
271 } else {
272 return false;
273 }
274 }();
275
276 const auto num_sums = buffers.sums.size();
277 auto get_sum = [&](Index_ i) -> Sum_* { return buffers.sums[i]; };
278 tatami_stats::LocalOutputBuffers<Sum_, I<decltype(get_sum)>> local_sums(t, num_sums, start, length, std::move(get_sum));
279
280 const auto num_detected = buffers.detected.size();
281 auto get_detected = [&](Index_ i) -> Detected_* { return buffers.detected[i]; };
282 tatami_stats::LocalOutputBuffers<Detected_, I<decltype(get_detected)>> local_detected(t, num_detected, start, length, std::move(get_detected));
283
284 for (Index_ x = 0; x < NC; ++x) {
285 const auto current = group[x];
286
287 if constexpr(sparse_) {
288 const auto col = ext->fetch(vbuffer.data(), ibuffer.data());
289 if (num_sums) {
290 const auto cursum = local_sums.data(current);
291 for (Index_ i = 0; i < col.number; ++i) {
292 cursum[col.index[i] - start] += col.value[i];
293 }
294 }
295 if (num_detected) {
296 const auto curdetected = local_detected.data(current);
297 for (Index_ i = 0; i < col.number; ++i) {
298 curdetected[col.index[i] - start] += (col.value[i] > 0);
299 }
300 }
301
302 } else {
303 const auto col = ext->fetch(vbuffer.data());
304 if (num_sums) {
305 const auto cursum = local_sums.data(current);
306 for (Index_ i = 0; i < length; ++i) {
307 cursum[i] += col[i];
308 }
309 }
310 if (num_detected) {
311 const auto curdetected = local_detected.data(current);
312 for (Index_ i = 0; i < length; ++i) {
313 curdetected[i] += (col[i] > 0);
314 }
315 }
316 }
317 }
318
319 local_sums.transfer();
320 local_detected.transfer();
321 }, p.nrow(), options.num_threads);
322}
349template<typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_, typename Float_>
352 const Group_* const group,
354 const AggregateAcrossCellsOptions& options
355) {
356 if (input.prefer_rows() || !buffers.medians.empty()) {
357 if (input.sparse()) {
358 aggregate_across_cells_by_row<true>(input, group, buffers, options);
359 } else {
360 aggregate_across_cells_by_row<false>(input, group, buffers, options);
361 }
362 } else {
363 if (input.sparse()) {
364 aggregate_across_cells_by_column<true>(input, group, buffers, options);
365 } else {
366 aggregate_across_cells_by_column<false>(input, group, buffers, options);
367 }
368 }
369}
370
391template<typename Sum_ = double, typename Detected_ = int, typename Float_ = double, typename Data_, typename Index_, typename Group_>
394 const Group_* const group,
395 const AggregateAcrossCellsOptions& options
396) {
397 const Index_ NR = input.nrow();
398 const Index_ NC = input.ncol();
399 const std::size_t ngroups = [&]{
400 if (NC) {
401 return sanisizer::sum<std::size_t>(*std::max_element(group, group + NC), 1);
402 } else {
403 return static_cast<std::size_t>(0);
404 }
405 }();
406
409
410 if (options.compute_sums) {
411 sanisizer::resize(output.sums, ngroups);
412 sanisizer::resize(buffers.sums, ngroups);
413 for (I<decltype(ngroups)> l = 0; l < ngroups; ++l) {
414 auto& cursum = output.sums[l];
415 tatami::resize_container_to_Index_size<I<decltype(cursum)>>(cursum, NR
416#ifdef SCRAN_AGGREGATE_TEST_INIT
417 , SCRAN_AGGREGATE_TEST_INIT
418#endif
419 );
420 buffers.sums[l] = cursum.data();
421 }
422 }
423
424 if (options.compute_detected) {
425 sanisizer::resize(output.detected, ngroups);
426 sanisizer::resize(buffers.detected, ngroups);
427 for (I<decltype(ngroups)> l = 0; l < ngroups; ++l) {
428 auto& curdet = output.detected[l];
429 tatami::resize_container_to_Index_size<I<decltype(curdet)>>(curdet, NR
430#ifdef SCRAN_AGGREGATE_TEST_INIT
431 , SCRAN_AGGREGATE_TEST_INIT
432#endif
433 );
434 buffers.detected[l] = curdet.data();
435 }
436 }
437
438 if (options.compute_medians) {
439 sanisizer::resize(output.medians, ngroups);
440 sanisizer::resize(buffers.medians, ngroups);
441 for (I<decltype(ngroups)> l = 0; l < ngroups; ++l) {
442 auto& curdet = output.medians[l];
443 tatami::resize_container_to_Index_size<I<decltype(curdet)>>(curdet, NR
444#ifdef SCRAN_AGGREGATE_TEST_INIT
445 , SCRAN_AGGREGATE_TEST_INIT
446#endif
447 );
448 buffers.medians[l] = curdet.data();
449 }
450 }
451
452
453 aggregate_across_cells(input, group, buffers, options);
454 return output;
455}
456
457}
458
459#endif
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
virtual bool prefer_rows() const=0
virtual std::unique_ptr< MyopicSparseExtractor< Value_, Index_ > > sparse(bool row, const Options &opt) const=0
Aggregate single-cell expression values.
Definition aggregate_across_cells.hpp:21
void aggregate_across_cells(const tatami::Matrix< Data_, Index_ > &input, const Group_ *const group, const AggregateAcrossCellsBuffers< Sum_, Detected_, Float_ > &buffers, const AggregateAcrossCellsOptions &options)
Definition aggregate_across_cells.hpp:350
void resize_container_to_Index_size(Container_ &container, const Index_ x, Args_ &&... args)
int parallelize(Function_ fun, const Index_ tasks, const int workers)
Container_ create_container_of_Index_size(const Index_ x, Args_ &&... args)
auto consecutive_extractor(const Matrix< Value_, Index_ > &matrix, const bool row, const Index_ iter_start, const Index_ iter_length, Args_ &&... args)
Buffers for aggregate_across_cells().
Definition aggregate_across_cells.hpp:61
std::vector< Float_ * > medians
Definition aggregate_across_cells.hpp:87
std::vector< Sum_ * > sums
Definition aggregate_across_cells.hpp:69
std::vector< Detected_ * > detected
Definition aggregate_across_cells.hpp:78
Options for aggregate_across_cells().
Definition aggregate_across_cells.hpp:26
int num_threads
Definition aggregate_across_cells.hpp:49
bool compute_medians
Definition aggregate_across_cells.hpp:43
bool compute_detected
Definition aggregate_across_cells.hpp:37
bool compute_sums
Definition aggregate_across_cells.hpp:31
Results of aggregate_across_cells().
Definition aggregate_across_cells.hpp:99
std::vector< std::vector< Float_ > > medians
Definition aggregate_across_cells.hpp:125
std::vector< std::vector< Sum_ > > sums
Definition aggregate_across_cells.hpp:107
std::vector< std::vector< Detected_ > > detected
Definition aggregate_across_cells.hpp:116
bool sparse_ordered_index