mxnet
Namespaces | Classes | Typedefs | Enumerations | Functions | Variables
mxnet Namespace Reference

namespace of mxnet More...

Namespaces

 common
 
 cpp
 
 csr
 
 engine
 namespace of engine internal types.
 
 features
 
 op
 namespace of arguments
 
 rowsparse
 
 runtime
 

Classes

class  Array
 Array container of NodeRef in DSL graph. Array implements copy on write semantics, which means array is mutable but copy will happen when array is referenced in more than two places. More...
 
class  ArrayNode
 array node content in array More...
 
class  BaseExpr
 Managed reference to BaseExprNode. More...
 
class  BaseExprNode
 Base type of all the expressions. More...
 
struct  Context
 Context information about the execution environment. More...
 
struct  DataBatch
 DataBatch of NDArray, returned by Iterator. More...
 
struct  DataInst
 a single data instance More...
 
struct  DataIteratorReg
 Registry entry for DataIterator factory functions. More...
 
class  Engine
 Dependency engine that schedules operations. More...
 
class  Executor
 Executor of a computation graph. Executor can be created by Binding a symbol. More...
 
class  FloatImm
 Managed reference class to FloatImmNode. More...
 
class  FloatImmNode
 Constant floating point literals in the program. More...
 
class  GPUAuxStream
 Holds an auxiliary mshadow gpu stream that can be synced with a primary stream. More...
 
class  IIterator
 iterator type More...
 
class  Imperative
 runtime functions for NDArray More...
 
struct  InspectorManager
 this singleton struct mediates individual TensorInspector objects so that we can control the global behavior from each of them More...
 
class  IntImm
 Managed reference class to IntImmNode. More...
 
class  IntImmNode
 Constant integer literals in the program. More...
 
class  IterAdapter
 iterator adapter that adapts TIter to return another type. More...
 
class  KVStore
 distributed key-value store More...
 
class  NDArray
 ndarray interface More...
 
struct  NDArrayFunctionReg
 Registry entry for NDArrayFunction. More...
 
struct  OpContext
 All the possible information needed by Operator.Forward and Backward This is the superset of RunContext. We use this data structure to bookkeep everything needed by Forward and Backward. More...
 
class  Operator
 Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet. This interface relies on pre-allocated memory in TBlob, the caller need to set the memory region in TBlob correctly before calling Forward and Backward. More...
 
class  OperatorProperty
 OperatorProperty is a object that stores all information about Operator. It also contains method to generate context(device) specific operators. More...
 
struct  OperatorPropertyReg
 Registry entry for OperatorProperty factory functions. More...
 
class  OpStatePtr
 Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const. More...
 
class  PrimExpr
 Reference to PrimExprNode. More...
 
class  PrimExprNode
 Base node of all primitive expressions. More...
 
struct  Resource
 Resources used by mxnet operations. A resource is something special other than NDArray, but will still participate. More...
 
class  ResourceManager
 Global resource manager. More...
 
struct  ResourceRequest
 The resources that can be requested by Operator. More...
 
struct  RunContext
 execution time context. The information needed in runtime for actual execution. More...
 
class  Storage
 Storage manager across multiple devices. More...
 
class  SyncedGPUAuxStream
 Provides automatic coordination of an auxilary stream with a primary one. This object, upon construction, prepares an aux stream for use by syncing it with enqueued primary-stream work. Object destruction will sync again so future primary-stream work will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults simply: the primary stream will equal the aux stream and the syncs will be executed as nops. See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example. More...
 
class  TBlob
 tensor blob class that can be used to hold tensor of any dimension, any device and any data type, This is a weak type that can be used to transfer data through interface TBlob itself doesn't involve any arithmetic operations, but it can be converted to tensor of fixed dimension for further operations More...
 
class  TensorInspector
 This class provides a unified interface to inspect the value of all data types including Tensor, TBlob, and NDArray. If the tensor resides on GPU, then it will be copied from GPU memory back to CPU memory to be operated on. Internally, all data types are stored as a TBlob object tb_. More...
 
