25 #ifndef MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
26 #define MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
28 #include <dmlc/logging.h>
38 template <
typename TElem>
48 template <
typename FCreate>
49 inline std::shared_ptr<TElem>
Get(
int index, FCreate creator);
54 template <
typename FVisit>
55 inline void ForEach(FVisit fvisit);
60 template <
typename SyncObject>
63 explicit unique_unlock(std::unique_lock<SyncObject>*
lock) : lock_(
lock) {
75 std::unique_lock<SyncObject>* lock_;
79 static constexpr std::size_t kInitSize = 16;
81 std::mutex create_mutex_;
83 std::array<std::shared_ptr<TElem>, kInitSize> head_;
85 std::vector<std::shared_ptr<TElem> > more_;
87 std::atomic<bool> is_clearing_;
90 template <
typename TElem>
94 template <
typename TElem>
95 template <
typename FCreate>
98 size_t idx =
static_cast<size_t>(index);
99 if (idx < kInitSize) {
100 std::shared_ptr<TElem> ptr = head_[idx];
104 std::lock_guard<std::mutex>
lock(create_mutex_);
105 if (!is_clearing_.load()) {
106 std::shared_ptr<TElem> ptr = head_[idx];
110 ptr = head_[idx] = std::shared_ptr<TElem>(creator());
115 std::lock_guard<std::mutex>
lock(create_mutex_);
116 if (!is_clearing_.load()) {
118 if (more_.size() <= idx) {
119 more_.reserve(idx + 1);
120 while (more_.size() <= idx) {
121 more_.push_back(std::shared_ptr<TElem>(
nullptr));
124 std::shared_ptr<TElem> ptr = more_[idx];
128 ptr = more_[idx] = std::shared_ptr<TElem>(creator());
135 template <
typename TElem>
137 std::unique_lock<std::mutex>
lock(create_mutex_);
138 is_clearing_.store(
true);
142 for (
size_t i = 0; i < head_.size(); ++i) {
143 std::shared_ptr<TElem> p = head_[i];
144 head_[i] = std::shared_ptr<TElem>(
nullptr);
145 unique_unlock<std::mutex> unlocker(&
lock);
146 p = std::shared_ptr<TElem>(
nullptr);
148 for (
size_t i = 0; i < more_.size(); ++i) {
149 std::shared_ptr<TElem> p = more_[i];
150 more_[i] = std::shared_ptr<TElem>(
nullptr);
151 unique_unlock<std::mutex> unlocker(&
lock);
152 p = std::shared_ptr<TElem>(
nullptr);
155 is_clearing_.store(
false);
158 template <
typename TElem>
159 template <
typename FVisit>
161 std::lock_guard<std::mutex>
lock(create_mutex_);
162 for (
size_t i = 0; i < head_.size(); ++i) {
163 if (head_[i].get() !=
nullptr) {
164 fvisit(i, head_[i].get());
167 for (
size_t i = 0; i < more_.size(); ++i) {
168 if (more_[i].get() !=
nullptr) {
169 fvisit(i + kInitSize, more_[i].get());
176 #endif // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_