24 #ifndef MXNET_RUNTIME_OBJECT_H_ 25 #define MXNET_RUNTIME_OBJECT_H_ 27 #include <dmlc/logging.h> 28 #include <type_traits> 39 #ifndef MXNET_OBJECT_ATOMIC_REF_COUNTER 40 #define MXNET_OBJECT_ATOMIC_REF_COUNTER 1 43 #if MXNET_OBJECT_ATOMIC_REF_COUNTER 45 #endif // MXNET_OBJECT_ATOMIC_REF_COUNTER 178 template<
typename TargetType>
200 #if MXNET_OBJECT_ATOMIC_REF_COUNTER 257 "RefCounter ABI check.");
277 const std::string& key,
278 uint32_t static_tindex,
279 uint32_t parent_tindex,
280 uint32_t type_child_slots,
281 bool type_child_slots_can_overflow);
297 inline int use_count()
const;
303 MXNET_DLL bool DerivedFrom(uint32_t parent_tindex)
const;
325 template <
typename RefType,
typename ObjectType>
326 inline RefType
GetRef(
const ObjectType* ptr);
336 template <
typename SubRef,
typename BaseRef>
337 inline SubRef
Downcast(BaseRef ref);
344 template <
typename T>
361 template <
typename U>
364 static_assert(std::is_base_of<T, U>::value,
365 "can only assign of child class ObjectPtr to parent");
372 : data_(other.data_) {
373 other.data_ =
nullptr;
379 template <
typename Y>
381 : data_(other.data_) {
382 static_assert(std::is_base_of<T, Y>::value,
383 "can only assign of child class ObjectPtr to parent");
384 other.data_ =
nullptr;
395 std::swap(data_, other.data_);
401 return static_cast<T*
>(data_);
438 if (data_ !=
nullptr) {
445 return data_ !=
nullptr ? data_->use_count() : 0;
449 return data_ !=
nullptr && data_->use_count() == 1;
453 return data_ == other.data_;
457 return data_ != other.data_;
461 return data_ ==
nullptr;
465 return data_ !=
nullptr;
476 if (data !=
nullptr) {
492 template <
typename RefType,
typename ObjType>
493 friend RefType
GetRef(
const ObjType* ptr);
494 template <
typename BaseType,
typename ObjType>
511 return data_ == other.
data_;
519 return data_ == other.
data_;
527 return data_ != other.
data_;
535 return data_.get() < other.
data_.get();
539 return data_ !=
nullptr;
551 return data_.unique();
564 template <
typename ObjectType>
565 inline const ObjectType* as()
const;
585 return T(std::move(ref.
data_));
593 template<
typename ObjectType>
601 template <
typename SubRef,
typename BaseRef>
602 friend SubRef
Downcast(BaseRef ref);
613 template <
typename BaseType,
typename ObjectType>
619 return operator()(a.
data_);
624 return std::hash<Object*>()(a.
get());
647 #define MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ 648 static const uint32_t RuntimeTypeIndex() { \ 649 if (TypeName::_type_index != ::mxnet::runtime::TypeIndex::kDynamic) { \ 650 return TypeName::_type_index; \ 652 return _GetOrAllocRuntimeTypeIndex(); \ 654 static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ 655 static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ 656 TypeName::_type_key, \ 657 TypeName::_type_index, \ 658 ParentType::_GetOrAllocRuntimeTypeIndex(), \ 659 TypeName::_type_child_slots, \ 660 TypeName::_type_child_slots_can_overflow); \ 669 #define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ 670 static const constexpr bool _type_final = true; \ 671 static const constexpr int _type_child_slots = 0; \ 672 MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ 681 #define MXNET_REGISTER_OBJECT_TYPE(TypeName) \ 682 static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ 683 TypeName::_GetOrAllocRuntimeTypeIndex() 686 #define MXNET_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ 689 ::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) \ 691 const ObjectName* operator->() const { \ 692 return static_cast<const ObjectName*>(data_.get()); \ 694 operator bool() const { return data_ != nullptr; } \ 695 using ContainerType = ObjectName; 697 #define MXNET_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ 700 ::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) \ 702 ObjectName* operator->() { \ 703 return static_cast<ObjectName*>(data_.get()); \ 705 operator bool() const { return data_ != nullptr; } \ 706 using ContainerType = ObjectName; 710 #if MXNET_OBJECT_ATOMIC_REF_COUNTER 717 if (
ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
718 std::atomic_thread_fence(std::memory_order_acquire);
719 if (this->deleter_ !=
nullptr) {
725 inline int Object::use_count()
const {
736 if (--ref_counter == 0) {
737 if (this->deleter_ !=
nullptr) {
743 inline int Object::use_count()
const {
747 #endif // MXNET_OBJECT_ATOMIC_REF_COUNTER 749 template<
typename TargetType>
751 const Object*
self =
this;
754 if (
self !=
nullptr) {
756 if (std::is_same<TargetType, Object>::value)
return true;
757 if (TargetType::_type_final) {
760 return self->
type_index_ == TargetType::RuntimeTypeIndex();
764 uint32_t begin = TargetType::RuntimeTypeIndex();
766 if (TargetType::_type_child_slots != 0) {
767 uint32_t end = begin + TargetType::_type_child_slots;
768 if (self->type_index_ >= begin && self->type_index_ < end)
return true;
770 if (self->type_index_ == begin)
return true;
772 if (!TargetType::_type_child_slots_can_overflow)
return false;
774 if (self->type_index_ < TargetType::RuntimeTypeIndex())
return false;
776 return self->DerivedFrom(TargetType::RuntimeTypeIndex());
784 template <
typename ObjectType>
786 if (data_ !=
nullptr &&
787 data_->IsInstance<ObjectType>()) {
788 return static_cast<ObjectType*
>(data_.get());
794 template <
typename RefType,
typename ObjType>
795 inline RefType
GetRef(
const ObjType* ptr) {
796 static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
797 "Can only cast to the ref of same container type");
798 return RefType(
ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
801 template <
typename BaseType,
typename ObjType>
803 static_assert(std::is_base_of<BaseType, ObjType>::value,
804 "Can only cast to the ref of same container type");
808 template <
typename SubRef,
typename BaseRef>
810 CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
811 <<
"Downcast from " << ref->GetTypeKey() <<
" to " 812 << SubRef::ContainerType::_type_key <<
" failed.";
813 return SubRef(std::move(ref.data_));
823 #endif // MXNET_RUNTIME_OBJECT_H_ bool operator!=(const ObjectPtr< T > &other) const
Definition: object.h:456
ObjectPtr(const ObjectPtr< T > &other)
copy constructor
Definition: object.h:355
static T DowncastNoCheck(ObjectRef ref)
Internal helper function downcast a ref without check.
Definition: object.h:584
T * get() const
Definition: object.h:400
Object(const Object &other)
Definition: object.h:231
Object(Object &&other)
Definition: object.h:233
bool operator==(const ObjectPtr< T > &other) const
Definition: object.h:452
static MXNET_DLL std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
Object & operator=(const Object &other)
Definition: object.h:235
Object * get_mutable() const
Definition: object.h:574
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:510
namespace of mxnet
Definition: api_registry.h:33
bool operator()(const ObjectRef &a, const ObjectRef &b) const
Definition: object.h:631
static constexpr const char * _type_key
Definition: object.h:206
ObjectRef hash functor.
Definition: object.h:617
bool operator!=(const ObjectRef &other) const
Comparator.
Definition: object.h:526
const Object * operator->() const
Definition: object.h:546
size_t GetTypeKeyHash() const
Definition: object.h:170
size_t operator()(const ObjectPtr< T > &a, const ObjectPtr< T > &b) const
Definition: object.h:636
Object & operator=(Object &&other)
Definition: object.h:238
A custom smart pointer for Object.
Definition: object.h:345
friend class ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case...
Definition: object.h:308
size_t operator()(const ObjectPtr< T > &a) const
Definition: object.h:623
ObjectRef(ObjectPtr< Object > data)
Constructor from existing object ptr.
Definition: object.h:504
int use_count() const
Definition: object.h:444
bool IsInstance() const
Definition: object.h:750
static uint32_t RuntimeTypeIndex()
Definition: object.h:211
RefCounterType ref_counter_
The internal reference counter.
Definition: object.h:247
bool unique() const
Definition: object.h:550
ObjectPtr(const ObjectPtr< U > &other)
copy constructor
Definition: object.h:362
Base class of all object reference.
Definition: object.h:499
ObjectPtr(ObjectPtr< Y > &&other)
move constructor
Definition: object.h:380
void swap(ObjectPtr< T > &other)
Swap this array with another Object.
Definition: object.h:394
Object()
Definition: object.h:225
bool operator!=(std::nullptr_t null) const
Definition: object.h:464
static constexpr bool _type_child_slots_can_overflow
Definition: object.h:218
T * operator->() const
Definition: object.h:406
Root object type.
Definition: object.h:53
Definition: packed_func.h:981
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:468
void DecRef()
developer function, decrease reference counter.
Definition: object.h:716
friend class ObjectInternal
Definition: object.h:310
ObjectPtr(ObjectPtr< T > &&other)
move constructor
Definition: object.h:371
std::string GetTypeKey() const
Definition: object.h:164
base class of all object containers.
Definition: object.h:149
static ObjectPtr< ObjectType > GetDataPtr(const ObjectRef &ref)
Internal helper function get data_ as ObjectPtr of ObjectType.
Definition: object.h:594
ObjectPtr< T > & operator=(ObjectPtr< T > &&other)
move assignmemt
Definition: object.h:431
static constexpr uint32_t _type_index
Definition: object.h:222
bool unique() const
Definition: object.h:448
uint32_t type_index_
Type index(tag) that indicates the type of the object.
Definition: object.h:245
ObjectPtr()
default constructor
Definition: object.h:348
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:572
static uint32_t _GetOrAllocRuntimeTypeIndex()
Definition: object.h:208
static constexpr uint32_t _type_child_slots
Definition: object.h:217
void reset()
reset the content of ptr to be nullptr
Definition: object.h:437
std::atomic< int32_t > RefCounterType
Definition: object.h:201
~ObjectPtr()
destructor
Definition: object.h:387
SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:809
Base class of object allocators that implements make. Use curiously recurring template pattern...
Definition: memory.h:60
bool operator==(std::nullptr_t null) const
Definition: object.h:460
bool defined() const
Definition: object.h:538
ObjectPtr< BaseType > GetObjectPtr(ObjectType *ptr)
Get an object ptr type from a raw object ptr.
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:54
TypeIndex
list of the type index.
Definition: object.h:51
ObjectPtr(std::nullptr_t)
default constructor
Definition: object.h:350
size_t operator()(const ObjectRef &a) const
Definition: object.h:618
static constexpr bool _type_final
Definition: object.h:216
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:785
uint32_t type_index() const
Definition: object.h:157
Type index is allocated during runtime.
Definition: object.h:63
Internal base class to handle conversion to POD values.
Definition: packed_func.h:387
T & operator*() const
Definition: object.h:412
void(* FDeleter)(Object *self)
Object deleter.
Definition: object.h:155
static MXNET_DLL uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
ObjectPtr< T > & operator=(const ObjectPtr< T > &other)
copy assignmemt
Definition: object.h:420
bool operator<(const ObjectRef &other) const
Comparator.
Definition: object.h:534
FDeleter deleter_
deleter of this object to enable customized allocation. If the deleter is nullptr, no deletion will be performed. The creator of the object must always set the deleter field properly.
Definition: object.h:253
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:552
static MXNET_DLL size_t TypeIndex2KeyHash(uint32_t tindex)
Get the type key hash of the corresponding index from runtime.
void IncRef()
developer function, increases reference counter.
Definition: object.h:712
static MXNET_DLL uint32_t GetOrAllocRuntimeTypeIndex(const std::string &key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow)
Get the type index using type key.
RefType GetRef(const ObjectType *ptr)
Get a reference type from a raw object ptr type.
ObjectRef equal functor.
Definition: object.h:630
bool operator==(const ObjectRef &other) const
Comparator.
Definition: object.h:518