44 template <
typename ValueType>
47 class OpRegistryEntry;
48 using dmlc::ParamFieldInfo;
51 static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
123 uint32_t num_inputs = 1;
131 uint32_t num_outputs = 1;
137 uint32_t support_level = 10;
143 std::function<uint32_t(
const NodeAttrs& attrs)> get_num_outputs =
nullptr;
149 std::function<uint32_t(
const NodeAttrs& attrs)> get_num_inputs =
nullptr;
182 std::function<void(
NodeAttrs* attrs)> attr_parser =
nullptr;
190 inline Op& describe(
const std::string& descr);
198 inline Op& add_argument(
const std::string& name,
const std::string& type,
199 const std::string& description);
205 inline Op& add_arguments(
const std::vector<ParamFieldInfo>& args);
211 inline Op& set_num_inputs(uint32_t n);
217 inline Op& set_support_level(uint32_t level);
223 inline Op& set_num_inputs(std::function<uint32_t(
const NodeAttrs& attr)> fn);
229 inline Op& set_num_outputs(uint32_t n);
235 inline Op& set_num_outputs(std::function<uint32_t(
const NodeAttrs& attr)> fn);
241 inline Op& set_attr_parser(std::function<
void(
NodeAttrs* attrs)> fn);
255 template <
typename ValueType>
256 inline Op& set_attr(
const std::string& attr_name,
257 const ValueType& value,
int plevel = 10);
264 Op& add_alias(
const std::string& alias);
272 Op& include(
const std::string& group_name);
279 static const Op* Get(
const std::string& op_name);
287 template <
typename ValueType>
291 template <
typename ValueType>
301 static const any* GetAttrMap(
const std::string& key);
303 static void UpdateAttrMap(
const std::string& key, std::function<
void(any*)> updater);
309 static void AddGroupTrigger(
const std::string& group_name, std::function<
void(
Op*)> trigger);
317 template <
typename ValueType>
332 inline const ValueType&
get(
const Op* op,
const ValueType& def_value)
const;
338 inline int count(
const Op* op)
const;
350 std::string attr_name_;
352 std::vector<std::pair<ValueType, int>> data_;
377 template <
typename ValueType>
379 const ValueType& value,
int plevel = 1);
383 #define NNVM_REGISTER_VAR_DEF(OpName) \
384 static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op& __make_##NnvmOp##_##OpName
386 #define NNVM_REGISTER_GVAR_DEF(TagName) \
387 static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_##NnvmOpGroup##_##TagName
404 #define NNVM_REGISTER_OP(OpName) \
405 DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
406 ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
429 #define NNVM_REGISTER_OP_GROUP(GroupName) \
430 DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = ::nnvm::OpGroup { #GroupName }
434 template <
typename ValueType>
436 const any* ref = GetAttrMap(key);
437 if (ref ==
nullptr) {
439 UpdateAttrMap(key, [key](any* pmap) {
444 *pmap = std::move(pm);
447 ref = GetAttrMap(key);
449 return nnvm::get<OpMap<ValueType>>(*ref);
452 template <
typename ValueType>
454 const std::string& attr_name,
const ValueType& value,
int plevel) {
455 CHECK_GT(plevel, 0) <<
"plevel in set_attr must be greater than 0";
457 UpdateAttrMap(attr_name, [
this, attr_name, value, plevel](any* pmap) {
461 pm.attr_name_ = attr_name;
462 *pmap = std::move(pm);
465 <<
"Attribute " << attr_name <<
" of operator " << this->name
466 <<
" is registered as inconsistent types"
467 <<
" previously " << pmap->type().name() <<
" current " <<
typeid(
OpMap<ValueType>).name();
468 std::vector<std::pair<ValueType, int>>& vec = nnvm::get<
OpMap<ValueType>>(*pmap).data_;
470 if (vec.size() <= index_) {
471 vec.resize(index_ + 1, std::make_pair(ValueType(), 0));
473 std::pair<ValueType, int>& p = vec[index_];
474 CHECK(p.second != plevel) <<
"Attribute " << attr_name <<
" of operator " << this->name
475 <<
" is already registered with same plevel=" << plevel;
476 if (p.second < plevel) {
477 vec[index_] = std::make_pair(value, plevel);
484 this->description = descr;
489 const std::string& description) {
495 this->arguments.insert(
arguments.end(), args.begin(), args.end());
500 this->num_inputs = n;
505 this->support_level = n;
510 this->get_num_inputs = fn;
515 this->num_outputs = n;
520 this->get_num_outputs = fn;
525 this->attr_parser = fn;
530 template <
typename ValueType>
539 template <
typename ValueType>
544 const uint32_t idx = op->index_;
545 return idx < data_.size() ? (data_[idx].second != 0) :
false;
548 template <
typename ValueType>
550 CHECK(op !=
nullptr);
551 const uint32_t idx = op->index_;
552 CHECK(idx < data_.size() && data_[idx].second)
553 <<
"Attribute " << attr_name_ <<
" has not been registered for Operator " << op->
name;
554 return data_[idx].first;
557 template <
typename ValueType>
559 if (op ==
nullptr)
return def_value;
560 const uint32_t idx = op->index_;
561 if (idx < data_.size() && data_[idx].second) {
562 return data_[idx].first;
568 template <
typename ValueType>
571 auto trigger = [attr_name, value, plevel](
Op* op) {
572 op->
set_attr<ValueType>(attr_name, value, plevel);