mxnet
op_attr_types.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_OP_ATTR_TYPES_H_
25 #define NNVM_OP_ATTR_TYPES_H_
26 
27 #include <functional>
28 #include <string>
29 #include <unordered_map>
30 #include <utility>
31 #include <vector>
32 
33 #include "base.h"
34 #include "layout.h"
35 #include "node.h"
36 #include "tuple.h"
37 
38 namespace nnvm {
39 
40 // These types are optional attributes in each operator.
41 // Each attribute can be required by some passes.
42 
52 using FListInputNames = std::function<std::vector<std::string>(const NodeAttrs& attrs)>;
53 
64 using FNumVisibleOutputs = std::function<uint32_t(const NodeAttrs& attrs)>;
65 
75 using FListOutputNames = std::function<std::vector<std::string>(const NodeAttrs& attrs)>;
76 
85 using FMutateInputs = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
86 
92 template <typename AttrType>
93 using FInferNodeEntryAttr = std::function<bool(
94  const NodeAttrs& attrs, std::vector<AttrType>* in_attrs, std::vector<AttrType>* out_attrs)>;
95 
103 using FGetAttrDict =
104  std::function<std::unordered_map<std::string, std::string>(const NodeAttrs& attrs)>;
105 
117 
126 
135 using TIsBackward = bool;
136 
146 using TIsGhost = bool;
147 
157 using FInplaceOption = std::function<std::vector<std::pair<int, int> >(const NodeAttrs& attrs)>;
158 
169 using FInplaceIdentity = std::function<std::vector<bool>(const NodeAttrs& attrs)>;
170 
180 using FIgnoreInputs = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
181 
191 using FGradient = std::function<std::vector<NodeEntry>(const ObjectPtr& nodeptr,
192  const std::vector<NodeEntry>& out_grads)>;
193 
202  std::function<void(const NodeAttrs& attrs, ObjectPtr var, const int index)>;
203 
223 using FCorrectLayout =
224  std::function<bool(const NodeAttrs& attrs, std::vector<Layout>* ilayouts,
225  const std::vector<Layout>* last_ilayouts, std::vector<Layout>* olayouts)>;
226 
237 using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
238 
239 } // namespace nnvm
240 
241 #endif // NNVM_OP_ATTR_TYPES_H_
nnvm::FSetInputVarAttrOnCompose
std::function< void(const NodeAttrs &attrs, ObjectPtr var, const int index)> FSetInputVarAttrOnCompose
Set the attributes of input variable. Usually used for setting initialization or weight decay.
Definition: op_attr_types.h:202
nnvm::FNumVisibleOutputs
std::function< uint32_t(const NodeAttrs &attrs)> FNumVisibleOutputs
Return number of visible outputs by the user.
Definition: op_attr_types.h:64
nnvm::TIsBackward
bool TIsBackward
Whether this op is an explicit backward operator, If TIsBackward is true:
Definition: op_attr_types.h:135
nnvm::FInferShape
FInferNodeEntryAttr< TShape > FInferShape
Shape inference function. Update the shapes given the input shape information. TShape....
Definition: op_attr_types.h:116
nnvm::FGradient
std::function< std::vector< NodeEntry >(const ObjectPtr &nodeptr, const std::vector< NodeEntry > &out_grads)> FGradient
Get the gradient node of the op node This function generates the backward graph of the node.
Definition: op_attr_types.h:192
tuple.h
Data structure Tuple and TShape to store dynamic sized shapes.
nnvm::FListOutputNames
std::function< std::vector< std::string >(const NodeAttrs &attrs)> FListOutputNames
Return list of output arguments names of each operator.
Definition: op_attr_types.h:75
base.h
Configuration of nnvm as well as basic data structure.
nnvm::FMutateInputs
std::function< std::vector< uint32_t >(const NodeAttrs &attrs)> FMutateInputs
Check whether operator will mutate k-th input.
Definition: op_attr_types.h:85
nnvm::FIgnoreInputs
std::function< std::vector< uint32_t >(const NodeAttrs &attrs)> FIgnoreInputs
Get list of inputs in the op whose content are actually not used by the operator These are dummy inpu...
Definition: op_attr_types.h:180
nnvm::FInplaceIdentity
std::function< std::vector< bool >(const NodeAttrs &attrs)> FInplaceIdentity
Get if the inplace option is an identity This function enables inplace optimization even when input r...
Definition: op_attr_types.h:169
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
nnvm::ObjectPtr
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
nnvm::FListInputNames
std::function< std::vector< std::string >(const NodeAttrs &attrs)> FListInputNames
Return list of input arguments names of each operator.
Definition: op_attr_types.h:52
nnvm::TIsGhost
bool TIsGhost
Whether this op is a ghost node. If TIsGhost is true:
Definition: op_attr_types.h:146
nnvm::FInputGraph
std::function< std::vector< uint32_t >(const NodeAttrs &attrs)> FInputGraph
Get a list of inputs that represent graphs instead of data. Normally, input symbols are considered as...
Definition: op_attr_types.h:237
nnvm::FInplaceOption
std::function< std::vector< std::pair< int, int > >(const NodeAttrs &attrs)> FInplaceOption
Get possible inplace options. This function enables optimization to reuse memory of inputs in output.
Definition: op_attr_types.h:157
layout.h
Layout expression. The layout is composed of upper cases, lower cases and numbers,...
nnvm::FGetAttrDict
std::function< std::unordered_map< std::string, std::string >(const NodeAttrs &attrs)> FGetAttrDict
Get attribute dictionary from node.
Definition: op_attr_types.h:104
nnvm::FInferType
FInferNodeEntryAttr< int > FInferType
Type inference function. Update the type given the known type information.
Definition: op_attr_types.h:125
node.h
Graph node data structure.
nnvm::FInferNodeEntryAttr
std::function< bool(const NodeAttrs &attrs, std::vector< AttrType > *in_attrs, std::vector< AttrType > *out_attrs)> FInferNodeEntryAttr
Inference function of certain type.
Definition: op_attr_types.h:94
nnvm
Definition: base.h:35
nnvm::FCorrectLayout
std::function< bool(const NodeAttrs &attrs, std::vector< Layout > *ilayouts, const std::vector< Layout > *last_ilayouts, std::vector< Layout > *olayouts)> FCorrectLayout
Infer & correct function of node layout. See Layout for layout convention.
Definition: op_attr_types.h:225