mxnet
optimizer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_OPTIMIZER_H_
28 #define MXNET_CPP_OPTIMIZER_H_
29 
30 #include <dmlc/strtonum.h>
31 #include <map>
32 #include <vector>
33 #include <string>
34 #include <memory>
35 #include <functional>
36 #include "mxnet-cpp/base.h"
37 #include "dmlc/logging.h"
38 #include "mxnet-cpp/ndarray.h"
39 #include "mxnet-cpp/op_map.h"
40 #include "mxnet-cpp/lr_scheduler.h"
41 
42 namespace mxnet {
43 namespace cpp {
44 
48 class Optimizer {
49  public:
54  explicit Optimizer(unsigned begin_num_update);
59  virtual std::string GetType() const = 0;
63  virtual ~Optimizer();
70  template <typename T>
71  Optimizer *SetParam(const std::string &name, const T &value) {
72  std::string value_str;
73  std::stringstream ss;
74  ss << value;
75  ss >> value_str;
76 
77  params_[name] = value_str;
78  return this;
79  }
85  Optimizer *SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
86  CHECK(lrScheduler);
87  lrScheduler_ = std::move(lrScheduler);
88  lrScheduler_->SetLR(dmlc::stof(params_["lr"]));
89  return this;
90  }
97  virtual void Update(int index, NDArray weight, NDArray grad) = 0;
98  // TODO(zhangcheng-qinyinghua)
99  // implement Update a list of arrays, maybe in the form of map
100  // void Update(int index, std::vector<NDArray> weights, std::vector<NDArray>
101  // grad, mx_float lr);
102 
107  std::string Serialize() const;
108 
109  protected:
110  std::map<std::string, std::string> params_;
111  static OpMap*& op_map();
112  const std::vector<const char*> GetParamKeys_() const;
113  const std::vector<const char*> GetParamValues_() const;
114  std::map<int, unsigned> count_;
116  unsigned UpdateCount_(int index);
117  float GetLR_(int index);
118  float GetWD_(int index);
119  virtual void CreateState_(int index, NDArray weight);
120  std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
121 };
122 
123 typedef std::function<Optimizer*()> OptimizerCreator;
124 
126  public:
127  static Optimizer* Find(const std::string& name);
128  static int __REGISTER__(const std::string& name, OptimizerCreator creator);
129  private:
130  static std::map<std::string, OptimizerCreator>& cmap();
131  OptimizerRegistry() = delete;
132  ~OptimizerRegistry() = delete;
133 };
134 #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType)\
135  OptimizerRegistry::__REGISTER__(#Name, [](){return new OptimizerType();})
136 
137 class SGDOptimizer : public Optimizer {
138  public:
139  explicit SGDOptimizer(unsigned begin_num_update = 0);
140  std::string GetType() const override;
141  void Update(int index, NDArray weight, NDArray grad) override;
142  private:
143  virtual ~SGDOptimizer();
144  void CreateState_(int index, NDArray weight) override;
145  std::map<int, NDArray*> states_;
146  AtomicSymbolCreator update_handle_;
147  AtomicSymbolCreator mom_update_handle_;
148 };
149 
150 class SignumOptimizer : public Optimizer {
151  public:
152  explicit SignumOptimizer(unsigned begin_num_update = 0);
153  std::string GetType() const override;
154  void Update(int index, NDArray weight, NDArray grad) override;
155  private:
156  virtual ~SignumOptimizer();
157  void CreateState_(int index, NDArray weight) override;
158  std::map<int, NDArray*> states_;
159  AtomicSymbolCreator update_handle_;
160  AtomicSymbolCreator mom_update_handle_;
161 };
162 
163 
164 class RMSPropOptimizer : public Optimizer {
165  public:
166  explicit RMSPropOptimizer(unsigned begin_num_update = 0);
167  std::string GetType() const override;
168  void Update(int index, NDArray weight, NDArray grad) override;
169  private:
170  virtual ~RMSPropOptimizer();
171  void CreateState_(int index, NDArray weight) override;
172  std::map<int, NDArray*> n_, g_, delta_;
173  AtomicSymbolCreator update_handle_;
174  AtomicSymbolCreator alex_update_handle_;
175 };
176 
177 class AdamOptimizer : public Optimizer {
178  public:
179  explicit AdamOptimizer(unsigned begin_num_update = 0);
180  std::string GetType() const override;
181  void Update(int index, NDArray weight, NDArray grad) override;
182  private:
183  virtual ~AdamOptimizer();
184  void CreateState_(int index, NDArray weight) override;
185  std::map<int, NDArray*> mean_;
186  std::map<int, NDArray*> var_;
187  AtomicSymbolCreator update_handle_;
188 };
189 
190 class AdaGradOptimizer : public Optimizer {
191  public:
192  explicit AdaGradOptimizer(unsigned begin_num_update = 0);
193  std::string GetType() const override;
194  void Update(int index, NDArray weight, NDArray grad) override;
195  private:
196  virtual ~AdaGradOptimizer();
197  void CreateState_(int index, NDArray weight) override;
198  std::map<int, NDArray*> history_;
199 };
200 
201 class AdaDeltaOptimizer : public Optimizer {
202  public:
203  explicit AdaDeltaOptimizer(unsigned begin_num_update = 0);
204  std::string GetType() const override;
205  void Update(int index, NDArray weight, NDArray grad) override;
206  private:
207  virtual ~AdaDeltaOptimizer();
208  void CreateState_(int index, NDArray weight) override;
209  std::map<int, NDArray*> acc_g_, acc_delta_;
210 };
211 
212 } // namespace cpp
213 } // namespace mxnet
214 
215 #endif // MXNET_CPP_OPTIMIZER_H_
Definition: optimizer.h:137
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
Definition: optimizer.h:177
definition of OpMap
unsigned UpdateCount_(int index)
namespace of mxnet
Definition: api_registry.h:33
Optimizer(unsigned begin_num_update)
constructor
const std::vector< const char * > GetParamKeys_() const
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:71
virtual std::string GetType() const =0
get optimizer type
Scheduling learning rate.
unsigned begin_num_update_
Definition: optimizer.h:115
Definition: optimizer.h:150
unsigned num_update_
Definition: optimizer.h:115
Optimizer interface.
Definition: optimizer.h:48
std::map< int, unsigned > count_
Definition: optimizer.h:114
A faster implementation of strtof and strtod.
Definition: optimizer.h:201
Definition: optimizer.h:190
NDArray interface.
Definition: ndarray.h:121
float GetWD_(int index)
virtual void Update(int index, NDArray weight, NDArray grad)=0
Update a weight with gradient.
std::unique_ptr< LRScheduler > lrScheduler_
Definition: optimizer.h:120
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:467
Definition: optimizer.h:125
Definition: optimizer.h:164
std::map< std::string, std::string > params_
Definition: optimizer.h:110
virtual ~Optimizer()
destructor
const std::vector< const char * > GetParamValues_() const
Optimizer * SetParam(const std::string &name, const T &value)
set config parameters
Definition: optimizer.h:71
virtual void CreateState_(int index, NDArray weight)
std::string Serialize() const
Serialize the optimizer parameters to a string.
static OpMap *& op_map()
std::function< Optimizer *()> OptimizerCreator
Definition: optimizer.h:123
Optimizer * SetLRScheduler(std::unique_ptr< LRScheduler > lrScheduler)
Definition: optimizer.h:85
float GetLR_(int index)