mxnet
container.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: This file originates from incubator-tvm
25 #ifndef MXNET_RUNTIME_CONTAINER_H_
26 #define MXNET_RUNTIME_CONTAINER_H_
27 #include <dmlc/logging.h>
28 #include <mxnet/runtime/memory.h>
29 #include <mxnet/runtime/object.h>
30 
31 #include <initializer_list>
32 #include <type_traits>
33 #include <utility>
34 #include <vector>
35 
36 namespace mxnet {
37 namespace runtime {
38 
39 class ADTBuilder;
79 template <typename ArrayType, typename ElemType>
81  public:
87  const ElemType& operator[](size_t idx) const {
88  size_t size = Self()->GetSize();
89  CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
90  return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
91  }
92 
98  ElemType& operator[](size_t idx) {
99  size_t size = Self()->GetSize();
100  CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
101  return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
102  }
103 
108  if (!(std::is_standard_layout<ElemType>::value && std::is_trivial<ElemType>::value)) {
109  size_t size = Self()->GetSize();
110  for (size_t i = 0; i < size; ++i) {
111  ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i));
112  fp->ElemType::~ElemType();
113  }
114  }
115  }
116 
117  protected:
118  friend class ADTBuilder;
129  template <typename... Args>
130  void EmplaceInit(size_t idx, Args&&... args) {
131  void* field_ptr = AddressOf(idx);
132  new (field_ptr) ElemType(std::forward<Args>(args)...);
133  }
134 
135  private:
141  inline ArrayType* Self() const {
142  return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this));
143  }
144 
151  void* AddressOf(size_t idx) const {
152  static_assert(
153  alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0,
154  "The size and alignment of ArrayType should respect "
155  "ElemType's alignment.");
156 
157  size_t kDataStart = sizeof(ArrayType);
158  ArrayType* self = Self();
159  char* data_start = reinterpret_cast<char*>(self) + kDataStart;
160  return data_start + idx * sizeof(ElemType);
161  }
162 };
163 
165 class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
166  public:
168  uint32_t tag;
170  uint32_t size{0};
171  // The fields of the structure follows directly in memory.
172 
173  static constexpr const char* _type_key = "MXNet.ADT";
174  static constexpr const uint32_t _type_index = TypeIndex::kMXNetADT;
176 
177  private:
181  size_t GetSize() const {
182  return size;
183  }
184 
192  template <typename Iterator>
193  void Init(Iterator begin, Iterator end) {
194  size_t num_elems = std::distance(begin, end);
195  this->size = 0;
196  auto it = begin;
197  for (size_t i = 0; i < num_elems; ++i) {
199  // Only increment size after the initialization succeeds
200  this->size++;
201  }
202  }
203 
204  friend class ADT;
206 };
207 
209 class ADT : public ObjectRef {
210  public:
217  ADT(uint32_t tag, std::vector<ObjectRef> fields) : ADT(tag, fields.begin(), fields.end()){};
218 
226  template <typename Iterator>
227  ADT(uint32_t tag, Iterator begin, Iterator end) {
228  size_t num_elems = std::distance(begin, end);
229  auto ptr = make_inplace_array_object<ADTObj, ObjectRef>(num_elems);
230  ptr->tag = tag;
231  ptr->Init(begin, end);
232  data_ = std::move(ptr);
233  }
234 
241  ADT(uint32_t tag, std::initializer_list<ObjectRef> init) : ADT(tag, init.begin(), init.end()){};
242 
249  const ObjectRef& operator[](size_t idx) const {
250  return operator->()->operator[](idx);
251  }
252 
256  size_t tag() const {
257  return operator->()->tag;
258  }
259 
263  size_t size() const {
264  return operator->()->size;
265  }
266 
274  template <typename... Args>
275  static ADT Tuple(Args&&... args) {
276  return ADT(0, std::forward<Args>(args)...);
277  }
278 
280 };
281 
282 } // namespace runtime
283 } // namespace mxnet
284 
285 #endif // MXNET_RUNTIME_CONTAINER_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::runtime::ADT::operator[]
const ObjectRef & operator[](size_t idx) const
Access element at index.
Definition: container.h:249
mxnet::runtime::Object
base class of all object containers.
Definition: object.h:151
mxnet::runtime::InplaceArrayBase::operator[]
const ElemType & operator[](size_t idx) const
Access element at index.
Definition: container.h:87
mxnet::runtime::ADT::tag
size_t tag() const
Return the ADT tag.
Definition: container.h:256
mxnet::runtime::ADTObj::tag
uint32_t tag
The tag representing the constructor used.
Definition: container.h:168
mxnet::runtime::ADT::Tuple
static ADT Tuple(Args &&... args)
Construct a tuple object.
Definition: container.h:275
mxnet::runtime::ADT::size
size_t size() const
Return the number of fields.
Definition: container.h:263
mxnet::runtime::kMXNetADT
@ kMXNetADT
Definition: object.h:56
mxnet::runtime::InplaceArrayBase::~InplaceArrayBase
~InplaceArrayBase()
Destroy the Inplace Array Base object.
Definition: container.h:107
mxnet::runtime::ADT::ADT
ADT(uint32_t tag, std::vector< ObjectRef > fields)
construct an ADT object reference.
Definition: container.h:217
MXNET_DECLARE_FINAL_OBJECT_INFO
#define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:668
mxnet::runtime::InplaceArrayBase
Base template for classes with array like memory layout.
Definition: container.h:80
mxnet::runtime::ADT
reference to algebraic data type objects.
Definition: container.h:209
mxnet::runtime::InplaceArrayBase::EmplaceInit
void EmplaceInit(size_t idx, Args &&... args)
Construct a value in place with the arguments.
Definition: container.h:130
mxnet::runtime::ADT::ADT
ADT(uint32_t tag, std::initializer_list< ObjectRef > init)
construct an ADT object reference.
Definition: container.h:241
mxnet::runtime::InplaceArrayBase::operator[]
ElemType & operator[](size_t idx)
Access element at index.
Definition: container.h:98
memory.h
Runtime memory management.
mxnet::runtime::ADT::ADT
ADT(uint32_t tag, Iterator begin, Iterator end)
construct an ADT object reference.
Definition: container.h:227
mxnet::runtime::ObjectRef::data_
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:575
mxnet::runtime::ADTObj::_type_index
static constexpr const uint32_t _type_index
Definition: container.h:174
mxnet::runtime::ObjectRef::operator->
const Object * operator->() const
Definition: object.h:547
mxnet::runtime::ADTObj
An object representing a structure or enumeration.
Definition: container.h:165
mxnet::runtime::ADTObj::size
uint32_t size
Number of fields in the ADT object.
Definition: container.h:170
mxnet::runtime::ADTObj::_type_key
static constexpr const char * _type_key
Definition: container.h:173
mxnet::runtime::ObjectRef
Base class of all object reference.
Definition: object.h:500
MXNET_DEFINE_OBJECT_REF_METHODS
#define MXNET_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:689
mxnet::runtime::ADTBuilder
A builder class that helps to incrementally build ADT.
Definition: ffi_helper.h:122
object.h
A managed object in MXNet runtime.