scran_aggregate
Aggregate expression values across cells
Loading...
Searching...
No Matches
aggregate_across_cells.hpp
Go to the documentation of this file.
1#ifndef SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
2#define SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
3
4#include <algorithm>
5#include <vector>
6#include <cstddef>
7#include <type_traits>
8
9#include "tatami/tatami.hpp"
10#include "tatami_stats/tatami_stats.hpp"
11#include "sanisizer/sanisizer.hpp"
12
13#include "utils.hpp"
14
20namespace scran_aggregate {
21
30 bool compute_sums = true;
31
36 bool compute_detected = true;
37
42 int num_threads = 1;
43};
44
52template <typename Sum_, typename Detected_>
61 std::vector<Sum_*> sums;
62
70 std::vector<Detected_*> detected;
71
72};
73
81template <typename Sum_, typename Detected_>
90 std::vector<std::vector<Sum_> > sums;
91
99 std::vector<std::vector<Detected_> > detected;
100};
101
105template<bool sparse_, typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_>
106void aggregate_across_cells_by_row(
108 const Group_* const group,
110 const AggregateAcrossCellsOptions& options)
111{
112 tatami::Options opt;
113 opt.sparse_ordered_index = false;
114
115 tatami::parallelize([&](const int, const Index_ s, const Index_ l) -> void {
116 auto ext = tatami::consecutive_extractor<sparse_>(p, true, s, l, opt);
117 const auto nsums = buffers.sums.size();
118 auto tmp_sums = sanisizer::create<std::vector<Sum_> >(nsums);
119 const auto ndetected = buffers.detected.size();
120 auto tmp_detected = sanisizer::create<std::vector<Detected_> >(ndetected);
121
122 const auto NC = p.ncol();
124 auto ibuffer = [&]{
125 if constexpr(sparse_) {
127 } else {
128 return false;
129 }
130 }();
131
132 for (Index_ x = s, end = s + l; x < end; ++x) {
133 const auto row = [&]{
134 if constexpr(sparse_) {
135 return ext->fetch(vbuffer.data(), ibuffer.data());
136 } else {
137 return ext->fetch(vbuffer.data());
138 }
139 }();
140
141 if (nsums) {
142 std::fill(tmp_sums.begin(), tmp_sums.end(), 0);
143
144 if constexpr(sparse_) {
145 for (Index_ j = 0; j < row.number; ++j) {
146 tmp_sums[group[row.index[j]]] += row.value[j];
147 }
148 } else {
149 for (Index_ j = 0; j < NC; ++j) {
150 tmp_sums[group[j]] += row[j];
151 }
152 }
153
154 // Computing before transferring for more cache-friendliness.
155 for (I<decltype(nsums)> l = 0; l < nsums; ++l) {
156 buffers.sums[l][x] = tmp_sums[l];
157 }
158 }
159
160 if (ndetected) {
161 std::fill(tmp_detected.begin(), tmp_detected.end(), 0);
162
163 if constexpr(sparse_) {
164 for (Index_ j = 0; j < row.number; ++j) {
165 tmp_detected[group[row.index[j]]] += (row.value[j] > 0);
166 }
167 } else {
168 for (Index_ j = 0; j < NC; ++j) {
169 tmp_detected[group[j]] += (row[j] > 0);
170 }
171 }
172
173 for (I<decltype(ndetected)> l = 0; l < ndetected; ++l) {
174 buffers.detected[l][x] = tmp_detected[l];
175 }
176 }
177 }
178 }, p.nrow(), options.num_threads);
179}
180
181template<bool sparse_, typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_>
182void aggregate_across_cells_by_column(
184 const Group_* const group,
185 const AggregateAcrossCellsBuffers<Sum_, Detected_>& buffers,
186 const AggregateAcrossCellsOptions& options)
187{
188 tatami::Options opt;
189 opt.sparse_ordered_index = false;
190
191 tatami::parallelize([&](const int t, const Index_ start, const Index_ length) -> void {
192 const auto NC = p.ncol();
193 auto ext = tatami::consecutive_extractor<sparse_>(p, false, static_cast<Index_>(0), NC, start, length, opt);
195 auto ibuffer = [&]{
196 if constexpr(sparse_) {
198 } else {
199 return false;
200 }
201 }();
202
203 const auto num_sums = buffers.sums.size();
204 auto get_sum = [&](Index_ i) -> Sum_* { return buffers.sums[i]; };
205 tatami_stats::LocalOutputBuffers<Sum_, I<decltype(get_sum)>> local_sums(t, num_sums, start, length, std::move(get_sum));
206
207 const auto num_detected = buffers.detected.size();
208 auto get_detected = [&](Index_ i) -> Detected_* { return buffers.detected[i]; };
209 tatami_stats::LocalOutputBuffers<Detected_, I<decltype(get_detected)>> local_detected(t, num_detected, start, length, std::move(get_detected));
210
211 for (Index_ x = 0; x < NC; ++x) {
212 const auto current = group[x];
213
214 if constexpr(sparse_) {
215 const auto col = ext->fetch(vbuffer.data(), ibuffer.data());
216 if (num_sums) {
217 const auto cursum = local_sums.data(current);
218 for (Index_ i = 0; i < col.number; ++i) {
219 cursum[col.index[i] - start] += col.value[i];
220 }
221 }
222 if (num_detected) {
223 const auto curdetected = local_detected.data(current);
224 for (Index_ i = 0; i < col.number; ++i) {
225 curdetected[col.index[i] - start] += (col.value[i] > 0);
226 }
227 }
228
229 } else {
230 const auto col = ext->fetch(vbuffer.data());
231 if (num_sums) {
232 const auto cursum = local_sums.data(current);
233 for (Index_ i = 0; i < length; ++i) {
234 cursum[i] += col[i];
235 }
236 }
237 if (num_detected) {
238 const auto curdetected = local_detected.data(current);
239 for (Index_ i = 0; i < length; ++i) {
240 curdetected[i] += (col[i] > 0);
241 }
242 }
243 }
244 }
245
246 local_sums.transfer();
247 local_detected.transfer();
248 }, p.nrow(), options.num_threads);
249}
274template<typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_>
277 const Group_* const group,
279 const AggregateAcrossCellsOptions& options)
280{
281 if (input.prefer_rows()) {
282 if (input.sparse()) {
283 aggregate_across_cells_by_row<true>(input, group, buffers, options);
284 } else {
285 aggregate_across_cells_by_row<false>(input, group, buffers, options);
286 }
287 } else {
288 if (input.sparse()) {
289 aggregate_across_cells_by_column<true>(input, group, buffers, options);
290 } else {
291 aggregate_across_cells_by_column<false>(input, group, buffers, options);
292 }
293 }
294}
295
315template<typename Sum_ = double, typename Detected_ = int, typename Data_, typename Index_, typename Group_>
318 const Group_* const group,
319 const AggregateAcrossCellsOptions& options)
320{
321 const Index_ NR = input.nrow();
322 const Index_ NC = input.ncol();
323 const std::size_t ngroups = [&]{
324 if (NC) {
325 return sanisizer::sum<std::size_t>(*std::max_element(group, group + NC), 1);
326 } else {
327 return static_cast<std::size_t>(0);
328 }
329 }();
330
333
334 if (options.compute_sums) {
335 sanisizer::resize(output.sums, ngroups);
336 sanisizer::resize(buffers.sums, ngroups);
337 for (I<decltype(ngroups)> l = 0; l < ngroups; ++l) {
338 auto& cursum = output.sums[l];
339 tatami::resize_container_to_Index_size<I<decltype(cursum)>>(cursum, NR
340#ifdef SCRAN_AGGREGATE_TEST_INIT
341 , SCRAN_AGGREGATE_TEST_INIT
342#endif
343 );
344 buffers.sums[l] = cursum.data();
345 }
346 }
347
348 if (options.compute_detected) {
349 sanisizer::resize(output.detected, ngroups);
350 sanisizer::resize(buffers.detected, ngroups);
351 for (I<decltype(ngroups)> l = 0; l < ngroups; ++l) {
352 auto& curdet = output.detected[l];
353 tatami::resize_container_to_Index_size<I<decltype(curdet)>>(curdet, NR
354#ifdef SCRAN_AGGREGATE_TEST_INIT
355 , SCRAN_AGGREGATE_TEST_INIT
356#endif
357 );
358 buffers.detected[l] = curdet.data();
359 }
360 }
361
362 aggregate_across_cells(input, group, buffers, options);
363 return output;
364}
365
366}
367
368#endif
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
virtual bool prefer_rows() const=0
virtual std::unique_ptr< MyopicSparseExtractor< Value_, Index_ > > sparse(bool row, const Options &opt) const=0
Aggregate single-cell expression values.
Definition aggregate_across_cells.hpp:20
void aggregate_across_cells(const tatami::Matrix< Data_, Index_ > &input, const Group_ *const group, const AggregateAcrossCellsBuffers< Sum_, Detected_ > &buffers, const AggregateAcrossCellsOptions &options)
Definition aggregate_across_cells.hpp:275
void resize_container_to_Index_size(Container_ &container, const Index_ x, Args_ &&... args)
int parallelize(Function_ fun, const Index_ tasks, const int workers)
Container_ create_container_of_Index_size(const Index_ x, Args_ &&... args)
auto consecutive_extractor(const Matrix< Value_, Index_ > &matrix, const bool row, const Index_ iter_start, const Index_ iter_length, Args_ &&... args)
Buffers for aggregate_across_cells().
Definition aggregate_across_cells.hpp:53
std::vector< Detected_ * > detected
Definition aggregate_across_cells.hpp:70
std::vector< Sum_ * > sums
Definition aggregate_across_cells.hpp:61
Options for aggregate_across_cells().
Definition aggregate_across_cells.hpp:25
int num_threads
Definition aggregate_across_cells.hpp:42
bool compute_detected
Definition aggregate_across_cells.hpp:36
bool compute_sums
Definition aggregate_across_cells.hpp:30
Results of aggregate_across_cells().
Definition aggregate_across_cells.hpp:82
std::vector< std::vector< Detected_ > > detected
Definition aggregate_across_cells.hpp:99
std::vector< std::vector< Sum_ > > sums
Definition aggregate_across_cells.hpp:90
bool sparse_ordered_index