mxnet
graph.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_GRAPH_H_
25 #define NNVM_GRAPH_H_
26 
27 #include <vector>
28 #include <string>
29 #include <utility>
30 #include <algorithm>
31 #include <memory>
32 #include <unordered_map>
33 #include <unordered_set>
34 #include "base.h"
35 #include "node.h"
36 #include "symbolic.h"
37 
38 namespace nnvm {
39 
40 class IndexedGraph;
41 
46 class Graph {
47  public:
49  std::vector<NodeEntry> outputs;
60  std::unordered_map<std::string, std::shared_ptr<any> > attrs;
67  template<typename T>
68  inline const T& GetAttr(const std::string& attr_name) const;
74  inline bool HasAttr(const std::string& attr_name) const;
84  template<typename T>
85  inline T MoveCopyAttr(const std::string& attr_name);
91  const IndexedGraph& indexed_graph() const;
92 
93  private:
94  // internal structure of indexed graph
95  mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
96 };
97 
108  public:
110  struct NodeEntry {
112  uint32_t node_id;
114  uint32_t index;
116  uint32_t version;
117  };
119  struct Node {
127  std::weak_ptr<nnvm::Node> weak_ref;
128  };
130  inline size_t num_nodes() const {
131  return nodes_.size();
132  }
134  inline size_t num_node_entries() const {
135  return entry_rptr_.back();
136  }
144  inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
145  return entry_rptr_[node_id] + index;
146  }
153  inline uint32_t entry_id(const NodeEntry& e) const {
154  return entry_rptr_[e.node_id] + e.index;
155  }
162  inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
163  return entry_rptr_[node_id(e.node.get())] + e.index;
164  }
170  inline uint32_t node_id(const nnvm::Node* node) const {
171  return node2index_.at(node);
172  }
178  inline const Node& operator[](uint32_t node_id) const {
179  return nodes_[node_id];
180  }
186  inline const Node& operator[](const nnvm::Node* node) const {
187  return nodes_[node_id(node)];
188  }
190  inline const std::vector<uint32_t>& input_nodes() const {
191  return input_nodes_;
192  }
194  inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
195  return mutable_input_nodes_;
196  }
198  inline const std::vector<NodeEntry>& outputs() const {
199  return outputs_;
200  }
201 
203  inline bool exist(const nnvm::Node* node) const {
204  return node2index_.count(node);
205  }
206 
207  // disalllow copy assign
208  IndexedGraph(const IndexedGraph&) = delete;
209 
210  private:
211  friend class Graph;
216  explicit IndexedGraph(const Graph& other);
217  // Node pointers in CSR structure.
218  std::vector<Node> nodes_;
219  // Index to all input nodes.
220  std::vector<uint32_t> input_nodes_;
221  // Index to all mutable input nodes.
222  std::unordered_set<uint32_t> mutable_input_nodes_;
223  // space to store the outputs entries
224  std::vector<NodeEntry> outputs_;
225  // mapping from node to index.
226  std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
227  // CSR pointer of node entries
228  std::vector<size_t> entry_rptr_;
229  // space to store input entries of each
230  std::vector<NodeEntry> input_entries_;
231  // control flow dependencies
232  std::vector<uint32_t> control_deps_;
233 };
234 
242 template<typename FVisit>
243 inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
244 
245 // inline function implementations
246 template<typename T>
247 inline const T& Graph::GetAttr(const std::string& attr_name) const {
248  auto it = attrs.find(attr_name);
249  CHECK(it != attrs.end())
250  << "Cannot find attribute " << attr_name << " in the graph";
251  return nnvm::unsafe_get<T>(*it->second);
252 }
253 
254 inline bool Graph::HasAttr(const std::string& attr_name) const {
255  auto it = attrs.find(attr_name);
256  return it != attrs.end();
257 }
258 
259 template<typename T>
260 inline T Graph::MoveCopyAttr(const std::string& attr_name) {
261  auto it = attrs.find(attr_name);
262  CHECK(it != attrs.end())
263  << "Cannot find attribute " << attr_name << " in the graph";
264  std::shared_ptr<any> sptr = it->second;
265  attrs.erase(it);
266  if (sptr.unique()) {
267  return std::move(nnvm::get<T>(*sptr));
268  } else {
269  return nnvm::get<T>(*sptr);
270  }
271 }
272 
273 template <typename GNode, typename HashType,
274  typename FVisit, typename HashFunc,
275  typename InDegree, typename GetInput>
276 void PostOrderDFSVisit(const std::vector<GNode>& heads,
277  FVisit fvisit,
278  HashFunc hash,
279  InDegree indegree,
280  GetInput getinput) {
281  std::vector<std::pair<GNode, uint32_t> > stack;
282  std::unordered_set<HashType> visited;
283  for (auto& head : heads) {
284  HashType head_hash = hash(head);
285  if (visited.count(head_hash) == 0) {
286  stack.push_back(std::make_pair(head, 0));
287  visited.insert(head_hash);
288  }
289  while (!stack.empty()) {
290  std::pair<GNode, uint32_t>& back = stack.back();
291  if (back.second == indegree(back.first)) {
292  fvisit(back.first);
293  stack.pop_back();
294  } else {
295  const GNode& input = getinput(back.first, back.second++);
296  HashType input_hash = hash(input);
297  if (visited.count(input_hash) == 0) {
298  stack.push_back(std::make_pair(input, 0));
299  visited.insert(input_hash);
300  }
301  }
302  }
303  }
304 }
305 
306 template<typename FVisit>
307 inline void DFSVisit(const std::vector<NodeEntry>& heads,
308  FVisit fvisit) {
309  typedef const ObjectPtr* GNode;
310  std::vector<GNode> head_nodes(heads.size());
311  std::transform(heads.begin(), heads.end(), head_nodes.begin(),
312  [](const NodeEntry& e)->GNode {
313  return &e.node;
314  });
315  PostOrderDFSVisit<GNode, Node*>(
316  head_nodes,
317  [fvisit](GNode n) {
318  fvisit(*n);
319  }, // FVisit
320  [](GNode n)->Node* {
321  return n->get();
322  }, // HashFunc
323  [](GNode n)->uint32_t { // InDegree
324  if (!(*n)) return 0;
325  return (*n)->inputs.size() + (*n)->control_deps.size();
326  },
327  [](GNode n, uint32_t index)->GNode { // GetInput
328  if (index < (*n)->inputs.size()) {
329  return &(*n)->inputs.at(index).node;
330  } else {
331  return &(*n)->control_deps.at(index - (*n)->inputs.size());
332  }
333  });
334 }
335 
336 } // namespace nnvm
337 
338 #endif // NNVM_GRAPH_H_
Read only data structure to reference continuous memory region of array. Provide unified view for vec...
Definition: array_view.h:36
bool HasAttr(const std::string &attr_name) const
Check whether has a specific attribute.
Definition: graph.h:254
Definition: base.h:35
void PostOrderDFSVisit(const std::vector< GNode > &heads, FVisit fvisit, HashFunc hash, InDegree indegree, GetInput getinput)
Definition: graph.h:276
const std::unordered_set< uint32_t > & mutable_input_nodes() const
Definition: graph.h:194
const IndexedGraph & indexed_graph() const
get a indexed graph of current graph, if not exist, create it on demand
T MoveCopyAttr(const std::string &attr_name)
Get a move copy of the attribute, implement copy on write semantics. The content is moved if the refe...
Definition: graph.h:260
const std::vector< uint32_t > & input_nodes() const
Definition: graph.h:190
ObjectPtr node
the source node of this data
Definition: node.h:75
std::vector< NodeEntry > inputs
inputs to this node
Definition: node.h:165
Node data structure in IndexedGraph.
Definition: graph.h:119
array_view< uint32_t > control_deps
control flow dependencies to the node
Definition: graph.h:125
const nnvm::Node * source
pointer to the source node
Definition: graph.h:121
Graph node data structure.
std::unordered_map< std::string, std::shared_ptr< any > > attrs
attributes of a graph Note that attribute is shared pointer and can be shared across graphs...
Definition: graph.h:60
uint32_t entry_id(uint32_t node_id, uint32_t index) const
Get a unique entry id between 0 to num_node_entries() for a given IndexedGraph::NodeEntry.
Definition: graph.h:144
std::weak_ptr< nnvm::Node > weak_ref
weak reference to node
Definition: graph.h:127
bool exist(const nnvm::Node *node) const
Definition: graph.h:203
Node represents an operation in a computation graph.
Definition: node.h:155
size_t num_nodes() const
Definition: graph.h:130
void DFSVisit(const std::vector< NodeEntry > &heads, FVisit fvisit)
perform a Post Order DFS visit to each node in the graph. This order is deterministic and is also top...
Definition: graph.h:307
Symbolic computation graph. This is the intermediate representation for optimization pass...
Definition: graph.h:46
Auxiliary data structure to index a graph. It maps Nodes in the graph to consecutive integers node_id...
Definition: graph.h:107
uint32_t version
version of the node
Definition: graph.h:116
uint32_t node_id(const nnvm::Node *node) const
Get the corresponding node id for a given Node in the IndexedGraph.
Definition: graph.h:170
const Node & operator[](uint32_t node_id) const
Get the corresponding Node structure for a given node_id.
Definition: graph.h:178
const Node & operator[](const nnvm::Node *node) const
Get the corresponding Node structure.
Definition: graph.h:186
uint32_t entry_id(const nnvm::NodeEntry &e) const
Get a unique entry id between 0 to num_node_entries() for a given NodeEntry.
Definition: graph.h:162
an entry that represents output data from a node
Definition: node.h:51
uint32_t index
index of output from the source.
Definition: graph.h:114
const T & GetAttr(const std::string &attr_name) const
Get the immutable attribute from attrs.
Definition: graph.h:247
uint32_t node_id
the source node id in the computation graph
Definition: graph.h:112
size_t num_node_entries() const
Definition: graph.h:134
array_view< NodeEntry > inputs
inputs to the node
Definition: graph.h:123
represents a data in the graph
Definition: graph.h:110
std::vector< NodeEntry > outputs
outputs of the computation graph.
Definition: graph.h:49
uint32_t entry_id(const NodeEntry &e) const
Get a unique entry id between 0 to num_node_entries() for a given IndexedGraph::NodeEntry.
Definition: graph.h:153
const std::vector< NodeEntry > & outputs() const
Definition: graph.h:198
Symbolic graph construction API.
uint32_t index
index of output from the source.
Definition: node.h:77
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case...
Definition: node.h:48
Configuration of nnvm as well as basic data structure.