mxnet
initializer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_INITIALIZER_H_
28 #define MXNET_CPP_INITIALIZER_H_
29 
30 #include <cmath>
31 #include <string>
32 #include <vector>
33 #include <random>
34 #include "mxnet-cpp/ndarray.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class Initializer {
40  public:
41  static bool StringStartWith(const std::string& name,
42  const std::string& check_str) {
43  return (name.size() >= check_str.size() &&
44  name.substr(0, check_str.size()) == check_str);
45  }
46  static bool StringEndWith(const std::string& name,
47  const std::string& check_str) {
48  return (name.size() >= check_str.size() &&
49  name.substr(name.size() - check_str.size(), check_str.size()) ==
50  check_str);
51  }
52  virtual void operator()(const std::string& name, NDArray* arr) {
53  if (StringStartWith(name, "upsampling")) {
54  InitBilinear(arr);
55  } else if (StringEndWith(name, "bias")) {
56  InitBias(arr);
57  } else if (StringEndWith(name, "gamma")) {
58  InitGamma(arr);
59  } else if (StringEndWith(name, "beta")) {
60  InitBeta(arr);
61  } else if (StringEndWith(name, "weight")) {
62  InitWeight(arr);
63  } else if (StringEndWith(name, "moving_mean")) {
64  InitZero(arr);
65  } else if (StringEndWith(name, "moving_var")) {
66  InitOne(arr);
67  } else if (StringEndWith(name, "moving_inv_var")) {
68  InitZero(arr);
69  } else if (StringEndWith(name, "moving_avg")) {
70  InitZero(arr);
71  } else if (StringEndWith(name, "min")) {
72  InitZero(arr);
73  } else if (StringEndWith(name, "max")) {
74  InitOne(arr);
75  } else if (StringEndWith(name, "weight_quantize")) {
77  } else if (StringEndWith(name, "bias_quantize")) {
78  InitQuantizedBias(arr);
79  } else {
80  InitDefault(arr);
81  }
82  }
83 
84  protected:
85  virtual void InitBilinear(NDArray* arr) {
86  Shape shape(arr->GetShape());
87  std::vector<float> weight(shape.Size(), 0);
88  int f = std::ceil(shape[3] / 2.0);
89  float c = (2 * f - 1 - f % 2) / (2. * f);
90  for (size_t i = 0; i < shape.Size(); ++i) {
91  int x = i % shape[3];
92  int y = (i / shape[3]) % shape[2];
93  weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
94  }
95  (*arr).SyncCopyFromCPU(weight);
96  }
97  virtual void InitZero(NDArray* arr) { (*arr) = 0.0f; }
98  virtual void InitOne(NDArray* arr) { (*arr) = 1.0f; }
99  virtual void InitBias(NDArray* arr) { (*arr) = 0.0f; }
100  virtual void InitGamma(NDArray* arr) { (*arr) = 1.0f; }
101  virtual void InitBeta(NDArray* arr) { (*arr) = 0.0f; }
102  virtual void InitWeight(NDArray* arr) {}
103  virtual void InitQuantizedWeight(NDArray* arr) {
104  std::default_random_engine generator;
105  std::uniform_int_distribution<int32_t> _val(-127, 127);
106  (*arr) = _val(generator);
107  }
108  virtual void InitQuantizedBias(NDArray* arr) {
109  (*arr) = 0;
110  }
111  virtual void InitDefault(NDArray* arr) {}
112 };
113 
114 class Constant : public Initializer {
115  public:
116  explicit Constant(float value)
117  : value(value) {}
118  void operator()(const std::string &name, NDArray *arr) override {
119  (*arr) = value;
120  }
121  protected:
122  float value;
123 };
124 
125 class Zero : public Constant {
126  public:
127  Zero(): Constant(0.0f) {}
128 };
129 
130 class One : public Constant {
131  public:
132  One(): Constant(1.0f) {}
133 };
134 
135 class Uniform : public Initializer {
136  public:
137  explicit Uniform(float scale)
138  : Uniform(-scale, scale) {}
139  Uniform(float begin, float end)
140  : begin(begin), end(end) {}
141  void operator()(const std::string &name, NDArray *arr) override {
142  if (StringEndWith(name, "weight_quantize")) {
143  InitQuantizedWeight(arr);
144  return;
145  }
146  if (StringEndWith(name, "bias_quantize")) {
147  InitQuantizedBias(arr);
148  return;
149  }
150  NDArray::SampleUniform(begin, end, arr);
151  }
152  protected:
153  float begin, end;
154 };
155 
156 class Normal : public Initializer {
157  public:
158  Normal(float mu, float sigma)
159  : mu(mu), sigma(sigma) {}
160  void operator()(const std::string &name, NDArray *arr) override {
161  if (StringEndWith(name, "weight_quantize")) {
162  InitQuantizedWeight(arr);
163  return;
164  }
165  if (StringEndWith(name, "bias_quantize")) {
166  InitQuantizedBias(arr);
167  return;
168  }
169  NDArray::SampleGaussian(mu, sigma, arr);
170  }
171  protected:
172  float mu, sigma;
173 };
174 
175 class Bilinear : public Initializer {
176  public:
177  Bilinear() {}
178  void operator()(const std::string &name, NDArray *arr) override {
179  if (StringEndWith(name, "weight_quantize")) {
180  InitQuantizedWeight(arr);
181  return;
182  }
183  if (StringEndWith(name, "bias_quantize")) {
184  InitQuantizedBias(arr);
185  return;
186  }
187  InitBilinear(arr);
188  }
189 };
190 
191 class Xavier : public Initializer {
192  public:
193  enum RandType {
195  uniform
196  } rand_type;
197  enum FactorType {
199  in,
200  out
201  } factor_type;
202  float magnitude;
203  Xavier(RandType rand_type = gaussian, FactorType factor_type = avg,
204  float magnitude = 3)
205  : rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {}
206 
207  void operator()(const std::string &name, NDArray* arr) override {
208  if (StringEndWith(name, "weight_quantize")) {
209  InitQuantizedWeight(arr);
210  return;
211  }
212  if (StringEndWith(name, "bias_quantize")) {
213  InitQuantizedBias(arr);
214  return;
215  }
216 
217  Shape shape(arr->GetShape());
218  float hw_scale = 1.0f;
219  if (shape.ndim() > 2) {
220  for (size_t i = 2; i < shape.ndim(); ++i) {
221  hw_scale *= shape[i];
222  }
223  }
224  float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
225  float factor = 1.0f;
226  switch (factor_type) {
227  case avg:
228  factor = (fan_in + fan_out) / 2.0;
229  break;
230  case in:
231  factor = fan_in;
232  break;
233  case out:
234  factor = fan_out;
235  }
236  float scale = std::sqrt(magnitude / factor);
237  switch (rand_type) {
238  case uniform:
239  NDArray::SampleUniform(-scale, scale, arr);
240  break;
241  case gaussian:
242  NDArray::SampleGaussian(0, scale, arr);
243  break;
244  }
245  }
246 };
247 
248 class MSRAPrelu : public Xavier {
249  public:
250  explicit MSRAPrelu(FactorType factor_type = avg, float slope = 0.25f)
251  : Xavier(gaussian, factor_type, 2. / (1 + slope * slope)) {}
252 };
253 
254 } // namespace cpp
255 } // namespace mxnet
256 
257 #endif // MXNET_CPP_INITIALIZER_H_
static bool StringStartWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:41
Uniform(float scale)
Definition: initializer.h:137
static bool StringEndWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:46
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:178
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:160
Definition: initializer.h:175
Definition: initializer.h:191
namespace of mxnet
Definition: api_registry.h:33
virtual void InitQuantizedWeight(NDArray *arr)
Definition: initializer.h:103
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Definition: initializer.h:114
MSRAPrelu(FactorType factor_type=avg, float slope=0.25f)
Definition: initializer.h:250
Definition: initializer.h:135
FactorType
Definition: initializer.h:197
Xavier(RandType rand_type=gaussian, FactorType factor_type=avg, float magnitude=3)
Definition: initializer.h:203
float value
Definition: initializer.h:122
Uniform(float begin, float end)
Definition: initializer.h:139
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:141
virtual void InitOne(NDArray *arr)
Definition: initializer.h:98
virtual void InitGamma(NDArray *arr)
Definition: initializer.h:100
Definition: initializer.h:130
virtual void InitBias(NDArray *arr)
Definition: initializer.h:99
One()
Definition: initializer.h:132
virtual void InitZero(NDArray *arr)
Definition: initializer.h:97
NDArray interface.
Definition: ndarray.h:121
Definition: initializer.h:194
Definition: initializer.h:125
Definition: initializer.h:198
float sigma
Definition: initializer.h:172
virtual void InitBeta(NDArray *arr)
Definition: initializer.h:101
Definition: initializer.h:199
virtual void InitWeight(NDArray *arr)
Definition: initializer.h:102
Definition: initializer.h:156
Definition: initializer.h:248
float end
Definition: initializer.h:153
RandType
Definition: initializer.h:193
Zero()
Definition: initializer.h:127
static void SampleUniform(mx_float begin, mx_float end, NDArray *out)
Sample uniform distribution for each elements of out.
Bilinear()
Definition: initializer.h:177
virtual void InitBilinear(NDArray *arr)
Definition: initializer.h:85
Normal(float mu, float sigma)
Definition: initializer.h:158
Constant(float value)
Definition: initializer.h:116
virtual void InitDefault(NDArray *arr)
Definition: initializer.h:111
float magnitude
Definition: initializer.h:202
std::vector< mx_uint > GetShape() const
virtual void operator()(const std::string &name, NDArray *arr)
Definition: initializer.h:52
static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:207
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:118
Definition: initializer.h:39
virtual void InitQuantizedBias(NDArray *arr)
Definition: initializer.h:108