mxnet
Classes | Namespaces | Macros | Typedefs | Enumerations
operator_util.h File Reference

Utility functions and registries to help quickly build new operators. [Deprecated] Use the register functions in this file when possible to simplify operator creations. Operators registered in this file will be exposed to both NDArray API and symbolic API. More...

#include <dmlc/registry.h>
#include <dmlc/parameter.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "./base.h"
#include "./operator.h"
#include <functional>
Include dependency graph for operator_util.h:

Go to the source code of this file.

Classes

struct  mxnet::op::GradFunctionArgument
 super class of all gradient function argument More...
 
struct  mxnet::op::Input0
 First input to the function. More...
 
struct  mxnet::op::Input1
 Second input to the function. More...
 
struct  mxnet::op::OutputValue
 Ouput value of the function to the function. More...
 
struct  mxnet::op::OutputGrad
 Gradient of output value. More...
 
struct  mxnet::op::EnvArguments
 Environment arguments that is used by the function. These can be things like scalar arguments when add a value with scalar. More...
 
class  mxnet::op::SimpleOpRegEntry
 registry entry to register simple operators via functions. More...
 
class  mxnet::op::SimpleOpRegistry
 registry for TBlob functions More...
 

Namespaces

 mxnet
 namespace of mxnet
 
 mxnet::op
 namespace of arguments
 

Macros

#define ASSIGN_DISPATCH(out, req, exp)
 assign the expression to out according to request More...
 
#define MXNET_SPECIAL_MAX_NDIM   5
 Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs. More...
 
#define MXNET_REGISTER_SIMPLE_OP(Name, DEV)
 Macro to register simple operator to both imperative and symbolic API. More...
 

Typedefs

typedef void(* mxnet::op::SourceFunction) (const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
 source function that generate output based on env The result container is pre-allocated with the correct shape. More...
 
typedef TShape(* mxnet::op::SourceShapeFunction) (const EnvArguments &env)
 Shape inference function to get the correct shape. More...
 
typedef void(* mxnet::op::UnaryFunction) (const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
 Unary function that takes a src and save result to ret. The result container is pre-allocated with the correct shape. More...
 
typedef TShape(* mxnet::op::UnaryShapeFunction) (const TShape &src, const EnvArguments &env)
 Shape inference function to get the correct shape given source. More...
 
typedef void(* mxnet::op::UnaryGradFunctionT0) (const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
 Gradient function that takes output value of function and computes gradient wrt to input. More...
 
typedef void(* mxnet::op::UnaryGradFunctionT1) (const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
 Gradient function that takes output value of function and computes gradient wrt to input. More...
 
typedef void(* mxnet::op::UnaryGradFunctionT2) (const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
 Gradient function that takes input value of function and computes gradient wrt to input. More...
 
typedef void(* mxnet::op::BinaryFunction) (const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
 Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated with the correct shape. More...
 
typedef TShape(* mxnet::op::BinaryShapeFunction) (const TShape &lhs, const TShape &rhs, const EnvArguments &env)
 Shape inference function to get the correct shape given source shapes. More...
 
typedef void(* mxnet::op::BinaryGradFunctionT0) (const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
 Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops. More...
 
typedef void(* mxnet::op::BinaryGradFunctionT1) (const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
 Gradient function that takes inputs of function anod computes gradient wrt to input. More...
 

Enumerations

enum  mxnet::op::SimpleOpInplaceOption {
  mxnet::op::kNoInplace, mxnet::op::kInplaceInOut, mxnet::op::kInplaceOutIn, mxnet::op::kInplaceLhsOut,
  mxnet::op::kInplaceOutLhs
}
 options in the registry to set inplace of operator More...
 
enum  mxnet::op::SimpleOpScalarOption { mxnet::op::kScalarBeforeArray, mxnet::op::kArrayBeforeScalar }
 options in the registry to set symbolic registration More...
 
enum  mxnet::op::SimpleOpRegOption { mxnet::op::kNotRegisterSymbolic, mxnet::op::kRegisterSymbolic }
 options in the registry to set symbolic registration More...
 

Detailed Description

Utility functions and registries to help quickly build new operators. [Deprecated] Use the register functions in this file when possible to simplify operator creations. Operators registered in this file will be exposed to both NDArray API and symbolic API.

Author
Tianqi Chen

Macro Definition Documentation

#define ASSIGN_DISPATCH (   out,
  req,
  exp 
)
Value:
{ \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
(out) = (exp); \
break; \
case kAddTo: \
(out) += (exp); \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:1828
no operation, do not write anything
Definition: op_attr_types.h:46
write gradient to provided space
Definition: op_attr_types.h:48
perform an inplace write, Target shares memory with one of input arguments. This option only happen w...
Definition: op_attr_types.h:54
add to the provided space
Definition: op_attr_types.h:56

assign the expression to out according to request

Parameters
outthe data to be assigned
reqthe assignment request
expthe expression
Template Parameters
OTypeoutput type
Expexpression type
#define MXNET_REGISTER_SIMPLE_OP (   Name,
  DEV 
)
Value:
static ::mxnet::op::SimpleOpRegEntry & \
__make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
static SimpleOpRegistry * Get()
SimpleOpRegEntry & __REGISTER_OR_FIND__(char const *name)
Internal function to register a name function under name.

Macro to register simple operator to both imperative and symbolic API.

see src/operator/elementwise_unary_op-inl.h for example

1 // example of registering a sigmoid operator on GPU
2 // MySigmoid is of type UnaryFunction,
3 // MySigmoidGrad is of type UnaryGradFunctionT2
4 
5 MXNET_REGISTER_SIMPLE_OP(sigmoid, cpu)
6 .set_function(MySigmoid<gpu>, true)
7 .set_gradient(MySigmoidGrad<gpu>, true)
8 .describe("Sigmoid function");
#define MXNET_SPECIAL_MAX_NDIM   5

Maximum ndim supported for special operators like broadcasting with non contiguous lhs/rhs.