mxnet
Classes | Public Member Functions | Static Public Member Functions | Friends | List of all members
mxnet::Imperative Class Reference

runtime functions for NDArray More...

#include <imperative.h>

Collaboration diagram for mxnet::Imperative:
Collaboration graph

Classes

class  AGInfo
 

Public Member Functions

bool is_training () const
 whether operator recording is on. More...
 
bool set_is_training (bool is_train)
 turn on or turn off operator recording for autograd. More...
 
bool is_recording () const
 whether operator recording is on. More...
 
bool set_is_recording (bool is_recording)
 turn on or turn off operator recording for autograd. More...
 
int is_np_shape () const
 return current numpy compatibility status, GlobalOn(2), ThreadLocalOn(1), Off(0). More...
 
bool is_np_default_dtype () const
 return current numpy default dtype compatibility status. More...
 
bool set_is_np_shape (int is_np_shape)
 specify numpy compatibility off, thread local on or global on. More...
 
void RecordOp (nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
 to record operator, return corresponding node. More...
 
OpStatePtr Invoke (const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
 
OpStatePtr InvokeOp (const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
 
void MarkVariables (const std::vector< NDArray * > &variables, const std::vector< uint32_t > &grad_reqs, const std::vector< NDArray * > &gradients)
 mark variables for computing gradients. More...
 
std::vector< NDArray * > Backward (const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
 compute the gradient of outputs w.r.t variables. More...
 

Static Public Member Functions

static ImperativeGet ()
 
static bool PreferBulkExecInference ()
 Should op execution bulking be employed during inference. More...
 
static bool PreferBulkExecTrain ()
 Should op execution bulking be employed during training. More...
 
static int BulkExecMaxNodeTrainFwd ()
 The max number of op nodes in a bulk during forward pass of training. More...
 
static int BulkExecMaxNodeTrainBwd ()
 The max number of op nodes in a bulk during backward pass of training. More...
 

Friends

class NDArray
 

Detailed Description

runtime functions for NDArray

Member Function Documentation

std::vector<NDArray*> mxnet::Imperative::Backward ( const std::vector< NDArray * > &  outputs,
const std::vector< NDArray * > &  ograds,
const std::vector< NDArray * > &  variables,
bool  is_train,
bool  retain_graph,
bool  create_graph 
)

compute the gradient of outputs w.r.t variables.

static int mxnet::Imperative::BulkExecMaxNodeTrainBwd ( )
inlinestatic

The max number of op nodes in a bulk during backward pass of training.

static int mxnet::Imperative::BulkExecMaxNodeTrainFwd ( )
inlinestatic

The max number of op nodes in a bulk during forward pass of training.

static Imperative* mxnet::Imperative::Get ( )
static
Returns
AutogradRuntime singleton
OpStatePtr mxnet::Imperative::Invoke ( const Context default_ctx,
const nnvm::NodeAttrs attrs,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs 
)
OpStatePtr mxnet::Imperative::InvokeOp ( const Context ctx,
const nnvm::NodeAttrs attrs,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs,
const std::vector< OpReqType > &  req,
const DispatchMode  dispatch_mode,
OpStatePtr  state = OpStatePtr() 
)
bool mxnet::Imperative::is_np_default_dtype ( ) const
inline

return current numpy default dtype compatibility status.

int mxnet::Imperative::is_np_shape ( ) const
inline

return current numpy compatibility status, GlobalOn(2), ThreadLocalOn(1), Off(0).

bool mxnet::Imperative::is_recording ( ) const
inline

whether operator recording is on.

bool mxnet::Imperative::is_training ( ) const
inline

whether operator recording is on.

void mxnet::Imperative::MarkVariables ( const std::vector< NDArray * > &  variables,
const std::vector< uint32_t > &  grad_reqs,
const std::vector< NDArray * > &  gradients 
)

mark variables for computing gradients.

static bool mxnet::Imperative::PreferBulkExecInference ( )
inlinestatic

Should op execution bulking be employed during inference.

static bool mxnet::Imperative::PreferBulkExecTrain ( )
inlinestatic

Should op execution bulking be employed during training.

void mxnet::Imperative::RecordOp ( nnvm::NodeAttrs &&  attrs,
const std::vector< NDArray * > &  inputs,
const std::vector< NDArray * > &  outputs,
const OpStatePtr state = OpStatePtr(),
std::vector< bool > *  p_save_inputs = nullptr,
std::vector< bool > *  p_save_outputs = nullptr 
)

to record operator, return corresponding node.

bool mxnet::Imperative::set_is_np_shape ( int  is_np_shape)
inline

specify numpy compatibility off, thread local on or global on.

bool mxnet::Imperative::set_is_recording ( bool  is_recording)
inline

turn on or turn off operator recording for autograd.

bool mxnet::Imperative::set_is_training ( bool  is_train)
inline

turn on or turn off operator recording for autograd.

Friends And Related Function Documentation

friend class NDArray
friend

The documentation for this class was generated from the following file: