scran_aggregate
Aggregate expression values across cells
Loading...
Searching...
No Matches
aggregate_across_cells.hpp
Go to the documentation of this file.
1#ifndef SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
2#define SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
3
4#include <algorithm>
5#include <vector>
6#include <cstddef>
7#include <type_traits>
8
9#include "tatami/tatami.hpp"
10#include "tatami_stats/tatami_stats.hpp"
11#include "sanisizer/sanisizer.hpp"
12
13#include "utils.hpp"
14
20namespace scran_aggregate {
21
30 bool compute_sums = true;
31
36 bool compute_detected = true;
37
42 int num_threads = 1;
43};
44
50template <typename Sum_, typename Detected_>
59 std::vector<Sum_*> sums;
60
68 std::vector<Detected_*> detected;
69
70};
71
77template <typename Sum_, typename Detected_>
86 std::vector<std::vector<Sum_> > sums;
87
95 std::vector<std::vector<Detected_> > detected;
96};
97
101namespace internal {
102
103template<bool sparse_, typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_>
104void compute_aggregate_by_row(
106 const Group_* const group,
108 const AggregateAcrossCellsOptions& options)
109{
110 tatami::Options opt;
111 opt.sparse_ordered_index = false;
112
113 tatami::parallelize([&](const int, const Index_ s, const Index_ l) -> void {
114 auto ext = tatami::consecutive_extractor<sparse_>(p, true, s, l, opt);
115 const auto nsums = buffers.sums.size();
116 auto tmp_sums = sanisizer::create<std::vector<Sum_> >(nsums);
117 const auto ndetected = buffers.detected.size();
118 auto tmp_detected = sanisizer::create<std::vector<Detected_> >(ndetected);
119
120 const auto NC = p.ncol();
122 auto ibuffer = [&]{
123 if constexpr(sparse_) {
125 } else {
126 return false;
127 }
128 }();
129
130 for (Index_ x = s, end = s + l; x < end; ++x) {
131 const auto row = [&]{
132 if constexpr(sparse_) {
133 return ext->fetch(vbuffer.data(), ibuffer.data());
134 } else {
135 return ext->fetch(vbuffer.data());
136 }
137 }();
138
139 if (nsums) {
140 std::fill(tmp_sums.begin(), tmp_sums.end(), 0);
141
142 if constexpr(sparse_) {
143 for (Index_ j = 0; j < row.number; ++j) {
144 tmp_sums[group[row.index[j]]] += row.value[j];
145 }
146 } else {
147 for (Index_ j = 0; j < NC; ++j) {
148 tmp_sums[group[j]] += row[j];
149 }
150 }
151
152 // Computing before transferring for more cache-friendliness.
153 for (decltype(I(nsums)) l = 0; l < nsums; ++l) {
154 buffers.sums[l][x] = tmp_sums[l];
155 }
156 }
157
158 if (ndetected) {
159 std::fill(tmp_detected.begin(), tmp_detected.end(), 0);
160
161 if constexpr(sparse_) {
162 for (Index_ j = 0; j < row.number; ++j) {
163 tmp_detected[group[row.index[j]]] += (row.value[j] > 0);
164 }
165 } else {
166 for (Index_ j = 0; j < NC; ++j) {
167 tmp_detected[group[j]] += (row[j] > 0);
168 }
169 }
170
171 for (decltype(I(ndetected)) l = 0; l < ndetected; ++l) {
172 buffers.detected[l][x] = tmp_detected[l];
173 }
174 }
175 }
176 }, p.nrow(), options.num_threads);
177}
178
179template<bool sparse_, typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_>
180void compute_aggregate_by_column(
182 const Group_* const group,
183 const AggregateAcrossCellsBuffers<Sum_, Detected_>& buffers,
184 const AggregateAcrossCellsOptions& options)
185{
186 tatami::Options opt;
187 opt.sparse_ordered_index = false;
188
189 tatami::parallelize([&](const int t, const Index_ start, const Index_ length) -> void {
190 const auto NC = p.ncol();
191 auto ext = tatami::consecutive_extractor<sparse_>(p, false, static_cast<Index_>(0), NC, start, length, opt);
193 auto ibuffer = [&]{
194 if constexpr(sparse_) {
196 } else {
197 return false;
198 }
199 }();
200
201 const auto num_sums = buffers.sums.size();
202 auto get_sum = [&](Index_ i) -> Sum_* { return buffers.sums[i]; };
203 tatami_stats::LocalOutputBuffers<Sum_, decltype(I(get_sum))> local_sums(t, num_sums, start, length, std::move(get_sum));
204
205 const auto num_detected = buffers.detected.size();
206 auto get_detected = [&](Index_ i) -> Detected_* { return buffers.detected[i]; };
207 tatami_stats::LocalOutputBuffers<Detected_, decltype(I(get_detected))> local_detected(t, num_detected, start, length, std::move(get_detected));
208
209 for (Index_ x = 0; x < NC; ++x) {
210 const auto current = group[x];
211
212 if constexpr(sparse_) {
213 const auto col = ext->fetch(vbuffer.data(), ibuffer.data());
214 if (num_sums) {
215 const auto cursum = local_sums.data(current);
216 for (Index_ i = 0; i < col.number; ++i) {
217 cursum[col.index[i] - start] += col.value[i];
218 }
219 }
220 if (num_detected) {
221 const auto curdetected = local_detected.data(current);
222 for (Index_ i = 0; i < col.number; ++i) {
223 curdetected[col.index[i] - start] += (col.value[i] > 0);
224 }
225 }
226
227 } else {
228 const auto col = ext->fetch(vbuffer.data());
229 if (num_sums) {
230 const auto cursum = local_sums.data(current);
231 for (Index_ i = 0; i < length; ++i) {
232 cursum[i] += col[i];
233 }
234 }
235 if (num_detected) {
236 const auto curdetected = local_detected.data(current);
237 for (Index_ i = 0; i < length; ++i) {
238 curdetected[i] += (col[i] > 0);
239 }
240 }
241 }
242 }
243
244 local_sums.transfer();
245 local_detected.transfer();
246 }, p.nrow(), options.num_threads);
247}
248
249}
272template<typename Data_, typename Index_, typename Group_, typename Sum_, typename Detected_>
275 const Group_* const group,
277 const AggregateAcrossCellsOptions& options)
278{
279 if (input.prefer_rows()) {
280 if (input.sparse()) {
281 internal::compute_aggregate_by_row<true>(input, group, buffers, options);
282 } else {
283 internal::compute_aggregate_by_row<false>(input, group, buffers, options);
284 }
285 } else {
286 if (input.sparse()) {
287 internal::compute_aggregate_by_column<true>(input, group, buffers, options);
288 } else {
289 internal::compute_aggregate_by_column<false>(input, group, buffers, options);
290 }
291 }
292}
293
311template<typename Sum_ = double, typename Detected_ = int, typename Data_, typename Index_, typename Group_>
314 const Group_* const group,
315 const AggregateAcrossCellsOptions& options)
316{
317 const Index_ NR = input.nrow();
318 const Index_ NC = input.ncol();
319 const std::size_t ngroups = [&]{
320 if (NC) {
321 return sanisizer::sum<std::size_t>(*std::max_element(group, group + NC), 1);
322 } else {
323 return static_cast<std::size_t>(0);
324 }
325 }();
326
329
330 if (options.compute_sums) {
331 sanisizer::resize(output.sums, ngroups);
332 sanisizer::resize(buffers.sums, ngroups);
333 for (decltype(I(ngroups)) l = 0; l < ngroups; ++l) {
334 auto& cursum = output.sums[l];
336#ifdef SCRAN_AGGREGATE_TEST_INIT
337 , SCRAN_AGGREGATE_TEST_INIT
338#endif
339 );
340 buffers.sums[l] = cursum.data();
341 }
342 }
343
344 if (options.compute_detected) {
345 sanisizer::resize(output.detected, ngroups);
346 sanisizer::resize(buffers.detected, ngroups);
347 for (decltype(I(ngroups)) l = 0; l < ngroups; ++l) {
348 auto& curdet = output.detected[l];
350#ifdef SCRAN_AGGREGATE_TEST_INIT
351 , SCRAN_AGGREGATE_TEST_INIT
352#endif
353 );
354 buffers.detected[l] = curdet.data();
355 }
356 }
357
358 aggregate_across_cells(input, group, buffers, options);
359 return output;
360}
361
362}
363
364#endif
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
virtual bool prefer_rows() const=0
virtual std::unique_ptr< MyopicSparseExtractor< Value_, Index_ > > sparse(bool row, const Options &opt) const=0
Aggregate single-cell expression values.
Definition aggregate_across_cells.hpp:20
void aggregate_across_cells(const tatami::Matrix< Data_, Index_ > &input, const Group_ *const group, const AggregateAcrossCellsBuffers< Sum_, Detected_ > &buffers, const AggregateAcrossCellsOptions &options)
Definition aggregate_across_cells.hpp:273
void parallelize(Function_ fun, Index_ tasks, int threads)
void resize_container_to_Index_size(Container_ &container, Index_ x, Args_ &&... args)
Container_ create_container_of_Index_size(Index_ x, Args_ &&... args)
auto consecutive_extractor(const Matrix< Value_, Index_ > &matrix, bool row, Index_ iter_start, Index_ iter_length, Args_ &&... args)
Buffers for aggregate_across_cells().
Definition aggregate_across_cells.hpp:51
std::vector< Detected_ * > detected
Definition aggregate_across_cells.hpp:68
std::vector< Sum_ * > sums
Definition aggregate_across_cells.hpp:59
Options for aggregate_across_cells().
Definition aggregate_across_cells.hpp:25
int num_threads
Definition aggregate_across_cells.hpp:42
bool compute_detected
Definition aggregate_across_cells.hpp:36
bool compute_sums
Definition aggregate_across_cells.hpp:30
Results of aggregate_across_cells().
Definition aggregate_across_cells.hpp:78
std::vector< std::vector< Detected_ > > detected
Definition aggregate_across_cells.hpp:95
std::vector< std::vector< Sum_ > > sums
Definition aggregate_across_cells.hpp:86
bool sparse_ordered_index