mxnet
object_pool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
23 #ifndef MXNET_COMMON_OBJECT_POOL_H_
24 #define MXNET_COMMON_OBJECT_POOL_H_
25 #include <dmlc/logging.h>
26 #include <cstdlib>
27 #include <mutex>
28 #include <utility>
29 #include <vector>
30 
31 namespace mxnet {
32 namespace common {
36 template <typename T>
37 class ObjectPool {
38  public:
42  ~ObjectPool();
47  template <typename... Args>
48  T* New(Args&&... args);
55  void Delete(T* ptr);
56 
61  static ObjectPool* Get();
62 
67  static std::shared_ptr<ObjectPool> _GetSharedRef();
68 
69  private:
73  struct LinkedList {
74 #if defined(_MSC_VER)
75  T t;
76  LinkedList* next{nullptr};
77 #else
78  union {
79  T t;
80  LinkedList* next{nullptr};
81  };
82 #endif
83  };
89  constexpr static std::size_t kPageSize = 1 << 12;
91  std::mutex m_;
95  LinkedList* head_{nullptr};
99  std::vector<void*> allocated_;
103  ObjectPool();
109  void AllocateChunk();
111 }; // class ObjectPool
112 
116 template <typename T>
122  template <typename... Args>
123  static T* New(Args&&... args);
130  static void Delete(T* ptr);
131 }; // struct ObjectPoolAllocatable
132 
133 template <typename T>
135  for (auto i : allocated_) {
136  free(i);
137  }
138 }
139 
140 template <typename T>
141 template <typename... Args>
142 T* ObjectPool<T>::New(Args&&... args) {
143  LinkedList* ret;
144  {
145  std::lock_guard<std::mutex> lock{m_};
146  if (head_->next == nullptr) {
147  AllocateChunk();
148  }
149  ret = head_;
150  head_ = head_->next;
151  }
152  return new (static_cast<void*>(ret)) T(std::forward<Args>(args)...);
153 }
154 
155 template <typename T>
156 void ObjectPool<T>::Delete(T* ptr) {
157  ptr->~T();
158  auto linked_list_ptr = reinterpret_cast<LinkedList*>(ptr);
159  {
160  std::lock_guard<std::mutex> lock{m_};
161  linked_list_ptr->next = head_;
162  head_ = linked_list_ptr;
163  }
164 }
165 
166 template <typename T>
168  return _GetSharedRef().get();
169 }
170 
171 template <typename T>
172 std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
173  static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
174  return inst_ptr;
175 }
176 
177 template <typename T>
179  AllocateChunk();
180 }
181 
182 template <typename T>
184  static_assert(sizeof(LinkedList) <= kPageSize, "Object too big.");
185  static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant");
186  static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant");
187  static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant");
188  void* new_chunk_ptr;
189 #ifdef _MSC_VER
190  new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize);
191  CHECK(new_chunk_ptr != NULL) << "Allocation failed";
192 #else
193  int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize);
194  CHECK_EQ(ret, 0) << "Allocation failed";
195 #endif
196  allocated_.emplace_back(new_chunk_ptr);
197  auto new_chunk = static_cast<LinkedList*>(new_chunk_ptr);
198  auto size = kPageSize / sizeof(LinkedList);
199  for (std::size_t i = 0; i < size - 1; ++i) {
200  new_chunk[i].next = &new_chunk[i + 1];
201  }
202  new_chunk[size - 1].next = head_;
203  head_ = new_chunk;
204 }
205 
206 template <typename T>
207 template <typename... Args>
208 T* ObjectPoolAllocatable<T>::New(Args&&... args) {
209  return ObjectPool<T>::Get()->New(std::forward<Args>(args)...);
210 }
211 
212 template <typename T>
214  ObjectPool<T>::Get()->Delete(ptr);
215 }
216 
217 } // namespace common
218 } // namespace mxnet
219 #endif // MXNET_COMMON_OBJECT_POOL_H_
static void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:213
static T * New(Args &&...args)
Create new object.
Definition: object_pool.h:208
namespace of mxnet
Definition: base.h:89
T * New(Args &&...args)
Create new object.
Definition: object_pool.h:142
static ObjectPool * Get()
Get singleton instance of pool.
Definition: object_pool.h:167
static std::shared_ptr< ObjectPool > _GetSharedRef()
Get a shared ptr of the singleton instance of pool.
Definition: object_pool.h:172
void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:156
#define DISALLOW_COPY_AND_ASSIGN(T)
Disable copy constructor and assignment operator.
Definition: base.h:165
Helper trait class for easy allocation and deallocation.
Definition: object_pool.h:117
~ObjectPool()
Destructor.
Definition: object_pool.h:134
Object pool for fast allocation and deallocation.
Definition: object_pool.h:37