class  TShape
 A Shape class that is used to represent shape of each tensor. More...
 
class  Tuple
 A dynamic sized array data structure that is optimized for storing small number of elements with same type. More...
 

Typedefs

typedef mshadow::cpu cpu
 mxnet cpu More...
 
typedef mshadow::gpu gpu
 mxnet gpu More...
 
typedef mshadow::index_t index_t
 index type usually use unsigned More...
 
typedef mshadow::default_real_t real_t
 data type that will be used to store ndarray More...
 
using Op = nnvm::Op
 operator structure from NNVM More...
 
using StorageTypeVector = std::vector< int >
 The result holder of storage type of each NodeEntry in the graph. More...
 
using DispatchModeVector = std::vector< DispatchMode >
 The result holder of dispatch mode of each Node in the graph. More...
 
typedef std::function< IIterator< DataBatch > *()> DataIteratorFactory
 typedef the factory function of data iterator More...
 
typedef std::function< void(NDArray **used_vars, real_t *scalars, NDArray **mutate_vars, int num_params, char **param_keys, char **param_vals)> NDArrayAPIFunction
 definition of NDArray function More...
 
using FCreateOpState = std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const mxnet::ShapeVector &in_shape, const std::vector< int > &in_type)>
 Create a Layer style, forward/backward operator. This is easy to write code that contains state. OpStatePtr is a pointer type, it's content is mutable even if OpStatePtr is constant. More...
 
using THasDeterministicOutput = bool
 Whether the operator always produces the same output given the same input. This enables certain optimizations like common expression elimination. More...
 
using FExecType = std::function< ExecType(const NodeAttrs &attrs)>
 Execution mode of this operator. More...
 
using FStatefulCompute = std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)>
 Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it's content is mutable even if OpStatePtr is constant. More...
 
using FStatefulComputeEx = std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)>
 Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer type, it's content is mutable even if OpStatePtr is constant. More...
 
using FResourceRequest = std::function< std::vector< ResourceRequest >(const NodeAttrs &n)>
 The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither. More...
 
using FResourceRequestEx = std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)>
 The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither. If an operator registers both ResourceRequestEx and ResourceRequest, ResourceRequest is ignored. More...
 
using FNDArrayFunction = std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)>
 Register an operator called as a NDArray function. More...
 
using FCompute = std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)>
 Register a compute function for simple stateless forward only operator. More...
 
using FComputeEx = std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)>
 Register an NDArray compute function for simple stateless forward only operator. More...
 
using FInferStorageType = std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)>
 Register a storage and dispatch mode inference function based on storage types of the inputs and outputs, and the dev_mask for the operator. More...
 
using FQuantizable = std::function< QuantizeType(const NodeAttrs &attrs)>
 Register a quantized node creation function based on the attrs of the node. More...
 
using FQuantizedOp = std::function< nnvm::ObjectPtr(const NodeAttrs &attrs)>
 Register a quantized node creation function based on the attrs of the node. More...
 
using FNeedRequantize = std::function< bool(const NodeAttrs &attrs)>
 Register a function to determine if the output of a quantized operator needs to be requantized. This is usually used for the operators taking int8 data types while accumulating in int32, e.g. quantized_conv. More...
 
using FAvoidQuantizeInput = std::function< bool(const NodeAttrs &attrs, const size_t index, const std::string quantize_granularity)>
 Register a function to determine if the input of a quantized operator needs to be quantized. This is usually used for the quantized operators which can handle fp32 inputs directly. More...
 
using FNeedCalibrateInput = std::function< std::vector< int >(const NodeAttrs &attrs)>
 Register a function to determine if the input of a quantized operator needs to be calibrated. This is usually used for the quantized operators which need calibration on its input. More...
 
