mxnet
optimizer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_OPTIMIZER_H_
28 #define MXNET_CPP_OPTIMIZER_H_
29 
30 #include <map>
31 #include <vector>
32 #include <string>
33 #include <memory>
34 #include <functional>
35 #include "mxnet-cpp/base.h"
36 #include "dmlc/logging.h"
37 #include "mxnet-cpp/ndarray.h"
38 #include "mxnet-cpp/op_map.h"
39 #include "mxnet-cpp/lr_scheduler.h"
40 
41 namespace mxnet {
42 namespace cpp {
43 
47 class Optimizer {
48  public:
53  explicit Optimizer(unsigned begin_num_update);
58  virtual std::string GetType() const = 0;
62  virtual ~Optimizer();
69  template <typename T>
70  Optimizer *SetParam(const std::string &name, const T &value) {
71  std::string value_str;
72  std::stringstream ss;
73  ss << value;
74  ss >> value_str;
75 
76  params_[name] = value_str;
77  return this;
78  }
84  Optimizer *SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
85  CHECK(lrScheduler);
86  lrScheduler_ = std::move(lrScheduler);
87  lrScheduler_->SetLR(std::stof(params_["lr"]));
88  return this;
89  }
96  virtual void Update(int index, NDArray weight, NDArray grad) = 0;
97  // TODO(zhangcheng-qinyinghua)
98  // implement Update a list of arrays, maybe in the form of map
99  // void Update(int index, std::vector<NDArray> weights, std::vector<NDArray>
100  // grad, mx_float lr);
101 
106  std::string Serialize() const;
107 
108  protected:
109  std::map<std::string, std::string> params_;
110  static OpMap*& op_map();
111  const std::vector<const char*> GetParamKeys_() const;
112  const std::vector<const char*> GetParamValues_() const;
113  std::map<int, unsigned> count_;
115  unsigned UpdateCount_(int index);
116  float GetLR_(int index);
117  float GetWD_(int index);
118  virtual void CreateState_(int index, NDArray weight);
119  std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
120 };
121 
122 typedef std::function<Optimizer*()> OptimizerCreator;
123 
125  public:
126  static Optimizer* Find(const std::string& name);
127  static int __REGISTER__(const std::string& name, OptimizerCreator creator);
128  private:
129  static std::map<std::string, OptimizerCreator>& cmap();
130  OptimizerRegistry() = delete;
131  ~OptimizerRegistry() = delete;
132 };
133 
134 #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType) \
135  static int __make_ ## OptimizerType ## _ ## Name ## __ = \
136  OptimizerRegistry::__REGISTER__(#Name, [](){return new OptimizerType();})
137 
138 class SGDOptimizer : public Optimizer {
139  public:
140  explicit SGDOptimizer(unsigned begin_num_update = 0);
141  std::string GetType() const override;
142  void Update(int index, NDArray weight, NDArray grad) override;
143  private:
144  virtual ~SGDOptimizer();
145  void CreateState_(int index, NDArray weight) override;
146  std::map<int, NDArray*> states_;
147  AtomicSymbolCreator update_handle_;
148  AtomicSymbolCreator mom_update_handle_;
149 };
150 
151 class RMSPropOptimizer : public Optimizer {
152  public:
153  explicit RMSPropOptimizer(unsigned begin_num_update = 0);
154  std::string GetType() const override;
155  void Update(int index, NDArray weight, NDArray grad) override;
156  private:
157  virtual ~RMSPropOptimizer();
158  void CreateState_(int index, NDArray weight) override;
159  std::map<int, NDArray*> n_, g_, delta_;
160  AtomicSymbolCreator update_handle_;
161  AtomicSymbolCreator alex_update_handle_;
162 };
163 
164 class AdamOptimizer : public Optimizer {
165  public:
166  explicit AdamOptimizer(unsigned begin_num_update = 0);
167  std::string GetType() const override;
168  void Update(int index, NDArray weight, NDArray grad) override;
169  private:
170  virtual ~AdamOptimizer();
171  void CreateState_(int index, NDArray weight) override;
172  std::map<int, NDArray*> mean_;
173  std::map<int, NDArray*> var_;
174  AtomicSymbolCreator update_handle_;
175 };
176 
177 class AdaGradOptimizer : public Optimizer {
178  public:
179  explicit AdaGradOptimizer(unsigned begin_num_update = 0);
180  std::string GetType() const override;
181  void Update(int index, NDArray weight, NDArray grad) override;
182  private:
183  virtual ~AdaGradOptimizer();
184  void CreateState_(int index, NDArray weight) override;
185  std::map<int, NDArray*> history_;
186 };
187 
188 class AdaDeltaOptimizer : public Optimizer {
189  public:
190  explicit AdaDeltaOptimizer(unsigned begin_num_update = 0);
191  std::string GetType() const override;
192  void Update(int index, NDArray weight, NDArray grad) override;
193  private:
194  virtual ~AdaDeltaOptimizer();
195  void CreateState_(int index, NDArray weight) override;
196  std::map<int, NDArray*> acc_g_, acc_delta_;
197 };
198 
199 } // namespace cpp
200 } // namespace mxnet
201 
202 #endif // MXNET_CPP_OPTIMIZER_H_
Definition: optimizer.h:138
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
Definition: optimizer.h:164
definition of OpMap
unsigned UpdateCount_(int index)
namespace of mxnet
Definition: base.h:127
Optimizer(unsigned begin_num_update)
constructor
const std::vector< const char * > GetParamKeys_() const
virtual std::string GetType() const =0
get optimizer type
Scheduling learning rate.
unsigned begin_num_update_
Definition: optimizer.h:114
unsigned num_update_
Definition: optimizer.h:114
Optimizer interface.
Definition: optimizer.h:47
std::map< int, unsigned > count_
Definition: optimizer.h:113
Definition: optimizer.h:188
Definition: optimizer.h:177
NDArray interface.
Definition: ndarray.h:121
float GetWD_(int index)
virtual void Update(int index, NDArray weight, NDArray grad)=0
Update a weight with gradient.
std::unique_ptr< LRScheduler > lrScheduler_
Definition: optimizer.h:119
Definition: optimizer.h:124
Definition: optimizer.h:151
std::map< std::string, std::string > params_
Definition: optimizer.h:109
virtual ~Optimizer()
destructor
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:69
const std::vector< const char * > GetParamValues_() const
Optimizer * SetParam(const std::string &name, const T &value)
set config parameters
Definition: optimizer.h:70
virtual void CreateState_(int index, NDArray weight)
std::string Serialize() const
Serialize the optimizer parameters to a string.
static OpMap *& op_map()
std::function< Optimizer *()> OptimizerCreator
Definition: optimizer.h:122
Optimizer * SetLRScheduler(std::unique_ptr< LRScheduler > lrScheduler)
Definition: optimizer.h:84
float GetLR_(int index)