mxnet
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_OP_H_
25 #define NNVM_OP_H_
26 
27 #include <dmlc/parameter.h>
28 
29 #include <functional>
30 #include <limits>
31 #include <string>
32 #include <typeinfo>
33 #include <utility>
34 #include <vector>
35 
36 #include "base.h"
37 #include "c_api.h"
38 
39 namespace nnvm {
40 
41 // forward declarations
42 class Node;
43 struct NodeAttrs;
44 template <typename ValueType>
45 class OpMap;
46 class OpGroup;
47 class OpRegistryEntry;
48 using dmlc::ParamFieldInfo;
49 
51 static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
52 
105 class NNVM_DLL Op {
106  public:
108  std::string name;
113  std::string description;
114  /* \brief description of inputs and keyword arguments*/
115  std::vector<ParamFieldInfo> arguments;
123  uint32_t num_inputs = 1;
131  uint32_t num_outputs = 1;
137  uint32_t support_level = 10;
143  std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
149  std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
182  std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
183  // function fields.
190  inline Op& describe(const std::string& descr); // NOLINT(*)
198  inline Op& add_argument(const std::string& name, const std::string& type,
199  const std::string& description);
205  inline Op& add_arguments(const std::vector<ParamFieldInfo>& args);
211  inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
217  inline Op& set_support_level(uint32_t level); // NOLINT(*)
223  inline Op& set_num_inputs(std::function<uint32_t(const NodeAttrs& attr)> fn); // NOLINT(*)
229  inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
235  inline Op& set_num_outputs(std::function<uint32_t(const NodeAttrs& attr)> fn); // NOLINT(*)
241  inline Op& set_attr_parser(std::function<void(NodeAttrs* attrs)> fn); // NOLINT(*)
255  template <typename ValueType>
256  inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
257  const ValueType& value, int plevel = 10);
264  Op& add_alias(const std::string& alias); // NOLINT(*)
272  Op& include(const std::string& group_name);
279  static const Op* Get(const std::string& op_name);
287  template <typename ValueType>
288  static const OpMap<ValueType>& GetAttr(const std::string& attr_name);
289 
290  private:
291  template <typename ValueType>
292  friend class OpMap;
293  friend class OpGroup;
294  friend class dmlc::Registry<Op>;
295  // Program internal unique index of operator.
296  // Used to help index the program.
297  uint32_t index_{0};
298  // internal constructor
299  Op();
300  // get const reference to certain attribute
301  static const any* GetAttrMap(const std::string& key);
302  // update the attribute OpMap
303  static void UpdateAttrMap(const std::string& key, std::function<void(any*)> updater);
304  // add a trigger based on tag matching on certain tag attribute
305  // This will apply trigger on all the op such that
306  // include the corresponding group.
307  // The trigger will also be applied to all future registrations
308  // that calls include
309  static void AddGroupTrigger(const std::string& group_name, std::function<void(Op*)> trigger);
310 };
311 
317 template <typename ValueType>
318 class OpMap {
319  public:
325  inline const ValueType& operator[](const Op* op) const;
332  inline const ValueType& get(const Op* op, const ValueType& def_value) const;
338  inline int count(const Op* op) const;
339 
345  inline bool contains(const Op* op) const;
346 
347  private:
348  friend class Op;
349  // internal attribute name
350  std::string attr_name_;
351  // internal data
352  std::vector<std::pair<ValueType, int>> data_;
353  OpMap() = default;
354 };
355 
360 class OpGroup {
361  public:
363  std::string group_name;
377  template <typename ValueType>
378  inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
379  const ValueType& value, int plevel = 1);
380 };
381 
382 // internal macros to make
383 #define NNVM_REGISTER_VAR_DEF(OpName) \
384  static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op& __make_##NnvmOp##_##OpName
385 
386 #define NNVM_REGISTER_GVAR_DEF(TagName) \
387  static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_##NnvmOpGroup##_##TagName
388 
404 #define NNVM_REGISTER_OP(OpName) \
405  DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
406  ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
407 
429 #define NNVM_REGISTER_OP_GROUP(GroupName) \
430  DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = ::nnvm::OpGroup { #GroupName }
431 
432 // implementations of template functions after this.
433 // member function of Op
434 template <typename ValueType>
435 inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
436  const any* ref = GetAttrMap(key);
437  if (ref == nullptr) {
438  // update the attribute map of the key by creating new empty OpMap
439  UpdateAttrMap(key, [key](any* pmap) {
440  // use callback so it is in lockscope
441  if (pmap->empty()) {
442  OpMap<ValueType> pm;
443  pm.attr_name_ = key;
444  *pmap = std::move(pm);
445  }
446  });
447  ref = GetAttrMap(key);
448  }
449  return nnvm::get<OpMap<ValueType>>(*ref);
450 }
451 
452 template <typename ValueType>
453 inline Op& Op::set_attr( // NOLINT(*)
454  const std::string& attr_name, const ValueType& value, int plevel) {
455  CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
456  // update the attribute map of the key by creating new empty if needed.
457  UpdateAttrMap(attr_name, [this, attr_name, value, plevel](any* pmap) {
458  // the callback is in lockscope so is threadsafe.
459  if (pmap->empty()) {
460  OpMap<ValueType> pm;
461  pm.attr_name_ = attr_name;
462  *pmap = std::move(pm);
463  }
464  CHECK(pmap->type() == typeid(OpMap<ValueType>))
465  << "Attribute " << attr_name << " of operator " << this->name
466  << " is registered as inconsistent types"
467  << " previously " << pmap->type().name() << " current " << typeid(OpMap<ValueType>).name();
468  std::vector<std::pair<ValueType, int>>& vec = nnvm::get<OpMap<ValueType>>(*pmap).data_;
469  // resize the value type.
470  if (vec.size() <= index_) {
471  vec.resize(index_ + 1, std::make_pair(ValueType(), 0));
472  }
473  std::pair<ValueType, int>& p = vec[index_];
474  CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name
475  << " is already registered with same plevel=" << plevel;
476  if (p.second < plevel) {
477  vec[index_] = std::make_pair(value, plevel);
478  }
479  });
480  return *this;
481 }
482 
483 inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
484  this->description = descr;
485  return *this;
486 }
487 
488 inline Op& Op::add_argument(const std::string& name, const std::string& type,
489  const std::string& description) {
490  arguments.push_back({name, type, type, description});
491  return *this;
492 }
493 
494 inline Op& Op::add_arguments(const std::vector<ParamFieldInfo>& args) {
495  this->arguments.insert(arguments.end(), args.begin(), args.end());
496  return *this;
497 }
498 
499 inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
500  this->num_inputs = n;
501  return *this;
502 }
503 
504 inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*)
505  this->support_level = n;
506  return *this;
507 }
508 
509 inline Op& Op::set_num_inputs(std::function<uint32_t(const NodeAttrs& attr)> fn) { // NOLINT(*)
510  this->get_num_inputs = fn;
511  return *this;
512 }
513 
514 inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
515  this->num_outputs = n;
516  return *this;
517 }
518 
519 inline Op& Op::set_num_outputs(std::function<uint32_t(const NodeAttrs& attr)> fn) { // NOLINT(*)
520  this->get_num_outputs = fn;
521  return *this;
522 }
523 
524 inline Op& Op::set_attr_parser(std::function<void(NodeAttrs* attrs)> fn) { // NOLINT(*)
525  this->attr_parser = fn;
526  return *this;
527 }
528 
529 // member functions of OpMap
530 template <typename ValueType>
531 inline int OpMap<ValueType>::count(const Op* op) const {
532  if (contains(op)) {
533  return 1;
534  } else {
535  return 0;
536  }
537 }
538 
539 template <typename ValueType>
540 inline bool OpMap<ValueType>::contains(const Op* op) const {
541  if (op == nullptr) {
542  return false;
543  }
544  const uint32_t idx = op->index_;
545  return idx < data_.size() ? (data_[idx].second != 0) : false;
546 }
547 
548 template <typename ValueType>
549 inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
550  CHECK(op != nullptr);
551  const uint32_t idx = op->index_;
552  CHECK(idx < data_.size() && data_[idx].second)
553  << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name;
554  return data_[idx].first;
555 }
556 
557 template <typename ValueType>
558 inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
559  if (op == nullptr) return def_value;
560  const uint32_t idx = op->index_;
561  if (idx < data_.size() && data_[idx].second) {
562  return data_[idx].first;
563  } else {
564  return def_value;
565  }
566 }
567 
568 template <typename ValueType>
569 inline OpGroup& OpGroup::set_attr(const std::string& attr_name, const ValueType& value,
570  int plevel) {
571  auto trigger = [attr_name, value, plevel](Op* op) {
572  op->set_attr<ValueType>(attr_name, value, plevel);
573  };
574  Op::AddGroupTrigger(group_name, trigger);
575  return *this;
576 }
577 
578 } // namespace nnvm
579 
580 #endif // NNVM_OP_H_
nnvm::OpMap::get
const ValueType & get(const Op *op, const ValueType &def_value) const
get the corresponding value element at op with default value.
Definition: op.h:558
nnvm::OpMap::count
int count(const Op *op) const
Check if the map has op as key.
Definition: op.h:531
nnvm::Op::name
std::string name
name of the operator
Definition: op.h:108
nnvm::Op::add_arguments
Op & add_arguments(const std::vector< ParamFieldInfo > &args)
Append list if arguments to the end.
Definition: op.h:494
parameter.h
Provide lightweight util to do parameter setup and checking.
nnvm::OpMap::Op
friend class Op
operator structure from NNVM
Definition: op.h:348
base.h
Configuration of nnvm as well as basic data structure.
nnvm::OpMap
A map data structure that takes Op* as key and returns ValueType.
Definition: op.h:45
nnvm::Op::arguments
std::vector< ParamFieldInfo > arguments
Definition: op.h:115
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
nnvm::Op::set_attr
Op & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:453
dmlc::Registry
Registry class. Registry can be used to register global singletons. The most commonly use case are fa...
Definition: registry.h:27
nnvm::Op::add_argument
Op & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:488
nnvm::Op::description
std::string description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:113
c_api.h
C API of NNVM symbolic construction and pass. Enables construction and transformation of Graph in any...
nnvm::OpGroup::set_attr
OpGroup & set_attr(const std::string &attr_name, const ValueType &value, int plevel=1)
Register additional attributes to operator group.
Definition: op.h:569
nnvm::OpGroup
auxiliary data structure used to set attributes to a group of operators
Definition: op.h:360
nnvm::Op::GetAttr
static const OpMap< ValueType > & GetAttr(const std::string &attr_name)
Get additional registered attribute about operators. If nothing has been registered,...
Definition: op.h:435
nnvm::Op::set_attr_parser
Op & set_attr_parser(std::function< void(NodeAttrs *attrs)> fn)
Set the attr_parser function.
Definition: op.h:524
nnvm::Op::set_num_outputs
Op & set_num_outputs(uint32_t n)
Set the num_outputs.
Definition: op.h:514
nnvm::Op::describe
Op & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:483
nnvm::OpMap::contains
bool contains(const Op *op) const
Check if the map has op as key.
Definition: op.h:540
nnvm::Op::set_num_inputs
Op & set_num_inputs(uint32_t n)
Set the num_inputs.
Definition: op.h:499
NNVM_DLL
#define NNVM_DLL
NNVM_DLL prefix for windows.
Definition: c_api.h:37
nnvm::OpMap::operator[]
const ValueType & operator[](const Op *op) const
get the corresponding value element at op
Definition: op.h:549
nnvm::OpGroup::group_name
std::string group_name
the tag key to be matched
Definition: op.h:363
mxnet::Op
nnvm::Op Op
operator structure from NNVM
Definition: base.h:87
nnvm::Op
Operator structure.
Definition: op.h:105
nnvm::Op::set_support_level
Op & set_support_level(uint32_t level)
Set the support level of op.
Definition: op.h:504
nnvm
Definition: base.h:35