mxnet
operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_OPERATOR_H_
26 #define MXNET_OPERATOR_H_
27 
28 #include <dmlc/base.h>
29 #include <dmlc/json.h>
30 #include <dmlc/logging.h>
31 #include <dmlc/registry.h>
32 #include <nnvm/node.h>
33 #include <vector>
34 #include <map>
35 #include <string>
36 #include <utility>
37 #include "./base.h"
38 #include "./resource.h"
39 #include "./op_attr_types.h"
40 
41 namespace mxnet {
54 class Operator {
55  public:
57  virtual ~Operator() {}
69  virtual void Forward(const OpContext& ctx,
70  const std::vector<TBlob>& in_data,
71  const std::vector<OpReqType>& req,
72  const std::vector<TBlob>& out_data,
73  const std::vector<TBlob>& aux_states) = 0;
102  virtual void Backward(const OpContext& ctx,
103  const std::vector<TBlob>& out_grad,
104  const std::vector<TBlob>& in_data,
105  const std::vector<TBlob>& out_data,
106  const std::vector<OpReqType>& req,
107  const std::vector<TBlob>& in_grad,
108  const std::vector<TBlob>& aux_states) {
109  LOG(FATAL) << "Backward is not implemented";
110  }
112  virtual ExecType exec_type() // NOLINT(*) exec_type has been moved to OperatorProperty
113  const final { // NOLINT(*) exec_type has been moved to OperatorProperty
114  return ExecType::kSync;
115  }
116 };
117 
118 #if DMLC_USE_CXX11
119 // OperatorProperty allows C++11, while Operator do not rely on it.
128  public:
132  virtual ~OperatorProperty() {}
138  virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
143  virtual std::map<std::string, std::string> GetParams() const = 0;
148  virtual std::vector<std::string> ListArguments() const {
149  return {"data"};
150  }
155  virtual std::vector<std::string> ListOutputs() const {
156  return {"output"};
157  }
162  virtual std::vector<std::string> ListAuxiliaryStates() const {
163  return {};
164  }
166  virtual int NumOutputs() const {
167  return this->ListOutputs().size();
168  }
181  virtual int NumVisibleOutputs() const {
182  return NumOutputs();
183  }
201  virtual bool InferShape(mxnet::ShapeVector* in_shape,
202  mxnet::ShapeVector* out_shape,
203  mxnet::ShapeVector* aux_shape) const = 0;
221  virtual bool InferType(std::vector<int>* in_type,
222  std::vector<int>* out_type,
223  std::vector<int>* aux_type) const {
224  CHECK_LE(in_type->size(), this->ListArguments().size());
225  int n_in = this->ListArguments().size();
226  for (unsigned i = 0; i < in_type->size(); ++i) {
227  CHECK(in_type->at(i) == mshadow::default_type_flag || in_type->at(i) == -1)
228  << "Unsupported data type " << in_type->at(i);
229  }
230  in_type->clear();
231  for (int i = 0; i < n_in; ++i)
232  in_type->push_back(mshadow::default_type_flag);
233 
234  int n_out = this->ListOutputs().size();
235  out_type->clear();
236  for (int i = 0; i < n_out; ++i)
237  out_type->push_back(mshadow::default_type_flag);
238 
239  int n_aux = this->ListAuxiliaryStates().size();
240  aux_type->clear();
241  for (int i = 0; i < n_aux; ++i)
242  aux_type->push_back(mshadow::default_type_flag);
243  return true;
244  }
249  virtual OperatorProperty* Copy() const = 0;
253  virtual Operator* CreateOperator(Context ctx) const = 0;
262  mxnet::ShapeVector* in_shape,
263  std::vector<int>* in_type) const {
264  std::vector<int> out_type, aux_type;
265  mxnet::ShapeVector out_shape, aux_shape;
266  out_type.resize(this->ListOutputs().size());
267  out_shape.resize(this->ListOutputs().size());
268  aux_type.resize(this->ListAuxiliaryStates().size());
269  aux_shape.resize(this->ListAuxiliaryStates().size());
270  CHECK(InferType(in_type, &out_type, &aux_type));
271  CHECK(InferShape(in_shape, &out_shape, &aux_shape));
272  return CreateOperator(ctx);
273  }
279  virtual std::string TypeString() const = 0;
280  //--------------------------------------------------------
281  // All the below functions are optional to override.
282  //--------------------------------------------------------
290  virtual std::vector<ResourceRequest> ForwardResource(const mxnet::ShapeVector& in_shape) const {
291  return std::vector<ResourceRequest>();
292  }
300  virtual std::vector<ResourceRequest> BackwardResource(const mxnet::ShapeVector& in_shape) const {
301  return std::vector<ResourceRequest>();
302  }
325  virtual std::vector<int> DeclareBackwardDependency(const std::vector<int>& out_grad,
326  const std::vector<int>& in_data,
327  const std::vector<int>& out_data) const {
328  // By default requires to see all the things.
329  // remember to override this function to get a better performance.
330  std::vector<int> ret = out_grad;
331  ret.insert(ret.end(), in_data.begin(), in_data.end());
332  ret.insert(ret.end(), out_data.begin(), out_data.end());
333  return ret;
334  }
356  virtual std::vector<std::pair<int, void*> > ForwardInplaceOption(
357  const std::vector<int>& in_data,
358  const std::vector<void*>& out_data) const {
359  return std::vector<std::pair<int, void*> >();
360  }
387  virtual std::vector<std::pair<int, void*> > BackwardInplaceOption(
388  const std::vector<int>& out_grad,
389  const std::vector<int>& in_data,
390  const std::vector<int>& out_data,
391  const std::vector<void*>& in_grad) const {
392  return std::vector<std::pair<int, void*> >();
393  }
406  template <typename T>
407  inline std::vector<T> BackwardInputs(const std::vector<T>& out_grad,
408  const std::vector<T>& in_data,
409  const std::vector<T>& out_data) const {
410  int counter = 0;
411  std::vector<int> out_grad_index(out_grad.size());
412  std::vector<int> in_data_index(in_data.size());
413  std::vector<int> out_data_index(out_data.size());
414  for (size_t i = 0; i < out_grad_index.size(); ++i) {
415  out_grad_index[i] = counter++;
416  }
417  for (size_t i = 0; i < in_data_index.size(); ++i) {
418  in_data_index[i] = counter++;
419  }
420  for (size_t i = 0; i < out_data_index.size(); ++i) {
421  out_data_index[i] = counter++;
422  }
423  std::vector<T> all_data;
424  all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
425  all_data.insert(all_data.end(), in_data.begin(), in_data.end());
426  all_data.insert(all_data.end(), out_data.begin(), out_data.end());
427 
428  std::vector<int> ret_index =
429  this->DeclareBackwardDependency(out_grad_index, in_data_index, out_data_index);
430 
431  std::vector<T> ret(ret_index.size());
432  for (size_t i = 0; i < ret_index.size(); ++i) {
433  ret[i] = all_data[ret_index[i]];
434  }
435  return ret;
436  }
442  static OperatorProperty* Create(const char* type_name);
444  virtual ExecType exec_type() const {
445  return ExecType::kSync;
446  }
447 };
448 
450 typedef std::function<OperatorProperty*()> OperatorPropertyFactory;
455  : public dmlc::FunctionRegEntryBase<OperatorPropertyReg, OperatorPropertyFactory> {
468  inline OperatorPropertyReg& set_key_var_num_args(const std::string& key) { // NOLINT(*)
469  this->key_var_num_args = key;
470  return *this;
471  }
476  OperatorProperty* p = this->body();
477  std::string type = p->TypeString();
478  delete p;
479  CHECK_EQ(this->name, type) << "Register Name and TypeString mismatch, name=\"" << this->name
480  << "\","
481  << " but TypeString=\"" << type << "\"";
482  return *this;
483  }
484 
486  std::string key_var_num_args;
487 };
488 
489 //---------------------------------------------------------------------------------
490 // The following part are API Registration of Operators
491 // See also MXNET_REGISTER_SIMPLE_OP in operator_util.h for registering simple ops.
492 //---------------------------------------------------------------------------------
503 #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
504  DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
505  .set_body([]() { return new OperatorPropertyType(); }) \
506  .set_return_type("NDArray-or-Symbol") \
507  .check_name()
508 
509 #endif // DMLC_USE_CXX11
510 } // namespace mxnet
511 #endif // MXNET_OPERATOR_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::OperatorProperty::GetParams
virtual std::map< std::string, std::string > GetParams() const =0
Get a map representation of internal parameters. This can be used by Init to recover the state of Ope...
mshadow::default_type_flag
const int default_type_flag
type enum value for default real type
Definition: base.h:492
dmlc::FunctionRegEntryBase
Common base class for function registry.
Definition: registry.h:151
mxnet::OperatorProperty
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:127
mxnet::OperatorPropertyReg::check_name
OperatorPropertyReg & check_name()
Check if TypeString of the type matches the registered name.
Definition: operator.h:475
mxnet::OperatorProperty::Init
virtual void Init(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Initialize the Operator by setting the parameters This function need to be called before all other fu...
mxnet::OperatorProperty::InferType
virtual bool InferType(std::vector< int > *in_type, std::vector< int > *out_type, std::vector< int > *aux_type) const
infer the data types of outputs and unknown input arguments
Definition: operator.h:221
mxnet::Operator
Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet....
Definition: operator.h:54
mxnet::OperatorProperty::CreateOperator
virtual Operator * CreateOperator(Context ctx) const =0
Create a Operator on specific context.
mxnet::OperatorProperty::exec_type
virtual ExecType exec_type() const
Definition: operator.h:444
base.h
defines configuration macros
mxnet::OperatorProperty::TypeString
virtual std::string TypeString() const =0
return the type string of the Operator subclasses override this function.
mxnet::OperatorProperty::~OperatorProperty
virtual ~OperatorProperty()
virtual destructor
Definition: operator.h:132
mxnet::OperatorProperty::InferShape
virtual bool InferShape(mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape, mxnet::ShapeVector *aux_shape) const =0
infer the shapes of outputs and unknown input arguments
mxnet::OpContext
All the possible information needed by Operator. This is the superset of RunContext....
Definition: op_attr_types.h:66
mxnet::OperatorPropertyReg
Registry entry for OperatorProperty factory functions.
Definition: operator.h:454
mxnet::OperatorProperty::NumOutputs
virtual int NumOutputs() const
Definition: operator.h:166
mxnet::OperatorPropertyReg::set_key_var_num_args
OperatorPropertyReg & set_key_var_num_args(const std::string &key)
Set key_var_num_args When this is set, the API caller is required to pass in a argument with key=key_...
Definition: operator.h:468
mxnet::Operator::Forward
virtual void Forward(const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0
perform a forward operation of Operator, save the output to TBlob.
dmlc::type_name
std::string type_name()
the string representation of type name
Definition: type_traits.h:101
mxnet::OperatorPropertyFactory
std::function< OperatorProperty *()> OperatorPropertyFactory
typedef the factory function of operator property
Definition: operator.h:450
mxnet::OperatorProperty::Copy
virtual OperatorProperty * Copy() const =0
Copy this OperatorProperty.
mxnet::OperatorProperty::DeclareBackwardDependency
virtual std::vector< int > DeclareBackwardDependency(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data) const
Declare the input requirement of Backward pass.
Definition: operator.h:325
resource.h
Global resource allocation handling.
mxnet::ExecType
ExecType
the execution type of the operator
Definition: op_attr_types.h:98
mxnet::ExecType::kSync
@ kSync
Forward/Backward are synchronous calls.
mxnet::OperatorProperty::ListOutputs
virtual std::vector< std::string > ListOutputs() const
Get name of output values of Operator.
Definition: operator.h:155
mxnet::Context
Context information about the execution environment.
Definition: base.h:90
mxnet::OperatorProperty::BackwardInplaceOption
virtual std::vector< std::pair< int, void * > > BackwardInplaceOption(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data, const std::vector< void * > &in_grad) const
Get possible backward inplace options. This function enables optimization to reuse memory of inputs i...
Definition: operator.h:387
op_attr_types.h
Additional operator attributes beside the ones provided by NNVM.
mxnet::OperatorPropertyReg::key_var_num_args
std::string key_var_num_args
The key num_args name.
Definition: operator.h:486
mxnet::OperatorProperty::ForwardInplaceOption
virtual std::vector< std::pair< int, void * > > ForwardInplaceOption(const std::vector< int > &in_data, const std::vector< void * > &out_data) const
Get possible forward inplace options. This function enables optimization to reuse memory of inputs in...
Definition: operator.h:356
mxnet::OperatorProperty::ListArguments
virtual std::vector< std::string > ListArguments() const
Get input arguments of the Operator.
Definition: operator.h:148
mxnet::OperatorProperty::NumVisibleOutputs
virtual int NumVisibleOutputs() const
get number of visible return values during Symbol creation. If NumVisibleOutputs() = k,...
Definition: operator.h:181
mxnet::Operator::exec_type
virtual ExecType exec_type() const final
Definition: operator.h:112
mxnet::OperatorProperty::BackwardResource
virtual std::vector< ResourceRequest > BackwardResource(const mxnet::ShapeVector &in_shape) const
Declare additional resource required in backward pass. These additional resources will be presented i...
Definition: operator.h:300
mxnet::ShapeVector
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:830
json.h
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
dmlc::FunctionRegEntryBase< OperatorPropertyReg, OperatorPropertyFactory >::body
OperatorPropertyFactory body
Function body to create ProductType.
Definition: registry.h:160
registry.h
Registry utility that helps to build registry singletons.
node.h
Graph node data structure.
mxnet::OperatorProperty::CreateOperatorEx
virtual Operator * CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, std::vector< int > *in_type) const
Create a Operator on specific context and input shape/type.
Definition: operator.h:261
mxnet::OperatorProperty::BackwardInputs
std::vector< T > BackwardInputs(const std::vector< T > &out_grad, const std::vector< T > &in_data, const std::vector< T > &out_data) const
Get Backward Input Dependency for generic types of data. Normally T can be pointer of Symbol::DataEnt...
Definition: operator.h:407
mxnet::Operator::~Operator
virtual ~Operator()
destructor
Definition: operator.h:57
mxnet::OperatorProperty::Create
static OperatorProperty * Create(const char *type_name)
create OperatorProperty
mxnet::Operator::Backward
virtual void Backward(const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states)
Perform a Backward Operation, write gradient to the in_grad.
Definition: operator.h:102
mxnet::OperatorProperty::ForwardResource
virtual std::vector< ResourceRequest > ForwardResource(const mxnet::ShapeVector &in_shape) const
Declare additional resource required in forward pass. These additional resources will be presented in...
Definition: operator.h:290
dmlc::FunctionRegEntryBase< OperatorPropertyReg, OperatorPropertyFactory >::name
std::string name
name of the entry
Definition: registry.h:154
mxnet::OperatorProperty::ListAuxiliaryStates
virtual std::vector< std::string > ListAuxiliaryStates() const
Get name of auxiliary states of Operator.
Definition: operator.h:162