mxnet
symbol.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_SYMBOL_H_
27 #define MXNET_CPP_SYMBOL_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/ndarray.h"
34 #include "mxnet-cpp/op_map.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class Executor;
40 
44 struct SymBlob {
45  public:
49  SymBlob() : handle_(nullptr) {}
53  explicit SymBlob(SymbolHandle handle) : handle_(handle) {}
62 
63  private:
64  SymBlob(const SymBlob &);
65  SymBlob &operator=(const SymBlob &);
66 };
67 
71 class Symbol {
72  public:
73  Symbol() {}
78  explicit Symbol(SymbolHandle handle);
83  explicit Symbol(const char *name);
88  explicit Symbol(const std::string &name);
89  Symbol operator+(const Symbol &rhs) const;
90  Symbol operator-(const Symbol &rhs) const;
91  Symbol operator*(const Symbol &rhs) const;
92  Symbol operator/(const Symbol &rhs) const;
93  Symbol operator%(const Symbol &rhs) const;
94 
95  Symbol operator+(mx_float scalar) const;
96  Symbol operator-(mx_float scalar) const;
97  Symbol operator*(mx_float scalar) const;
98  Symbol operator/(mx_float scalar) const;
99  Symbol operator%(mx_float scalar) const;
100  Symbol Copy() const;
105  static Symbol Variable(const std::string &name = "");
106  Symbol operator[](int index);
107  Symbol operator[](const std::string &index);
112  static Symbol Group(const std::vector<Symbol> &symbols);
117  static Symbol Load(const std::string &file_name);
122  static Symbol LoadJSON(const std::string &json_str);
127  void Save(const std::string &file_name) const;
131  std::string ToJSON() const;
136  Symbol GetInternals() const;
140  SymbolHandle GetHandle() const { return blob_ptr_->handle_; }
149  Symbol(const std::string &operator_name, const std::string &name,
150  std::vector<const char *> input_keys,
151  std::vector<SymbolHandle> input_values,
152  std::vector<const char *> config_keys,
153  std::vector<const char *> config_values);
162  void InferShape(
163  const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
164  std::vector<std::vector<mx_uint> > *in_shape,
165  std::vector<std::vector<mx_uint> > *aux_shape,
166  std::vector<std::vector<mx_uint> > *out_shape) const;
175  std::vector<std::string> ListArguments() const;
177  std::vector<std::string> ListOutputs() const;
179  std::vector<std::string> ListAuxiliaryStates() const;
194  void InferExecutorArrays(
195  const Context &context, std::vector<NDArray> *arg_arrays,
196  std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
197  std::vector<NDArray> *aux_arrays,
198  const std::map<std::string, NDArray> &args_map,
199  const std::map<std::string, NDArray> &arg_grad_store =
200  std::map<std::string, NDArray>(),
201  const std::map<std::string, OpReqType> &grad_req_type =
202  std::map<std::string, OpReqType>(),
203  const std::map<std::string, NDArray> &aux_map =
204  std::map<std::string, NDArray>()) const;
212  void InferArgsMap(const Context &context,
213  std::map<std::string, NDArray> *args_map,
214  const std::map<std::string, NDArray> &known_args) const;
233  Executor *SimpleBind(const Context &context,
234  const std::map<std::string, NDArray> &args_map,
235  const std::map<std::string, NDArray> &arg_grad_store =
236  std::map<std::string, NDArray>(),
237  const std::map<std::string, OpReqType> &grad_req_type =
238  std::map<std::string, OpReqType>(),
239  const std::map<std::string, NDArray> &aux_map =
240  std::map<std::string, NDArray>());
259  Executor *Bind(const Context &context, const std::vector<NDArray> &arg_arrays,
260  const std::vector<NDArray> &grad_arrays,
261  const std::vector<OpReqType> &grad_reqs,
262  const std::vector<NDArray> &aux_arrays,
263  const std::map<std::string, Context> &group_to_ctx =
264  std::map<std::string, Context>(),
265  Executor *shared_exec = nullptr);
266 
267  private:
268  std::shared_ptr<SymBlob> blob_ptr_;
269  static OpMap*& op_map();
270 };
271 Symbol operator+(mx_float lhs, const Symbol &rhs);
272 Symbol operator-(mx_float lhs, const Symbol &rhs);
273 Symbol operator*(mx_float lhs, const Symbol &rhs);
274 Symbol operator/(mx_float lhs, const Symbol &rhs);
275 Symbol operator%(mx_float lhs, const Symbol &rhs);
276 } // namespace cpp
277 } // namespace mxnet
278 #endif // MXNET_CPP_SYMBOL_H_
Symbol operator/(mx_float lhs, const Symbol &rhs)
Symbol()
Definition: symbol.h:73
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:42
definition of OpMap
namespace of mxnet
Definition: base.h:126
Executor interface.
Definition: executor.h:44
SymBlob(SymbolHandle handle)
construct with SymbolHandle to store
Definition: symbol.h:53
SymbolHandle handle_
the SymbolHandle to store
Definition: symbol.h:61
struct to store SymbolHandle
Definition: symbol.h:44
Symbol operator%(mx_float lhs, const Symbol &rhs)
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:72
SymBlob()
default constructor
Definition: symbol.h:49
~SymBlob()
destructor, free the SymbolHandle
Definition: symbol.h:57
Symbol operator+(mx_float lhs, const Symbol &rhs)
float mx_float
manually define float
Definition: c_api.h:59
Symbol operator-(mx_float lhs, const Symbol &rhs)
SymbolHandle GetHandle() const
Definition: symbol.h:140
Context interface.
Definition: ndarray.h:49
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
Symbol operator*(mx_float lhs, const Symbol &rhs)
Symbol interface.
Definition: symbol.h:71