25 #ifndef MXNET_RANDOM_GENERATOR_H_ 26 #define MXNET_RANDOM_GENERATOR_H_ 33 #include <curand_kernel.h> 35 #endif // MXNET_USE_CUDA 41 template<
typename Device,
typename DType MSHADOW_DEFAULT_DTYPE>
44 template<
typename DType>
56 typedef typename std::conditional<std::is_floating_point<DType>::value,
59 : engine_(gen->states_ + state_idx) {}
61 Impl(
const Impl &) =
delete;
62 Impl &operator=(
const Impl &) =
delete;
67 return static_cast<int64_t
>(engine_->operator()() << 31) + engine_->operator()();
71 typedef typename std::conditional<std::is_integral<DType>::value,
72 std::uniform_int_distribution<DType>,
73 std::uniform_real_distribution<FType>>::type GType;
75 return dist_uniform(*engine_);
79 std::normal_distribution<FType> dist_normal;
80 return dist_normal(*engine_);
84 std::mt19937 *engine_;
88 inst->states_ =
new std::mt19937[kNumRandomStates];
92 delete[] inst->states_;
96 for (
int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
101 return static_cast<void*
>(states_);
105 std::mt19937 *states_;
108 template<
typename DType>
111 template<
typename DType>
116 template<
typename DType>
131 Impl &operator=(
const Impl &) =
delete;
132 Impl(
const Impl &) =
delete;
137 global_state_idx_(state_idx),
138 state_(*(gen->states_ + state_idx)) {}
142 global_gen_->states_[global_state_idx_] = state_;
146 return curand(&state_);
150 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
154 return static_cast<float>(1.0) - curand_uniform(&state_);
158 return curand_normal(&state_);
163 int global_state_idx_;
164 curandStatePhilox4_32_10_t state_;
177 curandStatePhilox4_32_10_t *states_;
190 Impl &operator=(
const Impl &) =
delete;
191 Impl(
const Impl &) =
delete;
196 global_state_idx_(state_idx),
197 state_(*(gen->states_ + state_idx)) {}
201 global_gen_->states_[global_state_idx_] = state_;
205 return curand(&state_);
209 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
213 return static_cast<float>(1.0) - curand_uniform_double(&state_);
217 return curand_normal_double(&state_);
222 int global_state_idx_;
223 curandStatePhilox4_32_10_t state_;
227 curandStatePhilox4_32_10_t *states_;
230 #endif // MXNET_USE_CUDA 235 #endif // MXNET_RANDOM_GENERATOR_H_ MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:157
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:87
Definition: random_generator.h:181
namespace of mxnet
Definition: api_registry.h:33
Definition: stream_gpu-inl.h:38
static const int kMinNumRandomPerThread
Definition: random_generator.h:120
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:135
MSHADOW_XINLINE int rand()
Definition: random_generator.h:64
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:78
__device__ ~Impl()
Definition: random_generator.h:140
#define MSHADOW_FORCE_INLINE
Definition: base.h:218
static const int kNumRandomStates
Definition: random_generator.h:122
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:204
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:91
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:208
device name CPU
Definition: tensor.h:40
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:58
device name GPU
Definition: tensor.h:47
#define MSHADOW_XINLINE
Definition: base.h:223
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:153
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:57
static const int kNumRandomStates
Definition: random_generator.h:50
Definition: random_generator.h:117
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:66
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:149
static const int kMinNumRandomPerThread
Definition: random_generator.h:48
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:70
Definition: random_generator.h:45
MSHADOW_XINLINE void * GetStates()
Definition: random_generator.h:100
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:145
__device__ ~Impl()
Definition: random_generator.h:199
Definition: random_generator.h:42
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:216
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:212
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:194
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:95
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384