mxnet
operator_util.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
29 #ifndef MXNET_OPERATOR_UTIL_H_
30 #define MXNET_OPERATOR_UTIL_H_
31 
32 #ifdef _MSC_VER
33 #pragma warning(disable : 4503) // disable warning: decorated name length exceeded.
34 #endif
35 
36 #include <dmlc/registry.h>
37 #include <dmlc/parameter.h>
38 #include <map>
39 #include <vector>
40 #include <string>
41 #include <utility>
42 #include "./base.h"
43 #include "./operator.h"
44 
45 #if DMLC_USE_CXX11
46 #include <functional>
47 #endif
48 
49 namespace mxnet {
51 namespace op {
56 };
57 
62 
67 
72 struct EnvArguments {
76  std::vector<std::pair<std::string, std::string> > kwargs;
78  std::vector<Resource> resource;
79 };
80 
89 typedef void (*SourceFunction)(const EnvArguments& env, TBlob* ret, OpReqType req, RunContext ctx);
90 
97 
107 typedef void (*UnaryFunction)(const TBlob& src,
108  const EnvArguments& env,
109  TBlob* ret,
110  OpReqType req,
111  RunContext ctx);
118 typedef mxnet::TShape (*UnaryShapeFunction)(const mxnet::TShape& src, const EnvArguments& env);
119 
128 typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad,
129  const EnvArguments& env,
130  TBlob* in_grad,
131  OpReqType req,
132  RunContext ctx);
142 typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad,
143  const OutputValue& out_value,
144  const EnvArguments& env,
145  TBlob* in_grad,
146  OpReqType req,
147  RunContext ctx);
157 typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad,
158  const Input0& in_data0,
159  const EnvArguments& env,
160  TBlob* in_grad,
161  OpReqType req,
162  RunContext ctx);
173 typedef void (*BinaryFunction)(const TBlob& lhs,
174  const TBlob& rhs,
175  const EnvArguments& env,
176  TBlob* ret,
177  OpReqType req,
178  RunContext ctx);
179 
188  const mxnet::TShape& rhs,
189  const EnvArguments& env);
201 typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad,
202  const EnvArguments& env,
203  TBlob* lhs_grad,
204  TBlob* rhs_grad,
205  OpReqType req_lhs_grad,
206  OpReqType req_rhs_grad,
207  RunContext ctx);
220 typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad,
221  const Input0& lhs,
222  const Input1& rhs,
223  const EnvArguments& env,
224  TBlob* lhs_grad,
225  TBlob* rhs_grad,
226  OpReqType req_lhs_grad,
227  OpReqType req_rhs_grad,
228  RunContext ctx);
229 
242 };
243 
246 
249 
252  public:
256  std::string name;
263  virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0;
271  virtual TSelf& set_enable_scalar(bool enable_scalar,
272  SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0;
279  virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0;
286  virtual TSelf& set_resource_request(const std::vector<ResourceRequest>& reqs) = 0;
293  virtual TSelf& set_resource_request(ResourceRequest req) = 0;
298  virtual TSelf& set_shape_function(SourceShapeFunction fshapeinfer) = 0;
304  virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0;
310  virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0;
317  virtual TSelf& set_function(int dev_mask,
318  SourceFunction fsource,
319  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
327  virtual TSelf& set_function(int dev_mask,
328  UnaryFunction funary,
329  SimpleOpInplaceOption inplace_in_out,
330  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
338  virtual TSelf& set_function(int dev_mask,
339  BinaryFunction fbinary,
340  SimpleOpInplaceOption inplace_lhs_out,
341  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
348  virtual TSelf& set_gradient(int dev_mask,
349  UnaryGradFunctionT0 fgrad,
350  SimpleOpInplaceOption inplace_out_in_grad) = 0;
357  virtual TSelf& set_gradient(int dev_mask,
358  UnaryGradFunctionT1 fgrad,
359  SimpleOpInplaceOption inplace_out_in_grad) = 0;
366  virtual TSelf& set_gradient(int dev_mask,
367  UnaryGradFunctionT2 fgrad,
368  SimpleOpInplaceOption inplace_out_in_grad) = 0;
375  virtual TSelf& set_gradient(int dev_mask,
376  BinaryGradFunctionT0 fgrad,
377  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
384  virtual TSelf& set_gradient(int dev_mask,
385  BinaryGradFunctionT1 fgrad,
386  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
392  virtual TSelf& describe(const std::string& description) = 0;
399  virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo>& args) = 0;
401  virtual ~SimpleOpRegEntry() {}
402 };
403 
406  public:
412  SimpleOpRegEntry& __REGISTER_OR_FIND__(char const* name);
418  inline static const SimpleOpRegEntry* Find(const std::string& name) {
419  return Get()->fmap_.at(name);
420  }
422  static SimpleOpRegistry* Get();
423 
424  private:
425  // destructor
426  ~SimpleOpRegistry();
428  std::map<std::string, SimpleOpRegEntry*> fmap_;
429 };
430 
439 #define ASSIGN_DISPATCH(out, req, exp) \
440  { \
441  switch (req) { \
442  case kNullOp: \
443  break; \
444  case kWriteTo: \
445  case kWriteInplace: \
446  (out) = (exp); \
447  break; \
448  case kAddTo: \
449  (out) += (exp); \
450  break; \
451  default: \
452  LOG(FATAL) << "not reached"; \
453  } \
454  }
455 
459 #define MXNET_SPECIAL_MAX_NDIM 5
460 
461 //--------------------------------------------------------------
462 // The following part are API Registration of Simple Operators
463 //--------------------------------------------------------------
481 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
482  static ::mxnet::op::SimpleOpRegEntry& __make_##SimpleOpRegEntry##_##Name##__##DEV##__ = \
483  ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
484 
485 } // namespace op
486 } // namespace mxnet
487 #endif // MXNET_OPERATOR_UTIL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::op::SimpleOpRegEntry::~SimpleOpRegEntry
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:401
mxnet::op::SimpleOpRegEntry::set_function
virtual TSelf & set_function(int dev_mask, SourceFunction fsource, SimpleOpRegOption register_symbolic=kRegisterSymbolic)=0
set function of the function to be fsource
mxnet::op::SimpleOpRegOption
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:248
mxnet::op::SimpleOpRegEntry::describe
virtual TSelf & describe(const std::string &description)=0
Describe the function.
mxnet::ResourceRequest
The resources that can be requested by Operator.
Definition: resource.h:38
mxnet::op::SimpleOpRegEntry::set_enable_kwargs
virtual TSelf & set_enable_kwargs(bool enable_kwargs)=0
set whether to enable kwargs A function cannot have both kwargs and scalar arguments....
mxnet::op::SimpleOpRegEntry::set_shape_function
virtual TSelf & set_shape_function(SourceShapeFunction fshapeinfer)=0
set source inference function.
mxnet::op::Input1
Second input to the function.
Definition: operator_util.h:61
mxnet::op::kInplaceLhsOut
@ kInplaceLhsOut
in binary forward, allow inplace left operand with out
Definition: operator_util.h:239
mxnet::op::EnvArguments::scalar
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:74
mxnet::OpReqType
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
mxnet::op::SimpleOpRegistry
registry for TBlob functions
Definition: operator_util.h:405
parameter.h
Provide lightweight util to do parameter setup and checking.
mxnet::op::BinaryGradFunctionT0
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input....
Definition: operator_util.h:201
mxnet::op::UnaryFunction
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:107
mxnet::RunContext
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
mxnet::op::BinaryShapeFunction
mxnet::TShape(* BinaryShapeFunction)(const mxnet::TShape &lhs, const mxnet::TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:187
mxnet::op::SimpleOpRegEntry::TSelf
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:254
mxnet::op::SimpleOpInplaceOption
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:231
mxnet::op::GradFunctionArgument
super class of all gradient function argument
Definition: operator_util.h:53
mxnet::op::SimpleOpScalarOption
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:245
mxnet::op::SimpleOpRegEntry
registry entry to register simple operators via functions.
Definition: operator_util.h:251
mxnet::op::kScalarBeforeArray
@ kScalarBeforeArray
Definition: operator_util.h:245
mxnet::op::SimpleOpRegEntry::set_resource_request
virtual TSelf & set_resource_request(const std::vector< ResourceRequest > &reqs)=0
set resource request By default there is no resource request. The resource will be presented in both ...
mxnet::op::Input0
First input to the function.
Definition: operator_util.h:59
mxnet::op::SimpleOpRegEntry::name
std::string name
name of the operator
Definition: operator_util.h:256
mxnet::op::SimpleOpRegEntry::add_arguments
virtual TSelf & add_arguments(const std::vector< dmlc::ParamFieldInfo > &args)=0
Describe the function.
mxnet::op::SimpleOpRegEntry::set_gradient
virtual TSelf & set_gradient(int dev_mask, UnaryGradFunctionT0 fgrad, SimpleOpInplaceOption inplace_out_in_grad)=0
set gradient of the function of this function.
mxnet::op::kInplaceOutIn
@ kInplaceOutIn
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:237
mxnet::op::OutputValue
Ouput value of the function to the function.
Definition: operator_util.h:64
mxnet::op::EnvArguments::kwargs
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:76
mxnet::TBlob
tensor blob class that can be used to hold tensor of any dimension, any device and any data type,...
Definition: tensor_blob.h:65
mxnet::op::kArrayBeforeScalar
@ kArrayBeforeScalar
Definition: operator_util.h:245
mxnet::op::GradFunctionArgument::data
TBlob data
The real data.
Definition: operator_util.h:55
mxnet::op::EnvArguments
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:72
mxnet::op::UnaryShapeFunction
mxnet::TShape(* UnaryShapeFunction)(const mxnet::TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:118
mxnet::op::SimpleOpRegistry::Get
static SimpleOpRegistry * Get()
mxnet::op::UnaryGradFunctionT0
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input.
Definition: operator_util.h:128
mxnet::op::kNotRegisterSymbolic
@ kNotRegisterSymbolic
Definition: operator_util.h:248
mxnet::op::SimpleOpRegEntry::set_symbol_op_name
virtual TSelf & set_symbol_op_name(char const *symbol_name)=0
set a seperate name for symbol This must be called before set_function. Default: this is set to be sa...
mxnet::op::SimpleOpRegistry::__REGISTER_OR_FIND__
SimpleOpRegEntry & __REGISTER_OR_FIND__(char const *name)
Internal function to register a name function under name.
mxnet::op::kInplaceOutLhs
@ kInplaceOutLhs
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:241
mxnet::op::OutputGrad
Gradient of output value.
Definition: operator_util.h:66
mxnet::op::kNoInplace
@ kNoInplace
do not allow inplace in arguments
Definition: operator_util.h:233
mxnet::TShape
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:440
mxnet::op::SourceShapeFunction
mxnet::TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape.
Definition: operator_util.h:96
mxnet::op::BinaryGradFunctionT1
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:220
operator.h
Operator interface of mxnet.
mxnet::op::UnaryGradFunctionT2
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input.
Definition: operator_util.h:157
mxnet::op::SimpleOpRegEntry::set_enable_scalar
virtual TSelf & set_enable_scalar(bool enable_scalar, SimpleOpScalarOption type_mask=kArrayBeforeScalar)=0
set number of scalar arguments needed to be passed in env A function cannot have both kwargs and scal...
mxnet::op::kRegisterSymbolic
@ kRegisterSymbolic
Definition: operator_util.h:248
registry.h
Registry utility that helps to build registry singletons.
mxnet::op::EnvArguments::resource
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:78
mxnet::op::SourceFunction
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:89
mxnet::op::kInplaceInOut
@ kInplaceInOut
in unary forward, allow inplace in with out
Definition: operator_util.h:235
base.h
configuration of MXNet as well as basic data structure.
mxnet::op::UnaryGradFunctionT1
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input.
Definition: operator_util.h:142
mxnet::op::SimpleOpRegistry::Find
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:418
mxnet::real_t
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:85
mxnet::op::BinaryFunction
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:173