mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
object_pool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
23 #ifndef MXNET_COMMON_OBJECT_POOL_H_
24 #define MXNET_COMMON_OBJECT_POOL_H_
25 #include <dmlc/logging.h>
26 #include <cstdlib>
27 #include <mutex>
28 #include <utility>
29 #include <vector>
30 
31 namespace mxnet {
32 namespace common {
36 template <typename T>
37 class ObjectPool {
38  public:
42  ~ObjectPool();
47  template <typename... Args>
48  T* New(Args&&... args);
55  void Delete(T* ptr);
56 
61  static ObjectPool* Get();
62 
67  static std::shared_ptr<ObjectPool> _GetSharedRef();
68 
69  private:
73  struct LinkedList {
74 #if defined(_MSC_VER)
75  T t;
76  LinkedList* next{nullptr};
77 #else
78  union {
79  T t;
80  LinkedList* next{nullptr};
81  };
82 #endif
83  };
89  constexpr static std::size_t kPageSize = 1 << 12;
91  std::mutex m_;
95  LinkedList* head_{nullptr};
99  std::vector<void*> allocated_;
103  ObjectPool();
109  void AllocateChunk();
110  DISALLOW_COPY_AND_ASSIGN(ObjectPool);
111 }; // class ObjectPool
112 
116 template <typename T>
122  template <typename... Args>
123  static T* New(Args&&... args);
130  static void Delete(T* ptr);
131 }; // struct ObjectPoolAllocatable
132 
133 template <typename T>
135  // TODO(hotpxl): mind destruction order
136  // for (auto i : allocated_) {
137  // free(i);
138  // }
139 }
140 
141 template <typename T>
142 template <typename... Args>
143 T* ObjectPool<T>::New(Args&&... args) {
144  LinkedList* ret;
145  {
146  std::lock_guard<std::mutex> lock{m_};
147  if (head_->next == nullptr) {
148  AllocateChunk();
149  }
150  ret = head_;
151  head_ = head_->next;
152  }
153  return new (static_cast<void*>(ret)) T(std::forward<Args>(args)...);
154 }
155 
156 template <typename T>
157 void ObjectPool<T>::Delete(T* ptr) {
158  ptr->~T();
159  auto linked_list_ptr = reinterpret_cast<LinkedList*>(ptr);
160  {
161  std::lock_guard<std::mutex> lock{m_};
162  linked_list_ptr->next = head_;
163  head_ = linked_list_ptr;
164  }
165 }
166 
167 template <typename T>
169  return _GetSharedRef().get();
170 }
171 
172 template <typename T>
173 std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
174  static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
175  return inst_ptr;
176 }
177 
178 template <typename T>
180  AllocateChunk();
181 }
182 
183 template <typename T>
184 void ObjectPool<T>::AllocateChunk() {
185  static_assert(sizeof(LinkedList) <= kPageSize, "Object too big.");
186  static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant");
187  static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant");
188  static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant");
189  void* new_chunk_ptr;
190 #ifdef _MSC_VER
191  new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize);
192  CHECK(new_chunk_ptr != NULL) << "Allocation failed";
193 #else
194  int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize);
195  CHECK_EQ(ret, 0) << "Allocation failed";
196 #endif
197  allocated_.emplace_back(new_chunk_ptr);
198  auto new_chunk = static_cast<LinkedList*>(new_chunk_ptr);
199  auto size = kPageSize / sizeof(LinkedList);
200  for (std::size_t i = 0; i < size - 1; ++i) {
201  new_chunk[i].next = &new_chunk[i + 1];
202  }
203  new_chunk[size - 1].next = head_;
204  head_ = new_chunk;
205 }
206 
207 template <typename T>
208 template <typename... Args>
209 T* ObjectPoolAllocatable<T>::New(Args&&... args) {
210  return ObjectPool<T>::Get()->New(std::forward<Args>(args)...);
211 }
212 
213 template <typename T>
215  ObjectPool<T>::Get()->Delete(ptr);
216 }
217 
218 } // namespace common
219 } // namespace mxnet
220 #endif // MXNET_COMMON_OBJECT_POOL_H_
static void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:214
static T * New(Args &&...args)
Create new object.
Definition: object_pool.h:209
T * New(Args &&...args)
Create new object.
Definition: object_pool.h:143
static ObjectPool * Get()
Get singleton instance of pool.
Definition: object_pool.h:168
static std::shared_ptr< ObjectPool > _GetSharedRef()
Get a shared ptr of the singleton instance of pool.
Definition: object_pool.h:173
void Delete(T *ptr)
Delete an existing object.
Definition: object_pool.h:157
Helper trait class for easy allocation and deallocation.
Definition: object_pool.h:117
~ObjectPool()
Destructor.
Definition: object_pool.h:134
Object pool for fast allocation and deallocation.
Definition: object_pool.h:37