mxnet
symbol.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_SYMBOL_H_
28 #define MXNET_CPP_SYMBOL_H_
29 
30 #include <map>
31 #include <string>
32 #include <vector>
33 #include "mxnet-cpp/base.h"
34 #include "mxnet-cpp/ndarray.h"
35 #include "mxnet-cpp/op_map.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
40 class Executor;
41 
45 struct SymBlob {
46  public:
50  SymBlob() : handle_(nullptr) {}
54  explicit SymBlob(SymbolHandle handle) : handle_(handle) {}
63 
64  private:
65  SymBlob(const SymBlob &);
66  SymBlob &operator=(const SymBlob &);
67 };
68 
72 class Symbol {
73  public:
74  Symbol() {}
79  explicit Symbol(SymbolHandle handle);
84  explicit Symbol(const char *name);
89  explicit Symbol(const std::string &name);
90  Symbol operator+(const Symbol &rhs) const;
91  Symbol operator-(const Symbol &rhs) const;
92  Symbol operator*(const Symbol &rhs) const;
93  Symbol operator/(const Symbol &rhs) const;
94  Symbol operator%(const Symbol &rhs) const;
95 
96  Symbol operator+(mx_float scalar) const;
97  Symbol operator-(mx_float scalar) const;
98  Symbol operator*(mx_float scalar) const;
99  Symbol operator/(mx_float scalar) const;
100  Symbol operator%(mx_float scalar) const;
101  Symbol Copy() const;
106  static Symbol Variable(const std::string &name = "");
107  Symbol operator[](int index);
108  Symbol operator[](const std::string &index);
113  static Symbol Group(const std::vector<Symbol> &symbols);
118  static Symbol Load(const std::string &file_name);
123  static Symbol LoadJSON(const std::string &json_str);
128  void Save(const std::string &file_name) const;
132  std::string ToJSON() const;
137  Symbol GetInternals() const;
141  SymbolHandle GetHandle() const { return blob_ptr_->handle_; }
150  Symbol(const std::string &operator_name, const std::string &name,
151  std::vector<const char *> input_keys,
152  std::vector<SymbolHandle> input_values,
153  std::vector<const char *> config_keys,
154  std::vector<const char *> config_values);
163  void InferShape(
164  const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
165  std::vector<std::vector<mx_uint> > *in_shape,
166  std::vector<std::vector<mx_uint> > *aux_shape,
167  std::vector<std::vector<mx_uint> > *out_shape) const;
176  std::vector<std::string> ListArguments() const;
178  std::vector<std::string> ListOutputs() const;
180  std::vector<std::string> ListAuxiliaryStates() const;
195  void InferExecutorArrays(
196  const Context &context, std::vector<NDArray> *arg_arrays,
197  std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
198  std::vector<NDArray> *aux_arrays,
199  const std::map<std::string, NDArray> &args_map,
200  const std::map<std::string, NDArray> &arg_grad_store =
201  std::map<std::string, NDArray>(),
202  const std::map<std::string, OpReqType> &grad_req_type =
203  std::map<std::string, OpReqType>(),
204  const std::map<std::string, NDArray> &aux_map =
205  std::map<std::string, NDArray>()) const;
213  void InferArgsMap(const Context &context,
214  std::map<std::string, NDArray> *args_map,
215  const std::map<std::string, NDArray> &known_args) const;
234  Executor *SimpleBind(const Context &context,
235  const std::map<std::string, NDArray> &args_map,
236  const std::map<std::string, NDArray> &arg_grad_store =
237  std::map<std::string, NDArray>(),
238  const std::map<std::string, OpReqType> &grad_req_type =
239  std::map<std::string, OpReqType>(),
240  const std::map<std::string, NDArray> &aux_map =
241  std::map<std::string, NDArray>());
260  Executor *Bind(const Context &context, const std::vector<NDArray> &arg_arrays,
261  const std::vector<NDArray> &grad_arrays,
262  const std::vector<OpReqType> &grad_reqs,
263  const std::vector<NDArray> &aux_arrays,
264  const std::map<std::string, Context> &group_to_ctx =
265  std::map<std::string, Context>(),
266  Executor *shared_exec = nullptr);
267 
268  private:
269  std::shared_ptr<SymBlob> blob_ptr_;
270  static OpMap*& op_map();
271 };
272 Symbol operator+(mx_float lhs, const Symbol &rhs);
273 Symbol operator-(mx_float lhs, const Symbol &rhs);
274 Symbol operator*(mx_float lhs, const Symbol &rhs);
275 Symbol operator/(mx_float lhs, const Symbol &rhs);
276 Symbol operator%(mx_float lhs, const Symbol &rhs);
277 } // namespace cpp
278 } // namespace mxnet
279 #endif // MXNET_CPP_SYMBOL_H_
Symbol operator/(mx_float lhs, const Symbol &rhs)
Symbol()
Definition: symbol.h:74
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
definition of OpMap
namespace of mxnet
Definition: base.h:127
Executor interface.
Definition: executor.h:45
SymBlob(SymbolHandle handle)
construct with SymbolHandle to store
Definition: symbol.h:54
SymbolHandle handle_
the SymbolHandle to store
Definition: symbol.h:62
struct to store SymbolHandle
Definition: symbol.h:45
Symbol operator%(mx_float lhs, const Symbol &rhs)
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:73
SymBlob()
default constructor
Definition: symbol.h:50
~SymBlob()
destructor, free the SymbolHandle
Definition: symbol.h:58
Symbol operator+(mx_float lhs, const Symbol &rhs)
float mx_float
manually define float
Definition: c_api.h:60
Symbol operator-(mx_float lhs, const Symbol &rhs)
SymbolHandle GetHandle() const
Definition: symbol.h:141
Context interface.
Definition: ndarray.h:50
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
Symbol operator*(mx_float lhs, const Symbol &rhs)
Symbol interface.
Definition: symbol.h:72