mxnet
packed_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: This file originates from incubator-tvm
25 #ifndef MXNET_RUNTIME_PACKED_FUNC_H_
26 #define MXNET_RUNTIME_PACKED_FUNC_H_
27 
28 #include <dmlc/logging.h>
30 #include <mxnet/runtime/object.h>
31 #include <mxnet/runtime/ndarray.h>
35 #include <mxnet/node/container.h>
36 #include <mxnet/ir/expr.h>
37 #include <mxnet/ndarray.h>
38 #include <mxnet/base.h>
39 #include <functional>
40 #include <tuple>
41 #include <vector>
42 #include <string>
43 #include <limits>
44 #include <memory>
45 #include <utility>
46 #include <type_traits>
47 #include <sstream>
48 
49 namespace mxnet {
50 // forward declarations
51 // class Integer;
52 // class Expr;
53 
54 namespace runtime {
55 
61 inline DLDataType String2DLDataType(std::string s);
62 
63 // forward declarations
64 class MXNetArgs;
65 class MXNetArgValue;
66 class MXNetRetValue;
67 class MXNetArgsSetter;
68 
77 class PackedFunc {
78  public:
97  using FType = std::function<void (MXNetArgs args, MXNetRetValue* rv)>;
101  PackedFunc(std::nullptr_t null) {} // NOLINT(*)
106  explicit PackedFunc(FType body) : body_(body) {}
121  template<typename... Args>
122  inline MXNetRetValue operator()(Args&& ...args) const;
128  inline void CallPacked(MXNetArgs args, MXNetRetValue* rv) const;
130  inline FType body() const;
132  bool operator==(std::nullptr_t null) const {
133  return body_ == nullptr;
134  }
136  bool operator!=(std::nullptr_t null) const {
137  return body_ != nullptr;
138  }
139 
140  private:
142  FType body_;
143 };
144 
148 template<typename FType>
150 
183 template<typename R, typename ...Args>
184 class TypedPackedFunc<R(Args...)> {
185  public:
187  using TSelf = TypedPackedFunc<R(Args...)>;
191  TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
209  inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
214  inline TypedPackedFunc(const MXNetRetValue& value); // NOLINT(*)
219  inline TypedPackedFunc(const MXNetArgValue& value); // NOLINT(*)
235  template<typename FLambda,
236  typename = typename std::enable_if<
237  std::is_convertible<FLambda,
238  std::function<R(Args...)>
239  >::value>::type>
240  TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
241  this->AssignTypedLambda(typed_lambda);
242  }
259  template<typename FLambda,
260  typename = typename std::enable_if<
261  std::is_convertible<FLambda,
262  std::function<R(Args...)>
263  >::value>::type>
264  TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
265  this->AssignTypedLambda(typed_lambda);
266  return *this;
267  }
274  packed_ = packed;
275  return *this;
276  }
282  inline R operator()(Args ...args) const;
287  operator PackedFunc() const {
288  return packed();
289  }
293  const PackedFunc& packed() const {
294  return packed_;
295  }
297  bool operator==(std::nullptr_t null) const {
298  return packed_ == nullptr;
299  }
301  bool operator!=(std::nullptr_t null) const {
302  return packed_ != nullptr;
303  }
304 
305  private:
306  friend class MXNetRetValue;
308  PackedFunc packed_;
316  template<typename FLambda>
317  inline void AssignTypedLambda(FLambda flambda);
318 };
319 
321 class MXNetArgs {
322  public:
324  const int* type_codes;
325  int num_args;
332  MXNetArgs(const MXNetValue* values,
333  const int* type_codes,
334  int num_args)
335  : values(values),
336  type_codes(type_codes),
337  num_args(num_args) { }
339  inline int size() const;
345  inline MXNetArgValue operator[](int i) const;
346 };
347 
353 inline const char* TypeCode2Str(int type_code);
354 
360 // inline TVMType String2TVMType(std::string s);
361 
362 // macro to check type code.
363 #define MXNET_CHECK_TYPE_CODE(CODE, T) \
364  CHECK_EQ(CODE, T) << " expected " \
365  << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
366 
367 
378 template<typename T>
380  static const int code = 0;
381 };
382 
388  public:
389  operator double() const {
390  // Allow automatic conversion from int to float
391  // This avoids errors when user pass in int from
392  // the frontend while the API expects a float.
393  if (type_code_ == kDLInt) {
394  return static_cast<double>(value_.v_int64);
395  }
396  MXNET_CHECK_TYPE_CODE(type_code_, kDLFloat);
397  return value_.v_float64;
398  }
399  operator int64_t() const {
400  MXNET_CHECK_TYPE_CODE(type_code_, kDLInt);
401  return value_.v_int64;
402  }
403  operator uint64_t() const {
404  MXNET_CHECK_TYPE_CODE(type_code_, kDLInt);
405  return value_.v_int64;
406  }
407  operator int() const {
408  MXNET_CHECK_TYPE_CODE(type_code_, kDLInt);
409  CHECK_LE(value_.v_int64,
410  std::numeric_limits<int>::max());
411  return static_cast<int>(value_.v_int64);
412  }
413  operator bool() const {
414  MXNET_CHECK_TYPE_CODE(type_code_, kDLInt);
415  return value_.v_int64 != 0;
416  }
417  operator void*() const {
418  if (type_code_ == kNull) return nullptr;
419  if (type_code_ == kArrayHandle) return value_.v_handle;
420  MXNET_CHECK_TYPE_CODE(type_code_, kHandle);
421  return value_.v_handle;
422  }
423  operator ObjectRef() const {
424  if (type_code_ == kNull) {
425  return ObjectRef(ObjectPtr<Object>(nullptr));
426  }
428  return ObjectRef(
429  ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
430  }
431  template<typename TObjectRef,
432  typename = typename std::enable_if<
433  std::is_class<TObjectRef>::value>::type>
434  inline bool IsObjectRef() const;
435  int type_code() const {
436  return type_code_;
437  }
438 
444  template<typename T>
445  T* ptr() const {
446  return static_cast<T*>(value_.v_handle);
447  }
448 
449  protected:
450  friend class MXNetArgsSetter;
451  friend class MXNetRetValue;
452  MXNetPODValue_() : type_code_(kNull) {}
453  MXNetPODValue_(MXNetValue value, int type_code)
454  : value_(value), type_code_(type_code) {}
455 
460 };
461 
469  public:
477  MXNetArgValue(MXNetValue value, int type_code)
478  : MXNetPODValue_(value, type_code) {
479  }
480  // reuse converter from parent
481  using MXNetPODValue_::operator double;
482  using MXNetPODValue_::operator int64_t;
483  using MXNetPODValue_::operator uint64_t;
484  using MXNetPODValue_::operator int;
485  using MXNetPODValue_::operator bool;
486  using MXNetPODValue_::operator void*;
487  using MXNetPODValue_::operator ObjectRef;
489 
490  // conversion operator.
491  operator std::string() const {
492  if (type_code_ == kBytes) {
493  MXNetByteArray* arr = static_cast<MXNetByteArray*>(value_.v_handle);
494  return std::string(arr->data, arr->size);
495  } else {
496  MXNET_CHECK_TYPE_CODE(type_code_, kStr);
497  return std::string(value_.v_str);
498  }
499  }
500  operator DLDataType() const {
501  if (type_code_ == kStr) {
502  return String2DLDataType(operator std::string());
503  }
504  // None type
505  if (type_code_ == kNull) {
506  DLDataType t;
507  t.code = kHandle; t.bits = 0; t.lanes = 0;
508  return t;
509  }
510  MXNET_CHECK_TYPE_CODE(type_code_, kMXNetType);
511  return value_.v_type;
512  }
513  operator MXNetDataType() const {
514  return MXNetDataType(operator DLDataType());
515  }
516  operator ::mxnet::NDArray*() const {
517  if (type_code_ == kNull) {
518  return nullptr;
519  }
521  return reinterpret_cast<::mxnet::NDArray*>(value_.v_handle);
522  }
523  operator PackedFunc() const {
524  if (type_code_ == kNull) return PackedFunc();
525  MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle);
526  return *ptr<PackedFunc>();
527  }
528  template<typename FType>
529  operator TypedPackedFunc<FType>() const {
530  return TypedPackedFunc<FType>(operator PackedFunc());
531  }
532  const MXNetValue& value() const {
533  return value_;
534  }
535  // Deferred extension handler.
536  template<typename TObjectRef>
537  inline TObjectRef AsObjectRef() const;
538  template<typename T,
539  typename = typename std::enable_if<
540  std::is_class<T>::value>::type>
541  inline operator T() const;
542 };
543 
553  public:
561  : MXNetPODValue_(other.value_, other.type_code_) {
562  other.value_.v_handle = nullptr;
563  other.type_code_ = kNull;
564  }
567  this->Clear();
568  }
569  // reuse converter from parent
570  using MXNetPODValue_::operator double;
571  using MXNetPODValue_::operator int64_t;
572  using MXNetPODValue_::operator uint64_t;
573  using MXNetPODValue_::operator int;
574  using MXNetPODValue_::operator bool;
575  using MXNetPODValue_::operator void*;
576  using MXNetPODValue_::operator ObjectRef;
578 
580  this->Assign(other);
581  }
582  // conversion operators
583  operator std::string() const {
584  if (type_code_ == kBytes) {
585  return *ptr<std::string>();
586  }
587  MXNET_CHECK_TYPE_CODE(type_code_, kStr);
588  return *ptr<std::string>();
589  }
590  operator DLDataType() const {
591  if (type_code_ == kStr) {
592  return String2DLDataType(operator std::string());
593  }
594  MXNET_CHECK_TYPE_CODE(type_code_, kMXNetType);
595  return value_.v_type;
596  }
597  operator MXNetDataType() const {
598  return MXNetDataType(operator DLDataType());
599  }
600  operator PackedFunc() const {
601  if (type_code_ == kNull) return PackedFunc();
602  MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle);
603  return *ptr<PackedFunc>();
604  }
605  template<typename FType>
606  operator TypedPackedFunc<FType>() const {
607  return TypedPackedFunc<FType>(operator PackedFunc());
608  }
609  // Assign operators
611  this->Clear();
612  value_ = other.value_;
613  type_code_ = other.type_code_;
614  other.type_code_ = kNull;
615  return *this;
616  }
617  MXNetRetValue& operator=(double value) {
618  this->SwitchToPOD(kDLFloat);
619  value_.v_float64 = value;
620  return *this;
621  }
622  MXNetRetValue& operator=(std::nullptr_t value) {
623  this->SwitchToPOD(kNull);
624  value_.v_handle = value;
625  return *this;
626  }
627  MXNetRetValue& operator=(void* value) {
628  this->SwitchToPOD(kHandle);
629  value_.v_handle = value;
630  return *this;
631  }
632  MXNetRetValue& operator=(int64_t value) {
633  this->SwitchToPOD(kDLInt);
634  value_.v_int64 = value;
635  return *this;
636  }
637  MXNetRetValue& operator=(int value) {
638  this->SwitchToPOD(kDLInt);
639  value_.v_int64 = value;
640  return *this;
641  }
642  MXNetRetValue& operator=(bool value) {
643  this->SwitchToPOD(kDLInt);
644  value_.v_int64 = value;
645  return *this;
646  }
647  MXNetRetValue& operator=(std::string value) {
648  this->SwitchToClass(kStr, value);
649  return *this;
650  }
652  this->SwitchToPOD(kMXNetType);
653  value_.v_type = t;
654  return *this;
655  }
657  return operator=(other.operator DLDataType());
658  }
660  this->SwitchToClass(kBytes, std::string(value.data, value.size));
661  return *this;
662  }
663  MXNetRetValue& operator=(ObjectRef other) {
664  return operator=(std::move(other.data_));
665  }
666  template<typename T>
668  SwitchToObject(kObjectHandle, std::move(other));
669  return *this;
670  }
672  this->SwitchToClass(kFuncHandle, f);
673  return *this;
674  }
675  template<typename FType>
677  return operator=(f.packed());
678  }
679  MXNetRetValue& operator=(const MXNetRetValue& other) { // NOLINT(*0
680  this->Assign(other);
681  return *this;
682  }
684  this->Assign(other);
685  return *this;
686  }
688  this->SwitchToPOD(kNDArrayHandle);
689  value_.v_handle = reinterpret_cast<void*>(value);
690  return *this;
691  }
692  template<typename T,
693  typename = typename std::enable_if<
694  extension_type_info<T>::code != 0>::type>
695  MXNetRetValue& operator=(const T& other) {
696  this->SwitchToClass<T>(
698  return *this;
699  }
709  void MoveToCHost(MXNetValue* ret_value,
710  int* ret_type_code) {
711  // cannot move str; need specially handle.
712  CHECK(type_code_ != kStr && type_code_ != kBytes);
713  *ret_value = value_;
714  *ret_type_code = type_code_;
715  type_code_ = kNull;
716  }
718  const MXNetValue& value() const {
719  CHECK(type_code_ != kObjectHandle &&
720  type_code_ != kFuncHandle &&
721  type_code_ != kStr) << "MXNetRetValue.value can only be used for POD data";
722  return value_;
723  }
724  // ObjectRef related extenstions: in tvm/packed_func_ext.h
725  template<typename T,
726  typename = typename std::enable_if<
727  std::is_class<T>::value>::type>
728  inline operator T() const;
729  template<typename TObjectRef>
730  inline TObjectRef AsObjectRef() const;
731 
732  private:
733  template<typename T>
734  void Assign(const T& other) {
735  switch (other.type_code()) {
736  case kStr: {
737  SwitchToClass<std::string>(kStr, other);
738  break;
739  }
740  case kBytes: {
741  SwitchToClass<std::string>(kBytes, other);
742  break;
743  }
744  case kFuncHandle: {
745  SwitchToClass<PackedFunc>(kFuncHandle, other);
746  break;
747  }
748  case kObjectHandle: {
749  *this = other.operator ObjectRef();
750  break;
751  }
752  default: {
753  if (other.type_code() < kExtBegin) {
754  SwitchToPOD(other.type_code());
755  value_ = other.value_;
756  } else {
757  LOG(FATAL) << "Does not support ext type";
758  }
759  break;
760  }
761  }
762  }
763  // get the internal container.
764  void SwitchToPOD(int type_code) {
765  if (type_code_ != type_code) {
766  this->Clear();
767  type_code_ = type_code;
768  }
769  }
770  template<typename T>
771  void SwitchToClass(int type_code, T v) {
772  if (type_code_ != type_code) {
773  this->Clear();
774  type_code_ = type_code;
775  value_.v_handle = new T(v);
776  } else {
777  *static_cast<T*>(value_.v_handle) = v;
778  }
779  }
780  void SwitchToObject(int type_code, ObjectPtr<Object> other) {
781  if (other.data_ != nullptr) {
782  this->Clear();
783  type_code_ = type_code;
784  // move the handle out
785  value_.v_handle = other.data_;
786  other.data_ = nullptr;
787  } else {
788  SwitchToPOD(kNull);
789  }
790  }
791  void Clear() {
792  if (type_code_ == kNull) return;
793  switch (type_code_) {
794  case kStr: delete ptr<std::string>(); break;
795  case kFuncHandle: delete ptr<PackedFunc>(); break;
796  case kObjectHandle: {
797  static_cast<Object*>(value_.v_handle)->DecRef();
798  break;
799  }
800  }
801  if (type_code_ > kExtBegin) {
802  LOG(FATAL) << "Does not support ext type";
803  }
804  type_code_ = kNull;
805  }
806 };
807 
808 inline DLDataType String2DLDataType(std::string s) {
809  DLDataType t;
810  // handle None type
811  if (s.length() == 0) {
812  t.bits = 0; t.lanes = 0; t.code = kHandle;
813  return t;
814  }
815  t.bits = 32; t.lanes = 1;
816  const char* scan = nullptr;
817  if (s.substr(0, 3) == "int") {
818  t.code = kDLInt; scan = s.c_str() + 3;
819  } else if (s.substr(0, 4) == "uint") {
820  t.code = kDLUInt; scan = s.c_str() + 4;
821  } else if (s.substr(0, 5) == "float") {
822  t.code = kDLFloat; scan = s.c_str() + 5;
823  } else if (s.substr(0, 6) == "handle") {
824  t.code = kHandle;
825  t.bits = 64; // handle uses 64 bit by default.
826  scan = s.c_str() + 6;
827  } else if (s == "bool") {
828  t.code = kDLUInt;
829  t.bits = 1;
830  t.lanes = 1;
831  return t;
832  } else if (s.substr(0, 6) == "custom") {
833  LOG(FATAL) << "custom MXNetDataType is not supported";
834  // t.code = ParseCustomDatatype(s, &scan);
835  } else {
836  scan = s.c_str();
837  LOG(FATAL) << "unknown type " << s;
838  }
839  char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
840  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
841  if (bits != 0) t.bits = bits;
842  char* endpt = xdelim;
843  if (*xdelim == 'x') {
844  t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
845  }
846  CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
847  return t;
848 }
849 
850 // implementation details
851 inline const char* TypeCode2Str(int type_code) {
852  switch (type_code) {
853  case kDLInt: return "int";
854  case kDLUInt: return "uint";
855  case kDLFloat: return "float";
856  case kStr: return "str";
857  case kBytes: return "bytes";
858  case kHandle: return "handle";
859  case kNull: return "NULL";
860  case kFuncHandle: return "FunctionHandle";
861  case kObjectHandle: return "ObjectCell";
862  default: LOG(FATAL) << "unknown type_code="
863  << static_cast<int>(type_code); return "";
864  }
865 }
866 
867 inline int String2MXNetTypeWithBool(const std::string& s) {
868  if (s == "float32") {
869  return mshadow::kFloat32;
870  } else if (s == "float64") {
871  return mshadow::kFloat64;
872  } else if (s == "float16") {
873  return mshadow::kFloat16;
874  } else if (s == "uint8") {
875  return mshadow::kUint8;
876  } else if (s == "int8") {
877  return mshadow::kInt8;
878  } else if (s == "int32") {
879  return mshadow::kInt32;
880  } else if (s == "int64") {
881  return mshadow::kInt64;
882  } else if (s == "bool") {
883  return mshadow::kBool;
884  } else {
885  LOG(FATAL) << "unknown type " << s;
886  }
887  LOG(FATAL) << "should not reach here ";
888  return 0;
889 }
890 
891 inline int String2MXNetType(const std::string& s) {
892  if (s == "float32") {
893  return mshadow::kFloat32;
894  } else if (s == "float64") {
895  return mshadow::kFloat64;
896  } else if (s == "float16") {
897  return mshadow::kFloat16;
898  } else if (s == "uint8") {
899  return mshadow::kUint8;
900  } else if (s == "int8") {
901  return mshadow::kInt8;
902  } else if (s == "int32") {
903  return mshadow::kInt32;
904  } else if (s == "int64") {
905  return mshadow::kInt64;
906  } else {
907  LOG(FATAL) << "unknown type " << s;
908  }
909  LOG(FATAL) << "should not reach here ";
910  return 0;
911 }
912 
913 inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
914  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
915  os << "bool"; return os;
916  }
917  if (t.code < kCustomBegin) {
918  os << TypeCode2Str(t.code);
919  } else {
920  LOG(FATAL) << "custom MXNetDataType is not supported";
921  // os << "custom[" << GetCustomTypeName(t.code) << "]";
922  }
923  if (t.code == kHandle) return os;
924  os << static_cast<int>(t.bits);
925  if (t.lanes != 1) {
926  os << 'x' << static_cast<int>(t.lanes);
927  }
928  return os;
929 }
930 
931 inline std::ostream& operator<<(std::ostream& os, const MXNetDataType& dtype) { // NOLINT(*)
932  return os << dtype.operator DLDataType();
933 }
934 
936  CHECK_LT(i, num_args)
937  << "not enough argument passed, "
938  << num_args << " passed"
939  << " but request arg[" << i << "].";
940  return MXNetArgValue(values[i], type_codes[i]);
941 }
942 
943 inline int MXNetArgs::size() const {
944  return num_args;
945 }
946 
947 inline void PackedFunc::CallPacked(MXNetArgs args, MXNetRetValue* rv) const {
948  body_(args, rv);
949 }
950 
952  return body_;
953 }
954 
955 // internal namespace
956 namespace detail {
957 
958 template<bool stop, std::size_t I, typename F>
960  template<typename T, typename ...Args>
961  static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
962  f(I, std::forward<T>(value));
963  for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
964  ::run(f, std::forward<Args>(args)...);
965  }
966 };
967 
968 template<std::size_t I, typename F>
969 struct for_each_dispatcher<true, I, F> {
970  static void run(const F& f) {} // NOLINT(*)
971 };
972 
973 template<typename F, typename ...Args>
974 inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
975  for_each_dispatcher<sizeof...(Args) == 0, 0, F>
976  ::run(f, std::forward<Args>(args)...);
977 }
978 } // namespace detail
979 
980 /* \brief argument settter to PackedFunc */
982  public:
983  MXNetArgsSetter(MXNetValue* values, int* type_codes)
984  : values_(values), type_codes_(type_codes) {}
985  // setters for POD types
986  template<typename T,
987  typename = typename std::enable_if<
988  std::is_integral<T>::value>::type>
989  void operator()(size_t i, T value) const {
990  values_[i].v_int64 = static_cast<int64_t>(value);
991  type_codes_[i] = kDLInt;
992  }
993  void operator()(size_t i, uint64_t value) const {
994  values_[i].v_int64 = static_cast<int64_t>(value);
995  CHECK_LE(value,
996  static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
997  type_codes_[i] = kDLInt;
998  }
999  void operator()(size_t i, double value) const {
1000  values_[i].v_float64 = value;
1001  type_codes_[i] = kDLFloat;
1002  }
1003  void operator()(size_t i, std::nullptr_t value) const {
1004  values_[i].v_handle = value;
1005  type_codes_[i] = kNull;
1006  }
1007  void operator()(size_t i, const MXNetArgValue& value) const {
1008  values_[i] = value.value_;
1009  type_codes_[i] = value.type_code_;
1010  }
1011  void operator()(size_t i, void* value) const {
1012  values_[i].v_handle = value;
1013  type_codes_[i] = kHandle;
1014  }
1015  void operator()(size_t i, DLTensor* value) const {
1016  values_[i].v_handle = value;
1017  type_codes_[i] = kArrayHandle;
1018  }
1019  void operator()(size_t i, const char* value) const {
1020  values_[i].v_str = value;
1021  type_codes_[i] = kStr;
1022  }
1023  // setters for container type
1024  // They must be reference(instead of const ref)
1025  // to make sure they are alive in the tuple(instead of getting converted)
1026  void operator()(size_t i, const std::string& value) const { // NOLINT(*)
1027  values_[i].v_str = value.c_str();
1028  type_codes_[i] = kStr;
1029  }
1030  void operator()(size_t i, DLDataType value) const {
1031  values_[i].v_type = value;
1032  type_codes_[i] = kMXNetType;
1033  }
1034  void operator()(size_t i, MXNetDataType dtype) const {
1035  operator()(i, dtype.operator DLDataType());
1036  }
1037  void operator()(size_t i, const MXNetByteArray& value) const { // NOLINT(*)
1038  values_[i].v_handle = const_cast<MXNetByteArray*>(&value);
1039  type_codes_[i] = kBytes;
1040  }
1041  void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
1042  values_[i].v_handle = const_cast<PackedFunc*>(&value);
1043  type_codes_[i] = kFuncHandle;
1044  }
1045  template<typename FType>
1046  void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
1047  operator()(i, value.packed());
1048  }
1049  void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
1050  if (value.defined()) {
1051  values_[i].v_handle = value.data_.data_;
1052  type_codes_[i] = kObjectHandle;
1053  } else {
1054  type_codes_[i] = kNull;
1055  }
1056  }
1057  void operator()(size_t i, const MXNetRetValue& value) const { // NOLINT(*)
1058  if (value.type_code() == kStr) {
1059  values_[i].v_str = value.ptr<std::string>()->c_str();
1060  type_codes_[i] = kStr;
1061  } else {
1062  CHECK_NE(value.type_code(), kBytes) << "not handled.";
1063  values_[i] = value.value_;
1064  type_codes_[i] = value.type_code();
1065  }
1066  }
1067 
1068  private:
1070  MXNetValue* values_;
1072  int* type_codes_;
1073 };
1074 
1075 template<typename... Args>
1076 inline MXNetRetValue PackedFunc::operator()(Args&& ...args) const {
1077  const int kNumArgs = sizeof...(Args);
1078  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1079  MXNetValue values[kArraySize];
1080  int type_codes[kArraySize];
1081  detail::for_each(MXNetArgsSetter(values, type_codes),
1082  std::forward<Args>(args)...);
1083  MXNetRetValue rv;
1084  body_(MXNetArgs(values, type_codes, kNumArgs), &rv);
1085  return rv;
1086 }
1087 
1088 namespace detail {
1089 template<typename R, int nleft, int index, typename F>
1091  template<typename ...Args>
1092  static void run(const F& f,
1093  const MXNetArgs& args_pack,
1094  MXNetRetValue* rv,
1095  Args&&... unpacked_args) {
1097  ::run(f, args_pack, rv,
1098  std::forward<Args>(unpacked_args)...,
1099  args_pack[index]);
1100  }
1101 };
1102 
1103 template<typename R, int index, typename F>
1104 struct unpack_call_dispatcher<R, 0, index, F> {
1105  template<typename ...Args>
1106  static void run(const F& f,
1107  const MXNetArgs& args_pack,
1108  MXNetRetValue* rv,
1109  Args&&... unpacked_args) {
1110  *rv = R(f(std::forward<Args>(unpacked_args)...));
1111  }
1112 };
1113 
1114 template<int index, typename F>
1115 struct unpack_call_dispatcher<void, 0, index, F> {
1116  template<typename ...Args>
1117  static void run(const F& f,
1118  const MXNetArgs& args_pack,
1119  MXNetRetValue* rv,
1120  Args&&... unpacked_args) {
1121  f(std::forward<Args>(unpacked_args)...);
1122  }
1123 };
1124 
1125 template<typename R, int nargs, typename F>
1126 inline void unpack_call(const F& f, const MXNetArgs& args, MXNetRetValue* rv) {
1128 }
1129 
1130 template<typename R, typename ...Args>
1131 inline R call_packed(const PackedFunc& pf, Args&& ...args) {
1132  return R(pf(std::forward<Args>(args)...));
1133 }
1134 
1135 template<typename R>
1137  template<typename ...Args>
1138  static inline R run(const PackedFunc& pf, Args&& ...args) {
1139  return pf(std::forward<Args>(args)...);
1140  }
1141 };
1142 
1143 template<>
1145  template<typename ...Args>
1146  static inline void run(const PackedFunc& pf, Args&& ...args) {
1147  pf(std::forward<Args>(args)...);
1148  }
1149 };
1150 } // namespace detail
1151 
1152 template<typename R, typename ...Args>
1154  : packed_(packed) {}
1155 
1156 template<typename R, typename ...Args>
1158  : packed_(value.operator PackedFunc()) {}
1159 
1160 template<typename R, typename ...Args>
1162  : packed_(value.operator PackedFunc()) {}
1163 
1164 template<typename R, typename ...Args>
1165 template<typename FType>
1166 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
1167  packed_ = PackedFunc([flambda](const MXNetArgs& args, MXNetRetValue* rv) {
1168  detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
1169  });
1170 }
1171 
1172 template<typename R, typename ...Args>
1173 inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
1175  ::run(packed_, std::forward<Args>(args)...);
1176 }
1177 
1178 // extension and node type handling
1179 namespace detail {
1180 template<typename T, typename TSrc, bool is_ext, bool is_nd>
1182  static T Apply(const TSrc* self) {
1183  static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
1184  return self->template AsObjectRef<T>();
1185  }
1186 };
1187 
1188 } // namespace detail
1189 
1190 template<typename T, typename>
1191 inline MXNetRetValue::operator T() const {
1192  return detail::
1196  ::Apply(this);
1197 }
1198 
1199 } // namespace runtime
1200 } // namespace mxnet
1201 #endif // MXNET_RUNTIME_PACKED_FUNC_H_
MXNetArgValue()
default constructor
Definition: packed_func.h:471
Definition: c_runtime_api.h:46
Definition: base.h:352
Definition: packed_func.h:1181
MXNetRetValue & operator=(int value)
Definition: packed_func.h:637
void operator()(size_t i, const MXNetArgValue &value) const
Definition: packed_func.h:1007
void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1015
void * v_handle
Definition: c_runtime_api.h:79
Definition: dlpack.h:81
Definition: c_runtime_api.h:62
MXNetPODValue_()
Definition: packed_func.h:452
MXNetRetValue & operator=(std::string value)
Definition: packed_func.h:647
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&...unpacked_args)
Definition: packed_func.h:1117
std::function< void(MXNetArgs args, MXNetRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:97
Definition: c_runtime_api.h:69
Definition: c_runtime_api.h:47
The type trait indicates subclass of TVM&#39;s NDArray. For irrelavant classes, code = -1...
Definition: ndarray.h:38
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: packed_func.h:808
void operator()(size_t i, const ObjectRef &value) const
Definition: packed_func.h:1049
namespace of mxnet
Definition: api_registry.h:33
MXNetRetValue & operator=(const MXNetDataType &other)
Definition: packed_func.h:656
int size() const
Definition: packed_func.h:943
MXNetRetValue & operator=(PackedFunc f)
Definition: packed_func.h:671
void CallPacked(MXNetArgs args, MXNetRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:947
void operator()(size_t i, const PackedFunc &value) const
Definition: packed_func.h:1041
const char * TypeCode2Str(int type_code)
Convert type code to its name.
Definition: packed_func.h:851
Arguments into TVM functions.
Definition: packed_func.h:321
MXNetRetValue & operator=(MXNetRetValue &&other)
Definition: packed_func.h:610
void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1026
void operator()(size_t i, T value) const
Definition: packed_func.h:989
int String2MXNetType(const std::string &s)
Definition: packed_func.h:891
A custom smart pointer for Object.
Definition: object.h:345
MXNetArgsSetter(MXNetValue *values, int *type_codes)
Definition: packed_func.h:983
void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1046
TypedPackedFunc()
default constructor
Definition: packed_func.h:189
Definition: c_runtime_api.h:53
void operator()(size_t i, const MXNetByteArray &value) const
Definition: packed_func.h:1037
void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1003
MXNetRetValue(MXNetRetValue &&other)
move constructor from anoter return value.
Definition: packed_func.h:560
int type_code_
the type code
Definition: packed_func.h:459
MXNetRetValue & operator=(const T &other)
Definition: packed_func.h:695
MXNetRetValue & operator=(MXNetByteArray value)
Definition: packed_func.h:659
Definition: dlpack.h:80
Base class of all object reference.
Definition: object.h:499
Definition: dlpack.h:82
void operator()(size_t i, const MXNetRetValue &value) const
Definition: packed_func.h:1057
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:132
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:77
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint...
Definition: dlpack.h:100
Definition: c_runtime_api.h:50
void operator()(size_t i, MXNetDataType dtype) const
Definition: packed_func.h:1034
void operator()(size_t i, double value) const
Definition: packed_func.h:999
MXNetRetValue & operator=(bool value)
Definition: packed_func.h:642
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:301
MXNetRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:622
MXNetRetValue()
default constructor
Definition: packed_func.h:555
void operator()(size_t i, const char *value) const
Definition: packed_func.h:1019
MXNetRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:676
Definition: packed_func.h:981
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:468
MXNetPODValue_(MXNetValue value, int type_code)
Definition: packed_func.h:453
MXNetArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:935
R call_packed(const PackedFunc &pf, Args &&...args)
Definition: packed_func.h:1131
const MXNetValue & value() const
Definition: packed_func.h:718
const MXNetValue * values
Definition: packed_func.h:323
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:264
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:149
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:993
MXNetRetValue & operator=(int64_t value)
Definition: packed_func.h:632
PackedFunc(FType body)
constructing a packed function from a std::function.
Definition: packed_func.h:106
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
Definition: base.h:357
Type traits to mark if a class is tvm extension type.
Definition: packed_func.h:379
~MXNetRetValue()
destructor
Definition: packed_func.h:566
base class of all object containers.
Definition: object.h:149
Definition: base.h:359
Definition: base.h:353
MXNetRetValue & operator=(double value)
Definition: packed_func.h:617
Runtime primitive data type.
Definition: data_type.h:41
void operator()(size_t i, void *value) const
Definition: packed_func.h:1011
void unpack_call(const F &f, const MXNetArgs &args, MXNetRetValue *rv)
Definition: packed_func.h:1126
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:572
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&...unpacked_args)
Definition: packed_func.h:1106
MXNetRetValue & operator=(DLDataType t)
Definition: packed_func.h:651
static void run(const F &f, T &&value, Args &&...args)
Definition: packed_func.h:961
PackedFunc()
default constructor
Definition: packed_func.h:99
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:136
Definition: base.h:356
MXNetValue value_
The value.
Definition: packed_func.h:457
#define MXNET_CHECK_TYPE_CODE(CODE, T)
convert a string to TVM type.
Definition: packed_func.h:363
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:445
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&...unpacked_args)
Definition: packed_func.h:1092
A device-independent managed NDArray abstraction.
A managed object in MXNet runtime.
bool defined() const
Definition: object.h:538
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:76
void MoveToCHost(MXNetValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null. The managed resources is moved to front-end and the front end should take charge in managing them.
Definition: packed_func.h:709
FType body() const
Definition: packed_func.h:951
const PackedFunc & packed() const
Definition: packed_func.h:293
Definition: base.h:355
Definition: c_runtime_api.h:55
void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1030
MXNetRetValue & operator=(ObjectRef other)
Definition: packed_func.h:663
static R run(const PackedFunc &pf, Args &&...args)
Definition: packed_func.h:1138
MXNetRetValue & operator=(void *value)
Definition: packed_func.h:627
std::ostream & operator<<(std::ostream &os, DLDataType t)
Definition: packed_func.h:913
Definition: c_runtime_api.h:51
Byte array type used to pass in byte array When kBytes is used as data type.
Definition: c_runtime_api.h:88
MXNetRetValue & operator=(ObjectPtr< T > other)
Definition: packed_func.h:667
const int * type_codes
Definition: packed_func.h:324
int type_code() const
Definition: packed_func.h:435
static T Apply(const TSrc *self)
Definition: packed_func.h:1182
Base expr nodes in MXNet.
MXNetRetValue & operator=(const MXNetRetValue &other)
Definition: packed_func.h:679
Definition: c_runtime_api.h:57
Definition: c_runtime_api.h:54
MXNetRetValue & operator=(const MXNetArgValue &other)
Definition: packed_func.h:683
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:297
Definition: packed_func.h:959
MXNetArgs(const MXNetValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:332
const char * data
Definition: c_runtime_api.h:89
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:273
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:191
Definition: base.h:354
void for_each(const F &f, Args &&...args)
Definition: packed_func.h:974
Internal base class to handle conversion to POD values.
Definition: packed_func.h:387
MXNetRetValue(const MXNetRetValue &other)
Definition: packed_func.h:579
Definition: base.h:358
MXNetRetValue & operator=(::mxnet::NDArray *value)
Definition: packed_func.h:687
runtime::MXNetDataType MXNetDataType
Definition: data_type.h:214
The data type the tensor can hold.
Definition: dlpack.h:94
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
static void run(const PackedFunc &pf, Args &&...args)
Definition: packed_func.h:1146
size_t size
Definition: c_runtime_api.h:90
ndarray interface
Definition: ndarray.h:82
PackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:101
MXNetArgValue(MXNetValue value, int type_code)
constructor
Definition: packed_func.h:477
MXNetRetValue operator()(Args &&...args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1076
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:552
const MXNetValue & value() const
Definition: packed_func.h:532
static void run(const F &f)
Definition: packed_func.h:970
int String2MXNetTypeWithBool(const std::string &s)
Definition: packed_func.h:867
Definition: c_runtime_api.h:48
int num_args
Definition: packed_func.h:325
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:240
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally...
Definition: packed_func.h:184