Go to the documentation of this file.
24 #ifndef MXNET_RANDOM_GENERATOR_H_
25 #define MXNET_RANDOM_GENERATOR_H_
32 #include <curand_kernel.h>
34 #endif // MXNET_USE_CUDA
40 template <
typename Device,
typename DType MSHADOW_DEFAULT_DTYPE>
43 template <
typename DType>
56 typename std::conditional<std::is_floating_point<DType>::value, DType,
double>::type
FType;
58 : engine_(gen->states_ + state_idx) {}
60 Impl(
const Impl&) =
delete;
61 Impl& operator=(
const Impl&) =
delete;
64 return engine_->operator()();
68 return static_cast<int64_t
>(engine_->operator()() << 31) + engine_->operator()();
72 typedef typename std::conditional<std::is_integral<DType>::value,
73 std::uniform_int_distribution<DType>,
74 std::uniform_real_distribution<FType>>::type GType;
76 return dist_uniform(*engine_);
80 std::normal_distribution<FType> dist_normal;
81 return dist_normal(*engine_);
85 std::mt19937* engine_;
89 inst->states_ =
new std::mt19937[kNumRandomStates];
93 delete[] inst->states_;
97 for (
int i = 0; i < kNumRandomStates; ++i)
98 (states_ + i)->seed(seed + i);
103 return static_cast<void*
>(states_);
107 std::mt19937* states_;
110 template <
typename DType>
113 template <
typename DType>
118 template <
typename DType>
133 Impl& operator=(
const Impl&) =
delete;
134 Impl(
const Impl&) =
delete;
138 : global_gen_(gen), global_state_idx_(state_idx), state_(*(gen->states_ + state_idx)) {}
142 global_gen_->states_[global_state_idx_] = state_;
146 return curand(&state_);
150 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
154 return static_cast<float>(1.0) - curand_uniform(&state_);
158 return curand_normal(&state_);
163 int global_state_idx_;
164 curandStatePhilox4_32_10_t state_;
177 curandStatePhilox4_32_10_t* states_;
190 Impl& operator=(
const Impl&) =
delete;
191 Impl(
const Impl&) =
delete;
195 : global_gen_(gen), global_state_idx_(state_idx), state_(*(gen->states_ + state_idx)) {}
199 global_gen_->states_[global_state_idx_] = state_;
203 return curand(&state_);
207 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
211 return static_cast<float>(1.0) - curand_uniform_double(&state_);
215 return curand_normal_double(&state_);
220 int global_state_idx_;
221 curandStatePhilox4_32_10_t state_;
225 curandStatePhilox4_32_10_t* states_;
228 #endif // MXNET_USE_CUDA
233 #endif // MXNET_RANDOM_GENERATOR_H_
namespace of mxnet
Definition: api_registry.h:33
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:210
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
Definition: random_generator.h:181
MSHADOW_XINLINE int rand()
Definition: random_generator.h:63
Definition: random_generator.h:41
#define MSHADOW_XINLINE
Definition: base.h:228
Definition: random_generator.h:44
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:153
MSHADOW_XINLINE void * GetStates()
Definition: random_generator.h:102
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:57
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:214
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:71
device name GPU
Definition: tensor.h:46
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:67
device name CPU
Definition: tensor.h:39
Definition: random_generator.h:119
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:194
__device__ ~Impl()
Definition: random_generator.h:197
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:145
Definition: stream_gpu-inl.h:37
#define MSHADOW_FORCE_INLINE
Definition: base.h:223
static const int kMinNumRandomPerThread
Definition: random_generator.h:122
static const int kNumRandomStates
Definition: random_generator.h:124
static const int kMinNumRandomPerThread
Definition: random_generator.h:47
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:137
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:79
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:56
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:202
static const int kNumRandomStates
Definition: random_generator.h:49
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:88
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:206
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:157
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:92
__device__ ~Impl()
Definition: random_generator.h:140
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:149
configuration of MXNet as well as basic data structure.
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:96