scran_variances
Model per-gene variance in expression
Loading...
Searching...
No Matches
model_gene_variances.hpp
Go to the documentation of this file.
1#ifndef SCRAN_MODEL_GENE_VARIANCES_HPP
2#define SCRAN_MODEL_GENE_VARIANCES_HPP
3
4#include <algorithm>
5#include <vector>
6#include <limits>
7#include <cstddef>
8#include <cassert>
9#include <optional>
10
11#include "tatami/tatami.hpp"
12#include "tatami_stats/tatami_stats.hpp"
14#include "sanisizer/sanisizer.hpp"
15
17#include "utils.hpp"
18
24namespace scran_variances {
25
33enum class BlockAveragePolicy : unsigned char { MEAN, QUANTILE, NONE };
34
44 bool trend = true;
45
51
57 BlockAveragePolicy block_average_policy = BlockAveragePolicy::MEAN;
58
69 scran_blocks::WeightPolicy block_weight_policy = scran_blocks::WeightPolicy::VARIABLE;
70
77
81 // Back-compatibility only.
82 bool compute_average = true;
92 double block_quantile = 0.5;
93
98 int num_threads = 1;
99};
100
109template<typename Stat_>
114 Stat_* mean;
115
119 Stat_* variance;
120
126 Stat_* fitted;
127
133 Stat_* residual;
134};
135
153template<typename Value_, typename Index_, typename Stat_>
157 const ModelGeneVariancesOptions& options
158) {
159 tatami_stats::VarianceBuffers<Stat_> vbuf;
160 vbuf.mean = buffers.mean;
161 vbuf.variance = buffers.variance;
162
163 tatami_stats::VarianceOptions vopt;
164 vopt.num_threads = options.num_threads;
165 tatami_stats::variance(true, mat, vbuf, vopt);
166
168 auto fopt = options.fit_variance_trend_options;
169 fopt.num_threads = options.num_threads;
170
171 if (buffers.fitted != NULL && buffers.residual != NULL) {
172 const auto NR = mat.nrow();
173 if (mat.ncol() >= 2) {
174 fit_variance_trend(NR, buffers.mean, buffers.variance, buffers.fitted, buffers.residual, work, fopt);
175 } else {
176 std::fill_n(buffers.fitted, NR, std::numeric_limits<double>::quiet_NaN());
177 std::fill_n(buffers.residual, NR, std::numeric_limits<double>::quiet_NaN());
178 }
179 }
180}
181
186template<typename Stat_>
191 ModelGeneVariancesResults() = default;
192
193 ModelGeneVariancesResults(const std::size_t ngenes, const bool trend) :
194 mean(sanisizer::cast<I<decltype(mean.size())> >(ngenes)
195#ifdef SCRAN_VARIANCES_TEST_INIT
196 , SCRAN_VARIANCES_TEST_INIT
197#endif
198 ),
199 variance(sanisizer::cast<I<decltype(variance.size())> >(ngenes)
200#ifdef SCRAN_VARIANCES_TEST_INIT
201 , SCRAN_VARIANCES_TEST_INIT
202#endif
203 ),
204 fitted(sanisizer::cast<I<decltype(fitted.size())> >(trend ? ngenes : 0)
205#ifdef SCRAN_VARIANCES_TEST_INIT
206 , SCRAN_VARIANCES_TEST_INIT
207#endif
208 ),
209 residual(sanisizer::cast<I<decltype(residual.size())> >(trend ? ngenes : 0)
210#ifdef SCRAN_VARIANCES_TEST_INIT
211 , SCRAN_VARIANCES_TEST_INIT
212#endif
213 )
214 {}
222 std::vector<Stat_> mean;
223
227 std::vector<Stat_> variance;
228
234 std::vector<Stat_> fitted;
235
241 std::vector<Stat_> residual;
242};
243
257template<typename Stat_ = double, typename Value_, typename Index_>
259 ModelGeneVariancesResults<Stat_> output(mat.nrow(), options.trend); // cast is safe, as any tatami Index_ can always fit into a size_t.
260
262 buffers.mean = output.mean.data();
263 buffers.variance = output.variance.data();
264
265 if (options.trend) {
266 buffers.fitted = output.fitted.data();
267 buffers.residual = output.residual.data();
268 } else {
269 buffers.fitted = NULL;
270 buffers.residual = NULL;
271 }
272
273 model_gene_variances(mat, std::move(buffers), options);
274 return output;
275}
276
281template<typename Stat_>
287 std::vector<ModelGeneVariancesBuffers<Stat_> > per_block;
288
299};
300
305template<typename Stat_>
311
312 ModelGeneVariancesBlockedResults(const std::size_t ngenes, const std::size_t nblocks, const bool do_average, const bool do_trend) :
313 average(do_average ? ngenes : 0, do_trend)
314 {
315 per_block.reserve(nblocks);
316 for (I<decltype(nblocks)> b = 0; b < nblocks; ++b) {
317 per_block.emplace_back(ngenes, do_trend);
318 }
319 }
327 std::vector<ModelGeneVariancesResults<Stat_> > per_block;
328
334};
335
339template<typename Stat_, typename Index_>
340void extract_blocked_weights(
341 const std::size_t num_blocks,
342 const std::vector<Stat_>& block_weights,
343 const std::vector<Index_>& block_sizes,
344 const Index_ min_size,
345 std::vector<Stat_>& tmp_weights
346) {
347 assert(sanisizer::is_equal(num_blocks, block_weights.size()));
348 assert(sanisizer::is_equal(num_blocks, block_sizes.size()));
349 tmp_weights.clear();
350 for (std::size_t b = 0; b < num_blocks; ++b) {
351 if (block_sizes[b] < min_size) { // skip blocks with insufficient cells.
352 continue;
353 }
354 tmp_weights.push_back(block_weights[b]);
355 }
356}
357
358template<typename Stat_, typename Index_, class Function_>
359void extract_blocked_pointers(
360 const std::size_t num_blocks,
361 const std::vector<ModelGeneVariancesBuffers<Stat_> >& per_block,
362 const std::vector<Index_>& block_sizes,
363 const Index_ min_size,
364 const Function_ fun,
365 std::vector<Stat_*>& tmp_pointers
366) {
367 assert(sanisizer::is_equal(num_blocks, per_block.size()));
368 assert(sanisizer::is_equal(num_blocks, block_sizes.size()));
369 tmp_pointers.clear();
370 for (std::size_t b = 0; b < num_blocks; ++b) {
371 if (block_sizes[b] < min_size) { // skip blocks with insufficient cells.
372 continue;
373 }
374 tmp_pointers.push_back(fun(per_block[b]));
375 }
376}
406template<typename Value_, typename Index_, typename Block_, typename Stat_>
409 const Block_* const block,
410 const std::size_t num_blocks,
412 const ModelGeneVariancesOptions& options
413) {
414 if (!sanisizer::is_equal(num_blocks, buffers.per_block.size())) {
415 throw std::runtime_error("length of 'buffers.per_block' is not equal to 'num_blocks'");
416 }
417 assert(mat.ncol() == 0 || sanisizer::is_less_than(*std::max_element(block, block + mat.ncol()), num_blocks));
418
419 // Just compute the block sizes here for simplicity.
420 // At some point, tatami_stats::variance() will accept the block sizes as input for greater efficiency.
421 // But, alas, today is not that day.
422 auto block_sizes = sanisizer::create<std::vector<Index_> >(num_blocks);
423 const auto NC = mat.ncol();
424 for (Index_ c = 0; c < NC; ++c) {
425 block_sizes[block[c]] += 1;
426 }
427
428 tatami_stats::GroupVarianceBuffers<Stat_> vbuf;
429 vbuf.mean.reserve(num_blocks);
430 vbuf.variance.reserve(num_blocks);
431 for (std::size_t b = 0; b < num_blocks; ++b) {
432 vbuf.mean.push_back(buffers.per_block[b].mean);
433 vbuf.variance.push_back(buffers.per_block[b].variance);
434 }
435
436 tatami_stats::GroupVarianceOptions vopt;
437 vopt.num_threads = options.num_threads;
438 tatami_stats::group_variance(true, mat, block, num_blocks, vbuf, vopt);
439
441 auto fopt = options.fit_variance_trend_options;
442 fopt.num_threads = options.num_threads;
443 bool all_trends_fitted = true;
444
445 const auto NR = mat.nrow();
446 for (std::size_t b = 0; b < num_blocks; ++b) {
447 const auto& current = buffers.per_block[b];
448 if (current.fitted == NULL || current.residual == NULL) {
449 all_trends_fitted = false;
450 continue;
451 }
452 if (block_sizes[b] >= 2) {
453 fit_variance_trend(NR, current.mean, current.variance, current.fitted, current.residual, work, fopt);
454 } else {
455 std::fill_n(current.fitted, NR, std::numeric_limits<double>::quiet_NaN());
456 std::fill_n(current.residual, NR, std::numeric_limits<double>::quiet_NaN());
457 }
458 }
459
460 const auto ave_means = buffers.average.mean;
461 const auto ave_variances = buffers.average.variance;
462 const auto ave_fitted = buffers.average.fitted;
463 const auto ave_residuals = buffers.average.residual;
464
465 if ((ave_fitted || ave_residuals) && !all_trends_fitted) {
466 throw std::runtime_error("cannot compute average fitted values/residuals without per-block trend fits");
467 }
468
469 std::vector<Stat_*> tmp_pointers;
470 tmp_pointers.reserve(num_blocks);
471
472 if (options.block_average_policy == BlockAveragePolicy::MEAN) {
473 const auto block_weight = scran_blocks::compute_weights<Stat_>(block_sizes, options.block_weight_policy, options.variable_block_weight_parameters);
474 std::vector<Stat_> tmp_weights;
475 tmp_weights.reserve(num_blocks);
476
477 if (ave_means) {
478 extract_blocked_weights(num_blocks, block_weight, block_sizes, static_cast<Index_>(1), tmp_weights);
479 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(1), [](const auto& x) -> Stat_* { return x.mean; }, tmp_pointers);
480 scran_blocks::parallel_weighted_means(NR, tmp_pointers, tmp_weights.data(), ave_means, /* skip_nan = */ false);
481 }
482
483 // Skip blocks without enough cells to compute the variance.
484 extract_blocked_weights(num_blocks, block_weight, block_sizes, static_cast<Index_>(2), tmp_weights);
485
486 if (ave_variances) {
487 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(2), [](const auto& x) -> Stat_* { return x.variance; }, tmp_pointers);
488 scran_blocks::parallel_weighted_means(NR, tmp_pointers, tmp_weights.data(), ave_variances, /* skip_nan = */ false);
489 }
490
491 if (ave_fitted) {
492 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(2), [](const auto& x) -> Stat_* { return x.fitted; }, tmp_pointers);
493 scran_blocks::parallel_weighted_means(NR, tmp_pointers, tmp_weights.data(), ave_fitted, /* skip_nan = */ false);
494 }
495
496 if (ave_residuals) {
497 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(2), [](const auto& x) -> Stat_* { return x.residual; }, tmp_pointers);
498 scran_blocks::parallel_weighted_means(NR, tmp_pointers, tmp_weights.data(), ave_residuals, /* skip_nan = */ false);
499 }
500
501 } else if (options.block_average_policy == BlockAveragePolicy::QUANTILE) {
502 if (ave_means) {
503 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(1), [](const auto& x) -> Stat_* { return x.mean; }, tmp_pointers);
504 scran_blocks::parallel_quantiles(NR, tmp_pointers, options.block_quantile, ave_means, /* skip_nan = */ false);
505 }
506
507 // Again, skip blocks without enough cells to compute the variance.
508
509 if (ave_variances) {
510 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(2), [](const auto& x) -> Stat_* { return x.variance; }, tmp_pointers);
511 scran_blocks::parallel_quantiles(NR, tmp_pointers, options.block_quantile, ave_variances, /* skip_nan = */ false);
512 }
513
514 if (ave_fitted) {
515 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(2), [](const auto& x) -> Stat_* { return x.fitted; }, tmp_pointers);
516 scran_blocks::parallel_quantiles(NR, tmp_pointers, options.block_quantile, ave_fitted, /* skip_nan = */ false);
517 }
518
519 if (ave_residuals) {
520 extract_blocked_pointers(num_blocks, buffers.per_block, block_sizes, static_cast<Index_>(2), [](const auto& x) -> Stat_* { return x.residual; }, tmp_pointers);
521 scran_blocks::parallel_quantiles(NR, tmp_pointers, options.block_quantile, ave_residuals, /* skip_nan = */ false);
522 }
523 }
524}
525
543template<typename Stat_ = double, typename Value_, typename Index_, typename Block_>
546 const Block_* const block,
547 const std::size_t num_blocks,
548 const ModelGeneVariancesOptions& options
549) {
550 const bool do_average = options.compute_average /* for back-compatibility */ && options.block_average_policy != BlockAveragePolicy::NONE;
552 mat.nrow(), // cast is safe, any tatami Index_ can always fit into a size_t.
553 num_blocks,
554 do_average,
555 options.trend
556 );
557
559 sanisizer::resize(buffers.per_block, num_blocks);
560 for (std::size_t b = 0; b < num_blocks; ++b) {
561 auto& current = buffers.per_block[b];
562 current.mean = output.per_block[b].mean.data();
563 current.variance = output.per_block[b].variance.data();
564
565 if (options.trend) {
566 current.fitted = output.per_block[b].fitted.data();
567 current.residual = output.per_block[b].residual.data();
568 } else {
569 current.fitted = NULL;
570 current.residual = NULL;
571 }
572 }
573
574 if (!do_average) {
575 buffers.average.mean = NULL;
576 buffers.average.variance = NULL;
577 buffers.average.fitted = NULL;
578 buffers.average.residual = NULL;
579 } else {
580 buffers.average.mean = output.average.mean.data();
581 buffers.average.variance = output.average.variance.data();
582
583 if (options.trend) {
584 buffers.average.fitted = output.average.fitted.data();
585 buffers.average.residual = output.average.residual.data();
586 } else {
587 buffers.average.fitted = NULL;
588 buffers.average.residual = NULL;
589 }
590 }
591
592 model_gene_variances_blocked(mat, block, num_blocks, buffers, options);
593 return output;
594}
595
596}
597
598#endif
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
Fit a mean-variance trend to log-count data.
void compute_weights(const std::size_t num_blocks, const Size_ *const sizes, const WeightPolicy policy, const VariableWeightParameters &variable, Weight_ *const weights)
void parallel_weighted_means(const std::size_t n, std::vector< Stat_ * > in, const Weight_ *const w, Output_ *const out, const bool skip_nan)
void parallel_quantiles(const std::size_t n, const std::vector< Stat_ * > &in, const double quantile, Output_ *const out, const bool skip_nan)
Variance modelling for single-cell expression data.
Definition choose_highly_variable_genes.hpp:15
void model_gene_variances(const tatami::Matrix< Value_, Index_ > &mat, const ModelGeneVariancesBuffers< Stat_ > buffers, const ModelGeneVariancesOptions &options)
Definition model_gene_variances.hpp:154
void fit_variance_trend(const std::size_t n, const Float_ *const mean, const Float_ *const variance, Float_ *const fitted, Float_ *const residual, FitVarianceTrendWorkspace< Float_ > &workspace, const FitVarianceTrendOptions &options)
Definition fit_variance_trend.hpp:149
void model_gene_variances_blocked(const tatami::Matrix< Value_, Index_ > &mat, const Block_ *const block, const std::size_t num_blocks, const ModelGeneVariancesBlockedBuffers< Stat_ > &buffers, const ModelGeneVariancesOptions &options)
Definition model_gene_variances.hpp:407
BlockAveragePolicy
Definition model_gene_variances.hpp:33
Options for fit_variance_trend().
Definition fit_variance_trend.hpp:24
int num_threads
Definition fit_variance_trend.hpp:96
Workspace for fit_variance_trend().
Definition fit_variance_trend.hpp:105
Buffers for model_gene_variances_blocked().
Definition model_gene_variances.hpp:282
ModelGeneVariancesBuffers< Stat_ > average
Definition model_gene_variances.hpp:298
std::vector< ModelGeneVariancesBuffers< Stat_ > > per_block
Definition model_gene_variances.hpp:287
Results of model_gene_variances_blocked().
Definition model_gene_variances.hpp:306
std::vector< ModelGeneVariancesResults< Stat_ > > per_block
Definition model_gene_variances.hpp:327
ModelGeneVariancesResults< Stat_ > average
Definition model_gene_variances.hpp:333
Buffers for model_gene_variances() and friends.
Definition model_gene_variances.hpp:110
Stat_ * mean
Definition model_gene_variances.hpp:114
Stat_ * fitted
Definition model_gene_variances.hpp:126
Stat_ * variance
Definition model_gene_variances.hpp:119
Stat_ * residual
Definition model_gene_variances.hpp:133
Options for model_gene_variances() and friends.
Definition model_gene_variances.hpp:38
FitVarianceTrendOptions fit_variance_trend_options
Definition model_gene_variances.hpp:50
double block_quantile
Definition model_gene_variances.hpp:92
bool trend
Definition model_gene_variances.hpp:44
BlockAveragePolicy block_average_policy
Definition model_gene_variances.hpp:57
scran_blocks::VariableWeightParameters variable_block_weight_parameters
Definition model_gene_variances.hpp:76
int num_threads
Definition model_gene_variances.hpp:98
scran_blocks::WeightPolicy block_weight_policy
Definition model_gene_variances.hpp:69
Results of model_gene_variances().
Definition model_gene_variances.hpp:187
std::vector< Stat_ > fitted
Definition model_gene_variances.hpp:234
std::vector< Stat_ > variance
Definition model_gene_variances.hpp:227
std::vector< Stat_ > residual
Definition model_gene_variances.hpp:241
std::vector< Stat_ > mean
Definition model_gene_variances.hpp:222