Go to the documentation of this file.
25 #ifndef MXNET_RUNTIME_PACKED_FUNC_H_
26 #define MXNET_RUNTIME_PACKED_FUNC_H_
28 #include <dmlc/logging.h>
49 #include <type_traits>
70 class MXNetArgsSetter;
124 template <
typename... Args>
136 return body_ ==
nullptr;
140 return body_ !=
nullptr;
151 template <
typename FType>
186 template <
typename R,
typename... Args>
238 template <
typename FLambda,
239 typename =
typename std::enable_if<
240 std::is_convertible<FLambda,
241 std::function<R(Args...)>>::value>::type>
243 this->AssignTypedLambda(typed_lambda);
261 template <
typename FLambda,
262 typename =
typename std::enable_if<
263 std::is_convertible<FLambda,
264 std::function<R(Args...)>>::value>::type>
266 this->AssignTypedLambda(typed_lambda);
283 inline R operator()(Args... args)
const;
299 return packed_ ==
nullptr;
303 return packed_ !=
nullptr;
317 template <
typename FLambda>
318 inline void AssignTypedLambda(FLambda flambda);
336 inline int size()
const;
360 #define MXNET_CHECK_TYPE_CODE(CODE, T) \
361 CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)
374 template <
typename T>
383 template <
typename T>
386 using ContainerType =
typename T::ContainerType;
388 return T::_type_is_nullable;
392 using ContainerType =
typename T::ContainerType;
393 return ContainerType::_type_key;
403 operator double()
const {
413 operator int64_t()
const {
417 operator uint64_t()
const {
421 operator int()
const {
426 operator bool()
const {
430 operator void*()
const {
443 template <
typename TObjectRef,
444 typename =
typename std::enable_if<std::is_class<TObjectRef>::value>::type>
446 template <
typename TObjectRef>
457 template <
typename T>
491 using MXNetPODValue_::operator double;
492 using MXNetPODValue_::operator int64_t;
493 using MXNetPODValue_::operator uint64_t;
494 using MXNetPODValue_::operator int;
495 using MXNetPODValue_::operator bool;
496 using MXNetPODValue_::operator
void*;
497 using MXNetPODValue_::operator
ObjectRef;
502 operator std::string()
const {
505 return std::string(arr->
data, arr->
size);
529 operator ::mxnet::NDArray*()
const {
536 template <
typename FType>
543 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
544 inline operator T()
const;
565 other.type_code_ =
kNull;
572 using MXNetPODValue_::operator double;
573 using MXNetPODValue_::operator int64_t;
574 using MXNetPODValue_::operator uint64_t;
575 using MXNetPODValue_::operator int;
576 using MXNetPODValue_::operator bool;
577 using MXNetPODValue_::operator
void*;
578 using MXNetPODValue_::operator
ObjectRef;
586 operator std::string()
const {
588 return *ptr<std::string>();
591 return *ptr<std::string>();
603 template <
typename FType>
612 other.type_code_ =
kNull;
621 this->SwitchToPOD(
kNull);
631 this->SwitchToPOD(
kDLInt);
636 this->SwitchToPOD(
kDLInt);
641 this->SwitchToPOD(
kDLInt);
663 return operator=(Downcast<NDArrayHandle, ObjectRef>(other));
667 template <
typename T>
672 template <
typename FType>
696 this->SwitchToPOD(
kPyArg);
700 template <typename T, typename = typename std::enable_if<extension_type_info<T>::code != 0>::type>
724 <<
"MXNetRetValue.value can only be used for POD data";
728 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
729 inline operator T()
const;
732 template <
typename T>
733 void Assign(
const T& other) {
734 switch (other.type_code()) {
736 SwitchToClass<std::string>(
kStr, other);
740 SwitchToClass<std::string>(
kBytes, other);
744 *
this = other.operator ObjectRef();
749 SwitchToPOD(other.type_code());
752 LOG(FATAL) <<
"Does not support ext type";
765 template <
typename T>
775 void SwitchToObject(
int type_code, ObjectPtr<Object> other) {
776 if (other.data_ !=
nullptr) {
781 other.data_ =
nullptr;
791 delete ptr<std::string>();
799 LOG(FATAL) <<
"Does not support ext type";
808 if (s.length() == 0) {
816 const char* scan =
nullptr;
817 if (s.substr(0, 3) ==
"int") {
819 scan = s.c_str() + 3;
820 }
else if (s.substr(0, 4) ==
"uint") {
822 scan = s.c_str() + 4;
823 }
else if (s.substr(0, 5) ==
"float") {
825 scan = s.c_str() + 5;
826 }
else if (s.substr(0, 6) ==
"handle") {
829 scan = s.c_str() + 6;
830 }
else if (s ==
"bool") {
835 }
else if (s.substr(0, 6) ==
"custom") {
836 LOG(FATAL) <<
"custom MXNetDataType is not supported";
840 LOG(FATAL) <<
"unknown type " << s;
843 uint8_t bits =
static_cast<uint8_t
>(strtoul(scan, &xdelim, 10));
846 char* endpt = xdelim;
847 if (*xdelim ==
'x') {
848 t.
lanes =
static_cast<uint16_t
>(strtoul(xdelim + 1, &endpt, 10));
850 CHECK(endpt == s.c_str() + s.length()) <<
"unknown type " << s;
876 LOG(FATAL) <<
"unknown type_code=" <<
static_cast<int>(type_code);
882 if (s ==
"float32") {
884 }
else if (s ==
"float64") {
886 }
else if (s ==
"float16") {
888 }
else if (s ==
"bfloat16") {
890 }
else if (s ==
"uint8") {
892 }
else if (s ==
"int8") {
894 }
else if (s ==
"int32") {
896 }
else if (s ==
"int64") {
898 }
else if (s ==
"bool") {
900 }
else if (s ==
"int16") {
902 }
else if (s ==
"uint16") {
904 }
else if (s ==
"uint32") {
906 }
else if (s ==
"uint64") {
909 LOG(FATAL) <<
"unknown type " << s;
911 LOG(FATAL) <<
"should not reach here ";
916 if (s ==
"float32") {
918 }
else if (s ==
"float64") {
920 }
else if (s ==
"float16") {
922 }
else if (s ==
"bfloat16") {
924 }
else if (s ==
"uint8") {
926 }
else if (s ==
"int8") {
928 }
else if (s ==
"int32") {
930 }
else if (s ==
"int64") {
932 }
else if (s ==
"int16") {
934 }
else if (s ==
"uint16") {
936 }
else if (s ==
"uint32") {
938 }
else if (s ==
"uint64") {
941 LOG(FATAL) <<
"unknown type " << s;
943 LOG(FATAL) <<
"should not reach here ";
955 LOG(FATAL) <<
"custom MXNetDataType is not supported";
960 os << static_cast<int>(t.
bits);
962 os << 'x' << static_cast<int>(t.
lanes);
972 CHECK_LT(i,
num_args) <<
"not enough argument passed, " <<
num_args <<
" passed"
973 <<
" but request arg[" << i <<
"].";
992 template <
bool stop, std::
size_t I,
typename F>
994 template <
typename T,
typename... Args>
995 static void run(
const F& f, T&& value, Args&&... args) {
996 f(I, std::forward<T>(value));
1001 template <std::
size_t I,
typename F>
1006 template <
typename F,
typename... Args>
1017 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1019 values_[i].
v_int64 =
static_cast<int64_t
>(value);
1023 values_[i].
v_int64 =
static_cast<int64_t
>(value);
1024 CHECK_LE(value,
static_cast<uint64_t
>(std::numeric_limits<int64_t>::max()));
1033 type_codes_[i] =
kNull;
1036 values_[i] = value.
value_;
1044 values_[i].
v_str = value;
1045 type_codes_[i] =
kStr;
1051 values_[i].
v_str = value.c_str();
1052 type_codes_[i] =
kStr;
1055 values_[i].
v_type = value;
1065 template <
typename FType>
1074 type_codes_[i] =
kNull;
1079 values_[i].
v_str = value.
ptr<std::string>()->c_str();
1080 type_codes_[i] =
kStr;
1083 values_[i] = value.
value_;
1095 template <
typename... Args>
1097 const int kNumArgs =
sizeof...(Args);
1098 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1100 int type_codes[kArraySize];
1103 body_(
MXNetArgs(values, type_codes, kNumArgs), &rv);
1108 template <
typename R,
int nleft,
int index,
typename F>
1110 template <
typename... Args>
1114 Args&&... unpacked_args) {
1116 f, args_pack, rv, std::forward<Args>(unpacked_args)..., args_pack[index]);
1120 template <
typename R,
int index,
typename F>
1122 template <
typename... Args>
1126 Args&&... unpacked_args) {
1127 *rv = R(f(std::forward<Args>(unpacked_args)...));
1131 template <
int index,
typename F>
1133 template <
typename... Args>
1137 Args&&... unpacked_args) {
1138 f(std::forward<Args>(unpacked_args)...);
1142 template <
typename R,
int nargs,
typename F>
1147 template <
typename R,
typename... Args>
1149 return R(pf(std::forward<Args>(args)...));
1152 template <
typename R>
1154 template <
typename... Args>
1156 return pf(std::forward<Args>(args)...);
1162 template <
typename... Args>
1164 pf(std::forward<Args>(args)...);
1169 template <
typename R,
typename... Args>
1172 template <
typename R,
typename... Args>
1176 template <
typename R,
typename... Args>
1180 template <
typename R,
typename... Args>
1181 template <
typename FType>
1188 template <
typename R,
typename... Args>
1195 template <
typename T,
typename TSrc,
bool is_ext,
bool is_nd>
1198 static_assert(!is_ext && !is_nd,
"The default case accepts only non-extensions");
1199 return self->template AsObjectRef<T>();
1214 template <
typename TObjectRef>
1253 template <
typename TObjectRef>
1255 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1256 "Conversion only works for ObjectRef");
1257 using ContainerType =
typename TObjectRef::ContainerType;
1260 CHECK(TObjectRef::_type_is_nullable)
1261 <<
"Expect a not null value of " << ContainerType::_type_key;
1269 <<
ptr->GetTypeKey();
1270 return TObjectRef(GetObjectPtr<Object>(
ptr));
1277 template <
typename T,
typename>
1278 inline MXNetArgValue::operator T()
const {
1282 template <
typename TObjectRef,
typename>
1284 using ContainerType =
typename TObjectRef::ContainerType;
1295 #endif // MXNET_RUNTIME_PACKED_FUNC_H_
static const int code
Definition: packed_func.h:376
namespace of mxnet
Definition: api_registry.h:33
@ kDLFloat
Definition: dlpack.h:82
MXNetRetValue & operator=(NDArrayHandle value)
Definition: packed_func.h:689
The data type the tensor can hold.
Definition: dlpack.h:94
void operator()(size_t i, MXNetDataType dtype) const
Definition: packed_func.h:1058
TObjectRef AsObjectRef() const
Definition: packed_func.h:1254
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:72
base class of all object containers.
Definition: object.h:151
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:139
Common POD(plain old data) container types.
@ kUint16
Definition: base.h:361
@ kUint64
Definition: base.h:363
static T Apply(const TSrc *self)
Definition: packed_func.h:1197
void unpack_call(const F &f, const MXNetArgs &args, MXNetRetValue *rv)
Definition: packed_func.h:1143
static TObjectRef From(const MXNetRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1229
MXNetRetValue & operator=(ObjectPtr< T > other)
Definition: packed_func.h:668
Definition: packed_func.h:1196
Definition: packed_func.h:1153
MXNetRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:673
Type traits to mark if a class is tvm extension type.
Definition: packed_func.h:375
void * v_handle
Definition: c_runtime_api.h:75
void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1050
A custom smart pointer for Object.
Definition: object.h:346
void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1054
@ kStr
Definition: c_runtime_api.h:50
int size() const
Definition: packed_func.h:977
@ kInt8
Definition: base.h:357
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:480
@ kUint32
Definition: base.h:362
size_t size
Definition: c_runtime_api.h:87
~MXNetRetValue()
destructor
Definition: packed_func.h:568
static void run(const F &f, T &&value, Args &&... args)
Definition: packed_func.h:995
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:458
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:135
void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1022
Type trait to specify special value conversion rules from MXNetArgValue and MXNetRetValue.
Definition: packed_func.h:1215
#define MXNET_CHECK_TYPE_CODE(CODE, T)
convert a string to TVM type.
Definition: packed_func.h:360
MXNetRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:620
void operator()(size_t i, const MXNetRetValue &value) const
Definition: packed_func.h:1077
Internal base class to handle conversion to POD values.
Definition: packed_func.h:401
@ kNDArrayHandle
Definition: c_runtime_api.h:53
@ kCustomBegin
Definition: c_runtime_api.h:65
const char * v_str
Definition: c_runtime_api.h:76
MXNetRetValue & operator=(double value)
Definition: packed_func.h:615
int String2MXNetTypeWithBool(const std::string &s)
Definition: packed_func.h:881
PackedFunc(FType body)
constructing a packed function from a std::function.
Definition: packed_func.h:109
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
TObjectRef AsObjectRef() const
Definition: packed_func.h:1254
@ kExtBegin
Definition: c_runtime_api.h:58
@ kBool
Definition: base.h:359
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&... unpacked_args)
Definition: packed_func.h:1123
void CallPacked(MXNetArgs args, MXNetRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:981
MXNetRetValue & operator=(int value)
Definition: packed_func.h:635
Base expr nodes in MXNet.
MXNetRetValue(const MXNetRetValue &other)
Definition: packed_func.h:582
void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1066
const int * type_codes
Definition: packed_func.h:325
static void run(const F &f)
Definition: packed_func.h:1003
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: packed_func.h:805
@ kFloat64
Definition: base.h:353
@ kNull
Definition: c_runtime_api.h:46
bool IsObjectRef() const
Definition: packed_func.h:1283
bool IsInstance() const
Definition: object.h:765
MXNetValue value_
The value.
Definition: packed_func.h:469
int type_code() const
Definition: packed_func.h:448
MXNetRetValue & operator=(int64_t value)
Definition: packed_func.h:630
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:555
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:298
TypedPackedFunc()
default constructor
Definition: packed_func.h:192
MXNetArgsSetter(MXNetValue *values, int *type_codes)
Definition: packed_func.h:1015
int type_code_
the type code
Definition: packed_func.h:471
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:194
Arguments into TVM functions.
Definition: packed_func.h:322
@ kInt16
Definition: base.h:360
TObjectRef AsObjectRef() const
Definition: packed_func.h:1254
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:152
Definition: packed_func.h:1109
MXNetRetValue & operator=(bool value)
Definition: packed_func.h:640
const char * data
Definition: c_runtime_api.h:86
MXNetRetValue & operator=(const MXNetDataType &other)
Definition: packed_func.h:654
static std::string TypeName()
Definition: packed_func.h:391
MXNetRetValue & operator=(std::string value)
Definition: packed_func.h:645
MXNetRetValue & operator=(ObjectRef other)
Definition: packed_func.h:661
A device-independent managed NDArray abstraction.
MXNetRetValue & operator=(const PythonArg &value)
Definition: packed_func.h:695
@ kObjectHandle
Definition: c_runtime_api.h:49
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&... unpacked_args)
Definition: packed_func.h:1134
MXNetRetValue & operator=(const MXNetArgValue &other)
Definition: packed_func.h:680
Common POD(plain old data) container types extension.
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint,...
Definition: dlpack.h:100
void operator()(size_t i, const char *value) const
Definition: packed_func.h:1043
Definition: ndarray_handle.h:31
double v_float64
Definition: c_runtime_api.h:74
@ kMXNetType
Definition: c_runtime_api.h:47
void operator()(size_t i, const MXNetArgValue &value) const
Definition: packed_func.h:1035
ndarray interface
Definition: ndarray.h:82
static R run(const PackedFunc &pf, Args &&... args)
Definition: packed_func.h:1155
MXNetRetValue & operator=(const MXNetRetValue &other)
Definition: packed_func.h:676
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:804
@ kInt64
Definition: base.h:358
MXNetRetValue()
default constructor
Definition: packed_func.h:558
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
@ kDLInt
Definition: dlpack.h:80
static bool Check(const Object *ptr)
Definition: packed_func.h:385
DLDataType v_type
Definition: c_runtime_api.h:78
bool IsObjectRef() const
Definition: packed_func.h:1283
@ kInt32
Definition: base.h:356
const PackedFunc & packed() const
Definition: packed_func.h:294
int num_args
Definition: packed_func.h:326
FType body() const
Definition: packed_func.h:985
MXNetRetValue & operator=(NDArray *value)
Definition: packed_func.h:684
const MXNetValue & value() const
Definition: packed_func.h:540
MXNetPODValue_()
Definition: packed_func.h:465
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:302
void operator()(size_t i, const ObjectRef &value) const
Definition: packed_func.h:1069
MXNetRetValue & operator=(DLDataType t)
Definition: packed_func.h:649
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:384
@ kDLUInt
Definition: dlpack.h:81
MXNetRetValue & operator=(const T &other)
Definition: packed_func.h:701
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:575
@ kHandle
Definition: c_runtime_api.h:45
std::function< void(MXNetArgs args, MXNetRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:100
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally.
Definition: packed_func.h:187
MXNetArgValue()
default constructor
Definition: packed_func.h:483
PackedFunc()
default constructor
Definition: packed_func.h:102
runtime::MXNetDataType MXNetDataType
Definition: data_type.h:210
Definition: ndarray_handle.h:40
Definition: packed_func.h:993
MXNetArgValue(MXNetValue value, int type_code)
constructor
Definition: packed_func.h:489
void operator()(size_t i, const MXNetByteArray &value) const
Definition: packed_func.h:1061
static bool CanConvertFrom(const MXNetArgValue &val)
Check if a MXNetArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1289
void MoveToCHost(MXNetValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null....
Definition: packed_func.h:714
void operator()(size_t i, void *value) const
Definition: packed_func.h:1039
static String From(const MXNetArgValue &val)
Definition: packed_func.h:1236
Base class of all object reference.
Definition: object.h:500
MXNetArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:971
@ kPyArg
Definition: c_runtime_api.h:52
MXNetRetValue(MXNetRetValue &&other)
move constructor from anoter return value.
Definition: packed_func.h:563
@ kUint8
Definition: base.h:355
MXNetArgs(const MXNetValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:333
@ kBfloat16
Definition: base.h:364
static TObjectRef From(const MXNetArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1221
Runtime primitive data type.
Definition: data_type.h:40
R call_packed(const PackedFunc &pf, Args &&... args)
Definition: packed_func.h:1148
void for_each(const F &f, Args &&... args)
Definition: packed_func.h:1007
std::ostream & operator<<(std::ostream &out, const String &input)
Definition: container_ext.h:873
const MXNetValue & value() const
Definition: packed_func.h:722
MXNetRetValue & operator=(MXNetRetValue &&other)
Definition: packed_func.h:608
void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1031
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
bool defined() const
Definition: object.h:539
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:80
bool IsObjectRef() const
Definition: packed_func.h:1283
NDArray interface that handles array arithematics.
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:274
void operator()(size_t i, double value) const
Definition: packed_func.h:1027
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:242
MXNetPODValue_(MXNetValue value, int type_code)
Definition: packed_func.h:466
const char * TypeCode2Str(int type_code)
Convert type code to its name.
Definition: packed_func.h:855
static void run(const PackedFunc &pf, Args &&... args)
Definition: packed_func.h:1163
MXNetRetValue & operator=(MXNetByteArray value)
Definition: packed_func.h:657
@ kBytes
Definition: c_runtime_api.h:51
Reference to string objects.
Definition: container_ext.h:490
@ kFloat16
Definition: base.h:354
PackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:104
Definition: packed_func.h:1013
uint64_t v_uint64
Definition: c_runtime_api.h:77
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&... unpacked_args)
Definition: packed_func.h:1111
const MXNetValue * values
Definition: packed_func.h:324
configuration of MXNet as well as basic data structure.
MXNetRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1096
int64_t v_int64
Definition: c_runtime_api.h:73
static String From(const MXNetRetValue &val)
Definition: packed_func.h:1244
Byte array type used to pass in byte array When kBytes is used as data type.
Definition: c_runtime_api.h:85
A managed object in MXNet runtime.
MXNetRetValue & operator=(void *value)
Definition: packed_func.h:625
@ kFloat32
Definition: base.h:352
int String2MXNetType(const std::string &s)
Definition: packed_func.h:915
void operator()(size_t i, T value) const
Definition: packed_func.h:1018
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:265