mxnet
optimizer.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_OPTIMIZER_H_
27 #define MXNET_CPP_OPTIMIZER_H_
28 
29 #include <map>
30 #include <vector>
31 #include <string>
32 #include <memory>
33 #include <functional>
34 #include "mxnet-cpp/base.h"
35 #include "dmlc/logging.h"
36 #include "mxnet-cpp/ndarray.h"
37 #include "mxnet-cpp/op_map.h"
38 #include "mxnet-cpp/lr_scheduler.h"
39 
40 namespace mxnet {
41 namespace cpp {
42 
46 class Optimizer {
47  public:
52  explicit Optimizer(unsigned begin_num_update);
57  virtual std::string GetType() const = 0;
61  virtual ~Optimizer();
68  template <typename T>
69  Optimizer *SetParam(const std::string &name, const T &value) {
70  std::string value_str;
71  std::stringstream ss;
72  ss << value;
73  ss >> value_str;
74 
75  params_[name] = value_str;
76  return this;
77  }
83  Optimizer *SetLRScheduler(std::unique_ptr<LRScheduler> lrScheduler) {
84  CHECK(lrScheduler);
85  lrScheduler_ = std::move(lrScheduler);
86  lrScheduler_->SetLR(std::stof(params_["lr"]));
87  return this;
88  }
95  virtual void Update(int index, NDArray weight, NDArray grad) = 0;
96  // TODO(zhangcheng-qinyinghua)
97  // implement Update a list of arrays, maybe in the form of map
98  // void Update(int index, std::vector<NDArray> weights, std::vector<NDArray>
99  // grad, mx_float lr);
100 
105  std::string Serialize() const;
106 
107  protected:
108  std::map<std::string, std::string> params_;
109  static OpMap*& op_map();
110  const std::vector<const char*> GetParamKeys_() const;
111  const std::vector<const char*> GetParamValues_() const;
112  std::map<int, unsigned> count_;
114  unsigned UpdateCount_(int index);
115  float GetLR_(int index);
116  float GetWD_(int index);
117  virtual void CreateState_(int index, NDArray weight);
118  std::unique_ptr<LRScheduler> lrScheduler_ = nullptr;
119 };
120 
121 typedef std::function<Optimizer*()> OptimizerCreator;
122 
124  public:
125  static Optimizer* Find(const std::string& name);
126  static int __REGISTER__(const std::string& name, OptimizerCreator creator);
127  private:
128  static std::map<std::string, OptimizerCreator>& cmap();
129  OptimizerRegistry() = delete;
130  ~OptimizerRegistry() = delete;
131 };
132 
133 #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType) \
134  static int __make_ ## OptimizerType ## _ ## Name ## __ = \
135  OptimizerRegistry::__REGISTER__(#Name, [](){return new OptimizerType();})
136 
137 class SGDOptimizer : public Optimizer {
138  public:
139  explicit SGDOptimizer(unsigned begin_num_update = 0);
140  std::string GetType() const override;
141  void Update(int index, NDArray weight, NDArray grad) override;
142  private:
143  virtual ~SGDOptimizer();
144  void CreateState_(int index, NDArray weight) override;
145  std::map<int, NDArray*> states_;
146  AtomicSymbolCreator update_handle_;
147  AtomicSymbolCreator mom_update_handle_;
148 };
149 
150 class RMSPropOptimizer : public Optimizer {
151  public:
152  explicit RMSPropOptimizer(unsigned begin_num_update = 0);
153  std::string GetType() const override;
154  void Update(int index, NDArray weight, NDArray grad) override;
155  private:
156  virtual ~RMSPropOptimizer();
157  void CreateState_(int index, NDArray weight) override;
158  std::map<int, NDArray*> n_, g_, delta_;
159  AtomicSymbolCreator update_handle_;
160  AtomicSymbolCreator alex_update_handle_;
161 };
162 
163 class AdamOptimizer : public Optimizer {
164  public:
165  explicit AdamOptimizer(unsigned begin_num_update = 0);
166  std::string GetType() const override;
167  void Update(int index, NDArray weight, NDArray grad) override;
168  private:
169  virtual ~AdamOptimizer();
170  void CreateState_(int index, NDArray weight) override;
171  std::map<int, NDArray*> mean_;
172  std::map<int, NDArray*> var_;
173  AtomicSymbolCreator update_handle_;
174 };
175 
176 class AdaGradOptimizer : public Optimizer {
177  public:
178  explicit AdaGradOptimizer(unsigned begin_num_update = 0);
179  std::string GetType() const override;
180  void Update(int index, NDArray weight, NDArray grad) override;
181  private:
182  virtual ~AdaGradOptimizer();
183  void CreateState_(int index, NDArray weight) override;
184  std::map<int, NDArray*> history_;
185 };
186 
187 class AdaDeltaOptimizer : public Optimizer {
188  public:
189  explicit AdaDeltaOptimizer(unsigned begin_num_update = 0);
190  std::string GetType() const override;
191  void Update(int index, NDArray weight, NDArray grad) override;
192  private:
193  virtual ~AdaDeltaOptimizer();
194  void CreateState_(int index, NDArray weight) override;
195  std::map<int, NDArray*> acc_g_, acc_delta_;
196 };
197 
198 } // namespace cpp
199 } // namespace mxnet
200 
201 #endif // MXNET_CPP_OPTIMIZER_H_
Definition: optimizer.h:137
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:42
Definition: optimizer.h:163
definition of OpMap
unsigned UpdateCount_(int index)
namespace of mxnet
Definition: base.h:126
Optimizer(unsigned begin_num_update)
constructor
const std::vector< const char * > GetParamKeys_() const
virtual std::string GetType() const =0
get optimizer type
Scheduling learning rate.
unsigned begin_num_update_
Definition: optimizer.h:113
unsigned num_update_
Definition: optimizer.h:113
Optimizer interface.
Definition: optimizer.h:46
std::map< int, unsigned > count_
Definition: optimizer.h:112
Definition: optimizer.h:187
Definition: optimizer.h:176
NDArray interface.
Definition: ndarray.h:120
float GetWD_(int index)
virtual void Update(int index, NDArray weight, NDArray grad)=0
Update a weight with gradient.
std::unique_ptr< LRScheduler > lrScheduler_
Definition: optimizer.h:118
Definition: optimizer.h:123
Definition: optimizer.h:150
std::map< std::string, std::string > params_
Definition: optimizer.h:108
virtual ~Optimizer()
destructor
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:68
const std::vector< const char * > GetParamValues_() const
Optimizer * SetParam(const std::string &name, const T &value)
set config parameters
Definition: optimizer.h:69
virtual void CreateState_(int index, NDArray weight)
std::string Serialize() const
Serialize the optimizer parameters to a string.
static OpMap *& op_map()
std::function< Optimizer *()> OptimizerCreator
Definition: optimizer.h:121
Optimizer * SetLRScheduler(std::unique_ptr< LRScheduler > lrScheduler)
Definition: optimizer.h:83
float GetLR_(int index)