mxnet
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_OP_H_
25 #define NNVM_OP_H_
26 
27 #include <dmlc/parameter.h>
28 #include <string>
29 #include <vector>
30 #include <utility>
31 #include <typeinfo>
32 #include <limits>
33 #include <functional>
34 #include "base.h"
35 #include "c_api.h"
36 
37 namespace nnvm {
38 
39 // forward declarations
40 class Node;
41 struct NodeAttrs;
42 template<typename ValueType>
43 class OpMap;
44 class OpGroup;
45 class OpRegistryEntry;
46 using dmlc::ParamFieldInfo;
47 
49 static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
50 
103 class NNVM_DLL Op {
104  public:
106  std::string name;
111  std::string description;
112  /* \brief description of inputs and keyword arguments*/
113  std::vector<ParamFieldInfo> arguments;
121  uint32_t num_inputs = 1;
129  uint32_t num_outputs = 1;
135  uint32_t support_level = 10;
141  std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
147  std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
180  std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
181  // function fields.
188  inline Op& describe(const std::string& descr); // NOLINT(*)
196  inline Op& add_argument(const std::string &name,
197  const std::string &type,
198  const std::string &description);
204  inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
210  inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
216  inline Op& set_support_level(uint32_t level); // NOLINT(*)
222  inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
228  inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
234  inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
240  inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
254  template<typename ValueType>
255  inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
256  const ValueType& value,
257  int plevel = 10);
264  Op& add_alias(const std::string& alias); // NOLINT(*)
272  Op& include(const std::string& group_name);
279  static const Op* Get(const std::string& op_name);
287  template<typename ValueType>
288  static const OpMap<ValueType>& GetAttr(const std::string& attr_name);
289 
290  private:
291  template<typename ValueType>
292  friend class OpMap;
293  friend class OpGroup;
294  friend class dmlc::Registry<Op>;
295  // Program internal unique index of operator.
296  // Used to help index the program.
297  uint32_t index_{0};
298  // internal constructor
299  Op();
300  // get const reference to certain attribute
301  static const any* GetAttrMap(const std::string& key);
302  // update the attribute OpMap
303  static void UpdateAttrMap(const std::string& key,
304  std::function<void(any*)> updater);
305  // add a trigger based on tag matching on certain tag attribute
306  // This will apply trigger on all the op such that
307  // include the corresponding group.
308  // The trigger will also be applied to all future registrations
309  // that calls include
310  static void AddGroupTrigger(const std::string& group_name,
311  std::function<void(Op*)> trigger);
312 };
313 
319 template<typename ValueType>
320 class OpMap {
321  public:
327  inline const ValueType& operator[](const Op* op) const;
334  inline const ValueType& get(const Op* op, const ValueType& def_value) const;
340  inline int count(const Op* op) const;
341 
347  inline bool contains(const Op* op) const;
348 
349  private:
350  friend class Op;
351  // internal attribute name
352  std::string attr_name_;
353  // internal data
354  std::vector<std::pair<ValueType, int> > data_;
355  OpMap() = default;
356 };
357 
362 class OpGroup {
363  public:
365  std::string group_name;
379  template<typename ValueType>
380  inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
381  const ValueType& value,
382  int plevel = 1);
383 };
384 
385 // internal macros to make
386 #define NNVM_REGISTER_VAR_DEF(OpName) \
387  static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
388 
389 #define NNVM_REGISTER_GVAR_DEF(TagName) \
390  static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
391 
407 #define NNVM_REGISTER_OP(OpName) \
408  DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
409  ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
410 
432 #define NNVM_REGISTER_OP_GROUP(GroupName) \
433  DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
434  ::nnvm::OpGroup {#GroupName}
435 
436 // implementations of template functions after this.
437 // member function of Op
438 template<typename ValueType>
439 inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
440  const any* ref = GetAttrMap(key);
441  if (ref == nullptr) {
442  // update the attribute map of the key by creating new empty OpMap
443  UpdateAttrMap(key, [key](any* pmap) {
444  // use callback so it is in lockscope
445  if (pmap->empty()) {
446  OpMap<ValueType> pm;
447  pm.attr_name_ = key;
448  *pmap = std::move(pm);
449  }
450  });
451  ref = GetAttrMap(key);
452  }
453  return nnvm::get<OpMap<ValueType> >(*ref);
454 }
455 
456 template<typename ValueType>
457 inline Op& Op::set_attr( // NOLINT(*)
458  const std::string& attr_name,
459  const ValueType& value,
460  int plevel) {
461  CHECK_GT(plevel, 0)
462  << "plevel in set_attr must be greater than 0";
463  // update the attribute map of the key by creating new empty if needed.
464  UpdateAttrMap(attr_name,
465  [this, attr_name, value, plevel](any* pmap) {
466  // the callback is in lockscope so is threadsafe.
467  if (pmap->empty()) {
468  OpMap<ValueType> pm;
469  pm.attr_name_ = attr_name;
470  *pmap = std::move(pm);
471  }
472  CHECK(pmap->type() == typeid(OpMap<ValueType>))
473  << "Attribute " << attr_name
474  << " of operator " << this->name
475  << " is registered as inconsistent types"
476  << " previously " << pmap->type().name()
477  << " current " << typeid(OpMap<ValueType>).name();
478  std::vector<std::pair<ValueType, int> >& vec =
479  nnvm::get<OpMap<ValueType> >(*pmap).data_;
480  // resize the value type.
481  if (vec.size() <= index_) {
482  vec.resize(index_ + 1,
483  std::make_pair(ValueType(), 0));
484  }
485  std::pair<ValueType, int>& p = vec[index_];
486  CHECK(p.second != plevel)
487  << "Attribute " << attr_name
488  << " of operator " << this->name
489  << " is already registered with same plevel=" << plevel;
490  if (p.second < plevel) {
491  vec[index_] = std::make_pair(value, plevel);
492  }
493  });
494  return *this;
495 }
496 
497 
498 inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
499  this->description = descr;
500  return *this;
501 }
502 
503 inline Op& Op::add_argument(const std::string &name,
504  const std::string &type,
505  const std::string &description) {
506  arguments.push_back({name, type, type, description});
507  return *this;
508 }
509 
510 inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
511  this->arguments.insert(arguments.end(), args.begin(), args.end());
512  return *this;
513 }
514 
515 inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
516  this->num_inputs = n;
517  return *this;
518 }
519 
520 inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*)
521  this->support_level = n;
522  return *this;
523 }
524 
525 inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
526  this->get_num_inputs = fn;
527  return *this;
528 }
529 
530 inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
531  this->num_outputs = n;
532  return *this;
533 }
534 
535 inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
536  this->get_num_outputs = fn;
537  return *this;
538 }
539 
540 inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
541  this->attr_parser = fn;
542  return *this;
543 }
544 
545 // member functions of OpMap
546 template<typename ValueType>
547 inline int OpMap<ValueType>::count(const Op* op) const {
548  if (contains(op)) {
549  return 1;
550  } else {
551  return 0;
552  }
553 }
554 
555 template<typename ValueType>
556 inline bool OpMap<ValueType>::contains(const Op* op) const {
557  if (op == nullptr) {
558  return false;
559  }
560  const uint32_t idx = op->index_;
561  return idx < data_.size() ? (data_[idx].second != 0) : false;
562 }
563 
564 template<typename ValueType>
565 inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
566  CHECK(op != nullptr);
567  const uint32_t idx = op->index_;
568  CHECK(idx < data_.size() && data_[idx].second)
569  << "Attribute " << attr_name_
570  << " has not been registered for Operator " << op->name;
571  return data_[idx].first;
572 }
573 
574 template<typename ValueType>
575 inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
576  if (op == nullptr) return def_value;
577  const uint32_t idx = op->index_;
578  if (idx < data_.size() && data_[idx].second) {
579  return data_[idx].first;
580  } else {
581  return def_value;
582  }
583 }
584 
585 template<typename ValueType>
586 inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
587  const ValueType& value,
588  int plevel) {
589  auto trigger = [attr_name, value, plevel](Op* op) {
590  op->set_attr<ValueType>(attr_name, value, plevel);
591  };
592  Op::AddGroupTrigger(group_name, trigger);
593  return *this;
594 }
595 
596 } // namespace nnvm
597 
598 #endif // NNVM_OP_H_
std::vector< ParamFieldInfo > arguments
Definition: op.h:113
Definition: base.h:35
Registry class. Registry can be used to register global singletons. The most commonly use case are fa...
Definition: registry.h:27
Op & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:503
std::string description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:111
Op & set_attr_parser(std::function< void(NodeAttrs *attrs)> fn)
Set the attr_parser function.
Definition: op.h:540
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:119
const ValueType & operator[](const Op *op) const
get the corresponding value element at op
Definition: op.h:565
static const OpMap< ValueType > & GetAttr(const std::string &attr_name)
Get additional registered attribute about operators. If nothing has been registered, an empty OpMap will be returned.
Definition: op.h:439
Op & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:498
Op & set_num_outputs(uint32_t n)
Set the num_outputs.
Definition: op.h:530
OpGroup & set_attr(const std::string &attr_name, const ValueType &value, int plevel=1)
Register additional attributes to operator group.
Definition: op.h:586
Op & set_support_level(uint32_t level)
Set the support level of op.
Definition: op.h:520
#define NNVM_DLL
NNVM_DLL prefix for windows.
Definition: c_api.h:37
Op & set_num_inputs(uint32_t n)
Set the num_inputs.
Definition: op.h:515
bool contains(const Op *op) const
Check if the map has op as key.
Definition: op.h:556
auxiliary data structure used to set attributes to a group of operators
Definition: op.h:362
Op & add_arguments(const std::vector< ParamFieldInfo > &args)
Append list if arguments to the end.
Definition: op.h:510
std::string name
name of the operator
Definition: op.h:106
nnvm::Op Op
operator structure from NNVM
Definition: base.h:99
int count(const Op *op) const
Check if the map has op as key.
Definition: op.h:547
A map data structure that takes Op* as key and returns ValueType.
Definition: op.h:43
const ValueType & get(const Op *op, const ValueType &def_value) const
get the corresponding value element at op with default value.
Definition: op.h:575
std::string group_name
the tag key to be matched
Definition: op.h:365
C API of NNVM symbolic construction and pass. Enables construction and transformation of Graph in any...
Op & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:457
Provide lightweight util to do parameter setup and checking.
Configuration of nnvm as well as basic data structure.
Operator structure.
Definition: op.h:103