44 const std::vector<std::size_t>& num_dims,
45 const Index_ num_cells,
46 const std::vector<Input_*>& embeddings,
47 const std::vector<Scale_>& scaling,
50 const auto nembed = num_dims.size();
51 if (embeddings.size() != nembed || scaling.size() != nembed) {
52 throw std::runtime_error(
"'num_dims', 'embeddings' and 'scale' should have the same length");
55 const std::size_t ntotal = std::accumulate(num_dims.begin(), num_dims.end(),
static_cast<std::size_t
>(0));
56 std::size_t starting_dim = 0;
58 for (I<
decltype(nembed)> e = 0; e < nembed; ++e) {
59 const auto curdim = num_dims[e];
60 const auto inptr = embeddings[e];
61 const auto s = scaling[e];
66 for (Index_ c = 0; c < num_cells; ++c) {
67 const auto out_offset = sanisizer::nd_offset<std::size_t>(starting_dim, ntotal, c);
68 std::fill_n(output + out_offset, curdim, 0);
71 for (Index_ c = 0; c < num_cells; ++c) {
72 for (I<
decltype(curdim)> d = 0; d < curdim; ++d) {
73 const auto out_offset = sanisizer::nd_offset<std::size_t>(starting_dim + d, ntotal, c);
74 const auto in_offset = sanisizer::nd_offset<std::size_t>(d, curdim, c);
75 output[out_offset] = inptr[in_offset] * s;
80 starting_dim += curdim;
void combine_scaled_embeddings(const std::vector< std::size_t > &num_dims, const Index_ num_cells, const std::vector< Input_ * > &embeddings, const std::vector< Scale_ > &scaling, Output_ *const output)
Definition combine_scaled_embeddings.hpp:43