mxnet
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef NNVM_OP_H_
26 #define NNVM_OP_H_
27 
28 #include <dmlc/parameter.h>
29 #include <string>
30 #include <vector>
31 #include <utility>
32 #include <typeinfo>
33 #include <limits>
34 #include <functional>
35 #include "base.h"
36 #include "c_api.h"
37 
38 namespace nnvm {
39 
40 // forward declarations
41 class Node;
42 struct NodeAttrs;
43 template<typename ValueType>
44 class OpMap;
45 class OpGroup;
46 class OpRegistryEntry;
47 using dmlc::ParamFieldInfo;
48 
50 static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
51 
104 class NNVM_DLL Op {
105  public:
107  std::string name;
112  std::string description;
113  /* \brief description of inputs and keyword arguments*/
114  std::vector<ParamFieldInfo> arguments;
122  uint32_t num_inputs = 1;
130  uint32_t num_outputs = 1;
136  uint32_t support_level = 10;
142  std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
148  std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
181  std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
182  // function fields.
189  inline Op& describe(const std::string& descr); // NOLINT(*)
197  inline Op& add_argument(const std::string &name,
198  const std::string &type,
199  const std::string &description);
205  inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
211  inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
217  inline Op& set_support_level(uint32_t level); // NOLINT(*)
223  inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
229  inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
235  inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
241  inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
255  template<typename ValueType>
256  inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
257  const ValueType& value,
258  int plevel = 10);
265  Op& add_alias(const std::string& alias); // NOLINT(*)
273  Op& include(const std::string& group_name);
280  static const Op* Get(const std::string& op_name);
288  template<typename ValueType>
289  static const OpMap<ValueType>& GetAttr(const std::string& attr_name);
290 
291  private:
292  template<typename ValueType>
293  friend class OpMap;
294  friend class OpGroup;
295  friend class dmlc::Registry<Op>;
296  // Program internal unique index of operator.
297  // Used to help index the program.
298  uint32_t index_{0};
299  // internal constructor
300  Op();
301  // get const reference to certain attribute
302  static const any* GetAttrMap(const std::string& key);
303  // update the attribute OpMap
304  static void UpdateAttrMap(const std::string& key,
305  std::function<void(any*)> updater);
306  // add a trigger based on tag matching on certain tag attribute
307  // This will apply trigger on all the op such that
308  // include the corresponding group.
309  // The trigger will also be applied to all future registrations
310  // that calls include
311  static void AddGroupTrigger(const std::string& group_name,
312  std::function<void(Op*)> trigger);
313 };
314 
320 template<typename ValueType>
321 class OpMap {
322  public:
328  inline const ValueType& operator[](const Op* op) const;
335  inline const ValueType& get(const Op* op, const ValueType& def_value) const;
341  inline int count(const Op* op) const;
342 
348  inline bool contains(const Op* op) const;
349 
350  private:
351  friend class Op;
352  // internal attribute name
353  std::string attr_name_;
354  // internal data
355  std::vector<std::pair<ValueType, int> > data_;
356  OpMap() = default;
357 };
358 
363 class OpGroup {
364  public:
366  std::string group_name;
380  template<typename ValueType>
381  inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
382  const ValueType& value,
383  int plevel = 1);
384 };
385 
386 // internal macros to make
387 #define NNVM_REGISTER_VAR_DEF(OpName) \
388  static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
389 
390 #define NNVM_REGISTER_GVAR_DEF(TagName) \
391  static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
392 
408 #define NNVM_REGISTER_OP(OpName) \
409  DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
410  ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
411 
433 #define NNVM_REGISTER_OP_GROUP(GroupName) \
434  DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
435  ::nnvm::OpGroup {#GroupName}
436 
437 // implementations of template functions after this.
438 // member function of Op
439 template<typename ValueType>
440 inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
441  const any* ref = GetAttrMap(key);
442  if (ref == nullptr) {
443  // update the attribute map of the key by creating new empty OpMap
444  UpdateAttrMap(key, [key](any* pmap) {
445  // use callback so it is in lockscope
446  if (pmap->empty()) {
447  OpMap<ValueType> pm;
448  pm.attr_name_ = key;
449  *pmap = std::move(pm);
450  }
451  });
452  ref = GetAttrMap(key);
453  }
454  return nnvm::get<OpMap<ValueType> >(*ref);
455 }
456 
457 template<typename ValueType>
458 inline Op& Op::set_attr( // NOLINT(*)
459  const std::string& attr_name,
460  const ValueType& value,
461  int plevel) {
462  CHECK_GT(plevel, 0)
463  << "plevel in set_attr must be greater than 0";
464  // update the attribute map of the key by creating new empty if needed.
465  UpdateAttrMap(attr_name,
466  [this, attr_name, value, plevel](any* pmap) {
467  // the callback is in lockscope so is threadsafe.
468  if (pmap->empty()) {
469  OpMap<ValueType> pm;
470  pm.attr_name_ = attr_name;
471  *pmap = std::move(pm);
472  }
473  CHECK(pmap->type() == typeid(OpMap<ValueType>))
474  << "Attribute " << attr_name
475  << " of operator " << this->name
476  << " is registered as inconsistent types"
477  << " previously " << pmap->type().name()
478  << " current " << typeid(OpMap<ValueType>).name();
479  std::vector<std::pair<ValueType, int> >& vec =
480  nnvm::get<OpMap<ValueType> >(*pmap).data_;
481  // resize the value type.
482  if (vec.size() <= index_) {
483  vec.resize(index_ + 1,
484  std::make_pair(ValueType(), 0));
485  }
486  std::pair<ValueType, int>& p = vec[index_];
487  CHECK(p.second != plevel)
488  << "Attribute " << attr_name
489  << " of operator " << this->name
490  << " is already registered with same plevel=" << plevel;
491  if (p.second < plevel) {
492  vec[index_] = std::make_pair(value, plevel);
493  }
494  });
495  return *this;
496 }
497 
498 
499 inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
500  this->description = descr;
501  return *this;
502 }
503 
504 inline Op& Op::add_argument(const std::string &name,
505  const std::string &type,
506  const std::string &description) {
507  arguments.push_back({name, type, type, description});
508  return *this;
509 }
510 
511 inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
512  this->arguments.insert(arguments.end(), args.begin(), args.end());
513  return *this;
514 }
515 
516 inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
517  this->num_inputs = n;
518  return *this;
519 }
520 
521 inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*)
522  this->support_level = n;
523  return *this;
524 }
525 
526 inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
527  this->get_num_inputs = fn;
528  return *this;
529 }
530 
531 inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
532  this->num_outputs = n;
533  return *this;
534 }
535 
536 inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
537  this->get_num_outputs = fn;
538  return *this;
539 }
540 
541 inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
542  this->attr_parser = fn;
543  return *this;
544 }
545 
546 // member functions of OpMap
547 template<typename ValueType>
548 inline int OpMap<ValueType>::count(const Op* op) const {
549  if (contains(op)) {
550  return 1;
551  } else {
552  return 0;
553  }
554 }
555 
556 template<typename ValueType>
557 inline bool OpMap<ValueType>::contains(const Op* op) const {
558  if (op == nullptr) {
559  return false;
560  }
561  const uint32_t idx = op->index_;
562  return idx < data_.size() ? (data_[idx].second != 0) : false;
563 }
564 
565 template<typename ValueType>
566 inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
567  CHECK(op != nullptr);
568  const uint32_t idx = op->index_;
569  CHECK(idx < data_.size() && data_[idx].second)
570  << "Attribute " << attr_name_
571  << " has not been registered for Operator " << op->name;
572  return data_[idx].first;
573 }
574 
575 template<typename ValueType>
576 inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
577  if (op == nullptr) return def_value;
578  const uint32_t idx = op->index_;
579  if (idx < data_.size() && data_[idx].second) {
580  return data_[idx].first;
581  } else {
582  return def_value;
583  }
584 }
585 
586 template<typename ValueType>
587 inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
588  const ValueType& value,
589  int plevel) {
590  auto trigger = [attr_name, value, plevel](Op* op) {
591  op->set_attr<ValueType>(attr_name, value, plevel);
592  };
593  Op::AddGroupTrigger(group_name, trigger);
594  return *this;
595 }
596 
597 } // namespace nnvm
598 
599 #endif // NNVM_OP_H_
std::vector< ParamFieldInfo > arguments
Definition: op.h:114
Definition: base.h:36
Registry class. Registry can be used to register global singletons. The most commonly use case are fa...
Definition: registry.h:27
Op & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:504
std::string description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:112
Op & set_attr_parser(std::function< void(NodeAttrs *attrs)> fn)
Set the attr_parser function.
Definition: op.h:541
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:120
const ValueType & operator[](const Op *op) const
get the corresponding value element at op
Definition: op.h:566
static const OpMap< ValueType > & GetAttr(const std::string &attr_name)
Get additional registered attribute about operators. If nothing has been registered, an empty OpMap will be returned.
Definition: op.h:440
Op & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:499
Op & set_num_outputs(uint32_t n)
Set the num_outputs.
Definition: op.h:531
OpGroup & set_attr(const std::string &attr_name, const ValueType &value, int plevel=1)
Register additional attributes to operator group.
Definition: op.h:587
Op & set_support_level(uint32_t level)
Set the support level of op.
Definition: op.h:521
#define NNVM_DLL
NNVM_DLL prefix for windows.
Definition: c_api.h:38
Op & set_num_inputs(uint32_t n)
Set the num_inputs.
Definition: op.h:516
bool contains(const Op *op) const
Check if the map has op as key.
Definition: op.h:557
auxiliary data structure used to set attributes to a group of operators
Definition: op.h:363
Op & add_arguments(const std::vector< ParamFieldInfo > &args)
Append list if arguments to the end.
Definition: op.h:511
std::string name
name of the operator
Definition: op.h:107
nnvm::Op Op
operator structure from NNVM
Definition: base.h:99
int count(const Op *op) const
Check if the map has op as key.
Definition: op.h:548
A map data structure that takes Op* as key and returns ValueType.
Definition: op.h:44
const ValueType & get(const Op *op, const ValueType &def_value) const
get the corresponding value element at op with default value.
Definition: op.h:576
std::string group_name
the tag key to be matched
Definition: op.h:366
C API of NNVM symbolic construction and pass. Enables construction and transformation of Graph in any...
Op & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:458
Provide lightweight util to do parameter setup and checking.
Configuration of nnvm as well as basic data structure.
Operator structure.
Definition: op.h:104