6 #ifndef DMLC_PARAMETER_H_ 7 #define DMLC_PARAMETER_H_ 27 #include "./logging.h" 35 struct ParamError :
public dmlc::Error {
40 explicit ParamError(
const std::string &msg)
50 template<
typename ValueType>
51 inline ValueType GetEnv(
const char *key,
52 ValueType default_value);
59 template<
typename ValueType>
60 inline void SetEnv(
const char *key,
68 class FieldAccessEntry;
70 template<
typename DType>
73 template<
typename PType>
74 struct ParamManagerSingleton;
77 enum ParamInitOption {
89 struct ParamFieldInfo {
98 std::string type_info_str;
100 std::string description;
127 template<
typename PType>
140 template<
typename Container>
141 inline void Init(
const Container &kwargs,
142 parameter::ParamInitOption option = parameter::kAllowHidden) {
143 PType::__MANAGER__()->RunInit(static_cast<PType*>(
this),
144 kwargs.begin(), kwargs.end(),
157 template<
typename Container>
158 inline std::vector<std::pair<std::string, std::string> >
159 InitAllowUnknown(
const Container &kwargs) {
160 std::vector<std::pair<std::string, std::string> > unknown;
161 PType::__MANAGER__()->RunInit(static_cast<PType*>(
this),
162 kwargs.begin(), kwargs.end(),
163 &unknown, parameter::kAllowUnknown);
179 template <
typename Container>
180 std::vector<std::pair<std::string, std::string> >
181 UpdateAllowUnknown(Container
const& kwargs,
bool* out_changed =
nullptr) {
182 std::vector<std::pair<std::string, std::string> > unknown;
183 bool changed {
false};
184 changed = PType::__MANAGER__()->RunUpdate(static_cast<PType*>(
this),
185 kwargs.begin(), kwargs.end(),
186 parameter::kAllowUnknown, &unknown,
nullptr);
187 if (out_changed) { *out_changed = changed; }
197 template<
typename Container>
198 inline void UpdateDict(Container *dict)
const {
199 PType::__MANAGER__()->UpdateDict(this->head(), dict);
205 inline std::map<std::string, std::string> __DICT__()
const {
206 std::vector<std::pair<std::string, std::string> > vec
207 = PType::__MANAGER__()->GetDict(this->head());
208 return std::map<std::string, std::string>(vec.begin(), vec.end());
215 writer->
Write(this->__DICT__());
223 std::map<std::string, std::string> kwargs;
224 reader->
Read(&kwargs);
231 inline static std::vector<ParamFieldInfo> __FIELDS__() {
232 return PType::__MANAGER__()->GetFieldInfo();
238 inline static std::string __DOC__() {
239 std::ostringstream os;
240 PType::__MANAGER__()->PrintDocString(os);
251 template<
typename DType>
252 inline parameter::FieldEntry<DType>& DECLARE(
253 parameter::ParamManagerSingleton<PType> *manager,
254 const std::string &key, DType &ref) {
255 parameter::FieldEntry<DType> *e =
256 new parameter::FieldEntry<DType>();
257 e->Init(key, this->head(), ref);
258 manager->manager.AddEntry(key, e);
264 inline PType *head()
const {
265 return static_cast<PType*
>(
const_cast<Parameter<PType>*
>(
this));
289 #define DMLC_DECLARE_PARAMETER(PType) \ 290 static ::dmlc::parameter::ParamManager *__MANAGER__(); \ 291 inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \ 297 #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) 304 #define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) 314 #define DMLC_REGISTER_PARAMETER(PType) \ 315 ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ 316 static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ 317 return &inst.manager; \ 319 static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ 320 __make__ ## PType ## ParamManager__ = \ 321 (*PType::__MANAGER__()) \ 328 namespace parameter {
335 class FieldAccessEntry {
338 : has_default_(false), index_(0) {}
340 virtual ~FieldAccessEntry() {}
346 virtual void SetDefault(
void *head)
const = 0;
352 virtual void Set(
void *head,
const std::string &value)
const = 0;
358 virtual bool Same(
void* head,
const std::string& value)
const = 0;
360 virtual void Check(
void *head)
const {}
365 virtual std::string GetStringValue(
void *head)
const = 0;
370 virtual ParamFieldInfo GetFieldInfo()
const = 0;
382 std::string description_;
386 char* GetRawPtr(
void* head)
const {
387 return reinterpret_cast<char*
>(head) + offset_;
393 virtual void PrintDefaultValueString(std::ostream &os)
const = 0;
395 friend class ParamManager;
406 for (
size_t i = 0; i < entry_.size(); ++i) {
415 inline FieldAccessEntry *Find(
const std::string &key)
const {
416 std::map<std::string, FieldAccessEntry*>::const_iterator it =
417 entry_map_.find(key);
418 if (it == entry_map_.end())
return NULL;
431 template<
typename RandomAccessIterator>
432 inline void RunInit(
void *head,
433 RandomAccessIterator begin,
434 RandomAccessIterator end,
435 std::vector<std::pair<std::string, std::string> > *unknown_args,
436 parameter::ParamInitOption option)
const {
437 std::set<FieldAccessEntry*> selected_args;
438 RunUpdate(head, begin, end, option, unknown_args, &selected_args);
439 for (
auto const& kv : entry_map_) {
440 if (selected_args.find(kv.second) == selected_args.cend()) {
441 kv.second->SetDefault(head);
444 for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
445 it != entry_map_.end(); ++it) {
446 if (selected_args.count(it->second) == 0) {
447 it->second->SetDefault(head);
463 template <
typename RandomAccessIterator>
464 bool RunUpdate(
void *head,
465 RandomAccessIterator begin,
466 RandomAccessIterator end,
467 parameter::ParamInitOption option,
468 std::vector<std::pair<std::string, std::string> > *unknown_args,
469 std::set<FieldAccessEntry*>* selected_args =
nullptr)
const {
470 bool changed {
false};
471 for (RandomAccessIterator it = begin; it != end; ++it) {
472 if (FieldAccessEntry *e = Find(it->first)) {
473 if (!e->Same(head, it->second)) {
476 e->Set(head, it->second);
479 selected_args->insert(e);
482 if (unknown_args != NULL) {
483 unknown_args->push_back(*it);
485 if (option != parameter::kAllowUnknown) {
486 if (option == parameter::kAllowHidden &&
487 it->first.length() > 4 &&
488 it->first.find(
"__") == 0 &&
489 it->first.rfind(
"__") == it->first.length()-2) {
492 std::ostringstream os;
493 os <<
"Cannot find argument \'" << it->first <<
"\', Possible Arguments:\n";
494 os <<
"----------------\n";
496 throw dmlc::ParamError(os.str());
509 inline void AddEntry(
const std::string &key, FieldAccessEntry *e) {
510 e->index_ = entry_.size();
512 if (entry_map_.count(key) != 0) {
513 LOG(FATAL) <<
"key " << key <<
" has already been registered in " << name_;
524 inline void AddAlias(
const std::string& field,
const std::string& alias) {
525 if (entry_map_.count(field) == 0) {
526 LOG(FATAL) <<
"key " << field <<
" has not been registered in " << name_;
528 if (entry_map_.count(alias) != 0) {
529 LOG(FATAL) <<
"Alias " << alias <<
" has already been registered in " << name_;
531 entry_map_[alias] = entry_map_[field];
537 inline void set_name(
const std::string &name) {
544 inline std::vector<ParamFieldInfo> GetFieldInfo()
const {
545 std::vector<ParamFieldInfo> ret(entry_.size());
546 for (
size_t i = 0; i < entry_.size(); ++i) {
547 ret[i] = entry_[i]->GetFieldInfo();
555 inline void PrintDocString(std::ostream &os)
const {
556 for (
size_t i = 0; i < entry_.size(); ++i) {
557 ParamFieldInfo info = entry_[i]->GetFieldInfo();
558 os << info.name <<
" : " << info.type_info_str <<
'\n';
559 if (info.description.length() != 0) {
560 os <<
" " << info.description <<
'\n';
570 inline std::vector<std::pair<std::string, std::string> > GetDict(
void * head)
const {
571 std::vector<std::pair<std::string, std::string> > ret;
572 for (std::map<std::string, FieldAccessEntry*>::const_iterator
573 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
574 ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
584 template<
typename Container>
585 inline void UpdateDict(
void * head, Container* dict)
const {
586 for (std::map<std::string, FieldAccessEntry*>::const_iterator
587 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
588 (*dict)[it->first] = it->second->GetStringValue(head);
596 std::vector<FieldAccessEntry*> entry_;
598 std::map<std::string, FieldAccessEntry*> entry_map_;
605 template<
typename PType>
606 struct ParamManagerSingleton {
607 ParamManager manager;
608 explicit ParamManagerSingleton(
const std::string ¶m_name) {
610 manager.set_name(param_name);
611 param.__DECLARE__(
this);
617 template<
typename TEntry,
typename DType>
618 class FieldEntryBase :
public FieldAccessEntry {
621 typedef TEntry EntryType;
623 void Set(
void *head,
const std::string &value)
const override {
624 std::istringstream is(value);
625 is >> this->Get(head);
633 is.setstate(std::ios::failbit);
break;
639 std::ostringstream os;
640 os <<
"Invalid Parameter format for " << key_
641 <<
" expect " << type_ <<
" but value=\'" << value<<
'\'';
642 throw dmlc::ParamError(os.str());
645 bool Same(
void* head, std::string
const& value)
const override {
646 DType old = this->Get(head);
648 std::istringstream is(value);
651 bool is_same = std::equal(
652 reinterpret_cast<char*>(&now), reinterpret_cast<char*>(&now) +
sizeof(now),
653 reinterpret_cast<char*>(&old));
656 std::string GetStringValue(
void *head)
const override {
657 std::ostringstream os;
658 PrintValue(os, this->Get(head));
661 ParamFieldInfo GetFieldInfo()
const override {
663 std::ostringstream os;
668 os <<
',' <<
" optional, default=";
669 PrintDefaultValueString(os);
673 info.type_info_str = os.str();
674 info.description = description_;
678 void SetDefault(
void *head)
const override {
680 std::ostringstream os;
681 os <<
"Required parameter " << key_
682 <<
" of " << type_ <<
" is not presented";
683 throw dmlc::ParamError(os.str());
685 this->Get(head) = default_value_;
689 inline TEntry &
self() {
690 return *(
static_cast<TEntry*
>(
this));
693 inline TEntry &set_default(
const DType &default_value) {
694 default_value_ = default_value;
700 inline TEntry &describe(
const std::string &description) {
701 description_ = description;
706 inline void Init(
const std::string &key,
707 void *head, DType &ref) {
709 if (this->type_.length() == 0) {
710 this->type_ = dmlc::type_name<DType>();
712 this->offset_ = ((
char*)&ref) - ((
char*)head);
717 virtual void PrintValue(std::ostream &os, DType value)
const {
720 void PrintDefaultValueString(std::ostream &os)
const override {
721 PrintValue(os, default_value_);
726 inline DType &Get(
void *head)
const {
727 return *(DType*)this->GetRawPtr(head);
730 DType default_value_;
734 template<
typename TEntry,
typename DType>
735 class FieldEntryNumeric
736 :
public FieldEntryBase<TEntry, DType> {
739 : has_begin_(false), has_end_(false) {}
741 virtual TEntry &set_range(DType begin, DType end) {
742 begin_ = begin; end_ = end;
743 has_begin_ =
true; has_end_ =
true;
747 virtual TEntry &set_lower_bound(DType begin) {
748 begin_ = begin; has_begin_ =
true;
752 virtual void Check(
void *head)
const {
753 FieldEntryBase<TEntry, DType>::Check(head);
754 DType v = this->Get(head);
755 if (has_begin_ && has_end_) {
756 if (v < begin_ || v > end_) {
757 std::ostringstream os;
758 os <<
"value " << v <<
" for Parameter " << this->key_
759 <<
" exceed bound [" << begin_ <<
',' << end_ <<
']' <<
'\n';
760 os << this->key_ <<
": " << this->description_;
761 throw dmlc::ParamError(os.str());
763 }
else if (has_begin_ && v < begin_) {
764 std::ostringstream os;
765 os <<
"value " << v <<
" for Parameter " << this->key_
766 <<
" should be greater equal to " << begin_ <<
'\n';
767 os << this->key_ <<
": " << this->description_;
768 throw dmlc::ParamError(os.str());
769 }
else if (has_end_ && v > end_) {
770 std::ostringstream os;
771 os <<
"value " << v <<
" for Parameter " << this->key_
772 <<
" should be smaller equal to " << end_ <<
'\n';
773 os << this->key_ <<
": " << this->description_;
774 throw dmlc::ParamError(os.str());
780 bool has_begin_, has_end_;
790 template<
typename DType>
792 public IfThenElseType<dmlc::is_arithmetic<DType>::value,
793 FieldEntryNumeric<FieldEntry<DType>, DType>,
794 FieldEntryBase<FieldEntry<DType>, DType> >::Type {
799 class FieldEntry<int>
800 :
public FieldEntryNumeric<FieldEntry<int>, int> {
803 FieldEntry<int>() : is_enum_(
false) {}
805 typedef FieldEntryNumeric<FieldEntry<int>,
int> Parent;
807 virtual void Set(
void *head,
const std::string &value)
const {
809 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
810 std::ostringstream os;
811 if (it == enum_map_.end()) {
812 os <<
"Invalid Input: \'" << value;
813 os <<
"\', valid values are: ";
815 throw dmlc::ParamError(os.str());
818 Parent::Set(head, os.str());
821 Parent::Set(head, value);
824 virtual ParamFieldInfo GetFieldInfo()
const {
827 std::ostringstream os;
832 os <<
',' <<
"optional, default=";
833 PrintDefaultValueString(os);
837 info.type_info_str = os.str();
838 info.description = description_;
841 return Parent::GetFieldInfo();
845 inline FieldEntry<int> &add_enum(
const std::string &key,
int value) {
846 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
847 enum_back_map_.count(value) != 0) {
848 std::ostringstream os;
849 os <<
"Enum " <<
"(" << key <<
": " << value <<
" exisit!" <<
")\n";
851 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
852 it != enum_map_.end(); ++it) {
853 os <<
"(" << it->first <<
": " << it->second <<
"), ";
855 throw dmlc::ParamError(os.str());
857 enum_map_[key] = value;
858 enum_back_map_[value] = key;
867 std::map<std::string, int> enum_map_;
869 std::map<int, std::string> enum_back_map_;
871 virtual void PrintDefaultValueString(std::ostream &os)
const {
873 PrintValue(os, default_value_);
877 virtual void PrintValue(std::ostream &os,
int value)
const {
879 CHECK_NE(enum_back_map_.count(value), 0U)
880 <<
"Value not found in enum declared";
881 os << enum_back_map_.at(value);
889 inline void PrintEnums(std::ostream &os)
const {
891 for (std::map<std::string, int>::const_iterator
892 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
893 if (it != enum_map_.begin()) {
896 os <<
"\'" << it->first <<
'\'';
905 class FieldEntry<optional<int> >
906 :
public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
909 FieldEntry<optional<int> >() : is_enum_(
false) {}
911 typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
913 virtual void Set(
void *head,
const std::string &value)
const {
914 if (is_enum_ && value !=
"None") {
915 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
916 std::ostringstream os;
917 if (it == enum_map_.end()) {
918 os <<
"Invalid Input: \'" << value;
919 os <<
"\', valid values are: ";
921 throw dmlc::ParamError(os.str());
924 Parent::Set(head, os.str());
927 Parent::Set(head, value);
930 virtual ParamFieldInfo GetFieldInfo()
const {
933 std::ostringstream os;
938 os <<
',' <<
"optional, default=";
939 PrintDefaultValueString(os);
943 info.type_info_str = os.str();
944 info.description = description_;
947 return Parent::GetFieldInfo();
951 inline FieldEntry<optional<int> > &add_enum(
const std::string &key,
int value) {
952 CHECK_NE(key,
"None") <<
"None is reserved for empty optional<int>";
953 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
954 enum_back_map_.count(value) != 0) {
955 std::ostringstream os;
956 os <<
"Enum " <<
"(" << key <<
": " << value <<
" exisit!" <<
")\n";
958 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
959 it != enum_map_.end(); ++it) {
960 os <<
"(" << it->first <<
": " << it->second <<
"), ";
962 throw dmlc::ParamError(os.str());
964 enum_map_[key] = value;
965 enum_back_map_[value] = key;
974 std::map<std::string, int> enum_map_;
976 std::map<int, std::string> enum_back_map_;
978 virtual void PrintDefaultValueString(std::ostream &os)
const {
980 PrintValue(os, default_value_);
984 virtual void PrintValue(std::ostream &os, optional<int> value)
const {
989 CHECK_NE(enum_back_map_.count(value.value()), 0U)
990 <<
"Value not found in enum declared";
991 os << enum_back_map_.at(value.value());
1000 inline void PrintEnums(std::ostream &os)
const {
1002 for (std::map<std::string, int>::const_iterator
1003 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
1005 os <<
"\'" << it->first <<
'\'';
1013 class FieldEntry<
std::string>
1014 :
public FieldEntryBase<FieldEntry<std::string>, std::string> {
1017 typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
1019 virtual void Set(
void *head,
const std::string &value)
const {
1020 this->Get(head) = value;
1023 virtual void PrintDefaultValueString(std::ostream &os)
const {
1024 os <<
'\'' << default_value_ <<
'\'';
1030 class FieldEntry<bool>
1031 :
public FieldEntryBase<FieldEntry<bool>, bool> {
1034 typedef FieldEntryBase<FieldEntry<bool>,
bool> Parent;
1036 virtual void Set(
void *head,
const std::string &value)
const {
1037 std::string lower_case; lower_case.resize(value.length());
1038 std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
1039 bool &ref = this->Get(head);
1040 if (lower_case ==
"true") {
1042 }
else if (lower_case ==
"false") {
1044 }
else if (lower_case ==
"1") {
1046 }
else if (lower_case ==
"0") {
1049 std::ostringstream os;
1050 os <<
"Invalid Parameter format for " << key_
1051 <<
" expect " << type_ <<
" but value=\'" << value<<
'\'';
1052 throw dmlc::ParamError(os.str());
1058 virtual void PrintValue(std::ostream &os,
bool value)
const {
1059 os << static_cast<int>(value);
1068 class FieldEntry<float> :
public FieldEntryNumeric<FieldEntry<float>, float> {
1071 typedef FieldEntryNumeric<FieldEntry<float>,
float> Parent;
1073 virtual void Set(
void *head,
const std::string &value)
const {
1077 }
catch (
const std::invalid_argument &) {
1078 std::ostringstream os;
1079 os <<
"Invalid Parameter format for " << key_ <<
" expect " << type_
1080 <<
" but value=\'" << value <<
'\'';
1081 throw dmlc::ParamError(os.str());
1082 }
catch (
const std::out_of_range&) {
1083 std::ostringstream os;
1084 os <<
"Out of range value for " << key_ <<
", value=\'" << value <<
'\'';
1085 throw dmlc::ParamError(os.str());
1087 CHECK_LE(pos, value.length());
1088 if (pos < value.length()) {
1089 std::ostringstream os;
1090 os <<
"Some trailing characters could not be parsed: \'" 1091 << value.substr(pos) <<
"\'";
1092 throw dmlc::ParamError(os.str());
1098 virtual void PrintValue(std::ostream &os,
float value)
const {
1099 os << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
1106 class FieldEntry<double>
1107 :
public FieldEntryNumeric<FieldEntry<double>, double> {
1110 typedef FieldEntryNumeric<FieldEntry<double>,
double> Parent;
1112 virtual void Set(
void *head,
const std::string &value)
const {
1116 }
catch (
const std::invalid_argument &) {
1117 std::ostringstream os;
1118 os <<
"Invalid Parameter format for " << key_ <<
" expect " << type_
1119 <<
" but value=\'" << value <<
'\'';
1120 throw dmlc::ParamError(os.str());
1121 }
catch (
const std::out_of_range&) {
1122 std::ostringstream os;
1123 os <<
"Out of range value for " << key_ <<
", value=\'" << value <<
'\'';
1124 throw dmlc::ParamError(os.str());
1126 CHECK_LE(pos, value.length());
1127 if (pos < value.length()) {
1128 std::ostringstream os;
1129 os <<
"Some trailing characters could not be parsed: \'" 1130 << value.substr(pos) <<
"\'";
1131 throw dmlc::ParamError(os.str());
1137 virtual void PrintValue(std::ostream &os,
double value)
const {
1138 os << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
1141 #endif // DMLC_USE_CXX11 1147 template<
typename ValueType>
1148 inline ValueType GetEnv(
const char *key,
1149 ValueType default_value) {
1150 const char *val = getenv(key);
1154 if (val ==
nullptr || !*val) {
1155 return default_value;
1158 parameter::FieldEntry<ValueType> e;
1159 e.Init(key, &ret, ret);
1165 template<
typename ValueType>
1166 inline void SetEnv(
const char *key,
1168 parameter::FieldEntry<ValueType> e;
1169 e.Init(key, &value, value);
1171 _putenv_s(key, e.GetStringValue(&value).c_str());
1173 setenv(key, e.GetStringValue(&value).c_str(), 1);
1177 #endif // DMLC_PARAMETER_H_ Container to hold optional data.
Definition: optional.h:241
double stod(const std::string &value, size_t *pos=nullptr)
A faster implementation of stod(). See documentation of std::stod() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:497
A faster implementation of strtof and strtod.
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
namespace for dmlc
Definition: array_view.h:12
void Write(const ValueType &value)
Write value to json.
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:467
void Read(ValueType *out_value)
Read next ValueType.
type traits information header
Lightweight json to write any STL compositions.
Definition: json.h:189