Go to the documentation of this file.
25 #ifndef MXNET_OPERATOR_H_
26 #define MXNET_OPERATOR_H_
30 #include <dmlc/logging.h>
70 const std::vector<TBlob>& in_data,
71 const std::vector<OpReqType>& req,
72 const std::vector<TBlob>& out_data,
73 const std::vector<TBlob>& aux_states) = 0;
103 const std::vector<TBlob>& out_grad,
104 const std::vector<TBlob>& in_data,
105 const std::vector<TBlob>& out_data,
106 const std::vector<OpReqType>& req,
107 const std::vector<TBlob>& in_grad,
108 const std::vector<TBlob>& aux_states) {
109 LOG(FATAL) <<
"Backward is not implemented";
138 virtual void Init(
const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
143 virtual std::map<std::string, std::string>
GetParams()
const = 0;
222 std::vector<int>* out_type,
223 std::vector<int>* aux_type)
const {
224 CHECK_LE(in_type->size(), this->ListArguments().size());
226 for (
unsigned i = 0; i < in_type->size(); ++i) {
228 <<
"Unsupported data type " << in_type->at(i);
231 for (
int i = 0; i < n_in; ++i)
236 for (
int i = 0; i < n_out; ++i)
241 for (
int i = 0; i < n_aux; ++i)
263 std::vector<int>* in_type)
const {
264 std::vector<int> out_type, aux_type;
270 CHECK(
InferType(in_type, &out_type, &aux_type));
271 CHECK(
InferShape(in_shape, &out_shape, &aux_shape));
291 return std::vector<ResourceRequest>();
301 return std::vector<ResourceRequest>();
326 const std::vector<int>& in_data,
327 const std::vector<int>& out_data)
const {
330 std::vector<int> ret = out_grad;
331 ret.insert(ret.end(), in_data.begin(), in_data.end());
332 ret.insert(ret.end(), out_data.begin(), out_data.end());
357 const std::vector<int>& in_data,
358 const std::vector<void*>& out_data)
const {
359 return std::vector<std::pair<int, void*> >();
388 const std::vector<int>& out_grad,
389 const std::vector<int>& in_data,
390 const std::vector<int>& out_data,
391 const std::vector<void*>& in_grad)
const {
392 return std::vector<std::pair<int, void*> >();
406 template <
typename T>
408 const std::vector<T>& in_data,
409 const std::vector<T>& out_data)
const {
411 std::vector<int> out_grad_index(out_grad.size());
412 std::vector<int> in_data_index(in_data.size());
413 std::vector<int> out_data_index(out_data.size());
414 for (
size_t i = 0; i < out_grad_index.size(); ++i) {
415 out_grad_index[i] = counter++;
417 for (
size_t i = 0; i < in_data_index.size(); ++i) {
418 in_data_index[i] = counter++;
420 for (
size_t i = 0; i < out_data_index.size(); ++i) {
421 out_data_index[i] = counter++;
423 std::vector<T> all_data;
424 all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
425 all_data.insert(all_data.end(), in_data.begin(), in_data.end());
426 all_data.insert(all_data.end(), out_data.begin(), out_data.end());
428 std::vector<int> ret_index =
431 std::vector<T> ret(ret_index.size());
432 for (
size_t i = 0; i < ret_index.size(); ++i) {
433 ret[i] = all_data[ret_index[i]];
479 CHECK_EQ(this->
name, type) <<
"Register Name and TypeString mismatch, name=\"" << this->
name
481 <<
" but TypeString=\"" << type <<
"\"";
503 #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
504 DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
505 .set_body([]() { return new OperatorPropertyType(); }) \
506 .set_return_type("NDArray-or-Symbol") \
509 #endif // DMLC_USE_CXX11
511 #endif // MXNET_OPERATOR_H_
namespace of mxnet
Definition: api_registry.h:33
virtual std::map< std::string, std::string > GetParams() const =0
Get a map representation of internal parameters. This can be used by Init to recover the state of Ope...
const int default_type_flag
type enum value for default real type
Definition: base.h:492
Common base class for function registry.
Definition: registry.h:151
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:127
OperatorPropertyReg & check_name()
Check if TypeString of the type matches the registered name.
Definition: operator.h:475
virtual void Init(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Initialize the Operator by setting the parameters This function need to be called before all other fu...
virtual bool InferType(std::vector< int > *in_type, std::vector< int > *out_type, std::vector< int > *aux_type) const
infer the data types of outputs and unknown input arguments
Definition: operator.h:221
Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet....
Definition: operator.h:54
virtual Operator * CreateOperator(Context ctx) const =0
Create a Operator on specific context.
virtual ExecType exec_type() const
Definition: operator.h:444
defines configuration macros
virtual std::string TypeString() const =0
return the type string of the Operator subclasses override this function.
virtual ~OperatorProperty()
virtual destructor
Definition: operator.h:132
virtual bool InferShape(mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape, mxnet::ShapeVector *aux_shape) const =0
infer the shapes of outputs and unknown input arguments
All the possible information needed by Operator. This is the superset of RunContext....
Definition: op_attr_types.h:66
Registry entry for OperatorProperty factory functions.
Definition: operator.h:454
virtual int NumOutputs() const
Definition: operator.h:166
OperatorPropertyReg & set_key_var_num_args(const std::string &key)
Set key_var_num_args When this is set, the API caller is required to pass in a argument with key=key_...
Definition: operator.h:468
virtual void Forward(const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0
perform a forward operation of Operator, save the output to TBlob.
std::string type_name()
the string representation of type name
Definition: type_traits.h:101
std::function< OperatorProperty *()> OperatorPropertyFactory
typedef the factory function of operator property
Definition: operator.h:450
virtual OperatorProperty * Copy() const =0
Copy this OperatorProperty.
virtual std::vector< int > DeclareBackwardDependency(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data) const
Declare the input requirement of Backward pass.
Definition: operator.h:325
Global resource allocation handling.
ExecType
the execution type of the operator
Definition: op_attr_types.h:98
@ kSync
Forward/Backward are synchronous calls.
virtual std::vector< std::string > ListOutputs() const
Get name of output values of Operator.
Definition: operator.h:155
Context information about the execution environment.
Definition: base.h:90
virtual std::vector< std::pair< int, void * > > BackwardInplaceOption(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data, const std::vector< void * > &in_grad) const
Get possible backward inplace options. This function enables optimization to reuse memory of inputs i...
Definition: operator.h:387
Additional operator attributes beside the ones provided by NNVM.
std::string key_var_num_args
The key num_args name.
Definition: operator.h:486
virtual std::vector< std::pair< int, void * > > ForwardInplaceOption(const std::vector< int > &in_data, const std::vector< void * > &out_data) const
Get possible forward inplace options. This function enables optimization to reuse memory of inputs in...
Definition: operator.h:356
virtual std::vector< std::string > ListArguments() const
Get input arguments of the Operator.
Definition: operator.h:148
virtual int NumVisibleOutputs() const
get number of visible return values during Symbol creation. If NumVisibleOutputs() = k,...
Definition: operator.h:181
virtual ExecType exec_type() const final
Definition: operator.h:112
virtual std::vector< ResourceRequest > BackwardResource(const mxnet::ShapeVector &in_shape) const
Declare additional resource required in backward pass. These additional resources will be presented i...
Definition: operator.h:300
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:830
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
OperatorPropertyFactory body
Function body to create ProductType.
Definition: registry.h:160
Registry utility that helps to build registry singletons.
Graph node data structure.
virtual Operator * CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, std::vector< int > *in_type) const
Create a Operator on specific context and input shape/type.
Definition: operator.h:261
std::vector< T > BackwardInputs(const std::vector< T > &out_grad, const std::vector< T > &in_data, const std::vector< T > &out_data) const
Get Backward Input Dependency for generic types of data. Normally T can be pointer of Symbol::DataEnt...
Definition: operator.h:407
virtual ~Operator()
destructor
Definition: operator.h:57
static OperatorProperty * Create(const char *type_name)
create OperatorProperty
virtual void Backward(const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states)
Perform a Backward Operation, write gradient to the in_grad.
Definition: operator.h:102
virtual std::vector< ResourceRequest > ForwardResource(const mxnet::ShapeVector &in_shape) const
Declare additional resource required in forward pass. These additional resources will be presented in...
Definition: operator.h:290
std::string name
name of the entry
Definition: registry.h:154
virtual std::vector< std::string > ListAuxiliaryStates() const
Get name of auxiliary states of Operator.
Definition: operator.h:162