mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
Classes | Public Member Functions | Static Public Member Functions | Friends | List of all members
mxnet::Imperative Class Reference

runtime functions for NDArray More...

#include <imperative.h>

Collaboration diagram for mxnet::Imperative:
Collaboration graph

Classes

class  AGInfo
 
class  CachedOp
 

Public Member Functions

bool is_training () const
 whether operator recording is on. More...
 
bool set_is_training (bool is_train)
 turn on or turn off operator recording for autograd. More...
 
bool is_recording () const
 whether operator recording is on. More...
 
bool set_is_recording (bool is_recording)
 turn on or turn off operator recording for autograd. More...
 
void RecordOp (nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
 to record operator, return corresponding node. More...
 
OpStatePtr Invoke (const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
 
OpStatePtr InvokeOp (const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
 
void MarkVariables (const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
 mark variables for computing gradients. More...
 
std::vector< NDArray * > Backward (const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
 compute the gradient of outputs w.r.t variables. More...
 

Static Public Member Functions

static ImperativeGet ()
 

Friends

class NDArray
 

Detailed Description

runtime functions for NDArray

Member Function Documentation

std::vector<NDArray*> mxnet::Imperative::Backward ( const std::vector< NDArray * > &  outputs,
const std::vector< NDArray * > &  ograds,
const std::vector< NDArray * > &  variables,
bool  is_train,
bool  retain_graph,
bool  create_graph 
)

compute the gradient of outputs w.r.t variables.

static Imperative* mxnet::Imperative::Get ( )
static
Returns
AutogradRuntime singleton
OpStatePtr mxnet::Imperative::Invoke ( const Context default_ctx,
const nnvm::NodeAttrs &  attrs,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs 
)
OpStatePtr mxnet::Imperative::InvokeOp ( const Context ctx,
const nnvm::NodeAttrs &  attrs,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs,
const std::vector< OpReqType > &  req,
const DispatchMode  dispatch_mode,
OpStatePtr  state = OpStatePtr() 
)
bool mxnet::Imperative::is_recording ( ) const
inline

whether operator recording is on.

bool mxnet::Imperative::is_training ( ) const
inline

whether operator recording is on.

void mxnet::Imperative::MarkVariables ( const std::vector< NDArray * > &  variables,
const std::vector< mx_uint > &  grad_reqs,
const std::vector< NDArray * > &  gradients 
)

mark variables for computing gradients.

void mxnet::Imperative::RecordOp ( nnvm::NodeAttrs &&  attrs,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs,
const OpStatePtr state = OpStatePtr(),
std::vector< bool > *  p_save_inputs = nullptr,
std::vector< bool > *  p_save_outputs = nullptr 
)

to record operator, return corresponding node.

bool mxnet::Imperative::set_is_recording ( bool  is_recording)
inline

turn on or turn off operator recording for autograd.

bool mxnet::Imperative::set_is_training ( bool  is_train)
inline

turn on or turn off operator recording for autograd.

Friends And Related Function Documentation

friend class NDArray
friend

The documentation for this class was generated from the following file: