mxnet
initializer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_INITIALIZER_H_
27 #define MXNET_CPP_INITIALIZER_H_
28 
29 #include <cmath>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/ndarray.h"
33 
34 namespace mxnet {
35 namespace cpp {
36 
37 class Initializer {
38  public:
39  static bool StringStartWith(const std::string& name,
40  const std::string& check_str) {
41  return (name.size() >= check_str.size() &&
42  name.substr(0, check_str.size()) == check_str);
43  }
44  static bool StringEndWith(const std::string& name,
45  const std::string& check_str) {
46  return (name.size() >= check_str.size() &&
47  name.substr(name.size() - check_str.size(), check_str.size()) ==
48  check_str);
49  }
50  virtual void operator()(const std::string& name, NDArray* arr) {
51  if (StringStartWith(name, "upsampling")) {
52  InitBilinear(arr);
53  } else if (StringEndWith(name, "bias")) {
54  InitBias(arr);
55  } else if (StringEndWith(name, "gamma")) {
56  InitGamma(arr);
57  } else if (StringEndWith(name, "beta")) {
58  InitBeta(arr);
59  } else if (StringEndWith(name, "weight")) {
60  InitWeight(arr);
61  } else if (StringEndWith(name, "moving_mean")) {
62  InitZero(arr);
63  } else if (StringEndWith(name, "moving_var")) {
64  InitOne(arr);
65  } else if (StringEndWith(name, "moving_inv_var")) {
66  InitZero(arr);
67  } else if (StringEndWith(name, "moving_avg")) {
68  InitZero(arr);
69  } else {
70  InitDefault(arr);
71  }
72  }
73 
74  protected:
75  virtual void InitBilinear(NDArray* arr) {
76  Shape shape(arr->GetShape());
77  std::vector<float> weight(shape.Size(), 0);
78  int f = std::ceil(shape[3] / 2.0);
79  float c = (2 * f - 1 - f % 2) / (2. * f);
80  for (size_t i = 0; i < shape.Size(); ++i) {
81  int x = i % shape[3];
82  int y = (i / shape[3]) % shape[2];
83  weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
84  }
85  (*arr).SyncCopyFromCPU(weight);
86  }
87  virtual void InitZero(NDArray* arr) { (*arr) = 0.0f; }
88  virtual void InitOne(NDArray* arr) { (*arr) = 1.0f; }
89  virtual void InitBias(NDArray* arr) { (*arr) = 0.0f; }
90  virtual void InitGamma(NDArray* arr) { (*arr) = 1.0f; }
91  virtual void InitBeta(NDArray* arr) { (*arr) = 0.0f; }
92  virtual void InitWeight(NDArray* arr) {}
93  virtual void InitDefault(NDArray* arr) {}
94 };
95 
96 class Constant : public Initializer {
97  public:
98  explicit Constant(float value)
99  : value(value) {}
100  void operator()(const std::string &name, NDArray *arr) override {
101  (*arr) = value;
102  }
103  protected:
104  float value;
105 };
106 
107 class Zero : public Constant {
108  public:
109  Zero(): Constant(0.0f) {}
110 };
111 
112 class One : public Constant {
113  public:
114  One(): Constant(1.0f) {}
115 };
116 
117 class Uniform : public Initializer {
118  public:
119  explicit Uniform(float scale)
120  : Uniform(-scale, scale) {}
121  Uniform(float begin, float end)
122  : begin(begin), end(end) {}
123  void operator()(const std::string &name, NDArray *arr) override {
124  NDArray::SampleUniform(begin, end, arr);
125  }
126  protected:
127  float begin, end;
128 };
129 
130 class Normal : public Initializer {
131  public:
132  Normal(float mu, float sigma)
133  : mu(mu), sigma(sigma) {}
134  void operator()(const std::string &name, NDArray *arr) override {
135  NDArray::SampleGaussian(mu, sigma, arr);
136  }
137  protected:
138  float mu, sigma;
139 };
140 
141 class Bilinear : public Initializer {
142  public:
143  Bilinear() {}
144  void operator()(const std::string &name, NDArray *arr) override {
145  InitBilinear(arr);
146  }
147 };
148 
149 class Xavier : public Initializer {
150  public:
151  enum RandType {
153  uniform
154  } rand_type;
155  enum FactorType {
157  in,
158  out
159  } factor_type;
160  float magnitude;
161  Xavier(RandType rand_type = gaussian, FactorType factor_type = avg,
162  float magnitude = 3)
163  : rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {}
164 
165  void operator()(const std::string &name, NDArray* arr) override {
166  Shape shape(arr->GetShape());
167  float hw_scale = 1.0f;
168  if (shape.ndim() > 2) {
169  for (size_t i = 2; i < shape.ndim(); ++i) {
170  hw_scale *= shape[i];
171  }
172  }
173  float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale;
174  float factor = 1.0f;
175  switch (factor_type) {
176  case avg:
177  factor = (fan_in + fan_out) / 2.0;
178  break;
179  case in:
180  factor = fan_in;
181  break;
182  case out:
183  factor = fan_out;
184  }
185  float scale = std::sqrt(magnitude / factor);
186  switch (rand_type) {
187  case uniform:
188  NDArray::SampleUniform(-scale, scale, arr);
189  break;
190  case gaussian:
191  NDArray::SampleGaussian(0, scale, arr);
192  break;
193  }
194  }
195 };
196 
197 } // namespace cpp
198 } // namespace mxnet
199 
200 #endif // MXNET_CPP_INITIALIZER_H_
static bool StringStartWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:39
Uniform(float scale)
Definition: initializer.h:119
static bool StringEndWith(const std::string &name, const std::string &check_str)
Definition: initializer.h:44
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:144
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:134
Definition: initializer.h:141
Definition: initializer.h:149
namespace of mxnet
Definition: base.h:126
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
Definition: initializer.h:96
Definition: initializer.h:117
FactorType
Definition: initializer.h:155
Xavier(RandType rand_type=gaussian, FactorType factor_type=avg, float magnitude=3)
Definition: initializer.h:161
float value
Definition: initializer.h:104
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:3007
Uniform(float begin, float end)
Definition: initializer.h:121
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:123
virtual void InitOne(NDArray *arr)
Definition: initializer.h:88
virtual void InitGamma(NDArray *arr)
Definition: initializer.h:90
Definition: initializer.h:112
virtual void InitBias(NDArray *arr)
Definition: initializer.h:89
One()
Definition: initializer.h:114
virtual void InitZero(NDArray *arr)
Definition: initializer.h:87
NDArray interface.
Definition: ndarray.h:120
Definition: initializer.h:152
Definition: initializer.h:107
Definition: initializer.h:156
float sigma
Definition: initializer.h:138
virtual void InitBeta(NDArray *arr)
Definition: initializer.h:91
Definition: initializer.h:157
virtual void InitWeight(NDArray *arr)
Definition: initializer.h:92
Definition: initializer.h:130
float end
Definition: initializer.h:127
RandType
Definition: initializer.h:151
Zero()
Definition: initializer.h:109
static void SampleUniform(mx_float begin, mx_float end, NDArray *out)
Sample uniform distribution for each elements of out.
Bilinear()
Definition: initializer.h:143
virtual void InitBilinear(NDArray *arr)
Definition: initializer.h:75
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:2801
Normal(float mu, float sigma)
Definition: initializer.h:132
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:2891
Constant(float value)
Definition: initializer.h:98
virtual void InitDefault(NDArray *arr)
Definition: initializer.h:93
float magnitude
Definition: initializer.h:160
std::vector< mx_uint > GetShape() const
virtual void operator()(const std::string &name, NDArray *arr)
Definition: initializer.h:50
static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:165
void operator()(const std::string &name, NDArray *arr) override
Definition: initializer.h:100
Definition: initializer.h:37