mxnet
container.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
23 // Acknowledgement: This file originates from incubator-tvm
24 #ifndef MXNET_NODE_CONTAINER_H_
25 #define MXNET_NODE_CONTAINER_H_
26 
27 #include <mxnet/node/node.h>
28 
29 #include <type_traits>
30 #include <vector>
31 #include <initializer_list>
32 #include <unordered_map>
33 #include <utility>
34 #include <string>
35 
36 namespace mxnet {
37 
39 class ArrayNode : public Object {
40  public:
42  std::vector<ObjectRef> data;
43 
44  static constexpr const char* _type_key = "Array";
46 };
47 
53 template <typename Converter, typename TIter>
54 class IterAdapter {
55  public:
56  using difference_type = typename std::iterator_traits<TIter>::difference_type;
57  using value_type = typename Converter::ResultType;
58  using pointer = typename Converter::ResultType*;
59  using reference = typename Converter::ResultType&; // NOLINT(*)
60  using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
61 
62  explicit IterAdapter(TIter iter) : iter_(iter) {}
63  inline IterAdapter& operator++() {
64  ++iter_;
65  return *this;
66  }
67  inline IterAdapter operator+(difference_type offset) const {
68  return IterAdapter(iter_ + offset);
69  }
70 
71  template <typename T = IterAdapter>
72  typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
73  typename T::difference_type>::type inline
74  operator-(const IterAdapter& rhs) const {
75  return iter_ - rhs.iter_;
76  }
77 
78  inline bool operator==(IterAdapter other) const {
79  return iter_ == other.iter_;
80  }
81  inline bool operator!=(IterAdapter other) const {
82  return !(*this == other);
83  }
84  inline const value_type operator*() const {
85  return Converter::convert(*iter_);
86  }
87 
88  private:
89  TIter iter_;
90 };
91 
100 template <typename T,
101  typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
102 class Array : public ObjectRef {
103  public:
107  Array() {
108  data_ = make_object<ArrayNode>();
109  }
114  Array(Array<T>&& other) { // NOLINT(*)
115  data_ = std::move(other.data_);
116  }
121  Array(const Array<T>& other) { // NOLINT(*)
122  data_ = std::move(other.data_);
123  }
135  template <typename IterType>
136  Array(IterType begin, IterType end) {
137  assign(begin, end);
138  }
143  Array(std::initializer_list<T> init) { // NOLINT(*)
144  assign(init.begin(), init.end());
145  }
150  Array(const std::vector<T>& init) { // NOLINT(*)
151  assign(init.begin(), init.end());
152  }
158  explicit Array(size_t n, const T& val) {
159  auto tmp_node = make_object<ArrayNode>();
160  for (size_t i = 0; i < n; ++i) {
161  tmp_node->data.push_back(val);
162  }
163  data_ = std::move(tmp_node);
164  }
171  data_ = std::move(other.data_);
172  return *this;
173  }
179  Array<T>& operator=(const Array<T>& other) {
180  data_ = other.data_;
181  return *this;
182  }
189  template <typename IterType>
190  void assign(IterType begin, IterType end) {
191  auto n = make_object<ArrayNode>();
192  for (IterType it = begin; it != end; ++it) {
193  n->data.push_back(T(*it));
194  }
195  data_ = std::move(n);
196  }
202  inline const T operator[](size_t i) const {
203  return DowncastNoCheck<T>(static_cast<const ArrayNode*>(data_.get())->data[i]);
204  }
206  inline size_t size() const {
207  if (data_.get() == nullptr)
208  return 0;
209  return static_cast<const ArrayNode*>(data_.get())->data.size();
210  }
219  inline ArrayNode* CopyOnWrite() {
220  if (data_.get() == nullptr || !data_.unique()) {
221  runtime::ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
222  n->data = static_cast<ArrayNode*>(data_.get())->data;
223  runtime::ObjectPtr<Object>(std::move(n)).swap(data_);
224  }
225  return static_cast<ArrayNode*>(data_.get());
226  }
231  inline void push_back(const T& item) {
232  ArrayNode* n = this->CopyOnWrite();
233  n->data.push_back(item);
234  }
239  inline void resize(size_t size) {
240  ArrayNode* n = this->CopyOnWrite();
241  n->data.resize(size);
242  }
248  inline void Set(size_t i, const T& value) {
249  ArrayNode* n = this->CopyOnWrite();
250  n->data[i] = value;
251  }
253  inline bool empty() const {
254  return size() == 0;
255  }
262  template <typename F>
263  inline void MutateByApply(F fmutate) {
264  ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
265  if (ptr == nullptr)
266  return;
267  if (data_.unique()) {
268  // Copy on write optimization.
269  // Perform inplace update because this is an unique copy.
270  for (size_t i = 0; i < ptr->data.size(); ++i) {
271  // It is important to use move here
272  // to make prevent the element's ref count from increasing
273  // so fmutate itself can perform copy-on-write optimization
274  T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
275  T new_elem = fmutate(std::move(old_elem));
276  ptr->data[i] = std::move(new_elem);
277  }
278  } else {
279  // lazily trigger copy if there is element change.
281  for (size_t i = 0; i < ptr->data.size(); ++i) {
282  T old_elem = DowncastNoCheck<T>(ptr->data[i]);
283  T new_elem = fmutate(old_elem);
284  if (!new_elem.same_as(ptr->data[i])) {
285  // copy the old array
286  if (copy == nullptr) {
287  copy = runtime::make_object<ArrayNode>(*ptr);
288  }
289  copy->data[i] = std::move(new_elem);
290  }
291  }
292  // replace the data with the new copy.
293  if (copy != nullptr) {
294  data_ = std::move(copy);
295  }
296  }
297  }
298 
301 
302  struct ValueConverter {
303  using ResultType = T;
304  static inline T convert(const ObjectRef& n) {
305  return DowncastNoCheck<T>(n);
306  }
307  };
309 
310  using reverse_iterator =
312 
314  inline iterator begin() const {
315  return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin());
316  }
318  inline iterator end() const {
319  return iterator(static_cast<const ArrayNode*>(data_.get())->data.end());
320  }
322  inline reverse_iterator rbegin() const {
323  return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin());
324  }
326  inline reverse_iterator rend() const {
327  return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend());
328  }
329 };
330 
331 } // namespace mxnet
332 #endif // MXNET_NODE_CONTAINER_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::IterAdapter::operator!=
bool operator!=(IterAdapter other) const
Definition: container.h:81
mxnet::runtime::Object
base class of all object containers.
Definition: object.h:151
mxnet::Array
Array container of NodeRef in DSL graph. Array implements copy on write semantics,...
Definition: container.h:102
mxnet::Array::ValueConverter::ResultType
T ResultType
Definition: container.h:303
mxnet::Array::Array
Array(Array< T > &&other)
move constructor
Definition: container.h:114
mxnet::Array::operator[]
const T operator[](size_t i) const
Read i-th element from array.
Definition: container.h:202
mxnet::runtime::ObjectPtr
A custom smart pointer for Object.
Definition: object.h:346
mxnet::IterAdapter::difference_type
typename std::iterator_traits< TIter >::difference_type difference_type
Definition: container.h:56
mxnet::Array::operator=
Array< T > & operator=(Array< T > &&other)
move assign operator
Definition: container.h:170
mxnet::IterAdapter::value_type
typename Converter::ResultType value_type
Definition: container.h:57
mxnet::Array::reverse_iterator
IterAdapter< ValueConverter, std::vector< ObjectRef >::const_reverse_iterator > reverse_iterator
Definition: container.h:311
mxnet::Array::Array
Array(runtime::ObjectPtr< Object > n)
constructor from pointer
Definition: container.h:128
mxnet::ArrayNode::data
std::vector< ObjectRef > data
the data content
Definition: container.h:42
mxnet::IterAdapter
iterator adapter that adapts TIter to return another type.
Definition: container.h:54
MXNET_DECLARE_FINAL_OBJECT_INFO
#define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:668
mshadow::expr::F
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
mxnet::Array::ValueConverter
Definition: container.h:302
mxnet::Array::ValueConverter::convert
static T convert(const ObjectRef &n)
Definition: container.h:304
mxnet::Array::CopyOnWrite
ArrayNode * CopyOnWrite()
copy on write semantics Do nothing if current handle is the unique copy of the array....
Definition: container.h:219
mxnet::IterAdapter::operator-
std::enable_if< std::is_same< iterator_category, std::random_access_iterator_tag >::value, typename T::difference_type >::type operator-(const IterAdapter &rhs) const
Definition: container.h:74
mxnet::Array::iterator
IterAdapter< ValueConverter, std::vector< ObjectRef >::const_iterator > iterator
Definition: container.h:308
mxnet::IterAdapter::operator*
const value_type operator*() const
Definition: container.h:84
mxnet::Array::Array
Array(const std::vector< T > &init)
constructor from vector
Definition: container.h:150
mxnet::Array::end
iterator end() const
Definition: container.h:318
mxnet::Array::Array
Array()
default constructor
Definition: container.h:107
mxnet::IterAdapter::operator+
IterAdapter operator+(difference_type offset) const
Definition: container.h:67
mxnet::IterAdapter::iterator_category
typename std::iterator_traits< TIter >::iterator_category iterator_category
Definition: container.h:60
mxnet::ArrayNode
array node content in array
Definition: container.h:39
mxnet::IterAdapter::pointer
typename Converter::ResultType * pointer
Definition: container.h:58
mxnet::Array::resize
void resize(size_t size)
Resize the array.
Definition: container.h:239
mxnet::Array::Array
Array(size_t n, const T &val)
Constructs a container with n elements. Each element is a copy of val.
Definition: container.h:158
mxnet::Array::assign
void assign(IterType begin, IterType end)
reset the array to content from iterator.
Definition: container.h:190
mxnet::runtime::ObjectRef::data_
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:575
mxnet::IterAdapter::reference
typename Converter::ResultType & reference
Definition: container.h:59
mxnet::Array::MutateByApply
void MutateByApply(F fmutate)
Helper function to apply fmutate to mutate an array.
Definition: container.h:263
mxnet::Array::size
size_t size() const
Definition: container.h:206
mxnet::ArrayNode::_type_key
static constexpr const char * _type_key
Definition: container.h:44
mxnet::Array::Array
Array(std::initializer_list< T > init)
constructor from initializer list
Definition: container.h:143
mxnet::IterAdapter::operator==
bool operator==(IterAdapter other) const
Definition: container.h:78
mxnet::Array::empty
bool empty() const
Definition: container.h:253
mxnet::runtime::ObjectRef
Base class of all object reference.
Definition: object.h:500
mxnet::Array::rend
reverse_iterator rend() const
Definition: container.h:326
mxnet::Array::begin
iterator begin() const
Definition: container.h:314
mxnet::IterAdapter::IterAdapter
IterAdapter(TIter iter)
Definition: container.h:62
mxnet::Array::push_back
void push_back(const T &item)
push a new item to the back of the list
Definition: container.h:231
mxnet::Array::rbegin
reverse_iterator rbegin() const
Definition: container.h:322
mxnet::Array::Array
Array(const Array< T > &other)
copy constructor
Definition: container.h:121
mxnet::IterAdapter::operator++
IterAdapter & operator++()
Definition: container.h:63
mxnet::runtime::ObjectPtr::swap
void swap(ObjectPtr< T > &other)
Swap this array with another Object.
Definition: object.h:395
mxnet::Array::operator=
Array< T > & operator=(const Array< T > &other)
copy assign operator
Definition: container.h:179
mxnet::Array::Array
Array(IterType begin, IterType end)
constructor from iterator
Definition: container.h:136
mxnet::Array::Set
void Set(size_t i, const T &value)
set i-th element of the array.
Definition: container.h:248
node.h
Definitions and helper macros for IR/AST nodes.