mxnet
lib_api.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
31 #ifndef MXNET_LIB_API_H_
32 #define MXNET_LIB_API_H_
33 
34 #include <stdint.h>
35 #include <stdlib.h>
36 #include <string.h>
37 #include <vector>
38 #include <map>
39 #include <unordered_set>
40 #include <unordered_map>
41 #include <string>
42 #include <iostream>
43 #include <utility>
44 #include <stdexcept>
45 #include <functional>
46 #include <random>
47 #include <sstream>
48 
49 #if defined(__NVCC__)
50 #include <cuda_runtime.h>
51 #include <curand_kernel.h>
52 #endif
53 
54 /* Make sure to update the version number everytime you make changes */
55 #define MX_LIBRARY_VERSION 11
56 
62 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
63 #define PRIVATE_SYMBOL
64 #else
65 #define PRIVATE_SYMBOL __attribute__((visibility("hidden")))
66 #endif
67 
68 /*
69  * Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
70  */
71 #ifndef DLPACK_VERSION
72 #ifdef __cplusplus
73 #define DLPACK_EXTERN_C extern "C"
74 #else
75 #define DLPACK_EXTERN_C
76 #endif
77 
79 #define DLPACK_VERSION 020
80 
82 #ifdef _WIN32
83 #ifdef DLPACK_EXPORTS
84 #define DLPACK_DLL __declspec(dllexport)
85 #else
86 #define DLPACK_DLL __declspec(dllimport)
87 #endif
88 #else
89 #define DLPACK_DLL
90 #endif
91 
92 #include <stdint.h>
93 #include <stddef.h>
94 
95 #ifdef __cplusplus
96 extern "C" {
97 #endif
98 
101 typedef enum {
103  kDLCPU = 1,
105  kDLGPU = 2,
116  kDLMetal = 8,
118  kDLVPI = 9,
120  kDLROCM = 10,
126  kDLExtDev = 12,
127 } DLDeviceType;
128 
132 typedef struct {
134  DLDeviceType device_type;
136  int device_id;
137 } DLContext;
138 
142 typedef enum {
143  kDLInt = 0U,
144  kDLUInt = 1U,
145  kDLFloat = 2U,
147 
156 typedef struct {
162  uint8_t code;
166  uint8_t bits;
168  uint16_t lanes;
169 } DLDataType;
170 
174 typedef struct {
194  void* data;
196  DLContext ctx;
198  int ndim;
200  DLDataType dtype;
202  int64_t* shape;
207  int64_t* strides;
209  uint64_t byte_offset;
210 } DLTensor;
211 #ifdef __cplusplus
212 } // DLPACK_EXTERN_C
213 #endif
214 #endif
215 
216 namespace mxnet {
217 namespace ext {
218 
219 /* \brief Class to store error messages from extensions to pass to MXNet */
220 class MXerrorMsgs {
221  public:
222  /* \brief get singleton pointer to class */
223  static MXerrorMsgs* get();
224 
225  /* \brief add a new error message */
226  std::stringstream& add(const char* file, int line);
227 
228  /* \brief return number of error messages */
229  int size();
230 
231  /* \brief get error message at index */
232  const std::string* get(int idx);
233 
234  private:
236  MXerrorMsgs() {}
238  ~MXerrorMsgs() {}
240  std::vector<std::stringstream> messages;
241 };
242 
243 // Add a new error message, example: MX_ERROR_MSG << "my error msg";
244 #define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__)
245 
249 enum MXDType {
250  kFloat32 = 0,
251  kFloat64 = 1,
252  kFloat16 = 2,
253  kUint8 = 3,
254  kInt32 = 4,
255  kInt8 = 5,
256  kInt64 = 6,
257  kUNSET = 100,
258 };
259 
260 /*
261  * MXTensor storage type.
262  */
264  // dense
266  // row sparse
268  // csr
270 };
271 
277 struct MXContext {
278  MXContext();
279  explicit MXContext(std::string dev_type_, int dev_id_);
280  explicit MXContext(const char* dev_type_, int dev_id_);
281  static MXContext CPU();
282  static MXContext GPU();
283  static MXContext CPU(int dev_id);
284  static MXContext GPU(int dev_id);
285 
286  std::string dev_type;
287  int dev_id;
288 };
289 
291  MX_FAIL = 0,
293 };
294 
295 // For sparse tensors, read/write the data from NDarray via pointers.
296 struct MXSparse {
297  // Pointer to data.
298  void* data{nullptr};
299  // length of (non-zero) data.
300  int64_t data_len;
301 
302  // To store aux data for sparse.
303  // For CSR, indices stores the col index of non-zero elements.
304  // For row sparse, indices store row index of rows which have non-zero elements.
305  int64_t* indices;
306  int64_t indices_len;
307 
308  // For CSR, indptr gives the start and end index of data for each row.
309  // For row sparse, indptr is not used.
310  int64_t* indptr = nullptr;
311  int64_t indptr_len;
312 
313  void set(void* data_ptr,
314  const int64_t* dims,
315  int ndims,
316  void* idx,
317  int64_t num_idx,
318  void* idx_ptr = nullptr,
319  int64_t num_idx_ptr = 0);
320 };
321 
325 struct MXTensor {
326  MXTensor();
327  MXTensor(const MXTensor& oth);
328  MXTensor(void* data_ptr,
329  std::vector<int64_t> shape,
330  MXDType dtype,
331  size_t vID,
332  MXContext mx_ctx,
334 
336  void setTensor(void* dptr,
337  MXDType type,
338  const int64_t* dims,
339  int ndims,
340  size_t vID,
341  MXContext mx_ctx,
342  MXStorageType storage_type);
343 
345  void setDLTensor();
346 
348  template <typename data_type>
349  inline data_type* data() {
350  return reinterpret_cast<data_type*>(data_ptr);
351  }
352 
354  int64_t size() const;
355 
357  bool isSame(const MXTensor& oth) const;
358 
359  // For dense, data_ptr points to 1D flattened tensor data
360  // For sparse, data_ptr points to MXSparse
361  void* data_ptr;
362 
363  // shape is in [2,3,4] format to represent high-dim tensor
364  std::vector<int64_t> shape;
365 
366  // type can only be MXDType enum types
368 
369  // version number updated if the tensor has changed since the last use by custom op
370  size_t verID;
371 
372  // context of MXTensor representing which device the tensor data is located
374 
375  // corresponding DLTensor repr of MXTensor
376  // easy way to reuse functions taking DLTensor
378 
379  // storage type
381 };
382 
384 typedef void* (*xpu_malloc_t)(void*, int);
386 typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**);
388 typedef void (*nd_malloc_t)(const void* _ndarray_alloc,
389  const int64_t* shapes,
390  int num_shapes,
391  const char* dev_str,
392  int dev_id,
393  int dtype,
394  const char* name,
395  int isArg,
396  void** data);
398 #if defined(__NVCC__)
399 typedef cudaStream_t mx_stream_t;
400 typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
401 #else
402 typedef void* mx_stream_t;
403 typedef void* mx_gpu_rand_t;
404 #endif
405 typedef std::mt19937 mx_cpu_rand_t;
406 
408 /* Each thread should generate random number unique sequence out of different states */
409 #define MX_NUM_CPU_RANDOM_STATES 1024
410 #define MX_NUM_GPU_RANDOM_STATES 32768
411 
412 /* \brief Class to help allocate new args/aux params in graph passes */
414  public:
415  PassResource(std::unordered_map<std::string, MXTensor>* new_args,
416  std::unordered_map<std::string, MXTensor>* new_aux,
417  nd_malloc_t nd_malloc,
418  const void* nd_alloc);
419 
420  // allocate new arg param, adds to args map, returns newly allocated tensor
421  MXTensor* alloc_arg(const std::string& name,
422  const std::vector<int64_t>& shapes,
423  const MXContext& ctx,
424  MXDType dtype) const;
425 
426  // allocate new aux param, adds to aux map, returns newly allocated tensor
427  MXTensor* alloc_aux(const std::string& name,
428  const std::vector<int64_t>& shapes,
429  const MXContext& ctx,
430  MXDType dtype) const;
431 
432  private:
433  std::unordered_map<std::string, MXTensor>* new_args_;
434  std::unordered_map<std::string, MXTensor>* new_aux_;
435  nd_malloc_t nd_malloc_;
436  const void* nd_alloc_;
437 };
438 
442 class OpResource {
443  public:
444  OpResource(xpu_malloc_t cpu_malloc_fp,
445  void* cpu_alloc_fp,
446  xpu_malloc_t gpu_malloc_fp,
447  void* gpu_alloc_fp,
448  void* stream,
449  sparse_malloc_t sparse_malloc_fp,
450  void* sparse_alloc_fp,
451  void* rng_cpu_states,
452  void* rng_gpu_states);
453 
455  void* alloc_cpu(int size) const;
456 
458  void* alloc_gpu(int size) const;
459 
461  inline mx_stream_t get_cuda_stream() const {
462  return static_cast<mx_stream_t>(cuda_stream);
463  }
464 
466  void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) const;
467 
469  /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
471 
473  /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
474  /* Note that if you are using cpu build, it will return a nullptr */
476  return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
477  }
478 
479  private:
481  xpu_malloc_t cpu_malloc, gpu_malloc;
483  void *cpu_alloc, *gpu_alloc;
485  void* cuda_stream;
487  sparse_malloc_t sparse_malloc;
489  void* sparse_alloc;
491  void *rand_cpu_states, *rand_gpu_states;
492 };
493 
495 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
496 
497 #define MX_STR_DTYPE "__ext_dtype__"
498 
499 #define MX_STR_SHAPE "__ext_shape__"
500 
501 #define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__"
502 
503 /* \brief get shape value from list of shapes string
504  *
505  * Examples:
506  *
507  * getShapeAt("[[1]]", 0) returns "[1]"
508  * getShapeAt("[[1],[2,3]]", 1) returns "[2,3]"
509  */
510 std::string getShapeAt(const std::string& shape, unsigned index);
511 
512 /* \brief get dtype value from list of dtypes string
513  *
514  * Examples:
515  *
516  * getDtypeAt("[1]", 0) returns "1"
517  * getDtypeAt("[1,2]", 1) returns "2"
518  */
519 std::string getDtypeAt(const std::string& dtype, unsigned index);
520 
525 enum JsonType { ERR, STR, NUM, LIST, MAP };
526 
528 struct JsonVal {
529  JsonVal(); // default constructor
530  // construct a JSON object by type
531  explicit JsonVal(JsonType t);
532  // construct a string JSON object
533  explicit JsonVal(std::string s);
534  // construct a number JSON object
535  explicit JsonVal(int n);
536  // complex constructor
537  JsonVal(JsonType t, int n, std::string s);
538  bool operator<(const JsonVal& o) const;
539 
540  // convert JSON object back to JSON-compatible string
541  std::string dump() const;
542 
543  // convert JSON-compatible string to JSON object
544  static JsonVal parse(const std::string& json);
545 
546  // parse a string JSON object
547  static JsonVal parse_string(const std::string& json, unsigned int* idx);
548 
549  // parse a number JSON object
550  static JsonVal parse_num(const std::string& json, unsigned int* idx);
551 
552  // parse a list of JSON objects
553  static JsonVal parse_list(const std::string& json, unsigned int* idx);
554 
555  // parse a map of JSON objects
556  static JsonVal parse_map(const std::string& json, unsigned int* idx);
557 
558  // generic parse function
559  static JsonVal parse(const std::string& json, unsigned int* idx);
560 
561  // debug function to convert data structure to a debugstring
562  std::string toString() const;
563 
565  int num;
566  std::string str;
567  std::vector<JsonVal> list;
568  std::map<JsonVal, JsonVal> map;
569 };
570 
574 class Node;
575 class Graph;
576 
577 // Representation of an input/output to a node
578 struct NodeEntry {
579  Node* node; // other node thats producing/consuming inputs/outputs
580  int entry; // entry index from other node (ie. output index from producing node)
581 };
582 
583 // Representation of a node in the graph
584 class Node {
585  public:
586  Node();
587 
588  // internally set passResource to enable tensor allocation for graph passes
589  void _setPassResource(PassResource* res_);
590 
591  /* \brief allocate an arg tensor for this node */
592  void alloc_arg(const std::vector<int64_t>& shapes, const MXContext& ctx, MXDType dtype);
593 
594  /* \brief allocate an aux tensor for this node */
595  void alloc_aux(const std::vector<int64_t>& shapes, const MXContext& ctx, MXDType dtype);
596 
597  std::string op; // operator name (ie. Convolution)
598  std::string name; // unique node name (ie. conv_0 or conv_1)
599  MXTensor* tensor; // tensor data for input nodes
600  std::vector<NodeEntry> inputs; // set of inputs to the node
601  std::vector<NodeEntry> outputs; // set of outputs from the node
602  std::vector<Graph*> subgraphs; // set of subgraphs within this node
603  std::unordered_map<std::string, std::string> attrs; // node attributes
604 
605  private:
606  PassResource* res;
607 };
608 
609 // Representation of the graph
610 class Graph {
611  public:
612  Graph();
613 
614  /* \brief deleted nodes when deleting the graph */
615  ~Graph();
616 
617  /* \brief create a graph object from an unparsed string */
618  static Graph* fromString(const std::string& json);
619 
620  /* \brief create a graph object from a parsed JSON object */
621  static Graph* fromJson(JsonVal val);
622 
623  /* \brief convert graph object back to JSON object */
624  JsonVal toJson() const;
625 
626  /* \brief convert graph object to JSON string */
627  std::string toString() const;
628 
629  /* \brief visits a node "n" */
630  void _dfs_util(Node* n,
631  std::unordered_set<Node*>* to_visit,
632  std::function<void(Node*)> handler) const;
633 
634  /* \brief post-order DFS graph traversal */
635  void DFS(std::function<void(Node*)> handler) const;
636 
637  /* \brief sort graph nodes in topological order */
638  std::vector<Node*> topological_sort() const;
639 
640  /* \brief print out graph details */
641  void print(int indent = 0) const;
642 
643  /* \brief add a new node to this graph */
644  Node* addNode(const std::string& name, const std::string& op);
645 
646  /* \brief get node at index in graph */
647  Node* getNode(size_t idx);
648 
649  /* \brief get const node at index in const graph */
650  const Node* getNode(size_t idx) const;
651 
652  /* \brief get attribute on graph */
653  const JsonVal& getAttr(const std::string& key) const;
654 
655  /* \brief get number of nodes in the graph */
656  size_t size() const;
657 
658  // internally set passResource to enable tensor allocation for graph passes
659  void _setPassResource(PassResource* res_);
660 
661  // internally set arg/aux params when available
662  void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
663  std::unordered_map<std::string, mxnet::ext::MXTensor>* aux);
664 
665  std::vector<Node*> inputs;
666  std::vector<NodeEntry> outputs;
667  std::map<std::string, JsonVal> attrs;
668 
669  private:
670  std::vector<Node*> nodes;
671  PassResource* res;
672 };
673 
674 /* \brief An abstract class for library authors creating custom
675  * partitioners. Optional, can just implement supportedOps instead
676  */
678  public:
679  /* \brief Select a node to include in subgraph, return true to include node
680  * nodeID - index of node in graph
681  */
682  virtual bool Select(int nodeID) = 0;
683  /* \brief Select an input node from current node to include in subgraph
684  * return true to include node
685  * nodeID - index of node in graph
686  * input_nodeID - index of input node in graph
687  */
688  virtual bool SelectInput(int nodeID, int input_nodeID) = 0;
689  /* \brief Select an output node from current node to include in subgraph
690  * return true to include node
691  * nodeID - index of node in graph
692  * output_nodeID - index of output node in graph
693  */
694  virtual bool SelectOutput(int nodeID, int output_nodeID) = 0;
695  /* \brief Review nodes to include in subgraph
696  * return set of candidate nodes to keep in subgraph
697  * candidates - indices of nodes to include in subgraph
698  * keep - indices of nodes to keep in subgraph
699  */
700  virtual void Filter(const std::vector<int>& candidates, std::vector<int>* keep) {
701  keep->insert(keep->end(), candidates.begin(), candidates.end());
702  }
703  /* \brief Reset any selector state, called after growing subgraph, before filter
704  * Called after finished calling SelectInput/SelectOutput and growing subgraph
705  */
706  virtual void Reset() {}
707 };
708 
715  public:
717  virtual ~CustomStatefulOp();
718 
719  template <class A, typename... Ts>
720  static CustomStatefulOp* create(Ts... args) {
721  CustomStatefulOp* op = new A(args...);
722  op->created = true;
723  return op;
724  }
725 
726  bool wasCreated() {
727  return created;
728  }
729 
730  virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
731  std::vector<MXTensor>* outputs,
732  const OpResource& op_res) = 0;
733  virtual MXReturnValue Backward(std::vector<MXTensor>* inputs,
734  std::vector<MXTensor>* outputs,
735  const OpResource& op_res) {
736  MX_ERROR_MSG << "Error! Operator does not support backward" << std::endl;
737  return MX_FAIL;
738  }
739 
741 
742  private:
743  bool created;
744 };
745 
747 typedef MXReturnValue (*fcomp_t)(const std::unordered_map<std::string, std::string>& attributes,
748  std::vector<MXTensor>* inputs,
749  std::vector<MXTensor>* outputs,
750  const OpResource& res);
752  const std::unordered_map<std::string, std::string>& attributes,
753  int* num_inputs,
754  int* num_outputs);
755 typedef MXReturnValue (*inferType_t)(const std::unordered_map<std::string, std::string>& attributes,
756  std::vector<int>* in_types,
757  std::vector<int>* out_types);
759  const std::unordered_map<std::string, std::string>& attributes,
760  std::vector<int>* in_storage_types,
761  std::vector<int>* out_storage_types);
763  const std::unordered_map<std::string, std::string>& attributes,
764  std::vector<std::vector<unsigned int> >* in_shapes,
765  std::vector<std::vector<unsigned int> >* out_shapes);
767  const std::unordered_map<std::string, std::string>& attributes,
768  std::vector<int>* input_indices);
770  const std::unordered_map<std::string, std::string>& attributes,
771  const MXContext& ctx,
772  const std::vector<std::vector<unsigned int> >& in_shapes,
773  const std::vector<int> in_types,
774  CustomStatefulOp**);
775 
779 class CustomOp {
780  public:
781  explicit CustomOp(const char* op_name);
782 
783  CustomOp& setForward(fcomp_t fcomp, const char* ctx);
784 
785  CustomOp& setBackward(fcomp_t fgrad, const char* ctx);
786 
788 
790 
792 
794 
796 
797  CustomOp& setCreateOpState(createOpState_t func, const char* ctx);
798 
800 
801  void mapToVector();
802 
804  const char* name;
805 
812  bool isSGop;
813 
816  std::vector<fcomp_t> forward_fp, backward_fp;
817  std::vector<createOpState_t> create_op_fp;
818 
819  private:
820  void raiseDuplicateContextError();
821 
823  std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
824  std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
825 };
826 
829  const std::unordered_map<std::string, std::string>& options);
830 
834 class CustomPass {
835  public:
836  CustomPass();
837  explicit CustomPass(const char* pass_name);
838 
840 
842  const char* name;
845 };
846 
849  const mxnet::ext::Graph* graph,
850  std::vector<int>* ids,
851  const std::unordered_map<std::string, std::string>& options);
853  const mxnet::ext::Graph* graph,
854  CustomOpSelector** sel_inst,
855  const std::unordered_map<std::string, std::string>& options);
857  const mxnet::ext::Graph* subgraph,
858  int subgraph_id,
859  bool* accept,
860  const std::unordered_map<std::string, std::string>& options,
861  std::unordered_map<std::string, std::string>* attrs);
862 
867  public:
869 
870  explicit CustomPartitioner(const char* backend_name);
871 
872  CustomPartitioner& addStrategy(const char* prop_name, const char* sg_name);
873 
874  CustomPartitioner& setSupportedOps(const char* prop_name, supportedOps_t fn);
875 
876  CustomPartitioner& setCreateSelector(const char* prop_name, createSelector_t fn);
877 
878  CustomPartitioner& setReviewSubgraph(const char* prop_name, reviewSubgraph_t fn);
879 
880  supportedOps_t getSupportedOps(int stg_id);
881 
883 
885 
887  const char* name;
888  std::map<std::string, supportedOps_t> supported_map;
889  std::map<std::string, createSelector_t> selector_map;
890  std::map<std::string, reviewSubgraph_t> review_map;
892  std::vector<const char*> strategies;
894  std::vector<const char*> op_names;
895 };
896 
901 template <class T>
902 class Registry {
903  public:
909  static Registry inst;
910  return &inst;
911  }
916  T& add(const char* name) {
917  T* entry = new T(name);
918  entries.push_back(entry);
919  return *entry;
920  }
921  int size() {
922  return entries.size();
923  }
924  T& get(int idx) {
925  return *(entries.at(idx));
926  }
927 
928  private:
930  Registry() {}
932  ~Registry() {}
934  std::vector<T*> entries;
935 };
936 
942 #define MX_STR_CONCAT_(__a, __b) __a##__b
943 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b)
944 
946 #define MX_STRINGIFY(x) #x
947 #define MX_TOSTRING(x) MX_STRINGIFY(x)
948 
950 #define MX_REGISTER_NAME_(Name) MXNet##_CustomOp##_##Name
951 #define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name)
952 
953 #define MX_REGISTER_PROP_NAME_(Name) MXNet##_CustomSubProp##_##Name
954 #define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
955 
956 #define MX_REGISTER_PASS_NAME_(Name) MXNet##_CustomPass##_##Name
957 #define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name)
958 
960 #define REGISTER_OP(Name) \
961  MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
962  mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name))
963 
964 #define REGISTER_PARTITIONER(Name) \
965  MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
966  mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name))
967 
968 #define REGISTER_PASS(Name) \
969  MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
970  mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name))
971 
972 /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */
973 
979 #define MXLIB_OPREGSIZE_STR "_opRegSize"
980 typedef int (*opRegSize_t)(void);
981 
982 #define MXLIB_OPREGGET_STR "_opRegGet"
983 typedef int (*opRegGet_t)(int idx,
984  const char** name,
985  int* isSGop,
986  const char*** forward_ctx,
987  mxnet::ext::fcomp_t** forward_fp,
988  int* forward_count,
989  const char*** backward_ctx,
990  mxnet::ext::fcomp_t** backward_fp,
991  int* backward_count,
992  const char*** create_op_ctx,
993  mxnet::ext::createOpState_t** create_op_fp,
994  int* create_op_count,
1000 
1001 #define MXLIB_OPCALLFREE_STR "_opCallFree"
1002 typedef int (*opCallFree_t)(void* ptr);
1003 
1004 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
1005 typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs,
1006  const char* const* keys,
1007  const char* const* vals,
1008  int num,
1009  int* num_in,
1010  int* num_out);
1011 
1012 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape"
1013 typedef int (*opCallInferShape_t)(inferShape_t inferShape,
1014  const char* const* keys,
1015  const char* const* vals,
1016  int num,
1017  unsigned int** inshapes,
1018  int* indims,
1019  int num_in,
1020  unsigned int*** mod_inshapes,
1021  int** mod_indims,
1022  unsigned int*** outshapes,
1023  int** outdims,
1024  int num_out);
1025 
1026 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType"
1027 typedef int (*opCallInferType_t)(inferType_t inferType,
1028  const char* const* keys,
1029  const char* const* vals,
1030  int num,
1031  int* intypes,
1032  int num_in,
1033  int* outtypes,
1034  int num_out);
1035 
1036 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
1037 typedef int (*opCallInferSType_t)(inferSType_t inferSType,
1038  const char* const* keys,
1039  const char* const* vals,
1040  int num,
1041  int* intypes,
1042  int num_in,
1043  int* outtypes,
1044  int num_out);
1045 
1046 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
1047 typedef int (*opCallFComp_t)(fcomp_t fcomp,
1048  const char* const* keys,
1049  const char* const* vals,
1050  int num,
1051  const int64_t** inshapes,
1052  int* indims,
1053  void** indata,
1054  int* intypes,
1055  size_t* inIDs,
1056  const char** indev_type,
1057  int* indev_id,
1058  int num_in,
1059  const int64_t** outshapes,
1060  int* outdims,
1061  void** outdata,
1062  int* outtypes,
1063  size_t* outIDs,
1064  const char** outdev_type,
1065  int* outdev_id,
1066  int num_out,
1067  xpu_malloc_t cpu_malloc,
1068  void* cpu_alloc,
1069  xpu_malloc_t gpu_malloc,
1070  void* gpu_alloc,
1071  void* cuda_stream,
1072  sparse_malloc_t sparse_malloc,
1073  void* sparse_alloc,
1074  int* instypes,
1075  int* outstypes,
1076  void** in_indices,
1077  void** out_indices,
1078  void** in_indptr,
1079  void** out_indptr,
1080  int64_t* in_indices_shapes,
1081  int64_t* out_indices_shapes,
1082  int64_t* in_indptr_shapes,
1083  int64_t* out_indptr_shapes,
1084  void* rng_cpu_states,
1085  void* rng_gpu_states);
1086 
1087 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
1088 typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate,
1089  const char* const* keys,
1090  const char* const* vals,
1091  int num,
1092  int** mutate_indices,
1093  int* indices_size);
1094 
1095 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
1096 typedef int (*opCallCreateOpState_t)(createOpState_t create_op,
1097  const char* const* keys,
1098  const char* const* vals,
1099  int num,
1100  const char* dev_type,
1101  int dev_id,
1102  unsigned int** inshapes,
1103  int* indims,
1104  int num_in,
1105  const int* intypes,
1106  void** state_op);
1107 
1108 #define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState"
1109 typedef int (*opCallDestroyOpState_t)(void* state_op);
1110 
1111 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
1112 typedef int (*opCallFStatefulComp_t)(int is_forward,
1113  void* state_op,
1114  const int64_t** inshapes,
1115  int* indims,
1116  void** indata,
1117  int* intypes,
1118  size_t* inIDs,
1119  const char** indev_type,
1120  int* indev_id,
1121  int num_in,
1122  const int64_t** outshapes,
1123  int* outdims,
1124  void** outdata,
1125  int* outtypes,
1126  size_t* outIDs,
1127  const char** outdev_type,
1128  int* outdev_id,
1129  int num_out,
1130  xpu_malloc_t cpu_malloc,
1131  void* cpu_alloc,
1132  xpu_malloc_t gpu_malloc,
1133  void* gpu_alloc,
1134  void* stream,
1135  sparse_malloc_t sparse_malloc,
1136  void* sparse_alloc,
1137  int* instypes,
1138  int* outstypes,
1139  void** in_indices,
1140  void** out_indices,
1141  void** in_indptr,
1142  void** out_indptr,
1143  int64_t* in_indices_shapes,
1144  int64_t* out_indices_shapes,
1145  int64_t* in_indptr_shapes,
1146  int64_t* out_indptr_shapes,
1147  void* rng_cpu_states,
1148  void* rng_gpu_states);
1149 
1150 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
1151 typedef int (*partRegSize_t)(void);
1152 
1153 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount"
1154 typedef int (*partRegGetCount_t)(int idx, const char** name);
1155 
1156 #define MXLIB_PARTREGGET_STR "_partRegGet"
1157 typedef void (*partRegGet_t)(int part_idx,
1158  int stg_idx,
1159  const char** strategy,
1160  supportedOps_t* supportedOps,
1161  createSelector_t* createSelector,
1162  reviewSubgraph_t* reviewSubgraph,
1163  const char** op_name);
1164 
1165 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps"
1166 typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps,
1167  const char* json,
1168  int num_ids,
1169  int* ids,
1170  const char* const* opt_keys,
1171  const char* const* opt_vals,
1172  int num_opts);
1173 
1174 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector"
1175 typedef int (*partCallCreateSelector_t)(createSelector_t createSelector,
1176  const char* json,
1177  void** selector,
1178  const char* const* opt_keys,
1179  const char* const* opt_vals,
1180  int num_opts);
1181 
1182 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect"
1183 typedef void (*partCallSelect_t)(void* sel_inst, int nodeID, int* selected);
1184 
1185 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput"
1186 typedef void (*partCallSelectInput_t)(void* sel_inst, int nodeID, int input_nodeID, int* selected);
1187 
1188 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput"
1189 typedef void (*partCallSelectOutput_t)(void* sel_inst,
1190  int nodeID,
1191  int output_nodeID,
1192  int* selected);
1193 
1194 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter"
1195 typedef void (*partCallFilter_t)(void* sel_inst,
1196  int* candidates,
1197  int num_candidates,
1198  int** keep,
1199  int* num_keep);
1200 
1201 #define MXLIB_PARTCALLRESET_STR "_partCallReset"
1202 typedef void (*partCallReset_t)(void* sel_inst);
1203 
1204 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph"
1205 typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph,
1206  const char* json,
1207  int subgraph_id,
1208  int* accept,
1209  const char* const* opt_keys,
1210  const char* const* opt_vals,
1211  int num_opts,
1212  char*** attr_keys,
1213  char*** attr_vals,
1214  int* num_attrs,
1215  const char* const* arg_names,
1216  int num_args,
1217  void* const* arg_data,
1218  const int64_t* const* arg_shapes,
1219  const int* arg_dims,
1220  const int* arg_types,
1221  const size_t* arg_IDs,
1222  const char* const* arg_dev_type,
1223  const int* arg_dev_id,
1224  const char* const* aux_names,
1225  int num_aux,
1226  void* const* aux_data,
1227  const int64_t* const* aux_shapes,
1228  const int* aux_dims,
1229  const int* aux_types,
1230  const size_t* aux_IDs,
1231  const char* const* aux_dev_type,
1232  const int* aux_dev_id);
1233 
1234 #define MXLIB_PASSREGSIZE_STR "_passRegSize"
1235 typedef int (*passRegSize_t)(void);
1236 
1237 #define MXLIB_PASSREGGET_STR "_passRegGet"
1238 typedef void (*passRegGet_t)(int pass_idx, graphPass_t* graphPass, const char** pass_name);
1239 
1240 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass"
1241 typedef int (*passCallGraphPass_t)(graphPass_t graphPass,
1242  const char* in_graph,
1243  char** out_graph,
1244  const char* const* opt_keys,
1245  const char* const* opt_vals,
1246  int num_opts,
1247  const char* pass_name,
1248  const char* const* arg_names,
1249  int num_args,
1250  void* const* arg_data,
1251  const int64_t* const* arg_shapes,
1252  const int* arg_dims,
1253  const int* arg_types,
1254  const size_t* arg_IDs,
1255  const char* const* arg_dev_type,
1256  const int* arg_dev_id,
1257  const char* const* aux_names,
1258  int num_aux,
1259  void* const* aux_data,
1260  const int64_t* const* aux_shapes,
1261  const int* aux_dims,
1262  const int* aux_types,
1263  const size_t* aux_IDs,
1264  const char* const* aux_dev_type,
1265  const int* aux_dev_id,
1266  nd_malloc_t nd_malloc,
1267  const void* nd_alloc);
1268 
1269 #define MXLIB_INITIALIZE_STR "initialize"
1270 typedef int (*initialize_t)(int version);
1271 
1272 #define MXLIB_OPVERSION_STR "_opVersion"
1273 typedef int (*opVersion_t)();
1274 
1275 #define MXLIB_MSGSIZE_STR "_msgSize"
1276 typedef int (*msgSize_t)(void);
1277 
1278 #define MXLIB_MSGGET_STR "_msgGet"
1279 typedef int (*msgGet_t)(int idx, const char** msg);
1280 
1283  public:
1286  : instance(inst), destroy_(destroy) {}
1288  return instance;
1289  }
1290 
1291  private:
1292  CustomStatefulOp* instance;
1293  opCallDestroyOpState_t destroy_;
1294 };
1295 
1296 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1297 #define MX_INT_RET __declspec(dllexport) int __cdecl
1298 #define MX_VOID_RET __declspec(dllexport) void __cdecl
1299 #else
1300 #define MX_INT_RET int
1301 #define MX_VOID_RET void
1302 #endif
1303 
1304 } // namespace ext
1305 } // namespace mxnet
1306 
1307 extern "C" {
1310 
1313 
1315 MX_VOID_RET _opRegGet(int idx,
1316  const char** name,
1317  int* isSGop,
1318  const char*** forward_ctx,
1319  mxnet::ext::fcomp_t** forward_fp,
1320  int* forward_count,
1321  const char*** backward_ctx,
1322  mxnet::ext::fcomp_t** backward_fp,
1323  int* backward_count,
1324  const char*** create_op_ctx,
1325  mxnet::ext::createOpState_t** create_op_fp,
1326  int* create_op_count,
1327  mxnet::ext::parseAttrs_t* parse,
1329  mxnet::ext::inferSType_t* stype,
1330  mxnet::ext::inferShape_t* shape,
1331  mxnet::ext::mutateInputs_t* mutate);
1332 
1334 MX_VOID_RET _opCallFree(void* ptr);
1335 
1338  const char* const* keys,
1339  const char* const* vals,
1340  int num,
1341  int* num_in,
1342  int* num_out);
1343 
1346  const char* const* keys,
1347  const char* const* vals,
1348  int num,
1349  unsigned int** inshapes,
1350  int* indims,
1351  int num_in,
1352  unsigned int*** mod_inshapes,
1353  int** mod_indims,
1354  unsigned int*** outshapes,
1355  int** outdims,
1356  int num_out);
1357 
1360  const char* const* keys,
1361  const char* const* vals,
1362  int num,
1363  int* intypes,
1364  int num_in,
1365  int* outtypes,
1366  int num_out);
1367 
1370  const char* const* keys,
1371  const char* const* vals,
1372  int num,
1373  int* instypes,
1374  int num_in,
1375  int* outstypes,
1376  int num_out);
1377 
1380  const char* const* keys,
1381  const char* const* vals,
1382  int num,
1383  const int64_t** inshapes,
1384  int* indims,
1385  void** indata,
1386  int* intypes,
1387  size_t* inIDs,
1388  const char** indev_type,
1389  int* indev_id,
1390  int num_in,
1391  const int64_t** outshapes,
1392  int* outdims,
1393  void** outdata,
1394  int* outtypes,
1395  size_t* outIDs,
1396  const char** outdev_type,
1397  int* outdev_id,
1398  int num_out,
1399  mxnet::ext::xpu_malloc_t cpu_malloc,
1400  void* cpu_alloc,
1401  mxnet::ext::xpu_malloc_t gpu_malloc,
1402  void* gpu_alloc,
1403  void* cuda_stream,
1404  mxnet::ext::sparse_malloc_t sparse_malloc,
1405  void* sparse_alloc,
1406  int* instypes,
1407  int* outstypes,
1408  void** in_indices,
1409  void** out_indices,
1410  void** in_indptr,
1411  void** out_indptr,
1412  int64_t* in_indices_shapes,
1413  int64_t* out_indices_shapes,
1414  int64_t* in_indptr_shapes,
1415  int64_t* out_indptr_shapes,
1416  void* rng_cpu_states,
1417  void* rng_gpu_states);
1418 
1421  const char* const* keys,
1422  const char* const* vals,
1423  int num,
1424  int** mutate_indices,
1425  int* indices_size);
1426 
1429  const char* const* keys,
1430  const char* const* vals,
1431  int num,
1432  const char* dev_type,
1433  int dev_id,
1434  unsigned int** inshapes,
1435  int* indims,
1436  int num_in,
1437  const int* intypes,
1438  void** state_op);
1439 
1441 MX_VOID_RET _opCallDestroyOpState(void* state_op);
1442 
1444 MX_INT_RET _opCallFStatefulCompute(int is_forward,
1445  void* state_op,
1446  const int64_t** inshapes,
1447  int* indims,
1448  void** indata,
1449  int* intypes,
1450  size_t* inIDs,
1451  const char** indev_type,
1452  int* indev_id,
1453  int num_in,
1454  const int64_t** outshapes,
1455  int* outdims,
1456  void** outdata,
1457  int* outtypes,
1458  size_t* outIDs,
1459  const char** outdev_type,
1460  int* outdev_id,
1461  int num_out,
1462  mxnet::ext::xpu_malloc_t cpu_malloc,
1463  void* cpu_alloc,
1464  mxnet::ext::xpu_malloc_t gpu_malloc,
1465  void* gpu_alloc,
1466  void* stream,
1467  mxnet::ext::sparse_malloc_t sparse_malloc,
1468  void* sparse_alloc,
1469  int* instypes,
1470  int* outstypes,
1471  void** in_indices,
1472  void** out_indices,
1473  void** in_indptr,
1474  void** out_indptr,
1475  int64_t* in_indices_shapes,
1476  int64_t* out_indices_shapes,
1477  int64_t* in_indptr_shapes,
1478  int64_t* out_indptr_shapes,
1479  void* rng_cpu_states,
1480  void* rng_gpu_states);
1481 
1484 
1485 /* returns number of strategies registered for partitioner
1486  * at specified index */
1487 MX_INT_RET _partRegGetCount(int idx, const char** name);
1488 
1490 MX_VOID_RET _partRegGet(int part_idx,
1491  int stg_idx,
1492  const char** strategy,
1493  mxnet::ext::supportedOps_t* supportedOps,
1494  mxnet::ext::createSelector_t* createSelector,
1495  mxnet::ext::reviewSubgraph_t* reviewSubgraph,
1496  const char** op_name);
1497 
1500  const char* json,
1501  int num_ids,
1502  int* ids,
1503  const char* const* opt_keys,
1504  const char* const* opt_vals,
1505  int num_opts);
1506 
1509  const char* json,
1510  void** selector,
1511  const char* const* opt_keys,
1512  const char* const* opt_vals,
1513  int num_opts);
1514 
1516 MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected);
1517 
1519 MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, int input_nodeID, int* selected);
1520 
1522 MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, int output_nodeID, int* selected);
1523 
1525 MX_VOID_RET _partCallFilter(void* sel_inst,
1526  int* candidates,
1527  int num_candidates,
1528  int** keep,
1529  int* num_keep);
1530 
1532 MX_VOID_RET _partCallReset(void* sel_inst);
1533 
1536  const char* json,
1537  int subgraph_id,
1538  int* accept,
1539  const char* const* opt_keys,
1540  const char* const* opt_vals,
1541  int num_opts,
1542  char*** attr_keys,
1543  char*** attr_vals,
1544  int* num_attrs,
1545  const char* const* arg_names,
1546  int num_args,
1547  void* const* arg_data,
1548  const int64_t* const* arg_shapes,
1549  const int* arg_dims,
1550  const int* arg_types,
1551  const size_t* arg_IDs,
1552  const char* const* arg_dev_type,
1553  const int* arg_dev_id,
1554  const char* const* aux_names,
1555  int num_aux,
1556  void* const* aux_data,
1557  const int64_t* const* aux_shapes,
1558  const int* aux_dims,
1559  const int* aux_types,
1560  const size_t* aux_IDs,
1561  const char* const* aux_dev_type,
1562  const int* aux_dev_id);
1563 
1566 
1568 MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, const char** pass_name);
1569 
1572  const char* json,
1573  char** out_graph,
1574  const char* const* opt_keys,
1575  const char* const* opt_vals,
1576  int num_opts,
1577  const char* pass_name,
1578  const char* const* arg_names,
1579  int num_args,
1580  void* const* arg_data,
1581  const int64_t* const* arg_shapes,
1582  const int* arg_dims,
1583  const int* arg_types,
1584  const size_t* arg_IDs,
1585  const char* const* arg_dev_type,
1586  const int* arg_dev_id,
1587  const char* const* aux_names,
1588  int num_aux,
1589  void* const* aux_data,
1590  const int64_t* const* aux_shapes,
1591  const int* aux_dims,
1592  const int* aux_types,
1593  const size_t* aux_IDs,
1594  const char* const* aux_dev_type,
1595  const int* aux_dev_id,
1596  mxnet::ext::nd_malloc_t nd_malloc,
1597  const void* nd_alloc);
1598 
1606 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1607 __declspec(dllexport) mxnet::ext::MXReturnValue __cdecl
1608 #else
1610 #endif
1611  initialize(int version);
1612 
1614 
1616 MX_VOID_RET _msgGet(int idx, const char** msg);
1617 } // extern "C"
1618 
1619 #endif // MXNET_LIB_API_H_
mxnet::ext::CustomOp::infer_storage_type
inferSType_t infer_storage_type
Definition: lib_api.h:809
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::ext::partCallSelect_t
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1183
mxnet::ext::Graph::~Graph
~Graph()
_partCallSelect
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
mxnet::ext::CustomOp::isSGop
bool isSGop
Definition: lib_api.h:812
mxnet::ext::mx_gpu_rand_t
void * mx_gpu_rand_t
Definition: lib_api.h:403
mxnet::ext::JsonType
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:525
mxnet::ext::MXTensor::dtype
MXDType dtype
Definition: lib_api.h:367
mxnet::ext::JsonVal::parse
static JsonVal parse(const std::string &json)
mxnet::ext::MXSparse::indices
int64_t * indices
Definition: lib_api.h:305
mxnet::ext::fcomp_t
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:747
mxnet::ext::MXTensor::data_ptr
void * data_ptr
Definition: lib_api.h:361
DLDataType
The data type the tensor can hold.
Definition: dlpack.h:94
mxnet::ext::CustomOp::setBackward
CustomOp & setBackward(fcomp_t fgrad, const char *ctx)
mxnet::ext::passRegSize_t
int(* passRegSize_t)(void)
Definition: lib_api.h:1235
mxnet::ext::opCallInferSType_t
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:1037
kDLOpenCL
@ kDLOpenCL
OpenCL devices.
Definition: lib_api.h:112
mxnet::ext::Node::op
std::string op
Definition: lib_api.h:597
_partCallCreateSelector
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
mxnet::ext::NodeEntry
Definition: lib_api.h:578
mxnet::ext::OpResource::alloc_gpu
void * alloc_gpu(int size) const
allocate gpu memory controlled by MXNet
mxnet::ext::CustomStatefulOp::Backward
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:733
mxnet::ext::JsonVal::JsonVal
JsonVal()
mxnet::ext::Graph
Definition: lib_api.h:610
mxnet::ext::CustomPartitioner
An abstract class for subgraph property.
Definition: lib_api.h:866
mxnet::ext::MXTensor::shape
std::vector< int64_t > shape
Definition: lib_api.h:364
mxnet::ext::Graph::getAttr
const JsonVal & getAttr(const std::string &key) const
mxnet::ext::getDtypeAt
std::string getDtypeAt(const std::string &dtype, unsigned index)
mxnet::ext::CustomPass::setBody
CustomPass & setBody(graphPass_t fn)
mxnet::ext::CustomOp::setIsSubgraphOp
CustomOp & setIsSubgraphOp()
mxnet::ext::opCallFree_t
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:1002
mxnet::ext::reviewSubgraph_t
MXReturnValue(* reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs)
Definition: lib_api.h:856
mxnet::ext::MXContext::CPU
static MXContext CPU()
mxnet::ext::MXTensor
Tensor data structure used by custom operator.
Definition: lib_api.h:325
mxnet::ext::CustomOp::mutate_inputs
mutateInputs_t mutate_inputs
Definition: lib_api.h:811
_opCallCreateOpState
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
returns status of calling createStatefulOp function for operator from library
mxnet::ext::MXSparse::data
void * data
Definition: lib_api.h:298
mxnet::ext::mx_cpu_rand_t
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:405
_partRegGetCount
MX_INT_RET _partRegGetCount(int idx, const char **name)
mxnet::ext::Graph::fromString
static Graph * fromString(const std::string &json)
mxnet::ext::CustomPass
An abstract class for graph passes.
Definition: lib_api.h:834
mxnet::ext::kCSRStorage
@ kCSRStorage
Definition: lib_api.h:269
_opRegGet
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
returns operator registration at specified index
mxnet::ext::Node::inputs
std::vector< NodeEntry > inputs
Definition: lib_api.h:600
mxnet::ext::partCallReviewSubgraph_t
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1205
_opCallFCompute
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
mxnet::ext::inferType_t
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:755
mxnet::ext::CustomOpSelector::SelectOutput
virtual bool SelectOutput(int nodeID, int output_nodeID)=0
mxnet::ext::CustomOpSelector
Definition: lib_api.h:677
mxnet::ext::CustomPass::CustomPass
CustomPass()
mxnet::ext::MXSparse::indptr_len
int64_t indptr_len
Definition: lib_api.h:311
kDLROCM
@ kDLROCM
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:120
mxnet::ext::supportedOps_t
MXReturnValue(* supportedOps_t)(const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:848
mxnet::ext::Node::_setPassResource
void _setPassResource(PassResource *res_)
mxnet::ext::opCallFStatefulComp_t
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1112
mxnet::ext::MX_SUCCESS
@ MX_SUCCESS
Definition: lib_api.h:292
DLContext
A Device context for Tensor and operator.
Definition: dlpack.h:69
mxnet::ext::parseAttrs_t
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:751
mxnet::ext::partCallCreateSelector_t
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1175
mxnet::ext::PassResource
Definition: lib_api.h:413
mxnet::ext::partCallFilter_t
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1195
mxnet::ext::Node::alloc_aux
void alloc_aux(const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype)
mxnet::ext::mx_stream_t
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:402
_partRegGet
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
mxnet::ext::JsonVal
definition of JSON objects
Definition: lib_api.h:528
mxnet::ext::kUint8
@ kUint8
Definition: lib_api.h:253
mxnet::ext::CustomPass::pass
graphPass_t pass
pass function
Definition: lib_api.h:844
mxnet::ext::opCallCreateOpState_t
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
Definition: lib_api.h:1096
mxnet::ext::MXSparse::indptr
int64_t * indptr
Definition: lib_api.h:310
mxnet::ext::JsonVal::operator<
bool operator<(const JsonVal &o) const
mxnet::ext::CustomStatefulOpWrapper::get_instance
CustomStatefulOp * get_instance()
Definition: lib_api.h:1287
mxnet::ext::LIST
@ LIST
Definition: lib_api.h:525
_opCallInferType
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
mxnet::ext::Graph::_dfs_util
void _dfs_util(Node *n, std::unordered_set< Node * > *to_visit, std::function< void(Node *)> handler) const
mxnet::ext::OpResource::get_cuda_stream
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:461
DLDeviceType
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
mxnet::ext::JsonVal::list
std::vector< JsonVal > list
Definition: lib_api.h:567
mxnet::ext::opCallDestroyOpState_t
int(* opCallDestroyOpState_t)(void *state_op)
Definition: lib_api.h:1109
mxnet::ext::CustomStatefulOp::Forward
virtual MXReturnValue Forward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)=0
mxnet::ext::PassResource::PassResource
PassResource(std::unordered_map< std::string, MXTensor > *new_args, std::unordered_map< std::string, MXTensor > *new_aux, nd_malloc_t nd_malloc, const void *nd_alloc)
mxnet::ext::partRegSize_t
int(* partRegSize_t)(void)
Definition: lib_api.h:1151
mxnet::ext::CustomOp::backward_fp
std::vector< fcomp_t > backward_fp
Definition: lib_api.h:816
mxnet::ext::Graph::print
void print(int indent=0) const
mxnet::ext::MXContext::dev_id
int dev_id
Definition: lib_api.h:287
mxnet::ext::opRegSize_t
int(* opRegSize_t)(void)
Definition: lib_api.h:980
mxnet::ext::Graph::toString
std::string toString() const
mxnet::ext::passRegGet_t
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1238
mxnet::ext::sparse_malloc_t
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:386
mxnet::ext::CustomOp::infer_type
inferType_t infer_type
Definition: lib_api.h:808
mxnet::ext::CustomPartitioner::setReviewSubgraph
CustomPartitioner & setReviewSubgraph(const char *prop_name, reviewSubgraph_t fn)
mxnet::ext::CustomOpSelector::SelectInput
virtual bool SelectInput(int nodeID, int input_nodeID)=0
kDLUInt
@ kDLUInt
Definition: lib_api.h:144
mxnet::ext::nd_malloc_t
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:388
mxnet::ext::MXTensor::isSame
bool isSame(const MXTensor &oth) const
helper function to compare two MXTensors
mxnet::ext::OpResource::alloc_cpu
void * alloc_cpu(int size) const
allocate cpu memory controlled by MXNet
mxnet::ext::opCallInferType_t
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:1027
mxnet::ext::CustomOp::setCreateOpState
CustomOp & setCreateOpState(createOpState_t func, const char *ctx)
mxnet::ext::Node::tensor
MXTensor * tensor
Definition: lib_api.h:599
mxnet::ext::CustomPartitioner::setSupportedOps
CustomPartitioner & setSupportedOps(const char *prop_name, supportedOps_t fn)
mxnet::ext::Graph::attrs
std::map< std::string, JsonVal > attrs
Definition: lib_api.h:667
mxnet::ext::partRegGetCount_t
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1154
kDLMetal
@ kDLMetal
Metal for Apple GPU.
Definition: lib_api.h:116
mxnet::ext::CustomOp::forward_ctx_cstr
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:815
_partCallReviewSubgraph
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
mxnet::ext::CustomPartitioner::getReviewSubgraph
reviewSubgraph_t getReviewSubgraph(int stg_id)
mxnet::ext::MXDType
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:249
mxnet::ext::CustomStatefulOp::CustomStatefulOp
CustomStatefulOp()
mxnet::ext::initialize_t
int(* initialize_t)(int version)
Definition: lib_api.h:1270
_opCallDestroyOpState
MX_VOID_RET _opCallDestroyOpState(void *state_op)
returns status of deleting StatefulOp instance for operator from library
mxnet::ext::Graph::_setPassResource
void _setPassResource(PassResource *res_)
mxnet::ext::CustomPartitioner::name
const char * name
partitioner name
Definition: lib_api.h:887
mxnet::ext::mutateInputs_t
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:766
mxnet::ext::MX_FAIL
@ MX_FAIL
Definition: lib_api.h:291
mxnet::ext::NodeEntry::entry
int entry
Definition: lib_api.h:580
mxnet::ext::JsonVal::parse_map
static JsonVal parse_map(const std::string &json, unsigned int *idx)
mxnet::ext::Registry::get
T & get(int idx)
Definition: lib_api.h:924
mxnet::ext::MXTensor::setDLTensor
void setDLTensor()
populate DLTensor fields
mxnet::ext::CustomPartitioner::CustomPartitioner
CustomPartitioner()
mxnet::ext::CustomOpSelector::Reset
virtual void Reset()
Definition: lib_api.h:706
_partCallSupportedOps
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
_opCallParseAttrs
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
mxnet::ext::CustomPartitioner::op_names
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:894
MX_ERROR_MSG
#define MX_ERROR_MSG
Definition: lib_api.h:244
mxnet::ext::CustomPass::name
const char * name
pass name
Definition: lib_api.h:842
mxnet::ext::Node::attrs
std::unordered_map< std::string, std::string > attrs
Definition: lib_api.h:603
mxnet::ext::kFloat16
@ kFloat16
Definition: lib_api.h:252
mxnet::ext::CustomOp::name
const char * name
operator name
Definition: lib_api.h:804
mxnet::ext::MXSparse
Definition: lib_api.h:296
mxnet::ext::MXContext
Context info passing from MXNet OpContext dev_type is string repr of supported context,...
Definition: lib_api.h:277
mxnet::ext::Registry::get
static Registry * get() PRIVATE_SYMBOL
get singleton pointer to class
Definition: lib_api.h:908
mxnet::ext::MXTensor::verID
size_t verID
Definition: lib_api.h:370
mxnet::ext::STR
@ STR
Definition: lib_api.h:525
mxnet::ext::ERR
@ ERR
Definition: lib_api.h:525
mxnet::ext::CustomPartitioner::getSupportedOps
supportedOps_t getSupportedOps(int stg_id)
mxnet::ext::JsonVal::toString
std::string toString() const
kDLVPI
@ kDLVPI
Verilog simulator buffer.
Definition: lib_api.h:118
mxnet::ext::JsonVal::map
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:568
mxnet::ext::MXTensor::setTensor
void setTensor(void *dptr, MXDType type, const int64_t *dims, int ndims, size_t vID, MXContext mx_ctx, MXStorageType storage_type)
populate internal tensor fields
mxnet::ext::partCallSupportedOps_t
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1166
_opCallFStatefulCompute
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
mxnet::ext::JsonVal::dump
std::string dump() const
mxnet::ext::PassResource::alloc_aux
MXTensor * alloc_aux(const std::string &name, const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype) const
kDLInt
@ kDLInt
Definition: lib_api.h:143
mxnet::ext::JsonVal::type
JsonType type
Definition: lib_api.h:564
mxnet::ext::Graph::Graph
Graph()
mxnet::ext::Graph::_setParams
void _setParams(std::unordered_map< std::string, mxnet::ext::MXTensor > *args, std::unordered_map< std::string, mxnet::ext::MXTensor > *aux)
mxnet::ext::getShapeAt
std::string getShapeAt(const std::string &shape, unsigned index)
mxnet::ext::JsonVal::parse_num
static JsonVal parse_num(const std::string &json, unsigned int *idx)
mxnet::ext::msgSize_t
int(* msgSize_t)(void)
Definition: lib_api.h:1276
mxnet::ext::opCallParseAttrs_t
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:1005
PRIVATE_SYMBOL
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:65
mxnet::ext::NodeEntry::node
Node * node
Definition: lib_api.h:579
_partRegSize
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
_opCallInferSType
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
_opCallFree
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
mxnet::ext::Node::subgraphs
std::vector< Graph * > subgraphs
Definition: lib_api.h:602
_passRegGet
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
mxnet::ext::opCallFComp_t
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1047
mxnet::ext::OpResource
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:442
mxnet::ext::partCallReset_t
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1202
mxnet::ext::MXerrorMsgs::size
int size()
mxnet::ext::MXContext::dev_type
std::string dev_type
Definition: lib_api.h:286
mxnet::ext::Graph::getNode
Node * getNode(size_t idx)
mxnet::ext::Node
Definition: lib_api.h:584
mxnet::ext::PassResource::alloc_arg
MXTensor * alloc_arg(const std::string &name, const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype) const
mxnet::ext::CustomOp::create_op_ctx_cstr
std::vector< const char * > create_op_ctx_cstr
Definition: lib_api.h:815
mxnet::ext::opVersion_t
int(* opVersion_t)()
Definition: lib_api.h:1273
mxnet::ext::kUNSET
@ kUNSET
Definition: lib_api.h:257
initialize
mxnet::ext::MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library.
mxnet::ext::kFloat32
@ kFloat32
Definition: lib_api.h:250
mxnet::ext::Graph::toJson
JsonVal toJson() const
_opCallMutateInputs
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
mxnet::ext::MXerrorMsgs::get
static MXerrorMsgs * get()
mxnet::ext::CustomOp::setParseAttrs
CustomOp & setParseAttrs(parseAttrs_t func)
mxnet::ext::MXSparse::set
void set(void *data_ptr, const int64_t *dims, int ndims, void *idx, int64_t num_idx, void *idx_ptr=nullptr, int64_t num_idx_ptr=0)
mxnet::ext::CustomOp::setInferShape
CustomOp & setInferShape(inferShape_t func)
DLDataTypeCode
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:142
mxnet::ext::CustomOp
Class to hold custom operator registration.
Definition: lib_api.h:779
_partCallSelectOutput
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
mxnet::ext::kRowSparseStorage
@ kRowSparseStorage
Definition: lib_api.h:267
_partCallSelectInput
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
mxnet::ext::inferShape_t
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:762
mxnet::ext::Graph::addNode
Node * addNode(const std::string &name, const std::string &op)
mxnet::ext::MXerrorMsgs
Definition: lib_api.h:220
mxnet::ext::passCallGraphPass_t
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1241
MX_VOID_RET
#define MX_VOID_RET
Definition: lib_api.h:1301
_partCallFilter
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
mxnet::ext::CustomStatefulOp::create
static CustomStatefulOp * create(Ts... args)
Definition: lib_api.h:720
mxnet::ext::msgGet_t
int(* msgGet_t)(int idx, const char **msg)
Definition: lib_api.h:1279
mxnet::ext::kFloat64
@ kFloat64
Definition: lib_api.h:251
mxnet::ext::OpResource::get_cpu_rand_states
mx_cpu_rand_t * get_cpu_rand_states() const
get pointer to initialized and seeded random number states located on CPU
mxnet::ext::Node::Node
Node()
mxnet::ext::Registry
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:902
_opCallInferShape
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
_opRegSize
MX_INT_RET _opRegSize()
returns number of ops registered in this library
mxnet::ext::CustomOp::setForward
CustomOp & setForward(fcomp_t fcomp, const char *ctx)
_passRegSize
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
mxnet::ext::CustomOp::CustomOp
CustomOp(const char *op_name)
mxnet::ext::MXTensor::data
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:349
mxnet::ext::Node::alloc_arg
void alloc_arg(const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype)
mxnet::ext::OpResource::get_gpu_rand_states
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:475
mxnet::ext::MXTensor::MXTensor
MXTensor()
mxnet::ext::Graph::outputs
std::vector< NodeEntry > outputs
Definition: lib_api.h:666
kDLCPU
@ kDLCPU
CPU device.
Definition: lib_api.h:103
mxnet::ext::CustomOp::create_op_fp
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:817
kDLGPU
@ kDLGPU
CUDA GPU device.
Definition: lib_api.h:105
mxnet::ext::OpResource::alloc_sparse
void alloc_sparse(MXSparse *sparse, int index, int indices_len, int indptr_len=0) const
allocate sparse memory controlled by MXNet
mxnet::ext::CustomOpSelector::Select
virtual bool Select(int nodeID)=0
_passCallGraphPass
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
mxnet::ext::CustomOp::backward_ctx_cstr
std::vector< const char * > backward_ctx_cstr
Definition: lib_api.h:815
mxnet::ext::partCallSelectInput_t
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1186
mxnet::ext::CustomPartitioner::addStrategy
CustomPartitioner & addStrategy(const char *prop_name, const char *sg_name)
mxnet::ext::graphPass_t
MXReturnValue(* graphPass_t)(mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options)
Custom Pass Create function template.
Definition: lib_api.h:828
mxnet::ext::CustomOpSelector::Filter
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:700
mxnet::ext::MXContext::MXContext
MXContext()
mxnet::ext::MXTensor::stype
MXStorageType stype
Definition: lib_api.h:380
mxnet::ext::CustomStatefulOpWrapper
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:1282
mxnet::ext::partRegGet_t
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1157
mxnet::ext::xpu_malloc_t
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:384
mxnet::ext::MXTensor::dltensor
DLTensor dltensor
Definition: lib_api.h:377
mxnet::ext::CustomStatefulOp::wasCreated
bool wasCreated()
Definition: lib_api.h:726
mxnet::ext::CustomPartitioner::review_map
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:890
mxnet::ext::CustomPartitioner::strategies
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:892
mxnet::ext::MXStorageType
MXStorageType
Definition: lib_api.h:263
mxnet::ext::CustomOp::infer_shape
inferShape_t infer_shape
Definition: lib_api.h:810
mxnet::ext::MXReturnValue
MXReturnValue
Definition: lib_api.h:290
mxnet::ext::Registry::add
T & add(const char *name)
add a new entry
Definition: lib_api.h:916
mxnet::ext::CustomOp::setMutateInputs
CustomOp & setMutateInputs(mutateInputs_t func)
mxnet::ext::kDefaultStorage
@ kDefaultStorage
Definition: lib_api.h:265
mxnet::ext::partCallSelectOutput_t
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1189
mxnet::ext::CustomPartitioner::setCreateSelector
CustomPartitioner & setCreateSelector(const char *prop_name, createSelector_t fn)
mxnet::ext::Graph::DFS
void DFS(std::function< void(Node *)> handler) const
mxnet::ext::JsonVal::parse_string
static JsonVal parse_string(const std::string &json, unsigned int *idx)
mxnet::ext::CustomStatefulOp::ignore_warn
bool ignore_warn
Definition: lib_api.h:740
_partCallReset
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
kDLFloat
@ kDLFloat
Definition: lib_api.h:145
mxnet::ext::Graph::fromJson
static Graph * fromJson(JsonVal val)
mxnet::ext::MXTensor::ctx
MXContext ctx
Definition: lib_api.h:373
mxnet::ext::MXerrorMsgs::add
std::stringstream & add(const char *file, int line)
mxnet::ext::MXContext::GPU
static MXContext GPU()
mxnet::ext::CustomOp::parse_attrs
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:807
mxnet::ext::Node::name
std::string name
Definition: lib_api.h:598
mxnet::ext::MXTensor::size
int64_t size() const
helper function to get data size
mxnet::ext::MXSparse::data_len
int64_t data_len
Definition: lib_api.h:300
MX_INT_RET
#define MX_INT_RET
Definition: lib_api.h:1300
mxnet::ext::CustomStatefulOp
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:714
mxnet::ext::Graph::topological_sort
std::vector< Node * > topological_sort() const
DLDeviceType
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:101
mxnet::ext::CustomOp::setInferSType
CustomOp & setInferSType(inferSType_t func)
mxnet::ext::JsonVal::num
int num
Definition: lib_api.h:565
mxnet::ext::opCallInferShape_t
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:1013
mxnet::ext::CustomStatefulOpWrapper::~CustomStatefulOpWrapper
~CustomStatefulOpWrapper()
DLTensor
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
mxnet::ext::CustomPartitioner::selector_map
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:889
_msgGet
MX_VOID_RET _msgGet(int idx, const char **msg)
returns operator registration at specified index
mxnet::ext::createSelector_t
MXReturnValue(* createSelector_t)(const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:852
mxnet::ext::createOpState_t
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **)
Definition: lib_api.h:769
mxnet::ext::OpResource::OpResource
OpResource(xpu_malloc_t cpu_malloc_fp, void *cpu_alloc_fp, xpu_malloc_t gpu_malloc_fp, void *gpu_alloc_fp, void *stream, sparse_malloc_t sparse_malloc_fp, void *sparse_alloc_fp, void *rng_cpu_states, void *rng_gpu_states)
mxnet::ext::Graph::size
size_t size() const
mxnet::ext::CustomStatefulOp::~CustomStatefulOp
virtual ~CustomStatefulOp()
mxnet::ext::CustomPartitioner::supported_map
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:888
kDLExtDev
@ kDLExtDev
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:126
_msgSize
MX_INT_RET _msgSize()
mxnet::ext::CustomOp::forward_fp
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:816
_opVersion
MX_INT_RET _opVersion()
returns MXNet library version
mxnet::ext::kInt64
@ kInt64
Definition: lib_api.h:256
mxnet::ext::opRegGet_t
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
Definition: lib_api.h:983
mxnet::ext::MXSparse::indices_len
int64_t indices_len
Definition: lib_api.h:306
kDLCPUPinned
@ kDLCPUPinned
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:110
mxnet::ext::CustomStatefulOpWrapper::CustomStatefulOpWrapper
CustomStatefulOpWrapper(CustomStatefulOp *inst, opCallDestroyOpState_t destroy)
Definition: lib_api.h:1285
mxnet::ext::NUM
@ NUM
Definition: lib_api.h:525
mxnet::ext::kInt32
@ kInt32
Definition: lib_api.h:254
mxnet::ext::Registry::size
int size()
Definition: lib_api.h:921
mxnet::ext::MAP
@ MAP
Definition: lib_api.h:525
mxnet::ext::Graph::inputs
std::vector< Node * > inputs
Definition: lib_api.h:665
kDLVulkan
@ kDLVulkan
Vulkan buffer for next generation graphics.
Definition: lib_api.h:114
mxnet::ext::CustomPartitioner::getCreateSelector
createSelector_t getCreateSelector(int stg_id)
mxnet::ext::CustomOp::setInferType
CustomOp & setInferType(inferType_t func)
mxnet::ext::opCallMutateInputs_t
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1088
mxnet::ext::Node::outputs
std::vector< NodeEntry > outputs
Definition: lib_api.h:601
mxnet::ext::JsonVal::str
std::string str
Definition: lib_api.h:566
mxnet::ext::kInt8
@ kInt8
Definition: lib_api.h:255
mxnet::ext::CustomOp::mapToVector
void mapToVector()
mxnet::ext::inferSType_t
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:758
mxnet::ext::JsonVal::parse_list
static JsonVal parse_list(const std::string &json, unsigned int *idx)