Go to the documentation of this file.
24 #ifndef MXNET_RUNTIME_OBJECT_H_
25 #define MXNET_RUNTIME_OBJECT_H_
27 #include <dmlc/logging.h>
28 #include <type_traits>
39 #ifndef MXNET_OBJECT_ATOMIC_REF_COUNTER
40 #define MXNET_OBJECT_ATOMIC_REF_COUNTER 1
43 #if MXNET_OBJECT_ATOMIC_REF_COUNTER
45 #endif // MXNET_OBJECT_ATOMIC_REF_COUNTER
180 template <
typename TargetType>
202 #if MXNET_OBJECT_ATOMIC_REF_COUNTER
259 "RefCounter ABI check.");
279 uint32_t static_tindex,
280 uint32_t parent_tindex,
281 uint32_t type_child_slots,
282 bool type_child_slots_can_overflow);
298 inline int use_count()
const;
304 MXNET_DLL bool DerivedFrom(uint32_t parent_tindex)
const;
326 template <
typename RefType,
typename ObjectType>
327 inline RefType
GetRef(
const ObjectType* ptr);
337 template <
typename SubRef,
typename BaseRef>
338 inline SubRef
Downcast(BaseRef ref);
345 template <
typename T>
362 template <
typename U>
365 static_assert(std::is_base_of<T, U>::value,
366 "can only assign of child class ObjectPtr to parent");
373 : data_(other.data_) {
374 other.data_ =
nullptr;
380 template <
typename Y>
382 : data_(other.data_) {
383 static_assert(std::is_base_of<T, Y>::value,
384 "can only assign of child class ObjectPtr to parent");
385 other.data_ =
nullptr;
396 std::swap(data_, other.data_);
402 return static_cast<T*
>(data_);
439 if (data_ !=
nullptr) {
446 return data_ !=
nullptr ? data_->use_count() : 0;
450 return data_ !=
nullptr && data_->use_count() == 1;
454 return data_ == other.data_;
458 return data_ != other.data_;
462 return data_ ==
nullptr;
466 return data_ !=
nullptr;
477 if (data !=
nullptr) {
493 template <
typename RefType,
typename ObjType>
494 friend RefType
GetRef(
const ObjType* ptr);
495 template <
typename BaseType,
typename ObjType>
540 return data_ !=
nullptr;
552 return data_.unique();
565 template <
typename ObjectType>
566 inline const ObjectType*
as()
const;
586 template <
typename T>
588 return T(std::move(ref.
data_));
596 template <
typename ObjectType>
604 template <
typename SubRef,
typename BaseRef>
605 friend SubRef
Downcast(BaseRef ref);
616 template <
typename BaseType,
typename ObjectType>
625 template <
typename T>
627 return std::hash<Object*>()(a.
get());
637 template <
typename T>
648 #define MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
649 static uint32_t RuntimeTypeIndex() { \
650 return TypeName::_type_index != ::mxnet::runtime::TypeIndex::kDynamic ? \
651 TypeName::_type_index : \
652 _GetOrAllocRuntimeTypeIndex(); \
654 static uint32_t _GetOrAllocRuntimeTypeIndex() { \
655 static uint32_t tidx = GetOrAllocRuntimeTypeIndex(TypeName::_type_key, \
656 TypeName::_type_index, \
657 ParentType::_GetOrAllocRuntimeTypeIndex(), \
658 TypeName::_type_child_slots, \
659 TypeName::_type_child_slots_can_overflow); \
668 #define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
669 static const constexpr bool _type_final = true; \
670 static const constexpr int _type_child_slots = 0; \
671 MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
679 #define MXNET_REGISTER_OBJECT_TYPE(TypeName) \
680 static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx##_##TypeName##__ = \
681 TypeName::_GetOrAllocRuntimeTypeIndex()
683 #define MXNET_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
684 TypeName(const TypeName& other) = default; \
685 TypeName(TypeName&& other) = default; \
686 TypeName& operator=(const TypeName& other) = default; \
687 TypeName& operator=(TypeName&& other) = default;
689 #define MXNET_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
691 explicit TypeName(::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) : ParentType(n) {} \
692 const ObjectName* operator->() const { \
693 return static_cast<const ObjectName*>(data_.get()); \
695 operator bool() const { \
696 return data_ != nullptr; \
698 using ContainerType = ObjectName;
700 #define MXNET_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
702 explicit TypeName(::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) : ParentType(n) {} \
703 ObjectName* operator->() { \
704 return static_cast<ObjectName*>(data_.get()); \
706 operator bool() const { \
707 return data_ != nullptr; \
709 using ContainerType = ObjectName;
711 #define MXNET_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
712 explicit TypeName(::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) : ParentType(n) {} \
713 MXNET_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
714 const ObjectName* operator->() const { \
715 return static_cast<const ObjectName*>(data_.get()); \
717 const ObjectName* get() const { \
718 return operator->(); \
720 static constexpr bool _type_is_nullable = false; \
721 using ContainerType = ObjectName;
725 #if MXNET_OBJECT_ATOMIC_REF_COUNTER
732 if (
ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
733 std::atomic_thread_fence(std::memory_order_acquire);
734 if (this->deleter_ !=
nullptr) {
740 inline int Object::use_count()
const {
751 if (--ref_counter == 0) {
752 if (this->deleter_ !=
nullptr) {
758 inline int Object::use_count()
const {
762 #endif // MXNET_OBJECT_ATOMIC_REF_COUNTER
764 template <
typename TargetType>
766 const Object*
self =
this;
769 if (
self !=
nullptr) {
771 if (std::is_same<TargetType, Object>::value)
773 if (TargetType::_type_final) {
776 return self->
type_index_ == TargetType::RuntimeTypeIndex();
780 uint32_t begin = TargetType::RuntimeTypeIndex();
782 if (TargetType::_type_child_slots != 0) {
783 uint32_t end = begin + TargetType::_type_child_slots;
784 if (self->type_index_ >= begin && self->type_index_ < end)
787 if (self->type_index_ == begin)
790 if (!TargetType::_type_child_slots_can_overflow)
793 if (self->type_index_ < TargetType::RuntimeTypeIndex())
796 return self->DerivedFrom(TargetType::RuntimeTypeIndex());
803 template <
typename ObjectType>
805 if (
data_ !=
nullptr &&
data_->IsInstance<ObjectType>()) {
806 return static_cast<ObjectType*
>(
data_.get());
812 template <
typename RefType,
typename ObjType>
813 inline RefType
GetRef(
const ObjType* ptr) {
814 static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
815 "Can only cast to the ref of same container type");
816 if (!RefType::_type_is_nullable) {
817 CHECK(ptr !=
nullptr);
822 template <
typename BaseType,
typename ObjType>
824 static_assert(std::is_base_of<BaseType, ObjType>::value,
825 "Can only cast to the ref of same container type");
829 template <
typename SubRef,
typename BaseRef>
832 CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
833 <<
"Downcast from " << ref->GetTypeKey() <<
" to " << SubRef::ContainerType::_type_key
836 CHECK(SubRef::_type_is_nullable) <<
"Downcast from nullptr to not nullable reference of "
837 << SubRef::ContainerType::_type_key;
839 return SubRef(std::move(ref.data_));
844 template <
typename T>
849 #endif // MXNET_RUNTIME_OBJECT_H_
static MXNET_DLL uint32_t GetOrAllocRuntimeTypeIndex(const std::string &key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow)
Get the type index using type key.
namespace of mxnet
Definition: api_registry.h:33
ObjectPtr(ObjectPtr< T > &&other)
move constructor
Definition: object.h:372
TypeIndex
list of the type index.
Definition: object.h:51
static uint32_t RuntimeTypeIndex()
Definition: object.h:213
size_t operator()(const ObjectPtr< T > &a, const ObjectPtr< T > &b) const
Definition: object.h:638
const Object * get() const
Definition: object.h:543
base class of all object containers.
Definition: object.h:151
void IncRef()
developer function, increases reference counter.
Definition: object.h:727
ObjectPtr(std::nullptr_t)
default constructor
Definition: object.h:351
void(* FDeleter)(Object *self)
Object deleter.
Definition: object.h:157
ObjectPtr< BaseType > GetObjectPtr(ObjectType *ptr)
Get an object ptr type from a raw object ptr.
ObjectPtr(ObjectPtr< Y > &&other)
move constructor
Definition: object.h:381
size_t operator()(const ObjectPtr< T > &a) const
Definition: object.h:626
static MXNET_DLL std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
A custom smart pointer for Object.
Definition: object.h:346
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:480
@ kMXNetADT
Definition: object.h:56
bool operator!=(std::nullptr_t null) const
Definition: object.h:465
std::atomic< int32_t > RefCounterType
Definition: object.h:203
Internal base class to handle conversion to POD values.
Definition: packed_func.h:401
static ObjectPtr< ObjectType > GetDataPtr(const ObjectRef &ref)
Internal helper function get data_ as ObjectPtr of ObjectType.
Definition: object.h:597
RefType GetRef(const ObjectType *ptr)
Get a reference type from a raw object ptr type.
ObjectPtr< T > & operator=(const ObjectPtr< T > &other)
copy assignmemt
Definition: object.h:421
T * get() const
Definition: object.h:401
friend RefType GetRef(const ObjType *ptr)
Definition: object.h:813
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:511
friend ObjectPtr< BaseType > GetObjectPtr(ObjType *ptr)
Definition: object.h:823
SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:830
static constexpr bool _type_child_slots_can_overflow
Definition: object.h:220
static uint32_t _GetOrAllocRuntimeTypeIndex()
Definition: object.h:210
Object & operator=(Object &&other)
Definition: object.h:240
static MXNET_DLL size_t TypeIndex2KeyHash(uint32_t tindex)
Get the type key hash of the corresponding index from runtime.
bool IsInstance() const
Definition: object.h:765
@ kRoot
Root object type.
Definition: object.h:53
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:555
static MXNET_DLL uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
@ kMXNetClosure
Definition: object.h:55
friend class ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: object.h:486
T & operator*() const
Definition: object.h:413
@ kSlice
Definition: object.h:60
@ kEllipsis
Definition: object.h:59
static T DowncastNoCheck(ObjectRef ref)
Internal helper function downcast a ref without check.
Definition: object.h:587
bool operator==(std::nullptr_t null) const
Definition: object.h:461
size_t GetTypeKeyHash() const
Definition: object.h:172
static constexpr uint32_t _type_index
Definition: object.h:224
static constexpr uint32_t _type_child_slots
Definition: object.h:219
bool unique() const
Definition: object.h:449
int use_count() const
Definition: object.h:445
@ kFloat
Definition: object.h:62
ObjectPtr(const ObjectPtr< U > &other)
copy constructor
Definition: object.h:363
@ kStaticIndexEnd
Definition: object.h:63
@ kInteger
Definition: object.h:61
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:804
Object(const Object &other)
Definition: object.h:233
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:53
bool operator()(const ObjectRef &a, const ObjectRef &b) const
Definition: object.h:633
ObjectPtr< T > & operator=(ObjectPtr< T > &&other)
move assignmemt
Definition: object.h:432
ObjectRef()=default
default constructor
void reset()
reset the content of ptr to be nullptr
Definition: object.h:438
bool unique() const
Definition: object.h:551
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:575
Object & operator=(const Object &other)
Definition: object.h:237
ObjectRef hash functor.
Definition: object.h:620
T * operator->() const
Definition: object.h:407
const Object * operator->() const
Definition: object.h:547
ObjectRef(ObjectPtr< Object > data)
Constructor from existing object ptr.
Definition: object.h:505
Object()
Definition: object.h:227
static constexpr const char * _type_key
Definition: object.h:208
Base class of object allocators that implements make. Use curiously recurring template pattern.
Definition: memory.h:60
bool operator<(const ObjectRef &other) const
Comparator.
Definition: object.h:535
bool operator!=(const ObjectRef &other) const
Comparator.
Definition: object.h:527
@ kMXNetMap
Definition: object.h:57
Object(Object &&other)
Definition: object.h:235
Base class of all object reference.
Definition: object.h:500
size_t operator()(const ObjectRef &a) const
Definition: object.h:621
RefCounterType ref_counter_
The internal reference counter.
Definition: object.h:249
static constexpr bool _type_final
Definition: object.h:218
Object * get_mutable() const
Definition: object.h:577
FDeleter deleter_
deleter of this object to enable customized allocation. If the deleter is nullptr,...
Definition: object.h:255
void DecRef()
developer function, decrease reference counter.
Definition: object.h:731
ObjectPtr(const ObjectPtr< T > &other)
copy constructor
Definition: object.h:356
ObjectRef equal functor.
Definition: object.h:632
std::string GetTypeKey() const
Definition: object.h:166
@ kDynamic
Type index is allocated during runtime.
Definition: object.h:65
~ObjectPtr()
destructor
Definition: object.h:388
uint32_t type_index_
Type index(tag) that indicates the type of the object.
Definition: object.h:247
bool defined() const
Definition: object.h:539
@ kMXNetTensor
Definition: object.h:54
bool operator!=(const ObjectPtr< T > &other) const
Definition: object.h:457
bool operator==(const ObjectPtr< T > &other) const
Definition: object.h:453
friend SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:830
uint32_t type_index() const
Definition: object.h:159
Definition: packed_func.h:1013
void swap(ObjectPtr< T > &other)
Swap this array with another Object.
Definition: object.h:395
static constexpr bool _type_is_nullable
Definition: object.h:571
ObjectPtr()
default constructor
Definition: object.h:349
friend class ObjectInternal
Definition: object.h:311
@ kMXNetString
Definition: object.h:58
bool operator==(const ObjectRef &other) const
Comparator.
Definition: object.h:519