mxnet
object_pool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
22 #ifndef MXNET_COMMON_OBJECT_POOL_H_
23 #define MXNET_COMMON_OBJECT_POOL_H_
24 #include <dmlc/logging.h>
25 #include <cstdlib>
26 #include <mutex>
27 #include <utility>
28 #include <vector>
29 
30 namespace mxnet {
31 namespace common {
35 template <typename T>
36 class ObjectPool {
37  public:
41  ~ObjectPool();
46  template <typename... Args>
47  T* New(Args&&... args);
54  void Delete(T* ptr);
55 
60  static ObjectPool* Get();
61 
66  static std::shared_ptr<ObjectPool> _GetSharedRef();
67 
68  private:
72  struct LinkedList {
73 #if defined(_MSC_VER)
74  T t;
75  LinkedList* next{nullptr};
76 #else
77  union {
78  T t;
79  LinkedList* next{nullptr};
80  };
81 #endif
82  };
88  constexpr static std::size_t kPageSize = 1 << 12;
90  std::mutex m_;
94  LinkedList* head_{nullptr};
98  std::vector<void*> allocated_;
102  ObjectPool();
108  void AllocateChunk();
109  DISALLOW_COPY_AND_ASSIGN(ObjectPool);
110 }; // class ObjectPool
111 
115 template <typename T>
121  template <typename... Args>
122  static T* New(Args&&... args);
129  static void Delete(T* ptr);
130 }; // struct ObjectPoolAllocatable
131 
132 template <typename T>
134  // TODO(hotpxl): mind destruction order
135  // for (auto i : allocated_) {
136  // free(i);
137  // }
138 }
139 
140 template <typename T>
141 template <typename... Args>
142 T* ObjectPool<T>::New(Args&&... args) {
143  LinkedList* ret;
144  {
145  std::lock_guard<std::mutex> lock{m_};
146  if (head_->next == nullptr) {
147  AllocateChunk();
148  }
149  ret = head_;
150  head_ = head_->next;
151  }
152  return new (static_cast<void*>(ret)) T(std::forward<Args>(args)...);
153 }
154 
155 template <typename T>
156 void ObjectPool<T>::Delete(T* ptr) {
157  ptr->~T();
158  auto linked_list_ptr = reinterpret_cast<LinkedList*>(ptr);
159  {
160  std::lock_guard<std::mutex> lock{m_};
161  linked_list_ptr->next = head_;
162  head_ = linked_list_ptr;
163  }
164 }
165 
166 template <typename T>
168  return _GetSharedRef().get();
169 }
170 
171 template <typename T>
172 std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
173  static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
174  return inst_ptr;
175 }
176 
177 template <typename T>
179  AllocateChunk();
180 }
181 
182 template <typename T>
184  static_assert(sizeof(LinkedList) <= kPageSize, "Object too big.");
185  static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant");
186  static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant");
187  static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant");
188  void* new_chunk_ptr;
189 #ifdef _MSC_VER
190  new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize);
191  CHECK(new_chunk_ptr != NULL) << "Allocation failed";
192 #else
193  int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize);
194  CHECK_EQ(ret, 0) << "Allocation failed";
195 #endif
196  allocated_.emplace_back(new_chunk_ptr);
197  auto new_chunk = static_cast<LinkedList*>(new_chunk_ptr);
198  auto size = kPageSize / sizeof(LinkedList);
199  for (std::size_t i = 0; i < size - 1; ++i) {
200  new_chunk[i].next = &new_chunk[i + 1];
201  }
202  new_chunk[size - 1].next = head_;
203  head_ = new_chunk;
204 }
205 
206 template <typename T>
207 template <typename... Args>
208 T* ObjectPoolAllocatable<T>::New(Args&&... args) {
209  return ObjectPool<T>::Get()->New(std::forward<Args>(args)...);
210 }
211 
212 template <typename T>
214  ObjectPool<T>::Get()->Delete(ptr);
215 }
216 
217 } // namespace common
218 } // namespace mxnet
219 #endif // MXNET_COMMON_OBJECT_POOL_H_
static void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:213
static T * New(Args &&...args)
Create new object.
Definition: object_pool.h:208
namespace of mxnet
Definition: base.h:126
T * New(Args &&...args)
Create new object.
Definition: object_pool.h:142
static ObjectPool * Get()
Get singleton instance of pool.
Definition: object_pool.h:167
static std::shared_ptr< ObjectPool > _GetSharedRef()
Get a shared ptr of the singleton instance of pool.
Definition: object_pool.h:172
void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:156
Helper trait class for easy allocation and deallocation.
Definition: object_pool.h:116
~ObjectPool()
Destructor.
Definition: object_pool.h:133
Object pool for fast allocation and deallocation.
Definition: object_pool.h:36