mxnet
node.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_NODE_H_
25 #define NNVM_NODE_H_
26 
27 #include <memory>
28 #include <string>
29 #include <unordered_map>
30 #include <utility>
31 #include <vector>
32 
33 #include "base.h"
34 #include "c_api.h"
35 #include "op.h"
36 
37 namespace nnvm {
38 
39 // Forward declare node.
40 class Node;
41 class Symbol;
42 
49 using ObjectPtr = std::shared_ptr<Node>;
50 
52 struct NodeEntry {
53  NodeEntry(ObjectPtr node, uint32_t index, uint32_t version)
54  : node(std::move(node)), index(index), version(version) {}
55 
56  explicit NodeEntry(ObjectPtr node) : node(std::move(node)), index(), version() {}
57 
62  NodeEntry() : node(nullptr), index(), version() {}
63 
67  uint32_t index;
75  uint32_t version;
76 };
77 
82 struct NodeEntryHash {
83  size_t operator()(const NodeEntry& e) const {
84  return std::hash<Node*>()(e.node.get()) ^ (std::hash<size_t>()(e.index) << 1 >> 1) ^
85  (std::hash<size_t>()(e.version) << 1);
86  }
87 };
88 
94  size_t operator()(const NodeEntry& a, const NodeEntry& b) const {
95  return (a.node.get() == b.node.get()) && (a.index == b.index) && (a.version == b.version);
96  }
97 };
98 
100 template <typename ValueType>
101 using NodeEntryMap = std::unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>;
102 
107 struct NodeAttrs {
112  const Op* op{nullptr};
114  std::string name;
116  std::unordered_map<std::string, std::string> dict;
122  any parsed;
137  std::vector<std::shared_ptr<Symbol> > subgraphs;
138 };
139 
143 class NNVM_DLL Node {
144  public:
145  Node() = default;
146  Node(const Op* op, const std::string& name) {
147  this->attrs.op = op;
148  this->attrs.name = name;
149  }
153  std::vector<NodeEntry> inputs;
158  std::vector<ObjectPtr> control_deps;
160  any info;
162  ~Node();
164  inline const Op* op() const;
170  inline bool is_variable() const;
172  inline uint32_t num_outputs() const;
174  inline uint32_t num_inputs() const;
179  template <class... Args>
180  static ObjectPtr Create(Args&&... args) {
181  return std::make_shared<Node>(std::forward<Args>(args)...);
182  }
183 };
184 
193 inline NodeEntry MakeNode(const char* op_name, std::string node_name, std::vector<NodeEntry> inputs,
194  std::unordered_map<std::string, std::string> attrs =
195  std::unordered_map<std::string, std::string>()) {
196  ObjectPtr p = Node::Create();
197  p->attrs.op = nnvm::Op::Get(op_name);
198  p->attrs.name = std::move(node_name);
199  p->attrs.dict = attrs;
200  if (p->attrs.op->attr_parser) {
201  p->attrs.op->attr_parser(&(p->attrs));
202  }
203  p->inputs = std::move(inputs);
204  return NodeEntry(p, 0, 0);
205 }
206 
207 // implementation of functions.
208 inline const Op* Node::op() const { return this->attrs.op; }
209 
210 inline bool Node::is_variable() const { return this->op() == nullptr; }
211 
212 inline uint32_t Node::num_outputs() const {
213  if (is_variable()) return 1;
214  if (this->op()->get_num_outputs == nullptr) {
215  return this->op()->num_outputs;
216  } else {
217  return this->op()->get_num_outputs(this->attrs);
218  }
219 }
220 
221 inline uint32_t Node::num_inputs() const {
222  if (is_variable()) return 1;
223  if (this->op()->get_num_inputs == nullptr) {
224  return this->op()->num_inputs;
225  } else {
226  return this->op()->get_num_inputs(this->attrs);
227  }
228 }
229 
230 } // namespace nnvm
231 
232 #endif // NNVM_NODE_H_
nnvm::NodeAttrs::name
std::string name
name of the node
Definition: node.h:114
nnvm::Node::is_variable
bool is_variable() const
return whether node is placeholder variable. This is equivalent to op == nullptr
Definition: node.h:210
nnvm::Op::get_num_inputs
std::function< uint32_t(const NodeAttrs &attrs)> get_num_inputs
get number of inputs given information about the node.
Definition: op.h:149
nnvm::Op::num_inputs
uint32_t num_inputs
number of inputs to the operator, -1 means it is variable length When get_num_inputs is presented,...
Definition: op.h:123
nnvm::NodeEntry::NodeEntry
NodeEntry(ObjectPtr node, uint32_t index, uint32_t version)
Definition: node.h:53
nnvm::Node::op
const Op * op() const
Definition: node.h:208
nnvm::NodeEntryEqual::operator()
size_t operator()(const NodeEntry &a, const NodeEntry &b) const
Definition: node.h:94
nnvm::Node
Node represents an operation in a computation graph.
Definition: node.h:143
nnvm::Op::name
std::string name
name of the operator
Definition: op.h:108
nnvm::NodeEntryHash
This lets you use a NodeEntry as a key in a unordered_map of the form unordered_map<NodeEntry,...
Definition: node.h:82
nnvm::NodeEntry::NodeEntry
NodeEntry()
Definition: node.h:62
nnvm::MakeNode
NodeEntry MakeNode(const char *op_name, std::string node_name, std::vector< NodeEntry > inputs, std::unordered_map< std::string, std::string > attrs=std::unordered_map< std::string, std::string >())
Quick utilities make node.
Definition: node.h:193
nnvm::Node::num_outputs
uint32_t num_outputs() const
Definition: node.h:212
nnvm::NodeAttrs::op
const Op * op
The operator this node uses. For place holder variable, op == nullptr.
Definition: node.h:112
nnvm::Node::Node
Node(const Op *op, const std::string &name)
Definition: node.h:146
nnvm::Node::inputs
std::vector< NodeEntry > inputs
inputs to this node
Definition: node.h:153
nnvm::NodeEntry::index
uint32_t index
index of output from the source.
Definition: node.h:67
nnvm::NodeEntry::version
uint32_t version
version of input Variable. This field can only be nonzero when this->node is a Variable node....
Definition: node.h:75
nnvm::Op::Get
static const Op * Get(const std::string &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
op.h
Operator information structor.
nnvm::Node::attrs
NodeAttrs attrs
The attributes in the node.
Definition: node.h:151
nnvm::NodeAttrs::dict
std::unordered_map< std::string, std::string > dict
The dictionary representation of attributes.
Definition: node.h:116
nnvm::Node::Create
static ObjectPtr Create(Args &&... args)
create a new empty shared_ptr of Node.
Definition: node.h:180
base.h
Configuration of nnvm as well as basic data structure.
nnvm::NodeEntryMap
std::unordered_map< NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual > NodeEntryMap
Definition: node.h:101
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
nnvm::ObjectPtr
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
nnvm::NodeAttrs::subgraphs
std::vector< std::shared_ptr< Symbol > > subgraphs
Some operators take graphs as input. These operators include control flow operators and high-order fu...
Definition: node.h:137
nnvm::NodeEntryHash::operator()
size_t operator()(const NodeEntry &e) const
Definition: node.h:83
c_api.h
C API of NNVM symbolic construction and pass. Enables construction and transformation of Graph in any...
nnvm::Node::info
any info
additional fields for this node
Definition: node.h:160
nnvm::Node::num_inputs
uint32_t num_inputs() const
Definition: node.h:221
std
Definition: optional.h:251
nnvm::NodeEntry
an entry that represents output data from a node
Definition: node.h:52
nnvm::Op::get_num_outputs
std::function< uint32_t(const NodeAttrs &attrs)> get_num_outputs
get number of outputs given information about the node.
Definition: op.h:143
NNVM_DLL
#define NNVM_DLL
NNVM_DLL prefix for windows.
Definition: c_api.h:37
nnvm::NodeEntry::node
ObjectPtr node
the source node of this data
Definition: node.h:65
nnvm::Op
Operator structure.
Definition: op.h:105
nnvm::NodeEntryEqual
This lets you use a NodeEntry as a key in a unordered_map of the form unordered_map<NodeEntry,...
Definition: node.h:93
nnvm::NodeAttrs::parsed
any parsed
A parsed version of attributes, This is generated if OpProperty.attr_parser is registered....
Definition: node.h:122
nnvm::Op::num_outputs
uint32_t num_outputs
number of outputs of the operator When get_num_outputs is presented. The number of outputs will be de...
Definition: op.h:131
nnvm
Definition: base.h:35
nnvm::Node::control_deps
std::vector< ObjectPtr > control_deps
Optional control flow dependencies Gives operation must be performed before this operation.
Definition: node.h:158
nnvm::NodeEntry::NodeEntry
NodeEntry(ObjectPtr node)
Definition: node.h:56