mxnet
random.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_RANDOM_H_
27 #define MSHADOW_RANDOM_H_
28 
29 #include <cstdlib>
30 #include <algorithm>
31 #include <random>
32 #include "./base.h"
33 #include "./tensor.h"
34 #include "./tensor_container.h"
35 #include <random>
36 
37 
38 namespace mshadow {
44 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
45 class Random {};
46 
48 template<typename DType>
49 class Random<cpu, DType> {
50  public:
55  explicit Random(int seed) {
56  this->Seed(seed);
57  buffer_.Resize(Shape1(kRandBufferSize));
58  }
59  ~Random(void) {
60  }
65  inline void Seed(int seed) {
66  rnd_engine_.seed(seed);
67  this->rseed_ = static_cast<unsigned>(seed);
68  }
73  inline unsigned GetSeed() const {
74  return rseed_;
75  }
80  inline void set_stream(Stream<cpu> *stream) {
81  }
82 
87  inline unsigned GetRandInt() {
88  return rnd_engine_();
89  }
90 
94  inline void GetRandInt(const Tensor<cpu, 1, unsigned>& dst) {
95  std::generate_n(dst.dptr_, dst.size(0), [&](){ return rnd_engine_(); });
96  }
97 
104  template<int dim, class Sampler>
105  inline void SampleDistribution(Tensor<cpu, dim, DType> *dst, Sampler sampler) {
106  if (dst->CheckContiguous()) {
107  std::generate_n(dst->dptr_, dst->shape_.Size(), sampler);
108  } else {
109  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
110  for (index_t i = 0; i < mat.size(0); ++i) {
111  std::generate_n(mat[i].dptr_, mat.size(1), sampler);
112  }
113  }
114  }
115 
123  template<int dim, typename PType>
125  PType a = 0.0f , PType b = 1.0f ) {
126  // Ensure that half_t is handled correctly.
127  typedef typename std::conditional<std::is_floating_point<DType>::value,
128  DType, double>::type FType;
129  typedef typename std::conditional<std::is_integral<DType>::value,
130  std::uniform_int_distribution<DType>,
131  std::uniform_real_distribution<FType>>::type GType;
132  GType dist_uniform(a, b);
133  SampleDistribution(dst, [&](){ return dist_uniform(rnd_engine_);});
134  }
135 
143  template<int dim, typename PType>
145  PType mu = 0.0f, PType sigma = 1.0f ) {
146  if (sigma <= 0) {
147  *dst = mu; return;
148  }
149  typedef typename std::conditional<std::is_floating_point<DType>::value,
150  DType, double>::type GType;
151  std::normal_distribution<GType> dist_normal(mu, sigma);
152  SampleDistribution(dst, [&](){ return dist_normal(rnd_engine_);});
153  }
154 
162  template<int dim, typename PType>
164  PType alpha, PType beta) {
165  typedef typename std::conditional<std::is_floating_point<DType>::value,
166  DType, double>::type GType;
167  std::gamma_distribution<GType> dist_gamma(alpha, beta);
168  SampleDistribution(dst, [&](){ return dist_gamma(rnd_engine_);});
169  }
170 
177  template<int dim, typename PType>
178  inline void SampleExponential(Tensor<cpu, dim, DType> *dst, PType lambda ) {
179  typedef typename std::conditional<std::is_floating_point<DType>::value,
180  DType, double>::type GType;
181  std::exponential_distribution<GType> dist_exp(lambda);
182  SampleDistribution(dst, [&](){ return dist_exp(rnd_engine_);});
183  }
184 
191  template<int dim, typename PType>
192  inline void SamplePoisson(Tensor<cpu, dim, DType> *dst, PType lambda) {
193  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
194  std::poisson_distribution<GType> dist_poisson(lambda);
195  SampleDistribution(dst, [&](){ return static_cast<DType>(dist_poisson(rnd_engine_));});
196  }
197 
205  template<int dim, typename PType1, typename PType2>
206  inline void SampleNegativeBinomial(Tensor<cpu, dim, DType> *dst, PType1 k, PType2 p) {
207  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
208  std::negative_binomial_distribution<GType> dist_negbinomial(k, p);
209  SampleDistribution(dst, [&](){ return static_cast<DType>(dist_negbinomial(rnd_engine_));});
210  }
211 
220  template<int dim, typename PType>
222  PType mu, PType alpha) {
223  if (alpha == PType(0)) {
224  SamplePoisson(dst, mu); // limit of Poisson
225  } else {
226  PType r(PType(1) / alpha);
227  PType beta = mu * alpha;
228  std::gamma_distribution<> dist_gamma(r, beta);
229  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
230  SampleDistribution(dst,
231  [&](){ std::poisson_distribution<GType> dist_poisson(dist_gamma(rnd_engine_));
232  return static_cast<DType>(dist_poisson(rnd_engine_));});
233  }
234  }
235 
247  template<int dim>
248  inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
250  buffer_.Resize(Shape1(shape.Size()));
251  this->SampleGaussian(&buffer_, 0.0f, 1.0f);
252  return expr::reshape(buffer_, shape);
253  }
265  template<int dim>
266  inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
268  buffer_.Resize(Shape1(shape.Size()));
269  this->SampleUniform(&buffer_, 0.0f, 1.0f);
270  return expr::reshape(buffer_, shape);
271  }
272 
273  std::mt19937 &GetRndEngine() {
274  return rnd_engine_;
275  }
276 
277  private:
279  std::mt19937 rnd_engine_;
281  unsigned rseed_;
284 }; // class Random<cpu, DType>
285 
286 // only allow GPU PRNG when cuda is enabled
287 #if MSHADOW_USE_CUDA
288 
289 template<typename DType>
290 class Random<gpu, DType> {
291  public:
296  explicit Random(int seed) : gen_(NULL) {
297  this->Seed(seed);
298  buffer_.Resize(Shape1(kRandBufferSize));
299  }
301  DeleteGenerator();
302  }
307  inline void set_stream(Stream<gpu> *stream) {
308  curandStatus_t status;
309  status = curandSetStream(gen_, Stream<gpu>::GetStream(stream));
310 
311  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed";
312  }
317  inline void Seed(int seed) {
318  // Create a new rng, either initially or if the RNG type can't reset its offset.
319  if (gen_ == NULL || (curandSetGeneratorOffset(gen_, 0ULL) != CURAND_STATUS_SUCCESS))
320  CreateGenerator();
321  // Now set the seed.
322  curandStatus_t status;
323  status = curandSetPseudoRandomGeneratorSeed(gen_, static_cast<uint64_t>(seed));
324  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed.";
325  }
329  inline void GetRandInt(const Tensor<gpu, 1, unsigned>& dst) {
330  curandStatus_t status;
331  status = curandGenerate(gen_, dst.dptr_, dst.size(0));
332  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen rand ints failed."
333  << " size = " << dst.size(0);
334  }
342  template<int dim>
343  inline void SampleUniform(Tensor<gpu, dim, DType> *dst,
344  DType a = 0.0f, DType b = 1.0f);
345 
353  template<int dim>
354  inline void SampleGaussian(Tensor<gpu, dim, DType> *dst,
355  DType mu = 0.0f, DType sigma = 1.0f);
369  template<int dim>
370  inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
371  gaussian(Shape<dim> shape, DType mu = 0.0f, DType sigma = 1.0f);
383  template<int dim>
384  inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
385  uniform(Shape<dim> shape);
386 
387  private:
388  inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) {
389  curandStatus_t status;
390  status = curandGenerateNormal(gen_, dptr, size, mu, sigma);
391  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed."
392  << " size = " << size
393  << ",mu = " << mu
394  << ",sigma = " << sigma;
395  }
396  inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) {
397  curandStatus_t status;
398  status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma);
399  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed."
400  << " size = " << size
401  << ",mu = " << mu
402  << ",sigma = " << sigma;
403  }
404  inline void GenUniform(float *dptr, size_t size) {
405  curandStatus_t status;
406  status = curandGenerateUniform(gen_, dptr, size);
407  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed."
408  << " size = " << size;
409  }
410  inline void GenUniform(double *dptr, size_t size) {
411  curandStatus_t status;
412  status = curandGenerateUniformDouble(gen_, dptr, size);
413  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed."
414  << " size = " << size;
415  }
416  inline void CreateGenerator() {
417  if (gen_ != NULL)
418  DeleteGenerator();
419  curandStatus_t status;
420  status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT);
421  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Cannot create CURAND Generator";
422  }
423  inline void DeleteGenerator() {
424  if (gen_ != NULL) {
425  curandStatus_t status;
426  status = curandDestroyGenerator(gen_);
427  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed";
428  gen_ = NULL;
429  }
430  }
432  curandGenerator_t gen_;
434  TensorContainer<gpu, 1, DType> buffer_;
435 }; // class Random<gpu, DType>
436 #endif // MSHADOW_USE_CUDA
437 
438 #ifdef __CUDACC__
439 // implementations that depends on cuda kernels
440 template<typename DType>
441 template<int dim>
443  Tensor<gpu, dim, DType> *dst, DType a, DType b) {
444  if (a == 0.0f && b == 1.0f) {
445  if (dst->CheckContiguous()) {
446  this->GenUniform(dst->dptr_, dst->shape_.Size());
447  } else {
448  *dst = this->uniform(dst->shape_);
449  }
450  } else {
451  *dst = this->uniform(dst->shape_) * (b - a) + a;
452  }
453 }
454 template<typename DType>
455 template<int dim>
457  Tensor<gpu, dim, DType> *dst, DType mu, DType sigma) {
458  // We need to check whether the shape size is even since CuRand supports only normal distribution
459  // generation of even number of elements.
460  if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) {
461  this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
462  } else {
463  *dst = this->gaussian(dst->shape_, mu, sigma);
464  }
465 }
466 
467 template<typename DType>
468 template<int dim>
469 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
470 Random<gpu, DType>::gaussian(Shape<dim> shape, DType mu, DType sigma) {
471  size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1;
472  // allocate alligned size
473  buffer_.Resize(Shape1(aligned_sz));
474  buffer_.Resize(Shape1(shape.Size()));
475  this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma);
476  return expr::reshape(buffer_, shape);
477 }
478 
479 template<typename DType>
480 template<int dim>
481 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
482 Random<gpu, DType>::uniform(Shape<dim> shape) {
483  buffer_.Resize(Shape1(shape.Size()));
484  this->GenUniform(buffer_.dptr_, buffer_.size(0));
485  return expr::reshape(buffer_, shape);
486 }
487 #endif // __CUDACC__
488 } // namespace mshadow
489 #endif // MSHADOW_RANDOM_H_
MSHADOW_THROW_EXCEPTION
#define MSHADOW_THROW_EXCEPTION
Definition: base.h:250
mshadow::Random< cpu, DType >::SampleGamma
void SampleGamma(Tensor< cpu, dim, DType > *dst, PType alpha, PType beta)
generate data from a gamma distribution
Definition: random.h:163
mshadow::Shape::Size
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:158
mxnet::SampleGaussian
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mxnet::SamplePoisson
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
mshadow::Random< gpu, DType >::Random
Random(int seed)
constructor of random engine
Definition: random.h:296
mshadow::Random< cpu, DType >::SampleGeneralizedNegativeBinomial
void SampleGeneralizedNegativeBinomial(Tensor< cpu, dim, DType > *dst, PType mu, PType alpha)
generate data from a generalized negative binomial distribution
Definition: random.h:221
mshadow::Random
random number generator
Definition: random.h:45
mshadow::Random< cpu, DType >::GetRandInt
unsigned GetRandInt()
get some random integer
Definition: random.h:87
mshadow::Tensor
general tensor
Definition: tensor.h:525
mshadow::Random< cpu, DType >::SampleUniform
void SampleUniform(Tensor< cpu, dim, DType > *dst, PType a=0.0f, PType b=1.0f)
generate data from uniform [a,b)
Definition: random.h:124
mshadow::Random< cpu, DType >::~Random
~Random(void)
Definition: random.h:59
mshadow::Random< cpu, DType >::SampleNegativeBinomial
void SampleNegativeBinomial(Tensor< cpu, dim, DType > *dst, PType1 k, PType2 p)
generate data from a negative binomial distribution
Definition: random.h:206
mshadow::gpu
device name GPU
Definition: tensor.h:46
mshadow::Random< gpu, DType >::Seed
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:317
mshadow::cpu
device name CPU
Definition: tensor.h:39
mshadow::Random< cpu, DType >::Seed
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:65
mshadow::Random< gpu, DType >::GetRandInt
void GetRandInt(const Tensor< gpu, 1, unsigned > &dst)
get a set of random integers
Definition: random.h:329
tensor.h
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
mshadow::Random< cpu, DType >::GetRandInt
void GetRandInt(const Tensor< cpu, 1, unsigned > &dst)
get a set of random integers
Definition: random.h:94
mshadow::expr::reshape
ReshapeExp< SrcExp, DType, dimdst, ExpInfo< SrcExp >::kDim > reshape(const Exp< SrcExp, DType, etype > &src, Shape< dimdst > oshape)
a expression that reshapes a tensor to another shape
Definition: reshape.h:66
mshadow::Stream< gpu >
Definition: stream_gpu-inl.h:37
mshadow::Random< gpu, DType >::~Random
~Random(void) MSHADOW_THROW_EXCEPTION
Definition: random.h:300
mshadow::kRandBufferSize
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:321
mshadow::Tensor::CheckContiguous
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:596
mshadow::Random< cpu, DType >::SampleDistribution
void SampleDistribution(Tensor< cpu, dim, DType > *dst, Sampler sampler)
generate data from a distribution
Definition: random.h:105
mshadow::Tensor::shape_
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:541
mshadow::Random< cpu, DType >::set_stream
void set_stream(Stream< cpu > *stream)
set the stream of computation
Definition: random.h:80
mshadow::expr::ReshapeExp
reshape the content to another shape input: Tensor<Device,dimsrc>: ishape output: Tensor<Device,...
Definition: reshape.h:39
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::Random< cpu, DType >::SampleExponential
void SampleExponential(Tensor< cpu, dim, DType > *dst, PType lambda)
generate data from an exponential distribution
Definition: random.h:178
mshadow::Random< cpu, DType >::gaussian
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > gaussian(Shape< dim > shape)
return a temporal expression storing standard gaussian random variables the temporal tensor is only v...
Definition: random.h:249
mshadow::Random< cpu, DType >::GetSeed
unsigned GetSeed() const
get random seed used in random generator
Definition: random.h:73
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::Tensor::FlatTo2D
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:624
mshadow::TensorContainer
tensor container that does memory allocation and resize like STL, use it to save the lines of FreeSpa...
Definition: tensor_container.h:40
mxnet::SampleUniform
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
mshadow::Shape< dim >
mshadow::Tensor::dptr_
DType * dptr_
pointer to the data
Definition: tensor.h:539
tensor_container.h
tensor container that does memory allocation and resize like STL
mshadow::Shape1
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:220
mshadow::Random< cpu, DType >::SampleGaussian
void SampleGaussian(Tensor< cpu, dim, DType > *dst, PType mu=0.0f, PType sigma=1.0f)
generate data from standard gaussian
Definition: random.h:144
mshadow::Random< cpu, DType >::SamplePoisson
void SamplePoisson(Tensor< cpu, dim, DType > *dst, PType lambda)
generate data from a poisson distribution
Definition: random.h:192
mshadow::Tensor::size
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:610
mshadow::Random< cpu, DType >::GetRndEngine
std::mt19937 & GetRndEngine()
Definition: random.h:273
mshadow::Random< gpu, DType >::gaussian
expr::ReshapeExp< Tensor< gpu, 1, DType >, DType, dim, 1 > gaussian(Shape< dim > shape, DType mu=0.0f, DType sigma=1.0f)
return a temporal expression storing standard gaussian random variables the temporal tensor is only v...
mshadow::Random< cpu, DType >::uniform
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > uniform(Shape< dim > shape)
return a temporal expression storing standard uniform [0,1) the temporal tensor is only valid before ...
Definition: random.h:267
mshadow::Random< gpu, DType >::uniform
expr::ReshapeExp< Tensor< gpu, 1, DType >, DType, dim, 1 > uniform(Shape< dim > shape)
return a temporal expression storing standard uniform [0,1) the temporal tensor is only valid before ...
base.h
definitions of base types, operators, macros functions
mshadow::Random< cpu, DType >::Random
Random(int seed)
constructor of random engine
Definition: random.h:55
mshadow::Random< gpu, DType >::set_stream
void set_stream(Stream< gpu > *stream)
set the stream of computation
Definition: random.h:307