using FNeedCalibrateOutput = std::function< std::vector< int >(const NodeAttrs &attrs)>
 Register a function to determine if the output of a quantized operator needs to be calibrated. This is usually used for the quantized operators which need calibration on its output. More...
 
typedef std::function< OperatorProperty *()> OperatorPropertyFactory
 typedef the factory function of operator property More...
 
using MXNetDataType = runtime::MXNetDataType
 
template<typename T >
using NodePtr = runtime::ObjectPtr< T >
 
using ShapeVector = std::vector< mxnet::TShape >
 The result holder of shape of each NodeEntry in the graph. More...
 
using FInferShape = nnvm::FInferNodeEntryAttr< mxnet::TShape >
 Shape inference function. Update the shapes given the input shape information. TShape.ndim() == -1 means the shape is still unknown. More...
 

Enumerations

enum  FnProperty {
  FnProperty::kNormal, FnProperty::kCopyFromGPU, FnProperty::kCopyToGPU, FnProperty::kCPUPrioritized,
  FnProperty::kAsync, FnProperty::kDeleteVar, FnProperty::kGPUPrioritized, FnProperty::kNoSkip
}
 Function property, used to hint what action is pushed to engine. More...
 
enum  NumpyShape { Off, ThreadLocalOn, GlobalOn }
 there are three numpy shape flags based on priority. GlobalOn turn on numpy shape flag globally, it includes thread local. The flag can be seen in any thread. ThreadLocalOn only turn on thread local numpy shape flag, it cannot be seen in other threads. Off turn off numpy shape flag globally. More...
 
enum  KVStoreServerProfilerCommand { KVStoreServerProfilerCommand::kSetConfig, KVStoreServerProfilerCommand::kState, KVStoreServerProfilerCommand::kPause, KVStoreServerProfilerCommand::kDump }
 enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler configs. Similar to mx.profiler.set_config() kState allows changing state of profiler to stop or run kPause allows pausing and resuming of profiler kDump asks profiler to dump output More...
 
enum  NDArrayStorageType { kUndefinedStorage = -1, kDefaultStorage, kRowSparseStorage, kCSRStorage }
 
enum  NDArrayFormatErr {
  kNormalErr, kCSRShapeErr, kCSRIndPtrErr, kCSRIdxErr,
  kRSPShapeErr, kRSPIdxErr
}
 
enum  NDArrayFunctionTypeMask { kNDArrayArgBeforeScalar = 1, kScalarArgBeforeNDArray = 1 << 1, kAcceptEmptyMutateTarget = 1 << 2 }
 mask information on how functions can be exposed More...
 
enum  OpReqType { kNullOp, kWriteTo, kWriteInplace, kAddTo }
 operation request type to Forward and Backward More...
 
enum  ExecType { ExecType::kSync, ExecType::kAsync, ExecType::kCrossDeviceCopy, ExecType::kSubgraphExec }
 the execution type of the operator More...
 
enum  DispatchMode {
  DispatchMode::kUndefined = -1, DispatchMode::kFCompute, DispatchMode::kFComputeEx, DispatchMode::kFComputeFallback,
  DispatchMode::kVariable
}
 the dispatch mode of the operator More...
 
enum  QuantizeType { QuantizeType::kNone = 0, QuantizeType::kMust, QuantizeType::kSupport }
 the quantization type of the operator More...
 
enum  CheckerType {
  NegativeChecker, PositiveChecker, ZeroChecker, NaNChecker,
  InfChecker, PositiveInfChecker, NegativeInfChecker, FiniteChecker,
  NormalChecker, AbnormalChecker
}
 Enum for building value checkers for TensorInspector::check_value() More...
 

Functions

void on_enter_api (const char *function)
 
void on_exit_api ()
 
template<typename ValueType >
PrimExpr MakeConstScalar (MXNetDataType t, ValueType value)
 
template<typename ValueType >
PrimExpr make_const (MXNetDataType t, ValueType value)
 
size_t num_aux_data (NDArrayStorageType stype)
 
