27 #ifndef MXNET_CPP_SYMBOL_H_ 28 #define MXNET_CPP_SYMBOL_H_ 84 explicit Symbol(
const char *name);
89 explicit Symbol(
const std::string &name);
106 static Symbol Variable(
const std::string &name =
"");
107 Symbol operator[](
int index);
108 Symbol operator[](
const std::string &index);
113 static Symbol Group(
const std::vector<Symbol> &symbols);
118 static Symbol Load(
const std::string &file_name);
128 void Save(
const std::string &file_name)
const;
132 std::string ToJSON()
const;
137 Symbol GetInternals()
const;
150 Symbol(
const std::string &operator_name,
const std::string &name,
151 std::vector<const char *> input_keys,
152 std::vector<SymbolHandle> input_values,
153 std::vector<const char *> config_keys,
154 std::vector<const char *> config_values);
164 const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
165 std::vector<std::vector<mx_uint> > *in_shape,
166 std::vector<std::vector<mx_uint> > *aux_shape,
167 std::vector<std::vector<mx_uint> > *out_shape)
const;
176 std::vector<std::string> ListArguments()
const;
178 std::vector<std::string> ListOutputs()
const;
180 std::vector<std::string> ListAuxiliaryStates()
const;
182 std::map<std::string, std::string> ListAttributes()
const;
188 void SetAttribute(
const std::string& key,
const std::string& value);
193 void SetAttributes(
const std::map<std::string, std::string>& attrs);
199 std::string GetName()
const;
214 void InferExecutorArrays(
215 const Context &context, std::vector<NDArray> *arg_arrays,
216 std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
217 std::vector<NDArray> *aux_arrays,
218 const std::map<std::string, NDArray> &args_map,
219 const std::map<std::string, NDArray> &arg_grad_store =
220 std::map<std::string, NDArray>(),
221 const std::map<std::string, OpReqType> &grad_req_type =
222 std::map<std::string, OpReqType>(),
223 const std::map<std::string, NDArray> &aux_map =
224 std::map<std::string, NDArray>())
const;
232 void InferArgsMap(
const Context &context,
233 std::map<std::string, NDArray> *args_map,
234 const std::map<std::string, NDArray> &known_args)
const;
254 const std::map<std::string, NDArray> &args_map,
255 const std::map<std::string, NDArray> &arg_grad_store =
256 std::map<std::string, NDArray>(),
257 const std::map<std::string, OpReqType> &grad_req_type =
258 std::map<std::string, OpReqType>(),
259 const std::map<std::string, NDArray> &aux_map =
260 std::map<std::string, NDArray>());
279 Executor *Bind(
const Context &context,
const std::vector<NDArray> &arg_arrays,
280 const std::vector<NDArray> &grad_arrays,
281 const std::vector<OpReqType> &grad_reqs,
282 const std::vector<NDArray> &aux_arrays,
283 const std::map<std::string, Context> &group_to_ctx =
284 std::map<std::string, Context>(),
288 std::shared_ptr<SymBlob> blob_ptr_;
289 static OpMap*& op_map();
298 #endif // MXNET_CPP_SYMBOL_H_ Symbol operator/(mx_float lhs, const Symbol &rhs)
Symbol()
Definition: symbol.h:74
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:85
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:75
float mx_float
manually define float
Definition: c_api.h:60
namespace of mxnet
Definition: base.h:89
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:127
Executor interface.
Definition: executor.h:45
SymBlob(SymbolHandle handle)
construct with SymbolHandle to store
Definition: symbol.h:54
SymbolHandle handle_
the SymbolHandle to store
Definition: symbol.h:62
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
struct to store SymbolHandle
Definition: symbol.h:45
Symbol operator%(mx_float lhs, const Symbol &rhs)
SymBlob()
default constructor
Definition: symbol.h:50
~SymBlob()
destructor, free the SymbolHandle
Definition: symbol.h:58
Symbol operator+(mx_float lhs, const Symbol &rhs)
Graph LoadJSON(const std::string &json_str)
Load a graph from JSON string, redirects to "LoadJSON" pass.
Definition: pass_functions.h:48
Symbol operator-(mx_float lhs, const Symbol &rhs)
Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key="")
Infer shapes in the graph given the information.
Definition: pass_functions.h:98
SymbolHandle GetHandle() const
Definition: symbol.h:141
Context interface.
Definition: ndarray.h:50
Symbol operator*(mx_float lhs, const Symbol &rhs)
Symbol interface.
Definition: symbol.h:72
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:58