Go to the documentation of this file.
29 #ifndef MXNET_OPERATOR_UTIL_H_
30 #define MXNET_OPERATOR_UTIL_H_
33 #pragma warning(disable : 4503) // disable warning: decorated name length exceeded.
76 std::vector<std::pair<std::string, std::string> >
kwargs;
419 return Get()->fmap_.at(name);
428 std::map<std::string, SimpleOpRegEntry*> fmap_;
439 #define ASSIGN_DISPATCH(out, req, exp) \
445 case kWriteInplace: \
452 LOG(FATAL) << "not reached"; \
459 #define MXNET_SPECIAL_MAX_NDIM 5
481 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
482 static ::mxnet::op::SimpleOpRegEntry& __make_##SimpleOpRegEntry##_##Name##__##DEV##__ = \
483 ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
487 #endif // MXNET_OPERATOR_UTIL_H_
namespace of mxnet
Definition: api_registry.h:33
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:401
virtual TSelf & set_function(int dev_mask, SourceFunction fsource, SimpleOpRegOption register_symbolic=kRegisterSymbolic)=0
set function of the function to be fsource
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:248
virtual TSelf & describe(const std::string &description)=0
Describe the function.
The resources that can be requested by Operator.
Definition: resource.h:38
virtual TSelf & set_enable_kwargs(bool enable_kwargs)=0
set whether to enable kwargs A function cannot have both kwargs and scalar arguments....
virtual TSelf & set_shape_function(SourceShapeFunction fshapeinfer)=0
set source inference function.
@ kInplaceLhsOut
in binary forward, allow inplace left operand with out
Definition: operator_util.h:239
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:74
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
registry for TBlob functions
Definition: operator_util.h:405
Provide lightweight util to do parameter setup and checking.
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input....
Definition: operator_util.h:201
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:107
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
mxnet::TShape(* BinaryShapeFunction)(const mxnet::TShape &lhs, const mxnet::TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:187
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:254
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:231
super class of all gradient function argument
Definition: operator_util.h:53
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:245
registry entry to register simple operators via functions.
Definition: operator_util.h:251
@ kScalarBeforeArray
Definition: operator_util.h:245
virtual TSelf & set_resource_request(const std::vector< ResourceRequest > &reqs)=0
set resource request By default there is no resource request. The resource will be presented in both ...
std::string name
name of the operator
Definition: operator_util.h:256
virtual TSelf & add_arguments(const std::vector< dmlc::ParamFieldInfo > &args)=0
Describe the function.
virtual TSelf & set_gradient(int dev_mask, UnaryGradFunctionT0 fgrad, SimpleOpInplaceOption inplace_out_in_grad)=0
set gradient of the function of this function.
@ kInplaceOutIn
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:237
Ouput value of the function to the function.
Definition: operator_util.h:64
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:76
tensor blob class that can be used to hold tensor of any dimension, any device and any data type,...
Definition: tensor_blob.h:65
@ kArrayBeforeScalar
Definition: operator_util.h:245
TBlob data
The real data.
Definition: operator_util.h:55
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:72
mxnet::TShape(* UnaryShapeFunction)(const mxnet::TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:118
static SimpleOpRegistry * Get()
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input.
Definition: operator_util.h:128
@ kNotRegisterSymbolic
Definition: operator_util.h:248
virtual TSelf & set_symbol_op_name(char const *symbol_name)=0
set a seperate name for symbol This must be called before set_function. Default: this is set to be sa...
SimpleOpRegEntry & __REGISTER_OR_FIND__(char const *name)
Internal function to register a name function under name.
@ kInplaceOutLhs
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:241
Gradient of output value.
Definition: operator_util.h:66
@ kNoInplace
do not allow inplace in arguments
Definition: operator_util.h:233
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:440
mxnet::TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape.
Definition: operator_util.h:96
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:220
Operator interface of mxnet.
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input.
Definition: operator_util.h:157
virtual TSelf & set_enable_scalar(bool enable_scalar, SimpleOpScalarOption type_mask=kArrayBeforeScalar)=0
set number of scalar arguments needed to be passed in env A function cannot have both kwargs and scal...
@ kRegisterSymbolic
Definition: operator_util.h:248
Registry utility that helps to build registry singletons.
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:78
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:89
@ kInplaceInOut
in unary forward, allow inplace in with out
Definition: operator_util.h:235
configuration of MXNet as well as basic data structure.
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input.
Definition: operator_util.h:142
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:418
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:85
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:173