1#ifndef SCRAN_MODEL_GENE_VARIANCES_HPP
2#define SCRAN_MODEL_GENE_VARIANCES_HPP
12#include "tatami_stats/tatami_stats.hpp"
14#include "sanisizer/sanisizer.hpp"
82 bool compute_average =
true;
109template<
typename Stat_>
153template<
typename Value_,
typename Index_,
typename Stat_>
159 tatami_stats::VarianceBuffers<Stat_> vbuf;
160 vbuf.mean = buffers.
mean;
163 tatami_stats::VarianceOptions vopt;
165 tatami_stats::variance(
true, mat, vbuf, vopt);
172 const auto NR = mat.
nrow();
173 if (mat.
ncol() >= 2) {
176 std::fill_n(buffers.
fitted, NR, std::numeric_limits<double>::quiet_NaN());
177 std::fill_n(buffers.
residual, NR, std::numeric_limits<double>::quiet_NaN());
186template<
typename Stat_>
194 mean(sanisizer::cast<I<
decltype(
mean.size())> >(ngenes)
195#ifdef SCRAN_VARIANCES_TEST_INIT
196 , SCRAN_VARIANCES_TEST_INIT
200#ifdef SCRAN_VARIANCES_TEST_INIT
201 , SCRAN_VARIANCES_TEST_INIT
204 fitted(sanisizer::cast<I<
decltype(
fitted.size())> >(trend ? ngenes : 0)
205#ifdef SCRAN_VARIANCES_TEST_INIT
206 , SCRAN_VARIANCES_TEST_INIT
210#ifdef SCRAN_VARIANCES_TEST_INIT
211 , SCRAN_VARIANCES_TEST_INIT
257template<
typename Stat_ =
double,
typename Value_,
typename Index_>
281template<
typename Stat_>
287 std::vector<ModelGeneVariancesBuffers<Stat_> >
per_block;
305template<
typename Stat_>
313 average(do_average ? ngenes : 0, do_trend)
316 for (I<
decltype(nblocks)> b = 0; b < nblocks; ++b) {
317 per_block.emplace_back(ngenes, do_trend);
327 std::vector<ModelGeneVariancesResults<Stat_> >
per_block;
339template<
typename Stat_,
typename Index_>
340void extract_blocked_weights(
341 const std::size_t num_blocks,
342 const std::vector<Stat_>& block_weights,
343 const std::vector<Index_>& block_sizes,
344 const Index_ min_size,
345 std::vector<Stat_>& tmp_weights
347 assert(sanisizer::is_equal(num_blocks, block_weights.size()));
348 assert(sanisizer::is_equal(num_blocks, block_sizes.size()));
350 for (std::size_t b = 0; b < num_blocks; ++b) {
351 if (block_sizes[b] < min_size) {
354 tmp_weights.push_back(block_weights[b]);
358template<
typename Stat_,
typename Index_,
class Function_>
359void extract_blocked_pointers(
360 const std::size_t num_blocks,
361 const std::vector<ModelGeneVariancesBuffers<Stat_> >& per_block,
362 const std::vector<Index_>& block_sizes,
363 const Index_ min_size,
365 std::vector<Stat_*>& tmp_pointers
367 assert(sanisizer::is_equal(num_blocks, per_block.size()));
368 assert(sanisizer::is_equal(num_blocks, block_sizes.size()));
369 tmp_pointers.clear();
370 for (std::size_t b = 0; b < num_blocks; ++b) {
371 if (block_sizes[b] < min_size) {
374 tmp_pointers.push_back(fun(per_block[b]));
406template<
typename Value_,
typename Index_,
typename Block_,
typename Stat_>
409 const Block_*
const block,
410 const std::size_t num_blocks,
414 if (!sanisizer::is_equal(num_blocks, buffers.
per_block.size())) {
415 throw std::runtime_error(
"length of 'buffers.per_block' is not equal to 'num_blocks'");
417 assert(mat.
ncol() == 0 || sanisizer::is_less_than(*std::max_element(block, block + mat.
ncol()), num_blocks));
422 auto block_sizes = sanisizer::create<std::vector<Index_> >(num_blocks);
423 const auto NC = mat.
ncol();
424 for (Index_ c = 0; c < NC; ++c) {
425 block_sizes[block[c]] += 1;
428 tatami_stats::GroupVarianceBuffers<Stat_> vbuf;
429 vbuf.mean.reserve(num_blocks);
430 vbuf.variance.reserve(num_blocks);
431 for (std::size_t b = 0; b < num_blocks; ++b) {
432 vbuf.mean.push_back(buffers.
per_block[b].mean);
433 vbuf.variance.push_back(buffers.
per_block[b].variance);
436 tatami_stats::GroupVarianceOptions vopt;
438 tatami_stats::group_variance(
true, mat, block, num_blocks, vbuf, vopt);
443 bool all_trends_fitted =
true;
445 const auto NR = mat.
nrow();
446 for (std::size_t b = 0; b < num_blocks; ++b) {
447 const auto& current = buffers.
per_block[b];
448 if (current.fitted == NULL || current.residual == NULL) {
449 all_trends_fitted =
false;
452 if (block_sizes[b] >= 2) {
453 fit_variance_trend(NR, current.mean, current.variance, current.fitted, current.residual, work, fopt);
455 std::fill_n(current.fitted, NR, std::numeric_limits<double>::quiet_NaN());
456 std::fill_n(current.residual, NR, std::numeric_limits<double>::quiet_NaN());
460 const auto ave_means = buffers.
average.mean;
461 const auto ave_variances = buffers.
average.variance;
462 const auto ave_fitted = buffers.
average.fitted;
463 const auto ave_residuals = buffers.
average.residual;
465 if ((ave_fitted || ave_residuals) && !all_trends_fitted) {
466 throw std::runtime_error(
"cannot compute average fitted values/residuals without per-block trend fits");
469 std::vector<Stat_*> tmp_pointers;
470 tmp_pointers.reserve(num_blocks);
474 std::vector<Stat_> tmp_weights;
475 tmp_weights.reserve(num_blocks);
478 extract_blocked_weights(num_blocks, block_weight, block_sizes,
static_cast<Index_
>(1), tmp_weights);
479 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(1), [](
const auto& x) -> Stat_* { return x.mean; }, tmp_pointers);
484 extract_blocked_weights(num_blocks, block_weight, block_sizes,
static_cast<Index_
>(2), tmp_weights);
487 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(2), [](
const auto& x) -> Stat_* { return x.variance; }, tmp_pointers);
492 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(2), [](
const auto& x) -> Stat_* { return x.fitted; }, tmp_pointers);
497 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(2), [](
const auto& x) -> Stat_* { return x.residual; }, tmp_pointers);
503 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(1), [](
const auto& x) -> Stat_* { return x.mean; }, tmp_pointers);
510 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(2), [](
const auto& x) -> Stat_* { return x.variance; }, tmp_pointers);
515 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(2), [](
const auto& x) -> Stat_* { return x.fitted; }, tmp_pointers);
520 extract_blocked_pointers(num_blocks, buffers.
per_block, block_sizes,
static_cast<Index_
>(2), [](
const auto& x) -> Stat_* { return x.residual; }, tmp_pointers);
543template<
typename Stat_ =
double,
typename Value_,
typename Index_,
typename Block_>
546 const Block_*
const block,
547 const std::size_t num_blocks,
550 const bool do_average = options.compute_average && options.
block_average_policy != BlockAveragePolicy::NONE;
559 sanisizer::resize(buffers.
per_block, num_blocks);
560 for (std::size_t b = 0; b < num_blocks; ++b) {
562 current.mean = output.
per_block[b].mean.data();
563 current.variance = output.
per_block[b].variance.data();
566 current.fitted = output.
per_block[b].fitted.data();
567 current.residual = output.
per_block[b].residual.data();
569 current.fitted = NULL;
570 current.residual = NULL;
576 buffers.
average.variance = NULL;
578 buffers.
average.residual = NULL;
588 buffers.
average.residual = NULL;
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
Fit a mean-variance trend to log-count data.
void compute_weights(const std::size_t num_blocks, const Size_ *const sizes, const WeightPolicy policy, const VariableWeightParameters &variable, Weight_ *const weights)
void parallel_weighted_means(const std::size_t n, std::vector< Stat_ * > in, const Weight_ *const w, Output_ *const out, const bool skip_nan)
void parallel_quantiles(const std::size_t n, const std::vector< Stat_ * > &in, const double quantile, Output_ *const out, const bool skip_nan)
Variance modelling for single-cell expression data.
Definition choose_highly_variable_genes.hpp:15
void model_gene_variances(const tatami::Matrix< Value_, Index_ > &mat, const ModelGeneVariancesBuffers< Stat_ > buffers, const ModelGeneVariancesOptions &options)
Definition model_gene_variances.hpp:154
void fit_variance_trend(const std::size_t n, const Float_ *const mean, const Float_ *const variance, Float_ *const fitted, Float_ *const residual, FitVarianceTrendWorkspace< Float_ > &workspace, const FitVarianceTrendOptions &options)
Definition fit_variance_trend.hpp:149
void model_gene_variances_blocked(const tatami::Matrix< Value_, Index_ > &mat, const Block_ *const block, const std::size_t num_blocks, const ModelGeneVariancesBlockedBuffers< Stat_ > &buffers, const ModelGeneVariancesOptions &options)
Definition model_gene_variances.hpp:407
BlockAveragePolicy
Definition model_gene_variances.hpp:33
Options for fit_variance_trend().
Definition fit_variance_trend.hpp:24
int num_threads
Definition fit_variance_trend.hpp:96
Workspace for fit_variance_trend().
Definition fit_variance_trend.hpp:105
Buffers for model_gene_variances_blocked().
Definition model_gene_variances.hpp:282
ModelGeneVariancesBuffers< Stat_ > average
Definition model_gene_variances.hpp:298
std::vector< ModelGeneVariancesBuffers< Stat_ > > per_block
Definition model_gene_variances.hpp:287
Results of model_gene_variances_blocked().
Definition model_gene_variances.hpp:306
std::vector< ModelGeneVariancesResults< Stat_ > > per_block
Definition model_gene_variances.hpp:327
ModelGeneVariancesResults< Stat_ > average
Definition model_gene_variances.hpp:333
Buffers for model_gene_variances() and friends.
Definition model_gene_variances.hpp:110
Stat_ * mean
Definition model_gene_variances.hpp:114
Stat_ * fitted
Definition model_gene_variances.hpp:126
Stat_ * variance
Definition model_gene_variances.hpp:119
Stat_ * residual
Definition model_gene_variances.hpp:133
Options for model_gene_variances() and friends.
Definition model_gene_variances.hpp:38
FitVarianceTrendOptions fit_variance_trend_options
Definition model_gene_variances.hpp:50
double block_quantile
Definition model_gene_variances.hpp:92
bool trend
Definition model_gene_variances.hpp:44
BlockAveragePolicy block_average_policy
Definition model_gene_variances.hpp:57
scran_blocks::VariableWeightParameters variable_block_weight_parameters
Definition model_gene_variances.hpp:76
int num_threads
Definition model_gene_variances.hpp:98
scran_blocks::WeightPolicy block_weight_policy
Definition model_gene_variances.hpp:69
Results of model_gene_variances().
Definition model_gene_variances.hpp:187
std::vector< Stat_ > fitted
Definition model_gene_variances.hpp:234
std::vector< Stat_ > variance
Definition model_gene_variances.hpp:227
std::vector< Stat_ > residual
Definition model_gene_variances.hpp:241
std::vector< Stat_ > mean
Definition model_gene_variances.hpp:222