mxnet
random_generator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_RANDOM_GENERATOR_H_
26 #define MXNET_RANDOM_GENERATOR_H_
27 
28 #include <random>
29 #include <new>
30 #include "./base.h"
31 
32 #if MXNET_USE_CUDA
33 #include <curand_kernel.h>
34 #include <math.h>
35 #endif // MXNET_USE_CUDA
36 
37 namespace mxnet {
38 namespace common {
39 namespace random {
40 
41 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
43 
44 template<typename DType>
45 class RandGenerator<cpu, DType> {
46  public:
47  // at least how many random numbers should be generated by one CPU thread.
48  static const int kMinNumRandomPerThread;
49  // store how many global random states for CPU.
50  static const int kNumRandomStates;
51 
52  // implementation class for random number generator
53  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
54  class Impl {
55  public:
56  typedef typename std::conditional<std::is_floating_point<DType>::value,
57  DType, double>::type FType;
58  explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
59  : engine_(gen->states_ + state_idx) {}
60 
61  Impl(const Impl &) = delete;
62  Impl &operator=(const Impl &) = delete;
63 
64  MSHADOW_XINLINE int rand() { return engine_->operator()(); }
65 
67  return static_cast<int64_t>(engine_->operator()() << 31) + engine_->operator()();
68  }
69 
71  typedef typename std::conditional<std::is_integral<DType>::value,
72  std::uniform_int_distribution<DType>,
73  std::uniform_real_distribution<FType>>::type GType;
74  GType dist_uniform;
75  return dist_uniform(*engine_);
76  }
77 
79  std::normal_distribution<FType> dist_normal;
80  return dist_normal(*engine_);
81  }
82 
83  private:
84  std::mt19937 *engine_;
85  }; // class RandGenerator<cpu, DType>::Impl
86 
88  inst->states_ = new std::mt19937[kNumRandomStates];
89  }
90 
91  static void FreeState(RandGenerator<cpu, DType> *inst) {
92  delete[] inst->states_;
93  }
94 
95  MSHADOW_XINLINE void Seed(mshadow::Stream<cpu> *, uint32_t seed) {
96  for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
97  }
98 
99  // export global random states, used by c++ custom operator
101  return static_cast<void*>(states_);
102  }
103 
104  private:
105  std::mt19937 *states_;
106 }; // class RandGenerator<cpu, DType>
107 
108 template<typename DType>
110 
111 template<typename DType>
113 
114 #if MXNET_USE_CUDA
115 
116 template<typename DType>
117 class RandGenerator<gpu, DType> {
118  public:
119  // at least how many random numbers should be generated by one GPU thread.
120  static const int kMinNumRandomPerThread;
121  // store how many global random states for GPU.
122  static const int kNumRandomStates;
123 
124  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
125  // by using 1.0-curand_uniform().
126  // Needed as some samplers in sampler.h won't be able to deal with
127  // one of the boundary cases.
128  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
129  class Impl {
130  public:
131  Impl &operator=(const Impl &) = delete;
132  Impl(const Impl &) = delete;
133 
134  // Copy state to local memory for efficiency.
135  __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
136  : global_gen_(gen),
137  global_state_idx_(state_idx),
138  state_(*(gen->states_ + state_idx)) {}
139 
140  __device__ ~Impl() {
141  // store the curand state back into global memory
142  global_gen_->states_[global_state_idx_] = state_;
143  }
144 
145  MSHADOW_FORCE_INLINE __device__ int rand() {
146  return curand(&state_);
147  }
148 
149  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
150  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
151  }
152 
153  MSHADOW_FORCE_INLINE __device__ float uniform() {
154  return static_cast<float>(1.0) - curand_uniform(&state_);
155  }
156 
157  MSHADOW_FORCE_INLINE __device__ float normal() {
158  return curand_normal(&state_);
159  }
160 
161  private:
162  RandGenerator<gpu, DType> *global_gen_;
163  int global_state_idx_;
164  curandStatePhilox4_32_10_t state_;
165  }; // class RandGenerator<gpu, DType>::Impl
166 
167  static void AllocState(RandGenerator<gpu, DType> *inst);
168 
169  static void FreeState(RandGenerator<gpu, DType> *inst);
170 
171  void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
172 
173  // export global random states, used by c++ custom operator
174  void* GetStates();
175 
176  private:
177  curandStatePhilox4_32_10_t *states_;
178 }; // class RandGenerator<gpu, DType>
179 
180 template<>
181 class RandGenerator<gpu, double> {
182  public:
183  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
184  // by using 1.0-curand_uniform().
185  // Needed as some samplers in sampler.h won't be able to deal with
186  // one of the boundary cases.
187  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
188  class Impl {
189  public:
190  Impl &operator=(const Impl &) = delete;
191  Impl(const Impl &) = delete;
192 
193  // Copy state to local memory for efficiency.
194  __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
195  : global_gen_(gen),
196  global_state_idx_(state_idx),
197  state_(*(gen->states_ + state_idx)) {}
198 
199  __device__ ~Impl() {
200  // store the curand state back into global memory
201  global_gen_->states_[global_state_idx_] = state_;
202  }
203 
204  MSHADOW_FORCE_INLINE __device__ int rand() {
205  return curand(&state_);
206  }
207 
208  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
209  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
210  }
211 
212  MSHADOW_FORCE_INLINE __device__ double uniform() {
213  return static_cast<float>(1.0) - curand_uniform_double(&state_);
214  }
215 
216  MSHADOW_FORCE_INLINE __device__ double normal() {
217  return curand_normal_double(&state_);
218  }
219 
220  private:
221  RandGenerator<gpu, double> *global_gen_;
222  int global_state_idx_;
223  curandStatePhilox4_32_10_t state_;
224  }; // class RandGenerator<gpu, double>::Impl
225 
226  private:
227  curandStatePhilox4_32_10_t *states_;
228 }; // class RandGenerator<gpu, double>
229 
230 #endif // MXNET_USE_CUDA
231 
232 } // namespace random
233 } // namespace common
234 } // namespace mxnet
235 #endif // MXNET_RANDOM_GENERATOR_H_
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:157
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:87
Definition: random_generator.h:181
namespace of mxnet
Definition: api_registry.h:33
Definition: stream_gpu-inl.h:38
static const int kMinNumRandomPerThread
Definition: random_generator.h:120
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:135
MSHADOW_XINLINE int rand()
Definition: random_generator.h:64
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:78
__device__ ~Impl()
Definition: random_generator.h:140
#define MSHADOW_FORCE_INLINE
Definition: base.h:218
static const int kNumRandomStates
Definition: random_generator.h:122
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:204
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:91
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:208
device name CPU
Definition: tensor.h:40
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:58
device name GPU
Definition: tensor.h:47
#define MSHADOW_XINLINE
Definition: base.h:223
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:153
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:57
static const int kNumRandomStates
Definition: random_generator.h:50
Definition: random_generator.h:117
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:66
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:149
static const int kMinNumRandomPerThread
Definition: random_generator.h:48
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:70
MSHADOW_XINLINE void * GetStates()
Definition: random_generator.h:100
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:145
__device__ ~Impl()
Definition: random_generator.h:199
Definition: random_generator.h:42
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:216
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:212
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:194
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:95
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384