Go to the documentation of this file.
30 #include <unordered_map>
31 #include <unordered_set>
61 std::unordered_map<std::string, std::shared_ptr<any> >
attrs;
69 inline const T&
GetAttr(
const std::string& attr_name)
const;
75 inline bool HasAttr(
const std::string& attr_name)
const;
96 mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
131 inline size_t num_nodes()
const {
return nodes_.size(); }
142 return entry_rptr_[
node_id] + index;
179 inline const std::vector<uint32_t>&
input_nodes()
const {
return input_nodes_; }
182 return mutable_input_nodes_;
185 inline const std::vector<NodeEntry>&
outputs()
const {
return outputs_; }
201 std::vector<Node> nodes_;
203 std::vector<uint32_t> input_nodes_;
205 std::unordered_set<uint32_t> mutable_input_nodes_;
207 std::vector<NodeEntry> outputs_;
209 std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
211 std::vector<size_t> entry_rptr_;
213 std::vector<NodeEntry> input_entries_;
215 std::vector<uint32_t> control_deps_;
225 template <
typename FVisit>
226 inline void DFSVisit(
const std::vector<NodeEntry>& heads, FVisit fvisit);
229 template <
typename T>
231 auto it =
attrs.find(attr_name);
232 CHECK(it !=
attrs.end()) <<
"Cannot find attribute " << attr_name <<
" in the graph";
233 return nnvm::unsafe_get<T>(*it->second);
237 auto it =
attrs.find(attr_name);
238 return it !=
attrs.end();
241 template <
typename T>
243 auto it =
attrs.find(attr_name);
244 CHECK(it !=
attrs.end()) <<
"Cannot find attribute " << attr_name <<
" in the graph";
245 std::shared_ptr<any> sptr = it->second;
248 return std::move(nnvm::get<T>(*sptr));
250 return nnvm::get<T>(*sptr);
254 template <
typename GNode,
typename HashType,
typename FVisit,
typename HashFunc,
typename InDegree,
257 InDegree indegree, GetInput getinput) {
258 std::vector<std::pair<GNode, uint32_t> > stack;
259 std::unordered_set<HashType> visited;
260 for (
auto& head : heads) {
261 HashType head_hash = hash(head);
262 if (visited.count(head_hash) == 0) {
263 stack.push_back(std::make_pair(head, 0));
264 visited.insert(head_hash);
266 while (!stack.empty()) {
267 std::pair<GNode, uint32_t>& back = stack.back();
268 if (back.second == indegree(back.first)) {
272 const GNode& input = getinput(back.first, back.second++);
273 HashType input_hash = hash(input);
274 if (visited.count(input_hash) == 0) {
275 stack.push_back(std::make_pair(input, 0));
276 visited.insert(input_hash);
283 template <
typename FVisit>
284 inline void DFSVisit(
const std::vector<NodeEntry>& heads, FVisit fvisit) {
286 std::vector<GNode> head_nodes(heads.size());
287 std::transform(heads.begin(), heads.end(), head_nodes.begin(),
288 [](
const NodeEntry& e) -> GNode { return &e.node; });
289 PostOrderDFSVisit<GNode, Node*>(
290 head_nodes, [fvisit](GNode n) { fvisit(*n); },
291 [](GNode n) ->
Node* {
return n->get(); },
292 [](GNode n) -> uint32_t {
294 return (*n)->
inputs.size() + (*n)->control_deps.size();
296 [](GNode n, uint32_t index) -> GNode {
297 if (index < (*n)->inputs.size()) {
298 return &(*n)->inputs.at(index).node;
300 return &(*n)->control_deps.at(index - (*n)->inputs.size());
307 #endif // NNVM_GRAPH_H_
bool HasAttr(const std::string &attr_name) const
Check whether has a specific attribute.
Definition: graph.h:236
uint32_t node_id
the source node id in the computation graph
Definition: graph.h:113
const std::unordered_set< uint32_t > & mutable_input_nodes() const
Definition: graph.h:181
Node represents an operation in a computation graph.
Definition: node.h:143
uint32_t entry_id(const NodeEntry &e) const
Get a unique entry id between 0 to num_node_entries() for a given IndexedGraph::NodeEntry.
Definition: graph.h:150
const IndexedGraph & indexed_graph() const
get a indexed graph of current graph, if not exist, create it on demand
const Node & operator[](uint32_t node_id) const
Get the corresponding Node structure for a given node_id.
Definition: graph.h:171
Symbolic computation graph. This is the intermediate representation for optimization pass.
Definition: graph.h:47
std::vector< NodeEntry > inputs
inputs to this node
Definition: node.h:153
uint32_t index
index of output from the source.
Definition: node.h:67
size_t num_node_entries() const
Definition: graph.h:133
uint32_t entry_id(uint32_t node_id, uint32_t index) const
Get a unique entry id between 0 to num_node_entries() for a given IndexedGraph::NodeEntry.
Definition: graph.h:141
Auxiliary data structure to index a graph. It maps Nodes in the graph to consecutive integers node_id...
Definition: graph.h:108
const nnvm::Node * source
pointer to the source node
Definition: graph.h:122
Read only data structure to reference continuous memory region of array. Provide unified view for vec...
Definition: array_view.h:36
uint32_t node_id(const nnvm::Node *node) const
Get the corresponding node id for a given Node in the IndexedGraph.
Definition: graph.h:165
Configuration of nnvm as well as basic data structure.
uint32_t version
version of the node
Definition: graph.h:117
void PostOrderDFSVisit(const std::vector< GNode > &heads, FVisit fvisit, HashFunc hash, InDegree indegree, GetInput getinput)
Definition: graph.h:256
std::unordered_map< std::string, std::shared_ptr< any > > attrs
attributes of a graph Note that attribute is shared pointer and can be shared across graphs.
Definition: graph.h:61
std::weak_ptr< nnvm::Node > weak_ref
weak reference to node
Definition: graph.h:128
const std::vector< NodeEntry > & outputs() const
Definition: graph.h:185
array_view< uint32_t > control_deps
control flow dependencies to the node
Definition: graph.h:126
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
IndexedGraph(const IndexedGraph &)=delete
Symbolic graph construction API.
array_view< NodeEntry > inputs
inputs to the node
Definition: graph.h:124
Node data structure in IndexedGraph.
Definition: graph.h:120
uint32_t index
index of output from the source.
Definition: graph.h:115
const T & GetAttr(const std::string &attr_name) const
Get the immutable attribute from attrs.
Definition: graph.h:230
void DFSVisit(const std::vector< NodeEntry > &heads, FVisit fvisit)
perform a Post Order DFS visit to each node in the graph. This order is deterministic and is also top...
Definition: graph.h:284
const std::vector< uint32_t > & input_nodes() const
Definition: graph.h:179
T MoveCopyAttr(const std::string &attr_name)
Get a move copy of the attribute, implement copy on write semantics. The content is moved if the refe...
Definition: graph.h:242
uint32_t entry_id(const nnvm::NodeEntry &e) const
Get a unique entry id between 0 to num_node_entries() for a given NodeEntry.
Definition: graph.h:157
an entry that represents output data from a node
Definition: node.h:52
bool exist(const nnvm::Node *node) const
Definition: graph.h:188
Graph node data structure.
ObjectPtr node
the source node of this data
Definition: node.h:65
const Node & operator[](const nnvm::Node *node) const
Get the corresponding Node structure.
Definition: graph.h:177
represents a data in the graph
Definition: graph.h:111
size_t num_nodes() const
Definition: graph.h:131
std::vector< NodeEntry > outputs
outputs of the computation graph.
Definition: graph.h:50