void CopyFromTo (const NDArray &from, const NDArray *to, int priority=0)
 issue an copy operation from one NDArray to another the two ndarray can sit on different devices this operation will be scheduled by the engine More...
 
void CopyFromTo (const NDArray &from, const NDArray &to, int priority=0, bool is_opr=false)
 issue an copy operation from one NDArray to another the two ndarray can sit on different devices this operation will be scheduled by the engine More...
 
void ElementwiseSum (const std::vector< NDArray > &source, NDArray *out, int priority=0)
 Perform elementwise sum over each data from source, store result into out. More...
 
NDArray operator+ (const NDArray &lhs, const NDArray &rhs)
 elementwise add More...
 
NDArray operator+ (const NDArray &lhs, const real_t &rhs)
 elementwise add More...
 
NDArray operator- (const NDArray &lhs, const NDArray &rhs)
 elementwise subtraction More...
 
NDArray operator- (const NDArray &lhs, const real_t &rhs)
 elementwise subtraction More...
 
NDArray operator* (const NDArray &lhs, const NDArray &rhs)
 elementwise multiplication More...
 
NDArray operator* (const NDArray &lhs, const real_t &rhs)
 elementwise multiplication More...
 
NDArray operator/ (const NDArray &lhs, const NDArray &rhs)
 elementwise division More...
 
NDArray operator/ (const NDArray &lhs, const real_t &rhs)
 elementwise division More...
 
void RandomSeed (uint32_t seed)
 Seed all random number generator in mxnet. More...
 
void RandomSeed (Context ctx, uint32_t seed)
 Seed the random number generator of the device. More...
 
void SampleUniform (real_t begin, real_t end, NDArray *out)
 Sample uniform distribution for each elements of out. More...
 
void SampleGaussian (real_t mu, real_t sigma, NDArray *out)
 Sample gaussian distribution for each elements of out. More...
 
void SampleGamma (real_t alpha, real_t beta, NDArray *out)
 Sample gamma distribution for each elements of out. More...
 
void SampleExponential (real_t lambda, NDArray *out)
 Sample exponential distribution for each elements of out. More...
 
void SamplePoisson (real_t lambda, NDArray *out)
 Sample Poisson distribution for each elements of out. More...
 
void SampleNegBinomial (int32_t k, real_t p, NDArray *out)
 Sample negative binomial distribution for each elements of out. More...
 
void SampleGenNegBinomial (real_t mu, real_t alpha, NDArray *out)
 Sample generalized negative binomial distribution for each elements of out. More...
 
bool ndim_is_known (const int ndim)
 
bool dim_size_is_known (const dim_t dim_size)
 
bool ndim_is_known (const TShape &x)
 
bool dim_size_is_known (const TShape &x, const int idx)
 
bool shape_is_known (const TShape &x)
 
bool shape_is_known (const std::vector< TShape > &shapes)
 
template<typename SrcIter , typename DstIter >
DstIter ShapeTypeCast (const SrcIter begin, const SrcIter end, DstIter dst_begin)
 helper function to cast type of container elements More...
 
template<typename SrcIter >
TShape ShapeTypeCast (const SrcIter begin, const SrcIter end)
 helper function to transform a container to TShape with type cast More...
 

Variables

constexpr const int kCPU = kDLCPU
 
constexpr const int kGPU = kDLGPU
 
constexpr const int kTVMNDArrayTypeCode = 19
 

Detailed Description

namespace of mxnet

Copyright (c) 2015 by Contributors

Typedef Documentation

mxnet cpu

typedef std::function<IIterator<DataBatch> *()> mxnet::DataIteratorFactory

typedef the factory function of data iterator

using mxnet::DispatchModeVector = typedef std::vector<DispatchMode>

The result holder of dispatch mode of each Node in the graph.

  • *+ *
    Note
    Stored under graph.attrs["dispatch_mode"], provided by Pass "InferStorageType"
  • *
  • *
    + * Graph g = ApplyPass(src_graph, "InferStorageType");
    + * const DispatchModeVector& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
    + * // get dispatch mode by entry node id
    + * int node_type = dispatch_modes[nid];
    + *
  • *
  • *
    See also
    FInferStorageType
