mxnet
initializer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_INITIALIZER_H_
27 #define MXNET_CPP_INITIALIZER_H_
28 
29 #include <cmath>
30 #include <string>
31 #include <vector>
32 #include <random>
33 #include "mxnet-cpp/ndarray.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
38 class Initializer {
39  public:
40  static bool StringStartWith(const std::string& name, const std::string& check_str) {
41  return (name.size() >= check_str.size() && name.substr(0, check_str.size()) == check_str);
42  }
43  static bool StringEndWith(const std::string& name, const std::string& check_str) {
44  return (name.size() >= check_str.size() &&
45  name.substr(name.size() - check_str.size(), check_str.size()) == check_str);
46  }
47  virtual void operator()(const std::string& name, NDArray* arr) {
48  if (StringStartWith(name, "upsampling")) {
49  InitBilinear(arr);
50  } else if (StringEndWith(name, "bias")) {
51  InitBias(arr);
52  } else if (StringEndWith(name, "gamma")) {
53  InitGamma(arr);
54  } else if (StringEndWith(name, "beta")) {
55  InitBeta(arr);
56  } else if (StringEndWith(name, "weight")) {
57  InitWeight(arr);
58  } else if (StringEndWith(name, "moving_mean")) {
59  InitZero(arr);
60  } else if (StringEndWith(name, "moving_var")) {
61  InitOne(arr);
62  } else if (StringEndWith(name, "moving_inv_var")) {
63  InitZero(arr);
64  } else if (StringEndWith(name, "moving_avg")) {
65  InitZero(arr);
66  } else if (StringEndWith(name, "min")) {
67  InitZero(arr);
68  } else if (StringEndWith(name, "max")) {
69  InitOne(arr);
70  } else if (StringEndWith(name, "weight_quantize")) {
72  } else if (StringEndWith(name, "bias_quantize")) {
73  InitQuantizedBias(arr);
74  } else {
75  InitDefault(arr);
76  }
77  }
78 
79  protected:
80  virtual void InitBilinear(NDArray* arr) {
81  Shape shape(arr->GetShape());
82  std::vector<float> weight(shape.Size(), 0);
83  int f = std::ceil(shape[3] / 2.0);
84  float c = (2 * f - 1 - f % 2) / (2. * f);
85  for (size_t i = 0; i < shape.Size(); ++i) {
86  int x = i % shape[3];
87  int y = (i / shape[3]) % shape[2];
88  weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
89  }
90  (*arr).SyncCopyFromCPU(weight);
91  }
92  virtual void InitZero(NDArray* arr) {
93  (*arr) = 0.0f;
94  }
95  virtual void InitOne(NDArray* arr) {
96  (*arr) = 1.0f;
97  }
98  virtual void InitBias(NDArray* arr) {
99  (*arr) = 0.0f;
100  }
101  virtual void InitGamma(NDArray* arr) {
102  (*arr) = 1.0f;
103  }
104  virtual void InitBeta(NDArray* arr) {
105  (*arr) = 0.0f;
106  }
107  virtual void InitWeight(NDArray* arr) {}
108  virtual void InitQuantizedWeight(NDArray* arr) {
109  std::default_random_engine generator;
110  std::uniform_int_distribution<int32_t> _val(-127, 127);
111  (*arr) = _val(generator);
112  }
113  virtual void InitQuantizedBias(NDArray* arr) {
114  (*arr) = 0;
115  }
116  virtual void InitDefault(NDArray* arr) {}
117 };
118 
119 class Constant : public Initializer {
120  public:
121  explicit Constant(float value) : value(value) {}
122  void operator()(const std::string& name, NDArray* arr) override {
123  (*arr) = value;
124  }
125 
126  protected:
127  float value;
128 };
129 
130 class Zero : public Constant {
131  public:
132  Zero() : Constant(0.0f) {}
133 };
134 
135 class One : public Constant {
136  public:
137  One() : Constant(1.0f) {}
138 };
139 
140 class Uniform : public Initializer {
141  public:
142  explicit Uniform(float scale) : Uniform(-scale, scale) {}
143  Uniform(float begin, float end) : begin(begin), end(end) {}
144  void operator()(const std::string& name, NDArray* arr) override {
145  if (StringEndWith(name, "weight_quantize")) {
146  InitQuantizedWeight(arr);
147  return;
148  }
149  if (StringEndWith(name, "bias_quantize")) {
150  InitQuantizedBias(arr);
151  return;
152  }
154  }
155 
156  protected:
157  float begin, end;
158 };
159 
160 class Normal : public Initializer {
161  public:
162  Normal(float mu, float sigma) : mu(mu), sigma(sigma) {}
163  void operator()(const std::string& name, NDArray* arr) override {
164  if (StringEndWith(name, "weight_quantize")) {
165  InitQuantizedWeight(arr);
166  return;
167  }
168  if (StringEndWith(name, "bias_quantize")) {
169  InitQuantizedBias(arr);
170  return;
171  }
173  }
174 
175  protected:
176  float mu, sigma;
177 };
178 
179 class Bilinear : public Initializer {
180  public:
181  Bilinear() {}
182  void operator()(const std::string& name, NDArray* arr) override {
183  if (StringEndWith(name, "weight_quantize")) {
184  InitQuantizedWeight(arr);
185  return;
186  }
187  if (StringEndWith(name, "bias_quantize")) {
188  InitQuantizedBias(arr);
189  return;
190  }
191  InitBilinear(arr);
192  }
193 };
194 
195 class Xavier : public Initializer {
196  public:
199  float magnitude;
201  FactorType factor_type = avg, // NOLINT
202  float magnitude = 3) // NOLINT
204 
205  void operator()(const std::string& name, NDArray* arr) override {
206  if (StringEndWith(name, "weight_quantize")) {
207  InitQuantizedWeight(arr);
208  return;
209  }
210  if (StringEndWith(name, "bias_quantize")) {
211  InitQuantizedBias(arr);
212  return;
213  }
214 
215  Shape shape(arr->GetShape());
216  float hw_scale = 1.0f;
217  if (shape.ndim() > 2) {
218  for (size_t i = 2; i < shape.ndim(); ++i) {
219  hw_scale *= shape[i];
220  }
221  }
222  float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
223  float factor = 1.0f;
224  switch (factor_type) {
225  case avg:
226  factor = (fan_in + fan_out) / 2.0;
227  break;
228  case in:
229  factor = fan_in;
230  break;
231  case out:
232  factor = fan_out;
233  }
234  float scale = std::sqrt(magnitude / factor);
235  switch (rand_type) {
236  case uniform:
237  NDArray::SampleUniform(-scale, scale, arr);
238  break;
239  case gaussian:
240  NDArray::SampleGaussian(0, scale, arr);
241  break;
242  }
243  }
244 };
245 
246 class MSRAPrelu : public Xavier {
247  public:
248  explicit MSRAPrelu(FactorType factor_type = avg, float slope = 0.25f)
249  : Xavier(gaussian, factor_type, 2. / (1 + slope * slope)) {}
250 };
251 
252 } // namespace cpp
253 } // namespace mxnet
254 
255 #endif // MXNET_CPP_INITIALIZER_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::Xavier::avg
@ avg
Definition: initializer.h:198
mxnet::cpp::Zero::Zero
Zero()
Definition: initializer.h:132
mxnet::cpp::Bilinear::Bilinear
Bilinear()
Definition: initializer.h:181
mxnet::cpp::Normal::mu
float mu
Definition: initializer.h:176
mxnet::cpp::Initializer::InitZero
virtual void InitZero(NDArray *arr)
Definition: initializer.h:92
mxnet::cpp::NDArray::SampleGaussian
static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
mxnet::cpp::Xavier::uniform
@ uniform
Definition: initializer.h:197
mxnet::cpp::One::One
One()
Definition: initializer.h:137
mxnet::cpp::Initializer
Definition: initializer.h:38
mxnet::cpp::Initializer::InitGamma
virtual void InitGamma(NDArray *arr)
Definition: initializer.h:101
mxnet::cpp::Initializer::StringEndWith
static bool StringEndWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:43
mxnet::cpp::MSRAPrelu
Definition: initializer.h:246
mxnet::cpp::Bilinear
Definition: initializer.h:179
mxnet::cpp::Initializer::InitWeight
virtual void InitWeight(NDArray *arr)
Definition: initializer.h:107
mxnet::cpp::Initializer::InitQuantizedWeight
virtual void InitQuantizedWeight(NDArray *arr)
Definition: initializer.h:108
mxnet::cpp::Uniform
Definition: initializer.h:140
mxnet::cpp::Xavier::RandType
RandType
Definition: initializer.h:197
mxnet::cpp::Xavier::gaussian
@ gaussian
Definition: initializer.h:197
mxnet::cpp::Constant::operator()
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:122
mxnet::cpp::Initializer::InitBias
virtual void InitBias(NDArray *arr)
Definition: initializer.h:98
mxnet::cpp::Xavier::rand_type
enum mxnet::cpp::Xavier::RandType rand_type
mxnet::cpp::One
Definition: initializer.h:135
mxnet::cpp::NDArray
NDArray interface.
Definition: ndarray.h:122
mxnet::cpp::Constant::value
float value
Definition: initializer.h:127
mxnet::cpp::Normal::sigma
float sigma
Definition: initializer.h:176
mxnet::cpp::Initializer::InitQuantizedBias
virtual void InitQuantizedBias(NDArray *arr)
Definition: initializer.h:113
ndarray.h
definition of ndarray
mxnet::cpp::NDArray::SampleUniform
static void SampleUniform(mx_float begin, mx_float end, NDArray *out)
Sample uniform distribution for each elements of out.
mxnet::cpp::Bilinear::operator()
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:182
mxnet::cpp::Xavier::magnitude
float magnitude
Definition: initializer.h:199
mxnet::cpp::Normal::Normal
Normal(float mu, float sigma)
Definition: initializer.h:162
mxnet::cpp::Xavier
Definition: initializer.h:195
mxnet::cpp::Initializer::InitOne
virtual void InitOne(NDArray *arr)
Definition: initializer.h:95
mxnet::cpp::Constant
Definition: initializer.h:119
mxnet::cpp::Normal::operator()
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:163
mxnet::cpp::MSRAPrelu::MSRAPrelu
MSRAPrelu(FactorType factor_type=avg, float slope=0.25f)
Definition: initializer.h:248
mxnet::cpp::Xavier::Xavier
Xavier(RandType rand_type=gaussian, FactorType factor_type=avg, float magnitude=3)
Definition: initializer.h:200
mxnet::cpp::Normal
Definition: initializer.h:160
mxnet::cpp::Uniform::operator()
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:144
mxnet::cpp::Shape::ndim
index_t ndim(void) const
return number of dimension of the tensor inside
Definition: shape.h:240
mxnet::cpp::Constant::Constant
Constant(float value)
Definition: initializer.h:121
mxnet::cpp::Uniform::Uniform
Uniform(float begin, float end)
Definition: initializer.h:143
mxnet::cpp::Xavier::operator()
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:205
mxnet::cpp::Xavier::out
@ out
Definition: initializer.h:198
mxnet::cpp::Uniform::end
float end
Definition: initializer.h:157
mxnet::cpp::NDArray::GetShape
std::vector< mx_uint > GetShape() const
mxnet::cpp::Xavier::FactorType
FactorType
Definition: initializer.h:198
mxnet::cpp::Uniform::Uniform
Uniform(float scale)
Definition: initializer.h:142
mxnet::cpp::Shape
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
mxnet::cpp::Initializer::StringStartWith
static bool StringStartWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:40
mxnet::cpp::Initializer::InitDefault
virtual void InitDefault(NDArray *arr)
Definition: initializer.h:116
mxnet::cpp::Initializer::InitBilinear
virtual void InitBilinear(NDArray *arr)
Definition: initializer.h:80
mxnet::cpp::Zero
Definition: initializer.h:130
mxnet::cpp::Shape::Size
size_t Size(void) const
total number of elements in the tensor
Definition: shape.h:260
mxnet::cpp::Xavier::in
@ in
Definition: initializer.h:198
mxnet::cpp::Initializer::operator()
virtual void operator()(const std::string &name, NDArray *arr)
Definition: initializer.h:47
mxnet::cpp::Initializer::InitBeta
virtual void InitBeta(NDArray *arr)
Definition: initializer.h:104
mxnet::cpp::Xavier::factor_type
enum mxnet::cpp::Xavier::FactorType factor_type
mxnet::cpp::Uniform::begin
float begin
Definition: initializer.h:157