mxnet
random_generator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_RANDOM_GENERATOR_H_
25 #define MXNET_RANDOM_GENERATOR_H_
26 
27 #include <random>
28 #include <new>
29 #include "./base.h"
30 
31 #if MXNET_USE_CUDA
32 #include <curand_kernel.h>
33 #include <math.h>
34 #endif // MXNET_USE_CUDA
35 
36 namespace mxnet {
37 namespace common {
38 namespace random {
39 
40 template <typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
42 
43 template <typename DType>
44 class RandGenerator<cpu, DType> {
45  public:
46  // at least how many random numbers should be generated by one CPU thread.
47  static const int kMinNumRandomPerThread;
48  // store how many global random states for CPU.
49  static const int kNumRandomStates;
50 
51  // implementation class for random number generator
52  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
53  class Impl {
54  public:
55  typedef
56  typename std::conditional<std::is_floating_point<DType>::value, DType, double>::type FType;
57  explicit Impl(RandGenerator<cpu, DType>* gen, int state_idx)
58  : engine_(gen->states_ + state_idx) {}
59 
60  Impl(const Impl&) = delete;
61  Impl& operator=(const Impl&) = delete;
62 
64  return engine_->operator()();
65  }
66 
68  return static_cast<int64_t>(engine_->operator()() << 31) + engine_->operator()();
69  }
70 
72  typedef typename std::conditional<std::is_integral<DType>::value,
73  std::uniform_int_distribution<DType>,
74  std::uniform_real_distribution<FType>>::type GType;
75  GType dist_uniform;
76  return dist_uniform(*engine_);
77  }
78 
80  std::normal_distribution<FType> dist_normal;
81  return dist_normal(*engine_);
82  }
83 
84  private:
85  std::mt19937* engine_;
86  }; // class RandGenerator<cpu, DType>::Impl
87 
89  inst->states_ = new std::mt19937[kNumRandomStates];
90  }
91 
92  static void FreeState(RandGenerator<cpu, DType>* inst) {
93  delete[] inst->states_;
94  }
95 
96  MSHADOW_XINLINE void Seed(mshadow::Stream<cpu>*, uint32_t seed) {
97  for (int i = 0; i < kNumRandomStates; ++i)
98  (states_ + i)->seed(seed + i);
99  }
100 
101  // export global random states, used by c++ custom operator
103  return static_cast<void*>(states_);
104  }
105 
106  private:
107  std::mt19937* states_;
108 }; // class RandGenerator<cpu, DType>
109 
110 template <typename DType>
112 
113 template <typename DType>
115 
116 #if MXNET_USE_CUDA
117 
118 template <typename DType>
119 class RandGenerator<gpu, DType> {
120  public:
121  // at least how many random numbers should be generated by one GPU thread.
122  static const int kMinNumRandomPerThread;
123  // store how many global random states for GPU.
124  static const int kNumRandomStates;
125 
126  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
127  // by using 1.0-curand_uniform().
128  // Needed as some samplers in sampler.h won't be able to deal with
129  // one of the boundary cases.
130  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
131  class Impl {
132  public:
133  Impl& operator=(const Impl&) = delete;
134  Impl(const Impl&) = delete;
135 
136  // Copy state to local memory for efficiency.
137  __device__ explicit Impl(RandGenerator<gpu, DType>* gen, int state_idx)
138  : global_gen_(gen), global_state_idx_(state_idx), state_(*(gen->states_ + state_idx)) {}
139 
140  __device__ ~Impl() {
141  // store the curand state back into global memory
142  global_gen_->states_[global_state_idx_] = state_;
143  }
144 
145  MSHADOW_FORCE_INLINE __device__ int rand() {
146  return curand(&state_);
147  }
148 
149  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
150  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
151  }
152 
153  MSHADOW_FORCE_INLINE __device__ float uniform() {
154  return static_cast<float>(1.0) - curand_uniform(&state_);
155  }
156 
157  MSHADOW_FORCE_INLINE __device__ float normal() {
158  return curand_normal(&state_);
159  }
160 
161  private:
162  RandGenerator<gpu, DType>* global_gen_;
163  int global_state_idx_;
164  curandStatePhilox4_32_10_t state_;
165  }; // class RandGenerator<gpu, DType>::Impl
166 
167  static void AllocState(RandGenerator<gpu, DType>* inst);
168 
169  static void FreeState(RandGenerator<gpu, DType>* inst);
170 
171  void Seed(mshadow::Stream<gpu>* s, uint32_t seed);
172 
173  // export global random states, used by c++ custom operator
174  void* GetStates();
175 
176  private:
177  curandStatePhilox4_32_10_t* states_;
178 }; // class RandGenerator<gpu, DType>
179 
180 template <>
181 class RandGenerator<gpu, double> {
182  public:
183  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
184  // by using 1.0-curand_uniform().
185  // Needed as some samplers in sampler.h won't be able to deal with
186  // one of the boundary cases.
187  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
188  class Impl {
189  public:
190  Impl& operator=(const Impl&) = delete;
191  Impl(const Impl&) = delete;
192 
193  // Copy state to local memory for efficiency.
194  __device__ explicit Impl(RandGenerator<gpu, double>* gen, int state_idx)
195  : global_gen_(gen), global_state_idx_(state_idx), state_(*(gen->states_ + state_idx)) {}
196 
197  __device__ ~Impl() {
198  // store the curand state back into global memory
199  global_gen_->states_[global_state_idx_] = state_;
200  }
201 
202  MSHADOW_FORCE_INLINE __device__ int rand() {
203  return curand(&state_);
204  }
205 
206  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
207  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
208  }
209 
210  MSHADOW_FORCE_INLINE __device__ double uniform() {
211  return static_cast<float>(1.0) - curand_uniform_double(&state_);
212  }
213 
214  MSHADOW_FORCE_INLINE __device__ double normal() {
215  return curand_normal_double(&state_);
216  }
217 
218  private:
219  RandGenerator<gpu, double>* global_gen_;
220  int global_state_idx_;
221  curandStatePhilox4_32_10_t state_;
222  }; // class RandGenerator<gpu, double>::Impl
223 
224  private:
225  curandStatePhilox4_32_10_t* states_;
226 }; // class RandGenerator<gpu, double>
227 
228 #endif // MXNET_USE_CUDA
229 
230 } // namespace random
231 } // namespace common
232 } // namespace mxnet
233 #endif // MXNET_RANDOM_GENERATOR_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::random::RandGenerator< gpu, double >::Impl::uniform
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:210
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mxnet::common::random::RandGenerator< gpu, double >
Definition: random_generator.h:181
mxnet::common::random::RandGenerator< cpu, DType >::Impl::rand
MSHADOW_XINLINE int rand()
Definition: random_generator.h:63
mxnet::common::random::RandGenerator
Definition: random_generator.h:41
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mxnet::common::random::RandGenerator< cpu, DType >
Definition: random_generator.h:44
mxnet::common::random::RandGenerator< gpu, DType >::Impl::uniform
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:153
mxnet::common::random::RandGenerator< cpu, DType >::GetStates
MSHADOW_XINLINE void * GetStates()
Definition: random_generator.h:102
mxnet::common::random::RandGenerator< cpu, DType >::Impl::Impl
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:57
mxnet::common::random::RandGenerator< gpu, double >::Impl::normal
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:214
mxnet::common::random::RandGenerator< cpu, DType >::Impl::uniform
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:71
mshadow::gpu
device name GPU
Definition: tensor.h:46
mxnet::common::random::RandGenerator< cpu, DType >::Impl::rand_int64
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:67
mshadow::cpu
device name CPU
Definition: tensor.h:39
mxnet::common::random::RandGenerator< gpu, DType >
Definition: random_generator.h:119
mxnet::common::random::RandGenerator< gpu, double >::Impl::Impl
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:194
mxnet::common::random::RandGenerator< gpu, double >::Impl::~Impl
__device__ ~Impl()
Definition: random_generator.h:197
mxnet::common::random::RandGenerator< gpu, DType >::Impl::rand
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:145
mshadow::Stream< gpu >
Definition: stream_gpu-inl.h:37
MSHADOW_FORCE_INLINE
#define MSHADOW_FORCE_INLINE
Definition: base.h:223
mxnet::common::random::RandGenerator< gpu, DType >::kMinNumRandomPerThread
static const int kMinNumRandomPerThread
Definition: random_generator.h:122
mxnet::common::random::RandGenerator< gpu, DType >::kNumRandomStates
static const int kNumRandomStates
Definition: random_generator.h:124
mxnet::common::random::RandGenerator< cpu, DType >::kMinNumRandomPerThread
static const int kMinNumRandomPerThread
Definition: random_generator.h:47
mxnet::common::random::RandGenerator< gpu, DType >::Impl::Impl
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:137
mxnet::common::random::RandGenerator< cpu, DType >::Impl::normal
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:79
mxnet::common::random::RandGenerator< cpu, DType >::Impl::FType
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:56
mxnet::common::random::RandGenerator< gpu, double >::Impl::rand
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:202
mxnet::common::random::RandGenerator< cpu, DType >::kNumRandomStates
static const int kNumRandomStates
Definition: random_generator.h:49
mxnet::common::random::RandGenerator< cpu, DType >::AllocState
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:88
mxnet::common::random::RandGenerator< gpu, double >::Impl::rand_int64
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:206
mxnet::common::random::RandGenerator< gpu, DType >::Impl::normal
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:157
mxnet::common::random::RandGenerator< cpu, DType >::FreeState
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:92
mxnet::common::random::RandGenerator< gpu, DType >::Impl::~Impl
__device__ ~Impl()
Definition: random_generator.h:140
mxnet::common::random::RandGenerator< gpu, DType >::Impl::rand_int64
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:149
base.h
configuration of MXNet as well as basic data structure.
mxnet::common::random::RandGenerator< cpu, DType >::Seed
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:96