mxnet
operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_OPERATOR_H_
27 #define MXNET_CPP_OPERATOR_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/op_map.h"
34 #include "mxnet-cpp/symbol.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 class Mxnet;
42 class Operator {
43  public:
48  explicit Operator(const std::string& operator_name);
49  Operator& operator=(const Operator& rhs);
56  template <typename T>
57  Operator& SetParam(const std::string& name, const T& value) {
58  std::string value_str;
59  std::stringstream ss;
60  ss << value;
61  ss >> value_str;
62 
63  params_[name] = value_str;
64  return *this;
65  }
72  template <typename T>
73  Operator& SetParam(int pos, const T& value) {
74  std::string value_str;
75  std::stringstream ss;
76  ss << value;
77  ss >> value_str;
78 
79  params_[arg_names_[pos]] = value_str;
80  return *this;
81  }
88  Operator& SetInput(const std::string& name, const Symbol& symbol);
93  template <int N = 0>
94  void PushInput(const Symbol& symbol) {
95  input_symbols_.push_back(symbol.GetHandle());
96  }
102  return *this;
103  }
109  Operator& operator()(const Symbol& symbol) {
110  input_symbols_.push_back(symbol.GetHandle());
111  return *this;
112  }
118  Operator& operator()(const std::vector<Symbol>& symbols) {
119  for (auto& s : symbols) {
120  input_symbols_.push_back(s.GetHandle());
121  }
122  return *this;
123  }
129  Symbol CreateSymbol(const std::string& name = "");
130 
137  Operator& SetInput(const std::string& name, const NDArray& ndarray);
142  template <int N = 0>
143  Operator& PushInput(const NDArray& ndarray) {
144  input_ndarrays_.push_back(ndarray.GetHandle());
145  return *this;
146  }
150  template <class T, class... Args, int N = 0>
151  Operator& PushInput(const T& t, Args... args) {
152  SetParam(N, t);
153  PushInput<Args..., N + 1>(args...);
154  return *this;
155  }
159  template <class T, int N = 0>
160  Operator& PushInput(const T& t) {
161  SetParam(N, t);
162  return *this;
163  }
169  Operator& operator()(const NDArray& ndarray) {
170  input_ndarrays_.push_back(ndarray.GetHandle());
171  return *this;
172  }
178  Operator& operator()(const std::vector<NDArray>& ndarrays) {
179  for (auto& s : ndarrays) {
180  input_ndarrays_.push_back(s.GetHandle());
181  }
182  return *this;
183  }
188  template <typename... Args>
189  Operator& operator()(Args... args) {
190  PushInput(args...);
191  return *this;
192  }
193  std::vector<NDArray> Invoke();
194  void Invoke(NDArray& output);
195  void Invoke(std::vector<NDArray>& outputs);
196 
197  private:
198  std::map<std::string, std::string> params_desc_;
199  bool variable_params_ = false;
200  std::map<std::string, std::string> params_;
201  std::vector<SymbolHandle> input_symbols_;
202  std::vector<NDArrayHandle> input_ndarrays_;
203  std::vector<std::string> input_keys_;
204  std::vector<std::string> arg_names_;
205  AtomicSymbolCreator handle_;
206  static OpMap*& op_map();
207 };
208 } // namespace cpp
209 } // namespace mxnet
210 
211 #endif // MXNET_CPP_OPERATOR_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::Operator::SetInput
Operator & SetInput(const std::string &name, const Symbol &symbol)
add an input symbol
mxnet::cpp::Operator::PushInput
void PushInput(const Symbol &symbol)
add an input symbol
Definition: operator.h:94
mxnet::cpp::Operator::SetParam
Operator & SetParam(int pos, const T &value)
set config parameters from positional inputs
Definition: operator.h:73
mxnet::cpp::Operator::PushInput
Operator & PushInput(const T &t, Args... args)
add positional inputs
Definition: operator.h:151
mxnet::cpp::Operator::operator=
Operator & operator=(const Operator &rhs)
mxnet::cpp::Operator::operator()
Operator & operator()(const std::vector< Symbol > &symbols)
add a list of input symbols
Definition: operator.h:118
mxnet::cpp::Operator
Operator interface.
Definition: operator.h:42
mxnet::cpp::NDArray
NDArray interface.
Definition: ndarray.h:122
mxnet::cpp::Operator::PushInput
Operator & PushInput(const T &t)
add the last positional input
Definition: operator.h:160
mxnet::cpp::Operator::operator()
Operator & operator()(const std::vector< NDArray > &ndarrays)
add a list of input ndarrays
Definition: operator.h:178
mxnet::cpp::OpMap
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name....
Definition: op_map.h:42
mxnet::cpp::Operator::Invoke
std::vector< NDArray > Invoke()
mxnet::cpp::Operator::SetParam
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:57
symbol.h
definition of symbol
AtomicSymbolCreator
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:78
mxnet::cpp::Operator::CreateSymbol
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
mxnet::cpp::Operator::operator()
Operator & operator()(const Symbol &symbol)
add input symbols
Definition: operator.h:109
mxnet::cpp::Operator::operator()
Operator & operator()(const NDArray &ndarray)
add input ndarrays
Definition: operator.h:169
mxnet::cpp::Operator::operator()
Operator & operator()(Args... args)
add input ndarrays
Definition: operator.h:189
mxnet::cpp::Operator::Operator
Operator(const std::string &operator_name)
Operator constructor.
mxnet::cpp::NDArray::GetHandle
NDArrayHandle GetHandle() const
Definition: ndarray.h:475
mxnet::cpp::Symbol::GetHandle
SymbolHandle GetHandle() const
Definition: symbol.h:142
base.h
base definitions for mxnetcpp
mxnet::cpp::Operator::operator()
Operator & operator()()
add input symbols
Definition: operator.h:101
mxnet::cpp::Operator::PushInput
Operator & PushInput(const NDArray &ndarray)
add an input ndarray
Definition: operator.h:143
mxnet::cpp::Symbol
Symbol interface.
Definition: symbol.h:73
op_map.h
definition of OpMap