mxnet
lib_api.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
32 #ifndef MXNET_LIB_API_H_
33 #define MXNET_LIB_API_H_
34 
35 #include <stdint.h>
36 #include <stdlib.h>
37 #include <string.h>
38 #include <vector>
39 #include <map>
40 #include <unordered_map>
41 #include <string>
42 #include <iostream>
43 #include <utility>
44 #include <stdexcept>
45 #include <random>
46 
47 #if defined(__NVCC__)
48  #include <curand_kernel.h>
49 #endif
50 
51 /* Make sure to update the version number everytime you make changes */
52 #define MX_LIBRARY_VERSION 7
53 
59 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
60  #define PRIVATE_SYMBOL
61 #else
62  #define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden")))
63 #endif
64 
65 /*
66  * Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
67  */
68 #ifndef DLPACK_VERSION
69 #ifdef __cplusplus
70 #define DLPACK_EXTERN_C extern "C"
71 #else
72 #define DLPACK_EXTERN_C
73 #endif
74 
76 #define DLPACK_VERSION 020
77 
79 #ifdef _WIN32
80 #ifdef DLPACK_EXPORTS
81 #define DLPACK_DLL __declspec(dllexport)
82 #else
83 #define DLPACK_DLL __declspec(dllimport)
84 #endif
85 #else
86 #define DLPACK_DLL
87 #endif
88 
89 #include <stdint.h>
90 #include <stddef.h>
91 
92 #ifdef __cplusplus
93 extern "C" {
94  #endif
95 
98  typedef enum {
100  kDLCPU = 1,
102  kDLGPU = 2,
113  kDLMetal = 8,
115  kDLVPI = 9,
117  kDLROCM = 10,
123  kDLExtDev = 12,
124  } DLDeviceType;
125 
129  typedef struct {
131  DLDeviceType device_type;
133  int device_id;
134  } DLContext;
135 
139  typedef enum {
140  kDLInt = 0U,
141  kDLUInt = 1U,
142  kDLFloat = 2U,
143  } DLDataTypeCode;
144 
153  typedef struct {
159  uint8_t code;
163  uint8_t bits;
165  uint16_t lanes;
166  } DLDataType;
167 
171  typedef struct {
191  void* data;
193  DLContext ctx;
195  int ndim;
197  DLDataType dtype;
199  int64_t* shape;
204  int64_t* strides;
206  uint64_t byte_offset;
207  } DLTensor;
208 #ifdef __cplusplus
209 } // DLPACK_EXTERN_C
210 #endif
211 #endif
212 
216 enum MXDType {
217  kFloat32 = 0,
218  kFloat64 = 1,
219  kFloat16 = 2,
220  kUint8 = 3,
221  kInt32 = 4,
222  kInt8 = 5,
223  kInt64 = 6,
224  kUNSET = 100,
225 };
226 
227 /*
228  * MXTensor storage type.
229  */
231  // dense
233  // row sparse
235  // csr
237 };
238 
244 struct MXContext {
245  MXContext() : dev_type("error"), dev_id(-1) {}
246  explicit MXContext(std::string dev_type_, int dev_id_)
247  : dev_type(dev_type_), dev_id(dev_id_) {}
248  explicit MXContext(const char* dev_type_, int dev_id_)
249  : dev_type(dev_type_), dev_id(dev_id_) {}
250  static MXContext CPU() { return MXContext("cpu", 0); }
251  static MXContext GPU() { return MXContext("gpu", 0); }
252  static MXContext CPU(int dev_id) { return MXContext("cpu", dev_id); }
253  static MXContext GPU(int dev_id) { return MXContext("gpu", dev_id); }
254 
255  std::string dev_type;
256  int dev_id;
257 };
258 
260  MX_FAIL = 0,
262 };
263 
264 // For sparse tensors, read/write the data from NDarray via pointers.
265 struct MXSparse {
266  // Pointer to data.
267  void *data{nullptr};
268  // length of (non-zero) data.
269  int64_t data_len;
270 
271  // To store aux data for sparse.
272  // For CSR, indices stores the col index of non-zero elements.
273  // For row sparse, indices store row index of rows which have non-zero elements.
274  int64_t* indices;
275  int64_t indices_len;
276 
277  // For CSR, indptr gives the start and end index of data for each row.
278  // For row sparse, indptr is not used.
279  int64_t* indptr = nullptr;
280  int64_t indptr_len;
281 
282  void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
283  int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0) {
284  data = data_ptr;
285  // If CSR, num of non-zero elemets is num_idx,
286  // If row sparse, num of elements is num_idx * width.
287  data_len = num_idx;
288  if (!idx_ptr) {
289  for (int i = 1; i < ndims; ++i)
290  data_len *= dims[i];
291  }
292 
293  indices = reinterpret_cast<int64_t*>(idx);
294  indices_len = num_idx;
295 
296  if (idx_ptr) {
297  indptr = reinterpret_cast<int64_t*>(idx_ptr);
298  indptr_len = num_idx_ptr;
299  }
300  }
301 };
302 
306 struct MXTensor {
307  MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {}
308  MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape),
309  dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx), stype(oth.stype) {
310  setDLTensor();
311  }
312  MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
313  size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage)
314  : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx), stype(stype) {
315  setDLTensor();
316  }
317 
319  void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
320  size_t vID, MXContext mx_ctx, MXStorageType storage_type) {
321  data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = storage_type;
322  shape.clear();
323  for (int j = 0; j < ndims; j++) {
324  shape.push_back(dims[j]);
325  }
326  setDLTensor();
327  }
328 
330  void setDLTensor() {
331  dltensor.data = data_ptr;
332  dltensor.ndim = shape.size();
333  dltensor.shape = const_cast<int64_t*>(shape.data());
334  dltensor.strides = nullptr;
335  dltensor.byte_offset = 0;
336  dltensor.dtype.lanes = 1;
337  dltensor.ctx.device_id = ctx.dev_id;
338  if (ctx.dev_type == "cpu")
339  dltensor.ctx.device_type = kDLCPU;
340  else if (ctx.dev_type == "gpu")
341  dltensor.ctx.device_type = kDLGPU;
342  else if (ctx.dev_type == "opencl")
343  dltensor.ctx.device_type = kDLOpenCL;
344  else if (ctx.dev_type == "vulcan")
345  dltensor.ctx.device_type = kDLVulkan;
346  else if (ctx.dev_type == "metal")
347  dltensor.ctx.device_type = kDLMetal;
348  else if (ctx.dev_type == "vpi")
349  dltensor.ctx.device_type = kDLVPI;
350  else if (ctx.dev_type == "rocm")
351  dltensor.ctx.device_type = kDLROCM;
352  else
353  dltensor.ctx.device_type = kDLExtDev;
354  switch (dtype) {
355  case kFloat32:
356  dltensor.dtype.code = kDLFloat;
357  dltensor.dtype.bits = 32;
358  break;
359  case kFloat64:
360  dltensor.dtype.code = kDLFloat;
361  dltensor.dtype.bits = 64;
362  break;
363  case kFloat16:
364  dltensor.dtype.code = kDLFloat;
365  dltensor.dtype.bits = 16;
366  break;
367  case kUint8:
368  dltensor.dtype.code = kDLUInt;
369  dltensor.dtype.bits = 8;
370  break;
371  case kInt32:
372  dltensor.dtype.code = kDLInt;
373  dltensor.dtype.bits = 32;
374  break;
375  case kInt8:
376  dltensor.dtype.code = kDLInt;
377  dltensor.dtype.bits = 8;
378  break;
379  case kInt64:
380  dltensor.dtype.code = kDLInt;
381  dltensor.dtype.bits = 64;
382  break;
383  default:
384  dltensor.dtype.code = 0;
385  dltensor.dtype.bits = 0;
386  throw std::runtime_error("Error! Invalid dtype flag: "
387  + std::to_string(static_cast<int>(dtype))
388  + " when constructing MXTensor");
389  }
390  }
391 
393  template<typename data_type>
394  inline data_type* data() {
395  return reinterpret_cast<data_type*>(data_ptr);
396  }
397 
399  inline int64_t size() const {
400  int64_t size = 1;
401  for (unsigned int i = 0; i < shape.size(); i++) {
402  size *= shape[i];
403  }
404  return size;
405  }
406 
408  inline bool isSame(const MXTensor &oth) const {
409  return data_ptr == oth.data_ptr &&
410  dtype == oth.dtype &&
411  verID == oth.verID &&
412  ctx.dev_type == oth.ctx.dev_type &&
413  ctx.dev_id == oth.ctx.dev_id &&
414  shape == oth.shape &&
415  stype == oth.stype;
416  }
417 
418  // For dense, data_ptr points to 1D flattened tensor data
419  // For sparse, data_ptr points to MXSparse
420  void *data_ptr;
421 
422  // shape is in [2,3,4] format to represent high-dim tensor
423  std::vector<int64_t> shape;
424 
425  // type can only be MXDType enum types
427 
428  // version number updated if the tensor has changed since the last use by custom op
429  size_t verID;
430 
431  // context of MXTensor representing which device the tensor data is located
433 
434  // corresponding DLTensor repr of MXTensor
435  // easy way to reuse functions taking DLTensor
437 
438  // storage type
440 };
441 
443 typedef void* (*xpu_malloc_t)(void*, int);
445 typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**);
447 typedef void (*nd_malloc_t)(const void* _ndarray_alloc, const int64_t* shapes, int num_shapes,
448  const char* dev_str, int dev_id, int dtype, const char* name,
449  int isArg, void** data);
451 #if defined(__NVCC__)
452  typedef cudaStream_t mx_stream_t;
453  typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
454 #else
455  typedef void* mx_stream_t;
456  typedef void* mx_gpu_rand_t;
457 #endif
458 typedef std::mt19937 mx_cpu_rand_t;
459 
461 /* Each thread should generate random number unique sequence out of different states */
462 #define MX_NUM_CPU_RANDOM_STATES 1024
463 #define MX_NUM_GPU_RANDOM_STATES 32768
464 
466  public:
467  PassResource(std::unordered_map<std::string, MXTensor>* new_args,
468  std::unordered_map<std::string, MXTensor>* new_aux,
469  nd_malloc_t nd_malloc, const void* nd_alloc)
470  : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {}
471  MXTensor* alloc_arg(const std::string& name, const std::vector<int64_t>& shapes,
472  const MXContext &ctx, MXDType dtype) const {
473  void* data;
474  nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id,
475  dtype, name.c_str(), 1, &data);
476  MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
477  (*new_args_)[name] = tensor;
478  return &(new_args_->at(name));
479  }
480  MXTensor* alloc_aux(const std::string& name, const std::vector<int64_t>& shapes,
481  const MXContext &ctx, MXDType dtype) const {
482  void* data;
483  nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id,
484  dtype, name.c_str(), 0, &data);
485  MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage);
486  (*new_aux_)[name] = tensor;
487  return &(new_aux_->at(name));
488  }
489 
490  private:
491  std::unordered_map<std::string, MXTensor>* new_args_;
492  std::unordered_map<std::string, MXTensor>* new_aux_;
493  nd_malloc_t nd_malloc_;
494  const void* nd_alloc_;
495 };
496 
500 class OpResource {
501  public:
502  OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
503  xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
504  sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
505  void* rng_cpu_states, void* rng_gpu_states)
506  : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
507  cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
508  sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp),
509  rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {}
510 
512  void* alloc_cpu(int size) const {
513  return cpu_malloc(cpu_alloc, size);
514  }
515 
517  void* alloc_gpu(int size) const {
518  return gpu_malloc(gpu_alloc, size);
519  }
520 
523  return static_cast<mx_stream_t>(cuda_stream);
524  }
525 
527  void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) const {
528  sparse_malloc(sparse_alloc, index, indices_len, indptr_len,
529  &(sparse->data), &(sparse->indices), &(sparse->indptr));
530  }
531 
533  /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
535  return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
536  }
537 
539  /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
540  /* Note that if you are using cpu build, it will return a nullptr */
542  return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
543  }
544 
545  private:
547  xpu_malloc_t cpu_malloc, gpu_malloc;
549  void *cpu_alloc, *gpu_alloc;
551  void *cuda_stream;
553  sparse_malloc_t sparse_malloc;
555  void *sparse_alloc;
557  void *rand_cpu_states, *rand_gpu_states;
558 };
559 
561 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
562 #define MX_STR_DTYPE "__ext_dtype__"
563 #define MX_STR_SHAPE "__ext_shape__"
564 
565 /* \brief get shape value from list of shapes string
566  *
567  * Examples:
568  *
569  * getShapeAt("[[1]]", 0) returns "[1]"
570  * getShapeAt("[[1],[2,3]]", 1) returns "[2,3]"
571  */
572 std::string getShapeAt(const std::string& shape, unsigned index) {
573  int idx = 1; // start at 1 to skip the first square bracket [
574  // find the beginning of the output shape for the particular output index
575  for (unsigned x=0; x < index; x++)
576  idx = shape.find("[", idx+1);
577  int stop = shape.find("]", idx); // find stop index for this output shape
578  // add this shape to the list
579  return shape.substr(idx, stop-idx+1);
580 }
581 
582 /* \brief get dtype value from list of dtypes string
583  *
584  * Examples:
585  *
586  * getDtypeAt("[1]", 0) returns "1"
587  * getDtypeAt("[1,2]", 1) returns "2"
588  */
589 std::string getDtypeAt(const std::string& dtype, unsigned index) {
590  // find the beginning of the output dtype for the particular output index
591  int idx = 0;
592  for (unsigned x=0; x < index; x++)
593  idx = dtype.find(",", idx+1);
594  int stop = dtype.find(",", idx+1); // find stop index for this output dtype
595  if (stop == -1) stop = dtype.find("]", idx+1);
596  return dtype.substr(idx+1, stop-idx-1);
597 }
598 
603 enum JsonType {ERR, STR, NUM, LIST, MAP};
604 
606 struct JsonVal {
607  JsonVal() : type(ERR), num(-1), str("") {} // default constructor
608  // construct a JSON object by type
609  explicit JsonVal(JsonType t) : type(t), num(-1), str("") {}
610  // construct a string JSON object
611  explicit JsonVal(std::string s) : type(STR), num(-1), str(s) {}
612  // construct a number JSON object
613  explicit JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {}
614  // complex constructor
615  JsonVal(JsonType t, int n, std::string s) : type(t), num(n), str(s) {}
616  bool operator<(const JsonVal &o) const {
617  // for string JSON objects compare the string
618  if (type == STR) return type == o.type && str < o.str;
619  // for number JSON objects compare the number
620  if (type == NUM) return type == o.type && num < o.num;
621  // for list JSON objects, compare the size of list, and then each object in the list
622  if (type == LIST) {
623  if (list.size() != o.list.size()) return false;
624  for (unsigned int i=0; i< list.size(); i++)
625  if (list[i] < o.list[i])
626  return false; // if we find an object that doesnt match return
627  return true; // all objects in lists matched
628  }
629  // for map JSON objects, compare the size of map, and then each key/value in the maps
630  if (type == MAP) {
631  if (map.size() != o.map.size()) return false;
632  for (auto &item : map) {
633  // if one map is missing a key in another return
634  if (o.map.find(item.first) == o.map.end()) return false;
635  if (item.second < o.map.at(item.first)) return false;
636  }
637  return true;
638  }
639  return type < o.type;
640  }
642  int num;
643  std::string str;
644  std::vector<JsonVal> list;
645  std::map<JsonVal, JsonVal> map;
646 };
647 
649 struct JsonParser {
650  JsonVal parse_to_json(const std::string& json) {
651  unsigned int idx = 0;
652  return parse(json, &idx);
653  }
654  void print_json_val(const JsonVal& val) {
655  std::cout << json_val_string(val) << std::endl;
656  }
657  // debug function to dump data structure to string
658  std::string json_val_string(const JsonVal &val) {
659  std::string ret;
660  switch (val.type) {
661  case ERR:
662  ret = "json(Error)";
663  break;
664  case STR:
665  ret = "json(STR:" + val.str + ")";
666  break;
667  case NUM:
668  ret = "json(INT:" + val.str + ")";
669  break;
670  case LIST:
671  ret = "json(LIST:[";
672  for (auto &item : val.list)
673  ret += json_val_string(item) + ",";
674  ret += "])";
675  break;
676  case MAP:
677  ret = "json(MAP:{";
678  for (auto &item : val.map)
679  ret += json_val_string(item.first) + " : " + json_val_string(item.second) + ",";
680  ret += "})";
681  break;
682  }
683  return ret;
684  }
685  // parse a string JSON object
686  JsonVal parse_string(const std::string& json, unsigned int* idx) {
687  JsonVal ret(STR);
688  while (*idx < json.size()) {
689  if (json[*idx] == '"') {
690  ++(*idx);
691  return ret;
692  } else {
693  ret.str += json[*idx];
694  ++(*idx);
695  }
696  }
697  std::cout << "Error! Unable to parse string" << std::endl;
698  return JsonVal();
699  }
700  // parse a number JSON object
701  JsonVal parse_num(const std::string& json, unsigned int* idx) {
702  JsonVal ret(NUM);
703  while (*idx < json.size()) {
704  if (json[*idx] >= '0' && json[*idx] <= '9') {
705  ret.str += json[*idx];
706  ++(*idx);
707  } else {
708  break;
709  }
710  }
711  ret.num = std::stoi(ret.str);
712  return ret;
713  }
714  // parse a list of JSON objects
715  JsonVal parse_list(const std::string& json, unsigned int* idx) {
716  JsonVal ret(LIST);
717  while (*idx < json.size()) {
718  if (json[*idx] == ']') {
719  ++(*idx);
720  return ret;
721  } else {
722  JsonVal item = parse(json, idx);
723  if (item.type != ERR)
724  ret.list.push_back(item);
725  }
726  }
727  std::cout << "Error! Unable to parse list" << std::endl;
728  return JsonVal();
729  }
730  // parse a map of JSON objects
731  JsonVal parse_map(const std::string& json, unsigned int* idx) {
732  JsonVal ret(MAP), key;
733  while (*idx < json.size()) {
734  if (json[*idx] == '}') {
735  ++(*idx);
736  return ret;
737  } else {
738  JsonVal item = parse(json, idx);
739  if (key.type == ERR) {
740  key = item;
741  } else {
742  ret.map[key] = item;
743  key.type = ERR;
744  }
745  }
746  }
747  std::cout << "Error! Unable to parse map" << std::endl;
748  return JsonVal();
749  }
750  // generic parse function
751  JsonVal parse(const std::string& json, unsigned int *idx) {
752  JsonVal ret;
753  while (*idx < json.size()) {
754  if (json[*idx] == '"') {
755  ++(*idx);
756  ret = parse_string(json, idx);
757  } else if (json[*idx] >= '0' && json[*idx] <= '9') {
758  ret = parse_num(json, idx);
759  } else if (json[*idx] == '[') {
760  ++(*idx);
761  ret = parse_list(json, idx);
762  } else if (json[*idx] == '{') {
763  ++(*idx);
764  ret = parse_map(json, idx);
765  } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;}
766  if (ret.type != ERR) return ret;
767  ++(*idx);
768  }
769  return ret;
770  }
771  // convert JSON object back to JSON-compatible string
772  std::string dump(const JsonVal &val) {
773  std::string ret;
774  switch (val.type) {
775  case ERR:
776  ret = "json(Error)";
777  break;
778  case STR:
779  ret = "\"" + val.str + "\"";
780  break;
781  case NUM:
782  ret = val.str;
783  break;
784  case LIST:
785  ret = "[";
786  for (unsigned i=0; i < val.list.size(); i++) {
787  auto &item = val.list[i];
788  ret += dump(item);
789  if (i < val.list.size()-1)
790  ret += ",";
791  }
792  ret += "]";
793  break;
794  case MAP:
795  ret = "{";
796  unsigned cnt = 0;
797  for (auto &item : val.map) {
798  ret += dump(item.first) + " : " + dump(item.second);
799  if (cnt++ < val.map.size()-1)
800  ret += ",";
801  }
802  ret += "}";
803  break;
804  }
805  return ret;
806  }
807 };
808 
809 /* \brief An abstract class for library authors creating custom
810  * partitioners. Optional, can just implement supportedOps instead
811  */
813  public:
814  /* \brief Select a node to include in subgraph, return true to include node
815  * nodeID - index of node in graph
816  */
817  virtual bool Select(int nodeID) = 0;
818  /* \brief Select an input node from current node to include in subgraph
819  * return true to include node
820  * nodeID - index of node in graph
821  * input_nodeID - index of input node in graph
822  */
823  virtual bool SelectInput(int nodeID, int input_nodeID) = 0;
824  /* \brief Select an output node from current node to include in subgraph
825  * return true to include node
826  * nodeID - index of node in graph
827  * output_nodeID - index of output node in graph
828  */
829  virtual bool SelectOutput(int nodeID, int output_nodeID) = 0;
830  /* \brief Review nodes to include in subgraph
831  * return set of candidate nodes to keep in subgraph
832  * candidates - indices of nodes to include in subgraph
833  * keep - indices of nodes to keep in subgraph
834  */
835  virtual void Filter(const std::vector<int>& candidates,
836  std::vector<int>* keep) {
837  keep->insert(keep->end(), candidates.begin(), candidates.end());
838  }
839  /* \brief Reset any selector state, called after growing subgraph, before filter
840  * Called after finished calling SelectInput/SelectOutput and growing subgraph
841  */
842  virtual void Reset() {}
843 };
844 
851  public:
852  virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
853  std::vector<MXTensor>* outputs,
854  const OpResource& op_res) = 0;
855  virtual MXReturnValue Backward(std::vector<MXTensor>* inputs,
856  std::vector<MXTensor>* outputs,
857  const OpResource& op_res) {
858  std::cout << "Error! Operator does not support backward" << std::endl;
859  return MX_FAIL;
860  }
861 };
862 
865  public:
866  explicit CustomStatefulOpWrapper(CustomStatefulOp* inst) : instance(inst) {}
867  CustomStatefulOp* get_instance() { return instance; }
868  private:
869  CustomStatefulOp* instance;
870 };
871 
873 typedef MXReturnValue (*fcomp_t)(const std::unordered_map<std::string,
874  std::string>& attributes,
875  std::vector<MXTensor>* inputs,
876  std::vector<MXTensor>* outputs,
877  const OpResource& res);
878 typedef MXReturnValue (*parseAttrs_t)(const std::unordered_map<std::string,
879  std::string>& attributes,
880  int* num_inputs, int* num_outputs);
881 typedef MXReturnValue (*inferType_t)(const std::unordered_map<std::string,
882  std::string>& attributes,
883  std::vector<int>* in_types,
884  std::vector<int>* out_types);
885 typedef MXReturnValue (*inferSType_t)(const std::unordered_map<std::string,
886  std::string>& attributes,
887  std::vector<int>* in_storage_types,
888  std::vector<int>* out_storage_types);
889 typedef MXReturnValue (*inferShape_t)(const std::unordered_map<std::string,
890  std::string>& attributes,
891  std::vector<std::vector<unsigned int> >* in_shapes,
892  std::vector<std::vector<unsigned int> >* out_shapes);
893 typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
894  std::string>& attributes,
895  std::vector<int>* input_indices);
896 typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
897  std::string>& attributes,
898  CustomStatefulOp**);
899 
903 class CustomOp {
904  public:
905  explicit CustomOp(const char* op_name) : name(op_name),
906  parse_attrs(NULL), infer_type(NULL), infer_storage_type(NULL), infer_shape(NULL),
907  mutate_inputs(NULL), isSGop(false) {}
908  CustomOp& setForward(fcomp_t fcomp, const char* ctx) {
909  if (forward_ctx_map.count(ctx) > 0)
910  raiseDuplicateContextError();
911  forward_ctx_map[ctx] = fcomp;
912  return *this;
913  }
914  CustomOp& setBackward(fcomp_t fgrad, const char* ctx) {
915  if (backward_ctx_map.count(ctx) > 0)
916  raiseDuplicateContextError();
917  backward_ctx_map[ctx] = fgrad;
918  return *this;
919  }
921  parse_attrs = func;
922  return *this;
923  }
925  infer_type = func;
926  return *this;
927  }
929  infer_storage_type = func;
930  return *this;
931  }
933  infer_shape = func;
934  return *this;
935  }
937  mutate_inputs = func;
938  return *this;
939  }
940  CustomOp& setCreateOpState(createOpState_t func, const char* ctx) {
941  if (create_op_ctx_map.count(ctx) > 0)
942  raiseDuplicateContextError();
943  create_op_ctx_map[ctx] = func;
944  return *this;
945  }
947  isSGop = true;
948  return *this;
949  }
950  void mapToVector() {
951  for (auto kv : forward_ctx_map) {
952  forward_ctx_cstr.push_back(kv.first);
953  forward_fp.push_back(kv.second);
954  }
955  for (auto kv : backward_ctx_map) {
956  backward_ctx_cstr.push_back(kv.first);
957  backward_fp.push_back(kv.second);
958  }
959  for (auto kv : create_op_ctx_map) {
960  create_op_ctx_cstr.push_back(kv.first);
961  create_op_fp.push_back(kv.second);
962  }
963  }
965 
967  const char* name;
968 
975  bool isSGop;
976 
978  std::vector<const char*> forward_ctx_cstr, backward_ctx_cstr, create_op_ctx_cstr;
979  std::vector<fcomp_t> forward_fp, backward_fp;
980  std::vector<createOpState_t> create_op_fp;
981 
982  private:
983  void raiseDuplicateContextError() {
984  std::string op_name_str(name);
985  throw std::runtime_error(
986  "Error! Error! Cannot register multiple functions under same context for operator '"
987  + op_name_str + "'");
988  }
989 
991  std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
992  std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
993 };
994 
996 typedef MXReturnValue (*graphPass_t)(const std::string& in_graph, const std::string** out_graph,
997  const std::unordered_map<std::string, std::string>& options,
998  const std::unordered_map<std::string, MXTensor>& args,
999  const std::unordered_map<std::string, MXTensor>& aux,
1000  const PassResource& res);
1001 
1005 class CustomPass {
1006  public:
1007  CustomPass() : name("ERROR") {}
1008  explicit CustomPass(const char* pass_name)
1009  : name(pass_name) {}
1011  pass = fn;
1012  return *this;
1013  }
1014 
1016  const char* name;
1019 };
1020 
1022 typedef MXReturnValue (*supportedOps_t)(const std::string& json, std::vector<int>* ids,
1023  const std::unordered_map<std::string,
1024  std::string>& options);
1025 typedef MXReturnValue (*createSelector_t)(const std::string& json, CustomOpSelector** sel_inst,
1026  const std::unordered_map<std::string,
1027  std::string>& options);
1028 typedef MXReturnValue (*reviewSubgraph_t)(const std::string& json, int subgraph_id, bool* accept,
1029  const std::unordered_map<std::string,
1030  std::string>& options,
1031  std::unordered_map<std::string, std::string>* attrs,
1032  const std::unordered_map<std::string, MXTensor>& args,
1033  const std::unordered_map<std::string, MXTensor>& aux);
1034 
1039  public:
1040  CustomPartitioner() : name("ERROR") {}
1041  explicit CustomPartitioner(const char* backend_name) :
1042  name(backend_name) {}
1043  CustomPartitioner& addStrategy(const char* prop_name,
1044  const char* sg_name) {
1045  strategies.push_back(prop_name);
1046  op_names.push_back(sg_name);
1047  return *this;
1048  }
1049  CustomPartitioner& setSupportedOps(const char* prop_name, supportedOps_t fn) {
1050  supported_map[std::string(prop_name)] = fn;
1051  return *this;
1052  }
1054  selector_map[std::string(prop_name)] = fn;
1055  return *this;
1056  }
1058  review_map[std::string(prop_name)] = fn;
1059  return *this;
1060  }
1062  std::string prop(strategies[stg_id]);
1063  if (supported_map.count(prop) > 0)
1064  return supported_map[prop];
1065  else
1066  return nullptr;
1067  }
1069  std::string prop(strategies[stg_id]);
1070  if (selector_map.count(prop) > 0)
1071  return selector_map[prop];
1072  else
1073  return nullptr;
1074  }
1076  std::string prop(strategies[stg_id]);
1077  if (review_map.count(prop) > 0)
1078  return review_map[prop];
1079  else
1080  return nullptr;
1081  }
1082 
1084  const char* name;
1085  std::map<std::string, supportedOps_t> supported_map;
1086  std::map<std::string, createSelector_t> selector_map;
1087  std::map<std::string, reviewSubgraph_t> review_map;
1089  std::vector<const char*> strategies;
1091  std::vector<const char*> op_names;
1092 };
1093 
1098 template <class T>
1099 class Registry {
1100  public:
1105  static Registry* get() PRIVATE_SYMBOL {
1106  static Registry inst;
1107  return &inst;
1108  }
1113  T& add(const char* name) {
1114  T *entry = new T(name);
1115  entries.push_back(entry);
1116  return *entry;
1117  }
1118  int size() {
1119  return entries.size();
1120  }
1121  T& get(int idx) {
1122  return *(entries.at(idx));
1123  }
1124 
1125  private:
1127  Registry() {}
1129  ~Registry() {}
1131  std::vector<T*> entries;
1132 };
1133 
1139 #define MX_STR_CONCAT_(__a, __b) __a ## __b
1140 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b)
1141 
1143 #define MX_STRINGIFY(x) #x
1144 #define MX_TOSTRING(x) MX_STRINGIFY(x)
1145 
1147 #define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _
1148 #define MX_REGISTER_DEF_(Name) CustomOp MX_REGISTER_NAME_(Name)
1149 
1150 #define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _
1151 #define MX_REGISTER_PROP_DEF_(Name) CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
1152 
1153 #define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _
1154 #define MX_REGISTER_PASS_DEF_(Name) CustomPass MX_REGISTER_PASS_NAME_(Name)
1155 
1157 #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
1158  Registry<CustomOp>::get()->add(MX_TOSTRING(Name))
1159 
1160 #define REGISTER_PARTITIONER(Name) \
1161  MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
1162  Registry<CustomPartitioner>::get()->add(MX_TOSTRING(Name))
1163 
1164 #define REGISTER_PASS(Name) \
1165  MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
1166  Registry<CustomPass>::get()->add(MX_TOSTRING(Name))
1167 
1168 /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */
1169 
1175 #define MXLIB_OPREGSIZE_STR "_opRegSize"
1176 typedef int (*opRegSize_t)(void);
1177 
1178 #define MXLIB_OPREGGET_STR "_opRegGet"
1179 typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop,
1180  const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count,
1181  const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count,
1182  const char*** create_op_ctx, createOpState_t** create_op_fp,
1183  int* create_op_count,
1184  parseAttrs_t* parse, inferType_t* type, inferSType_t* stype,
1185  inferShape_t* shape, mutateInputs_t* mutate);
1186 
1187 #define MXLIB_OPCALLFREE_STR "_opCallFree"
1188 typedef int (*opCallFree_t)(void* ptr);
1189 
1190 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
1191 typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char* const* keys,
1192  const char* const* vals, int num,
1193  int* num_in, int* num_out);
1194 
1195 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape"
1196 typedef int (*opCallInferShape_t)(inferShape_t inferShape, const char* const* keys,
1197  const char* const* vals, int num,
1198  unsigned int** inshapes, int* indims, int num_in,
1199  unsigned int*** mod_inshapes, int** mod_indims,
1200  unsigned int*** outshapes, int** outdims, int num_out);
1201 
1202 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType"
1203 typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* keys,
1204  const char* const* vals, int num,
1205  int* intypes, int num_in, int* outtypes, int num_out);
1206 
1207 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
1208 typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const* keys,
1209  const char* const* vals, int num,
1210  int* intypes, int num_in, int* outtypes, int num_out);
1211 
1212 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
1213 typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
1214  const char* const* vals, int num,
1215  const int64_t** inshapes, int* indims,
1216  void** indata, int* intypes,
1217  size_t* inIDs, const char** indev_type,
1218  int* indev_id, int num_in,
1219  const int64_t** outshapes, int* outdims,
1220  void** outdata, int* outtypes,
1221  size_t* outIDs, const char** outdev_type,
1222  int* outdev_id, int num_out,
1223  xpu_malloc_t cpu_malloc, void* cpu_alloc,
1224  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
1225  sparse_malloc_t sparse_malloc, void* sparse_alloc,
1226  int* instypes, int* outstypes,
1227  void** in_indices, void** out_indices,
1228  void** in_indptr, void** out_indptr,
1229  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1230  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1231  void* rng_cpu_states, void* rng_gpu_states);
1232 
1233 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
1234 typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys,
1235  const char* const* vals, int num,
1236  int** mutate_indices, int* indices_size);
1237 
1238 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
1239 typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys,
1240  const char* const* vals, int num,
1241  void** state_op);
1242 
1243 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
1244 typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
1245  const int64_t** inshapes, int* indims,
1246  void** indata, int* intypes,
1247  size_t* inIDs, const char** indev_type,
1248  int* indev_id, int num_in,
1249  const int64_t** outshapes, int* outdims,
1250  void** outdata, int* outtypes,
1251  size_t* outIDs, const char** outdev_type,
1252  int* outdev_id, int num_out,
1253  xpu_malloc_t cpu_malloc, void* cpu_alloc,
1254  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream,
1255  sparse_malloc_t sparse_malloc, void* sparse_alloc,
1256  int* instypes, int* outstypes,
1257  void** in_indices, void** out_indices,
1258  void** in_indptr, void** out_indptr,
1259  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1260  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1261  void* rng_cpu_states, void* rng_gpu_states);
1262 
1263 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
1264 typedef int (*partRegSize_t)(void);
1265 
1266 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount"
1267 typedef int (*partRegGetCount_t)(int idx, const char** name);
1268 
1269 #define MXLIB_PARTREGGET_STR "_partRegGet"
1270 typedef void (*partRegGet_t)(int part_idx, int stg_idx, const char** strategy,
1271  supportedOps_t* supportedOps, createSelector_t* createSelector,
1272  reviewSubgraph_t* reviewSubgraph, const char** op_name);
1273 
1274 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps"
1275 typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json,
1276  int num_ids, int *ids, const char* const* opt_keys,
1277  const char* const* opt_vals, int num_opts);
1278 
1279 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector"
1280 typedef int (*partCallCreateSelector_t)(createSelector_t createSelector, const char *json,
1281  void** selector, const char* const* opt_keys,
1282  const char* const* opt_vals, int num_opts);
1283 
1284 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect"
1285 typedef void (*partCallSelect_t)(void* sel_inst, int nodeID, int* selected);
1286 
1287 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput"
1288 typedef void (*partCallSelectInput_t)(void* sel_inst, int nodeID, int input_nodeID,
1289  int* selected);
1290 
1291 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput"
1292 typedef void (*partCallSelectOutput_t)(void* sel_inst, int nodeID, int output_nodeID,
1293  int* selected);
1294 
1295 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter"
1296 typedef void (*partCallFilter_t)(void* sel_inst, int* candidates, int num_candidates,
1297  int** keep, int* num_keep);
1298 
1299 #define MXLIB_PARTCALLRESET_STR "_partCallReset"
1300 typedef void (*partCallReset_t)(void* sel_inst);
1301 
1302 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph"
1303 typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json,
1304  int subgraph_id, int *accept, const char* const* opt_keys,
1305  const char* const* opt_vals, int num_opts,
1306  char*** attr_keys, char*** attr_vals, int *num_attrs,
1307  const char* const* arg_names, int num_args,
1308  void* const* arg_data, const int64_t* const* arg_shapes,
1309  const int* arg_dims, const int* arg_types,
1310  const size_t* arg_IDs, const char* const* arg_dev_type,
1311  const int* arg_dev_id,
1312  const char* const* aux_names, int num_aux,
1313  void* const* aux_data, const int64_t* const* aux_shapes,
1314  const int* aux_dims, const int* aux_types,
1315  const size_t* aux_IDs, const char* const* aux_dev_type,
1316  const int* aux_dev_id);
1317 
1318 #define MXLIB_PASSREGSIZE_STR "_passRegSize"
1319 typedef int (*passRegSize_t)(void);
1320 
1321 #define MXLIB_PASSREGGET_STR "_passRegGet"
1322 typedef void (*passRegGet_t)(int pass_idx, graphPass_t* graphPass, const char** pass_name);
1323 
1324 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass"
1325 typedef int (*passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph,
1326  char** out_graph, const char* const* opt_keys,
1327  const char* const* opt_vals, int num_opts,
1328  const char* pass_name, const char* const* arg_names,
1329  int num_args, void* const* arg_data,
1330  const int64_t* const* arg_shapes, const int* arg_dims,
1331  const int* arg_types, const size_t* arg_IDs,
1332  const char* const* arg_dev_type, const int* arg_dev_id,
1333  const char* const* aux_names, int num_aux,
1334  void* const* aux_data, const int64_t* const* aux_shapes,
1335  const int* aux_dims, const int* aux_types,
1336  const size_t* aux_IDs, const char* const* aux_dev_type,
1337  const int* aux_dev_id, nd_malloc_t nd_malloc,
1338  const void* nd_alloc);
1339 
1340 #define MXLIB_INITIALIZE_STR "initialize"
1341 typedef int (*initialize_t)(int version);
1342 
1343 #define MXLIB_OPVERSION_STR "_opVersion"
1344 typedef int (*opVersion_t)();
1345 
1346 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1347 #define MX_INT_RET __declspec(dllexport) int __cdecl
1348 #define MX_VOID_RET __declspec(dllexport) void __cdecl
1349 #else
1350 #define MX_INT_RET int
1351 #define MX_VOID_RET void
1352 #endif
1353 
1354 extern "C" {
1357  return MX_LIBRARY_VERSION;
1358  }
1359 
1362  return Registry<CustomOp>::get()->size();
1363  }
1364 
1366  MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop,
1367  const char*** forward_ctx, fcomp_t** forward_fp,
1368  int* forward_count, const char*** backward_ctx,
1369  fcomp_t** backward_fp, int* backward_count,
1370  const char*** create_op_ctx, createOpState_t** create_op_fp,
1371  int* create_op_count, parseAttrs_t* parse, inferType_t* type,
1372  inferSType_t* stype, inferShape_t* shape, mutateInputs_t* mutate) {
1373  CustomOp &op = Registry<CustomOp>::get()->get(idx);
1374  *name = op.name;
1375  *parse = op.parse_attrs;
1376  *type = op.infer_type;
1377  *stype = op.infer_storage_type;
1378  *shape = op.infer_shape;
1379  *mutate = op.mutate_inputs;
1380  *isSGop = op.isSGop;
1381  op.mapToVector();
1382  *forward_ctx = op.forward_ctx_cstr.data();
1383  *forward_fp = op.forward_fp.data();
1384  *forward_count = op.forward_fp.size();
1385  *backward_ctx = op.backward_ctx_cstr.data();
1386  *backward_fp = op.backward_fp.data();
1387  *backward_count = op.backward_fp.size();
1388  *create_op_ctx = op.create_op_ctx_cstr.data();
1389  *create_op_fp = op.create_op_fp.data();
1390  *create_op_count = op.create_op_fp.size();
1391  }
1392 
1395  free(ptr);
1396  }
1397 
1399  MX_INT_RET _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys,
1400  const char* const* vals, int num,
1401  int* num_in, int* num_out) {
1402  // create map of attributes from list
1403  std::unordered_map<std::string, std::string> attrs;
1404  for (int i = 0; i < num; i++) {
1405  attrs[std::string(keys[i])] = std::string(vals[i]);
1406  }
1407 
1408  return parseAttrs(attrs, num_in, num_out);
1409  }
1410 
1412  MX_INT_RET _opCallInferShape(inferShape_t inferShape, const char* const* keys,
1413  const char* const* vals, int num,
1414  unsigned int** inshapes, int* indims, int num_in,
1415  unsigned int*** mod_inshapes, int** mod_indims,
1416  unsigned int*** outshapes, int** outdims, int num_out) {
1417  // create map of attributes from list
1418  std::unordered_map<std::string, std::string> attrs;
1419  for (int i = 0; i < num; i++) {
1420  attrs[std::string(keys[i])] = std::string(vals[i]);
1421  }
1422 
1423  // create a vector of shapes for inputs
1424  std::vector<std::vector<unsigned int> > in_shapes(num_in);
1425  for (int i = 0; i < num_in; i++) {
1426  for (int j = 0; j < indims[i]; j++) {
1427  in_shapes[i].push_back(inshapes[i][j]);
1428  }
1429  }
1430 
1431  // create a vector of shapes for outputs
1432  std::vector<std::vector<unsigned int> > out_shapes(num_out);
1433 
1434  int retval = inferShape(attrs, &in_shapes, &out_shapes);
1435  if (!retval) return retval;
1436 
1437  // allocate space for modified input dims, shape
1438  *mod_indims = static_cast<int*>(malloc (num_in * sizeof(int)));
1439  *mod_inshapes = static_cast<unsigned**>(malloc (num_in * sizeof(unsigned*)));
1440 
1441  // copy modified input shapes
1442  for (int i = 0; i < num_in; i++) {
1443  (*mod_indims)[i] = in_shapes[i].size();
1444  (*mod_inshapes)[i] = static_cast<unsigned*>(malloc ((*mod_indims)[i] * sizeof(unsigned)));
1445  for (int j = 0; j < (*mod_indims)[i]; j++) {
1446  (*mod_inshapes)[i][j] = in_shapes[i][j];
1447  }
1448  }
1449 
1450  // allocate space for output dims, shape
1451  *outdims = static_cast<int*>(malloc (num_out * sizeof(int)));
1452  *outshapes = static_cast<unsigned**>(malloc (num_out * sizeof(unsigned*)));
1453 
1454  // copy output shapes
1455  for (int i = 0; i < num_out; i++) {
1456  (*outdims)[i] = out_shapes[i].size();
1457  (*outshapes)[i] = static_cast<unsigned*>(malloc ((*outdims)[i] * sizeof(unsigned)));
1458  for (int j = 0; j < (*outdims)[i]; j++) {
1459  (*outshapes)[i][j] = out_shapes[i][j];
1460  }
1461  }
1462 
1463  return retval;
1464  }
1465 
1467  MX_INT_RET _opCallInferType(inferType_t inferType, const char* const* keys,
1468  const char* const* vals, int num,
1469  int* intypes, int num_in, int* outtypes, int num_out) {
1470  // create map of attributes from list
1471  std::unordered_map<std::string, std::string> attrs;
1472  for (int i = 0; i < num; i++) {
1473  attrs[std::string(keys[i])] = std::string(vals[i]);
1474  }
1475 
1476  // create a vector of types for inputs
1477  std::vector<int> in_types(num_in);
1478  for (int i = 0; i < num_in; i++) {
1479  in_types[i] = intypes[i];
1480  }
1481 
1482  // create a vector of types for outputs
1483  std::vector<int> out_types(num_out, -1);
1484 
1485  int retval = inferType(attrs, &in_types, &out_types);
1486  if (!retval)
1487  return retval;
1488 
1489  // copy modified input types
1490  for (int i = 0; i < num_in; i++) {
1491  intypes[i] = in_types[i];
1492  }
1493  // copy output types
1494  for (int i = 0; i < num_out; i++) {
1495  outtypes[i] = out_types[i];
1496  }
1497 
1498  return retval;
1499  }
1500 
1502  MX_INT_RET _opCallInferSType(inferSType_t inferSType, const char* const* keys,
1503  const char* const* vals, int num,
1504  int* instypes, int num_in, int* outstypes, int num_out) {
1505  // create map of attributes from list
1506  std::unordered_map<std::string, std::string> attrs;
1507  for (int i = 0; i < num; i++) {
1508  attrs[std::string(keys[i])] = std::string(vals[i]);
1509  }
1510 
1511  // create a vector of types for inputs
1512  std::vector<int> in_stypes(num_in);
1513  for (int i = 0; i < num_in; i++) {
1514  in_stypes[i] = instypes[i];
1515  }
1516 
1517  // create a vector of types for outputs
1518  std::vector<int> out_stypes(num_out, -1);
1519 
1520  int retval = inferSType(attrs, &in_stypes, &out_stypes);
1521 
1522  if (!retval)
1523  return retval;
1524 
1525  // copy modified input storage types
1526  for (int i = 0; i < num_in; i++) {
1527  instypes[i] = in_stypes[i];
1528  }
1529  // copy output storage types
1530  for (int i = 0; i < num_out; i++) {
1531  outstypes[i] = out_stypes[i];
1532  }
1533 
1534  return retval;
1535  }
1536 
1538  MX_INT_RET _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals,
1539  int num, const int64_t** inshapes, int* indims, void** indata,
1540  int* intypes, size_t* inIDs, const char** indev_type, int* indev_id,
1541  int num_in, const int64_t** outshapes, int* outdims, void** outdata,
1542  int* outtypes, size_t* outIDs, const char** outdev_type,
1543  int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc,
1544  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
1545  sparse_malloc_t sparse_malloc, void* sparse_alloc,
1546  int* instypes, int* outstypes, void** in_indices, void** out_indices,
1547  void** in_indptr, void** out_indptr,
1548  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1549  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1550  void* rng_cpu_states, void* rng_gpu_states) {
1551  // create map of attributes from list
1552  std::unordered_map<std::string, std::string> attrs;
1553  for (int i = 0; i < num; i++) {
1554  attrs[std::string(keys[i])] = std::string(vals[i]);
1555  }
1556 
1557  // create a vector of tensors for inputs
1558  std::vector<MXTensor> inputs(num_in);
1559  // create a vector for sparse inputs
1560  std::vector<MXSparse> in_sparse(num_in);
1561 
1562  for (int i = 0; i < num_in; i++) {
1563  // Dense representation.
1564  if (instypes[i] == 0) {
1565  inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
1566  inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage);
1567  } else {
1568  // Sparse representation.
1569  MXStorageType type;
1570  if (instypes[i] == 1) {
1571  type = kRowSparseStorage;
1572  in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
1573  } else {
1574  type = kCSRStorage;
1575  in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
1576  in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]);
1577  }
1578  inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (MXDType)intypes[i],
1579  inshapes[i], indims[i], inIDs[i],
1580  MXContext(indev_type[i], indev_id[i]), type);
1581  }
1582  }
1583 
1584  // create a vector of tensors for outputs
1585  std::vector<MXTensor> outputs(num_out);
1586  std::vector<MXSparse> out_sparse(num_out);
1587 
1588  for (int i = 0; i < num_out; i++) {
1589  // Dense representation.
1590  if (outstypes[i] == 0) {
1591  outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
1592  outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage);
1593  } else {
1594  // Sparse representation.
1595  MXStorageType type;
1596  if (outstypes[i] == 1) {
1597  type = kRowSparseStorage;
1598  out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
1599  out_indices[i], out_indices_shapes[i]);
1600  } else {
1601  type = kCSRStorage;
1602  out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
1603  out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]);
1604  }
1605  outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), (MXDType)outtypes[i],
1606  outshapes[i], outdims[i], outIDs[i],
1607  MXContext(outdev_type[i], outdev_id[i]), type);
1608  }
1609  }
1610 
1611  OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
1612  cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
1613  return fcomp(attrs, &inputs, &outputs, res);
1614  }
1615 
1617  MX_INT_RET _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys,
1618  const char* const* vals, int num,
1619  int** mutate_indices, int* indices_size) {
1620  // create map of attributes from list
1621  std::unordered_map<std::string, std::string> attrs;
1622  for (int i = 0; i < num; i++) {
1623  attrs[std::string(keys[i])] = std::string(vals[i]);
1624  }
1625 
1626  // create a vector of mutate input indices
1627  std::vector<int> mut_ind;
1628 
1629  int retval = mutate(attrs, &mut_ind);
1630  if (!retval)
1631  return retval;
1632 
1633  // output the input indices
1634  *indices_size = mut_ind.size();
1635  *mutate_indices = static_cast<int*>(malloc (*indices_size * sizeof(int)));
1636  for (int i = 0; i < *indices_size; i++) {
1637  (*mutate_indices)[i] = mut_ind[i];
1638  }
1639 
1640  return retval;
1641  }
1642 
1644  MX_INT_RET _opCallCreateOpState(createOpState_t create_op, const char* const* keys,
1645  const char* const* vals, int num,
1646  void** state_op) {
1647  // create map of attributes from list
1648  std::unordered_map<std::string, std::string> attrs;
1649  for (int i = 0; i < num; i++) {
1650  attrs[std::string(keys[i])] = std::string(vals[i]);
1651  }
1652 
1653  // void pointer to hold custom state op instance created in custom library
1654  // eventually state_op pointer is populated by instance from custom library
1655  CustomStatefulOp** op_ptr = reinterpret_cast<CustomStatefulOp**>(state_op);
1656  return create_op(attrs, op_ptr);
1657  }
1658 
1660  MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
1661  int* indims, void** indata, int* intypes, size_t* inIDs,
1662  const char** indev_type, int* indev_id, int num_in,
1663  const int64_t** outshapes, int* outdims, void** outdata,
1664  int* outtypes, size_t* outIDs, const char** outdev_type,
1665  int* outdev_id, int num_out, xpu_malloc_t cpu_malloc,
1666  void* cpu_alloc, xpu_malloc_t gpu_malloc, void* gpu_alloc,
1667  void* stream, sparse_malloc_t sparse_malloc,
1668  void* sparse_alloc, int* instypes, int* outstypes,
1669  void** in_indices, void** out_indices, void** in_indptr,
1670  void** out_indptr, int64_t* in_indices_shapes,
1671  int64_t* out_indices_shapes, int64_t* in_indptr_shapes,
1672  int64_t* out_indptr_shapes,
1673  void* rng_cpu_states, void* rng_gpu_states) {
1674  // create a vector of tensors for inputs
1675  std::vector<MXTensor> inputs(num_in);
1676  // create a vector for sparse inputs
1677  std::vector<MXSparse> in_sparse(num_in);
1678 
1679  for (int i = 0; i < num_in; i++) {
1680  if (instypes[i] == 0) {
1681  // Dense representation.
1682  inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i],
1683  inIDs[i], MXContext(indev_type[i], indev_id[i]), kDefaultStorage);
1684  } else {
1685  // Sparse representation.
1686  MXStorageType type;
1687  if (instypes[i] == 1) {
1688  type = kRowSparseStorage;
1689  in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]);
1690  } else {
1691  type = kCSRStorage;
1692  in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
1693  in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]);
1694  }
1695  inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), (MXDType)intypes[i],
1696  inshapes[i], indims[i], inIDs[i],
1697  MXContext(indev_type[i], indev_id[i]), type);
1698  }
1699  }
1700 
1701  // create a vector of tensors for outputs
1702  std::vector<MXTensor> outputs(num_out);
1703  // create a vector for sparse outputs
1704  std::vector<MXSparse> out_sparse(num_out);
1705 
1706  for (int i = 0; i < num_out; i++) {
1707  if (outstypes[i] == 0) {
1708  // Dense representation.
1709  outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i],
1710  outIDs[i], MXContext(outdev_type[i], outdev_id[i]), kDefaultStorage);
1711  } else {
1712  // Sparse representation.
1713  MXStorageType type;
1714  if (outstypes[i] == 1) {
1715  type = kRowSparseStorage;
1716  out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
1717  out_indices_shapes[i]);
1718  } else {
1719  type = kCSRStorage;
1720  out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i],
1721  out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]);
1722  }
1723  outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), (MXDType)outtypes[i],
1724  outshapes[i], outdims[i], outIDs[i],
1725  MXContext(outdev_type[i], outdev_id[i]), type);
1726  }
1727  }
1728 
1729  OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
1730  stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
1731 
1732  CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
1733  if (is_forward) {
1734  return op_ptr->Forward(&inputs, &outputs, res);
1735  }
1736  return op_ptr->Backward(&inputs, &outputs, res);
1737  }
1738 
1742  }
1743 
1744  /* returns number of strategies registered for partitioner
1745  * at specified index */
1746  MX_INT_RET _partRegGetCount(int idx, const char** name) {
1748  *name = part.name;
1749  return part.strategies.size();
1750  }
1751 
1753  MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy,
1754  supportedOps_t* supportedOps, createSelector_t* createSelector,
1755  reviewSubgraph_t* reviewSubgraph, const char** op_name) {
1757  *strategy = part.strategies[stg_idx];
1758  *op_name = part.op_names[stg_idx];
1759  *supportedOps = part.getSupportedOps(stg_idx);
1760  *createSelector = part.getCreateSelector(stg_idx);
1761  *reviewSubgraph = part.getReviewSubgraph(stg_idx);
1762  }
1763 
1765  MX_INT_RET _partCallSupportedOps(supportedOps_t supportedOps, const char *json,
1766  int num_ids, int *ids, const char* const* opt_keys,
1767  const char* const* opt_vals, int num_opts) {
1768  std::string subgraph_json(json);
1769  // create map of options from list
1770  std::unordered_map<std::string, std::string> opts;
1771  for (int i = 0; i < num_opts; i++)
1772  opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
1773 
1774  // create array of subgraph IDs for operator support
1775  std::vector<int> _ids(num_ids, -2);
1776  // call user's supportedOps function
1777  MXReturnValue retval = supportedOps(subgraph_json, &_ids, opts);
1778  if (!retval) return retval;
1779 
1780  // copy bools in ids to ints
1781  for (int i = 0; i < num_ids; i++)
1782  ids[i] = _ids[i];
1783 
1784  return retval;
1785  }
1786 
1788  MX_INT_RET _partCallCreateSelector(createSelector_t createSelector, const char *json,
1789  void** selector, const char* const* opt_keys,
1790  const char* const* opt_vals, int num_opts) {
1791  std::string symbol_json(json);
1792  // create map of options from list
1793  std::unordered_map<std::string, std::string> opts;
1794  for (int i = 0; i < num_opts; i++)
1795  opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
1796 
1797  // void pointer to hold selector instance created in custom library
1798  // eventually pointer is populated by instance from custom library
1799  CustomOpSelector** sel_ptr = reinterpret_cast<CustomOpSelector**>(selector);
1800 
1801  // call user's createSelector function
1802  return createSelector(symbol_json, sel_ptr, opts);
1803  }
1804 
1806  MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) {
1807  CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
1808  *selected = sel_ptr->Select(nodeID);
1809  }
1810 
1812  MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID,
1813  int input_nodeID, int* selected) {
1814  CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
1815  *selected = sel_ptr->SelectInput(nodeID, input_nodeID);
1816  }
1817 
1819  MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID,
1820  int output_nodeID, int* selected) {
1821  CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
1822  *selected = sel_ptr->SelectOutput(nodeID, output_nodeID);
1823  }
1824 
1826  MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates,
1827  int** keep, int* num_keep) {
1828  CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
1829  std::vector<int> candidates_(num_candidates);
1830  for (int i=0; i < num_candidates; i++) {
1831  candidates_[i] = candidates[i];
1832  }
1833  std::vector<int> keep_;
1834 
1835  sel_ptr->Filter(candidates_, &keep_);
1836 
1837  *num_keep = keep_.size();
1838  *keep = static_cast<int*>(malloc(keep_.size() * sizeof(int)));
1839  for (unsigned i=0; i < keep_.size(); i++)
1840  (*keep)[i] = keep_[i];
1841  }
1842 
1844  MX_VOID_RET _partCallReset(void* sel_inst) {
1845  CustomOpSelector* sel_ptr = reinterpret_cast<CustomOpSelector*>(sel_inst);
1846  sel_ptr->Reset();
1847  }
1848 
1850  MX_INT_RET _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json,
1851  int subgraph_id, int *accept, const char* const* opt_keys,
1852  const char* const* opt_vals, int num_opts,
1853  char*** attr_keys, char*** attr_vals, int *num_attrs,
1854  const char* const* arg_names, int num_args,
1855  void* const* arg_data, const int64_t* const* arg_shapes,
1856  const int* arg_dims, const int* arg_types,
1857  const size_t* arg_IDs, const char* const* arg_dev_type,
1858  const int* arg_dev_id,
1859  const char* const* aux_names, int num_aux,
1860  void* const* aux_data, const int64_t* const* aux_shapes,
1861  const int* aux_dims, const int* aux_types,
1862  const size_t* aux_IDs, const char* const* aux_dev_type,
1863  const int* aux_dev_id) {
1864  std::string subgraph_json(json);
1865  bool accept_bool = false;
1866  // create map of attributes from list
1867  std::unordered_map<std::string, std::string> opts;
1868  for (int i = 0; i < num_opts; i++)
1869  opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
1870 
1871  // create a map of named tensors for args
1872  std::unordered_map<std::string, MXTensor> args;
1873  for (int i = 0; i < num_args; i++) {
1874  std::vector<int64_t> shapes;
1875  for (int j = 0; j < arg_dims[i]; j++)
1876  shapes.push_back(arg_shapes[i][j]);
1877 
1878  MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i],
1879  arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i]));
1880  args[arg_names[i]] = tensor;
1881  }
1882  // create a map of named tensors for aux
1883  std::unordered_map<std::string, MXTensor> aux;
1884  for (int i = 0; i < num_aux; i++) {
1885  std::vector<int64_t> shapes;
1886  for (int j = 0; j < aux_dims[i]; j++)
1887  shapes.push_back(aux_shapes[i][j]);
1888 
1889  MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i],
1890  aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i]));
1891  aux[aux_names[i]] = tensor;
1892  }
1893 
1894  // attributes to set on subgraph node
1895  std::unordered_map<std::string, std::string> attrs;
1896 
1897  MXReturnValue retval = reviewSubgraph(subgraph_json, subgraph_id, &accept_bool,
1898  opts, &attrs, args, aux);
1899  if (!retval) return retval;
1900 
1901  *accept = accept_bool;
1902 
1903  if (attrs.size() > 0) {
1904  *num_attrs = attrs.size();
1905  // allocate space for attributes
1906  *attr_keys = static_cast<char**>(malloc (attrs.size() * sizeof(char*)));
1907  *attr_vals = static_cast<char**>(malloc (attrs.size() * sizeof(char*)));
1908 
1909  // copy attributes
1910  int i = 0;
1911  for (auto kv : attrs) {
1912  (*attr_keys)[i] = static_cast<char*>(malloc ((kv.first.size()+1) * sizeof(char)));
1913  (*attr_vals)[i] = static_cast<char*>(malloc ((kv.second.size()+1) * sizeof(char)));
1914  snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str());
1915  snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str());
1916  i++;
1917  }
1918  }
1919 
1920  return retval;
1921  }
1922 
1925  return Registry<CustomPass>::get()->size();
1926  }
1927 
1929  MX_VOID_RET _passRegGet(int pass_idx, graphPass_t* graphPass,
1930  const char** pass_name) {
1931  CustomPass pass = Registry<CustomPass>::get()->get(pass_idx);
1932  *graphPass = pass.pass;
1933  *pass_name = pass.name;
1934  }
1935 
1937  MX_INT_RET _passCallGraphPass(graphPass_t graphPass, const char *json,
1938  char** graph, const char* const* opt_keys,
1939  const char* const* opt_vals, int num_opts,
1940  const char* pass_name, const char* const* arg_names, int num_args,
1941  void* const* arg_data, const int64_t* const* arg_shapes,
1942  const int* arg_dims, const int* arg_types,
1943  const size_t* arg_IDs, const char* const* arg_dev_type,
1944  const int* arg_dev_id, const char* const* aux_names, int num_aux,
1945  void* const* aux_data, const int64_t* const* aux_shapes,
1946  const int* aux_dims, const int* aux_types,
1947  const size_t* aux_IDs, const char* const* aux_dev_type,
1948  const int* aux_dev_id, nd_malloc_t nd_malloc,
1949  const void* nd_alloc) {
1950  std::string graph_json(json);
1951  const std::string* out_graph = nullptr;
1952  // create map of attributes from list
1953  std::unordered_map<std::string, std::string> opts;
1954  for (int i = 0; i < num_opts; i++)
1955  opts[std::string(opt_keys[i])] = std::string(opt_vals[i]);
1956 
1957  // create a map of named tensors for args
1958  std::unordered_map<std::string, MXTensor> args;
1959  for (int i = 0; i < num_args; i++) {
1960  std::vector<int64_t> shapes;
1961  for (int j = 0; j < arg_dims[i]; j++)
1962  shapes.push_back(arg_shapes[i][j]);
1963 
1964  MXTensor tensor(arg_data[i], shapes, (MXDType)arg_types[i],
1965  arg_IDs[i], MXContext(arg_dev_type[i], arg_dev_id[i]));
1966  args[arg_names[i]] = tensor;
1967  }
1968  // create a map of named tensors for aux
1969  std::unordered_map<std::string, MXTensor> aux;
1970  for (int i = 0; i < num_aux; i++) {
1971  std::vector<int64_t> shapes;
1972  for (int j = 0; j < aux_dims[i]; j++)
1973  shapes.push_back(aux_shapes[i][j]);
1974 
1975  MXTensor tensor(aux_data[i], shapes, (MXDType)aux_types[i],
1976  aux_IDs[i], MXContext(aux_dev_type[i], aux_dev_id[i]));
1977  aux[aux_names[i]] = tensor;
1978  }
1979 
1980  std::unordered_map<std::string, MXTensor> new_args, new_aux;
1981  PassResource res(&new_args, &new_aux, nd_malloc, nd_alloc);
1982  MXReturnValue retval = graphPass(graph_json, &out_graph, opts, args, aux, res);
1983  if (!retval) return retval;
1984 
1985  if (out_graph == nullptr) {
1986  std::cout << "Error calling graph pass '" << pass_name
1987  << "' returned out_graph string is null" << std::endl;
1988  return MX_FAIL;
1989  }
1990  *graph = static_cast<char*>(malloc((out_graph->length()+1) * sizeof(char)));
1991  out_graph->copy(*graph, out_graph->size()+1);
1992  delete out_graph;
1993  return retval;
1994  }
1995 
2003 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
2004  __declspec(dllexport) MXReturnValue __cdecl
2005 #else
2007 #endif
2008  initialize(int version);
2009 }
2010 #endif // MXNET_LIB_API_H_
Definition: lib_api.h:234
int(* opRegSize_t)(void)
Definition: lib_api.h:1176
MX_INT_RET _opCallInferShape(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
Definition: lib_api.h:1412
virtual bool Select(int nodeID)=0
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:978
reviewSubgraph_t getReviewSubgraph(int stg_id)
Definition: lib_api.h:1075
CustomOp & setInferType(inferType_t func)
Definition: lib_api.h:924
CustomOp & setForward(fcomp_t fcomp, const char *ctx)
Definition: lib_api.h:908
bool isSame(const MXTensor &oth) const
helper function to compare two MXTensors
Definition: lib_api.h:408
CustomPass & setBody(graphPass_t fn)
Definition: lib_api.h:1010
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:1087
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:98
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:1188
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, void **state_op)
Definition: lib_api.h:1239
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:458
JsonType type
Definition: lib_api.h:641
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:522
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1285
CustomOp & setBackward(fcomp_t fgrad, const char *ctx)
Definition: lib_api.h:914
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:979
OpenCL devices.
Definition: lib_api.h:109
Definition: lib_api.h:260
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:139
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
Definition: lib_api.h:1819
static MXContext CPU(int dev_id)
Definition: lib_api.h:252
MXContext ctx
Definition: lib_api.h:432
PassResource(std::unordered_map< std::string, MXTensor > *new_args, std::unordered_map< std::string, MXTensor > *new_aux, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:467
Definition: lib_api.h:265
MX_INT_RET _opRegSize()
returns number of ops registered in this library
Definition: lib_api.h:1361
std::vector< int64_t > shape
Definition: lib_api.h:423
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:835
MXTensor(const MXTensor &oth)
Definition: lib_api.h:308
CustomOp & setIsSubgraphOp()
Definition: lib_api.h:946
Metal for Apple GPU.
Definition: lib_api.h:113
void setDLTensor()
populate DLTensor fields
Definition: lib_api.h:330
Definition: lib_api.h:603
Definition: lib_api.h:142
#define MX_VOID_RET
Definition: lib_api.h:1351
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:117
inferType_t infer_type
Definition: lib_api.h:971
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
Definition: lib_api.h:1753
A Device context for Tensor and operator.
Definition: dlpack.h:69
CUDA GPU device.
Definition: lib_api.h:102
An abstract class for graph passes.
Definition: lib_api.h:1005
void * alloc_cpu(int size) const
allocate cpu memory controlled by MXNet
Definition: lib_api.h:512
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1267
mx_cpu_rand_t * get_cpu_rand_states() const
get pointer to initialized and seeded random number states located on CPU
Definition: lib_api.h:534
Definition: lib_api.h:217
std::vector< const char * > backward_ctx_cstr
Definition: lib_api.h:978
std::vector< fcomp_t > backward_fp
Definition: lib_api.h:979
MXReturnValue(* supportedOps_t)(const std::string &json, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:1022
JsonVal parse_num(const std::string &json, unsigned int *idx)
Definition: lib_api.h:701
definition of JSON objects
Definition: lib_api.h:606
MXTensor(void *data_ptr, const std::vector< int64_t > &shape, MXDType dtype, size_t vID, MXContext mx_ctx, MXStorageType stype=kDefaultStorage)
Definition: lib_api.h:312
CustomPartitioner & setCreateSelector(const char *prop_name, createSelector_t fn)
Definition: lib_api.h:1053
MXReturnValue(* createSelector_t)(const std::string &json, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:1025
int64_t size() const
helper function to get data size
Definition: lib_api.h:399
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:123
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:107
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:873
Definition: optional.h:241
int(* opVersion_t)()
Definition: lib_api.h:1344
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:62
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
Definition: lib_api.h:1740
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1300
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1213
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
Definition: lib_api.h:1806
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, CustomStatefulOp **)
Definition: lib_api.h:896
MX_INT_RET _opCallInferType(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
Definition: lib_api.h:1467
An abstract class for subgraph property.
Definition: lib_api.h:1038
Definition: lib_api.h:221
inferShape_t infer_shape
Definition: lib_api.h:973
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
Definition: lib_api.h:1826
MXStorageType stype
Definition: lib_api.h:439
MXReturnValue(* reviewSubgraph_t)(const std::string &json, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs, const std::unordered_map< std::string, MXTensor > &args, const std::unordered_map< std::string, MXTensor > &aux)
Definition: lib_api.h:1028
void print_json_val(const JsonVal &val)
Definition: lib_api.h:654
T & add(const char *name)
add a new entry
Definition: lib_api.h:1113
graphPass_t pass
pass function
Definition: lib_api.h:1018
void * alloc_gpu(int size) const
allocate gpu memory controlled by MXNet
Definition: lib_api.h:517
JsonVal parse_string(const std::string &json, unsigned int *idx)
Definition: lib_api.h:686
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
Definition: lib_api.h:1812
CustomPartitioner & addStrategy(const char *prop_name, const char *sg_name)
Definition: lib_api.h:1043
int num
Definition: lib_api.h:642
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1244
~CustomOp()
Definition: lib_api.h:964
bool operator<(const JsonVal &o) const
Definition: lib_api.h:616
Definition: lib_api.h:603
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:889
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:881
virtual MXReturnValue Forward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)=0
JsonVal(std::string s)
Definition: lib_api.h:611
void setTensor(void *dptr, MXDType type, const int64_t *dims, int ndims, size_t vID, MXContext mx_ctx, MXStorageType storage_type)
populate internal tensor fields
Definition: lib_api.h:319
MX_INT_RET _passCallGraphPass(graphPass_t graphPass, const char *json, char **graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
Definition: lib_api.h:1937
MX_INT_RET _opCallMutateInputs(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
Definition: lib_api.h:1617
int64_t * indptr
Definition: lib_api.h:279
void alloc_sparse(MXSparse *sparse, int index, int indices_len, int indptr_len=0) const
allocate sparse memory controlled by MXNet
Definition: lib_api.h:527
Context info passing from MXNet OpContext dev_type is string repr of supported context, currently only "cpu" and "gpu" dev_id is the device index where the tensor locates.
Definition: lib_api.h:244
MXContext(const char *dev_type_, int dev_id_)
Definition: lib_api.h:248
MXContext(std::string dev_type_, int dev_id_)
Definition: lib_api.h:246
Definition: lib_api.h:232
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:850
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1280
Definition: lib_api.h:236
Tensor data structure used by custom operator.
Definition: lib_api.h:306
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:394
CustomPartitioner()
Definition: lib_api.h:1040
OpResource(xpu_malloc_t cpu_malloc_fp, void *cpu_alloc_fp, xpu_malloc_t gpu_malloc_fp, void *gpu_alloc_fp, void *stream, sparse_malloc_t sparse_malloc_fp, void *sparse_alloc_fp, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:502
Definition: lib_api.h:220
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:603
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1292
MX_INT_RET _opCallFCompute(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
Definition: lib_api.h:1538
MX_INT_RET _partCallSupportedOps(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
Definition: lib_api.h:1765
Definition: lib_api.h:140
bool isSGop
Definition: lib_api.h:975
CustomPartitioner(const char *backend_name)
Definition: lib_api.h:1041
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:645
std::string json_val_string(const JsonVal &val)
Definition: lib_api.h:658
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:541
MXTensor()
Definition: lib_api.h:307
int64_t data_len
Definition: lib_api.h:269
CustomOp & setCreateOpState(createOpState_t func, const char *ctx)
Definition: lib_api.h:940
JsonVal parse_list(const std::string &json, unsigned int *idx)
Definition: lib_api.h:715
MXStorageType
Definition: lib_api.h:230
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:1085
int(* partRegSize_t)(void)
Definition: lib_api.h:1264
Vulkan buffer for next generation graphics.
Definition: lib_api.h:111
void * data
Definition: lib_api.h:267
MX_INT_RET _opVersion()
returns MXNet library version
Definition: lib_api.h:1356
static Registry * get() PRIVATE_SYMBOL
get singleton pointer to class
Definition: lib_api.h:1105
static MXContext GPU(int dev_id)
Definition: lib_api.h:253
virtual bool SelectOutput(int nodeID, int output_nodeID)=0
static MXContext CPU()
Definition: lib_api.h:250
void * data_ptr
Definition: lib_api.h:420
createSelector_t getCreateSelector(int stg_id)
Definition: lib_api.h:1068
supportedOps_t getSupportedOps(int stg_id)
Definition: lib_api.h:1061
Class to hold custom operator registration.
Definition: lib_api.h:903
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1322
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:500
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:455
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:1208
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
MX_INT_RET _partRegGetCount(int idx, const char **name)
Definition: lib_api.h:1746
Verilog simulator buffer.
Definition: lib_api.h:115
JsonVal(JsonType t)
Definition: lib_api.h:609
#define MX_LIBRARY_VERSION
Definition: lib_api.h:52
size_t verID
Definition: lib_api.h:429
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
Definition: lib_api.h:1394
int64_t indptr_len
Definition: lib_api.h:280
MXDType dtype
Definition: lib_api.h:426
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
Definition: lib_api.h:1924
Definition: lib_api.h:141
functions used for parsing JSON
Definition: lib_api.h:649
Definition: lib_api.h:603
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1296
JsonVal(JsonType t, int n, std::string s)
Definition: lib_api.h:615
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, createOpState_t **create_op_fp, int *create_op_count, parseAttrs_t *parse, inferType_t *type, inferSType_t *stype, inferShape_t *shape, mutateInputs_t *mutate)
Definition: lib_api.h:1179
MX_INT_RET _partCallReviewSubgraph(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
Definition: lib_api.h:1850
MXReturnValue
Definition: lib_api.h:259
mutateInputs_t mutate_inputs
Definition: lib_api.h:974
CustomPass(const char *pass_name)
Definition: lib_api.h:1008
DLTensor dltensor
Definition: lib_api.h:436
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:864
Definition: lib_api.h:603
JsonVal(int n)
Definition: lib_api.h:613
int(* passRegSize_t)(void)
Definition: lib_api.h:1319
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:1191
std::string dev_type
Definition: lib_api.h:255
#define MX_INT_RET
Definition: lib_api.h:1350
MXReturnValue(* graphPass_t)(const std::string &in_graph, const std::string **out_graph, const std::unordered_map< std::string, std::string > &options, const std::unordered_map< std::string, MXTensor > &args, const std::unordered_map< std::string, MXTensor > &aux, const PassResource &res)
Custom Pass Create function template.
Definition: lib_api.h:996
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
Definition: lib_api.h:1660
Definition: lib_api.h:218
Definition: lib_api.h:224
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
Definition: lib_api.h:1844
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:216
const char * name
partitioner name
Definition: lib_api.h:1084
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1288
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:1089
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1234
static MXContext GPU()
Definition: lib_api.h:251
CustomPartitioner & setReviewSubgraph(const char *prop_name, reviewSubgraph_t fn)
Definition: lib_api.h:1057
std::string str
Definition: lib_api.h:643
std::string dump(const JsonVal &val)
Definition: lib_api.h:772
CPU device.
Definition: lib_api.h:100
Definition: lib_api.h:223
std::vector< const char * > create_op_ctx_cstr
Definition: lib_api.h:978
MX_VOID_RET _passRegGet(int pass_idx, graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
Definition: lib_api.h:1929
const char * name
pass name
Definition: lib_api.h:1016
CustomOp(const char *op_name)
Definition: lib_api.h:905
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:445
JsonVal()
Definition: lib_api.h:607
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:855
std::string getDtypeAt(const std::string &dtype, unsigned index)
Definition: lib_api.h:589
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:885
std::string getShapeAt(const std::string &shape, unsigned index)
Definition: lib_api.h:572
MXContext()
Definition: lib_api.h:245
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:443
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1325
CustomPartitioner & setSupportedOps(const char *prop_name, supportedOps_t fn)
Definition: lib_api.h:1049
Definition: lib_api.h:219
virtual bool SelectInput(int nodeID, int input_nodeID)=0
int(* initialize_t)(int version)
Definition: lib_api.h:1341
CustomOp & setMutateInputs(mutateInputs_t func)
Definition: lib_api.h:936
MX_INT_RET _opCallCreateOpState(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, void **state_op)
returns status of calling createStatefulOp function for operator from library
Definition: lib_api.h:1644
std::vector< JsonVal > list
Definition: lib_api.h:644
CustomOp & setParseAttrs(parseAttrs_t func)
Definition: lib_api.h:920
JsonVal parse_map(const std::string &json, unsigned int *idx)
Definition: lib_api.h:731
CustomStatefulOp * get_instance()
Definition: lib_api.h:867
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:1196
MXTensor * alloc_aux(const std::string &name, const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype) const
Definition: lib_api.h:480
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, createOpState_t **create_op_fp, int *create_op_count, parseAttrs_t *parse, inferType_t *type, inferSType_t *stype, inferShape_t *shape, mutateInputs_t *mutate)
returns operator registration at specified index
Definition: lib_api.h:1366
CustomOp & setInferSType(inferSType_t func)
Definition: lib_api.h:928
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:1203
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:878
const char * name
operator name
Definition: lib_api.h:967
CustomStatefulOpWrapper(CustomStatefulOp *inst)
Definition: lib_api.h:866
void * mx_gpu_rand_t
Definition: lib_api.h:456
Definition: lib_api.h:812
JsonVal parse_to_json(const std::string &json)
Definition: lib_api.h:650
inferSType_t infer_storage_type
Definition: lib_api.h:972
Definition: lib_api.h:222
MXTensor * alloc_arg(const std::string &name, const std::vector< int64_t > &shapes, const MXContext &ctx, MXDType dtype) const
Definition: lib_api.h:471
virtual void Reset()
Definition: lib_api.h:842
CustomPass()
Definition: lib_api.h:1007
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:1091
void mapToVector()
Definition: lib_api.h:950
MX_INT_RET _opCallParseAttrs(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
Definition: lib_api.h:1399
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:970
int dev_id
Definition: lib_api.h:256
The data type the tensor can hold.
Definition: dlpack.h:94
MX_INT_RET _partCallCreateSelector(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
Definition: lib_api.h:1788
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1275
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library...
MX_INT_RET _opCallInferSType(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
Definition: lib_api.h:1502
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:1086
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1270
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:1099
Definition: lib_api.h:603
Definition: lib_api.h:465
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:893
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:980
JsonVal parse(const std::string &json, unsigned int *idx)
Definition: lib_api.h:751
int64_t indices_len
Definition: lib_api.h:275
int size()
Definition: lib_api.h:1118
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1303
CustomOp & setInferShape(inferShape_t func)
Definition: lib_api.h:932
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:447
int64_t * indices
Definition: lib_api.h:274
Definition: lib_api.h:261