mxnet
node.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef NNVM_NODE_H_
26 #define NNVM_NODE_H_
27 
28 #include <memory>
29 #include <string>
30 #include <vector>
31 #include <utility>
32 #include <unordered_map>
33 #include "base.h"
34 #include "op.h"
35 #include "c_api.h"
36 
37 namespace nnvm {
38 
39 // Forward declare node.
40 class Node;
41 class Symbol;
42 
49 using NodePtr = std::shared_ptr<Node>;
50 
52 struct NodeEntry {
53  NodeEntry(NodePtr node, uint32_t index, uint32_t version):
54  node(std::move(node)),
55  index(index),
56  version(version)
57  {}
58 
59  explicit NodeEntry(NodePtr node):
60  node(std::move(node)),
61  index(),
62  version()
63  {}
64 
70  node(nullptr),
71  index(),
72  version()
73  {}
74 
78  uint32_t index;
85  uint32_t version;
86 };
87 
92 struct NodeEntryHash {
93  size_t operator()(const NodeEntry& e) const {
94  return std::hash<Node*>()(e.node.get()) ^
95  (std::hash<size_t>()(e.index) << 1 >> 1) ^
96  (std::hash<size_t>()(e.version) << 1);
97  }
98 };
99 
105  size_t operator()(const NodeEntry& a, const NodeEntry& b) const {
106  return (a.node.get() == b.node.get()) &&
107  (a.index == b.index) &&
108  (a.version == b.version);
109  }
110 };
111 
113 template<typename ValueType>
114 using NodeEntryMap = std::unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>;
115 
120 struct NodeAttrs {
125  const Op *op{nullptr};
127  std::string name;
129  std::unordered_map<std::string, std::string> dict;
135  any parsed;
150  std::vector<std::shared_ptr<Symbol> > subgraphs;
151 };
152 
156 class NNVM_DLL Node {
157  public:
158  Node() = default;
159  Node(const Op* op, const std::string& name) {
160  this->attrs.op = op;
161  this->attrs.name = name;
162  }
166  std::vector<NodeEntry> inputs;
171  std::vector<NodePtr> control_deps;
173  any info;
175  ~Node();
177  inline const Op* op() const;
183  inline bool is_variable() const;
185  inline uint32_t num_outputs() const;
187  inline uint32_t num_inputs() const;
192  template<class ...Args>
193  static NodePtr Create(Args&&... args) {
194  return std::make_shared<Node>(std::forward<Args>(args)...);
195  }
196 };
197 
207  const char* op_name,
208  std::string node_name,
209  std::vector<NodeEntry> inputs,
210  std::unordered_map<std::string, std::string> attrs =
211  std::unordered_map<std::string, std::string>()) {
212  NodePtr p = Node::Create();
213  p->attrs.op = nnvm::Op::Get(op_name);
214  p->attrs.name = std::move(node_name);
215  p->attrs.dict = attrs;
216  if (p->attrs.op->attr_parser) {
217  p->attrs.op->attr_parser(&(p->attrs));
218  }
219  p->inputs = std::move(inputs);
220  return NodeEntry(p, 0, 0);
221 }
222 
223 // implementation of functions.
224 inline const Op* Node::op() const {
225  return this->attrs.op;
226 }
227 
228 inline bool Node::is_variable() const {
229  return this->op() == nullptr;
230 }
231 
232 inline uint32_t Node::num_outputs() const {
233  if (is_variable()) return 1;
234  if (this->op()->get_num_outputs == nullptr) {
235  return this->op()->num_outputs;
236  } else {
237  return this->op()->get_num_outputs(this->attrs);
238  }
239 }
240 
241 inline uint32_t Node::num_inputs() const {
242  if (is_variable()) return 1;
243  if (this->op()->get_num_inputs == nullptr) {
244  return this->op()->num_inputs;
245  } else {
246  return this->op()->get_num_inputs(this->attrs);
247  }
248 }
249 
250 } // namespace nnvm
251 
252 #endif // NNVM_NODE_H_
uint32_t version
version of input Variable. This field can only be nonzero when this->node is a Variable node...
Definition: node.h:85
Definition: base.h:36
size_t operator()(const NodeEntry &a, const NodeEntry &b) const
Definition: node.h:105
Node(const Op *op, const std::string &name)
Definition: node.h:159
bool is_variable() const
return whether node is placeholder variable. This is equivalent to op == nullptr
Definition: node.h:228
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:120
size_t operator()(const NodeEntry &e) const
Definition: node.h:93
This lets you use a NodeEntry as a key in a unordered_map of the form unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>
Definition: node.h:104
NodeEntry(NodePtr node, uint32_t index, uint32_t version)
Definition: node.h:53
Definition: optional.h:241
std::shared_ptr< Node > NodePtr
we always used NodePtr for a reference pointer to the node, so this alias can be changed in case...
Definition: node.h:49
std::vector< NodeEntry > inputs
inputs to this node
Definition: node.h:166
NodeAttrs attrs
The attributes in the node.
Definition: node.h:164
std::vector< std::shared_ptr< Symbol > > subgraphs
Some operators take graphs as input. These operators include control flow operators and high-order fu...
Definition: node.h:150
NodePtr node
the source node of this data
Definition: node.h:76
NodeEntry(NodePtr node)
Definition: node.h:59
Node represents an operation in a computation graph.
Definition: node.h:156
any parsed
A parsed version of attributes, This is generated if OpProperty.attr_parser is registered. The object can be used to quickly access attributes.
Definition: node.h:135
#define NNVM_DLL
NNVM_DLL prefix for windows.
Definition: c_api.h:38
This lets you use a NodeEntry as a key in a unordered_map of the form unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>
Definition: node.h:92
std::string name
name of the node
Definition: node.h:127
std::string name
name of the operator
Definition: op.h:107
uint32_t num_inputs() const
Definition: node.h:241
std::vector< NodePtr > control_deps
Optional control flow dependencies Gives operation must be performed before this operation.
Definition: node.h:171
uint32_t num_outputs() const
Definition: node.h:232
an entry that represents output data from a node
Definition: node.h:52
std::unordered_map< NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual > NodeEntryMap
Definition: node.h:114
NodeEntry()
Definition: node.h:69
NodeEntry MakeNode(const char *op_name, std::string node_name, std::vector< NodeEntry > inputs, std::unordered_map< std::string, std::string > attrs=std::unordered_map< std::string, std::string >())
Quick utilities make node.
Definition: node.h:206
std::unordered_map< std::string, std::string > dict
The dictionary representation of attributes.
Definition: node.h:129
Operator information structor.
static const Op * Get(const std::string &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
uint32_t index
index of output from the source.
Definition: node.h:78
C API of NNVM symbolic construction and pass. Enables construction and transformation of Graph in any...
static NodePtr Create(Args &&...args)
create a new empty shared_ptr of Node.
Definition: node.h:193
any info
additional fields for this node
Definition: node.h:173
Configuration of nnvm as well as basic data structure.
Operator structure.
Definition: op.h:104
const Op * op() const
Definition: node.h:224