mxnet
parameter.h
Go to the documentation of this file.
1 
6 #ifndef DMLC_PARAMETER_H_
7 #define DMLC_PARAMETER_H_
8 
9 #include <cstddef>
10 #include <cstdlib>
11 #include <cmath>
12 #include <sstream>
13 #include <limits>
14 #include <map>
15 #include <set>
16 #include <typeinfo>
17 #include <string>
18 #include <vector>
19 #include <algorithm>
20 #include <utility>
21 #include <stdexcept>
22 #include <iostream>
23 #include <iomanip>
24 #include <cerrno>
25 #include "./base.h"
26 #include "./json.h"
27 #include "./logging.h"
28 #include "./type_traits.h"
29 #include "./optional.h"
30 #include "./strtonum.h"
31 
32 namespace dmlc {
33 // this file is backward compatible with non-c++11
35 struct ParamError : public dmlc::Error {
40  explicit ParamError(const std::string &msg)
41  : dmlc::Error(msg) {}
42 };
43 
50 template<typename ValueType>
51 inline ValueType GetEnv(const char *key,
52  ValueType default_value);
59 template<typename ValueType>
60 inline void SetEnv(const char *key,
61  ValueType value);
62 
64 namespace parameter {
65 // forward declare ParamManager
66 class ParamManager;
67 // forward declare FieldAccessEntry
68 class FieldAccessEntry;
69 // forward declare FieldEntry
70 template<typename DType>
71 class FieldEntry;
72 // forward declare ParamManagerSingleton
73 template<typename PType>
74 struct ParamManagerSingleton;
75 
77 enum ParamInitOption {
79  kAllowUnknown,
81  kAllMatch,
83  kAllowHidden
84 };
85 } // namespace parameter
89 struct ParamFieldInfo {
91  std::string name;
93  std::string type;
98  std::string type_info_str;
100  std::string description;
101 };
102 
127 template<typename PType>
128 struct Parameter {
129  public:
140  template<typename Container>
141  inline void Init(const Container &kwargs,
142  parameter::ParamInitOption option = parameter::kAllowHidden) {
143  PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
144  kwargs.begin(), kwargs.end(),
145  NULL,
146  option);
147  }
157  template<typename Container>
158  inline std::vector<std::pair<std::string, std::string> >
159  InitAllowUnknown(const Container &kwargs) {
160  std::vector<std::pair<std::string, std::string> > unknown;
161  PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
162  kwargs.begin(), kwargs.end(),
163  &unknown, parameter::kAllowUnknown);
164  return unknown;
165  }
166 
179  template <typename Container>
180  std::vector<std::pair<std::string, std::string> >
181  UpdateAllowUnknown(Container const& kwargs, bool* out_changed = nullptr) {
182  std::vector<std::pair<std::string, std::string> > unknown;
183  bool changed {false};
184  changed = PType::__MANAGER__()->RunUpdate(static_cast<PType*>(this),
185  kwargs.begin(), kwargs.end(),
186  parameter::kAllowUnknown, &unknown, nullptr);
187  if (out_changed) { *out_changed = changed; }
188  return unknown;
189  }
190 
197  template<typename Container>
198  inline void UpdateDict(Container *dict) const {
199  PType::__MANAGER__()->UpdateDict(this->head(), dict);
200  }
205  inline std::map<std::string, std::string> __DICT__() const {
206  std::vector<std::pair<std::string, std::string> > vec
207  = PType::__MANAGER__()->GetDict(this->head());
208  return std::map<std::string, std::string>(vec.begin(), vec.end());
209  }
214  inline void Save(dmlc::JSONWriter *writer) const {
215  writer->Write(this->__DICT__());
216  }
222  inline void Load(dmlc::JSONReader *reader) {
223  std::map<std::string, std::string> kwargs;
224  reader->Read(&kwargs);
225  this->Init(kwargs);
226  }
231  inline static std::vector<ParamFieldInfo> __FIELDS__() {
232  return PType::__MANAGER__()->GetFieldInfo();
233  }
238  inline static std::string __DOC__() {
239  std::ostringstream os;
240  PType::__MANAGER__()->PrintDocString(os);
241  return os.str();
242  }
243 
244  protected:
251  template<typename DType>
252  inline parameter::FieldEntry<DType>& DECLARE(
253  parameter::ParamManagerSingleton<PType> *manager,
254  const std::string &key, DType &ref) { // NOLINT(*)
255  parameter::FieldEntry<DType> *e =
256  new parameter::FieldEntry<DType>();
257  e->Init(key, this->head(), ref);
258  manager->manager.AddEntry(key, e);
259  return *e;
260  }
261 
262  private:
264  inline PType *head() const {
265  return static_cast<PType*>(const_cast<Parameter<PType>*>(this));
266  }
267 };
268 
270 
289 #define DMLC_DECLARE_PARAMETER(PType) \
290  static ::dmlc::parameter::ParamManager *__MANAGER__(); \
291  inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
292 
293 
297 #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
298 
304 #define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
305 
314 #define DMLC_REGISTER_PARAMETER(PType) \
315  ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
316  static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
317  return &inst.manager; \
318  } \
319  static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
320  __make__ ## PType ## ParamManager__ = \
321  (*PType::__MANAGER__()) \
322 
323 
328 namespace parameter {
335 class FieldAccessEntry {
336  public:
337  FieldAccessEntry()
338  : has_default_(false), index_(0) {}
340  virtual ~FieldAccessEntry() {}
346  virtual void SetDefault(void *head) const = 0;
352  virtual void Set(void *head, const std::string &value) const = 0;
358  virtual bool Same(void* head, const std::string& value) const = 0;
359  // check if value is OK
360  virtual void Check(void *head) const {}
365  virtual std::string GetStringValue(void *head) const = 0;
370  virtual ParamFieldInfo GetFieldInfo() const = 0;
371 
372  protected:
374  bool has_default_;
376  size_t index_;
378  std::string key_;
380  std::string type_;
382  std::string description_;
383  // internal offset of the field
384  ptrdiff_t offset_;
386  char* GetRawPtr(void* head) const {
387  return reinterpret_cast<char*>(head) + offset_;
388  }
393  virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*)
394  // allow ParamManager to modify self
395  friend class ParamManager;
396 };
397 
402 class ParamManager {
403  public:
405  ~ParamManager() {
406  for (size_t i = 0; i < entry_.size(); ++i) {
407  delete entry_[i];
408  }
409  }
415  inline FieldAccessEntry *Find(const std::string &key) const {
416  std::map<std::string, FieldAccessEntry*>::const_iterator it =
417  entry_map_.find(key);
418  if (it == entry_map_.end()) return NULL;
419  return it->second;
420  }
431  template<typename RandomAccessIterator>
432  inline void RunInit(void *head,
433  RandomAccessIterator begin,
434  RandomAccessIterator end,
435  std::vector<std::pair<std::string, std::string> > *unknown_args,
436  parameter::ParamInitOption option) const {
437  std::set<FieldAccessEntry*> selected_args;
438  RunUpdate(head, begin, end, option, unknown_args, &selected_args);
439  for (auto const& kv : entry_map_) {
440  if (selected_args.find(kv.second) == selected_args.cend()) {
441  kv.second->SetDefault(head);
442  }
443  }
444  for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
445  it != entry_map_.end(); ++it) {
446  if (selected_args.count(it->second) == 0) {
447  it->second->SetDefault(head);
448  }
449  }
450  }
463  template <typename RandomAccessIterator>
464  bool RunUpdate(void *head,
465  RandomAccessIterator begin,
466  RandomAccessIterator end,
467  parameter::ParamInitOption option,
468  std::vector<std::pair<std::string, std::string> > *unknown_args,
469  std::set<FieldAccessEntry*>* selected_args = nullptr) const {
470  bool changed {false};
471  for (RandomAccessIterator it = begin; it != end; ++it) {
472  if (FieldAccessEntry *e = Find(it->first)) {
473  if (!e->Same(head, it->second)) {
474  changed = true;
475  }
476  e->Set(head, it->second);
477  e->Check(head);
478  if (selected_args) {
479  selected_args->insert(e);
480  }
481  } else {
482  if (unknown_args != NULL) {
483  unknown_args->push_back(*it);
484  } else {
485  if (option != parameter::kAllowUnknown) {
486  if (option == parameter::kAllowHidden &&
487  it->first.length() > 4 &&
488  it->first.find("__") == 0 &&
489  it->first.rfind("__") == it->first.length()-2) {
490  continue;
491  }
492  std::ostringstream os;
493  os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
494  os << "----------------\n";
495  PrintDocString(os);
496  throw dmlc::ParamError(os.str());
497  }
498  }
499  }
500  }
501  return changed;
502  }
509  inline void AddEntry(const std::string &key, FieldAccessEntry *e) {
510  e->index_ = entry_.size();
511  // TODO(bing) better error message
512  if (entry_map_.count(key) != 0) {
513  LOG(FATAL) << "key " << key << " has already been registered in " << name_;
514  }
515  entry_.push_back(e);
516  entry_map_[key] = e;
517  }
524  inline void AddAlias(const std::string& field, const std::string& alias) {
525  if (entry_map_.count(field) == 0) {
526  LOG(FATAL) << "key " << field << " has not been registered in " << name_;
527  }
528  if (entry_map_.count(alias) != 0) {
529  LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_;
530  }
531  entry_map_[alias] = entry_map_[field];
532  }
537  inline void set_name(const std::string &name) {
538  name_ = name;
539  }
544  inline std::vector<ParamFieldInfo> GetFieldInfo() const {
545  std::vector<ParamFieldInfo> ret(entry_.size());
546  for (size_t i = 0; i < entry_.size(); ++i) {
547  ret[i] = entry_[i]->GetFieldInfo();
548  }
549  return ret;
550  }
555  inline void PrintDocString(std::ostream &os) const { // NOLINT(*)
556  for (size_t i = 0; i < entry_.size(); ++i) {
557  ParamFieldInfo info = entry_[i]->GetFieldInfo();
558  os << info.name << " : " << info.type_info_str << '\n';
559  if (info.description.length() != 0) {
560  os << " " << info.description << '\n';
561  }
562  }
563  }
570  inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const {
571  std::vector<std::pair<std::string, std::string> > ret;
572  for (std::map<std::string, FieldAccessEntry*>::const_iterator
573  it = entry_map_.begin(); it != entry_map_.end(); ++it) {
574  ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
575  }
576  return ret;
577  }
584  template<typename Container>
585  inline void UpdateDict(void * head, Container* dict) const {
586  for (std::map<std::string, FieldAccessEntry*>::const_iterator
587  it = entry_map_.begin(); it != entry_map_.end(); ++it) {
588  (*dict)[it->first] = it->second->GetStringValue(head);
589  }
590  }
591 
592  private:
594  std::string name_;
596  std::vector<FieldAccessEntry*> entry_;
598  std::map<std::string, FieldAccessEntry*> entry_map_;
599 };
600 
602 
603 // The following piece of code will be template heavy and less documented
604 // singleton parameter manager for certain type, used for initialization
605 template<typename PType>
606 struct ParamManagerSingleton {
607  ParamManager manager;
608  explicit ParamManagerSingleton(const std::string &param_name) {
609  PType param;
610  manager.set_name(param_name);
611  param.__DECLARE__(this);
612  }
613 };
614 
615 // Base class of FieldEntry
616 // implement set_default
617 template<typename TEntry, typename DType>
618 class FieldEntryBase : public FieldAccessEntry {
619  public:
620  // entry type
621  typedef TEntry EntryType;
622  // implement set value
623  void Set(void *head, const std::string &value) const override {
624  std::istringstream is(value);
625  is >> this->Get(head);
626  if (!is.fail()) {
627  while (!is.eof()) {
628  int ch = is.get();
629  if (ch == EOF) {
630  is.clear(); break;
631  }
632  if (!isspace(ch)) {
633  is.setstate(std::ios::failbit); break;
634  }
635  }
636  }
637 
638  if (is.fail()) {
639  std::ostringstream os;
640  os << "Invalid Parameter format for " << key_
641  << " expect " << type_ << " but value=\'" << value<< '\'';
642  throw dmlc::ParamError(os.str());
643  }
644  }
645  bool Same(void* head, std::string const& value) const override {
646  DType old = this->Get(head);
647  DType now;
648  std::istringstream is(value);
649  is >> now;
650  // don't require = operator
651  bool is_same = std::equal(
652  reinterpret_cast<char*>(&now), reinterpret_cast<char*>(&now) + sizeof(now),
653  reinterpret_cast<char*>(&old));
654  return is_same;
655  }
656  std::string GetStringValue(void *head) const override {
657  std::ostringstream os;
658  PrintValue(os, this->Get(head));
659  return os.str();
660  }
661  ParamFieldInfo GetFieldInfo() const override {
662  ParamFieldInfo info;
663  std::ostringstream os;
664  info.name = key_;
665  info.type = type_;
666  os << type_;
667  if (has_default_) {
668  os << ',' << " optional, default=";
669  PrintDefaultValueString(os);
670  } else {
671  os << ", required";
672  }
673  info.type_info_str = os.str();
674  info.description = description_;
675  return info;
676  }
677  // implement set head to default value
678  void SetDefault(void *head) const override {
679  if (!has_default_) {
680  std::ostringstream os;
681  os << "Required parameter " << key_
682  << " of " << type_ << " is not presented";
683  throw dmlc::ParamError(os.str());
684  } else {
685  this->Get(head) = default_value_;
686  }
687  }
688  // return reference of self as derived type
689  inline TEntry &self() {
690  return *(static_cast<TEntry*>(this));
691  }
692  // implement set_default
693  inline TEntry &set_default(const DType &default_value) {
694  default_value_ = default_value;
695  has_default_ = true;
696  // return self to allow chaining
697  return this->self();
698  }
699  // implement describe
700  inline TEntry &describe(const std::string &description) {
701  description_ = description;
702  // return self to allow chaining
703  return this->self();
704  }
705  // initialization function
706  inline void Init(const std::string &key,
707  void *head, DType &ref) { // NOLINT(*)
708  this->key_ = key;
709  if (this->type_.length() == 0) {
710  this->type_ = dmlc::type_name<DType>();
711  }
712  this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*)
713  }
714 
715  protected:
716  // print the value
717  virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*)
718  os << value;
719  }
720  void PrintDefaultValueString(std::ostream &os) const override { // NOLINT(*)
721  PrintValue(os, default_value_);
722  }
723  // get the internal representation of parameter
724  // for example if this entry corresponds field param.learning_rate
725  // then Get(&param) will return reference to param.learning_rate
726  inline DType &Get(void *head) const {
727  return *(DType*)this->GetRawPtr(head); // NOLINT(*)
728  }
729  // default value of field
730  DType default_value_;
731 };
732 
733 // parameter base for numeric types that have range
734 template<typename TEntry, typename DType>
735 class FieldEntryNumeric
736  : public FieldEntryBase<TEntry, DType> {
737  public:
738  FieldEntryNumeric()
739  : has_begin_(false), has_end_(false) {}
740  // implement set_range
741  virtual TEntry &set_range(DType begin, DType end) {
742  begin_ = begin; end_ = end;
743  has_begin_ = true; has_end_ = true;
744  return this->self();
745  }
746  // implement set_range
747  virtual TEntry &set_lower_bound(DType begin) {
748  begin_ = begin; has_begin_ = true;
749  return this->self();
750  }
751  // consistency check for numeric ranges
752  virtual void Check(void *head) const {
753  FieldEntryBase<TEntry, DType>::Check(head);
754  DType v = this->Get(head);
755  if (has_begin_ && has_end_) {
756  if (v < begin_ || v > end_) {
757  std::ostringstream os;
758  os << "value " << v << " for Parameter " << this->key_
759  << " exceed bound [" << begin_ << ',' << end_ <<']' << '\n';
760  os << this->key_ << ": " << this->description_;
761  throw dmlc::ParamError(os.str());
762  }
763  } else if (has_begin_ && v < begin_) {
764  std::ostringstream os;
765  os << "value " << v << " for Parameter " << this->key_
766  << " should be greater equal to " << begin_ << '\n';
767  os << this->key_ << ": " << this->description_;
768  throw dmlc::ParamError(os.str());
769  } else if (has_end_ && v > end_) {
770  std::ostringstream os;
771  os << "value " << v << " for Parameter " << this->key_
772  << " should be smaller equal to " << end_ << '\n';
773  os << this->key_ << ": " << this->description_;
774  throw dmlc::ParamError(os.str());
775  }
776  }
777 
778  protected:
779  // whether it have begin and end range
780  bool has_begin_, has_end_;
781  // data bound
782  DType begin_, end_;
783 };
784 
790 template<typename DType>
791 class FieldEntry :
792  public IfThenElseType<dmlc::is_arithmetic<DType>::value,
793  FieldEntryNumeric<FieldEntry<DType>, DType>,
794  FieldEntryBase<FieldEntry<DType>, DType> >::Type {
795 };
796 
797 // specialize define for int(enum)
798 template<>
799 class FieldEntry<int>
800  : public FieldEntryNumeric<FieldEntry<int>, int> {
801  public:
802  // construct
803  FieldEntry<int>() : is_enum_(false) {}
804  // parent
805  typedef FieldEntryNumeric<FieldEntry<int>, int> Parent;
806  // override set
807  virtual void Set(void *head, const std::string &value) const {
808  if (is_enum_) {
809  std::map<std::string, int>::const_iterator it = enum_map_.find(value);
810  std::ostringstream os;
811  if (it == enum_map_.end()) {
812  os << "Invalid Input: \'" << value;
813  os << "\', valid values are: ";
814  PrintEnums(os);
815  throw dmlc::ParamError(os.str());
816  } else {
817  os << it->second;
818  Parent::Set(head, os.str());
819  }
820  } else {
821  Parent::Set(head, value);
822  }
823  }
824  virtual ParamFieldInfo GetFieldInfo() const {
825  if (is_enum_) {
826  ParamFieldInfo info;
827  std::ostringstream os;
828  info.name = key_;
829  info.type = type_;
830  PrintEnums(os);
831  if (has_default_) {
832  os << ',' << "optional, default=";
833  PrintDefaultValueString(os);
834  } else {
835  os << ", required";
836  }
837  info.type_info_str = os.str();
838  info.description = description_;
839  return info;
840  } else {
841  return Parent::GetFieldInfo();
842  }
843  }
844  // add enum
845  inline FieldEntry<int> &add_enum(const std::string &key, int value) {
846  if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
847  enum_back_map_.count(value) != 0) {
848  std::ostringstream os;
849  os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
850  os << "Enums: ";
851  for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
852  it != enum_map_.end(); ++it) {
853  os << "(" << it->first << ": " << it->second << "), ";
854  }
855  throw dmlc::ParamError(os.str());
856  }
857  enum_map_[key] = value;
858  enum_back_map_[value] = key;
859  is_enum_ = true;
860  return this->self();
861  }
862 
863  protected:
864  // enum flag
865  bool is_enum_;
866  // enum map
867  std::map<std::string, int> enum_map_;
868  // enum map
869  std::map<int, std::string> enum_back_map_;
870  // override print behavior
871  virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
872  os << '\'';
873  PrintValue(os, default_value_);
874  os << '\'';
875  }
876  // override print default
877  virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
878  if (is_enum_) {
879  CHECK_NE(enum_back_map_.count(value), 0U)
880  << "Value not found in enum declared";
881  os << enum_back_map_.at(value);
882  } else {
883  os << value;
884  }
885  }
886 
887 
888  private:
889  inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
890  os << '{';
891  for (std::map<std::string, int>::const_iterator
892  it = enum_map_.begin(); it != enum_map_.end(); ++it) {
893  if (it != enum_map_.begin()) {
894  os << ", ";
895  }
896  os << "\'" << it->first << '\'';
897  }
898  os << '}';
899  }
900 };
901 
902 
903 // specialize define for optional<int>(enum)
904 template<>
905 class FieldEntry<optional<int> >
906  : public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
907  public:
908  // construct
909  FieldEntry<optional<int> >() : is_enum_(false) {}
910  // parent
911  typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
912  // override set
913  virtual void Set(void *head, const std::string &value) const {
914  if (is_enum_ && value != "None") {
915  std::map<std::string, int>::const_iterator it = enum_map_.find(value);
916  std::ostringstream os;
917  if (it == enum_map_.end()) {
918  os << "Invalid Input: \'" << value;
919  os << "\', valid values are: ";
920  PrintEnums(os);
921  throw dmlc::ParamError(os.str());
922  } else {
923  os << it->second;
924  Parent::Set(head, os.str());
925  }
926  } else {
927  Parent::Set(head, value);
928  }
929  }
930  virtual ParamFieldInfo GetFieldInfo() const {
931  if (is_enum_) {
932  ParamFieldInfo info;
933  std::ostringstream os;
934  info.name = key_;
935  info.type = type_;
936  PrintEnums(os);
937  if (has_default_) {
938  os << ',' << "optional, default=";
939  PrintDefaultValueString(os);
940  } else {
941  os << ", required";
942  }
943  info.type_info_str = os.str();
944  info.description = description_;
945  return info;
946  } else {
947  return Parent::GetFieldInfo();
948  }
949  }
950  // add enum
951  inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) {
952  CHECK_NE(key, "None") << "None is reserved for empty optional<int>";
953  if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
954  enum_back_map_.count(value) != 0) {
955  std::ostringstream os;
956  os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
957  os << "Enums: ";
958  for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
959  it != enum_map_.end(); ++it) {
960  os << "(" << it->first << ": " << it->second << "), ";
961  }
962  throw dmlc::ParamError(os.str());
963  }
964  enum_map_[key] = value;
965  enum_back_map_[value] = key;
966  is_enum_ = true;
967  return this->self();
968  }
969 
970  protected:
971  // enum flag
972  bool is_enum_;
973  // enum map
974  std::map<std::string, int> enum_map_;
975  // enum map
976  std::map<int, std::string> enum_back_map_;
977  // override print behavior
978  virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
979  os << '\'';
980  PrintValue(os, default_value_);
981  os << '\'';
982  }
983  // override print default
984  virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*)
985  if (is_enum_) {
986  if (!value) {
987  os << "None";
988  } else {
989  CHECK_NE(enum_back_map_.count(value.value()), 0U)
990  << "Value not found in enum declared";
991  os << enum_back_map_.at(value.value());
992  }
993  } else {
994  os << value;
995  }
996  }
997 
998 
999  private:
1000  inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
1001  os << "{None";
1002  for (std::map<std::string, int>::const_iterator
1003  it = enum_map_.begin(); it != enum_map_.end(); ++it) {
1004  os << ", ";
1005  os << "\'" << it->first << '\'';
1006  }
1007  os << '}';
1008  }
1009 };
1010 
1011 // specialize define for string
1012 template<>
1013 class FieldEntry<std::string>
1014  : public FieldEntryBase<FieldEntry<std::string>, std::string> {
1015  public:
1016  // parent class
1017  typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
1018  // override set
1019  virtual void Set(void *head, const std::string &value) const {
1020  this->Get(head) = value;
1021  }
1022  // override print default
1023  virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
1024  os << '\'' << default_value_ << '\'';
1025  }
1026 };
1027 
1028 // specialize define for bool
1029 template<>
1030 class FieldEntry<bool>
1031  : public FieldEntryBase<FieldEntry<bool>, bool> {
1032  public:
1033  // parent class
1034  typedef FieldEntryBase<FieldEntry<bool>, bool> Parent;
1035  // override set
1036  virtual void Set(void *head, const std::string &value) const {
1037  std::string lower_case; lower_case.resize(value.length());
1038  std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
1039  bool &ref = this->Get(head);
1040  if (lower_case == "true") {
1041  ref = true;
1042  } else if (lower_case == "false") {
1043  ref = false;
1044  } else if (lower_case == "1") {
1045  ref = true;
1046  } else if (lower_case == "0") {
1047  ref = false;
1048  } else {
1049  std::ostringstream os;
1050  os << "Invalid Parameter format for " << key_
1051  << " expect " << type_ << " but value=\'" << value<< '\'';
1052  throw dmlc::ParamError(os.str());
1053  }
1054  }
1055 
1056  protected:
1057  // print default string
1058  virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*)
1059  os << static_cast<int>(value);
1060  }
1061 };
1062 
1063 
1064 // specialize define for float. Uses stof for platform independent handling of
1065 // INF, -INF, NAN, etc.
1066 #if DMLC_USE_CXX11
1067 template <>
1068 class FieldEntry<float> : public FieldEntryNumeric<FieldEntry<float>, float> {
1069  public:
1070  // parent
1071  typedef FieldEntryNumeric<FieldEntry<float>, float> Parent;
1072  // override set
1073  virtual void Set(void *head, const std::string &value) const {
1074  size_t pos = 0; // number of characters processed by dmlc::stof()
1075  try {
1076  this->Get(head) = dmlc::stof(value, &pos);
1077  } catch (const std::invalid_argument &) {
1078  std::ostringstream os;
1079  os << "Invalid Parameter format for " << key_ << " expect " << type_
1080  << " but value=\'" << value << '\'';
1081  throw dmlc::ParamError(os.str());
1082  } catch (const std::out_of_range&) {
1083  std::ostringstream os;
1084  os << "Out of range value for " << key_ << ", value=\'" << value << '\'';
1085  throw dmlc::ParamError(os.str());
1086  }
1087  CHECK_LE(pos, value.length()); // just in case
1088  if (pos < value.length()) {
1089  std::ostringstream os;
1090  os << "Some trailing characters could not be parsed: \'"
1091  << value.substr(pos) << "\'";
1092  throw dmlc::ParamError(os.str());
1093  }
1094  }
1095 
1096  protected:
1097  // print the value
1098  virtual void PrintValue(std::ostream &os, float value) const { // NOLINT(*)
1099  os << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
1100  }
1101 };
1102 
1103 // specialize define for double. Uses stod for platform independent handling of
1104 // INF, -INF, NAN, etc.
1105 template <>
1106 class FieldEntry<double>
1107  : public FieldEntryNumeric<FieldEntry<double>, double> {
1108  public:
1109  // parent
1110  typedef FieldEntryNumeric<FieldEntry<double>, double> Parent;
1111  // override set
1112  virtual void Set(void *head, const std::string &value) const {
1113  size_t pos = 0; // number of characters processed by dmlc::stod()
1114  try {
1115  this->Get(head) = dmlc::stod(value, &pos);
1116  } catch (const std::invalid_argument &) {
1117  std::ostringstream os;
1118  os << "Invalid Parameter format for " << key_ << " expect " << type_
1119  << " but value=\'" << value << '\'';
1120  throw dmlc::ParamError(os.str());
1121  } catch (const std::out_of_range&) {
1122  std::ostringstream os;
1123  os << "Out of range value for " << key_ << ", value=\'" << value << '\'';
1124  throw dmlc::ParamError(os.str());
1125  }
1126  CHECK_LE(pos, value.length()); // just in case
1127  if (pos < value.length()) {
1128  std::ostringstream os;
1129  os << "Some trailing characters could not be parsed: \'"
1130  << value.substr(pos) << "\'";
1131  throw dmlc::ParamError(os.str());
1132  }
1133  }
1134 
1135  protected:
1136  // print the value
1137  virtual void PrintValue(std::ostream &os, double value) const { // NOLINT(*)
1138  os << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
1139  }
1140 };
1141 #endif // DMLC_USE_CXX11
1142 
1143 } // namespace parameter
1145 
1146 // implement GetEnv
1147 template<typename ValueType>
1148 inline ValueType GetEnv(const char *key,
1149  ValueType default_value) {
1150  const char *val = getenv(key);
1151  // On some implementations, if the var is set to a blank string (i.e. "FOO="), then
1152  // a blank string will be returned instead of NULL. In order to be consistent, if
1153  // the environment var is a blank string, then also behave as if a null was returned.
1154  if (val == nullptr || !*val) {
1155  return default_value;
1156  }
1157  ValueType ret;
1158  parameter::FieldEntry<ValueType> e;
1159  e.Init(key, &ret, ret);
1160  e.Set(&ret, val);
1161  return ret;
1162 }
1163 
1164 // implement SetEnv
1165 template<typename ValueType>
1166 inline void SetEnv(const char *key,
1167  ValueType value) {
1168  parameter::FieldEntry<ValueType> e;
1169  e.Init(key, &value, value);
1170 #ifdef _WIN32
1171  _putenv_s(key, e.GetStringValue(&value).c_str());
1172 #else
1173  setenv(key, e.GetStringValue(&value).c_str(), 1);
1174 #endif // _WIN32
1175 }
1176 } // namespace dmlc
1177 #endif // DMLC_PARAMETER_H_
Container to hold optional data.
Definition: optional.h:241
double stod(const std::string &value, size_t *pos=nullptr)
A faster implementation of stod(). See documentation of std::stod() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:497
A faster implementation of strtof and strtod.
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
namespace for dmlc
Definition: array_view.h:12
void Write(const ValueType &value)
Write value to json.
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:467
void Read(ValueType *out_value)
Read next ValueType.
type traits information header
Lightweight json to write any STL compositions.
Definition: json.h:189