27 #ifndef MXNET_CPP_OPTIMIZER_H_ 28 #define MXNET_CPP_OPTIMIZER_H_ 37 #include "dmlc/logging.h" 54 explicit Optimizer(
unsigned begin_num_update);
59 virtual std::string
GetType()
const = 0;
72 std::string value_str;
127 static Optimizer* Find(
const std::string& name);
128 static int __REGISTER__(
const std::string& name, OptimizerCreator creator);
130 static std::map<std::string, OptimizerCreator>& cmap();
134 #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType)\ 135 OptimizerRegistry::__REGISTER__(#Name, [](){return new OptimizerType();}) 140 std::string
GetType()
const override;
145 std::map<int, NDArray*> states_;
153 std::string
GetType()
const override;
158 std::map<int, NDArray*> states_;
167 std::string
GetType()
const override;
172 std::map<int, NDArray*> n_, g_, delta_;
180 std::string
GetType()
const override;
185 std::map<int, NDArray*> mean_;
186 std::map<int, NDArray*> var_;
193 std::string
GetType()
const override;
198 std::map<int, NDArray*> history_;
204 std::string
GetType()
const override;
209 std::map<int, NDArray*> acc_g_, acc_delta_;
215 #endif // MXNET_CPP_OPTIMIZER_H_ Definition: optimizer.h:137
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
Definition: optimizer.h:177
unsigned UpdateCount_(int index)
namespace of mxnet
Definition: api_registry.h:33
Optimizer(unsigned begin_num_update)
constructor
const std::vector< const char * > GetParamKeys_() const
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:71
virtual std::string GetType() const =0
get optimizer type
Scheduling learning rate.
unsigned begin_num_update_
Definition: optimizer.h:115
Definition: optimizer.h:150
unsigned num_update_
Definition: optimizer.h:115
Optimizer interface.
Definition: optimizer.h:48
std::map< int, unsigned > count_
Definition: optimizer.h:114
A faster implementation of strtof and strtod.
Definition: optimizer.h:201
Definition: optimizer.h:190
NDArray interface.
Definition: ndarray.h:121
virtual void Update(int index, NDArray weight, NDArray grad)=0
Update a weight with gradient.
std::unique_ptr< LRScheduler > lrScheduler_
Definition: optimizer.h:120
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:467
Definition: optimizer.h:125
Definition: optimizer.h:164
std::map< std::string, std::string > params_
Definition: optimizer.h:110
virtual ~Optimizer()
destructor
const std::vector< const char * > GetParamValues_() const
Optimizer * SetParam(const std::string &name, const T &value)
set config parameters
Definition: optimizer.h:71
virtual void CreateState_(int index, NDArray weight)
std::string Serialize() const
Serialize the optimizer parameters to a string.
std::function< Optimizer *()> OptimizerCreator
Definition: optimizer.h:123
Optimizer * SetLRScheduler(std::unique_ptr< LRScheduler > lrScheduler)
Definition: optimizer.h:85