mxnet
symbol.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_SYMBOL_H_
27 #define MXNET_CPP_SYMBOL_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/ndarray.h"
34 #include "mxnet-cpp/op_map.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class Executor;
40 
44 struct SymBlob {
45  public:
49  SymBlob() : handle_(nullptr) {}
53  explicit SymBlob(SymbolHandle handle) : handle_(handle) {}
59  }
64 
65  private:
66  SymBlob(const SymBlob&);
67  SymBlob& operator=(const SymBlob&);
68 };
69 
73 class Symbol {
74  public:
75  Symbol() {}
80  explicit Symbol(SymbolHandle handle);
85  explicit Symbol(const char* name);
90  explicit Symbol(const std::string& name);
91  Symbol operator+(const Symbol& rhs) const;
92  Symbol operator-(const Symbol& rhs) const;
93  Symbol operator*(const Symbol& rhs) const;
94  Symbol operator/(const Symbol& rhs) const;
95  Symbol operator%(const Symbol& rhs) const;
96 
102  Symbol Copy() const;
107  static Symbol Variable(const std::string& name = "");
108  Symbol operator[](int index);
109  Symbol operator[](const std::string& index);
114  static Symbol Group(const std::vector<Symbol>& symbols);
119  static Symbol Load(const std::string& file_name);
124  static Symbol LoadJSON(const std::string& json_str);
129  void Save(const std::string& file_name) const;
133  std::string ToJSON() const;
138  Symbol GetInternals() const;
143  return (blob_ptr_) ? blob_ptr_->handle_ : nullptr;
144  }
153  Symbol(const std::string& operator_name,
154  const std::string& name,
155  std::vector<const char*> input_keys,
156  std::vector<SymbolHandle> input_values,
157  std::vector<const char*> config_keys,
158  std::vector<const char*> config_values);
167  void InferShape(const std::map<std::string, std::vector<mx_uint> >& arg_shapes,
168  std::vector<std::vector<mx_uint> >* in_shape,
169  std::vector<std::vector<mx_uint> >* aux_shape,
170  std::vector<std::vector<mx_uint> >* out_shape) const;
179  std::vector<std::string> ListArguments() const;
181  std::vector<std::string> ListInputs() const;
183  std::vector<std::string> ListOutputs() const;
185  std::vector<std::string> ListAuxiliaryStates() const;
187  std::map<std::string, std::string> ListAttributes() const;
193  void SetAttribute(const std::string& key, const std::string& value);
198  void SetAttributes(const std::map<std::string, std::string>& attrs);
200  mx_uint GetNumOutputs() const;
202  mxnet::cpp::Symbol GetBackendSymbol(const std::string& backendName) const;
204  std::string GetName() const;
219  void InferExecutorArrays(
220  const Context& context,
221  std::vector<NDArray>* arg_arrays,
222  std::vector<NDArray>* grad_arrays,
223  std::vector<OpReqType>* grad_reqs,
224  std::vector<NDArray>* aux_arrays,
225  const std::map<std::string, NDArray>& args_map,
226  const std::map<std::string, NDArray>& arg_grad_store = std::map<std::string, NDArray>(),
227  const std::map<std::string, OpReqType>& grad_req_type = std::map<std::string, OpReqType>(),
228  const std::map<std::string, NDArray>& aux_map = std::map<std::string, NDArray>()) const;
236  void InferArgsMap(const Context& context,
237  std::map<std::string, NDArray>* args_map,
238  const std::map<std::string, NDArray>& known_args) const;
258  const Context& context,
259  const std::map<std::string, NDArray>& args_map,
260  const std::map<std::string, NDArray>& arg_grad_store = std::map<std::string, NDArray>(),
261  const std::map<std::string, OpReqType>& grad_req_type = std::map<std::string, OpReqType>(),
262  const std::map<std::string, NDArray>& aux_map = std::map<std::string, NDArray>());
281  Executor* Bind(
282  const Context& context,
283  const std::vector<NDArray>& arg_arrays,
284  const std::vector<NDArray>& grad_arrays,
285  const std::vector<OpReqType>& grad_reqs,
286  const std::vector<NDArray>& aux_arrays,
287  const std::map<std::string, Context>& group_to_ctx = std::map<std::string, Context>(),
288  Executor* shared_exec = nullptr);
289 
290  private:
291  std::shared_ptr<SymBlob> blob_ptr_;
292  static OpMap*& op_map();
293 };
294 Symbol operator+(mx_float lhs, const Symbol& rhs);
295 Symbol operator-(mx_float lhs, const Symbol& rhs);
296 Symbol operator*(mx_float lhs, const Symbol& rhs);
297 Symbol operator/(mx_float lhs, const Symbol& rhs);
298 Symbol operator%(mx_float lhs, const Symbol& rhs);
299 } // namespace cpp
300 } // namespace mxnet
301 #endif // MXNET_CPP_SYMBOL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::Symbol::ListOutputs
std::vector< std::string > ListOutputs() const
mxnet::cpp::SymBlob::SymBlob
SymBlob(SymbolHandle handle)
construct with SymbolHandle to store
Definition: symbol.h:53
mxnet::cpp::operator+
Symbol operator+(mx_float lhs, const Symbol &rhs)
mshadow::expr::scalar
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
mxnet::cpp::Symbol::InferArgsMap
void InferArgsMap(const Context &context, std::map< std::string, NDArray > *args_map, const std::map< std::string, NDArray > &known_args) const
infer and construct all the input arguments arrays to bind to executor by providing some known argume...
MXSymbolFree
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
mxnet::cpp::Symbol::ToJSON
std::string ToJSON() const
save Symbol into a JSON string
mxnet::cpp::Symbol::operator-
Symbol operator-(const Symbol &rhs) const
mxnet::cpp::Symbol::ListArguments
std::vector< std::string > ListArguments() const
List the arguments names.
mxnet::cpp::Symbol::Save
void Save(const std::string &file_name) const
save Symbol to a file
mxnet::cpp::operator-
Symbol operator-(mx_float lhs, const Symbol &rhs)
mxnet::cpp::Symbol::InferExecutorArrays
void InferExecutorArrays(const Context &context, std::vector< NDArray > *arg_arrays, std::vector< NDArray > *grad_arrays, std::vector< OpReqType > *grad_reqs, std::vector< NDArray > *aux_arrays, const std::map< std::string, NDArray > &args_map, const std::map< std::string, NDArray > &arg_grad_store=std::map< std::string, NDArray >(), const std::map< std::string, OpReqType > &grad_req_type=std::map< std::string, OpReqType >(), const std::map< std::string, NDArray > &aux_map=std::map< std::string, NDArray >()) const
infer and construct all the arrays to bind to executor by providing some known arrays.
mxnet::cpp::Symbol::ListInputs
std::vector< std::string > ListInputs() const
mxnet::cpp::SymBlob::~SymBlob
~SymBlob()
destructor, free the SymbolHandle
Definition: symbol.h:57
mxnet::cpp::Symbol::operator*
Symbol operator*(const Symbol &rhs) const
mxnet::cpp::Symbol::SetAttribute
void SetAttribute(const std::string &key, const std::string &value)
set key-value attribute to the symbol
ndarray.h
definition of ndarray
mxnet::cpp::Context
Context interface.
Definition: ndarray.h:45
mxnet::cpp::Symbol::ListAuxiliaryStates
std::vector< std::string > ListAuxiliaryStates() const
nnvm::Symbol
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
mxnet::cpp::operator*
Symbol operator*(mx_float lhs, const Symbol &rhs)
mxnet::cpp::OpMap
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name....
Definition: op_map.h:42
mxnet::cpp::Symbol::GetName
std::string GetName() const
mx_float
float mx_float
manually define float
Definition: c_api.h:67
mxnet::cpp::Symbol::Bind
Executor * Bind(const Context &context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
Create an executor by bind symbol with context and arguments. If user do not want to compute the grad...
mxnet::cpp::operator/
Symbol operator/(mx_float lhs, const Symbol &rhs)
mxnet::cpp::Symbol::operator+
Symbol operator+(const Symbol &rhs) const
mxnet::cpp::SymBlob::SymBlob
SymBlob()
default constructor
Definition: symbol.h:49
mxnet::cpp::Symbol::Load
static Symbol Load(const std::string &file_name)
load Symbol from a JSON file
mxnet::cpp::Symbol::operator/
Symbol operator/(const Symbol &rhs) const
mxnet::cpp::Symbol::GetNumOutputs
mx_uint GetNumOutputs() const
mxnet::cpp::Symbol::GetInternals
Symbol GetInternals() const
save Symbol into a JSON string \retutrn the symbol whose outputs are all the internals.
mxnet::cpp::Symbol::GetBackendSymbol
mxnet::cpp::Symbol GetBackendSymbol(const std::string &backendName) const
SymbolHandle
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:82
mxnet::cpp::Symbol::Symbol
Symbol()
Definition: symbol.h:75
mxnet::cpp::SymBlob::handle_
SymbolHandle handle_
the SymbolHandle to store
Definition: symbol.h:63
mxnet::cpp::Symbol::ListAttributes
std::map< std::string, std::string > ListAttributes() const
mxnet::cpp::Executor
Executor interface.
Definition: executor.h:45
mxnet::cpp::Symbol::GetHandle
SymbolHandle GetHandle() const
Definition: symbol.h:142
mxnet::cpp::Symbol::operator[]
Symbol operator[](int index)
mxnet::cpp::SymBlob
struct to store SymbolHandle
Definition: symbol.h:44
mxnet::cpp::Symbol::SimpleBind
Executor * SimpleBind(const Context &context, const std::map< std::string, NDArray > &args_map, const std::map< std::string, NDArray > &arg_grad_store=std::map< std::string, NDArray >(), const std::map< std::string, OpReqType > &grad_req_type=std::map< std::string, OpReqType >(), const std::map< std::string, NDArray > &aux_map=std::map< std::string, NDArray >())
Create an executor by bind symbol with context and arguments. If user do not want to compute the grad...
mxnet::cpp::Symbol::SetAttributes
void SetAttributes(const std::map< std::string, std::string > &attrs)
set a series of key-value attribute to the symbol
mxnet::cpp::Symbol::operator%
Symbol operator%(const Symbol &rhs) const
base.h
base definitions for mxnetcpp
mxnet::cpp::operator%
Symbol operator%(mx_float lhs, const Symbol &rhs)
mxnet::cpp::Symbol::Variable
static Symbol Variable(const std::string &name="")
construct a variable Symbol
mx_uint
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65
mxnet::cpp::Symbol::Group
static Symbol Group(const std::vector< Symbol > &symbols)
Create a symbol that groups symbols together.
mxnet::cpp::Symbol::LoadJSON
static Symbol LoadJSON(const std::string &json_str)
load Symbol from a JSON string
mxnet::cpp::Symbol
Symbol interface.
Definition: symbol.h:73
op_map.h
definition of OpMap
mxnet::cpp::Symbol::InferShape
void InferShape(const std::map< std::string, std::vector< mx_uint > > &arg_shapes, std::vector< std::vector< mx_uint > > *in_shape, std::vector< std::vector< mx_uint > > *aux_shape, std::vector< std::vector< mx_uint > > *out_shape) const
infer the shapes by providing shapes of known argument shapes.
mxnet::cpp::Symbol::Copy
Symbol Copy() const