Go to the documentation of this file.
31 #ifndef MXNET_LIB_API_H_
32 #define MXNET_LIB_API_H_
39 #include <unordered_set>
40 #include <unordered_map>
50 #include <cuda_runtime.h>
51 #include <curand_kernel.h>
55 #define MX_LIBRARY_VERSION 11
62 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
63 #define PRIVATE_SYMBOL
65 #define PRIVATE_SYMBOL __attribute__((visibility("hidden")))
71 #ifndef DLPACK_VERSION
73 #define DLPACK_EXTERN_C extern "C"
75 #define DLPACK_EXTERN_C
79 #define DLPACK_VERSION 020
84 #define DLPACK_DLL __declspec(dllexport)
86 #define DLPACK_DLL __declspec(dllimport)
209 uint64_t byte_offset;
226 std::stringstream&
add(
const char* file,
int line);
232 const std::string*
get(
int idx);
240 std::vector<std::stringstream> messages;
244 #define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__)
279 explicit MXContext(std::string dev_type_,
int dev_id_);
280 explicit MXContext(
const char* dev_type_,
int dev_id_);
313 void set(
void* data_ptr,
318 void* idx_ptr =
nullptr,
319 int64_t num_idx_ptr = 0);
329 std::vector<int64_t>
shape,
348 template <
typename data_type>
350 return reinterpret_cast<data_type*
>(
data_ptr);
354 int64_t
size()
const;
384 typedef void* (*xpu_malloc_t)(
void*, int);
389 const int64_t* shapes,
398 #if defined(__NVCC__)
409 #define MX_NUM_CPU_RANDOM_STATES 1024
410 #define MX_NUM_GPU_RANDOM_STATES 32768
415 PassResource(std::unordered_map<std::string, MXTensor>* new_args,
416 std::unordered_map<std::string, MXTensor>* new_aux,
418 const void* nd_alloc);
422 const std::vector<int64_t>& shapes,
428 const std::vector<int64_t>& shapes,
433 std::unordered_map<std::string, MXTensor>* new_args_;
434 std::unordered_map<std::string, MXTensor>* new_aux_;
436 const void* nd_alloc_;
450 void* sparse_alloc_fp,
451 void* rng_cpu_states,
452 void* rng_gpu_states);
483 void *cpu_alloc, *gpu_alloc;
491 void *rand_cpu_states, *rand_gpu_states;
495 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
497 #define MX_STR_DTYPE "__ext_dtype__"
499 #define MX_STR_SHAPE "__ext_shape__"
501 #define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__"
510 std::string
getShapeAt(
const std::string& shape,
unsigned index);
519 std::string
getDtypeAt(
const std::string& dtype,
unsigned index);
533 explicit JsonVal(std::string s);
541 std::string
dump()
const;
559 static JsonVal parse(
const std::string& json,
unsigned int* idx);
568 std::map<JsonVal, JsonVal>
map;
603 std::unordered_map<std::string, std::string>
attrs;
631 std::unordered_set<Node*>* to_visit,
632 std::function<
void(
Node*)> handler)
const;
635 void DFS(std::function<
void(
Node*)> handler)
const;
641 void print(
int indent = 0)
const;
644 Node*
addNode(
const std::string& name,
const std::string& op);
662 void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
663 std::unordered_map<std::string, mxnet::ext::MXTensor>* aux);
667 std::map<std::string, JsonVal>
attrs;
670 std::vector<Node*> nodes;
682 virtual bool Select(
int nodeID) = 0;
688 virtual bool SelectInput(
int nodeID,
int input_nodeID) = 0;
694 virtual bool SelectOutput(
int nodeID,
int output_nodeID) = 0;
700 virtual void Filter(
const std::vector<int>& candidates, std::vector<int>* keep) {
701 keep->insert(keep->end(), candidates.begin(), candidates.end());
719 template <
class A,
typename... Ts>
731 std::vector<MXTensor>* outputs,
734 std::vector<MXTensor>* outputs,
736 MX_ERROR_MSG <<
"Error! Operator does not support backward" << std::endl;
748 std::vector<MXTensor>* inputs,
749 std::vector<MXTensor>* outputs,
752 const std::unordered_map<std::string, std::string>& attributes,
756 std::vector<int>* in_types,
757 std::vector<int>* out_types);
759 const std::unordered_map<std::string, std::string>& attributes,
760 std::vector<int>* in_storage_types,
761 std::vector<int>* out_storage_types);
763 const std::unordered_map<std::string, std::string>& attributes,
764 std::vector<std::vector<unsigned int> >* in_shapes,
765 std::vector<std::vector<unsigned int> >* out_shapes);
767 const std::unordered_map<std::string, std::string>& attributes,
768 std::vector<int>* input_indices);
770 const std::unordered_map<std::string, std::string>& attributes,
772 const std::vector<std::vector<unsigned int> >& in_shapes,
773 const std::vector<int> in_types,
781 explicit CustomOp(
const char* op_name);
820 void raiseDuplicateContextError();
823 std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
824 std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
829 const std::unordered_map<std::string, std::string>& options);
850 std::vector<int>* ids,
851 const std::unordered_map<std::string, std::string>& options);
855 const std::unordered_map<std::string, std::string>& options);
860 const std::unordered_map<std::string, std::string>& options,
861 std::unordered_map<std::string, std::string>* attrs);
916 T&
add(
const char* name) {
917 T* entry =
new T(name);
918 entries.push_back(entry);
922 return entries.size();
925 return *(entries.at(idx));
934 std::vector<T*> entries;
942 #define MX_STR_CONCAT_(__a, __b) __a##__b
943 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b)
946 #define MX_STRINGIFY(x) #x
947 #define MX_TOSTRING(x) MX_STRINGIFY(x)
950 #define MX_REGISTER_NAME_(Name) MXNet##_CustomOp##_##Name
951 #define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name)
953 #define MX_REGISTER_PROP_NAME_(Name) MXNet##_CustomSubProp##_##Name
954 #define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
956 #define MX_REGISTER_PASS_NAME_(Name) MXNet##_CustomPass##_##Name
957 #define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name)
960 #define REGISTER_OP(Name) \
961 MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
962 mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name))
964 #define REGISTER_PARTITIONER(Name) \
965 MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
966 mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name))
968 #define REGISTER_PASS(Name) \
969 MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
970 mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name))
979 #define MXLIB_OPREGSIZE_STR "_opRegSize"
982 #define MXLIB_OPREGGET_STR "_opRegGet"
986 const char*** forward_ctx,
989 const char*** backward_ctx,
992 const char*** create_op_ctx,
994 int* create_op_count,
1001 #define MXLIB_OPCALLFREE_STR "_opCallFree"
1004 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
1006 const char*
const* keys,
1007 const char*
const* vals,
1012 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape"
1014 const char*
const* keys,
1015 const char*
const* vals,
1017 unsigned int** inshapes,
1020 unsigned int*** mod_inshapes,
1022 unsigned int*** outshapes,
1026 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType"
1028 const char*
const* keys,
1029 const char*
const* vals,
1036 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
1038 const char*
const* keys,
1039 const char*
const* vals,
1046 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
1048 const char*
const* keys,
1049 const char*
const* vals,
1051 const int64_t** inshapes,
1056 const char** indev_type,
1059 const int64_t** outshapes,
1064 const char** outdev_type,
1080 int64_t* in_indices_shapes,
1081 int64_t* out_indices_shapes,
1082 int64_t* in_indptr_shapes,
1083 int64_t* out_indptr_shapes,
1084 void* rng_cpu_states,
1085 void* rng_gpu_states);
1087 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
1089 const char*
const* keys,
1090 const char*
const* vals,
1092 int** mutate_indices,
1095 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
1097 const char*
const* keys,
1098 const char*
const* vals,
1100 const char* dev_type,
1102 unsigned int** inshapes,
1108 #define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState"
1111 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
1114 const int64_t** inshapes,
1119 const char** indev_type,
1122 const int64_t** outshapes,
1127 const char** outdev_type,
1143 int64_t* in_indices_shapes,
1144 int64_t* out_indices_shapes,
1145 int64_t* in_indptr_shapes,
1146 int64_t* out_indptr_shapes,
1147 void* rng_cpu_states,
1148 void* rng_gpu_states);
1150 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
1153 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount"
1156 #define MXLIB_PARTREGGET_STR "_partRegGet"
1159 const char** strategy,
1163 const char** op_name);
1165 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps"
1170 const char*
const* opt_keys,
1171 const char*
const* opt_vals,
1174 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector"
1178 const char*
const* opt_keys,
1179 const char*
const* opt_vals,
1182 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect"
1185 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput"
1188 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput"
1194 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter"
1201 #define MXLIB_PARTCALLRESET_STR "_partCallReset"
1204 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph"
1209 const char*
const* opt_keys,
1210 const char*
const* opt_vals,
1215 const char*
const* arg_names,
1217 void*
const* arg_data,
1218 const int64_t*
const* arg_shapes,
1219 const int* arg_dims,
1220 const int* arg_types,
1221 const size_t* arg_IDs,
1222 const char*
const* arg_dev_type,
1223 const int* arg_dev_id,
1224 const char*
const* aux_names,
1226 void*
const* aux_data,
1227 const int64_t*
const* aux_shapes,
1228 const int* aux_dims,
1229 const int* aux_types,
1230 const size_t* aux_IDs,
1231 const char*
const* aux_dev_type,
1232 const int* aux_dev_id);
1234 #define MXLIB_PASSREGSIZE_STR "_passRegSize"
1237 #define MXLIB_PASSREGGET_STR "_passRegGet"
1240 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass"
1242 const char* in_graph,
1244 const char*
const* opt_keys,
1245 const char*
const* opt_vals,
1247 const char* pass_name,
1248 const char*
const* arg_names,
1250 void*
const* arg_data,
1251 const int64_t*
const* arg_shapes,
1252 const int* arg_dims,
1253 const int* arg_types,
1254 const size_t* arg_IDs,
1255 const char*
const* arg_dev_type,
1256 const int* arg_dev_id,
1257 const char*
const* aux_names,
1259 void*
const* aux_data,
1260 const int64_t*
const* aux_shapes,
1261 const int* aux_dims,
1262 const int* aux_types,
1263 const size_t* aux_IDs,
1264 const char*
const* aux_dev_type,
1265 const int* aux_dev_id,
1267 const void* nd_alloc);
1269 #define MXLIB_INITIALIZE_STR "initialize"
1272 #define MXLIB_OPVERSION_STR "_opVersion"
1275 #define MXLIB_MSGSIZE_STR "_msgSize"
1278 #define MXLIB_MSGGET_STR "_msgGet"
1286 : instance(inst), destroy_(destroy) {}
1296 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1297 #define MX_INT_RET __declspec(dllexport) int __cdecl
1298 #define MX_VOID_RET __declspec(dllexport) void __cdecl
1300 #define MX_INT_RET int
1301 #define MX_VOID_RET void
1318 const char*** forward_ctx,
1321 const char*** backward_ctx,
1323 int* backward_count,
1324 const char*** create_op_ctx,
1326 int* create_op_count,
1338 const char*
const* keys,
1339 const char*
const* vals,
1346 const char*
const* keys,
1347 const char*
const* vals,
1349 unsigned int** inshapes,
1352 unsigned int*** mod_inshapes,
1354 unsigned int*** outshapes,
1360 const char*
const* keys,
1361 const char*
const* vals,
1370 const char*
const* keys,
1371 const char*
const* vals,
1380 const char*
const* keys,
1381 const char*
const* vals,
1383 const int64_t** inshapes,
1388 const char** indev_type,
1391 const int64_t** outshapes,
1396 const char** outdev_type,
1412 int64_t* in_indices_shapes,
1413 int64_t* out_indices_shapes,
1414 int64_t* in_indptr_shapes,
1415 int64_t* out_indptr_shapes,
1416 void* rng_cpu_states,
1417 void* rng_gpu_states);
1421 const char*
const* keys,
1422 const char*
const* vals,
1424 int** mutate_indices,
1429 const char*
const* keys,
1430 const char*
const* vals,
1432 const char* dev_type,
1434 unsigned int** inshapes,
1446 const int64_t** inshapes,
1451 const char** indev_type,
1454 const int64_t** outshapes,
1459 const char** outdev_type,
1475 int64_t* in_indices_shapes,
1476 int64_t* out_indices_shapes,
1477 int64_t* in_indptr_shapes,
1478 int64_t* out_indptr_shapes,
1479 void* rng_cpu_states,
1480 void* rng_gpu_states);
1492 const char** strategy,
1496 const char** op_name);
1503 const char*
const* opt_keys,
1504 const char*
const* opt_vals,
1511 const char*
const* opt_keys,
1512 const char*
const* opt_vals,
1539 const char*
const* opt_keys,
1540 const char*
const* opt_vals,
1545 const char*
const* arg_names,
1547 void*
const* arg_data,
1548 const int64_t*
const* arg_shapes,
1549 const int* arg_dims,
1550 const int* arg_types,
1551 const size_t* arg_IDs,
1552 const char*
const* arg_dev_type,
1553 const int* arg_dev_id,
1554 const char*
const* aux_names,
1556 void*
const* aux_data,
1557 const int64_t*
const* aux_shapes,
1558 const int* aux_dims,
1559 const int* aux_types,
1560 const size_t* aux_IDs,
1561 const char*
const* aux_dev_type,
1562 const int* aux_dev_id);
1574 const char*
const* opt_keys,
1575 const char*
const* opt_vals,
1577 const char* pass_name,
1578 const char*
const* arg_names,
1580 void*
const* arg_data,
1581 const int64_t*
const* arg_shapes,
1582 const int* arg_dims,
1583 const int* arg_types,
1584 const size_t* arg_IDs,
1585 const char*
const* arg_dev_type,
1586 const int* arg_dev_id,
1587 const char*
const* aux_names,
1589 void*
const* aux_data,
1590 const int64_t*
const* aux_shapes,
1591 const int* aux_dims,
1592 const int* aux_types,
1593 const size_t* aux_IDs,
1594 const char*
const* aux_dev_type,
1595 const int* aux_dev_id,
1597 const void* nd_alloc);
1606 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1619 #endif // MXNET_LIB_API_H_
inferSType_t infer_storage_type
Definition: lib_api.h:809
namespace of mxnet
Definition: api_registry.h:33
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1183
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
bool isSGop
Definition: lib_api.h:812
void * mx_gpu_rand_t
Definition: lib_api.h:403
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:525
MXDType dtype
Definition: lib_api.h:367
static JsonVal parse(const std::string &json)
int64_t * indices
Definition: lib_api.h:305
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:747
void * data_ptr
Definition: lib_api.h:361
The data type the tensor can hold.
Definition: dlpack.h:94
CustomOp & setBackward(fcomp_t fgrad, const char *ctx)
int(* passRegSize_t)(void)
Definition: lib_api.h:1235
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:1037
@ kDLOpenCL
OpenCL devices.
Definition: lib_api.h:112
std::string op
Definition: lib_api.h:597
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
Definition: lib_api.h:578
void * alloc_gpu(int size) const
allocate gpu memory controlled by MXNet
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:733
Definition: lib_api.h:610
An abstract class for subgraph property.
Definition: lib_api.h:866
std::vector< int64_t > shape
Definition: lib_api.h:364
const JsonVal & getAttr(const std::string &key) const
std::string getDtypeAt(const std::string &dtype, unsigned index)
CustomPass & setBody(graphPass_t fn)
CustomOp & setIsSubgraphOp()
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:1002
MXReturnValue(* reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs)
Definition: lib_api.h:856
Tensor data structure used by custom operator.
Definition: lib_api.h:325
mutateInputs_t mutate_inputs
Definition: lib_api.h:811
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
returns status of calling createStatefulOp function for operator from library
void * data
Definition: lib_api.h:298
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:405
MX_INT_RET _partRegGetCount(int idx, const char **name)
static Graph * fromString(const std::string &json)
An abstract class for graph passes.
Definition: lib_api.h:834
@ kCSRStorage
Definition: lib_api.h:269
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
returns operator registration at specified index
std::vector< NodeEntry > inputs
Definition: lib_api.h:600
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1205
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:755
virtual bool SelectOutput(int nodeID, int output_nodeID)=0
Definition: lib_api.h:677
int64_t indptr_len
Definition: lib_api.h:311
@ kDLROCM
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:120
MXReturnValue(* supportedOps_t)(const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:848
void _setPassResource(PassResource *res_)
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1112
@ MX_SUCCESS
Definition: lib_api.h:292
A Device context for Tensor and operator.
Definition: dlpack.h:69
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:751
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1175
Definition: lib_api.h:413
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1195
void alloc_aux(const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype)
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:402
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
definition of JSON objects
Definition: lib_api.h:528
@ kUint8
Definition: lib_api.h:253
graphPass_t pass
pass function
Definition: lib_api.h:844
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
Definition: lib_api.h:1096
int64_t * indptr
Definition: lib_api.h:310
bool operator<(const JsonVal &o) const
CustomStatefulOp * get_instance()
Definition: lib_api.h:1287
@ LIST
Definition: lib_api.h:525
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
void _dfs_util(Node *n, std::unordered_set< Node * > *to_visit, std::function< void(Node *)> handler) const
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:461
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
std::vector< JsonVal > list
Definition: lib_api.h:567
int(* opCallDestroyOpState_t)(void *state_op)
Definition: lib_api.h:1109
virtual MXReturnValue Forward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)=0
PassResource(std::unordered_map< std::string, MXTensor > *new_args, std::unordered_map< std::string, MXTensor > *new_aux, nd_malloc_t nd_malloc, const void *nd_alloc)
int(* partRegSize_t)(void)
Definition: lib_api.h:1151
std::vector< fcomp_t > backward_fp
Definition: lib_api.h:816
void print(int indent=0) const
int dev_id
Definition: lib_api.h:287
int(* opRegSize_t)(void)
Definition: lib_api.h:980
std::string toString() const
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1238
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:386
inferType_t infer_type
Definition: lib_api.h:808
CustomPartitioner & setReviewSubgraph(const char *prop_name, reviewSubgraph_t fn)
virtual bool SelectInput(int nodeID, int input_nodeID)=0
@ kDLUInt
Definition: lib_api.h:144
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:388
bool isSame(const MXTensor &oth) const
helper function to compare two MXTensors
void * alloc_cpu(int size) const
allocate cpu memory controlled by MXNet
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:1027
CustomOp & setCreateOpState(createOpState_t func, const char *ctx)
MXTensor * tensor
Definition: lib_api.h:599
CustomPartitioner & setSupportedOps(const char *prop_name, supportedOps_t fn)
std::map< std::string, JsonVal > attrs
Definition: lib_api.h:667
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1154
@ kDLMetal
Metal for Apple GPU.
Definition: lib_api.h:116
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:815
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
reviewSubgraph_t getReviewSubgraph(int stg_id)
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:249
int(* initialize_t)(int version)
Definition: lib_api.h:1270
MX_VOID_RET _opCallDestroyOpState(void *state_op)
returns status of deleting StatefulOp instance for operator from library
void _setPassResource(PassResource *res_)
const char * name
partitioner name
Definition: lib_api.h:887
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:766
@ MX_FAIL
Definition: lib_api.h:291
int entry
Definition: lib_api.h:580
static JsonVal parse_map(const std::string &json, unsigned int *idx)
T & get(int idx)
Definition: lib_api.h:924
void setDLTensor()
populate DLTensor fields
virtual void Reset()
Definition: lib_api.h:706
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:894
#define MX_ERROR_MSG
Definition: lib_api.h:244
const char * name
pass name
Definition: lib_api.h:842
std::unordered_map< std::string, std::string > attrs
Definition: lib_api.h:603
@ kFloat16
Definition: lib_api.h:252
const char * name
operator name
Definition: lib_api.h:804
Definition: lib_api.h:296
Context info passing from MXNet OpContext dev_type is string repr of supported context,...
Definition: lib_api.h:277
static Registry * get() PRIVATE_SYMBOL
get singleton pointer to class
Definition: lib_api.h:908
size_t verID
Definition: lib_api.h:370
@ STR
Definition: lib_api.h:525
@ ERR
Definition: lib_api.h:525
supportedOps_t getSupportedOps(int stg_id)
std::string toString() const
@ kDLVPI
Verilog simulator buffer.
Definition: lib_api.h:118
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:568
void setTensor(void *dptr, MXDType type, const int64_t *dims, int ndims, size_t vID, MXContext mx_ctx, MXStorageType storage_type)
populate internal tensor fields
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1166
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
MXTensor * alloc_aux(const std::string &name, const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype) const
@ kDLInt
Definition: lib_api.h:143
JsonType type
Definition: lib_api.h:564
void _setParams(std::unordered_map< std::string, mxnet::ext::MXTensor > *args, std::unordered_map< std::string, mxnet::ext::MXTensor > *aux)
std::string getShapeAt(const std::string &shape, unsigned index)
static JsonVal parse_num(const std::string &json, unsigned int *idx)
int(* msgSize_t)(void)
Definition: lib_api.h:1276
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:1005
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:65
Node * node
Definition: lib_api.h:579
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
std::vector< Graph * > subgraphs
Definition: lib_api.h:602
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1047
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:442
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1202
std::string dev_type
Definition: lib_api.h:286
Node * getNode(size_t idx)
Definition: lib_api.h:584
MXTensor * alloc_arg(const std::string &name, const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype) const
std::vector< const char * > create_op_ctx_cstr
Definition: lib_api.h:815
int(* opVersion_t)()
Definition: lib_api.h:1273
@ kUNSET
Definition: lib_api.h:257
mxnet::ext::MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library.
@ kFloat32
Definition: lib_api.h:250
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
static MXerrorMsgs * get()
CustomOp & setParseAttrs(parseAttrs_t func)
void set(void *data_ptr, const int64_t *dims, int ndims, void *idx, int64_t num_idx, void *idx_ptr=nullptr, int64_t num_idx_ptr=0)
CustomOp & setInferShape(inferShape_t func)
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:142
Class to hold custom operator registration.
Definition: lib_api.h:779
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
@ kRowSparseStorage
Definition: lib_api.h:267
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:762
Node * addNode(const std::string &name, const std::string &op)
Definition: lib_api.h:220
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1241
#define MX_VOID_RET
Definition: lib_api.h:1301
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
static CustomStatefulOp * create(Ts... args)
Definition: lib_api.h:720
int(* msgGet_t)(int idx, const char **msg)
Definition: lib_api.h:1279
@ kFloat64
Definition: lib_api.h:251
mx_cpu_rand_t * get_cpu_rand_states() const
get pointer to initialized and seeded random number states located on CPU
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:902
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
MX_INT_RET _opRegSize()
returns number of ops registered in this library
CustomOp & setForward(fcomp_t fcomp, const char *ctx)
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
CustomOp(const char *op_name)
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:349
void alloc_arg(const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype)
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:475
std::vector< NodeEntry > outputs
Definition: lib_api.h:666
@ kDLCPU
CPU device.
Definition: lib_api.h:103
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:817
@ kDLGPU
CUDA GPU device.
Definition: lib_api.h:105
void alloc_sparse(MXSparse *sparse, int index, int indices_len, int indptr_len=0) const
allocate sparse memory controlled by MXNet
virtual bool Select(int nodeID)=0
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
std::vector< const char * > backward_ctx_cstr
Definition: lib_api.h:815
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1186
CustomPartitioner & addStrategy(const char *prop_name, const char *sg_name)
MXReturnValue(* graphPass_t)(mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options)
Custom Pass Create function template.
Definition: lib_api.h:828
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:700
MXStorageType stype
Definition: lib_api.h:380
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:1282
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1157
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:384
DLTensor dltensor
Definition: lib_api.h:377
bool wasCreated()
Definition: lib_api.h:726
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:890
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:892
MXStorageType
Definition: lib_api.h:263
inferShape_t infer_shape
Definition: lib_api.h:810
MXReturnValue
Definition: lib_api.h:290
T & add(const char *name)
add a new entry
Definition: lib_api.h:916
CustomOp & setMutateInputs(mutateInputs_t func)
@ kDefaultStorage
Definition: lib_api.h:265
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1189
CustomPartitioner & setCreateSelector(const char *prop_name, createSelector_t fn)
void DFS(std::function< void(Node *)> handler) const
static JsonVal parse_string(const std::string &json, unsigned int *idx)
bool ignore_warn
Definition: lib_api.h:740
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
@ kDLFloat
Definition: lib_api.h:145
static Graph * fromJson(JsonVal val)
MXContext ctx
Definition: lib_api.h:373
std::stringstream & add(const char *file, int line)
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:807
std::string name
Definition: lib_api.h:598
int64_t size() const
helper function to get data size
int64_t data_len
Definition: lib_api.h:300
#define MX_INT_RET
Definition: lib_api.h:1300
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:714
std::vector< Node * > topological_sort() const
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:101
CustomOp & setInferSType(inferSType_t func)
int num
Definition: lib_api.h:565
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:1013
~CustomStatefulOpWrapper()
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:889
MX_VOID_RET _msgGet(int idx, const char **msg)
returns operator registration at specified index
MXReturnValue(* createSelector_t)(const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:852
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **)
Definition: lib_api.h:769
OpResource(xpu_malloc_t cpu_malloc_fp, void *cpu_alloc_fp, xpu_malloc_t gpu_malloc_fp, void *gpu_alloc_fp, void *stream, sparse_malloc_t sparse_malloc_fp, void *sparse_alloc_fp, void *rng_cpu_states, void *rng_gpu_states)
virtual ~CustomStatefulOp()
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:888
@ kDLExtDev
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:126
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:816
MX_INT_RET _opVersion()
returns MXNet library version
@ kInt64
Definition: lib_api.h:256
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
Definition: lib_api.h:983
int64_t indices_len
Definition: lib_api.h:306
@ kDLCPUPinned
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:110
CustomStatefulOpWrapper(CustomStatefulOp *inst, opCallDestroyOpState_t destroy)
Definition: lib_api.h:1285
@ NUM
Definition: lib_api.h:525
@ kInt32
Definition: lib_api.h:254
int size()
Definition: lib_api.h:921
@ MAP
Definition: lib_api.h:525
std::vector< Node * > inputs
Definition: lib_api.h:665
@ kDLVulkan
Vulkan buffer for next generation graphics.
Definition: lib_api.h:114
createSelector_t getCreateSelector(int stg_id)
CustomOp & setInferType(inferType_t func)
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1088
std::vector< NodeEntry > outputs
Definition: lib_api.h:601
std::string str
Definition: lib_api.h:566
@ kInt8
Definition: lib_api.h:255
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:758
static JsonVal parse_list(const std::string &json, unsigned int *idx)