mxnet
packed_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: This file originates from incubator-tvm
25 #ifndef MXNET_RUNTIME_PACKED_FUNC_H_
26 #define MXNET_RUNTIME_PACKED_FUNC_H_
27 
28 #include <dmlc/logging.h>
30 #include <mxnet/runtime/object.h>
31 #include <mxnet/runtime/ndarray.h>
37 #include <mxnet/runtime/py_arg.h>
38 #include <mxnet/node/container.h>
39 #include <mxnet/ir/expr.h>
40 #include <mxnet/ndarray.h>
41 #include <mxnet/base.h>
42 #include <functional>
43 #include <tuple>
44 #include <vector>
45 #include <string>
46 #include <limits>
47 #include <memory>
48 #include <utility>
49 #include <type_traits>
50 #include <sstream>
51 
52 namespace mxnet {
53 // forward declarations
54 // class Integer;
55 // class Expr;
56 
57 namespace runtime {
58 
64 inline DLDataType String2DLDataType(std::string s);
65 
66 // forward declarations
67 class MXNetArgs;
68 class MXNetArgValue;
69 class MXNetRetValue;
70 class MXNetArgsSetter;
71 
80 class PackedFunc {
81  public:
100  using FType = std::function<void(MXNetArgs args, MXNetRetValue* rv)>;
104  PackedFunc(std::nullptr_t null) {} // NOLINT(*)
109  explicit PackedFunc(FType body) : body_(body) {}
124  template <typename... Args>
125  inline MXNetRetValue operator()(Args&&... args) const;
131  inline void CallPacked(MXNetArgs args, MXNetRetValue* rv) const;
133  inline FType body() const;
135  bool operator==(std::nullptr_t null) const {
136  return body_ == nullptr;
137  }
139  bool operator!=(std::nullptr_t null) const {
140  return body_ != nullptr;
141  }
142 
143  private:
145  FType body_;
146 };
147 
151 template <typename FType>
153 
186 template <typename R, typename... Args>
187 class TypedPackedFunc<R(Args...)> {
188  public:
190  using TSelf = TypedPackedFunc<R(Args...)>;
194  TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
212  inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
217  inline TypedPackedFunc(const MXNetRetValue& value); // NOLINT(*)
222  inline TypedPackedFunc(const MXNetArgValue& value); // NOLINT(*)
238  template <typename FLambda,
239  typename = typename std::enable_if<
240  std::is_convertible<FLambda,
241  std::function<R(Args...)>>::value>::type>
242  TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
243  this->AssignTypedLambda(typed_lambda);
244  }
261  template <typename FLambda,
262  typename = typename std::enable_if<
263  std::is_convertible<FLambda,
264  std::function<R(Args...)>>::value>::type>
265  TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
266  this->AssignTypedLambda(typed_lambda);
267  return *this;
268  }
275  packed_ = packed;
276  return *this;
277  }
283  inline R operator()(Args... args) const;
288  operator PackedFunc() const {
289  return packed();
290  }
294  const PackedFunc& packed() const {
295  return packed_;
296  }
298  bool operator==(std::nullptr_t null) const {
299  return packed_ == nullptr;
300  }
302  bool operator!=(std::nullptr_t null) const {
303  return packed_ != nullptr;
304  }
305 
306  private:
307  friend class MXNetRetValue;
309  PackedFunc packed_;
317  template <typename FLambda>
318  inline void AssignTypedLambda(FLambda flambda);
319 };
320 
322 class MXNetArgs {
323  public:
325  const int* type_codes;
326  int num_args;
333  MXNetArgs(const MXNetValue* values, const int* type_codes, int num_args)
336  inline int size() const;
342  inline MXNetArgValue operator[](int i) const;
343 };
344 
350 inline const char* TypeCode2Str(int type_code);
351 
357 // inline TVMType String2TVMType(std::string s);
358 
359 // macro to check type code.
360 #define MXNET_CHECK_TYPE_CODE(CODE, T) \
361  CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)
362 
374 template <typename T>
376  static const int code = 0;
377 };
378 
383 template <typename T>
385  static bool Check(const Object* ptr) {
386  using ContainerType = typename T::ContainerType;
387  if (ptr == nullptr)
388  return T::_type_is_nullable;
389  return ptr->IsInstance<ContainerType>();
390  }
391  static std::string TypeName() {
392  using ContainerType = typename T::ContainerType;
393  return ContainerType::_type_key;
394  }
395 };
396 
402  public:
403  operator double() const {
404  // Allow automatic conversion from int to float
405  // This avoids errors when user pass in int from
406  // the frontend while the API expects a float.
407  if (type_code_ == kDLInt) {
408  return static_cast<double>(value_.v_int64);
409  }
411  return value_.v_float64;
412  }
413  operator int64_t() const {
415  return value_.v_int64;
416  }
417  operator uint64_t() const {
419  return value_.v_uint64;
420  }
421  operator int() const {
423  CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
424  return static_cast<int>(value_.v_int64);
425  }
426  operator bool() const {
428  return value_.v_int64 != 0;
429  }
430  operator void*() const {
431  if (type_code_ == kNull)
432  return nullptr;
434  return value_.v_handle;
435  }
436  operator ObjectRef() const {
437  if (type_code_ == kNull) {
438  return ObjectRef(ObjectPtr<Object>(nullptr));
439  }
441  return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
442  }
443  template <typename TObjectRef,
444  typename = typename std::enable_if<std::is_class<TObjectRef>::value>::type>
445  inline bool IsObjectRef() const;
446  template <typename TObjectRef>
447  inline TObjectRef AsObjectRef() const;
448  int type_code() const {
449  return type_code_;
450  }
451 
457  template <typename T>
458  T* ptr() const {
459  return static_cast<T*>(value_.v_handle);
460  }
461 
462  protected:
463  friend class MXNetArgsSetter;
464  friend class MXNetRetValue;
467 
472 };
473 
481  public:
490  // reuse converter from parent
491  using MXNetPODValue_::operator double;
492  using MXNetPODValue_::operator int64_t;
493  using MXNetPODValue_::operator uint64_t;
494  using MXNetPODValue_::operator int;
495  using MXNetPODValue_::operator bool;
496  using MXNetPODValue_::operator void*;
497  using MXNetPODValue_::operator ObjectRef;
500 
501  // conversion operator.
502  operator std::string() const {
503  if (type_code_ == kBytes) {
504  MXNetByteArray* arr = static_cast<MXNetByteArray*>(value_.v_handle);
505  return std::string(arr->data, arr->size);
506  } else {
508  return std::string(value_.v_str);
509  }
510  }
511  operator DLDataType() const {
512  if (type_code_ == kStr) {
513  return String2DLDataType(operator std::string());
514  }
515  // None type
516  if (type_code_ == kNull) {
517  DLDataType t;
518  t.code = kHandle;
519  t.bits = 0;
520  t.lanes = 0;
521  return t;
522  }
524  return value_.v_type;
525  }
526  operator MXNetDataType() const {
527  return MXNetDataType(operator DLDataType());
528  }
529  operator ::mxnet::NDArray*() const {
530  if (type_code_ == kNull) {
531  return nullptr;
532  }
534  return reinterpret_cast<::mxnet::NDArray*>(value_.v_handle);
535  }
536  template <typename FType>
537  operator TypedPackedFunc<FType>() const {
538  return TypedPackedFunc<FType>(operator PackedFunc());
539  }
540  const MXNetValue& value() const {
541  return value_;
542  }
543  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
544  inline operator T() const;
545 };
546 
556  public:
564  other.value_.v_handle = nullptr;
565  other.type_code_ = kNull;
566  }
569  this->Clear();
570  }
571  // reuse converter from parent
572  using MXNetPODValue_::operator double;
573  using MXNetPODValue_::operator int64_t;
574  using MXNetPODValue_::operator uint64_t;
575  using MXNetPODValue_::operator int;
576  using MXNetPODValue_::operator bool;
577  using MXNetPODValue_::operator void*;
578  using MXNetPODValue_::operator ObjectRef;
581 
583  this->Assign(other);
584  }
585  // conversion operators
586  operator std::string() const {
587  if (type_code_ == kBytes) {
588  return *ptr<std::string>();
589  }
591  return *ptr<std::string>();
592  }
593  operator DLDataType() const {
594  if (type_code_ == kStr) {
595  return String2DLDataType(operator std::string());
596  }
598  return value_.v_type;
599  }
600  operator MXNetDataType() const {
601  return MXNetDataType(operator DLDataType());
602  }
603  template <typename FType>
604  operator TypedPackedFunc<FType>() const {
605  return TypedPackedFunc<FType>(operator PackedFunc());
606  }
607  // Assign operators
609  this->Clear();
610  value_ = other.value_;
611  type_code_ = other.type_code_;
612  other.type_code_ = kNull;
613  return *this;
614  }
616  this->SwitchToPOD(kDLFloat);
618  return *this;
619  }
620  MXNetRetValue& operator=(std::nullptr_t value) {
621  this->SwitchToPOD(kNull);
623  return *this;
624  }
626  this->SwitchToPOD(kHandle);
628  return *this;
629  }
631  this->SwitchToPOD(kDLInt);
632  value_.v_int64 = value;
633  return *this;
634  }
636  this->SwitchToPOD(kDLInt);
637  value_.v_int64 = value;
638  return *this;
639  }
641  this->SwitchToPOD(kDLInt);
642  value_.v_int64 = value;
643  return *this;
644  }
645  MXNetRetValue& operator=(std::string value) {
646  this->SwitchToClass(kStr, value);
647  return *this;
648  }
650  this->SwitchToPOD(kMXNetType);
651  value_.v_type = t;
652  return *this;
653  }
655  return operator=(other.operator DLDataType());
656  }
658  this->SwitchToClass(kBytes, std::string(value.data, value.size));
659  return *this;
660  }
662  if (other.as<NDArrayHandleObj>()) {
663  return operator=(Downcast<NDArrayHandle, ObjectRef>(other));
664  }
665  return operator=(std::move(other.data_));
666  }
667  template <typename T>
669  SwitchToObject(kObjectHandle, std::move(other));
670  return *this;
671  }
672  template <typename FType>
674  return operator=(f.packed());
675  }
676  MXNetRetValue& operator=(const MXNetRetValue& other) { // NOLINT(*0
677  this->Assign(other);
678  return *this;
679  }
681  this->Assign(other);
682  return *this;
683  }
685  this->SwitchToPOD(kNDArrayHandle);
686  value_.v_handle = reinterpret_cast<void*>(value);
687  return *this;
688  }
690  this->SwitchToPOD(kNDArrayHandle);
691  NDArray* arr = new NDArray(value->value);
692  value_.v_handle = reinterpret_cast<void*>(arr);
693  return *this;
694  }
696  this->SwitchToPOD(kPyArg);
697  value_.v_int64 = value.offset();
698  return *this;
699  }
700  template <typename T, typename = typename std::enable_if<extension_type_info<T>::code != 0>::type>
701  MXNetRetValue& operator=(const T& other) {
702  this->SwitchToClass<T>(extension_type_info<T>::code, other);
703  return *this;
704  }
714  void MoveToCHost(MXNetValue* ret_value, int* ret_type_code) {
715  // cannot move str; need specially handle.
716  CHECK(type_code_ != kStr && type_code_ != kBytes);
717  *ret_value = value_;
718  *ret_type_code = type_code_;
719  type_code_ = kNull;
720  }
722  const MXNetValue& value() const {
723  CHECK(type_code_ != kObjectHandle && type_code_ != kStr)
724  << "MXNetRetValue.value can only be used for POD data";
725  return value_;
726  }
727  // ObjectRef related extenstions: in tvm/packed_func_ext.h
728  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
729  inline operator T() const;
730 
731  private:
732  template <typename T>
733  void Assign(const T& other) {
734  switch (other.type_code()) {
735  case kStr: {
736  SwitchToClass<std::string>(kStr, other);
737  break;
738  }
739  case kBytes: {
740  SwitchToClass<std::string>(kBytes, other);
741  break;
742  }
743  case kObjectHandle: {
744  *this = other.operator ObjectRef();
745  break;
746  }
747  default: {
748  if (other.type_code() < kExtBegin) {
749  SwitchToPOD(other.type_code());
750  value_ = other.value_;
751  } else {
752  LOG(FATAL) << "Does not support ext type";
753  }
754  break;
755  }
756  }
757  }
758  // get the internal container.
759  void SwitchToPOD(int type_code) {
760  if (type_code_ != type_code) {
761  this->Clear();
763  }
764  }
765  template <typename T>
766  void SwitchToClass(int type_code, T v) {
767  if (type_code_ != type_code) {
768  this->Clear();
770  value_.v_handle = new T(v);
771  } else {
772  *static_cast<T*>(value_.v_handle) = v;
773  }
774  }
775  void SwitchToObject(int type_code, ObjectPtr<Object> other) {
776  if (other.data_ != nullptr) {
777  this->Clear();
779  // move the handle out
780  value_.v_handle = other.data_;
781  other.data_ = nullptr;
782  } else {
783  SwitchToPOD(kNull);
784  }
785  }
786  void Clear() {
787  if (type_code_ == kNull)
788  return;
789  switch (type_code_) {
790  case kStr:
791  delete ptr<std::string>();
792  break;
793  case kObjectHandle: {
794  static_cast<Object*>(value_.v_handle)->DecRef();
795  break;
796  }
797  }
798  if (type_code_ > kExtBegin) {
799  LOG(FATAL) << "Does not support ext type";
800  }
801  type_code_ = kNull;
802  }
803 };
804 
805 inline DLDataType String2DLDataType(std::string s) {
806  DLDataType t;
807  // handle None type
808  if (s.length() == 0) {
809  t.bits = 0;
810  t.lanes = 0;
811  t.code = kHandle;
812  return t;
813  }
814  t.bits = 32;
815  t.lanes = 1;
816  const char* scan = nullptr;
817  if (s.substr(0, 3) == "int") {
818  t.code = kDLInt;
819  scan = s.c_str() + 3;
820  } else if (s.substr(0, 4) == "uint") {
821  t.code = kDLUInt;
822  scan = s.c_str() + 4;
823  } else if (s.substr(0, 5) == "float") {
824  t.code = kDLFloat;
825  scan = s.c_str() + 5;
826  } else if (s.substr(0, 6) == "handle") {
827  t.code = kHandle;
828  t.bits = 64; // handle uses 64 bit by default.
829  scan = s.c_str() + 6;
830  } else if (s == "bool") {
831  t.code = kDLUInt;
832  t.bits = 1;
833  t.lanes = 1;
834  return t;
835  } else if (s.substr(0, 6) == "custom") {
836  LOG(FATAL) << "custom MXNetDataType is not supported";
837  // t.code = ParseCustomDatatype(s, &scan);
838  } else {
839  scan = s.c_str();
840  LOG(FATAL) << "unknown type " << s;
841  }
842  char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
843  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
844  if (bits != 0)
845  t.bits = bits;
846  char* endpt = xdelim;
847  if (*xdelim == 'x') {
848  t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
849  }
850  CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
851  return t;
852 }
853 
854 // implementation details
855 inline const char* TypeCode2Str(int type_code) {
856  switch (type_code) {
857  case kDLInt:
858  return "int";
859  case kDLUInt:
860  return "uint";
861  case kDLFloat:
862  return "float";
863  case kStr:
864  return "str";
865  case kBytes:
866  return "bytes";
867  case kHandle:
868  return "handle";
869  case kNull:
870  return "NULL";
871  case kObjectHandle:
872  return "ObjectCell";
873  case kNDArrayHandle:
874  return "NDArray";
875  default:
876  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
877  return "";
878  }
879 }
880 
881 inline int String2MXNetTypeWithBool(const std::string& s) {
882  if (s == "float32") {
883  return mshadow::kFloat32;
884  } else if (s == "float64") {
885  return mshadow::kFloat64;
886  } else if (s == "float16") {
887  return mshadow::kFloat16;
888  } else if (s == "bfloat16") {
889  return mshadow::kBfloat16;
890  } else if (s == "uint8") {
891  return mshadow::kUint8;
892  } else if (s == "int8") {
893  return mshadow::kInt8;
894  } else if (s == "int32") {
895  return mshadow::kInt32;
896  } else if (s == "int64") {
897  return mshadow::kInt64;
898  } else if (s == "bool") {
899  return mshadow::kBool;
900  } else if (s == "int16") {
901  return mshadow::kInt16;
902  } else if (s == "uint16") {
903  return mshadow::kUint16;
904  } else if (s == "uint32") {
905  return mshadow::kUint32;
906  } else if (s == "uint64") {
907  return mshadow::kUint64;
908  } else {
909  LOG(FATAL) << "unknown type " << s;
910  }
911  LOG(FATAL) << "should not reach here ";
912  return 0;
913 }
914 
915 inline int String2MXNetType(const std::string& s) {
916  if (s == "float32") {
917  return mshadow::kFloat32;
918  } else if (s == "float64") {
919  return mshadow::kFloat64;
920  } else if (s == "float16") {
921  return mshadow::kFloat16;
922  } else if (s == "bfloat16") {
923  return mshadow::kBfloat16;
924  } else if (s == "uint8") {
925  return mshadow::kUint8;
926  } else if (s == "int8") {
927  return mshadow::kInt8;
928  } else if (s == "int32") {
929  return mshadow::kInt32;
930  } else if (s == "int64") {
931  return mshadow::kInt64;
932  } else if (s == "int16") {
933  return mshadow::kInt16;
934  } else if (s == "uint16") {
935  return mshadow::kUint16;
936  } else if (s == "uint32") {
937  return mshadow::kUint32;
938  } else if (s == "uint64") {
939  return mshadow::kUint64;
940  } else {
941  LOG(FATAL) << "unknown type " << s;
942  }
943  LOG(FATAL) << "should not reach here ";
944  return 0;
945 }
946 
947 inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
948  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
949  os << "bool";
950  return os;
951  }
952  if (t.code < kCustomBegin) {
953  os << TypeCode2Str(t.code);
954  } else {
955  LOG(FATAL) << "custom MXNetDataType is not supported";
956  // os << "custom[" << GetCustomTypeName(t.code) << "]";
957  }
958  if (t.code == kHandle)
959  return os;
960  os << static_cast<int>(t.bits);
961  if (t.lanes != 1) {
962  os << 'x' << static_cast<int>(t.lanes);
963  }
964  return os;
965 }
966 
967 inline std::ostream& operator<<(std::ostream& os, const MXNetDataType& dtype) { // NOLINT(*)
968  return os << dtype.operator DLDataType();
969 }
970 
972  CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed"
973  << " but request arg[" << i << "].";
974  return MXNetArgValue(values[i], type_codes[i]);
975 }
976 
977 inline int MXNetArgs::size() const {
978  return num_args;
979 }
980 
981 inline void PackedFunc::CallPacked(MXNetArgs args, MXNetRetValue* rv) const {
982  body_(args, rv);
983 }
984 
986  return body_;
987 }
988 
989 // internal namespace
990 namespace detail {
991 
992 template <bool stop, std::size_t I, typename F>
994  template <typename T, typename... Args>
995  static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
996  f(I, std::forward<T>(value));
997  for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
998  }
999 };
1000 
1001 template <std::size_t I, typename F>
1002 struct for_each_dispatcher<true, I, F> {
1003  static void run(const F& f) {} // NOLINT(*)
1004 };
1005 
1006 template <typename F, typename... Args>
1007 inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
1008  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1009 }
1010 } // namespace detail
1011 
1012 /* \brief argument settter to PackedFunc */
1014  public:
1015  MXNetArgsSetter(MXNetValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}
1016  // setters for POD types
1017  template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1018  void operator()(size_t i, T value) const {
1019  values_[i].v_int64 = static_cast<int64_t>(value);
1020  type_codes_[i] = kDLInt;
1021  }
1022  void operator()(size_t i, uint64_t value) const {
1023  values_[i].v_int64 = static_cast<int64_t>(value);
1024  CHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1025  type_codes_[i] = kDLInt;
1026  }
1027  void operator()(size_t i, double value) const {
1028  values_[i].v_float64 = value;
1029  type_codes_[i] = kDLFloat;
1030  }
1031  void operator()(size_t i, std::nullptr_t value) const {
1032  values_[i].v_handle = value;
1033  type_codes_[i] = kNull;
1034  }
1035  void operator()(size_t i, const MXNetArgValue& value) const {
1036  values_[i] = value.value_;
1037  type_codes_[i] = value.type_code_;
1038  }
1039  void operator()(size_t i, void* value) const {
1040  values_[i].v_handle = value;
1041  type_codes_[i] = kHandle;
1042  }
1043  void operator()(size_t i, const char* value) const {
1044  values_[i].v_str = value;
1045  type_codes_[i] = kStr;
1046  }
1047  // setters for container type
1048  // They must be reference(instead of const ref)
1049  // to make sure they are alive in the tuple(instead of getting converted)
1050  void operator()(size_t i, const std::string& value) const { // NOLINT(*)
1051  values_[i].v_str = value.c_str();
1052  type_codes_[i] = kStr;
1053  }
1054  void operator()(size_t i, DLDataType value) const {
1055  values_[i].v_type = value;
1056  type_codes_[i] = kMXNetType;
1057  }
1058  void operator()(size_t i, MXNetDataType dtype) const {
1059  operator()(i, dtype.operator DLDataType());
1060  }
1061  void operator()(size_t i, const MXNetByteArray& value) const { // NOLINT(*)
1062  values_[i].v_handle = const_cast<MXNetByteArray*>(&value);
1063  type_codes_[i] = kBytes;
1064  }
1065  template <typename FType>
1066  void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
1067  operator()(i, value.packed());
1068  }
1069  void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
1070  if (value.defined()) {
1071  values_[i].v_handle = value.data_.data_;
1072  type_codes_[i] = kObjectHandle;
1073  } else {
1074  type_codes_[i] = kNull;
1075  }
1076  }
1077  void operator()(size_t i, const MXNetRetValue& value) const { // NOLINT(*)
1078  if (value.type_code() == kStr) {
1079  values_[i].v_str = value.ptr<std::string>()->c_str();
1080  type_codes_[i] = kStr;
1081  } else {
1082  CHECK_NE(value.type_code(), kBytes) << "not handled.";
1083  values_[i] = value.value_;
1084  type_codes_[i] = value.type_code();
1085  }
1086  }
1087 
1088  private:
1090  MXNetValue* values_;
1092  int* type_codes_;
1093 };
1094 
1095 template <typename... Args>
1096 inline MXNetRetValue PackedFunc::operator()(Args&&... args) const {
1097  const int kNumArgs = sizeof...(Args);
1098  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1099  MXNetValue values[kArraySize];
1100  int type_codes[kArraySize];
1101  detail::for_each(MXNetArgsSetter(values, type_codes), std::forward<Args>(args)...);
1102  MXNetRetValue rv;
1103  body_(MXNetArgs(values, type_codes, kNumArgs), &rv);
1104  return rv;
1105 }
1106 
1107 namespace detail {
1108 template <typename R, int nleft, int index, typename F>
1110  template <typename... Args>
1111  static void run(const F& f,
1112  const MXNetArgs& args_pack,
1113  MXNetRetValue* rv,
1114  Args&&... unpacked_args) {
1116  f, args_pack, rv, std::forward<Args>(unpacked_args)..., args_pack[index]);
1117  }
1118 };
1119 
1120 template <typename R, int index, typename F>
1121 struct unpack_call_dispatcher<R, 0, index, F> {
1122  template <typename... Args>
1123  static void run(const F& f,
1124  const MXNetArgs& args_pack,
1125  MXNetRetValue* rv,
1126  Args&&... unpacked_args) {
1127  *rv = R(f(std::forward<Args>(unpacked_args)...));
1128  }
1129 };
1130 
1131 template <int index, typename F>
1132 struct unpack_call_dispatcher<void, 0, index, F> {
1133  template <typename... Args>
1134  static void run(const F& f,
1135  const MXNetArgs& args_pack,
1136  MXNetRetValue* rv,
1137  Args&&... unpacked_args) {
1138  f(std::forward<Args>(unpacked_args)...);
1139  }
1140 };
1141 
1142 template <typename R, int nargs, typename F>
1143 inline void unpack_call(const F& f, const MXNetArgs& args, MXNetRetValue* rv) {
1145 }
1146 
1147 template <typename R, typename... Args>
1148 inline R call_packed(const PackedFunc& pf, Args&&... args) {
1149  return R(pf(std::forward<Args>(args)...));
1150 }
1151 
1152 template <typename R>
1154  template <typename... Args>
1155  static inline R run(const PackedFunc& pf, Args&&... args) {
1156  return pf(std::forward<Args>(args)...);
1157  }
1158 };
1159 
1160 template <>
1162  template <typename... Args>
1163  static inline void run(const PackedFunc& pf, Args&&... args) {
1164  pf(std::forward<Args>(args)...);
1165  }
1166 };
1167 } // namespace detail
1168 
1169 template <typename R, typename... Args>
1170 TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {}
1171 
1172 template <typename R, typename... Args>
1174  : packed_(value.operator PackedFunc()) {}
1175 
1176 template <typename R, typename... Args>
1178  : packed_(value.operator PackedFunc()) {}
1179 
1180 template <typename R, typename... Args>
1181 template <typename FType>
1182 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
1183  packed_ = PackedFunc([flambda](const MXNetArgs& args, MXNetRetValue* rv) {
1184  detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
1185  });
1186 }
1187 
1188 template <typename R, typename... Args>
1189 inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
1190  return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1191 }
1192 
1193 // extension and node type handling
1194 namespace detail {
1195 template <typename T, typename TSrc, bool is_ext, bool is_nd>
1197  static T Apply(const TSrc* self) {
1198  static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
1199  return self->template AsObjectRef<T>();
1200  }
1201 };
1202 
1203 } // namespace detail
1204 
1214 template <typename TObjectRef>
1221  static TObjectRef From(const MXNetArgValue& val) {
1222  return val.AsObjectRef<TObjectRef>();
1223  }
1229  static TObjectRef From(const MXNetRetValue& val) {
1230  return val.AsObjectRef<TObjectRef>();
1231  }
1232 };
1233 
1234 template <>
1236  static String From(const MXNetArgValue& val) {
1237  if (val.IsObjectRef<mxnet::runtime::String>()) {
1238  return val.AsObjectRef<mxnet::runtime::String>();
1239  } else {
1240  return mxnet::runtime::String(val.operator std::string());
1241  }
1242  }
1243 
1244  static String From(const MXNetRetValue& val) {
1245  if (val.IsObjectRef<mxnet::runtime::String>()) {
1246  return val.AsObjectRef<mxnet::runtime::String>();
1247  } else {
1248  return mxnet::runtime::String(val.operator std::string());
1249  }
1250  }
1251 };
1252 
1253 template <typename TObjectRef>
1254 inline TObjectRef MXNetPODValue_::AsObjectRef() const {
1255  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1256  "Conversion only works for ObjectRef");
1257  using ContainerType = typename TObjectRef::ContainerType;
1258 
1259  if (type_code_ == kNull) {
1260  CHECK(TObjectRef::_type_is_nullable)
1261  << "Expect a not null value of " << ContainerType::_type_key;
1262  return TObjectRef(ObjectPtr<Object>(nullptr));
1263  }
1264  if (type_code_ == kObjectHandle) {
1265  // normal object type check.
1266  Object* ptr = static_cast<Object*>(value_.v_handle);
1268  << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get "
1269  << ptr->GetTypeKey();
1270  return TObjectRef(GetObjectPtr<Object>(ptr));
1271  } else {
1273  return TObjectRef(ObjectPtr<Object>(nullptr));
1274  }
1275 }
1276 
1277 template <typename T, typename>
1278 inline MXNetArgValue::operator T() const {
1279  return PackedFuncValueConverter<T>::From(*this);
1280 }
1281 
1282 template <typename TObjectRef, typename>
1283 inline bool MXNetPODValue_::IsObjectRef() const {
1284  using ContainerType = typename TObjectRef::ContainerType;
1285  return type_code_ == kObjectHandle &&
1287 }
1288 
1289 inline bool String::CanConvertFrom(const MXNetArgValue& val) {
1290  return val.type_code() == kStr || val.IsObjectRef<mxnet::runtime::String>();
1291 }
1292 
1293 } // namespace runtime
1294 } // namespace mxnet
1295 #endif // MXNET_RUNTIME_PACKED_FUNC_H_
mxnet::runtime::extension_type_info::code
static const int code
Definition: packed_func.h:376
mxnet
namespace of mxnet
Definition: api_registry.h:33
kDLFloat
@ kDLFloat
Definition: dlpack.h:82
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(NDArrayHandle value)
Definition: packed_func.h:689
DLDataType
The data type the tensor can hold.
Definition: dlpack.h:94
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, MXNetDataType dtype) const
Definition: packed_func.h:1058
mxnet::runtime::MXNetRetValue::AsObjectRef
TObjectRef AsObjectRef() const
Definition: packed_func.h:1254
MXNetValue
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:72
mxnet::runtime::Object
base class of all object containers.
Definition: object.h:151
mxnet::runtime::PackedFunc::operator!=
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:139
c_runtime_api.h
container.h
Common POD(plain old data) container types.
mshadow::kUint16
@ kUint16
Definition: base.h:361
mshadow::kUint64
@ kUint64
Definition: base.h:363
mxnet::runtime::detail::MXNetValueCast::Apply
static T Apply(const TSrc *self)
Definition: packed_func.h:1197
mxnet::runtime::detail::unpack_call
void unpack_call(const F &f, const MXNetArgs &args, MXNetRetValue *rv)
Definition: packed_func.h:1143
mxnet::runtime::PackedFuncValueConverter::From
static TObjectRef From(const MXNetRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1229
ffi_helper.h
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(ObjectPtr< T > other)
Definition: packed_func.h:668
mxnet::runtime::detail::MXNetValueCast
Definition: packed_func.h:1196
mxnet::runtime::detail::typed_packed_call_dispatcher
Definition: packed_func.h:1153
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:673
mxnet::runtime::extension_type_info
Type traits to mark if a class is tvm extension type.
Definition: packed_func.h:375
MXNetValue::v_handle
void * v_handle
Definition: c_runtime_api.h:75
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1050
mxnet::runtime::ObjectPtr
A custom smart pointer for Object.
Definition: object.h:346
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1054
kStr
@ kStr
Definition: c_runtime_api.h:50
mxnet::runtime::MXNetArgs::size
int size() const
Definition: packed_func.h:977
mshadow::kInt8
@ kInt8
Definition: base.h:357
mxnet::runtime::MXNetArgValue
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:480
mshadow::kUint32
@ kUint32
Definition: base.h:362
MXNetByteArray::size
size_t size
Definition: c_runtime_api.h:87
mxnet::runtime::MXNetRetValue::~MXNetRetValue
~MXNetRetValue()
destructor
Definition: packed_func.h:568
mxnet::runtime::detail::for_each_dispatcher::run
static void run(const F &f, T &&value, Args &&... args)
Definition: packed_func.h:995
mxnet::runtime::MXNetPODValue_::ptr
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:458
mxnet::runtime::PackedFunc::operator==
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:135
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1022
mxnet::runtime::PackedFuncValueConverter
Type trait to specify special value conversion rules from MXNetArgValue and MXNetRetValue.
Definition: packed_func.h:1215
MXNET_CHECK_TYPE_CODE
#define MXNET_CHECK_TYPE_CODE(CODE, T)
convert a string to TVM type.
Definition: packed_func.h:360
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:620
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const MXNetRetValue &value) const
Definition: packed_func.h:1077
mxnet::runtime::MXNetPODValue_
Internal base class to handle conversion to POD values.
Definition: packed_func.h:401
kNDArrayHandle
@ kNDArrayHandle
Definition: c_runtime_api.h:53
kCustomBegin
@ kCustomBegin
Definition: c_runtime_api.h:65
MXNetValue::v_str
const char * v_str
Definition: c_runtime_api.h:76
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(double value)
Definition: packed_func.h:615
mxnet::runtime::String2MXNetTypeWithBool
int String2MXNetTypeWithBool(const std::string &s)
Definition: packed_func.h:881
mxnet::runtime::PackedFunc::PackedFunc
PackedFunc(FType body)
constructing a packed function from a std::function.
Definition: packed_func.h:109
mshadow::expr::F
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
mxnet::runtime::MXNetArgValue::AsObjectRef
TObjectRef AsObjectRef() const
Definition: packed_func.h:1254
kExtBegin
@ kExtBegin
Definition: c_runtime_api.h:58
mshadow::kBool
@ kBool
Definition: base.h:359
mxnet::runtime::detail::unpack_call_dispatcher< R, 0, index, F >::run
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&... unpacked_args)
Definition: packed_func.h:1123
mxnet::runtime::PackedFunc::CallPacked
void CallPacked(MXNetArgs args, MXNetRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:981
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(int value)
Definition: packed_func.h:635
expr.h
Base expr nodes in MXNet.
mxnet::runtime::MXNetRetValue::MXNetRetValue
MXNetRetValue(const MXNetRetValue &other)
Definition: packed_func.h:582
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1066
mxnet::runtime::MXNetArgs::type_codes
const int * type_codes
Definition: packed_func.h:325
mxnet::runtime::detail::for_each_dispatcher< true, I, F >::run
static void run(const F &f)
Definition: packed_func.h:1003
mxnet::runtime::String2DLDataType
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: packed_func.h:805
mshadow::kFloat64
@ kFloat64
Definition: base.h:353
kNull
@ kNull
Definition: c_runtime_api.h:46
mxnet::runtime::MXNetPODValue_::IsObjectRef
bool IsObjectRef() const
Definition: packed_func.h:1283
mxnet::runtime::Object::IsInstance
bool IsInstance() const
Definition: object.h:765
mxnet::runtime::MXNetPODValue_::value_
MXNetValue value_
The value.
Definition: packed_func.h:469
mxnet::runtime::MXNetPODValue_::type_code
int type_code() const
Definition: packed_func.h:448
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(int64_t value)
Definition: packed_func.h:630
mxnet::runtime::MXNetRetValue
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:555
mxnet::runtime::TypedPackedFunc< R(Args...)>::operator==
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:298
mxnet::runtime::TypedPackedFunc< R(Args...)>::TypedPackedFunc
TypedPackedFunc()
default constructor
Definition: packed_func.h:192
mxnet::runtime::MXNetArgsSetter::MXNetArgsSetter
MXNetArgsSetter(MXNetValue *values, int *type_codes)
Definition: packed_func.h:1015
mxnet::runtime::MXNetPODValue_::type_code_
int type_code_
the type code
Definition: packed_func.h:471
mxnet::runtime::TypedPackedFunc< R(Args...)>::TypedPackedFunc
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:194
mxnet::runtime::MXNetArgs
Arguments into TVM functions.
Definition: packed_func.h:322
mshadow::kInt16
@ kInt16
Definition: base.h:360
mxnet::runtime::MXNetPODValue_::AsObjectRef
TObjectRef AsObjectRef() const
Definition: packed_func.h:1254
mxnet::runtime::TypedPackedFunc
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:152
mxnet::runtime::detail::unpack_call_dispatcher
Definition: packed_func.h:1109
py_arg.h
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(bool value)
Definition: packed_func.h:640
MXNetByteArray::data
const char * data
Definition: c_runtime_api.h:86
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(const MXNetDataType &other)
Definition: packed_func.h:654
mxnet::runtime::ObjectTypeChecker::TypeName
static std::string TypeName()
Definition: packed_func.h:391
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(std::string value)
Definition: packed_func.h:645
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(ObjectRef other)
Definition: packed_func.h:661
data_type.h
ndarray.h
A device-independent managed NDArray abstraction.
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(const PythonArg &value)
Definition: packed_func.h:695
kObjectHandle
@ kObjectHandle
Definition: c_runtime_api.h:49
mxnet::runtime::detail::unpack_call_dispatcher< void, 0, index, F >::run
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&... unpacked_args)
Definition: packed_func.h:1134
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(const MXNetArgValue &other)
Definition: packed_func.h:680
container_ext.h
Common POD(plain old data) container types extension.
DLDataType::code
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint,...
Definition: dlpack.h:100
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const char *value) const
Definition: packed_func.h:1043
mxnet::NDArrayHandleObj
Definition: ndarray_handle.h:31
MXNetValue::v_float64
double v_float64
Definition: c_runtime_api.h:74
kMXNetType
@ kMXNetType
Definition: c_runtime_api.h:47
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const MXNetArgValue &value) const
Definition: packed_func.h:1035
mxnet::NDArray
ndarray interface
Definition: ndarray.h:82
mxnet::runtime::detail::typed_packed_call_dispatcher::run
static R run(const PackedFunc &pf, Args &&... args)
Definition: packed_func.h:1155
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(const MXNetRetValue &other)
Definition: packed_func.h:676
mxnet::runtime::ObjectRef::as
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:804
mshadow::kInt64
@ kInt64
Definition: base.h:358
mxnet::runtime::MXNetRetValue::MXNetRetValue
MXNetRetValue()
default constructor
Definition: packed_func.h:558
DLDataType::bits
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
kDLInt
@ kDLInt
Definition: dlpack.h:80
mxnet::runtime::ObjectTypeChecker::Check
static bool Check(const Object *ptr)
Definition: packed_func.h:385
MXNetValue::v_type
DLDataType v_type
Definition: c_runtime_api.h:78
mxnet::runtime::MXNetArgValue::IsObjectRef
bool IsObjectRef() const
Definition: packed_func.h:1283
mshadow::kInt32
@ kInt32
Definition: base.h:356
mxnet::runtime::TypedPackedFunc< R(Args...)>::packed
const PackedFunc & packed() const
Definition: packed_func.h:294
mxnet::runtime::MXNetArgs::num_args
int num_args
Definition: packed_func.h:326
mxnet::runtime::PackedFunc::body
FType body() const
Definition: packed_func.h:985
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(NDArray *value)
Definition: packed_func.h:684
mxnet::runtime::MXNetArgValue::value
const MXNetValue & value() const
Definition: packed_func.h:540
mxnet::runtime::MXNetPODValue_::MXNetPODValue_
MXNetPODValue_()
Definition: packed_func.h:465
mxnet::runtime::TypedPackedFunc< R(Args...)>::operator!=
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:302
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const ObjectRef &value) const
Definition: packed_func.h:1069
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(DLDataType t)
Definition: packed_func.h:649
mxnet::runtime::ObjectTypeChecker
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:384
kDLUInt
@ kDLUInt
Definition: dlpack.h:81
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(const T &other)
Definition: packed_func.h:701
mxnet::runtime::ObjectRef::data_
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:575
kHandle
@ kHandle
Definition: c_runtime_api.h:45
mxnet::runtime::PackedFunc::FType
std::function< void(MXNetArgs args, MXNetRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:100
mxnet::runtime::TypedPackedFunc< R(Args...)>
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally.
Definition: packed_func.h:187
mxnet::runtime::MXNetArgValue::MXNetArgValue
MXNetArgValue()
default constructor
Definition: packed_func.h:483
mxnet::runtime::PackedFunc::PackedFunc
PackedFunc()
default constructor
Definition: packed_func.h:102
mxnet::MXNetDataType
runtime::MXNetDataType MXNetDataType
Definition: data_type.h:210
mxnet::NDArrayHandle
Definition: ndarray_handle.h:40
mxnet::runtime::detail::for_each_dispatcher
Definition: packed_func.h:993
mxnet::runtime::MXNetArgValue::MXNetArgValue
MXNetArgValue(MXNetValue value, int type_code)
constructor
Definition: packed_func.h:489
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, const MXNetByteArray &value) const
Definition: packed_func.h:1061
mxnet::runtime::String::CanConvertFrom
static bool CanConvertFrom(const MXNetArgValue &val)
Check if a MXNetArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1289
mxnet::runtime::MXNetRetValue::MoveToCHost
void MoveToCHost(MXNetValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null....
Definition: packed_func.h:714
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, void *value) const
Definition: packed_func.h:1039
mxnet::runtime::PackedFuncValueConverter<::mxnet::runtime::String >::From
static String From(const MXNetArgValue &val)
Definition: packed_func.h:1236
mxnet::runtime::ObjectRef
Base class of all object reference.
Definition: object.h:500
mxnet::runtime::MXNetArgs::operator[]
MXNetArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:971
kPyArg
@ kPyArg
Definition: c_runtime_api.h:52
mxnet::runtime::MXNetRetValue::MXNetRetValue
MXNetRetValue(MXNetRetValue &&other)
move constructor from anoter return value.
Definition: packed_func.h:563
mshadow::kUint8
@ kUint8
Definition: base.h:355
mxnet::runtime::MXNetArgs::MXNetArgs
MXNetArgs(const MXNetValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:333
mshadow::kBfloat16
@ kBfloat16
Definition: base.h:364
ndarray_handle.h
NDArray handle types.
mxnet::runtime::PackedFuncValueConverter::From
static TObjectRef From(const MXNetArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1221
mxnet::runtime::PythonArg
Definition: py_arg.h:29
mxnet::runtime::MXNetDataType
Runtime primitive data type.
Definition: data_type.h:40
mxnet::runtime::detail::call_packed
R call_packed(const PackedFunc &pf, Args &&... args)
Definition: packed_func.h:1148
mxnet::runtime::detail::for_each
void for_each(const F &f, Args &&... args)
Definition: packed_func.h:1007
mxnet::runtime::operator<<
std::ostream & operator<<(std::ostream &out, const String &input)
Definition: container_ext.h:873
mxnet::runtime::MXNetRetValue::value
const MXNetValue & value() const
Definition: packed_func.h:722
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(MXNetRetValue &&other)
Definition: packed_func.h:608
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1031
DLDataType::lanes
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
mxnet::runtime::ObjectRef::defined
bool defined() const
Definition: object.h:539
mxnet::runtime::PackedFunc
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:80
mxnet::runtime::MXNetRetValue::IsObjectRef
bool IsObjectRef() const
Definition: packed_func.h:1283
ndarray.h
NDArray interface that handles array arithematics.
mxnet::runtime::TypedPackedFunc< R(Args...)>::operator=
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:274
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, double value) const
Definition: packed_func.h:1027
mxnet::runtime::TypedPackedFunc< R(Args...)>::TypedPackedFunc
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:242
mxnet::runtime::MXNetPODValue_::MXNetPODValue_
MXNetPODValue_(MXNetValue value, int type_code)
Definition: packed_func.h:466
mxnet::runtime::TypeCode2Str
const char * TypeCode2Str(int type_code)
Convert type code to its name.
Definition: packed_func.h:855
mxnet::runtime::detail::typed_packed_call_dispatcher< void >::run
static void run(const PackedFunc &pf, Args &&... args)
Definition: packed_func.h:1163
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(MXNetByteArray value)
Definition: packed_func.h:657
kBytes
@ kBytes
Definition: c_runtime_api.h:51
mxnet::runtime::String
Reference to string objects.
Definition: container_ext.h:490
mshadow::kFloat16
@ kFloat16
Definition: base.h:354
mxnet::runtime::PackedFunc::PackedFunc
PackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:104
mxnet::runtime::MXNetArgsSetter
Definition: packed_func.h:1013
MXNetValue::v_uint64
uint64_t v_uint64
Definition: c_runtime_api.h:77
mxnet::runtime::detail::unpack_call_dispatcher::run
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&... unpacked_args)
Definition: packed_func.h:1111
mxnet::runtime::MXNetArgs::values
const MXNetValue * values
Definition: packed_func.h:324
base.h
configuration of MXNet as well as basic data structure.
mxnet::runtime::PackedFunc::operator()
MXNetRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1096
MXNetValue::v_int64
int64_t v_int64
Definition: c_runtime_api.h:73
mxnet::runtime::PackedFuncValueConverter<::mxnet::runtime::String >::From
static String From(const MXNetRetValue &val)
Definition: packed_func.h:1244
MXNetByteArray
Byte array type used to pass in byte array When kBytes is used as data type.
Definition: c_runtime_api.h:85
object.h
A managed object in MXNet runtime.
container.h
Array container.
mxnet::runtime::MXNetRetValue::operator=
MXNetRetValue & operator=(void *value)
Definition: packed_func.h:625
mshadow::kFloat32
@ kFloat32
Definition: base.h:352
mxnet::runtime::String2MXNetType
int String2MXNetType(const std::string &s)
Definition: packed_func.h:915
mxnet::runtime::MXNetArgsSetter::operator()
void operator()(size_t i, T value) const
Definition: packed_func.h:1018
mxnet::runtime::TypedPackedFunc< R(Args...)>::operator=
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:265