mxnet
object_pool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_OBJECT_POOL_H_
21 #define MXNET_COMMON_OBJECT_POOL_H_
22 #include <dmlc/logging.h>
23 #include <cstdlib>
24 #include <mutex>
25 #include <utility>
26 #include <vector>
27 
28 namespace mxnet {
29 namespace common {
33 template <typename T>
34 class ObjectPool {
35  public:
39  ~ObjectPool();
44  template <typename... Args>
45  T* New(Args&&... args);
52  void Delete(T* ptr);
53 
58  static ObjectPool* Get();
59 
64  static const std::shared_ptr<ObjectPool>& _GetSharedRef();
65 
66  private:
70  struct LinkedList {
71 #if defined(_MSC_VER)
72  T t;
73  LinkedList* next{nullptr};
74 #else
75  union {
76  T t;
77  LinkedList* next{nullptr};
78  };
79 #endif
80  };
86  constexpr static std::size_t kPageSize = 1 << 12;
88  std::mutex m_;
92  LinkedList* head_{nullptr};
96  std::vector<void*> allocated_;
100  ObjectPool();
106  void AllocateChunk();
107  DISALLOW_COPY_AND_ASSIGN(ObjectPool);
108 }; // class ObjectPool
109 
113 template <typename T>
119  template <typename... Args>
120  static T* New(Args&&... args);
127  static void Delete(T* ptr);
128 }; // struct ObjectPoolAllocatable
129 
130 template <typename T>
132  for (auto i : allocated_) {
133 #ifdef _MSC_VER
134  _aligned_free(i);
135 #else
136  free(i);
137 #endif
138  }
139 }
140 
141 template <typename T>
142 template <typename... Args>
143 T* ObjectPool<T>::New(Args&&... args) {
144  LinkedList* ret;
145  {
146  std::lock_guard<std::mutex> lock{m_};
147  if (head_->next == nullptr) {
148  AllocateChunk();
149  }
150  ret = head_;
151  head_ = head_->next;
152  }
153  return new (static_cast<void*>(ret)) T(std::forward<Args>(args)...);
154 }
155 
156 template <typename T>
157 void ObjectPool<T>::Delete(T* ptr) {
158  ptr->~T();
159  auto linked_list_ptr = reinterpret_cast<LinkedList*>(ptr);
160  {
161  std::lock_guard<std::mutex> lock{m_};
162  linked_list_ptr->next = head_;
163  head_ = linked_list_ptr;
164  }
165 }
166 
167 template <typename T>
169  return _GetSharedRef().get();
170 }
171 
172 template <typename T>
173 const std::shared_ptr<ObjectPool<T> >& ObjectPool<T>::_GetSharedRef() {
174  static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
175  return inst_ptr;
176 }
177 
178 template <typename T>
180  AllocateChunk();
181 }
182 
183 template <typename T>
184 void ObjectPool<T>::AllocateChunk() {
185  static_assert(sizeof(LinkedList) <= kPageSize, "Object too big.");
186  static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant");
187  static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant");
188  static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant");
189  void* new_chunk_ptr;
190 #ifdef _MSC_VER
191  new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize);
192  CHECK(new_chunk_ptr != nullptr) << "Allocation failed";
193 #else
194  int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize);
195  CHECK_EQ(ret, 0) << "Allocation failed";
196 #endif
197  allocated_.emplace_back(new_chunk_ptr);
198  auto new_chunk = static_cast<LinkedList*>(new_chunk_ptr);
199  auto size = kPageSize / sizeof(LinkedList);
200  for (std::size_t i = 0; i < size - 1; ++i) {
201  new_chunk[i].next = &new_chunk[i + 1];
202  }
203  new_chunk[size - 1].next = head_;
204  head_ = new_chunk;
205 }
206 
207 template <typename T>
208 template <typename... Args>
209 T* ObjectPoolAllocatable<T>::New(Args&&... args) {
210  return ObjectPool<T>::Get()->New(std::forward<Args>(args)...);
211 }
212 
213 template <typename T>
215  ObjectPool<T>::Get()->Delete(ptr);
216 }
217 
218 } // namespace common
219 } // namespace mxnet
220 #endif // MXNET_COMMON_OBJECT_POOL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::ObjectPool::Delete
void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:157
mxnet::common::ObjectPoolAllocatable::Delete
static void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:214
mxnet::common::cuda::rtc::lock
std::mutex lock
mxnet::common::ObjectPoolAllocatable::New
static T * New(Args &&... args)
Create new object.
Definition: object_pool.h:209
mxnet::common::ObjectPool::_GetSharedRef
static const std::shared_ptr< ObjectPool > & _GetSharedRef()
Get a shared ptr of the singleton instance of pool.
Definition: object_pool.h:173
mxnet::common::ObjectPool::Get
static ObjectPool * Get()
Get singleton instance of pool.
Definition: object_pool.h:168
mxnet::common::ObjectPool::~ObjectPool
~ObjectPool()
Destructor.
Definition: object_pool.h:131
mxnet::common::ObjectPool::New
T * New(Args &&... args)
Create new object.
Definition: object_pool.h:143
mxnet::common::ObjectPoolAllocatable
Helper trait class for easy allocation and deallocation.
Definition: object_pool.h:114
mxnet::common::ObjectPool
Object pool for fast allocation and deallocation.
Definition: object_pool.h:34