mxnet
object.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: This file originates from incubator-tvm
24 #ifndef MXNET_RUNTIME_OBJECT_H_
25 #define MXNET_RUNTIME_OBJECT_H_
26 
27 #include <dmlc/logging.h>
28 #include <type_traits>
29 #include <string>
30 #include <utility>
31 #include "c_runtime_api.h"
32 
39 #ifndef MXNET_OBJECT_ATOMIC_REF_COUNTER
40 #define MXNET_OBJECT_ATOMIC_REF_COUNTER 1
41 #endif
42 
43 #if MXNET_OBJECT_ATOMIC_REF_COUNTER
44 #include <atomic>
45 #endif // MXNET_OBJECT_ATOMIC_REF_COUNTER
46 
47 namespace mxnet {
48 namespace runtime {
49 
51 enum TypeIndex {
53  kRoot = 0,
56  kMXNetADT = 3,
58  kEllipsis = 5,
59  kSlice = 6,
60  kInteger = 7,
64 };
65 
149 class Object {
150  public:
155  typedef void (*FDeleter)(Object* self);
157  uint32_t type_index() const {
158  return type_index_;
159  }
164  std::string GetTypeKey() const {
165  return TypeIndex2Key(type_index_);
166  }
170  size_t GetTypeKeyHash() const {
172  }
178  template<typename TargetType>
179  inline bool IsInstance() const;
180 
186  MXNET_DLL static std::string TypeIndex2Key(uint32_t tindex);
192  MXNET_DLL static size_t TypeIndex2KeyHash(uint32_t tindex);
198  MXNET_DLL static uint32_t TypeKey2Index(const std::string& key);
199 
200 #if MXNET_OBJECT_ATOMIC_REF_COUNTER
201  using RefCounterType = std::atomic<int32_t>;
202 #else
203  using RefCounterType = int32_t;
204 #endif
205 
206  static constexpr const char* _type_key = "Object";
207 
208  static uint32_t _GetOrAllocRuntimeTypeIndex() {
209  return TypeIndex::kRoot;
210  }
211  static uint32_t RuntimeTypeIndex() {
212  return TypeIndex::kRoot;
213  }
214 
215  // Default object type properties for sub-classes
216  static constexpr bool _type_final = false;
217  static constexpr uint32_t _type_child_slots = 0;
218  static constexpr bool _type_child_slots_can_overflow = true;
219  // NOTE: the following field is not type index of Object
220  // but was intended to be used by sub-classes as default value.
221  // The type index of Object is TypeIndex::kRoot
222  static constexpr uint32_t _type_index = TypeIndex::kDynamic;
223 
224  // Default constructor and copy constructor
225  Object() {}
226  // Override the copy and assign constructors to do nothing.
227  // This is to make sure only contents, but not deleter and ref_counter
228  // are copied when a child class copies itself.
229  // This will enable us to use make_object<ObjectClass>(*obj_ptr)
230  // to copy an existing object.
231  Object(const Object& other) { // NOLINT(*)
232  }
233  Object(Object&& other) { // NOLINT(*)
234  }
235  Object& operator=(const Object& other) { //NOLINT(*)
236  return *this;
237  }
238  Object& operator=(Object&& other) { //NOLINT(*)
239  return *this;
240  }
241 
242  protected:
243  // The fields of the base object cell.
245  uint32_t type_index_{0};
253  FDeleter deleter_ = nullptr;
254  // Invariant checks.
255  static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
256  alignof(int32_t) == sizeof(RefCounterType),
257  "RefCounter ABI check.");
258 
276  MXNET_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
277  const std::string& key,
278  uint32_t static_tindex,
279  uint32_t parent_tindex,
280  uint32_t type_child_slots,
281  bool type_child_slots_can_overflow);
282 
283  // reference counter related operations
285  inline void IncRef();
290  inline void DecRef();
291 
292  private:
297  inline int use_count() const;
303  MXNET_DLL bool DerivedFrom(uint32_t parent_tindex) const;
304  // friend classes
305  template<typename>
306  friend class ObjAllocatorBase;
307  template<typename>
308  friend class ObjectPtr;
309  friend class MXNetRetValue;
310  friend class ObjectInternal;
311 };
312 
325 template <typename RefType, typename ObjectType>
326 inline RefType GetRef(const ObjectType* ptr);
327 
336 template <typename SubRef, typename BaseRef>
337 inline SubRef Downcast(BaseRef ref);
338 
344 template <typename T>
345 class ObjectPtr {
346  public:
350  ObjectPtr(std::nullptr_t) {} // NOLINT(*)
355  ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
356  : ObjectPtr(other.data_) {}
361  template <typename U>
362  ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
363  : ObjectPtr(other.data_) {
364  static_assert(std::is_base_of<T, U>::value,
365  "can only assign of child class ObjectPtr to parent");
366  }
371  ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*)
372  : data_(other.data_) {
373  other.data_ = nullptr;
374  }
379  template <typename Y>
380  ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*)
381  : data_(other.data_) {
382  static_assert(std::is_base_of<T, Y>::value,
383  "can only assign of child class ObjectPtr to parent");
384  other.data_ = nullptr;
385  }
388  this->reset();
389  }
394  void swap(ObjectPtr<T>& other) { // NOLINT(*)
395  std::swap(data_, other.data_);
396  }
400  T* get() const {
401  return static_cast<T*>(data_);
402  }
406  T* operator->() const {
407  return get();
408  }
412  T& operator*() const { // NOLINT(*)
413  return *get();
414  }
420  ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*)
421  // takes in plane operator to enable copy elison.
422  // copy-and-swap idiom
423  ObjectPtr(other).swap(*this); // NOLINT(*)
424  return *this;
425  }
431  ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*)
432  // copy-and-swap idiom
433  ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
434  return *this;
435  }
437  void reset() {
438  if (data_ != nullptr) {
439  data_->DecRef();
440  data_ = nullptr;
441  }
442  }
444  int use_count() const {
445  return data_ != nullptr ? data_->use_count() : 0;
446  }
448  bool unique() const {
449  return data_ != nullptr && data_->use_count() == 1;
450  }
452  bool operator==(const ObjectPtr<T>& other) const {
453  return data_ == other.data_;
454  }
456  bool operator!=(const ObjectPtr<T>& other) const {
457  return data_ != other.data_;
458  }
460  bool operator==(std::nullptr_t null) const {
461  return data_ == nullptr;
462  }
464  bool operator!=(std::nullptr_t null) const {
465  return data_ != nullptr;
466  }
467 
468  private:
470  Object* data_{nullptr};
475  explicit ObjectPtr(Object* data) : data_(data) {
476  if (data != nullptr) {
477  data_->IncRef();
478  }
479  }
480  // friend classes
481  friend class Object;
482  friend class ObjectRef;
483  friend struct ObjectHash;
484  template<typename>
485  friend class ObjectPtr;
486  template<typename>
487  friend class ObjAllocatorBase;
488  friend class MXNetPODValue_;
489  friend class MXNetArgsSetter;
490  friend class MXNetRetValue;
491  friend class MXNetArgValue;
492  template <typename RefType, typename ObjType>
493  friend RefType GetRef(const ObjType* ptr);
494  template <typename BaseType, typename ObjType>
495  friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
496 };
497 
499 class ObjectRef {
500  public:
502  ObjectRef() = default;
504  explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
510  bool same_as(const ObjectRef& other) const {
511  return data_ == other.data_;
512  }
518  bool operator==(const ObjectRef& other) const {
519  return data_ == other.data_;
520  }
526  bool operator!=(const ObjectRef& other) const {
527  return data_ != other.data_;
528  }
534  bool operator<(const ObjectRef& other) const {
535  return data_.get() < other.data_.get();
536  }
538  bool defined() const {
539  return data_ != nullptr;
540  }
542  const Object* get() const {
543  return data_.get();
544  }
546  const Object* operator->() const {
547  return get();
548  }
550  bool unique() const {
551  return data_.unique();
552  }
564  template <typename ObjectType>
565  inline const ObjectType* as() const;
566 
569 
570  protected:
574  Object* get_mutable() const {
575  return data_.get();
576  }
583  template<typename T>
584  static T DowncastNoCheck(ObjectRef ref) {
585  return T(std::move(ref.data_));
586  }
593  template<typename ObjectType>
595  return ObjectPtr<ObjectType>(ref.data_.data_);
596  }
597  // friend classes.
598  friend struct ObjectHash;
599  friend class MXNetRetValue;
600  friend class MXNetArgsSetter;
601  template <typename SubRef, typename BaseRef>
602  friend SubRef Downcast(BaseRef ref);
603 };
604 
613 template <typename BaseType, typename ObjectType>
614 inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
615 
617 struct ObjectHash {
618  size_t operator()(const ObjectRef& a) const {
619  return operator()(a.data_);
620  }
621 
622  template<typename T>
623  size_t operator()(const ObjectPtr<T>& a) const {
624  return std::hash<Object*>()(a.get());
625  }
626 };
627 
628 
630 struct ObjectEqual {
631  bool operator()(const ObjectRef& a, const ObjectRef& b) const {
632  return a.same_as(b);
633  }
634 
635  template<typename T>
636  size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const {
637  return a == b;
638  }
639 };
640 
641 
647 #define MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
648  static const uint32_t RuntimeTypeIndex() { \
649  if (TypeName::_type_index != ::mxnet::runtime::TypeIndex::kDynamic) { \
650  return TypeName::_type_index; \
651  } \
652  return _GetOrAllocRuntimeTypeIndex(); \
653  } \
654  static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
655  static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \
656  TypeName::_type_key, \
657  TypeName::_type_index, \
658  ParentType::_GetOrAllocRuntimeTypeIndex(), \
659  TypeName::_type_child_slots, \
660  TypeName::_type_child_slots_can_overflow); \
661  return tidx; \
662  } \
663 
664 
669 #define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
670  static const constexpr bool _type_final = true; \
671  static const constexpr int _type_child_slots = 0; \
672  MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
673 
674 
681 #define MXNET_REGISTER_OBJECT_TYPE(TypeName) \
682  static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \
683  TypeName::_GetOrAllocRuntimeTypeIndex()
684 
685 
686 #define MXNET_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
687  TypeName() {} \
688  explicit TypeName( \
689  ::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) \
690  : ParentType(n) {} \
691  const ObjectName* operator->() const { \
692  return static_cast<const ObjectName*>(data_.get()); \
693  } \
694  operator bool() const { return data_ != nullptr; } \
695  using ContainerType = ObjectName;
696 
697 #define MXNET_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
698  TypeName() {} \
699  explicit TypeName( \
700  ::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) \
701  : ParentType(n) {} \
702  ObjectName* operator->() { \
703  return static_cast<ObjectName*>(data_.get()); \
704  } \
705  operator bool() const { return data_ != nullptr; } \
706  using ContainerType = ObjectName;
707 
708 // Implementations details below
709 // Object reference counting.
710 #if MXNET_OBJECT_ATOMIC_REF_COUNTER
711 
712 inline void Object::IncRef() {
713  ref_counter_.fetch_add(1, std::memory_order_relaxed);
714 }
715 
716 inline void Object::DecRef() {
717  if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
718  std::atomic_thread_fence(std::memory_order_acquire);
719  if (this->deleter_ != nullptr) {
720  (*this->deleter_)(this);
721  }
722  }
723 }
724 
725 inline int Object::use_count() const {
726  return ref_counter_.load(std::memory_order_relaxed);
727 }
728 
729 #else
730 
731 inline void Object::IncRef() {
732  ++ref_counter_;
733 }
734 
735 inline void Object::DecRef() {
736  if (--ref_counter == 0) {
737  if (this->deleter_ != nullptr) {
738  (*this->deleter_)(this);
739  }
740  }
741 }
742 
743 inline int Object::use_count() const {
744  return ref_counter_;
745 }
746 
747 #endif // MXNET_OBJECT_ATOMIC_REF_COUNTER
748 
749 template<typename TargetType>
750 inline bool Object::IsInstance() const {
751  const Object* self = this;
752  // NOTE: the following code can be optimized by
753  // compiler dead-code elimination for already known constants.
754  if (self != nullptr) {
755  // Everything is a subclass of object.
756  if (std::is_same<TargetType, Object>::value) return true;
757  if (TargetType::_type_final) {
758  // if the target type is a final type
759  // then we only need to check the equivalence.
760  return self->type_index_ == TargetType::RuntimeTypeIndex();
761  } else {
762  // if target type is a non-leaf type
763  // Check if type index falls into the range of reserved slots.
764  uint32_t begin = TargetType::RuntimeTypeIndex();
765  // The condition will be optimized by constant-folding.
766  if (TargetType::_type_child_slots != 0) {
767  uint32_t end = begin + TargetType::_type_child_slots;
768  if (self->type_index_ >= begin && self->type_index_ < end) return true;
769  } else {
770  if (self->type_index_ == begin) return true;
771  }
772  if (!TargetType::_type_child_slots_can_overflow) return false;
773  // Invariance: parent index is always smaller than the child.
774  if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false;
775  // The rare slower-path, check type hierachy.
776  return self->DerivedFrom(TargetType::RuntimeTypeIndex());
777  }
778  } else {
779  return false;
780  }
781 }
782 
783 
784 template <typename ObjectType>
785 inline const ObjectType* ObjectRef::as() const {
786  if (data_ != nullptr &&
787  data_->IsInstance<ObjectType>()) {
788  return static_cast<ObjectType*>(data_.get());
789  } else {
790  return nullptr;
791  }
792 }
793 
794 template <typename RefType, typename ObjType>
795 inline RefType GetRef(const ObjType* ptr) {
796  static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
797  "Can only cast to the ref of same container type");
798  return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
799 }
800 
801 template <typename BaseType, typename ObjType>
802 inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
803  static_assert(std::is_base_of<BaseType, ObjType>::value,
804  "Can only cast to the ref of same container type");
805  return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
806 }
807 
808 template <typename SubRef, typename BaseRef>
809 inline SubRef Downcast(BaseRef ref) {
810  CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
811  << "Downcast from " << ref->GetTypeKey() << " to "
812  << SubRef::ContainerType::_type_key << " failed.";
813  return SubRef(std::move(ref.data_));
814 }
815 
816 } // namespace runtime
817 
818 template<typename T>
820 
821 } // namespace mxnet
822 
823 #endif // MXNET_RUNTIME_OBJECT_H_
bool operator!=(const ObjectPtr< T > &other) const
Definition: object.h:456
ObjectPtr(const ObjectPtr< T > &other)
copy constructor
Definition: object.h:355
static T DowncastNoCheck(ObjectRef ref)
Internal helper function downcast a ref without check.
Definition: object.h:584
T * get() const
Definition: object.h:400
Object(const Object &other)
Definition: object.h:231
Object(Object &&other)
Definition: object.h:233
bool operator==(const ObjectPtr< T > &other) const
Definition: object.h:452
static MXNET_DLL std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
Object & operator=(const Object &other)
Definition: object.h:235
Object * get_mutable() const
Definition: object.h:574
bool same_as(const ObjectRef &other) const
Comparator.
Definition: object.h:510
namespace of mxnet
Definition: api_registry.h:33
bool operator()(const ObjectRef &a, const ObjectRef &b) const
Definition: object.h:631
static constexpr const char * _type_key
Definition: object.h:206
ObjectRef hash functor.
Definition: object.h:617
Definition: object.h:56
bool operator!=(const ObjectRef &other) const
Comparator.
Definition: object.h:526
const Object * operator->() const
Definition: object.h:546
size_t GetTypeKeyHash() const
Definition: object.h:170
size_t operator()(const ObjectPtr< T > &a, const ObjectPtr< T > &b) const
Definition: object.h:636
Object & operator=(Object &&other)
Definition: object.h:238
A custom smart pointer for Object.
Definition: object.h:345
friend class ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case...
Definition: object.h:308
size_t operator()(const ObjectPtr< T > &a) const
Definition: object.h:623
ObjectRef(ObjectPtr< Object > data)
Constructor from existing object ptr.
Definition: object.h:504
int use_count() const
Definition: object.h:444
bool IsInstance() const
Definition: object.h:750
static uint32_t RuntimeTypeIndex()
Definition: object.h:211
RefCounterType ref_counter_
The internal reference counter.
Definition: object.h:247
bool unique() const
Definition: object.h:550
ObjectPtr(const ObjectPtr< U > &other)
copy constructor
Definition: object.h:362
Base class of all object reference.
Definition: object.h:499
ObjectPtr(ObjectPtr< Y > &&other)
move constructor
Definition: object.h:380
void swap(ObjectPtr< T > &other)
Swap this array with another Object.
Definition: object.h:394
Object()
Definition: object.h:225
bool operator!=(std::nullptr_t null) const
Definition: object.h:464
static constexpr bool _type_child_slots_can_overflow
Definition: object.h:218
Definition: object.h:55
T * operator->() const
Definition: object.h:406
Root object type.
Definition: object.h:53
Definition: object.h:58
Definition: packed_func.h:981
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:468
void DecRef()
developer function, decrease reference counter.
Definition: object.h:716
Definition: object.h:59
friend class ObjectInternal
Definition: object.h:310
ObjectPtr(ObjectPtr< T > &&other)
move constructor
Definition: object.h:371
std::string GetTypeKey() const
Definition: object.h:164
base class of all object containers.
Definition: object.h:149
static ObjectPtr< ObjectType > GetDataPtr(const ObjectRef &ref)
Internal helper function get data_ as ObjectPtr of ObjectType.
Definition: object.h:594
ObjectPtr< T > & operator=(ObjectPtr< T > &&other)
move assignmemt
Definition: object.h:431
Definition: object.h:61
Definition: object.h:60
static constexpr uint32_t _type_index
Definition: object.h:222
bool unique() const
Definition: object.h:448
uint32_t type_index_
Type index(tag) that indicates the type of the object.
Definition: object.h:245
ObjectPtr()
default constructor
Definition: object.h:348
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:572
static uint32_t _GetOrAllocRuntimeTypeIndex()
Definition: object.h:208
static constexpr uint32_t _type_child_slots
Definition: object.h:217
void reset()
reset the content of ptr to be nullptr
Definition: object.h:437
std::atomic< int32_t > RefCounterType
Definition: object.h:201
~ObjectPtr()
destructor
Definition: object.h:387
SubRef Downcast(BaseRef ref)
Downcast a base reference type to a more specific type.
Definition: object.h:809
Base class of object allocators that implements make. Use curiously recurring template pattern...
Definition: memory.h:60
bool operator==(std::nullptr_t null) const
Definition: object.h:460
Definition: object.h:57
bool defined() const
Definition: object.h:538
ObjectPtr< BaseType > GetObjectPtr(ObjectType *ptr)
Get an object ptr type from a raw object ptr.
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:54
TypeIndex
list of the type index.
Definition: object.h:51
ObjectPtr(std::nullptr_t)
default constructor
Definition: object.h:350
size_t operator()(const ObjectRef &a) const
Definition: object.h:618
static constexpr bool _type_final
Definition: object.h:216
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:785
uint32_t type_index() const
Definition: object.h:157
Type index is allocated during runtime.
Definition: object.h:63
Internal base class to handle conversion to POD values.
Definition: packed_func.h:387
T & operator*() const
Definition: object.h:412
void(* FDeleter)(Object *self)
Object deleter.
Definition: object.h:155
static MXNET_DLL uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
Definition: object.h:54
ObjectPtr< T > & operator=(const ObjectPtr< T > &other)
copy assignmemt
Definition: object.h:420
bool operator<(const ObjectRef &other) const
Comparator.
Definition: object.h:534
FDeleter deleter_
deleter of this object to enable customized allocation. If the deleter is nullptr, no deletion will be performed. The creator of the object must always set the deleter field properly.
Definition: object.h:253
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:552
static MXNET_DLL size_t TypeIndex2KeyHash(uint32_t tindex)
Get the type key hash of the corresponding index from runtime.
void IncRef()
developer function, increases reference counter.
Definition: object.h:712
static MXNET_DLL uint32_t GetOrAllocRuntimeTypeIndex(const std::string &key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow)
Get the type index using type key.
RefType GetRef(const ObjectType *ptr)
Get a reference type from a raw object ptr type.
ObjectRef equal functor.
Definition: object.h:630
bool operator==(const ObjectRef &other) const
Comparator.
Definition: object.h:518