using mxnet::FAvoidQuantizeInput = typedef std::function<bool (const NodeAttrs& attrs, const size_t index, const std::string quantize_granularity)>

Register a function to determine if the input of a quantized operator needs to be quantized. This is usually used for the quantized operators which can handle fp32 inputs directly.

using mxnet::FCompute = typedef std::function<void (const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs)>

Register a compute function for simple stateless forward only operator.

Note
Register under "FCompute<cpu>" and "FCompute<gpu>"
using mxnet::FComputeEx = typedef std::function<void (const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<NDArray>& inputs, const std::vector<OpReqType>& req, const std::vector<NDArray>& outputs)>

Register an NDArray compute function for simple stateless forward only operator.

Note
Register under "FComputeEx<xpu>" and "FComputeEx<xpu>" Dispatched only when inferred dispatch_mode is FDispatchComputeEx
using mxnet::FCreateOpState = typedef std::function<OpStatePtr (const NodeAttrs& attrs, Context ctx, const mxnet::ShapeVector& in_shape, const std::vector<int>& in_type)>

Create a Layer style, forward/backward operator. This is easy to write code that contains state. OpStatePtr is a pointer type, it's content is mutable even if OpStatePtr is constant.

This is not the only way to register an op execution function. More simpler or specialized operator form can be registered

Note
Register under "FCreateLayerOp"
using mxnet::FExecType = typedef std::function<ExecType (const NodeAttrs& attrs)>

Execution mode of this operator.

Shape inference function. Update the shapes given the input shape information. TShape.ndim() == -1 means the shape is still unknown.

Note
Register under "FInferShape", by default do not update any shapes.

FInferShape is needed by shape inference

using mxnet::FInferStorageType = typedef std::function<bool (const NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, std::vector<int>* in_attrs, std::vector<int>* out_attrs)>

Register a storage and dispatch mode inference function based on storage types of the inputs and outputs, and the dev_mask for the operator.

Note
Register under "FInferStorageType"
using mxnet::FNDArrayFunction = typedef std::function<void (const nnvm::NodeAttrs& attrs, const std::vector<NDArray>& inputs, std::vector<NDArray>* outputs)>

Register an operator called as a NDArray function.

Note
Register under "FNDArrayFunction"
using mxnet::FNeedCalibrateInput = typedef std::function<std::vector<int> (const NodeAttrs& attrs)>

Register a function to determine if the input of a quantized operator needs to be calibrated. This is usually used for the quantized operators which need calibration on its input.

using mxnet::FNeedCalibrateOutput = typedef std::function<std::vector<int> (const NodeAttrs& attrs)>

Register a function to determine if the output of a quantized operator needs to be calibrated. This is usually used for the quantized operators which need calibration on its output.

using mxnet::FNeedRequantize = typedef std::function<bool (const NodeAttrs& attrs)>

Register a function to determine if the output of a quantized operator needs to be requantized. This is usually used for the operators taking int8 data types while accumulating in int32, e.g. quantized_conv.

Note
Register under "FNeedRequantize" for non-quantized operators
using mxnet::FQuantizable = typedef std::function<QuantizeType (const NodeAttrs& attrs)>

Register a quantized node creation function based on the attrs of the node.

Note
Register under "FQuantizedOp" for non-quantized operators
using mxnet::FQuantizedOp = typedef std::function<nnvm::ObjectPtr (const NodeAttrs& attrs)>

Register a quantized node creation function based on the attrs of the node.

Note
Register under "FQuantizedOp" for non-quantized operators
using mxnet::FResourceRequest = typedef std::function< std::vector<ResourceRequest> (const NodeAttrs& n)>

The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither.

