mxnet
optimizer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_OPTIMIZER_H_
27 #define MXNET_CPP_OPTIMIZER_H_
28 
29 #include <dmlc/strtonum.h>
30 #include <map>
31 #include <vector>
32 #include <string>
33 #include <memory>
34 #include <functional>
35 #include "mxnet-cpp/base.h"
36 #include "dmlc/logging.h"
37 #include "mxnet-cpp/ndarray.h"
38 #include "mxnet-cpp/op_map.h"
39 #include "mxnet-cpp/lr_scheduler.h"
40 
41 namespace mxnet {
42 namespace cpp {
43 
47 class Optimizer {
48  public:
53  explicit Optimizer(unsigned begin_num_update);
58  virtual std::string GetType() const = 0;
62  virtual ~Optimizer();
69  template <typename T>
70  Optimizer* SetParam(const std::string& name, const T& value) {
71  std::string value_str;
72  std::stringstream ss;
73  ss << value;
74  ss >> value_str;
75 
76  params_[name] = value_str;
77  return this;
78  }
84  Optimizer* SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
85  CHECK(lrScheduler);
86  lrScheduler_ = std::move(lrScheduler);
87  lrScheduler_->SetLR(dmlc::stof(params_["lr"]));
88  return this;
89  }
96  virtual void Update(int index, NDArray weight, NDArray grad) = 0;
97  // TODO(zhangcheng-qinyinghua)
98  // implement Update a list of arrays, maybe in the form of map
99  // void Update(int index, std::vector<NDArray> weights, std::vector<NDArray>
100  // grad, mx_float lr);
101 
106  std::string Serialize() const;
107 
108  protected:
109  std::map<std::string, std::string> params_;
110  static OpMap*& op_map();
111  const std::vector<const char*> GetParamKeys_() const;
112  const std::vector<const char*> GetParamValues_() const;
113  std::map<int, unsigned> count_;
115  unsigned UpdateCount_(int index);
116  float GetLR_(int index);
117  float GetWD_(int index);
118  virtual void CreateState_(int index, NDArray weight);
119  std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
120 };
121 
122 typedef std::function<Optimizer*()> OptimizerCreator;
123 
125  public:
126  static Optimizer* Find(const std::string& name);
127  static int __REGISTER__(const std::string& name, OptimizerCreator creator);
128 
129  private:
130  static std::map<std::string, OptimizerCreator>& cmap();
131  OptimizerRegistry() = delete;
132  ~OptimizerRegistry() = delete;
133 };
134 #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType) \
135  OptimizerRegistry::__REGISTER__(#Name, []() { return new OptimizerType(); })
136 
137 class SGDOptimizer : public Optimizer {
138  public:
139  explicit SGDOptimizer(unsigned begin_num_update = 0);
140  std::string GetType() const override;
141  void Update(int index, NDArray weight, NDArray grad) override;
142 
143  private:
144  virtual ~SGDOptimizer();
145  void CreateState_(int index, NDArray weight) override;
146  std::map<int, NDArray*> states_;
147  AtomicSymbolCreator update_handle_;
148  AtomicSymbolCreator mom_update_handle_;
149 };
150 
151 class SignumOptimizer : public Optimizer {
152  public:
153  explicit SignumOptimizer(unsigned begin_num_update = 0);
154  std::string GetType() const override;
155  void Update(int index, NDArray weight, NDArray grad) override;
156 
157  private:
158  virtual ~SignumOptimizer();
159  void CreateState_(int index, NDArray weight) override;
160  std::map<int, NDArray*> states_;
161  AtomicSymbolCreator update_handle_;
162  AtomicSymbolCreator mom_update_handle_;
163 };
164 
165 class RMSPropOptimizer : public Optimizer {
166  public:
167  explicit RMSPropOptimizer(unsigned begin_num_update = 0);
168  std::string GetType() const override;
169  void Update(int index, NDArray weight, NDArray grad) override;
170 
171  private:
172  virtual ~RMSPropOptimizer();
173  void CreateState_(int index, NDArray weight) override;
174  std::map<int, NDArray*> n_, g_, delta_;
175  AtomicSymbolCreator update_handle_;
176  AtomicSymbolCreator alex_update_handle_;
177 };
178 
179 class AdamOptimizer : public Optimizer {
180  public:
181  explicit AdamOptimizer(unsigned begin_num_update = 0);
182  std::string GetType() const override;
183  void Update(int index, NDArray weight, NDArray grad) override;
184 
185  private:
186  virtual ~AdamOptimizer();
187  void CreateState_(int index, NDArray weight) override;
188  std::map<int, NDArray*> mean_;
189  std::map<int, NDArray*> var_;
190  AtomicSymbolCreator update_handle_;
191 };
192 
193 class AdaGradOptimizer : public Optimizer {
194  public:
195  explicit AdaGradOptimizer(unsigned begin_num_update = 0);
196  std::string GetType() const override;
197  void Update(int index, NDArray weight, NDArray grad) override;
198 
199  private:
200  virtual ~AdaGradOptimizer();
201  void CreateState_(int index, NDArray weight) override;
202  std::map<int, NDArray*> history_;
203 };
204 
205 class AdaDeltaOptimizer : public Optimizer {
206  public:
207  explicit AdaDeltaOptimizer(unsigned begin_num_update = 0);
208  std::string GetType() const override;
209  void Update(int index, NDArray weight, NDArray grad) override;
210 
211  private:
212  virtual ~AdaDeltaOptimizer();
213  void CreateState_(int index, NDArray weight) override;
214  std::map<int, NDArray*> acc_g_, acc_delta_;
215 };
216 
217 } // namespace cpp
218 } // namespace mxnet
219 
220 #endif // MXNET_CPP_OPTIMIZER_H_
mxnet::cpp::Optimizer::~Optimizer
virtual ~Optimizer()
destructor
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::SignumOptimizer::Update
void Update(int index, NDArray weight, NDArray grad) override
Update a weight with gradient.
mxnet::cpp::SGDOptimizer::Update
void Update(int index, NDArray weight, NDArray grad) override
Update a weight with gradient.
mxnet::cpp::AdaGradOptimizer
Definition: optimizer.h:193
mxnet::cpp::Optimizer::op_map
static OpMap *& op_map()
mxnet::cpp::AdamOptimizer
Definition: optimizer.h:179
mxnet::cpp::SignumOptimizer
Definition: optimizer.h:151
mxnet::cpp::AdaDeltaOptimizer::GetType
std::string GetType() const override
get optimizer type
mxnet::cpp::AdaGradOptimizer::Update
void Update(int index, NDArray weight, NDArray grad) override
Update a weight with gradient.
mxnet::cpp::Optimizer::GetType
virtual std::string GetType() const =0
get optimizer type
mxnet::cpp::RMSPropOptimizer::RMSPropOptimizer
RMSPropOptimizer(unsigned begin_num_update=0)
mxnet::cpp::Optimizer::params_
std::map< std::string, std::string > params_
Definition: optimizer.h:109
mxnet::cpp::Optimizer::UpdateCount_
unsigned UpdateCount_(int index)
mxnet::cpp::AdaDeltaOptimizer::AdaDeltaOptimizer
AdaDeltaOptimizer(unsigned begin_num_update=0)
mxnet::cpp::Optimizer::begin_num_update_
unsigned begin_num_update_
Definition: optimizer.h:114
mxnet::cpp::NDArray
NDArray interface.
Definition: ndarray.h:122
mxnet::cpp::SGDOptimizer::GetType
std::string GetType() const override
get optimizer type
ndarray.h
definition of ndarray
mxnet::cpp::Optimizer::Update
virtual void Update(int index, NDArray weight, NDArray grad)=0
Update a weight with gradient.
strtonum.h
A faster implementation of strtof and strtod.
mxnet::cpp::SGDOptimizer
Definition: optimizer.h:137
mxnet::cpp::AdamOptimizer::Update
void Update(int index, NDArray weight, NDArray grad) override
Update a weight with gradient.
mxnet::cpp::AdaGradOptimizer::AdaGradOptimizer
AdaGradOptimizer(unsigned begin_num_update=0)
lr_scheduler.h
Scheduling learning rate.
mxnet::cpp::Optimizer::GetParamKeys_
const std::vector< const char * > GetParamKeys_() const
mxnet::cpp::OpMap
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name....
Definition: op_map.h:42
mxnet::cpp::AdaDeltaOptimizer
Definition: optimizer.h:205
mxnet::cpp::AdaGradOptimizer::GetType
std::string GetType() const override
get optimizer type
mxnet::cpp::Optimizer
Optimizer interface.
Definition: optimizer.h:47
dmlc::stof
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information....
Definition: strtonum.h:467
mxnet::cpp::SignumOptimizer::GetType
std::string GetType() const override
get optimizer type
mxnet::cpp::AdamOptimizer::GetType
std::string GetType() const override
get optimizer type
mxnet::cpp::OptimizerCreator
std::function< Optimizer *()> OptimizerCreator
Definition: optimizer.h:122
mxnet::cpp::RMSPropOptimizer::GetType
std::string GetType() const override
get optimizer type
mxnet::cpp::Optimizer::lrScheduler_
std::unique_ptr< LRScheduler > lrScheduler_
Definition: optimizer.h:119
mxnet::cpp::OptimizerRegistry::Find
static Optimizer * Find(const std::string &name)
mxnet::cpp::Optimizer::CreateState_
virtual void CreateState_(int index, NDArray weight)
mxnet::cpp::AdaDeltaOptimizer::Update
void Update(int index, NDArray weight, NDArray grad) override
Update a weight with gradient.
mxnet::cpp::RMSPropOptimizer::Update
void Update(int index, NDArray weight, NDArray grad) override
Update a weight with gradient.
AtomicSymbolCreator
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:78
mxnet::cpp::Optimizer::SetLRScheduler
Optimizer * SetLRScheduler(std::unique_ptr< LRScheduler > lrScheduler)
Definition: optimizer.h:84
mxnet::cpp::Optimizer::GetLR_
float GetLR_(int index)
mxnet::cpp::Optimizer::GetWD_
float GetWD_(int index)
mxnet::cpp::Optimizer::GetParamValues_
const std::vector< const char * > GetParamValues_() const
mxnet::cpp::SGDOptimizer::SGDOptimizer
SGDOptimizer(unsigned begin_num_update=0)
mxnet::cpp::Optimizer::count_
std::map< int, unsigned > count_
Definition: optimizer.h:113
mxnet::cpp::Optimizer::num_update_
unsigned num_update_
Definition: optimizer.h:114
mxnet::cpp::OptimizerRegistry
Definition: optimizer.h:124
mxnet::cpp::Optimizer::Optimizer
Optimizer(unsigned begin_num_update)
constructor
mxnet::cpp::Optimizer::Serialize
std::string Serialize() const
Serialize the optimizer parameters to a string.
mxnet::cpp::OptimizerRegistry::__REGISTER__
static int __REGISTER__(const std::string &name, OptimizerCreator creator)
base.h
base definitions for mxnetcpp
mxnet::cpp::AdamOptimizer::AdamOptimizer
AdamOptimizer(unsigned begin_num_update=0)
mxnet::cpp::Optimizer::SetParam
Optimizer * SetParam(const std::string &name, const T &value)
set config parameters
Definition: optimizer.h:70
mxnet::cpp::SignumOptimizer::SignumOptimizer
SignumOptimizer(unsigned begin_num_update=0)
op_map.h
definition of OpMap
mxnet::cpp::RMSPropOptimizer
Definition: optimizer.h:165