mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
c_api.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_C_API_H_
26 #define MXNET_C_API_H_
27 
29 #ifdef __cplusplus
30 extern "C" {
31 #endif // __cplusplus
32 
34 #ifdef __cplusplus
35 #define DEFAULT(x) = x
36 #else
37 #define DEFAULT(x)
38 #endif // __cplusplus
39 
40 #include <stdint.h>
41 
42 #include <stdint.h>
43 #include <stddef.h>
44 #include <stdbool.h>
45 
47 #ifdef _WIN32
48 #ifdef MXNET_EXPORTS
49 #define MXNET_DLL __declspec(dllexport)
50 #else
51 #define MXNET_DLL __declspec(dllimport)
52 #endif
53 #else
54 #define MXNET_DLL
55 #endif
56 
58 typedef unsigned int mx_uint;
60 typedef float mx_float;
61 // all the handles are simply void *
62 // will be casted internally to specific pointers types
63 // these typedefs are mainly used for readablity reasons
65 typedef void *NDArrayHandle;
67 typedef const void *FunctionHandle;
69 typedef void *AtomicSymbolCreator;
71 typedef void *CachedOpHandle;
73 typedef void *SymbolHandle;
75 typedef void *AtomicSymbolHandle;
77 typedef void *ExecutorHandle;
79 typedef void *DataIterCreator;
81 typedef void *DataIterHandle;
83 typedef void *KVStoreHandle;
85 typedef void *RecordIOHandle;
87 typedef void *RtcHandle;
89 typedef void *CudaModuleHandle;
91 typedef void *CudaKernelHandle;
92 
93 typedef void (*ExecutorMonitorCallback)(const char*,
95  void *);
96 
97 struct NativeOpInfo {
98  void (*forward)(int, float**, int*, unsigned**, int*, void*);
99  void (*backward)(int, float**, int*, unsigned**, int*, void*);
100  void (*infer_shape)(int, int*, unsigned**, void*);
101  void (*list_outputs)(char***, void*);
102  void (*list_arguments)(char***, void*);
103  // all functions also pass a payload void* pointer
104  void* p_forward;
105  void* p_backward;
109 };
110 
112  bool (*forward)(int, void**, int*, void*);
113  bool (*backward)(int, void**, int*, void*);
114  bool (*infer_shape)(int, int*, unsigned**, void*);
115  bool (*list_outputs)(char***, void*);
116  bool (*list_arguments)(char***, void*);
117  bool (*declare_backward_dependency)(const int*, const int*, const int*,
118  int*, int**, void*);
119  // all functions also pass a payload void* pointer
120  void* p_forward;
121  void* p_backward;
126 };
127 
128 typedef int (*MXGenericCallback)(void);
129 
132  int (**callbacks)(void);
133  void **contexts;
134 };
135 
140 };
141 
151 };
152 
153 
154 typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/,
155  const int* /*reqs*/, const int /*is_train*/,
156  void* /*state*/);
157 typedef int (*CustomOpDelFunc)(void* /*state*/);
158 typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
159 typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
160  unsigned** /*shapes*/, void* /*state*/);
161 typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/);
162 typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/,
163  const int* /*out_data*/, int* /*num_deps*/,
164  int** /*rdeps*/, void* /*state*/);
165 typedef int (*CustomOpCreateFunc)(const char* /*ctx*/, int /*num_inputs*/,
166  unsigned** /*shapes*/, const int* /*ndims*/,
167  const int* /*dtypes*/, struct MXCallbackList* /*ret*/,
168  void* /*state*/);
169 typedef int (*CustomOpPropCreator)(const char* /*op_type*/, const int /*num_kwargs*/,
170  const char** /*keys*/, const char** /*values*/,
171  struct MXCallbackList* /*ret*/);
172 
173 
177 };
178 
179 typedef int (*CustomFunctionBwdFunc)(int /*num_ograds*/, int /*num_igrads*/, void** /*ptrs*/,
180  const int* /*reqs*/, const int /*is_train*/,
181  void* /*state*/);
182 typedef int (*CustomFunctionDelFunc)(void* /*state*/);
183 
193 MXNET_DLL const char *MXGetLastError();
194 
195 //-------------------------------------
196 // Part 0: Global State setups
197 //-------------------------------------
203 MXNET_DLL int MXRandomSeed(int seed);
220 MXNET_DLL int MXSetProfilerConfig(int mode, const char* filename);
228 MXNET_DLL int MXSetProfilerState(int state);
229 
232 
234 MXNET_DLL int MXSetNumOMPThreads(int thread_num);
235 
241 MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);
242 
248 MXNET_DLL int MXGetVersion(int *out);
249 
250 //-------------------------------------
251 // Part 1: NDArray creation and deletion
252 //-------------------------------------
260 MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out);
272 MXNET_DLL int MXNDArrayCreate(const mx_uint *shape,
273  mx_uint ndim,
274  int dev_type,
275  int dev_id,
276  int delay_alloc,
277  NDArrayHandle *out);
278 
291 MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
292  mx_uint ndim,
293  int dev_type,
294  int dev_id,
295  int delay_alloc,
296  int dtype,
297  NDArrayHandle *out);
298 
299 
317 MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
318  const mx_uint *shape,
319  mx_uint ndim,
320  int dev_type,
321  int dev_id,
322  int delay_alloc,
323  int dtype,
324  mx_uint num_aux,
325  int *aux_type,
326  mx_uint *aux_ndims,
327  const mx_uint *aux_shape,
328  NDArrayHandle *out);
329 
337 MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf,
338  size_t size,
339  NDArrayHandle *out);
347 MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle,
348  size_t *out_size,
349  const char **out_buf);
358 MXNET_DLL int MXNDArraySave(const char* fname,
359  mx_uint num_args,
360  NDArrayHandle* args,
361  const char** keys);
371 MXNET_DLL int MXNDArrayLoad(const char* fname,
372  mx_uint *out_size,
373  NDArrayHandle** out_arr,
374  mx_uint *out_name_size,
375  const char*** out_names);
387 MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
388  const void *data,
389  size_t size);
401 MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
402  void *data,
403  size_t size);
411 MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
412  const NDArrayHandle handle_src,
413  const int i);
414 
420 MXNET_DLL int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check);
427 MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle);
434 MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle);
446 MXNET_DLL int MXNDArrayFree(NDArrayHandle handle);
455 MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
456  mx_uint slice_begin,
457  mx_uint slice_end,
458  NDArrayHandle *out);
459 
467 MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
468  mx_uint idx,
469  NDArrayHandle *out);
470 
474 MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
475  int *out_storage_type);
476 
485 MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
486  int ndim,
487  int *dims,
488  NDArrayHandle *out);
496 MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
497  mx_uint *out_dim,
498  const mx_uint **out_pdata);
505 MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
506  void **out_pdata);
513 MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
514  int *out_dtype);
515 
523 MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
524  mx_uint i,
525  int *out_type);
526 
532 MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
533  mx_uint i,
534  NDArrayHandle *out);
535 
541 MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle,
542  NDArrayHandle *out);
550 MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle,
551  int *out_dev_type,
552  int *out_dev_id);
558 MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out);
564 MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out);
571 MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state);
578 MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
579 //--------------------------------
580 // Part 2: functions on NDArray
581 //--------------------------------
589 MXNET_DLL int MXListFunctions(mx_uint *out_size,
590  FunctionHandle **out_array);
597 MXNET_DLL int MXGetFunction(const char *name,
598  FunctionHandle *out);
611 MXNET_DLL int MXFuncGetInfo(FunctionHandle fun,
612  const char **name,
613  const char **description,
614  mx_uint *num_args,
615  const char ***arg_names,
616  const char ***arg_type_infos,
617  const char ***arg_descriptions,
618  const char **return_type DEFAULT(NULL));
629 MXNET_DLL int MXFuncDescribe(FunctionHandle fun,
630  mx_uint *num_use_vars,
631  mx_uint *num_scalars,
632  mx_uint *num_mutate_vars,
633  int *type_mask);
644 MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
645  NDArrayHandle *use_vars,
646  mx_float *scalar_args,
647  NDArrayHandle *mutate_vars);
661 MXNET_DLL int MXFuncInvokeEx(FunctionHandle fun,
662  NDArrayHandle *use_vars,
663  mx_float *scalar_args,
664  NDArrayHandle *mutate_vars,
665  int num_params,
666  char **param_keys,
667  char **param_vals);
680 MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator,
681  int num_inputs,
682  NDArrayHandle *inputs,
683  int *num_outputs,
684  NDArrayHandle **outputs,
685  int num_params,
686  const char **param_keys,
687  const char **param_vals);
701 MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator,
702  int num_inputs,
703  NDArrayHandle *inputs,
704  int *num_outputs,
705  NDArrayHandle **outputs,
706  int num_params,
707  const char **param_keys,
708  const char **param_vals,
709  const int **out_stypes);
716 MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int* prev);
723 MXNET_DLL int MXAutogradSetIsTraining(int is_training, int* prev);
729 MXNET_DLL int MXAutogradIsRecording(bool* curr);
735 MXNET_DLL int MXAutogradIsTraining(bool* curr);
742 MXNET_DLL int MXAutogradMarkVariables(mx_uint num_var,
743  NDArrayHandle *var_handles,
744  mx_uint *reqs_array,
745  NDArrayHandle *grad_handles);
752 MXNET_DLL int MXAutogradComputeGradient(mx_uint num_output,
753  NDArrayHandle* output_handles);
762 MXNET_DLL int MXAutogradBackward(mx_uint num_output,
763  NDArrayHandle* output_handles,
764  NDArrayHandle* ograd_handles,
765  int retain_graph);
777 MXNET_DLL int MXAutogradBackwardEx(mx_uint num_output,
778  NDArrayHandle *output_handles,
779  NDArrayHandle *ograd_handles,
780  mx_uint num_variables,
781  NDArrayHandle *var_handles,
782  int retain_graph,
783  int create_graph,
784  int is_train,
785  NDArrayHandle **grad_handles,
786  int **grad_stypes);
787 /*
788  * \brief get the graph constructed by autograd.
789  * \param handle ndarray handle
790  * \param out output symbol handle
791  */
792 MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out);
796 MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out);
800 MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
801  int num_params,
802  const char** keys,
803  const char** vals,
804  CachedOpHandle *out);
808 MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle);
812 MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle,
813  int num_inputs,
814  NDArrayHandle *inputs,
815  int *num_outputs,
816  NDArrayHandle **outputs);
827 MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
828  int num_inputs,
829  NDArrayHandle *inputs,
830  int *num_outputs,
831  NDArrayHandle **outputs,
832  const int** out_stypes);
836 MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle,
837  int num_inputs,
838  NDArrayHandle *inputs,
839  int *num_outputs,
840  NDArrayHandle **outputs);
841 //--------------------------------------------
842 // Part 3: symbolic configuration generation
843 //--------------------------------------------
850 MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
851  const char ***out_array);
858 MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
859  AtomicSymbolCreator **out_array);
860 
866 MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
867  const char **name);
885 MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
886  const char **name,
887  const char **description,
888  mx_uint *num_args,
889  const char ***arg_names,
890  const char ***arg_type_infos,
891  const char ***arg_descriptions,
892  const char **key_var_num_args,
893  const char **return_type DEFAULT(NULL));
903 MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
904  mx_uint num_param,
905  const char **keys,
906  const char **vals,
907  SymbolHandle *out);
914 MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out);
922 MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols,
923  SymbolHandle *symbols,
924  SymbolHandle *out);
931 MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out);
938 MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out);
945 MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname);
952 MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json);
958 MXNET_DLL int MXSymbolFree(SymbolHandle symbol);
965 MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
972 MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str);
980 MXNET_DLL int MXSymbolGetName(SymbolHandle symbol,
981  const char** out,
982  int *success);
991 MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol,
992  const char* key,
993  const char** out,
994  int *success);
1011 MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol,
1012  const char* key,
1013  const char* value);
1021 MXNET_DLL int MXSymbolListAttr(SymbolHandle symbol,
1022  mx_uint *out_size,
1023  const char*** out);
1031 MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
1032  mx_uint *out_size,
1033  const char*** out);
1041 MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
1042  mx_uint *out_size,
1043  const char ***out_str_array);
1051 MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
1052  mx_uint *out_size,
1053  const char ***out_str_array);
1060 MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol,
1061  SymbolHandle *out);
1068 MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol,
1069  SymbolHandle *out);
1077 MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
1078  mx_uint index,
1079  SymbolHandle *out);
1087 MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
1088  mx_uint *out_size,
1089  const char ***out_str_array);
1104 MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
1105  const char *name,
1106  mx_uint num_args,
1107  const char** keys,
1108  SymbolHandle* args);
1118 MXNET_DLL int MXSymbolGrad(SymbolHandle sym,
1119  mx_uint num_wrt,
1120  const char** wrt,
1121  SymbolHandle* out);
1144 MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
1145  mx_uint num_args,
1146  const char** keys,
1147  const mx_uint *arg_ind_ptr,
1148  const mx_uint *arg_shape_data,
1149  mx_uint *in_shape_size,
1150  const mx_uint **in_shape_ndim,
1151  const mx_uint ***in_shape_data,
1152  mx_uint *out_shape_size,
1153  const mx_uint **out_shape_ndim,
1154  const mx_uint ***out_shape_data,
1155  mx_uint *aux_shape_size,
1156  const mx_uint **aux_shape_ndim,
1157  const mx_uint ***aux_shape_data,
1158  int *complete);
1183 MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
1184  mx_uint num_args,
1185  const char** keys,
1186  const mx_uint *arg_ind_ptr,
1187  const mx_uint *arg_shape_data,
1188  mx_uint *in_shape_size,
1189  const mx_uint **in_shape_ndim,
1190  const mx_uint ***in_shape_data,
1191  mx_uint *out_shape_size,
1192  const mx_uint **out_shape_ndim,
1193  const mx_uint ***out_shape_data,
1194  mx_uint *aux_shape_size,
1195  const mx_uint **aux_shape_ndim,
1196  const mx_uint ***aux_shape_data,
1197  int *complete);
1198 
1217 MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
1218  mx_uint num_args,
1219  const char** keys,
1220  const int *arg_type_data,
1221  mx_uint *in_type_size,
1222  const int **in_type_data,
1223  mx_uint *out_type_size,
1224  const int **out_type_data,
1225  mx_uint *aux_type_size,
1226  const int **aux_type_data,
1227  int *complete);
1228 
1229 
1230 
1231 
1232 //--------------------------------------------
1233 // Part 4: Executor interface
1234 //--------------------------------------------
1240 MXNET_DLL int MXExecutorFree(ExecutorHandle handle);
1247 MXNET_DLL int MXExecutorPrint(ExecutorHandle handle, const char **out_str);
1255 MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train);
1265 MXNET_DLL int MXExecutorBackward(ExecutorHandle handle,
1266  mx_uint len,
1267  NDArrayHandle *head_grads);
1278 MXNET_DLL int MXExecutorBackwardEx(ExecutorHandle handle,
1279  mx_uint len,
1280  NDArrayHandle *head_grads,
1281  int is_train);
1290 MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle,
1291  mx_uint *out_size,
1292  NDArrayHandle **out);
1293 
1309 MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
1310  int dev_type,
1311  int dev_id,
1312  mx_uint len,
1313  NDArrayHandle *in_args,
1314  NDArrayHandle *arg_grad_store,
1315  mx_uint *grad_req_type,
1316  mx_uint aux_states_len,
1317  NDArrayHandle *aux_states,
1318  ExecutorHandle *out);
1340 MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle,
1341  int dev_type,
1342  int dev_id,
1343  mx_uint num_map_keys,
1344  const char** map_keys,
1345  const int* map_dev_types,
1346  const int* map_dev_ids,
1347  mx_uint len,
1348  NDArrayHandle *in_args,
1349  NDArrayHandle *arg_grad_store,
1350  mx_uint *grad_req_type,
1351  mx_uint aux_states_len,
1352  NDArrayHandle *aux_states,
1353  ExecutorHandle *out);
1376 MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
1377  int dev_type,
1378  int dev_id,
1379  mx_uint num_map_keys,
1380  const char** map_keys,
1381  const int* map_dev_types,
1382  const int* map_dev_ids,
1383  mx_uint len,
1384  NDArrayHandle *in_args,
1385  NDArrayHandle *arg_grad_store,
1386  mx_uint *grad_req_type,
1387  mx_uint aux_states_len,
1388  NDArrayHandle *aux_states,
1389  ExecutorHandle shared_exec,
1390  ExecutorHandle *out);
1391 
1392 MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
1393  int dev_type,
1394  int dev_id,
1395  const mx_uint num_g2c_keys,
1396  const char** g2c_keys,
1397  const int* g2c_dev_types,
1398  const int* g2c_dev_ids,
1399  const mx_uint provided_grad_req_list_len,
1400  const char** provided_grad_req_names,
1401  const char** provided_grad_req_types,
1402  const mx_uint num_provided_arg_shapes,
1403  const char** provided_arg_shape_names,
1404  const mx_uint* provided_arg_shape_data,
1405  const mx_uint* provided_arg_shape_idx,
1406  const mx_uint num_provided_arg_dtypes,
1407  const char** provided_arg_dtype_names,
1408  const int* provided_arg_dtypes,
1409  const mx_uint num_provided_arg_stypes,
1410  const char** provided_arg_stype_names,
1411  const int* provided_arg_stypes,
1412  const mx_uint num_shared_arg_names,
1413  const char** shared_arg_name_list,
1414  int* shared_buffer_len,
1415  const char** shared_buffer_name_list,
1416  NDArrayHandle* shared_buffer_handle_list,
1417  const char*** updated_shared_buffer_name_list,
1418  NDArrayHandle** updated_shared_buffer_handle_list,
1419  mx_uint* num_in_args,
1420  NDArrayHandle** in_args,
1421  NDArrayHandle** arg_grads,
1422  mx_uint* num_aux_states,
1423  NDArrayHandle** aux_states,
1424  ExecutorHandle shared_exec_handle,
1425  ExecutorHandle* out);
1429 MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle,
1430  ExecutorMonitorCallback callback,
1431  void* callback_handle);
1432 //--------------------------------------------
1433 // Part 5: IO Interface
1434 //--------------------------------------------
1441 MXNET_DLL int MXListDataIters(mx_uint *out_size,
1442  DataIterCreator **out_array);
1453 MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle,
1454  mx_uint num_param,
1455  const char **keys,
1456  const char **vals,
1457  DataIterHandle *out);
1469 MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator,
1470  const char **name,
1471  const char **description,
1472  mx_uint *num_args,
1473  const char ***arg_names,
1474  const char ***arg_type_infos,
1475  const char ***arg_descriptions);
1481 MXNET_DLL int MXDataIterFree(DataIterHandle handle);
1488 MXNET_DLL int MXDataIterNext(DataIterHandle handle,
1489  int *out);
1495 MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);
1496 
1503 MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
1504  NDArrayHandle *out);
1512 MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle,
1513  uint64_t **out_index,
1514  uint64_t *out_size);
1521 MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle,
1522  int *pad);
1523 
1530 MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
1531  NDArrayHandle *out);
1532 //--------------------------------------------
1533 // Part 6: basic KVStore interface
1534 //--------------------------------------------
1541 MXNET_DLL int MXInitPSEnv(mx_uint num_vars,
1542  const char **keys,
1543  const char **vals);
1544 
1545 
1552 MXNET_DLL int MXKVStoreCreate(const char *type,
1553  KVStoreHandle *out);
1554 
1562 MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle,
1563  mx_uint num_params,
1564  const char** keys,
1565  const char** vals);
1566 
1572 MXNET_DLL int MXKVStoreFree(KVStoreHandle handle);
1581 MXNET_DLL int MXKVStoreInit(KVStoreHandle handle,
1582  mx_uint num,
1583  const int* keys,
1584  NDArrayHandle* vals);
1585 
1594 MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle,
1595  mx_uint num,
1596  const char** keys,
1597  NDArrayHandle* vals);
1598 
1608 MXNET_DLL int MXKVStorePush(KVStoreHandle handle,
1609  mx_uint num,
1610  const int* keys,
1611  NDArrayHandle* vals,
1612  int priority);
1622 MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle,
1623  mx_uint num,
1624  const char** keys,
1625  NDArrayHandle* vals,
1626  int priority);
1636 MXNET_DLL int MXKVStorePull(KVStoreHandle handle,
1637  mx_uint num,
1638  const int* keys,
1639  NDArrayHandle* vals,
1640  int priority);
1650 MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
1651  mx_uint num,
1652  const char** keys,
1653  NDArrayHandle* vals,
1654  int priority);
1655 
1668 MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
1669  mx_uint num,
1670  const int* keys,
1671  NDArrayHandle* vals,
1672  const NDArrayHandle* row_ids,
1673  int priority);
1686 MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle,
1687  mx_uint num,
1688  const char** keys,
1689  NDArrayHandle* vals,
1690  const NDArrayHandle* row_ids,
1691  int priority);
1692 
1701 typedef void (MXKVStoreUpdater)(int key,
1702  NDArrayHandle recv,
1703  NDArrayHandle local,
1704  void *handle);
1713 typedef void (MXKVStoreStrUpdater)(const char* key,
1714  NDArrayHandle recv,
1715  NDArrayHandle local,
1716  void *handle);
1724 MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle,
1725  MXKVStoreUpdater updater,
1726  void *updater_handle);
1735 MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
1736  MXKVStoreUpdater updater,
1737  MXKVStoreStrUpdater str_updater,
1738  void *updater_handle);
1745 MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle,
1746  const char** type);
1747 //--------------------------------------------
1748 // Part 6: advanced KVStore for multi-machines
1749 //--------------------------------------------
1750 
1758 MXNET_DLL int MXKVStoreGetRank(KVStoreHandle handle,
1759  int *ret);
1760 
1770 MXNET_DLL int MXKVStoreGetGroupSize(KVStoreHandle handle,
1771  int *ret);
1772 
1778 MXNET_DLL int MXKVStoreIsWorkerNode(int *ret);
1779 
1780 
1786 MXNET_DLL int MXKVStoreIsServerNode(int *ret);
1787 
1788 
1794 MXNET_DLL int MXKVStoreIsSchedulerNode(int *ret);
1795 
1802 MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle);
1803 
1811 MXNET_DLL int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle,
1812  const int barrier_before_exit);
1813 
1820 typedef void (MXKVStoreServerController)(int head,
1821  const char *body,
1822  void *controller_handle);
1823 
1832 MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
1833  MXKVStoreServerController controller,
1834  void *controller_handle);
1835 
1844 MXNET_DLL int MXKVStoreSendCommmandToServers(KVStoreHandle handle,
1845  int cmd_id,
1846  const char* cmd_body);
1847 
1858 MXNET_DLL int MXKVStoreGetNumDeadNode(KVStoreHandle handle,
1859  const int node_id,
1860  int *number,
1861  const int timeout_sec DEFAULT(60));
1862 
1869 MXNET_DLL int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out);
1870 
1876 MXNET_DLL int MXRecordIOWriterFree(RecordIOHandle handle);
1877 
1885 MXNET_DLL int MXRecordIOWriterWriteRecord(RecordIOHandle handle,
1886  const char *buf, size_t size);
1887 
1894 MXNET_DLL int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos);
1895 
1902 MXNET_DLL int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out);
1903 
1909 MXNET_DLL int MXRecordIOReaderFree(RecordIOHandle handle);
1910 
1918 MXNET_DLL int MXRecordIOReaderReadRecord(RecordIOHandle handle,
1919  char const **buf, size_t *size);
1920 
1927 MXNET_DLL int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos);
1928 
1935 MXNET_DLL int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos);
1936 
1940 MXNET_DLL int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
1941  char** input_names, char** output_names,
1942  NDArrayHandle* inputs, NDArrayHandle* outputs,
1943  char* kernel, RtcHandle *out);
1944 
1948 MXNET_DLL int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
1949  NDArrayHandle* inputs, NDArrayHandle* outputs,
1950  mx_uint gridDimX,
1951  mx_uint gridDimY,
1952  mx_uint gridDimZ,
1953  mx_uint blockDimX,
1954  mx_uint blockDimY,
1955  mx_uint blockDimZ);
1956 
1960 MXNET_DLL int MXRtcFree(RtcHandle handle);
1961 /*
1962  * \brief register custom operators from frontend.
1963  * \param op_type name of custom op
1964  * \param creator
1965  */
1966 MXNET_DLL int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator);
1967 /*
1968  * \brief record custom function for backward later.
1969  * \param num_inputs number of input NDArrays.
1970  * \param inputs handle to input NDArrays.
1971  * \param num_outputs number of output NDArrays.
1972  * \param outputs handle to output NDArrays.
1973  * \param callbacks callbacks for backward function.
1974  */
1975 MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
1976  int num_outputs, NDArrayHandle *outputs,
1977  struct MXCallbackList *callbacks);
1978 /*
1979  * \brief create cuda rtc module
1980  * \param source cuda source code
1981  * \param num_options number of compiler flags
1982  * \param options compiler flags
1983  * \param num_exports number of exported function names
1984  * \param exported function names
1985  * \param out handle to created module
1986  */
1987 MXNET_DLL int MXRtcCudaModuleCreate(const char* source, int num_options,
1988  const char** options, int num_exports,
1989  const char** exports, CudaModuleHandle *out);
1990 /*
1991  * \brief delete cuda rtc module
1992  * \param handle handle to cuda module
1993  */
1994 MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle);
1995 /*
1996  * \brief get kernel from module
1997  * \param handle handle to cuda module
1998  * \param name name of kernel function
1999  * \param num_args number of arguments
2000  * \param is_ndarray whether argument is ndarray
2001  * \param is_const whether argument is constant
2002  * \param arg_types data type of arguments
2003  * \param out created kernel
2004  */
2005 MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name,
2006  int num_args, int* is_ndarray, int* is_const,
2007  int* arg_types, CudaKernelHandle *out);
2008 /*
2009  * \brief delete kernel
2010  * \param handle handle to previously created kernel
2011  */
2012 MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle);
2013 /*
2014  * \brief launch cuda kernel
2015  * \param handle handle to kernel
2016  * \param dev_id (GPU) device id
2017  * \param args pointer to arguments
2018  * \param grid_dim_x grid dimension x
2019  * \param grid_dim_y grid dimension y
2020  * \param grid_dim_z grid dimension z
2021  * \param block_dim_x block dimension x
2022  * \param block_dim_y block dimension y
2023  * \param block_dim_z block dimension z
2024  * \param shared_mem size of dynamically allocated shared memory
2025  */
2026 MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
2027  mx_uint grid_dim_x, mx_uint grid_dim_y,
2028  mx_uint grid_dim_z, mx_uint block_dim_x,
2029  mx_uint block_dim_y, mx_uint block_dim_z,
2030  mx_uint shared_mem);
2037 MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
2038  int* shared_id);
2048 MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
2049  mx_uint ndim, int dtype, NDArrayHandle *out);
2050 
2051 
2052 #ifdef __cplusplus
2053 }
2054 #endif // __cplusplus
2055 
2056 #endif // MXNET_C_API_H_
MXNET_DLL int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char *cmd_body)
bool(* list_arguments)(char ***, void *)
Definition: c_api.h:116
MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out)
detach and ndarray from computation graph by clearing entry_
MXNET_DLL int MXListAllOpNames(mx_uint *out_size, const char ***out_array)
list all the available operator names, include entries.
int(* CustomOpFBFunc)(int, void **, int *, const int *, const int, void *)
Definition: c_api.h:154
int(** callbacks)(void)
Definition: c_api.h:132
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape, mx_uint ndim, int dtype, NDArrayHandle *out)
Reconstruct NDArray from shared memory handle.
MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, mx_uint i, int *out_type)
get the type of the ith aux data in NDArray
MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void *updater_handle)
register a push updater
MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle)
global barrier among all worker machines
void * DataIterHandle
handle to a DataIterator
Definition: c_api.h:81
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array)
list all the available AtomicSymbolEntry
Definition: c_api.h:137
MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, mx_uint num_param, const char **keys, const char **vals, DataIterHandle *out)
Init an iterator, init with parameters the array size of passed in arguments.
MXNET_DLL int MXDataIterNext(DataIterHandle handle, int *out)
Move iterator to next position.
int(* CustomOpInferTypeFunc)(int, int *, void *)
Definition: c_api.h:161
MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, int ndim, int *dims, NDArrayHandle *out)
Reshape the NDArray.
void * p_infer_shape
Definition: c_api.h:122
int(* CustomFunctionDelFunc)(void *)
Definition: c_api.h:182
MXNET_DLL int MXAutogradMarkVariables(mx_uint num_var, NDArrayHandle *var_handles, mx_uint *reqs_array, NDArrayHandle *grad_handles)
mark NDArrays as variables to compute gradient for autograd
Definition: c_api.h:147
MXNET_DLL int MXExecutorPrint(ExecutorHandle handle, const char **out_str)
Print the content of execution plan, used for debug.
MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle, MXKVStoreServerController controller, void *controller_handle)
MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json)
Save a symbol into a json string.
MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, mx_uint num_args, const char **keys, const mx_uint *arg_ind_ptr, const mx_uint *arg_shape_data, mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete)
infer shape of unknown input shapes given the known one. The shapes are packed into a CSR matrix repr...
MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol, const char *key, const char **out, int *success)
Get string attribute from symbol.
MXNET_DLL int MXRtcCreate(char *name, mx_uint num_input, mx_uint num_output, char **input_names, char **output_names, NDArrayHandle *inputs, NDArrayHandle *outputs, char *kernel, RtcHandle *out)
Create a MXRtc object.
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **name)
Get the name of an atomic symbol.
MXNET_DLL int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int *keys, NDArrayHandle *vals, int priority)
pull a list of (key, value) pairs from the kvstore
MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle)
Wait until all the pending writes with respect NDArray are finished. Always call this before read dat...
MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint len, NDArrayHandle *in_args, NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, NDArrayHandle *aux_states, ExecutorHandle *out)
Generate Executor from symbol.
int(* CustomFunctionBwdFunc)(int, int, void **, const int *, const int, void *)
Definition: c_api.h:179
void * p_forward
Definition: c_api.h:120
MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle, mx_uint num, const int *keys, NDArrayHandle *vals, const NDArrayHandle *row_ids, int priority)
pull a list of (key, value) pairs from the kvstore, where each key is an integer. The NDArray pulled ...
MXNET_DLL int MXDataIterFree(DataIterHandle handle)
Free the handle to the IO module.
MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle, MXKVStoreUpdater updater, MXKVStoreStrUpdater str_updater, void *updater_handle)
register a push updater with int keys and one with string keys
MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, int num_inputs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs)
invoke cached operator
MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle)
free cached operator
void * RecordIOHandle
handle to RecordIO
Definition: c_api.h:85
CustomOpCallbacks
Definition: c_api.h:136
MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol, SymbolHandle *out)
Get a symbol that contains only direct children.
MXNET_DLL int MXRecordIOReaderFree(RecordIOHandle handle)
Delete a RecordIO reader object.
MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out)
Get executor's head NDArray.
MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id)
get the context of the NDArray
MXNET_DLL int MXFuncDescribe(FunctionHandle fun, mx_uint *num_use_vars, mx_uint *num_scalars, mx_uint *num_mutate_vars, int *type_mask)
get the argument requirements of the function
void(* infer_shape)(int, int *, unsigned **, void *)
Definition: c_api.h:100
MXNET_DLL int MXKVStoreIsServerNode(int *ret)
return whether or not this process is a server node.
MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out)
Slice the NDArray along axis 0.
bool(* forward)(int, void **, int *, void *)
Definition: c_api.h:112
void * ExecutorHandle
handle to an Executor
Definition: c_api.h:77
CustomFunctionCallbacks
Definition: c_api.h:174
MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname)
Save a symbol into a json file.
void * AtomicSymbolHandle
handle to a AtomicSymbol
Definition: c_api.h:75
void * p_backward
Definition: c_api.h:105
MXNET_DLL int MXFuncGetInfo(FunctionHandle fun, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **return_type DEFAULT(NULL))
Get the information of the function handle.
MXNET_DLL int MXAutogradSetIsTraining(int is_training, int *prev)
set whether to record operator for autograd
MXNET_DLL int MXNDArrayFree(NDArrayHandle handle)
free the narray handle
void(* list_arguments)(char ***, void *)
Definition: c_api.h:102
MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out)
Load a symbol from a json string.
CustomOpPropCallbacks
Definition: c_api.h:142
bool(* infer_shape)(int, int *, unsigned **, void *)
Definition: c_api.h:114
MXNET_DLL int MXKVStoreIsSchedulerNode(int *ret)
return whether or not this process is a scheduler node.
int(* CustomOpDelFunc)(void *)
Definition: c_api.h:157
MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, int num_inputs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs, const int **out_stypes)
invoke a cached op
MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **key_var_num_args, const char **return_type DEFAULT(NULL))
Get the detailed information about atomic symbol.
void(* backward)(int, float **, int *, unsigned **, int *, void *)
Definition: c_api.h:99
#define DEFAULT(x)
Inhibit C++ name-mangling for MXNet functions.
Definition: c_api.h:37
MXNET_DLL int MXSetProfilerState(int state)
Set up state of profiler.
MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator, int num_inputs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs, int num_params, const char **param_keys, const char **param_vals)
invoke a nnvm op and imperative function
MXNET_DLL int MXSymbolListAttr(SymbolHandle symbol, mx_uint *out_size, const char ***out)
Get all attributes from symbol, including all descendents.
MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, int num_params, const char **keys, const char **vals, CachedOpHandle *out)
create cached operator
MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle)
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out)
Index the NDArray along axis 0.
MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, mx_uint i, NDArrayHandle *out)
Get a deep copy of the ith aux data blob in the form of an NDArray of default storage type...
int(* CustomOpInferShapeFunc)(int, int *, unsigned **, void *)
Definition: c_api.h:159
MXNET_DLL int MXNDArrayLoad(const char *fname, mx_uint *out_size, NDArrayHandle **out_arr, mx_uint *out_name_size, const char ***out_names)
Load list of narray from the file.
MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char *name, int num_args, int *is_ndarray, int *is_const, int *arg_types, CudaKernelHandle *out)
MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const void *data, size_t size)
Perform a synchronize copy from a continugous CPU memory region.
MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out)
Load a symbol from a json file.
MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, mx_uint num, const char **keys, NDArrayHandle *vals, int priority)
Push a list of (key,value) pairs to kvstore, where each key is a string.
Definition: c_api.h:111
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, int dtype, mx_uint num_aux, int *aux_type, mx_uint *aux_ndims, const mx_uint *aux_shape, NDArrayHandle *out)
create an empty sparse NDArray with specified shape and data type
MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, int dev_type, int dev_id, const mx_uint num_g2c_keys, const char **g2c_keys, const int *g2c_dev_types, const int *g2c_dev_ids, const mx_uint provided_grad_req_list_len, const char **provided_grad_req_names, const char **provided_grad_req_types, const mx_uint num_provided_arg_shapes, const char **provided_arg_shape_names, const mx_uint *provided_arg_shape_data, const mx_uint *provided_arg_shape_idx, const mx_uint num_provided_arg_dtypes, const char **provided_arg_dtype_names, const int *provided_arg_dtypes, const mx_uint num_provided_arg_stypes, const char **provided_arg_stype_names, const int *provided_arg_stypes, const mx_uint num_shared_arg_names, const char **shared_arg_name_list, int *shared_buffer_len, const char **shared_buffer_name_list, NDArrayHandle *shared_buffer_handle_list, const char ***updated_shared_buffer_name_list, NDArrayHandle **updated_shared_buffer_handle_list, mx_uint *num_in_args, NDArrayHandle **in_args, NDArrayHandle **arg_grads, mx_uint *num_aux_states, NDArrayHandle **aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle *out)
void * p_forward
Definition: c_api.h:104
MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, int dtype, NDArrayHandle *out)
create a NDArray with specified shape and data type
MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int *prev_bulk_size)
set bulk execution limit
void * p_list_arguments
Definition: c_api.h:108
MXNET_DLL int MXKVStoreCreate(const char *type, KVStoreHandle *out)
Create a kvstore.
int(* CustomOpBwdDepFunc)(const int *, const int *, const int *, int *, int **, void *)
Definition: c_api.h:162
void * p_declare_backward_dependency
Definition: c_api.h:125
MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out)
create a NDArray handle that is not initialized can be used to pass in as mutate variables to hold th...
MXNET_DLL int MXKVStoreFree(KVStoreHandle handle)
Delete a KVStore handle.
void * p_list_arguments
Definition: c_api.h:124
MXNET_DLL int MXCustomOpRegister(const char *op_type, CustomOpPropCreator creator)
MXNET_DLL int MXExecutorBackwardEx(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads, int is_train)
Excecutor run backward.
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:73
void * CudaModuleHandle
handle to rtc cuda module
Definition: c_api.h:89
MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out)
Copy the symbol to another handle.
void(* list_outputs)(char ***, void *)
Definition: c_api.h:101
MXNET_DLL int MXGetVersion(int *out)
get the MXNet library version as an integer
void * CachedOpHandle
handle to cached operator
Definition: c_api.h:71
MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol, const char *key, const char *value)
Set string attribute from symbol. NOTE: Setting attribute to a symbol can affect the semantics(mutabl...
MXNET_DLL int MXSetNumOMPThreads(int thread_num)
Set the number of OMP threads to use.
MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle)
Call iterator.Reset.
MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, mx_uint num_args, const char **keys, const mx_uint *arg_ind_ptr, const mx_uint *arg_shape_data, mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete)
partially infer shape of unknown input shapes given the known one.
MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array)
List returns in the symbol.
void( MXKVStoreUpdater)(int key, NDArrayHandle recv, NDArrayHandle local, void *handle)
user-defined updater for the kvstore It's this updater's responsibility to delete recv and local ...
Definition: c_api.h:1701
MXNET_DLL int MXListDataIters(mx_uint *out_size, DataIterCreator **out_array)
List all the available iterator entries.
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata)
get the shape of the array
MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle, int *pad)
Get the padding number in current data batch.
MXNET_DLL int MXAutogradComputeGradient(mx_uint num_output, NDArrayHandle *output_handles)
compute the gradient of outputs w.r.t variabels
MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int *shared_pid, int *shared_id)
Get shared memory handle from NDArray.
MXNET_DLL int MXFuncInvokeEx(FunctionHandle fun, NDArrayHandle *use_vars, mx_float *scalar_args, NDArrayHandle *mutate_vars, int num_params, char **param_keys, char **param_vals)
invoke a function, the array size of passed in arguments must match the values in the ...
MXNET_DLL int MXRtcCudaModuleCreate(const char *source, int num_options, const char **options, int num_exports, const char **exports, CudaModuleHandle *out)
int num_callbacks
Definition: c_api.h:131
MXNET_DLL int MXRandomSeed(int seed)
Seed the global random number generators in mxnet.
MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, mx_uint index, SymbolHandle *out)
Get index-th outputs of the symbol.
MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype)
get the type of the data in NDArray
MXNET_DLL int MXAutogradIsTraining(bool *curr)
get whether training mode is on
MXNET_DLL int MXAutogradBackwardEx(mx_uint num_output, NDArrayHandle *output_handles, NDArrayHandle *ograd_handles, mx_uint num_variables, NDArrayHandle *var_handles, int retain_graph, int create_graph, int is_train, NDArrayHandle **grad_handles, int **grad_stypes)
compute the gradient of outputs w.r.t variabels
void( MXKVStoreStrUpdater)(const char *key, NDArrayHandle recv, NDArrayHandle local, void *handle)
user-defined updater for the kvstore with string keys It's this updater's responsibility to delete re...
Definition: c_api.h:1713
Definition: c_api.h:175
MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out)
Get a symbol that contains all the internals.
MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, NDArrayHandle *out)
Get a deep copy of the data blob in the form of an NDArray of default storage type. This function blocks. Do not use it in performance critical code.
int(* CustomOpPropCreator)(const char *, const int, const char **, const char **, struct MXCallbackList *)
Definition: c_api.h:169
MXNET_DLL int MXInitPSEnv(mx_uint num_vars, const char **keys, const char **vals)
Initialized ps-lite environment variables.
MXNET_DLL int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out)
Get the handle to the NDArray of underlying data.
MXNET_DLL int MXRtcFree(RtcHandle handle)
Delete a MXRtc object.
MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, void *data, size_t size)
Perform a synchronize copyto a continugous CPU memory region.
MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int *prev)
set whether to record operator for autograd
Definition: c_api.h:138
void * KVStoreHandle
handle to KVStore
Definition: c_api.h:83
MXNET_DLL int MXSetProfilerConfig(int mode, const char *filename)
Set up configuration of profiler.
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:65
MXNET_DLL int MXNotifyShutdown()
Notify the engine about a shutdown, This can help engine to print less messages into display...
bool(* list_outputs)(char ***, void *)
Definition: c_api.h:115
void * p_list_outputs
Definition: c_api.h:123
bool(* declare_backward_dependency)(const int *, const int *, const int *, int *, int **, void *)
Definition: c_api.h:117
MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size)
Get the image index by array.
MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out)
Get the handle to the NDArray of underlying label.
Definition: c_api.h:144
void(* forward)(int, float **, int *, unsigned **, int *, void *)
Definition: c_api.h:98
MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst, const NDArrayHandle handle_src, const int i)
Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0 This function blocks...
MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols, SymbolHandle *symbols, SymbolHandle *out)
Create a Symbol by grouping list of symbols together.
MXNET_DLL int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out)
create a NDArray with specified shape
MXNET_DLL int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, const int barrier_before_exit)
whether to do barrier when finalize
MXNET_DLL int MXFuncInvoke(FunctionHandle fun, NDArrayHandle *use_vars, mx_float *scalar_args, NDArrayHandle *mutate_vars)
invoke a function, the array size of passed in arguments must match the values in the ...
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array)
List arguments in the symbol.
MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, int *out_storage_type)
get the storage type of the array
MXNET_DLL int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out)
Create a RecordIO writer object.
void * p_infer_shape
Definition: c_api.h:106
MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle, mx_uint num, const char **keys, NDArrayHandle *vals, const NDArrayHandle *row_ids, int priority)
pull a list of (key, value) pairs from the kvstore, where each key is a string. The NDArray pulled ba...
MXNET_DLL int MXNDArrayWaitAll()
wait until all delayed operations in the system is completed
MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, mx_uint num, const char **keys, NDArrayHandle *vals, int priority)
pull a list of (key, value) pairs from the kvstore, where each key is a string
MXNET_DLL int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out)
Create a RecordIO reader object.
MXNET_DLL int MXKVStoreGetNumDeadNode(KVStoreHandle handle, const int node_id, int *number, const int timeout_sec DEFAULT(60))
Get the number of ps dead node(s) specified by {node_id}.
int(* MXGenericCallback)(void)
Definition: c_api.h:128
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
Definition: c_api.h:139
MXNET_DLL int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos)
Get the current writer pointer position.
Definition: c_api.h:150
MXNET_DLL int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos)
Get the current writer pointer position.
MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t *out_size, const char **out_buf)
save the NDArray into raw bytes.
MXNET_DLL int MXSymbolGetName(SymbolHandle symbol, const char **out, int *success)
Get string name from symbol.
MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params, const char **keys, const char **vals)
Set parameters to use low-bit compressed gradients.
void * p_backward
Definition: c_api.h:121
Definition: c_api.h:143
int(* CustomOpCreateFunc)(const char *, int, unsigned **, const int *, const int *, struct MXCallbackList *, void *)
Definition: c_api.h:165
MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint num_map_keys, const char **map_keys, const int *map_dev_types, const int *map_dev_ids, mx_uint len, NDArrayHandle *in_args, NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, NDArrayHandle *aux_states, ExecutorHandle shared_exec, ExecutorHandle *out)
Generate Executor from symbol, This is advanced function, allow specify group2ctx map...
int(* CustomOpListFunc)(char ***, void *)
Definition: c_api.h:158
MXNET_DLL int MXAutogradIsRecording(bool *curr)
get whether autograd recording is on
Definition: c_api.h:176
MXNET_DLL int MXRecordIOWriterFree(RecordIOHandle handle)
Delete a RecordIO writer object.
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads)
Excecutor run backward.
MXNET_DLL int MXKVStoreGetRank(KVStoreHandle handle, int *ret)
return The rank of this node in its group, which is in [0, GroupSize).
MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs, int num_outputs, NDArrayHandle *outputs, struct MXCallbackList *callbacks)
MXNET_DLL int MXKVStoreGetGroupSize(KVStoreHandle handle, int *ret)
return The number of nodes in this group, which is
MXNET_DLL int MXRecordIOWriterWriteRecord(RecordIOHandle handle, const char *buf, size_t size)
Write a record to a RecordIO object.
MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out)
Create a Variable Symbol.
MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void *callback_handle)
set a call back to notify the completion of operation
MXNET_DLL int MXAutogradBackward(mx_uint num_output, NDArrayHandle *output_handles, NDArrayHandle *ograd_handles, int retain_graph)
compute the gradient of outputs w.r.t variabels
void * CudaKernelHandle
handle to rtc cuda kernel
Definition: c_api.h:91
MXNET_DLL int MXKVStoreIsWorkerNode(int *ret)
return whether or not this process is a worker node.
MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str)
Print the content of symbol, used for debug.
MXNET_DLL int MXSymbolCompose(SymbolHandle sym, const char *name, mx_uint num_args, const char **keys, SymbolHandle *args)
Compose the symbol on other symbols.
MXNET_DLL int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char **wrt, SymbolHandle *out)
Get the gradient graph of the symbol.
float mx_float
manually define float
Definition: c_api.h:60
MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out)
create cached operator
MXNET_DLL int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos)
Set the current reader pointer position.
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:54
MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, mx_uint num_param, const char **keys, const char **vals, SymbolHandle *out)
Create an AtomicSymbol.
MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state)
set the flag for gradient array state.
MXNET_DLL int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check)
check whether the NDArray format is valid
void(* ExecutorMonitorCallback)(const char *, NDArrayHandle, void *)
Definition: c_api.h:93
bool(* backward)(int, void **, int *, void *)
Definition: c_api.h:113
MXNET_DLL const char * MXGetLastError()
return str message of the last error all function in this file will return 0 when success and -1 when...
MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out)
set the flag for gradient array state.
void ** contexts
Definition: c_api.h:133
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:69
MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle, mx_uint num, const char **keys, NDArrayHandle *vals)
Init a list of (key,value) pairs in kvstore, where each key is a string.
MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out)
MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint num_map_keys, const char **map_keys, const int *map_dev_types, const int *map_dev_ids, mx_uint len, NDArrayHandle *in_args, NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, NDArrayHandle *aux_states, ExecutorHandle *out)
Generate Executor from symbol, This is advanced function, allow specify group2ctx map...
void * p_list_outputs
Definition: c_api.h:107
void( MXKVStoreServerController)(int head, const char *body, void *controller_handle)
the prototype of a server controller
Definition: c_api.h:1820
MXNET_DLL int MXNDArraySave(const char *fname, mx_uint num_args, NDArrayHandle *args, const char **keys)
Save list of narray into the file.
MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out)
create a NDArray handle that is loaded from raw bytes.
MXNET_DLL int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output, NDArrayHandle *inputs, NDArrayHandle *outputs, mx_uint gridDimX, mx_uint gridDimY, mx_uint gridDimZ, mx_uint blockDimX, mx_uint blockDimY, mx_uint blockDimZ)
Run cuda kernel.
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array)
List auxiliary states in the symbol.
MXNET_DLL int MXGetFunction(const char *name, FunctionHandle *out)
get the function handle by name
MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol, mx_uint *out_size, const char ***out)
Get all attributes from symbol, excluding descendents.
void * DataIterCreator
handle a dataiter creator
Definition: c_api.h:79
MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions)
Get the detailed information about data iterator.
MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle, const char **type)
get the type of the kvstore
void * RtcHandle
handle to MXRtc
Definition: c_api.h:87
MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out)
return gradient buffer attached to this NDArray
MXNET_DLL int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char **keys, const int *arg_type_data, mx_uint *in_type_size, const int **in_type_data, mx_uint *out_type_size, const int **out_type_data, mx_uint *aux_type_size, const int **aux_type_data, int *complete)
infer type of unknown input types given the known one. The types are packed into a CSR matrix represe...
Definition: c_api.h:145
MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, mx_uint num, const int *keys, NDArrayHandle *vals)
Init a list of (key,value) pairs in kvstore.
MXNET_DLL int MXDumpProfile()
Save profile and stop profiler.
Definition: c_api.h:130
MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator, int num_inputs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs, int num_params, const char **param_keys, const char **param_vals, const int **out_stypes)
invoke a nnvm op and imperative function
MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle)
const void * FunctionHandle
handle to a mxnet narray function that changes NDArray
Definition: c_api.h:67
Definition: c_api.h:97
MXNET_DLL int MXRecordIOReaderReadRecord(RecordIOHandle handle, char const **buf, size_t *size)
Write a record to a RecordIO object.
MXNET_DLL int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int *keys, NDArrayHandle *vals, int priority)
Push a list of (key,value) pairs to kvstore.
MXNET_DLL int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array)
list all the available functions handles most user can use it to list all the needed functions ...
Definition: c_api.h:146
Definition: c_api.h:149
MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void **args, mx_uint grid_dim_x, mx_uint grid_dim_y, mx_uint grid_dim_z, mx_uint block_dim_x, mx_uint block_dim_y, mx_uint block_dim_z, mx_uint shared_mem)
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
MXNET_DLL int MXExecutorFree(ExecutorHandle handle)
Delete the executor.
MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle)
Wait until all the pending read/write with respect NDArray are finished. Always call this before writ...
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train)
Executor forward method.
MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, void **out_pdata)
get the content of the data in NDArray