Note
Register under "FResourceRequest"
using mxnet::FResourceRequestEx = typedef std::function< std::vector<ResourceRequest> (const NodeAttrs& n, const int dev_mask, const DispatchMode dispatch_mode)>

The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither. If an operator registers both ResourceRequestEx and ResourceRequest, ResourceRequest is ignored.

Note
Register under "FResourceRequestEx"
using mxnet::FStatefulCompute = typedef std::function<void (const OpStatePtr& state, const OpContext& ctx, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs)>

Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it's content is mutable even if OpStatePtr is constant.

Note
Register under "FStatefulCompute<cpu>" and "FStatefulCompute<gpu>"
using mxnet::FStatefulComputeEx = typedef std::function<void (const OpStatePtr& state, const OpContext& ctx, const std::vector<NDArray>& inputs, const std::vector<OpReqType>& req, const std::vector<NDArray>& outputs)>

Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer type, it's content is mutable even if OpStatePtr is constant.

Note
Register under "FStatefulComputeEx<cpu>" and "FStatefulComputeEx<gpu>"

mxnet gpu

index type usually use unsigned

typedef std::function<void (NDArray **used_vars, real_t *scalars, NDArray **mutate_vars, int num_params, char **param_keys, char **param_vals)> mxnet::NDArrayAPIFunction

definition of NDArray function

template<typename T >
using mxnet::NodePtr = typedef runtime::ObjectPtr<T>
using mxnet::Op = typedef nnvm::Op

operator structure from NNVM

typedef the factory function of operator property

data type that will be used to store ndarray

using mxnet::ShapeVector = typedef std::vector<mxnet::TShape>

The result holder of shape of each NodeEntry in the graph.

Note
Stored under graph.attrs["shape"], provided by Pass "InferShape"
Graph g = ApplyPass(src_graph, "InferShape");
const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
// get shape by entry id
TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
See also
FInferShape
using mxnet::StorageTypeVector = typedef std::vector<int>

The result holder of storage type of each NodeEntry in the graph.

Note
Stored under graph.attrs["storage_type"], provided by Pass "InferStorageType"
Graph g = ApplyPass(src_graph, "InferStorageType");
const StorageVector& stypes = g.GetAttr<StorageTypeVector>("storage_type");
// get storage type by entry id
int entry_type = stypes[g.indexed_graph().entry_id(my_entry)];
See also
FInferStorageType
using mxnet::THasDeterministicOutput = typedef bool

Whether the operator always produces the same output given the same input. This enables certain optimizations like common expression elimination.

Note
Register under "THasDeterministicOutput"

Enumeration Type Documentation

Enum for building value checkers for TensorInspector::check_value()

Enumerator
NegativeChecker 
PositiveChecker 
ZeroChecker 
NaNChecker 
InfChecker 
PositiveInfChecker 
NegativeInfChecker 
FiniteChecker 
NormalChecker 
AbnormalChecker 
enum mxnet::DispatchMode
strong

the dispatch mode of the operator

Enumerator
kUndefined 
kFCompute 
kFComputeEx 
kFComputeFallback 
kVariable 
enum mxnet::ExecType
strong

the execution type of the operator

Enumerator
kSync 

Forward/Backward are synchronous calls.

kAsync 

Forward/Backward are asynchronous, will call OpContext.async_on_complete when operation finishes.

kCrossDeviceCopy 

Cross device copy operation, this is a special operator that indicates it will copy across devices. For example the input and output for this type of operator can potentially reside on different devices. In the current implementation, a copy operator is specially handled by an executor. This flag is used for special case treatment and future extension of different copy ops.

kSubgraphExec 

A subgraph execution should happen in the main thread, instead of in the execution engine.

enum mxnet::FnProperty
strong

Function property, used to hint what action is pushed to engine.

Enumerator
kNormal 

Normal operation.

kCopyFromGPU 

Copy operation from GPU to other devices.

kCopyToGPU 

Copy operation from CPU to other devices.

kCPUPrioritized 

