umappp
A C++ library for UMAP
Loading...
Searching...
No Matches
initialize.hpp
Go to the documentation of this file.
1#ifndef UMAPPP_INITIALIZE_HPP
2#define UMAPPP_INITIALIZE_HPP
3
4#include "NeighborList.hpp"
5#include "combine_neighbor_sets.hpp"
6#include "find_ab.hpp"
7#include "neighbor_similarities.hpp"
8#include "spectral_init.hpp"
9#include "Status.hpp"
10
11#include "knncolle/knncolle.hpp"
12
13#include <random>
14#include <cstdint>
15
21namespace umappp {
22
26namespace internal {
27
28inline int choose_num_epochs(int num_epochs, size_t size) {
29 if (num_epochs < 0) {
30 // Choosing the number of epochs. We use a simple formula to decrease
31 // the number of epochs with increasing size, with the aim being that
32 // the 'extra work' beyond the minimal 200 epochs should be the same
33 // regardless of the numbe of observations. Given one calculation per
34 // observation per epoch, this amounts to 300 * 10000 calculations at
35 // the lower bound, so we simply choose a number of epochs that
36 // equalizes the number of calculations for any number of observations.
37 if (num_epochs < 0) {
38 constexpr int limit = 10000, minimal = 200, maximal = 300;
39 if (size <= limit) {
40 num_epochs = minimal + maximal;
41 } else {
42 num_epochs = minimal + static_cast<int>(std::ceil(maximal * limit / static_cast<double>(size)));
43 }
44 }
45 }
46 return num_epochs;
47}
48
49}
72template<typename Index_, typename Float_>
73Status<Index_, Float_> initialize(NeighborList<Index_, Float_> x, int num_dim, Float_* embedding, Options options) {
74 internal::NeighborSimilaritiesOptions<Float_> nsopt;
75 nsopt.local_connectivity = options.local_connectivity;
76 nsopt.bandwidth = options.bandwidth;
77 nsopt.num_threads = options.num_threads;
78 internal::neighbor_similarities(x, nsopt);
79
80 internal::combine_neighbor_sets(x, static_cast<Float_>(options.mix_ratio));
81
82 // Choosing the manner of initialization.
83 if (options.initialize == InitializeMethod::SPECTRAL || options.initialize == InitializeMethod::SPECTRAL_ONLY) {
84 bool attempt = internal::spectral_init(x, num_dim, embedding, options.num_threads);
85 if (!attempt && options.initialize == InitializeMethod::SPECTRAL) {
86 internal::random_init(x.size(), num_dim, embedding);
87 }
88 } else if (options.initialize == InitializeMethod::RANDOM) {
89 internal::random_init(x.size(), num_dim, embedding);
90 }
91
92 // Finding a good a/b pair.
93 if (options.a <= 0 || options.b <= 0) {
94 auto found = internal::find_ab(options.spread, options.min_dist);
95 options.a = found.first;
96 options.b = found.second;
97 }
98
99 options.num_epochs = internal::choose_num_epochs(options.num_epochs, x.size());
100
102 internal::similarities_to_epochs<Index_, Float_>(x, options.num_epochs, options.negative_sample_rate),
103 options,
104 num_dim,
105 embedding
106 );
107}
108
126template<typename Dim_, typename Index_, typename Float_>
128 auto output = knncolle::find_nearest_neighbors(prebuilt, options.num_neighbors, options.num_threads);
129 return initialize(std::move(output), num_dim, embedding, std::move(options));
130}
131
153template<typename Dim_, typename Index_, typename Float_>
166
167}
168
169#endif
Defines the NeighborList alias.
Status of the UMAP algorithm.
Status of the UMAP optimization iterations.
Definition Status.hpp:25
int num_epochs() const
Definition Status.hpp:106
NeighborList< Index_, Float_ > find_nearest_neighbors(const Prebuilt< Dim_, Index_, Float_ > &index, int k, int num_threads=1)
Methods for UMAP.
Definition initialize.hpp:21
knncolle::NeighborList< Index_, Float_ > NeighborList
Lists of neighbors for each observation.
Definition NeighborList.hpp:29
Status< Index_, Float_ > initialize(NeighborList< Index_, Float_ > x, int num_dim, Float_ *embedding, Options options)
Definition initialize.hpp:73
Options for initialize().
Definition Options.hpp:28
double mix_ratio
Definition Options.hpp:48
double bandwidth
Definition Options.hpp:40
double min_dist
Definition Options.hpp:60
InitializeMethod initialize
Definition Options.hpp:88
double b
Definition Options.hpp:76
double a
Definition Options.hpp:68
int num_epochs
Definition Options.hpp:100
double spread
Definition Options.hpp:53
int num_threads
Definition Options.hpp:134
double local_connectivity
Definition Options.hpp:34