1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
24 #ifndef NNVM_OP_ATTR_TYPES_H_
25 #define NNVM_OP_ATTR_TYPES_H_
27 #include <vector>
28 #include <string>
29 #include <utility>
30 #include <functional>
31 #include <unordered_map>
32 #include "base.h"
33 #include "node.h"
34 #include "tuple.h"
35 #include "layout.h"
37 namespace nnvm {
39 // These types are optional attributes in each operator.
40 // Each attribute can be required by some passes.
51 using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
63 using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>;
74 using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
84 using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
91 template<typename AttrType>
92 using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
93  std::vector<AttrType> *in_attrs,
94  std::vector<AttrType> *out_attrs)>;
103 using FGetAttrDict = std::function<
104  std::unordered_map<std::string, std::string>
105  (const NodeAttrs& attrs)>;
136 using TIsBackward = bool;
147 using TIsGhost = bool;
158 using FInplaceOption = std::function<
159  std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
171 using FInplaceIdentity = std::function<std::vector<bool> (const NodeAttrs& attrs)>;
182 using FIgnoreInputs = std::function<
183  std::vector<uint32_t> (const NodeAttrs& attrs)>;
194 using FGradient = std::function<std::vector<NodeEntry>(
195  const ObjectPtr& nodeptr,
196  const std::vector<NodeEntry>& out_grads)>;
205 using FSetInputVarAttrOnCompose = std::function<void(
206  const NodeAttrs& attrs,
207  ObjectPtr var,
208  const int index)>;
229 using FCorrectLayout = std::function<bool(
230  const NodeAttrs& attrs,
231  std::vector<Layout> *ilayouts,
232  const std::vector<Layout> *last_ilayouts,
233  std::vector<Layout> *olayouts)>;
245 using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
247 } // namespace nnvm
249 #endif // NNVM_OP_ATTR_TYPES_H_
