mxnet
graph.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_GRAPH_H_
25 #define NNVM_GRAPH_H_
26 
27 #include <algorithm>
28 #include <memory>
29 #include <string>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include <utility>
33 #include <vector>
34 
35 #include "base.h"
36 #include "node.h"
37 #include "symbolic.h"
38 
39 namespace nnvm {
40 
41 class IndexedGraph;
42 
47 class Graph {
48  public:
50  std::vector<NodeEntry> outputs;
61  std::unordered_map<std::string, std::shared_ptr<any> > attrs;
68  template <typename T>
69  inline const T& GetAttr(const std::string& attr_name) const;
75  inline bool HasAttr(const std::string& attr_name) const;
85  template <typename T>
86  inline T MoveCopyAttr(const std::string& attr_name);
92  const IndexedGraph& indexed_graph() const;
93 
94  private:
95  // internal structure of indexed graph
96  mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
97 };
98 
109  public:
111  struct NodeEntry {
113  uint32_t node_id;
115  uint32_t index;
117  uint32_t version;
118  };
120  struct Node {
128  std::weak_ptr<nnvm::Node> weak_ref;
129  };
131  inline size_t num_nodes() const { return nodes_.size(); }
133  inline size_t num_node_entries() const { return entry_rptr_.back(); }
141  inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
142  return entry_rptr_[node_id] + index;
143  }
150  inline uint32_t entry_id(const NodeEntry& e) const { return entry_rptr_[e.node_id] + e.index; }
157  inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
158  return entry_rptr_[node_id(e.node.get())] + e.index;
159  }
165  inline uint32_t node_id(const nnvm::Node* node) const { return node2index_.at(node); }
171  inline const Node& operator[](uint32_t node_id) const { return nodes_[node_id]; }
177  inline const Node& operator[](const nnvm::Node* node) const { return nodes_[node_id(node)]; }
179  inline const std::vector<uint32_t>& input_nodes() const { return input_nodes_; }
181  inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
182  return mutable_input_nodes_;
183  }
185  inline const std::vector<NodeEntry>& outputs() const { return outputs_; }
186 
188  inline bool exist(const nnvm::Node* node) const { return node2index_.count(node); }
189 
190  // disalllow copy assign
191  IndexedGraph(const IndexedGraph&) = delete;
192 
193  private:
194  friend class Graph;
199  explicit IndexedGraph(const Graph& other);
200  // Node pointers in CSR structure.
201  std::vector<Node> nodes_;
202  // Index to all input nodes.
203  std::vector<uint32_t> input_nodes_;
204  // Index to all mutable input nodes.
205  std::unordered_set<uint32_t> mutable_input_nodes_;
206  // space to store the outputs entries
207  std::vector<NodeEntry> outputs_;
208  // mapping from node to index.
209  std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
210  // CSR pointer of node entries
211  std::vector<size_t> entry_rptr_;
212  // space to store input entries of each
213  std::vector<NodeEntry> input_entries_;
214  // control flow dependencies
215  std::vector<uint32_t> control_deps_;
216 };
217 
225 template <typename FVisit>
226 inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
227 
228 // inline function implementations
229 template <typename T>
230 inline const T& Graph::GetAttr(const std::string& attr_name) const {
231  auto it = attrs.find(attr_name);
232  CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph";
233  return nnvm::unsafe_get<T>(*it->second);
234 }
235 
236 inline bool Graph::HasAttr(const std::string& attr_name) const {
237  auto it = attrs.find(attr_name);
238  return it != attrs.end();
239 }
240 
241 template <typename T>
242 inline T Graph::MoveCopyAttr(const std::string& attr_name) {
243  auto it = attrs.find(attr_name);
244  CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph";
245  std::shared_ptr<any> sptr = it->second;
246  attrs.erase(it);
247  if (sptr.unique()) {
248  return std::move(nnvm::get<T>(*sptr));
249  } else {
250  return nnvm::get<T>(*sptr);
251  }
252 }
253 
254 template <typename GNode, typename HashType, typename FVisit, typename HashFunc, typename InDegree,
255  typename GetInput>
256 void PostOrderDFSVisit(const std::vector<GNode>& heads, FVisit fvisit, HashFunc hash,
257  InDegree indegree, GetInput getinput) {
258  std::vector<std::pair<GNode, uint32_t> > stack;
259  std::unordered_set<HashType> visited;
260  for (auto& head : heads) {
261  HashType head_hash = hash(head);
262  if (visited.count(head_hash) == 0) {
263  stack.push_back(std::make_pair(head, 0));
264  visited.insert(head_hash);
265  }
266  while (!stack.empty()) {
267  std::pair<GNode, uint32_t>& back = stack.back();
268  if (back.second == indegree(back.first)) {
269  fvisit(back.first);
270  stack.pop_back();
271  } else {
272  const GNode& input = getinput(back.first, back.second++);
273  HashType input_hash = hash(input);
274  if (visited.count(input_hash) == 0) {
275  stack.push_back(std::make_pair(input, 0));
276  visited.insert(input_hash);
277  }
278  }
279  }
280  }
281 }
282 
283 template <typename FVisit>
284 inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit) {
285  typedef const ObjectPtr* GNode;
286  std::vector<GNode> head_nodes(heads.size());
287  std::transform(heads.begin(), heads.end(), head_nodes.begin(),
288  [](const NodeEntry& e) -> GNode { return &e.node; });
289  PostOrderDFSVisit<GNode, Node*>(
290  head_nodes, [fvisit](GNode n) { fvisit(*n); }, // FVisit
291  [](GNode n) -> Node* { return n->get(); }, // HashFunc
292  [](GNode n) -> uint32_t { // InDegree
293  if (!(*n)) return 0;
294  return (*n)->inputs.size() + (*n)->control_deps.size();
295  },
296  [](GNode n, uint32_t index) -> GNode { // GetInput
297  if (index < (*n)->inputs.size()) {
298  return &(*n)->inputs.at(index).node;
299  } else {
300  return &(*n)->control_deps.at(index - (*n)->inputs.size());
301  }
302  });
303 }
304 
305 } // namespace nnvm
306 
307 #endif // NNVM_GRAPH_H_
nnvm::Graph::HasAttr
bool HasAttr(const std::string &attr_name) const
Check whether has a specific attribute.
Definition: graph.h:236
nnvm::IndexedGraph::NodeEntry::node_id
uint32_t node_id
the source node id in the computation graph
Definition: graph.h:113
nnvm::IndexedGraph::mutable_input_nodes
const std::unordered_set< uint32_t > & mutable_input_nodes() const
Definition: graph.h:181
nnvm::Node
Node represents an operation in a computation graph.
Definition: node.h:143
nnvm::IndexedGraph::entry_id
uint32_t entry_id(const NodeEntry &e) const
Get a unique entry id between 0 to num_node_entries() for a given IndexedGraph::NodeEntry.
Definition: graph.h:150
nnvm::Graph::indexed_graph
const IndexedGraph & indexed_graph() const
get a indexed graph of current graph, if not exist, create it on demand
nnvm::IndexedGraph::operator[]
const Node & operator[](uint32_t node_id) const
Get the corresponding Node structure for a given node_id.
Definition: graph.h:171
nnvm::Graph
Symbolic computation graph. This is the intermediate representation for optimization pass.
Definition: graph.h:47
nnvm::Node::inputs
std::vector< NodeEntry > inputs
inputs to this node
Definition: node.h:153
nnvm::NodeEntry::index
uint32_t index
index of output from the source.
Definition: node.h:67
nnvm::IndexedGraph::num_node_entries
size_t num_node_entries() const
Definition: graph.h:133
nnvm::IndexedGraph::entry_id
uint32_t entry_id(uint32_t node_id, uint32_t index) const
Get a unique entry id between 0 to num_node_entries() for a given IndexedGraph::NodeEntry.
Definition: graph.h:141
nnvm::IndexedGraph
Auxiliary data structure to index a graph. It maps Nodes in the graph to consecutive integers node_id...
Definition: graph.h:108
nnvm::IndexedGraph::Node::source
const nnvm::Node * source
pointer to the source node
Definition: graph.h:122
dmlc::array_view
Read only data structure to reference continuous memory region of array. Provide unified view for vec...
Definition: array_view.h:36
nnvm::IndexedGraph::node_id
uint32_t node_id(const nnvm::Node *node) const
Get the corresponding node id for a given Node in the IndexedGraph.
Definition: graph.h:165
base.h
Configuration of nnvm as well as basic data structure.
nnvm::IndexedGraph::NodeEntry::version
uint32_t version
version of the node
Definition: graph.h:117
nnvm::PostOrderDFSVisit
void PostOrderDFSVisit(const std::vector< GNode > &heads, FVisit fvisit, HashFunc hash, InDegree indegree, GetInput getinput)
Definition: graph.h:256
nnvm::Graph::attrs
std::unordered_map< std::string, std::shared_ptr< any > > attrs
attributes of a graph Note that attribute is shared pointer and can be shared across graphs.
Definition: graph.h:61
nnvm::IndexedGraph::Node::weak_ref
std::weak_ptr< nnvm::Node > weak_ref
weak reference to node
Definition: graph.h:128
nnvm::IndexedGraph::outputs
const std::vector< NodeEntry > & outputs() const
Definition: graph.h:185
nnvm::IndexedGraph::Node::control_deps
array_view< uint32_t > control_deps
control flow dependencies to the node
Definition: graph.h:126
nnvm::ObjectPtr
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
nnvm::IndexedGraph::IndexedGraph
IndexedGraph(const IndexedGraph &)=delete
symbolic.h
Symbolic graph construction API.
nnvm::IndexedGraph::Node::inputs
array_view< NodeEntry > inputs
inputs to the node
Definition: graph.h:124
nnvm::IndexedGraph::Node
Node data structure in IndexedGraph.
Definition: graph.h:120
nnvm::IndexedGraph::NodeEntry::index
uint32_t index
index of output from the source.
Definition: graph.h:115
nnvm::Graph::GetAttr
const T & GetAttr(const std::string &attr_name) const
Get the immutable attribute from attrs.
Definition: graph.h:230
nnvm::DFSVisit
void DFSVisit(const std::vector< NodeEntry > &heads, FVisit fvisit)
perform a Post Order DFS visit to each node in the graph. This order is deterministic and is also top...
Definition: graph.h:284
nnvm::IndexedGraph::input_nodes
const std::vector< uint32_t > & input_nodes() const
Definition: graph.h:179
nnvm::Graph::MoveCopyAttr
T MoveCopyAttr(const std::string &attr_name)
Get a move copy of the attribute, implement copy on write semantics. The content is moved if the refe...
Definition: graph.h:242
nnvm::IndexedGraph::entry_id
uint32_t entry_id(const nnvm::NodeEntry &e) const
Get a unique entry id between 0 to num_node_entries() for a given NodeEntry.
Definition: graph.h:157
nnvm::NodeEntry
an entry that represents output data from a node
Definition: node.h:52
nnvm::IndexedGraph::exist
bool exist(const nnvm::Node *node) const
Definition: graph.h:188
node.h
Graph node data structure.
nnvm::NodeEntry::node
ObjectPtr node
the source node of this data
Definition: node.h:65
nnvm::IndexedGraph::operator[]
const Node & operator[](const nnvm::Node *node) const
Get the corresponding Node structure.
Definition: graph.h:177
nnvm::IndexedGraph::NodeEntry
represents a data in the graph
Definition: graph.h:111
nnvm::IndexedGraph::num_nodes
size_t num_nodes() const
Definition: graph.h:131
nnvm::Graph::outputs
std::vector< NodeEntry > outputs
outputs of the computation graph.
Definition: graph.h:50
nnvm
Definition: base.h:35