mxnet
shape.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_SHAPE_H_
27 #define MXNET_CPP_SHAPE_H_
28 
29 #include <istream>
30 #include <ostream>
31 #include <algorithm>
32 #include <vector>
33 #include "mxnet-cpp/base.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
42 struct Shape {
43  public:
45  Shape() : ndim_(0), num_heap_allocated_(0), data_heap_(nullptr) {}
50  explicit Shape(const std::vector<index_t>& v) : ndim_(v.size()) {
51  if (ndim_ <= kStackCache) {
52  data_heap_ = nullptr;
53  num_heap_allocated_ = 0;
54  std::copy(v.begin(), v.end(), data_stack_);
55  } else {
56  data_heap_ = new index_t[ndim_];
57  num_heap_allocated_ = ndim_;
58  std::copy(v.begin(), v.end(), data_heap_);
59  }
60  }
65  explicit Shape(index_t s1) : ndim_(1) {
66  if (ndim_ <= kStackCache) {
67  data_heap_ = nullptr;
68  num_heap_allocated_ = 0;
69  data_stack_[0] = s1;
70  } else {
71  data_heap_ = new index_t[ndim_];
72  num_heap_allocated_ = ndim_;
73  data_heap_[0] = s1;
74  }
75  }
81  Shape(index_t s1, index_t s2) : ndim_(2) {
82  if (ndim_ <= kStackCache) {
83  data_heap_ = nullptr;
84  num_heap_allocated_ = 0;
85  data_stack_[0] = s1;
86  data_stack_[1] = s2;
87  } else {
88  data_heap_ = new index_t[ndim_];
89  num_heap_allocated_ = ndim_;
90  data_heap_[0] = s1;
91  data_heap_[1] = s2;
92  }
93  }
100  Shape(index_t s1, index_t s2, index_t s3) : ndim_(3) {
101  if (ndim_ <= kStackCache) {
102  data_heap_ = nullptr;
103  num_heap_allocated_ = 0;
104  data_stack_[0] = s1;
105  data_stack_[1] = s2;
106  data_stack_[2] = s3;
107  } else {
108  data_heap_ = new index_t[ndim_];
109  num_heap_allocated_ = ndim_;
110  data_heap_[0] = s1;
111  data_heap_[1] = s2;
112  data_heap_[2] = s3;
113  }
114  }
122  Shape(index_t s1, index_t s2, index_t s3, index_t s4) : ndim_(4) {
123  if (ndim_ <= kStackCache) {
124  data_heap_ = nullptr;
125  num_heap_allocated_ = 0;
126  data_stack_[0] = s1;
127  data_stack_[1] = s2;
128  data_stack_[2] = s3;
129  data_stack_[3] = s4;
130  } else {
131  data_heap_ = new index_t[ndim_];
132  num_heap_allocated_ = ndim_;
133  data_heap_[0] = s1;
134  data_heap_[1] = s2;
135  data_heap_[2] = s3;
136  data_heap_[3] = s4;
137  }
138  }
147  Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5) : ndim_(5) {
148  if (ndim_ <= kStackCache) {
149  data_heap_ = nullptr;
150  num_heap_allocated_ = 0;
151  data_stack_[0] = s1;
152  data_stack_[1] = s2;
153  data_stack_[2] = s3;
154  data_stack_[3] = s4;
155  data_stack_[4] = s5;
156  } else {
157  data_heap_ = new index_t[ndim_];
158  num_heap_allocated_ = ndim_;
159  data_heap_[0] = s1;
160  data_heap_[1] = s2;
161  data_heap_[2] = s3;
162  data_heap_[3] = s4;
163  data_heap_[4] = s5;
164  }
165  }
170  Shape(const Shape& s) : ndim_(s.ndim_) {
171  if (ndim_ <= kStackCache) {
172  data_heap_ = nullptr;
173  num_heap_allocated_ = 0;
174  std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
175  } else {
176  data_heap_ = new index_t[ndim_];
177  num_heap_allocated_ = ndim_;
178  std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
179  }
180  }
181 #if MSHADOW_IN_CXX11
182 
186  Shape(Shape&& s)
187  : ndim_(s.ndim_), num_heap_allocated_(s.num_heap_allocated_), data_heap_(s.data_heap_) {
188  if (ndim_ <= kStackCache) {
189  std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
190  }
191  // remove data heap space from s
192  s.data_heap_ = nullptr;
193  }
194 #endif
195 
196  ~Shape() {
197  // data_heap_ can be nullptr
198  delete[] data_heap_;
199  }
206  template <typename RandomAccessIterator>
207  inline void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end) {
208  this->SetDim(end - begin);
209  std::copy(begin, end, data());
210  }
216  inline Shape& operator=(const Shape& shape) {
217  this->SetDim(shape.ndim_);
218  const index_t* src = shape.data();
219  std::copy(src, src + ndim_, data());
220  return *this;
221  }
227  inline Shape& operator=(const std::vector<index_t>& shape) {
228  this->CopyFrom(shape.begin(), shape.end());
229  return *this;
230  }
232  inline const index_t* data() const {
233  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
234  }
236  inline index_t* data() {
237  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
238  }
240  inline index_t ndim(void) const {
241  return ndim_;
242  }
249  return data()[i];
250  }
256  inline const index_t& operator[](index_t i) const {
257  return data()[i];
258  }
260  inline size_t Size(void) const {
261  size_t size = 1;
262  const index_t* d = this->data();
263  for (index_t i = 0; i < ndim_; ++i) {
264  size *= d[i];
265  }
266  return size;
267  }
272  inline bool operator==(const Shape& s) const {
273  if (ndim_ != s.ndim_)
274  return false;
275  if (ndim_ <= kStackCache) {
276  for (index_t i = 0; i < ndim_; ++i) {
277  if (data_stack_[i] != s.data_stack_[i])
278  return false;
279  }
280  } else {
281  for (index_t i = 0; i < ndim_; ++i) {
282  if (data_heap_[i] != s.data_heap_[i])
283  return false;
284  }
285  }
286  return true;
287  }
292  inline bool operator!=(const Shape& s) const {
293  return !(*this == s);
294  }
295 
296  friend std::ostream& operator<<(std::ostream& os, const Shape& shape);
297  friend std::istream& operator>>(std::istream& is, Shape& shape);
298 
299  private:
300  // the shape will be stored in data_stack_
301  // when dimension is smaller than kStackCache
302  // when it is bigger, it will be stored in data_heap_;
304  static const index_t kStackCache = 5;
306  index_t ndim_;
308  index_t num_heap_allocated_;
310  index_t data_stack_[kStackCache];
312  index_t* data_heap_;
317  inline void SetDim(index_t dim) {
318  if (dim > kStackCache && dim > num_heap_allocated_) {
319  // data_heap_ can be nullptr
320  delete[] data_heap_;
321  data_heap_ = new index_t[dim];
322  num_heap_allocated_ = dim;
323  }
324  ndim_ = dim;
325  }
326 };
327 
334 inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
335  os << '(';
336  for (index_t i = 0; i < shape.ndim(); ++i) {
337  if (i != 0)
338  os << ',';
339  os << static_cast<int>(shape[i]); // Supports negative Shape 'special codes' for inferring
340  }
341  // python style tuple
342  if (shape.ndim() == 1)
343  os << ',';
344  os << ')';
345  return os;
346 }
347 
354 inline std::istream& operator>>(std::istream& is, Shape& shape) {
355  // get (
356  while (true) {
357  char ch = is.get();
358  if (ch == '(')
359  break;
360  if (!isspace(ch)) {
361  is.setstate(std::ios::failbit);
362  return is;
363  }
364  }
365  index_t idx;
366  std::vector<index_t> tmp;
367  while (is >> idx) {
368  tmp.push_back(idx);
369  char ch;
370  do {
371  ch = is.get();
372  } while (isspace(ch));
373  if (ch == ',') {
374  while (true) {
375  ch = is.peek();
376  if (isspace(ch)) {
377  is.get();
378  continue;
379  }
380  if (ch == ')') {
381  is.get();
382  break;
383  }
384  break;
385  }
386  if (ch == ')')
387  break;
388  } else if (ch == ')') {
389  break;
390  } else {
391  is.setstate(std::ios::failbit);
392  return is;
393  }
394  }
395  shape.CopyFrom(tmp.begin(), tmp.end());
396  return is;
397 }
398 
399 } // namespace cpp
400 } // namespace mxnet
401 
402 #endif // MXNET_CPP_SHAPE_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::Shape::Shape
Shape(index_t s1, index_t s2, index_t s3)
constructor three dimmension shape
Definition: shape.h:100
mxnet::cpp::Shape::CopyFrom
void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end)
copy shape from content betwen two iterators
Definition: shape.h:207
dmlc::isspace
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter.
Definition: strtonum.h:26
mxnet::cpp::operator>>
std::istream & operator>>(std::istream &is, Shape &shape)
read shape from the istream
Definition: shape.h:354
mxnet::cpp::Shape::Shape
Shape(index_t s1, index_t s2, index_t s3, index_t s4)
constructor four dimmension shape
Definition: shape.h:122
mxnet::cpp::Shape::operator=
Shape & operator=(const std::vector< index_t > &shape)
assignment from vector
Definition: shape.h:227
mxnet::cpp::Shape::operator=
Shape & operator=(const Shape &shape)
assignment from shape
Definition: shape.h:216
mxnet::cpp::operator<<
std::ostream & operator<<(std::ostream &out, const NDArray &ndarray)
mxnet::cpp::Shape::Shape
Shape(index_t s1)
constructor one dimmension shape
Definition: shape.h:65
mxnet::cpp::Shape::operator!=
bool operator!=(const Shape &s) const
Definition: shape.h:292
mxnet::cpp::Shape::Shape
Shape(const std::vector< index_t > &v)
constructor from a vector of index_t
Definition: shape.h:50
mxnet::cpp::Shape::Shape
Shape(index_t s1, index_t s2)
constructor two dimmension shape
Definition: shape.h:81
mxnet::cpp::index_t
unsigned index_t
Definition: base.h:36
mxnet::cpp::Shape::operator==
bool operator==(const Shape &s) const
Definition: shape.h:272
mxnet::cpp::Shape::ndim
index_t ndim(void) const
return number of dimension of the tensor inside
Definition: shape.h:240
mxnet::cpp::Shape::Shape
Shape(const Shape &s)
constructor from Shape
Definition: shape.h:170
mxnet::cpp::Shape::Shape
Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
constructor five dimmension shape
Definition: shape.h:147
mxnet::cpp::Shape::operator<<
friend std::ostream & operator<<(std::ostream &os, const Shape &shape)
allow string printing of the shape
Definition: shape.h:334
mxnet::cpp::Shape
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
mxnet::cpp::Shape::Shape
Shape()
constructor
Definition: shape.h:45
mxnet::cpp::Shape::operator[]
index_t & operator[](index_t i)
get corresponding index
Definition: shape.h:248
mxnet::cpp::Shape::operator[]
const index_t & operator[](index_t i) const
get corresponding index
Definition: shape.h:256
base.h
base definitions for mxnetcpp
mxnet::cpp::Shape::Size
size_t Size(void) const
total number of elements in the tensor
Definition: shape.h:260
mxnet::cpp::Shape::data
index_t * data()
Definition: shape.h:236
mxnet::cpp::Shape::operator>>
friend std::istream & operator>>(std::istream &is, Shape &shape)
read shape from the istream
Definition: shape.h:354
mxnet::cpp::Shape::~Shape
~Shape()
destructor
Definition: shape.h:196
mxnet::cpp::Shape::data
const index_t * data() const
Definition: shape.h:232