Prioritized sync operation on CPU.

kAsync 

Asynchronous function call.

kDeleteVar 

Delete variable call.

kGPUPrioritized 

Prioritized sync operation on GPU.

kNoSkip 

Operation not to be skipped even with associated exception.

enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler configs. Similar to mx.profiler.set_config() kState allows changing state of profiler to stop or run kPause allows pausing and resuming of profiler kDump asks profiler to dump output

Enumerator
kSetConfig 
kState 
kPause 
kDump 
Enumerator
kNormalErr 
kCSRShapeErr 
kCSRIndPtrErr 
kCSRIdxErr 
kRSPShapeErr 
kRSPIdxErr 

mask information on how functions can be exposed

Enumerator
kNDArrayArgBeforeScalar 

all the use_vars should go before scalar

kScalarArgBeforeNDArray 

all the scalar should go before use_vars

kAcceptEmptyMutateTarget 

whether this function allows the handles in the target to be empty NDArray that are not yet initialized, and will initialize them when the function is invoked.

most function should support this, except copy between different devices, which requires the NDArray to be pre-initialized with context

Enumerator
kUndefinedStorage 
kDefaultStorage 
kRowSparseStorage 
kCSRStorage 

there are three numpy shape flags based on priority. GlobalOn turn on numpy shape flag globally, it includes thread local. The flag can be seen in any thread. ThreadLocalOn only turn on thread local numpy shape flag, it cannot be seen in other threads. Off turn off numpy shape flag globally.

Enumerator
Off 
ThreadLocalOn 
GlobalOn 

operation request type to Forward and Backward

Enumerator
kNullOp 

no operation, do not write anything

kWriteTo 

write gradient to provided space

kWriteInplace 

perform an inplace write, This option only happen when Target shares memory with one of input arguments.

kAddTo 

add to the provided space

enum mxnet::QuantizeType
strong

the quantization type of the operator

Enumerator
kNone 
kMust 
kSupport 

Function Documentation

void mxnet::CopyFromTo ( const NDArray from,
const NDArray to,
int  priority = 0 
)

issue an copy operation from one NDArray to another the two ndarray can sit on different devices this operation will be scheduled by the engine

Parameters
fromthe ndarray we want to copy data from
tothe target ndarray
priorityPriority of the action.
Note
The function name explicitly marks the order of from and to due to different possible convention carried by copy function.
void mxnet::CopyFromTo ( const NDArray from,
const NDArray to,
int  priority = 0,
bool  is_opr = false 
)

issue an copy operation from one NDArray to another the two ndarray can sit on different devices this operation will be scheduled by the engine

Parameters
fromthe ndarray we want to copy data from
tothe target ndarray
priorityPriority of the action.
is_oprwhether it is invoked by an operator. For example, false if invoked from KVStore, true if invoked from _copyto operator.
Note
The function name explicitly marks the order of from and to due to different possible convention carried by copy function.
bool mxnet::dim_size_is_known ( const dim_t  dim_size)
inline

brief check if a shape's dim size is known.

bool mxnet::dim_size_is_known ( const TShape x,
const int  idx 
)
inline

brief check if a shape's dim size is known.

void mxnet::ElementwiseSum ( const std::vector< NDArray > &  source,
NDArray out,
int  priority = 0 
)

Perform elementwise sum over each data from source, store result into out.

Parameters
sourcethe ndarray we want to sum
outthe target ndarray
priorityPriority of the action.
template<typename ValueType >
PrimExpr mxnet::make_const ( MXNetDataType  t,
ValueType  value 
)
inline
template<typename ValueType >
PrimExpr mxnet::MakeConstScalar ( MXNetDataType  t,
ValueType  value 
)
inline
bool mxnet::ndim_is_known ( const int  ndim)
inline

brief check if a shape's ndim is known.

bool mxnet::ndim_is_known ( const TShape x)
inline

brief check if a shape's ndim is known.

