mxnet
container.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: This file originates from incubator-tvm
24 #ifndef MXNET_NODE_CONTAINER_H_
25 #define MXNET_NODE_CONTAINER_H_
26 
27 #include <mxnet/node/node.h>
28 
29 #include <type_traits>
30 #include <vector>
31 #include <initializer_list>
32 #include <unordered_map>
33 #include <utility>
34 #include <string>
35 
36 namespace mxnet {
37 
39 class ArrayNode : public Object {
40  public:
42  std::vector<ObjectRef> data;
43 
44  static constexpr const char* _type_key = "Array";
46 };
47 
53 template<typename Converter,
54  typename TIter>
55 class IterAdapter {
56  public:
57  using difference_type = typename std::iterator_traits<TIter>::difference_type;
58  using value_type = typename Converter::ResultType;
59  using pointer = typename Converter::ResultType*;
60  using reference = typename Converter::ResultType&; // NOLINT(*)
61  using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
62 
63  explicit IterAdapter(TIter iter) : iter_(iter) {}
64  inline IterAdapter& operator++() {
65  ++iter_;
66  return *this;
67  }
68  inline IterAdapter operator+(difference_type offset) const {
69  return IterAdapter(iter_ + offset);
70  }
71 
72  template<typename T = IterAdapter>
73  typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
74  typename T::difference_type>::type
75  inline operator-(const IterAdapter& rhs) const {
76  return iter_ - rhs.iter_;
77  }
78 
79  inline bool operator==(IterAdapter other) const {
80  return iter_ == other.iter_;
81  }
82  inline bool operator!=(IterAdapter other) const {
83  return !(*this == other);
84  }
85  inline const value_type operator*() const {
86  return Converter::convert(*iter_);
87  }
88 
89  private:
90  TIter iter_;
91 };
92 
101 template<typename T,
102  typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
103 class Array : public ObjectRef {
104  public:
108  Array() {
109  data_ = make_object<ArrayNode>();
110  }
115  Array(Array<T> && other) { // NOLINT(*)
116  data_ = std::move(other.data_);
117  }
122  Array(const Array<T> &other) { // NOLINT(*)
123  data_ = std::move(other.data_);
124  }
129  explicit Array(runtime::ObjectPtr<Object> n) : ObjectRef(n) {}
136  template<typename IterType>
137  Array(IterType begin, IterType end) {
138  assign(begin, end);
139  }
144  Array(std::initializer_list<T> init) { // NOLINT(*)
145  assign(init.begin(), init.end());
146  }
151  Array(const std::vector<T>& init) { // NOLINT(*)
152  assign(init.begin(), init.end());
153  }
159  explicit Array(size_t n, const T& val) {
160  auto tmp_node = make_object<ArrayNode>();
161  for (size_t i = 0; i < n; ++i) {
162  tmp_node->data.push_back(val);
163  }
164  data_ = std::move(tmp_node);
165  }
172  data_ = std::move(other.data_);
173  return *this;
174  }
180  Array<T>& operator=(const Array<T> & other) {
181  data_ = other.data_;
182  return *this;
183  }
190  template<typename IterType>
191  void assign(IterType begin, IterType end) {
192  auto n = make_object<ArrayNode>();
193  for (IterType it = begin; it != end; ++it) {
194  n->data.push_back(T(*it));
195  }
196  data_ = std::move(n);
197  }
203  inline const T operator[](size_t i) const {
204  return DowncastNoCheck<T>(
205  static_cast<const ArrayNode*>(data_.get())->data[i]);
206  }
208  inline size_t size() const {
209  if (data_.get() == nullptr) return 0;
210  return static_cast<const ArrayNode*>(data_.get())->data.size();
211  }
220  inline ArrayNode* CopyOnWrite() {
221  if (data_.get() == nullptr || !data_.unique()) {
222  runtime::ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
223  n->data = static_cast<ArrayNode*>(data_.get())->data;
224  runtime::ObjectPtr<Object>(std::move(n)).swap(data_);
225  }
226  return static_cast<ArrayNode*>(data_.get());
227  }
232  inline void push_back(const T& item) {
233  ArrayNode* n = this->CopyOnWrite();
234  n->data.push_back(item);
235  }
240  inline void resize(size_t size) {
241  ArrayNode* n = this->CopyOnWrite();
242  n->data.resize(size);
243  }
249  inline void Set(size_t i, const T& value) {
250  ArrayNode* n = this->CopyOnWrite();
251  n->data[i] = value;
252  }
254  inline bool empty() const {
255  return size() == 0;
256  }
263  template<typename F>
264  inline void MutateByApply(F fmutate) {
265  ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
266  if (ptr == nullptr) return;
267  if (data_.unique()) {
268  // Copy on write optimization.
269  // Perform inplace update because this is an unique copy.
270  for (size_t i = 0; i < ptr->data.size(); ++i) {
271  // It is important to use move here
272  // to make prevent the element's ref count from increasing
273  // so fmutate itself can perform copy-on-write optimization
274  T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
275  T new_elem = fmutate(std::move(old_elem));
276  ptr->data[i] = std::move(new_elem);
277  }
278  } else {
279  // lazily trigger copy if there is element change.
281  for (size_t i = 0; i < ptr->data.size(); ++i) {
282  T old_elem = DowncastNoCheck<T>(ptr->data[i]);
283  T new_elem = fmutate(old_elem);
284  if (!new_elem.same_as(ptr->data[i])) {
285  // copy the old array
286  if (copy == nullptr) {
287  copy = runtime::make_object<ArrayNode>(*ptr);
288  }
289  copy->data[i] = std::move(new_elem);
290  }
291  }
292  // replace the data with the new copy.
293  if (copy != nullptr) {
294  data_ = std::move(copy);
295  }
296  }
297  }
298 
301 
302  struct ValueConverter {
303  using ResultType = T;
304  static inline T convert(const ObjectRef& n) {
305  return DowncastNoCheck<T>(n);
306  }
307  };
309  std::vector<ObjectRef>::const_iterator>;
310 
312  ValueConverter,
313  std::vector<ObjectRef>::const_reverse_iterator>;
314 
316  inline iterator begin() const {
317  return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin());
318  }
320  inline iterator end() const {
321  return iterator(static_cast<const ArrayNode*>(data_.get())->data.end());
322  }
324  inline reverse_iterator rbegin() const {
325  return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin());
326  }
328  inline reverse_iterator rend() const {
329  return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend());
330  }
331 };
332 
333 } // namespace mxnet
334 #endif // MXNET_NODE_CONTAINER_H_
bool operator==(IterAdapter other) const
Definition: container.h:79
IterAdapter(TIter iter)
Definition: container.h:63
Array< T > & operator=(const Array< T > &other)
copy assign operator
Definition: container.h:180
ArrayNode * CopyOnWrite()
copy on write semantics Do nothing if current handle is the unique copy of the array. Otherwise make a new copy of the array to ensure the current handle hold a unique copy.
Definition: container.h:220
Array(std::initializer_list< T > init)
constructor from initializer list
Definition: container.h:144
typename Converter::ResultType & reference
Definition: container.h:60
bool operator!=(IterAdapter other) const
Definition: container.h:82
std::vector< ObjectRef > data
the data content
Definition: container.h:42
namespace of mxnet
Definition: api_registry.h:33
iterator end() const
Definition: container.h:320
typename Converter::ResultType * pointer
Definition: container.h:59
static T convert(const ObjectRef &n)
Definition: container.h:304
MXNET_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object)
reverse_iterator rend() const
Definition: container.h:328
A custom smart pointer for Object.
Definition: object.h:345
bool empty() const
Definition: container.h:254
IterAdapter & operator++()
Definition: container.h:64
Definition: container.h:302
reverse_iterator rbegin() const
Definition: container.h:324
void Set(size_t i, const T &value)
set i-th element of the array.
Definition: container.h:249
typename Converter::ResultType value_type
Definition: container.h:58
Array(size_t n, const T &val)
Constructs a container with n elements. Each element is a copy of val.
Definition: container.h:159
IterAdapter operator+(difference_type offset) const
Definition: container.h:68
std::enable_if< std::is_same< iterator_category, std::random_access_iterator_tag >::value, typename T::difference_type >::type operator-(const IterAdapter &rhs) const
Definition: container.h:75
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
Array< T > & operator=(Array< T > &&other)
move assign operator
Definition: container.h:171
const value_type operator*() const
Definition: container.h:85
void resize(size_t size)
Resize the array.
Definition: container.h:240
void push_back(const T &item)
push a new item to the back of the list
Definition: container.h:232
array node content in array
Definition: container.h:39
Array(const Array< T > &other)
copy constructor
Definition: container.h:122
typename std::iterator_traits< TIter >::difference_type difference_type
Definition: container.h:57
T ResultType
Definition: container.h:303
void MutateByApply(F fmutate)
Helper function to apply fmutate to mutate an array.
Definition: container.h:264
iterator begin() const
Definition: container.h:316
Array container of NodeRef in DSL graph. Array implements copy on write semantics, which means array is mutable but copy will happen when array is referenced in more than two places.
Definition: container.h:103
typename std::iterator_traits< TIter >::iterator_category iterator_category
Definition: container.h:61
const T operator[](size_t i) const
Read i-th element from array.
Definition: container.h:203
Array(Array< T > &&other)
move constructor
Definition: container.h:115
Array()
default constructor
Definition: container.h:108
size_t size() const
Definition: container.h:208
static constexpr const char * _type_key
Definition: container.h:44
iterator adapter that adapts TIter to return another type.
Definition: container.h:55
Array(const std::vector< T > &init)
constructor from vector
Definition: container.h:151
Array(runtime::ObjectPtr< Object > n)
constructor from pointer
Definition: container.h:129
Array(IterType begin, IterType end)
constructor from iterator
Definition: container.h:137
void assign(IterType begin, IterType end)
reset the array to content from iterator.
Definition: container.h:191