mxnet
lazy_alloc_array.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
26 #define MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
27 
28 #include <dmlc/logging.h>
29 #include <memory>
30 #include <mutex>
31 #include <array>
32 #include <vector>
33 #include <atomic>
34 
35 namespace mxnet {
36 namespace common {
37 
38 template<typename TElem>
40  public:
48  template<typename FCreate>
49  inline std::shared_ptr<TElem> Get(int index, FCreate creator);
54  template<typename FVisit>
55  inline void ForEach(FVisit fvisit);
57  inline void Clear();
58 
59  void SignalForKill();
60 
61  private:
62  template<typename SyncObject>
63  class unique_unlock {
64  public:
65  explicit unique_unlock(std::unique_lock<SyncObject> *lock)
66  : lock_(lock) {
67  if (lock_) {
68  lock_->unlock();
69  }
70  }
71  ~unique_unlock() {
72  if (lock_) {
73  lock_->lock();
74  }
75  }
76  private:
77  std::unique_lock<SyncObject> *lock_;
78  };
79 
81  static constexpr std::size_t kInitSize = 16;
83  std::mutex create_mutex_;
85  std::array<std::shared_ptr<TElem>, kInitSize> head_;
87  std::vector<std::shared_ptr<TElem> > more_;
89  std::atomic<bool> exit_now_;
90 };
91 
92 template<typename TElem>
94  : exit_now_(false) {
95 }
96 
97 // implementations
98 template<typename TElem>
99 template<typename FCreate>
100 inline std::shared_ptr<TElem> LazyAllocArray<TElem>::Get(int index, FCreate creator) {
101  CHECK_GE(index, 0);
102  size_t idx = static_cast<size_t>(index);
103  if (idx < kInitSize) {
104  std::shared_ptr<TElem> ptr = head_[idx];
105  if (ptr) {
106  return ptr;
107  } else {
108  std::lock_guard<std::mutex> lock(create_mutex_);
109  if (!exit_now_.load()) {
110  std::shared_ptr<TElem> ptr = head_[idx];
111  if (ptr) {
112  return ptr;
113  }
114  ptr = head_[idx] = std::shared_ptr<TElem>(creator());
115  return ptr;
116  }
117  }
118  } else {
119  std::lock_guard<std::mutex> lock(create_mutex_);
120  if (!exit_now_.load()) {
121  idx -= kInitSize;
122  if (more_.size() <= idx) {
123  more_.reserve(idx + 1);
124  while (more_.size() <= idx) {
125  more_.push_back(std::shared_ptr<TElem>(nullptr));
126  }
127  }
128  std::shared_ptr<TElem> ptr = more_[idx];
129  if (ptr) {
130  return ptr;
131  }
132  ptr = more_[idx] = std::shared_ptr<TElem>(creator());
133  return ptr;
134  }
135  }
136  return nullptr;
137 }
138 
139 template<typename TElem>
141  std::unique_lock<std::mutex> lock(create_mutex_);
142  exit_now_.store(true);
143  // Currently, head_ and more_ never get smaller, so it's safe to
144  // iterate them outside of the lock. The loops should catch
145  // any growth which might happen when create_mutex_ is unlocked
146  for (size_t i = 0; i < head_.size(); ++i) {
147  std::shared_ptr<TElem> p = head_[i];
148  head_[i] = std::shared_ptr<TElem>(nullptr);
149  unique_unlock<std::mutex> unlocker(&lock);
150  p = std::shared_ptr<TElem>(nullptr);
151  }
152  for (size_t i = 0; i < more_.size(); ++i) {
153  std::shared_ptr<TElem> p = more_[i];
154  more_[i] = std::shared_ptr<TElem>(nullptr);
155  unique_unlock<std::mutex> unlocker(&lock);
156  p = std::shared_ptr<TElem>(nullptr);
157  }
158 }
159 
160 template<typename TElem>
161 template<typename FVisit>
162 inline void LazyAllocArray<TElem>::ForEach(FVisit fvisit) {
163  std::lock_guard<std::mutex> lock(create_mutex_);
164  for (size_t i = 0; i < head_.size(); ++i) {
165  if (head_[i].get() != nullptr) {
166  fvisit(i, head_[i].get());
167  }
168  }
169  for (size_t i = 0; i < more_.size(); ++i) {
170  if (more_[i].get() != nullptr) {
171  fvisit(i + kInitSize, more_[i].get());
172  }
173  }
174 }
175 
176 template<typename TElem>
178  std::lock_guard<std::mutex> lock(create_mutex_);
179  exit_now_.store(true);
180 }
181 
182 } // namespace common
183 } // namespace mxnet
184 #endif // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
std::shared_ptr< TElem > Get(int index, FCreate creator)
Get element of corresponding index, if it is not created create by creator.
Definition: lazy_alloc_array.h:100
namespace of mxnet
Definition: base.h:126
void Clear()
clear all the allocated elements in array
Definition: lazy_alloc_array.h:140
void ForEach(FVisit fvisit)
for each not null element of the array, call fvisit
Definition: lazy_alloc_array.h:162
Definition: lazy_alloc_array.h:39
void SignalForKill()
Definition: lazy_alloc_array.h:177
LazyAllocArray()
Definition: lazy_alloc_array.h:93