size_t mxnet::num_aux_data ( NDArrayStorageType  stype)
Returns
the number of aux data used for given storage type
void mxnet::on_enter_api ( const char *  function)
void mxnet::on_exit_api ( )
NDArray mxnet::operator* ( const NDArray lhs,
const NDArray rhs 
)

elementwise multiplication

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator* ( const NDArray lhs,
const real_t rhs 
)

elementwise multiplication

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator+ ( const NDArray lhs,
const NDArray rhs 
)

elementwise add

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator+ ( const NDArray lhs,
const real_t rhs 
)

elementwise add

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator- ( const NDArray lhs,
const NDArray rhs 
)

elementwise subtraction

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator- ( const NDArray lhs,
const real_t rhs 
)

elementwise subtraction

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator/ ( const NDArray lhs,
const NDArray rhs 
)

elementwise division

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
NDArray mxnet::operator/ ( const NDArray lhs,
const real_t rhs 
)

elementwise division

Parameters
lhsleft operand
rhsright operand
Returns
a new result ndarray
void mxnet::RandomSeed ( uint32_t  seed)

Seed all random number generator in mxnet.

Parameters
seedthe seed to set to global random number generators.
void mxnet::RandomSeed ( Context  ctx,
uint32_t  seed 
)

Seed the random number generator of the device.

Parameters
seedthe seed to set to global random number generators.
void mxnet::SampleExponential ( real_t  lambda,
NDArray out 
)

Sample exponential distribution for each elements of out.

Parameters
lambdaparameter (rate) of the exponential distribution
outoutput NDArray.
void mxnet::SampleGamma ( real_t  alpha,
real_t  beta,
NDArray out 
)

Sample gamma distribution for each elements of out.

Parameters
alphaparameter (shape) of the gamma distribution
betaparameter (scale) of the gamma distribution
outoutput NDArray.
void mxnet::SampleGaussian ( real_t  mu,
real_t  sigma,
NDArray out 
)

Sample gaussian distribution for each elements of out.

Parameters
mumean of gaussian distribution.
sigmastandard deviation of gaussian distribution.
outoutput NDArray.
void mxnet::SampleGenNegBinomial ( real_t  mu,
real_t  alpha,
NDArray out 
)

Sample generalized negative binomial distribution for each elements of out.

Parameters
muparameter (mean) of the distribution
alphaparameter (over dispersion) of the distribution
outoutput NDArray.
void mxnet::SampleNegBinomial ( int32_t  k,
real_t  p,
NDArray out 
)

Sample negative binomial distribution for each elements of out.

Parameters
kfailure limit
psuccess probability
outoutput NDArray.
void mxnet::SamplePoisson ( real_t  lambda,
NDArray out 
)

Sample Poisson distribution for each elements of out.

Parameters
lambdaparameter (rate) of the Poisson distribution
outoutput NDArray.
void mxnet::SampleUniform ( real_t  begin,
real_t  end,
NDArray out 
)

Sample uniform distribution for each elements of out.

Parameters
beginlower bound of distribution.
endupper bound of distribution.
outoutput NDArray.
bool mxnet::shape_is_known ( const TShape x)
inline

brief check if shape is known using the NumPy compatible definition. zero-dim and zero-size tensors are valid. -1 means unknown.

bool mxnet::shape_is_known ( const std::vector< TShape > &  shapes)
inline
template<typename SrcIter , typename DstIter >
DstIter mxnet::ShapeTypeCast ( const SrcIter  begin,
const SrcIter  end,
DstIter  dst_begin 
)
inline

helper function to cast type of container elements

template<typename SrcIter >
TShape mxnet::ShapeTypeCast ( const SrcIter  begin,
const SrcIter  end 
)
inline

helper function to transform a container to TShape with type cast

Variable Documentation

constexpr const int mxnet::kCPU = kDLCPU
constexpr const int mxnet::kGPU = kDLGPU
constexpr const int mxnet::kTVMNDArrayTypeCode = 19