scran_aggregate
Aggregate expression values across cells
Loading...
Searching...
No Matches
aggregate_across_cells.hpp
Go to the documentation of this file.
1#ifndef SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
2#define SCRAN_AGGREGATE_AGGREGATE_ACROSS_CELLS_HPP
3
4#include <algorithm>
5#include <vector>
6
7#include "tatami/tatami.hpp"
8#include "tatami_stats/tatami_stats.hpp"
9
15namespace scran_aggregate {
16
25 bool compute_sums = true;
26
31 bool compute_detected = true;
32
37 int num_threads = 1;
38};
39
45template <typename Sum_, typename Detected_>
54 std::vector<Sum_*> sums;
55
63 std::vector<Detected_*> detected;
64
65};
66
72template <typename Sum_, typename Detected_>
81 std::vector<std::vector<Sum_> > sums;
82
90 std::vector<std::vector<Detected_> > detected;
91};
92
96namespace internal {
97
98template<bool sparse_, typename Data_, typename Index_, typename Factor_, typename Sum_, typename Detected_>
99void compute_aggregate_by_row(
101 const Factor_* factor,
103 const AggregateAcrossCellsOptions& options)
104{
105 tatami::Options opt;
106 opt.sparse_ordered_index = false;
107
108 tatami::parallelize([&](size_t, Index_ s, Index_ l) {
109 auto ext = tatami::consecutive_extractor<sparse_>(&p, true, s, l, opt);
110 size_t nsums = buffers.sums.size();
111 std::vector<Sum_> tmp_sums(nsums);
112 size_t ndetected = buffers.detected.size();
113 std::vector<Detected_> tmp_detected(ndetected);
114
115 auto NC = p.ncol();
116 std::vector<Data_> vbuffer(NC);
117 typename std::conditional<sparse_, std::vector<Index_>, Index_>::type ibuffer(NC);
118
119 for (Index_ x = s, end = s + l; x < end; ++x) {
120 auto row = [&]() {
121 if constexpr(sparse_) {
122 return ext->fetch(vbuffer.data(), ibuffer.data());
123 } else {
124 return ext->fetch(vbuffer.data());
125 }
126 }();
127
128 if (nsums) {
129 std::fill(tmp_sums.begin(), tmp_sums.end(), 0);
130
131 if constexpr(sparse_) {
132 for (Index_ j = 0; j < row.number; ++j) {
133 tmp_sums[factor[row.index[j]]] += row.value[j];
134 }
135 } else {
136 for (Index_ j = 0; j < NC; ++j) {
137 tmp_sums[factor[j]] += row[j];
138 }
139 }
140
141 // Computing before transferring for more cache-friendliness.
142 for (size_t l = 0; l < nsums; ++l) {
143 buffers.sums[l][x] = tmp_sums[l];
144 }
145 }
146
147 if (ndetected) {
148 std::fill(tmp_detected.begin(), tmp_detected.end(), 0);
149
150 if constexpr(sparse_) {
151 for (Index_ j = 0; j < row.number; ++j) {
152 tmp_detected[factor[row.index[j]]] += (row.value[j] > 0);
153 }
154 } else {
155 for (Index_ j = 0; j < NC; ++j) {
156 tmp_detected[factor[j]] += (row[j] > 0);
157 }
158 }
159
160 for (size_t l = 0; l < ndetected; ++l) {
161 buffers.detected[l][x] = tmp_detected[l];
162 }
163 }
164 }
165 }, p.nrow(), options.num_threads);
166}
167
168template<bool sparse_, typename Data_, typename Index_, typename Factor_, typename Sum_, typename Detected_>
169void compute_aggregate_by_column(
171 const Factor_* factor,
172 const AggregateAcrossCellsBuffers<Sum_, Detected_>& buffers,
173 const AggregateAcrossCellsOptions& options)
174{
175 tatami::Options opt;
176 opt.sparse_ordered_index = false;
177
178 tatami::parallelize([&](size_t t, Index_ start, Index_ length) {
179 auto NC = p.ncol();
180 auto ext = tatami::consecutive_extractor<sparse_>(&p, false, static_cast<Index_>(0), NC, start, length, opt);
181 std::vector<Data_> vbuffer(length);
182 typename std::conditional<sparse_, std::vector<Index_>, Index_>::type ibuffer(length);
183
184 size_t num_sums = buffers.sums.size();
185 auto get_sum = [&](Index_ i) -> Sum_* { return buffers.sums[i]; };
186 tatami_stats::LocalOutputBuffers<Sum_, decltype(get_sum)> local_sums(t, num_sums, start, length, std::move(get_sum));
187 auto get_detected = [&](Index_ i) -> Detected_* { return buffers.detected[i]; };
188 size_t num_detected = buffers.detected.size();
189 tatami_stats::LocalOutputBuffers<Detected_, decltype(get_detected)> local_detected(t, num_detected, start, length, std::move(get_detected));
190
191 for (Index_ x = 0; x < NC; ++x) {
192 auto current = factor[x];
193
194 if constexpr(sparse_) {
195 auto col = ext->fetch(vbuffer.data(), ibuffer.data());
196 if (num_sums) {
197 auto cursum = local_sums.data(current);
198 for (Index_ i = 0; i < col.number; ++i) {
199 cursum[col.index[i] - start] += col.value[i];
200 }
201 }
202 if (num_detected) {
203 auto curdetected = local_detected.data(current);
204 for (Index_ i = 0; i < col.number; ++i) {
205 curdetected[col.index[i] - start] += (col.value[i] > 0);
206 }
207 }
208
209 } else {
210 auto col = ext->fetch(vbuffer.data());
211 if (num_sums) {
212 auto cursum = local_sums.data(current);
213 for (Index_ i = 0; i < length; ++i) {
214 cursum[i] += col[i];
215 }
216 }
217 if (num_detected) {
218 auto curdetected = local_detected.data(current);
219 for (Index_ i = 0; i < length; ++i) {
220 curdetected[i] += (col[i] > 0);
221 }
222 }
223 }
224 }
225
226 local_sums.transfer();
227 local_detected.transfer();
228 }, p.nrow(), options.num_threads);
229}
230
231}
255template<typename Data_, typename Index_, typename Factor_, typename Sum_, typename Detected_>
258 const Factor_* factor,
260 const AggregateAcrossCellsOptions& options)
261{
262 if (input.prefer_rows()) {
263 if (input.sparse()) {
264 internal::compute_aggregate_by_row<true>(input, factor, buffers, options);
265 } else {
266 internal::compute_aggregate_by_row<false>(input, factor, buffers, options);
267 }
268 } else {
269 if (input.sparse()) {
270 internal::compute_aggregate_by_column<true>(input, factor, buffers, options);
271 } else {
272 internal::compute_aggregate_by_column<false>(input, factor, buffers, options);
273 }
274 }
275}
276
294template<typename Sum_ = double, typename Detected_ = int, typename Data_, typename Index_, typename Factor_>
297 const Factor_* factor,
298 const AggregateAcrossCellsOptions& options)
299{
300 size_t NC = input.ncol();
301 size_t nlevels = (NC ? *std::max_element(factor, factor + NC) + 1 : 0);
302 size_t ngenes = input.nrow();
303
306
307 if (options.compute_sums) {
308 output.sums.resize(nlevels, std::vector<Sum_>(ngenes
309#ifdef SCRAN_AGGREGATE_TEST_INIT
310 , SCRAN_AGGREGATE_TEST_INIT
311#endif
312 ));
313 buffers.sums.resize(nlevels);
314 for (size_t l = 0; l < nlevels; ++l) {
315 buffers.sums[l] = output.sums[l].data();
316 }
317 }
318
319 if (options.compute_detected) {
320 output.detected.resize(nlevels, std::vector<Detected_>(ngenes
321#ifdef SCRAN_AGGREGATE_TEST_INIT
322 , SCRAN_AGGREGATE_TEST_INIT
323#endif
324 ));
325 buffers.detected.resize(nlevels);
326 for (size_t l = 0; l < nlevels; ++l) {
327 buffers.detected[l] = output.detected[l].data();
328 }
329 }
330
331 aggregate_across_cells(input, factor, buffers, options);
332 return output;
333}
334
335}
336
337#endif
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
virtual bool prefer_rows() const=0
virtual std::unique_ptr< MyopicSparseExtractor< Value_, Index_ > > sparse(bool row, const Options &opt) const=0
Aggregate single-cell expression values.
Definition aggregate_across_cells.hpp:15
void aggregate_across_cells(const tatami::Matrix< Data_, Index_ > &input, const Factor_ *factor, const AggregateAcrossCellsBuffers< Sum_, Detected_ > &buffers, const AggregateAcrossCellsOptions &options)
Definition aggregate_across_cells.hpp:256
void parallelize(Function_ fun, Index_ tasks, int threads)
auto consecutive_extractor(const Matrix< Value_, Index_ > *mat, bool row, Index_ iter_start, Index_ iter_length, Args_ &&... args)
Buffers for aggregate_across_cells().
Definition aggregate_across_cells.hpp:46
std::vector< Detected_ * > detected
Definition aggregate_across_cells.hpp:63
std::vector< Sum_ * > sums
Definition aggregate_across_cells.hpp:54
Options for aggregate_across_cells().
Definition aggregate_across_cells.hpp:20
int num_threads
Definition aggregate_across_cells.hpp:37
bool compute_detected
Definition aggregate_across_cells.hpp:31
bool compute_sums
Definition aggregate_across_cells.hpp:25
Results of aggregate_across_cells().
Definition aggregate_across_cells.hpp:73
std::vector< std::vector< Detected_ > > detected
Definition aggregate_across_cells.hpp:90
std::vector< std::vector< Sum_ > > sums
Definition aggregate_across_cells.hpp:81
bool sparse_ordered_index