mxnet
object_pool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
23 #ifndef MXNET_COMMON_OBJECT_POOL_H_
24 #define MXNET_COMMON_OBJECT_POOL_H_
25 #include <dmlc/logging.h>
26 #include <cstdlib>
27 #include <mutex>
28 #include <utility>
29 #include <vector>
30 
31 namespace mxnet {
32 namespace common {
36 template <typename T>
37 class ObjectPool {
38  public:
42  ~ObjectPool();
47  template <typename... Args>
48  T* New(Args&&... args);
55  void Delete(T* ptr);
56 
61  static ObjectPool* Get();
62 
67  static std::shared_ptr<ObjectPool> _GetSharedRef();
68 
69  private:
73  struct LinkedList {
74 #if defined(_MSC_VER)
75  T t;
76  LinkedList* next{nullptr};
77 #else
78  union {
79  T t;
80  LinkedList* next{nullptr};
81  };
82 #endif
83  };
89  constexpr static std::size_t kPageSize = 1 << 12;
91  std::mutex m_;
95  LinkedList* head_{nullptr};
99  std::vector<void*> allocated_;
103  ObjectPool();
109  void AllocateChunk();
111 }; // class ObjectPool
112 
116 template <typename T>
122  template <typename... Args>
123  static T* New(Args&&... args);
130  static void Delete(T* ptr);
131 }; // struct ObjectPoolAllocatable
132 
133 template <typename T>
135  for (auto i : allocated_) {
136 #ifdef _MSC_VER
137  _aligned_free(i);
138 #else
139  free(i);
140 #endif
141  }
142 }
143 
144 template <typename T>
145 template <typename... Args>
146 T* ObjectPool<T>::New(Args&&... args) {
147  LinkedList* ret;
148  {
149  std::lock_guard<std::mutex> lock{m_};
150  if (head_->next == nullptr) {
151  AllocateChunk();
152  }
153  ret = head_;
154  head_ = head_->next;
155  }
156  return new (static_cast<void*>(ret)) T(std::forward<Args>(args)...);
157 }
158 
159 template <typename T>
160 void ObjectPool<T>::Delete(T* ptr) {
161  ptr->~T();
162  auto linked_list_ptr = reinterpret_cast<LinkedList*>(ptr);
163  {
164  std::lock_guard<std::mutex> lock{m_};
165  linked_list_ptr->next = head_;
166  head_ = linked_list_ptr;
167  }
168 }
169 
170 template <typename T>
172  return _GetSharedRef().get();
173 }
174 
175 template <typename T>
176 std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
177  static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
178  return inst_ptr;
179 }
180 
181 template <typename T>
183  AllocateChunk();
184 }
185 
186 template <typename T>
188  static_assert(sizeof(LinkedList) <= kPageSize, "Object too big.");
189  static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant");
190  static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant");
191  static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant");
192  void* new_chunk_ptr;
193 #ifdef _MSC_VER
194  new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize);
195  CHECK(new_chunk_ptr != nullptr) << "Allocation failed";
196 #else
197  int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize);
198  CHECK_EQ(ret, 0) << "Allocation failed";
199 #endif
200  allocated_.emplace_back(new_chunk_ptr);
201  auto new_chunk = static_cast<LinkedList*>(new_chunk_ptr);
202  auto size = kPageSize / sizeof(LinkedList);
203  for (std::size_t i = 0; i < size - 1; ++i) {
204  new_chunk[i].next = &new_chunk[i + 1];
205  }
206  new_chunk[size - 1].next = head_;
207  head_ = new_chunk;
208 }
209 
210 template <typename T>
211 template <typename... Args>
212 T* ObjectPoolAllocatable<T>::New(Args&&... args) {
213  return ObjectPool<T>::Get()->New(std::forward<Args>(args)...);
214 }
215 
216 template <typename T>
218  ObjectPool<T>::Get()->Delete(ptr);
219 }
220 
221 } // namespace common
222 } // namespace mxnet
223 #endif // MXNET_COMMON_OBJECT_POOL_H_
static void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:217
static T * New(Args &&...args)
Create new object.
Definition: object_pool.h:212
namespace of mxnet
Definition: api_registry.h:33
T * New(Args &&...args)
Create new object.
Definition: object_pool.h:146
static ObjectPool * Get()
Get singleton instance of pool.
Definition: object_pool.h:171
static std::shared_ptr< ObjectPool > _GetSharedRef()
Get a shared ptr of the singleton instance of pool.
Definition: object_pool.h:176
void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:160
#define DISALLOW_COPY_AND_ASSIGN(T)
Disable copy constructor and assignment operator.
Definition: base.h:165
Helper trait class for easy allocation and deallocation.
Definition: object_pool.h:117
~ObjectPool()
Destructor.
Definition: object_pool.h:134
Object pool for fast allocation and deallocation.
Definition: object_pool.h:37