mxnet
|
APIs to interact with libraries This API specifies function prototypes to register custom ops, partitioner, and passes for library authors See example/extension/lib_custom_op/README.md See example/extension/lib_subgraph/README.md See example/extension/lib_pass/README.md. More...
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <vector>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <string>
#include <iostream>
#include <utility>
#include <stdexcept>
#include <functional>
#include <random>
#include <sstream>
#include <stddef.h>
Go to the source code of this file.
Classes | |
struct | DLContext |
A Device context for Tensor and operator. More... | |
struct | DLDataType |
The data type the tensor can hold. More... | |
struct | DLTensor |
Plain C Tensor object, does not manage memory. More... | |
class | mxnet::ext::MXerrorMsgs |
struct | mxnet::ext::MXContext |
Context info passing from MXNet OpContext dev_type is string repr of supported context, currently only "cpu" and "gpu" dev_id is the device index where the tensor locates. More... | |
struct | mxnet::ext::MXSparse |
struct | mxnet::ext::MXTensor |
Tensor data structure used by custom operator. More... | |
class | mxnet::ext::PassResource |
class | mxnet::ext::OpResource |
provide resource APIs memory allocation mechanism to Forward/Backward functions More... | |
struct | mxnet::ext::JsonVal |
definition of JSON objects More... | |
struct | mxnet::ext::NodeEntry |
class | mxnet::ext::Node |
class | mxnet::ext::Graph |
class | mxnet::ext::CustomOpSelector |
class | mxnet::ext::CustomStatefulOp |
An abstract class for library authors creating stateful op custom library should override Forward and destructor, and has an option to implement Backward. More... | |
class | mxnet::ext::CustomOp |
Class to hold custom operator registration. More... | |
class | mxnet::ext::CustomPass |
An abstract class for graph passes. More... | |
class | mxnet::ext::CustomPartitioner |
An abstract class for subgraph property. More... | |
class | mxnet::ext::Registry< T > |
Registry class to registers things (ops, properties) Singleton class. More... | |
class | mxnet::ext::CustomStatefulOpWrapper |
StatefulOp wrapper class to pass to backend OpState. More... | |
Namespaces | |
mxnet | |
namespace of mxnet | |
mxnet::ext | |
Macros | |
#define | MX_LIBRARY_VERSION 11 |
#define | PRIVATE_SYMBOL __attribute__((visibility("hidden"))) |
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to undefined behaviour, so we need to set symbol visibility to hidden see https://labjack.com/news/simple-cpp-symbol-visibility-demo for details. More... | |
#define | DLPACK_EXTERN_C |
#define | DLPACK_VERSION 020 |
The current version of dlpack. More... | |
#define | DLPACK_DLL |
DLPACK_DLL prefix for windows. More... | |
#define | MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__) |
#define | MX_NUM_CPU_RANDOM_STATES 1024 |
MXNet initialized random states for each device, used for parallelism. More... | |
#define | MX_NUM_GPU_RANDOM_STATES 32768 |
#define | MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json" |
attribute key to help passing serialized subgraph through subgraph op attribute More... | |
#define | MX_STR_DTYPE "__ext_dtype__" |
dtype attribute key for ops after type propagation More... | |
#define | MX_STR_SHAPE "__ext_shape__" |
shape attribute key for ops after shape propagation More... | |
#define | MX_STR_EXTRA_INPUTS "__ext_extra_inputs__" |
extra input attribute key for ops More... | |
#define | MX_STR_CONCAT_(__a, __b) __a##__b |
Macros to help with string concat Annoyingly, the concat_ and concat macros are necessary to be able to use COUNTER in an identifier name. More... | |
#define | MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b) |
#define | MX_STRINGIFY(x) #x |
convert a token to a string More... | |
#define | MX_TOSTRING(x) MX_STRINGIFY(x) |
#define | MX_REGISTER_NAME_(Name) MXNet##_CustomOp##_##Name |
declare a variable with custom name More... | |
#define | MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name) |
#define | MX_REGISTER_PROP_NAME_(Name) MXNet##_CustomSubProp##_##Name |
#define | MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name) |
#define | MX_REGISTER_PASS_NAME_(Name) MXNet##_CustomPass##_##Name |
#define | MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name) |
#define | REGISTER_OP(Name) |
assign a var to a value More... | |
#define | REGISTER_PARTITIONER(Name) |
#define | REGISTER_PASS(Name) |
#define | MXLIB_OPREGSIZE_STR "_opRegSize" |
Following are the C type APIs implemented in the external library Each API has a #define string that is used to lookup the function in the library Followed by the function declaration. More... | |
#define | MXLIB_OPREGGET_STR "_opRegGet" |
#define | MXLIB_OPCALLFREE_STR "_opCallFree" |
#define | MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" |
#define | MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape" |
#define | MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" |
#define | MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType" |
#define | MXLIB_OPCALLFCOMP_STR "_opCallFCompute" |
#define | MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" |
#define | MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" |
#define | MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState" |
#define | MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" |
#define | MXLIB_PARTREGSIZE_STR "_partRegSize" |
#define | MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount" |
#define | MXLIB_PARTREGGET_STR "_partRegGet" |
#define | MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" |
#define | MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector" |
#define | MXLIB_PARTCALLSELECT_STR "_partCallSelect" |
#define | MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput" |
#define | MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput" |
#define | MXLIB_PARTCALLFILTER_STR "_partCallFilter" |
#define | MXLIB_PARTCALLRESET_STR "_partCallReset" |
#define | MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph" |
#define | MXLIB_PASSREGSIZE_STR "_passRegSize" |
#define | MXLIB_PASSREGGET_STR "_passRegGet" |
#define | MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass" |
#define | MXLIB_INITIALIZE_STR "initialize" |
#define | MXLIB_OPVERSION_STR "_opVersion" |
#define | MXLIB_MSGSIZE_STR "_msgSize" |
#define | MXLIB_MSGGET_STR "_msgGet" |
#define | MX_INT_RET int |
#define | MX_VOID_RET void |
Typedefs | |
typedef void *(* | mxnet::ext::xpu_malloc_t) (void *, int) |
resource malloc function to allocate memory inside Forward/Backward functions More... | |
typedef void(* | mxnet::ext::sparse_malloc_t) (void *, int, int, int, void **, int64_t **, int64_t **) |
sparse alloc function to allocate memory inside Forward/Backward functions More... | |
typedef void(* | mxnet::ext::nd_malloc_t) (const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data) |
resource malloc function to allocate ndarrays for graph passes More... | |
typedef void * | mxnet::ext::mx_stream_t |
GPU stream pointer, is void* when not compiled with CUDA. More... | |
typedef void * | mxnet::ext::mx_gpu_rand_t |
typedef std::mt19937 | mxnet::ext::mx_cpu_rand_t |
typedef MXReturnValue(* | mxnet::ext::fcomp_t) (const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res) |
Custom Operator function templates. More... | |
typedef MXReturnValue(* | mxnet::ext::parseAttrs_t) (const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs) |
typedef MXReturnValue(* | mxnet::ext::inferType_t) (const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types) |
typedef MXReturnValue(* | mxnet::ext::inferSType_t) (const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types) |
typedef MXReturnValue(* | mxnet::ext::inferShape_t) (const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes) |
typedef MXReturnValue(* | mxnet::ext::mutateInputs_t) (const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices) |
typedef MXReturnValue(* | mxnet::ext::createOpState_t) (const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **) |
typedef MXReturnValue(* | mxnet::ext::graphPass_t) (mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options) |
Custom Pass Create function template. More... | |
typedef MXReturnValue(* | mxnet::ext::supportedOps_t) (const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options) |
Custom Subgraph Create function template. More... | |
typedef MXReturnValue(* | mxnet::ext::createSelector_t) (const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options) |
typedef MXReturnValue(* | mxnet::ext::reviewSubgraph_t) (const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs) |
typedef int(* | mxnet::ext::opRegSize_t) (void) |
typedef int(* | mxnet::ext::opRegGet_t) (int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate) |
typedef int(* | mxnet::ext::opCallFree_t) (void *ptr) |
typedef int(* | mxnet::ext::opCallParseAttrs_t) (parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out) |
typedef int(* | mxnet::ext::opCallInferShape_t) (inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out) |
typedef int(* | mxnet::ext::opCallInferType_t) (inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out) |
typedef int(* | mxnet::ext::opCallInferSType_t) (inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out) |
typedef int(* | mxnet::ext::opCallFComp_t) (fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states) |
typedef int(* | mxnet::ext::opCallMutateInputs_t) (mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size) |
typedef int(* | mxnet::ext::opCallCreateOpState_t) (createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op) |
typedef int(* | mxnet::ext::opCallDestroyOpState_t) (void *state_op) |
typedef int(* | mxnet::ext::opCallFStatefulComp_t) (int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states) |
typedef int(* | mxnet::ext::partRegSize_t) (void) |
typedef int(* | mxnet::ext::partRegGetCount_t) (int idx, const char **name) |
typedef void(* | mxnet::ext::partRegGet_t) (int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name) |
typedef int(* | mxnet::ext::partCallSupportedOps_t) (supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts) |
typedef int(* | mxnet::ext::partCallCreateSelector_t) (createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts) |
typedef void(* | mxnet::ext::partCallSelect_t) (void *sel_inst, int nodeID, int *selected) |
typedef void(* | mxnet::ext::partCallSelectInput_t) (void *sel_inst, int nodeID, int input_nodeID, int *selected) |
typedef void(* | mxnet::ext::partCallSelectOutput_t) (void *sel_inst, int nodeID, int output_nodeID, int *selected) |
typedef void(* | mxnet::ext::partCallFilter_t) (void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep) |
typedef void(* | mxnet::ext::partCallReset_t) (void *sel_inst) |
typedef int(* | mxnet::ext::partCallReviewSubgraph_t) (reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id) |
typedef int(* | mxnet::ext::passRegSize_t) (void) |
typedef void(* | mxnet::ext::passRegGet_t) (int pass_idx, graphPass_t *graphPass, const char **pass_name) |
typedef int(* | mxnet::ext::passCallGraphPass_t) (graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc) |
typedef int(* | mxnet::ext::initialize_t) (int version) |
typedef int(* | mxnet::ext::opVersion_t) () |
typedef int(* | mxnet::ext::msgSize_t) (void) |
typedef int(* | mxnet::ext::msgGet_t) (int idx, const char **msg) |
Functions | |
std::string | mxnet::ext::getShapeAt (const std::string &shape, unsigned index) |
std::string | mxnet::ext::getDtypeAt (const std::string &dtype, unsigned index) |
MX_INT_RET | _opVersion () |
returns MXNet library version More... | |
MX_INT_RET | _opRegSize () |
returns number of ops registered in this library More... | |
MX_VOID_RET | _opRegGet (int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate) |
returns operator registration at specified index More... | |
MX_VOID_RET | _opCallFree (void *ptr) |
calls free from the external library for library allocated arrays More... | |
MX_INT_RET | _opCallParseAttrs (mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out) |
returns status of calling parse attributes function for operator from library More... | |
MX_INT_RET | _opCallInferShape (mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out) |
returns status of calling inferShape function for operator from library More... | |
MX_INT_RET | _opCallInferType (mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out) |
returns status of calling inferType function for operator from library More... | |
MX_INT_RET | _opCallInferSType (mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out) |
returns status of calling inferSType function for operator from library More... | |
MX_INT_RET | _opCallFCompute (mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states) |
returns status of calling Forward/Backward function for operator from library More... | |
MX_INT_RET | _opCallMutateInputs (mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size) |
returns status of calling mutateInputs function for operator from library More... | |
MX_INT_RET | _opCallCreateOpState (mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op) |
returns status of calling createStatefulOp function for operator from library More... | |
MX_VOID_RET | _opCallDestroyOpState (void *state_op) |
returns status of deleting StatefulOp instance for operator from library More... | |
MX_INT_RET | _opCallFStatefulCompute (int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states) |
returns status of calling Stateful Forward/Backward for operator from library More... | |
MX_INT_RET | _partRegSize () |
returns number of partitioners registered in this library More... | |
MX_INT_RET | _partRegGetCount (int idx, const char **name) |
MX_VOID_RET | _partRegGet (int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name) |
returns partitioner registration at specified index More... | |
MX_INT_RET | _partCallSupportedOps (mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts) |
returns status of calling supported ops function from library More... | |
MX_INT_RET | _partCallCreateSelector (mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts) |
returns status of calling create selector function from library More... | |
MX_VOID_RET | _partCallSelect (void *sel_inst, int nodeID, int *selected) |
returns status of calling select function from library More... | |
MX_VOID_RET | _partCallSelectInput (void *sel_inst, int nodeID, int input_nodeID, int *selected) |
returns status of calling select input function from library More... | |
MX_VOID_RET | _partCallSelectOutput (void *sel_inst, int nodeID, int output_nodeID, int *selected) |
returns status of calling select output function from library More... | |
MX_VOID_RET | _partCallFilter (void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep) |
returns status of calling filter function from library More... | |
MX_VOID_RET | _partCallReset (void *sel_inst) |
returns status of calling reset selector function from library More... | |
MX_INT_RET | _partCallReviewSubgraph (mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id) |
returns status of calling review subgraph function from library More... | |
MX_INT_RET | _passRegSize () |
returns number of graph passes registered in this library More... | |
MX_VOID_RET | _passRegGet (int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name) |
returns pass registration at specified index More... | |
MX_INT_RET | _passCallGraphPass (mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc) |
returns status of calling graph pass function from library More... | |
mxnet::ext::MXReturnValue | initialize (int version) |
Checks if the MXNet version is supported by the library. If supported, initializes the library. More... | |
MX_INT_RET | _msgSize () |
MX_VOID_RET | _msgGet (int idx, const char **msg) |
returns operator registration at specified index More... | |
APIs to interact with libraries This API specifies function prototypes to register custom ops, partitioner, and passes for library authors See example/extension/lib_custom_op/README.md See example/extension/lib_subgraph/README.md See example/extension/lib_pass/README.md.
#define DLPACK_DLL |
DLPACK_DLL prefix for windows.
#define DLPACK_EXTERN_C |
#define DLPACK_VERSION 020 |
The current version of dlpack.
#define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__) |
#define MX_INT_RET int |
#define MX_LIBRARY_VERSION 11 |
#define MX_NUM_CPU_RANDOM_STATES 1024 |
MXNet initialized random states for each device, used for parallelism.
#define MX_NUM_GPU_RANDOM_STATES 32768 |
#define MX_REGISTER_DEF_ | ( | Name | ) | mxnet::ext::CustomOp MX_REGISTER_NAME_(Name) |
#define MX_REGISTER_NAME_ | ( | Name | ) | MXNet##_CustomOp##_##Name |
declare a variable with custom name
#define MX_REGISTER_PASS_DEF_ | ( | Name | ) | mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name) |
#define MX_REGISTER_PASS_NAME_ | ( | Name | ) | MXNet##_CustomPass##_##Name |
#define MX_REGISTER_PROP_DEF_ | ( | Name | ) | mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name) |
#define MX_REGISTER_PROP_NAME_ | ( | Name | ) | MXNet##_CustomSubProp##_##Name |
#define MX_STR_CONCAT | ( | __a, | |
__b | |||
) | MX_STR_CONCAT_(__a, __b) |
#define MX_STR_CONCAT_ | ( | __a, | |
__b | |||
) | __a##__b |
Macros to help with string concat Annoyingly, the concat_ and concat macros are necessary to be able to use COUNTER in an identifier name.
#define MX_STR_DTYPE "__ext_dtype__" |
dtype attribute key for ops after type propagation
#define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__" |
extra input attribute key for ops
#define MX_STR_SHAPE "__ext_shape__" |
shape attribute key for ops after shape propagation
#define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json" |
attribute key to help passing serialized subgraph through subgraph op attribute
#define MX_STRINGIFY | ( | x | ) | #x |
convert a token to a string
#define MX_TOSTRING | ( | x | ) | MX_STRINGIFY(x) |
#define MX_VOID_RET void |
#define MXLIB_INITIALIZE_STR "initialize" |
#define MXLIB_MSGGET_STR "_msgGet" |
#define MXLIB_MSGSIZE_STR "_msgSize" |
#define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" |
#define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState" |
#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" |
#define MXLIB_OPCALLFREE_STR "_opCallFree" |
#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" |
#define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape" |
#define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType" |
#define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" |
#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" |
#define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" |
#define MXLIB_OPREGGET_STR "_opRegGet" |
#define MXLIB_OPREGSIZE_STR "_opRegSize" |
Following are the C type APIs implemented in the external library Each API has a #define string that is used to lookup the function in the library Followed by the function declaration.
#define MXLIB_OPVERSION_STR "_opVersion" |
#define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector" |
#define MXLIB_PARTCALLFILTER_STR "_partCallFilter" |
#define MXLIB_PARTCALLRESET_STR "_partCallReset" |
#define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph" |
#define MXLIB_PARTCALLSELECT_STR "_partCallSelect" |
#define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput" |
#define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput" |
#define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" |
#define MXLIB_PARTREGGET_STR "_partRegGet" |
#define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount" |
#define MXLIB_PARTREGSIZE_STR "_partRegSize" |
#define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass" |
#define MXLIB_PASSREGGET_STR "_passRegGet" |
#define MXLIB_PASSREGSIZE_STR "_passRegSize" |
#define PRIVATE_SYMBOL __attribute__((visibility("hidden"))) |
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to undefined behaviour, so we need to set symbol visibility to hidden see https://labjack.com/news/simple-cpp-symbol-visibility-demo for details.
#define REGISTER_OP | ( | Name | ) |
assign a var to a value
#define REGISTER_PARTITIONER | ( | Name | ) |
#define REGISTER_PASS | ( | Name | ) |
enum DLDataTypeCode |
The type code options DLDataType.
Enumerator | |
---|---|
kDLInt | |
kDLUInt | |
kDLFloat | |
kDLBfloat | |
kDLInt | |
kDLUInt | |
kDLFloat |
enum DLDeviceType |
The device type in DLContext.
MX_VOID_RET _msgGet | ( | int | idx, |
const char ** | msg | ||
) |
returns operator registration at specified index
MX_INT_RET _msgSize | ( | ) |
MX_INT_RET _opCallCreateOpState | ( | mxnet::ext::createOpState_t | create_op, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
const char * | dev_type, | ||
int | dev_id, | ||
unsigned int ** | inshapes, | ||
int * | indims, | ||
int | num_in, | ||
const int * | intypes, | ||
void ** | state_op | ||
) |
returns status of calling createStatefulOp function for operator from library
MX_VOID_RET _opCallDestroyOpState | ( | void * | state_op | ) |
returns status of deleting StatefulOp instance for operator from library
MX_INT_RET _opCallFCompute | ( | mxnet::ext::fcomp_t | fcomp, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
const int64_t ** | inshapes, | ||
int * | indims, | ||
void ** | indata, | ||
int * | intypes, | ||
size_t * | inIDs, | ||
const char ** | indev_type, | ||
int * | indev_id, | ||
int | num_in, | ||
const int64_t ** | outshapes, | ||
int * | outdims, | ||
void ** | outdata, | ||
int * | outtypes, | ||
size_t * | outIDs, | ||
const char ** | outdev_type, | ||
int * | outdev_id, | ||
int | num_out, | ||
mxnet::ext::xpu_malloc_t | cpu_malloc, | ||
void * | cpu_alloc, | ||
mxnet::ext::xpu_malloc_t | gpu_malloc, | ||
void * | gpu_alloc, | ||
void * | cuda_stream, | ||
mxnet::ext::sparse_malloc_t | sparse_malloc, | ||
void * | sparse_alloc, | ||
int * | instypes, | ||
int * | outstypes, | ||
void ** | in_indices, | ||
void ** | out_indices, | ||
void ** | in_indptr, | ||
void ** | out_indptr, | ||
int64_t * | in_indices_shapes, | ||
int64_t * | out_indices_shapes, | ||
int64_t * | in_indptr_shapes, | ||
int64_t * | out_indptr_shapes, | ||
void * | rng_cpu_states, | ||
void * | rng_gpu_states | ||
) |
returns status of calling Forward/Backward function for operator from library
MX_VOID_RET _opCallFree | ( | void * | ptr | ) |
calls free from the external library for library allocated arrays
MX_INT_RET _opCallFStatefulCompute | ( | int | is_forward, |
void * | state_op, | ||
const int64_t ** | inshapes, | ||
int * | indims, | ||
void ** | indata, | ||
int * | intypes, | ||
size_t * | inIDs, | ||
const char ** | indev_type, | ||
int * | indev_id, | ||
int | num_in, | ||
const int64_t ** | outshapes, | ||
int * | outdims, | ||
void ** | outdata, | ||
int * | outtypes, | ||
size_t * | outIDs, | ||
const char ** | outdev_type, | ||
int * | outdev_id, | ||
int | num_out, | ||
mxnet::ext::xpu_malloc_t | cpu_malloc, | ||
void * | cpu_alloc, | ||
mxnet::ext::xpu_malloc_t | gpu_malloc, | ||
void * | gpu_alloc, | ||
void * | stream, | ||
mxnet::ext::sparse_malloc_t | sparse_malloc, | ||
void * | sparse_alloc, | ||
int * | instypes, | ||
int * | outstypes, | ||
void ** | in_indices, | ||
void ** | out_indices, | ||
void ** | in_indptr, | ||
void ** | out_indptr, | ||
int64_t * | in_indices_shapes, | ||
int64_t * | out_indices_shapes, | ||
int64_t * | in_indptr_shapes, | ||
int64_t * | out_indptr_shapes, | ||
void * | rng_cpu_states, | ||
void * | rng_gpu_states | ||
) |
returns status of calling Stateful Forward/Backward for operator from library
MX_INT_RET _opCallInferShape | ( | mxnet::ext::inferShape_t | inferShape, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
unsigned int ** | inshapes, | ||
int * | indims, | ||
int | num_in, | ||
unsigned int *** | mod_inshapes, | ||
int ** | mod_indims, | ||
unsigned int *** | outshapes, | ||
int ** | outdims, | ||
int | num_out | ||
) |
returns status of calling inferShape function for operator from library
MX_INT_RET _opCallInferSType | ( | mxnet::ext::inferSType_t | inferSType, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
int * | instypes, | ||
int | num_in, | ||
int * | outstypes, | ||
int | num_out | ||
) |
returns status of calling inferSType function for operator from library
MX_INT_RET _opCallInferType | ( | mxnet::ext::inferType_t | inferType, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
int * | intypes, | ||
int | num_in, | ||
int * | outtypes, | ||
int | num_out | ||
) |
returns status of calling inferType function for operator from library
MX_INT_RET _opCallMutateInputs | ( | mxnet::ext::mutateInputs_t | mutate, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
int ** | mutate_indices, | ||
int * | indices_size | ||
) |
returns status of calling mutateInputs function for operator from library
MX_INT_RET _opCallParseAttrs | ( | mxnet::ext::parseAttrs_t | parseAttrs, |
const char *const * | keys, | ||
const char *const * | vals, | ||
int | num, | ||
int * | num_in, | ||
int * | num_out | ||
) |
returns status of calling parse attributes function for operator from library
MX_VOID_RET _opRegGet | ( | int | idx, |
const char ** | name, | ||
int * | isSGop, | ||
const char *** | forward_ctx, | ||
mxnet::ext::fcomp_t ** | forward_fp, | ||
int * | forward_count, | ||
const char *** | backward_ctx, | ||
mxnet::ext::fcomp_t ** | backward_fp, | ||
int * | backward_count, | ||
const char *** | create_op_ctx, | ||
mxnet::ext::createOpState_t ** | create_op_fp, | ||
int * | create_op_count, | ||
mxnet::ext::parseAttrs_t * | parse, | ||
mxnet::ext::inferType_t * | type, | ||
mxnet::ext::inferSType_t * | stype, | ||
mxnet::ext::inferShape_t * | shape, | ||
mxnet::ext::mutateInputs_t * | mutate | ||
) |
returns operator registration at specified index
MX_INT_RET _opRegSize | ( | ) |
returns number of ops registered in this library
MX_INT_RET _opVersion | ( | ) |
returns MXNet library version
MX_INT_RET _partCallCreateSelector | ( | mxnet::ext::createSelector_t | createSelector, |
const char * | json, | ||
void ** | selector, | ||
const char *const * | opt_keys, | ||
const char *const * | opt_vals, | ||
int | num_opts | ||
) |
returns status of calling create selector function from library
MX_VOID_RET _partCallFilter | ( | void * | sel_inst, |
int * | candidates, | ||
int | num_candidates, | ||
int ** | keep, | ||
int * | num_keep | ||
) |
returns status of calling filter function from library
MX_VOID_RET _partCallReset | ( | void * | sel_inst | ) |
returns status of calling reset selector function from library
MX_INT_RET _partCallReviewSubgraph | ( | mxnet::ext::reviewSubgraph_t | reviewSubgraph, |
const char * | json, | ||
int | subgraph_id, | ||
int * | accept, | ||
const char *const * | opt_keys, | ||
const char *const * | opt_vals, | ||
int | num_opts, | ||
char *** | attr_keys, | ||
char *** | attr_vals, | ||
int * | num_attrs, | ||
const char *const * | arg_names, | ||
int | num_args, | ||
void *const * | arg_data, | ||
const int64_t *const * | arg_shapes, | ||
const int * | arg_dims, | ||
const int * | arg_types, | ||
const size_t * | arg_IDs, | ||
const char *const * | arg_dev_type, | ||
const int * | arg_dev_id, | ||
const char *const * | aux_names, | ||
int | num_aux, | ||
void *const * | aux_data, | ||
const int64_t *const * | aux_shapes, | ||
const int * | aux_dims, | ||
const int * | aux_types, | ||
const size_t * | aux_IDs, | ||
const char *const * | aux_dev_type, | ||
const int * | aux_dev_id | ||
) |
returns status of calling review subgraph function from library
MX_VOID_RET _partCallSelect | ( | void * | sel_inst, |
int | nodeID, | ||
int * | selected | ||
) |
returns status of calling select function from library
MX_VOID_RET _partCallSelectInput | ( | void * | sel_inst, |
int | nodeID, | ||
int | input_nodeID, | ||
int * | selected | ||
) |
returns status of calling select input function from library
MX_VOID_RET _partCallSelectOutput | ( | void * | sel_inst, |
int | nodeID, | ||
int | output_nodeID, | ||
int * | selected | ||
) |
returns status of calling select output function from library
MX_INT_RET _partCallSupportedOps | ( | mxnet::ext::supportedOps_t | supportedOps, |
const char * | json, | ||
int | num_ids, | ||
int * | ids, | ||
const char *const * | opt_keys, | ||
const char *const * | opt_vals, | ||
int | num_opts | ||
) |
returns status of calling supported ops function from library
MX_VOID_RET _partRegGet | ( | int | part_idx, |
int | stg_idx, | ||
const char ** | strategy, | ||
mxnet::ext::supportedOps_t * | supportedOps, | ||
mxnet::ext::createSelector_t * | createSelector, | ||
mxnet::ext::reviewSubgraph_t * | reviewSubgraph, | ||
const char ** | op_name | ||
) |
returns partitioner registration at specified index
MX_INT_RET _partRegGetCount | ( | int | idx, |
const char ** | name | ||
) |
MX_INT_RET _partRegSize | ( | ) |
returns number of partitioners registered in this library
MX_INT_RET _passCallGraphPass | ( | mxnet::ext::graphPass_t | graphPass, |
const char * | json, | ||
char ** | out_graph, | ||
const char *const * | opt_keys, | ||
const char *const * | opt_vals, | ||
int | num_opts, | ||
const char * | pass_name, | ||
const char *const * | arg_names, | ||
int | num_args, | ||
void *const * | arg_data, | ||
const int64_t *const * | arg_shapes, | ||
const int * | arg_dims, | ||
const int * | arg_types, | ||
const size_t * | arg_IDs, | ||
const char *const * | arg_dev_type, | ||
const int * | arg_dev_id, | ||
const char *const * | aux_names, | ||
int | num_aux, | ||
void *const * | aux_data, | ||
const int64_t *const * | aux_shapes, | ||
const int * | aux_dims, | ||
const int * | aux_types, | ||
const size_t * | aux_IDs, | ||
const char *const * | aux_dev_type, | ||
const int * | aux_dev_id, | ||
mxnet::ext::nd_malloc_t | nd_malloc, | ||
const void * | nd_alloc | ||
) |
returns status of calling graph pass function from library
MX_VOID_RET _passRegGet | ( | int | pass_idx, |
mxnet::ext::graphPass_t * | graphPass, | ||
const char ** | pass_name | ||
) |
returns pass registration at specified index
MX_INT_RET _passRegSize | ( | ) |
returns number of graph passes registered in this library
mxnet::ext::MXReturnValue initialize | ( | int | version | ) |
Checks if the MXNet version is supported by the library. If supported, initializes the library.
version | MXNet version number passed to library and defined as: MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) |