mxnet
tuple.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_TUPLE_H_
25 #define NNVM_TUPLE_H_
26 
27 #include <vector>
28 #include <type_traits>
29 #include <algorithm>
30 #include <utility>
31 #include <iostream>
32 #include <string>
33 #include "base.h"
34 
35 namespace nnvm {
36 
38 typedef int64_t dim_t;
39 
50 template<typename ValueType>
51 class Tuple {
52  public:
54  Tuple() = default;
56  inline ~Tuple() {
57  delete [] data_heap_;
58  }
63  inline Tuple(const Tuple<ValueType>& s) {
64  this->assign(s.begin(), s.end());
65  }
70  inline Tuple(std::initializer_list<ValueType> init) {
71  this->assign(init.begin(), init.end());
72  }
77  inline Tuple(std::vector<ValueType> init) { // NOLINT(runtime/explicit)
78  this->assign(init.begin(), init.end());
79  }
85  inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
86  this->swap(src);
87  }
94  template<typename RandomAccessIterator>
95  inline Tuple(RandomAccessIterator begin,
96  RandomAccessIterator end) {
97  this->assign(begin, end);
98  }
105  template<typename RandomAccessIterator>
106  inline void assign(RandomAccessIterator begin,
107  RandomAccessIterator end) {
108  this->SetDim(end - begin);
109  std::copy(begin, end, this->begin());
110  }
115  inline void swap(Tuple<ValueType>& other) { // NOLINT(*)
116  std::swap(ndim_, other.ndim_);
117  std::swap(num_heap_allocated_, other.num_heap_allocated_);
118  std::swap(data_stack_, other.data_stack_);
119  std::swap(data_heap_, other.data_heap_);
120  }
127  this->assign(src.begin(), src.end());
128  return *this;
129  }
136  Tuple<ValueType>(std::move(src)).swap(*this);
137  return *this;
138  }
144  inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) {
145  this->assign(init.begin(), init.end());
146  return *this;
147  }
152  inline bool operator==(const Tuple<ValueType> &s) const {
153  if (ndim_ != s.ndim_) return false;
154  return std::equal(begin(), end(), s.begin());
155  }
160  inline bool operator!=(const Tuple<ValueType> &s) const {
161  return !(*this == s);
162  }
164  inline const ValueType *begin() const {
165  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
166  }
168  inline ValueType *begin() {
169  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
170  }
172  inline const ValueType* end() const {
173  return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
174  }
176  inline ValueType* end() {
177  return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
178  }
180  inline uint32_t ndim() const {
181  return ndim_;
182  }
188  inline ValueType& operator[](size_t i) {
189  return begin()[i];
190  }
196  inline const ValueType& operator[](size_t i) const {
197  return begin()[i];
198  }
203  inline void Save(dmlc::JSONWriter* writer) const {
204  std::vector<ValueType> tmp(begin(), end());
205  writer->Write(tmp);
206  }
211  inline void Load(dmlc::JSONReader* reader) {
212  std::vector<ValueType> tmp;
213  reader->Read(&tmp);
214  this->assign(tmp.begin(), tmp.end());
215  }
222  friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
223  os << '[';
224  const ValueType* begin = t.begin();
225  const ValueType* end = t.end();
226  for (const ValueType* it = begin; it != end; ++it) {
227  if (it != begin) os << ',';
228  os << *it;
229  }
230  os << ']';
231  return os;
232  }
239  friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) {
240  // get (
241  while (true) {
242  char ch = is.peek();
243  if (isdigit(ch) || ch == '-') {
244  ValueType idx;
245  if (is >> idx) {
246  t.assign(&idx, &idx + 1);
247  }
248  return is;
249  }
250  is.get();
251  if (ch == '(' || ch == '[') break;
252  if (!isspace(ch)) {
253  is.setstate(std::ios::failbit);
254  return is;
255  }
256  }
257  // Handle empty tuple
258  while (isspace(is.peek())) {
259  is.get();
260  }
261  if (is.peek() == ')' || is.peek() == ']') {
262  is.get();
263  return is;
264  }
265  // Handle non-empty tuple
266  ValueType idx;
267  std::vector<ValueType> tmp;
268  while (is >> idx) {
269  tmp.push_back(idx);
270  char ch;
271  do {
272  ch = is.get();
273  } while (isspace(ch));
274  if (std::is_integral<ValueType>::value && ch == 'L') {
275  ch = is.get();
276  }
277  if (ch == ',') {
278  while (true) {
279  ch = is.peek();
280  if (isspace(ch)) {
281  is.get(); continue;
282  }
283  if (ch == ')' || ch == ']') {
284  is.get(); break;
285  }
286  break;
287  }
288  if (ch == ')' || ch == ']') break;
289  } else if (ch == ')' || ch == ']') {
290  break;
291  } else {
292  is.setstate(std::ios::failbit);
293  return is;
294  }
295  }
296  t.assign(tmp.begin(), tmp.end());
297  return is;
298  }
305  template<typename DType = ValueType, typename TStream>
306  inline void Save(TStream *strm) const;
314  template<typename DType = ValueType, typename TStream>
315  inline bool Load(TStream *strm);
316 
317  protected:
318  // stack cache size
319  static const uint32_t kStackCache = 4;
321  uint32_t ndim_{0};
323  uint32_t num_heap_allocated_{0};
327  ValueType* data_heap_{nullptr};
328  // internal function to change the dimension
329  inline void SetDim(uint32_t ndim) {
330  if (ndim > kStackCache &&
331  ndim > num_heap_allocated_) {
332  delete [] data_heap_;
333  data_heap_ = new ValueType[ndim];
335  }
336  ndim_ = ndim;
337  }
338 };
339 
343 class TShape : public Tuple<dim_t> {
344  public:
346  TShape() = default;
351  inline TShape(uint32_t ndim) { // NOLINT(*)
352  this->SetDim(ndim);
353  std::fill_n(begin(), ndim, 1);
354  }
359  inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
360  this->assign(s.begin(), s.end());
361  }
366  inline TShape(std::initializer_list<dim_t> init) {
367  this->assign(init.begin(), init.end());
368  }
373  inline TShape(Tuple<dim_t>&& s) { // NOLINT(*)
374  this->swap(s);
375  }
382  template<typename RandomAccessIterator>
383  inline TShape(RandomAccessIterator begin,
384  RandomAccessIterator end) {
385  this->assign(begin, end);
386  }
392  inline TShape& operator=(const Tuple<dim_t>& src) {
393  this->assign(src.begin(), src.end());
394  return *this;
395  }
401  inline TShape& operator=(Tuple<dim_t>&& src) { // NOLINT(*)
402  TShape(std::move(src)).swap(*this); // NOLINT(*)
403  return *this;
404  }
406  inline size_t Size() const {
407  dim_t size = 1;
408  const dim_t* start = begin(), *fin = end();
409  for (const dim_t* it = start; it != fin; ++it) {
410  size *= *it;
411  }
412  return size;
413  }
419  inline size_t ProdShape(int dimstart, int dimend) const {
420  dim_t num = 1;
421  const dim_t *d = this->data();
422  for (int i = dimstart; i < dimend; ++i) {
423  num *= d[i];
424  }
425  return num;
426  }
428  inline const dim_t *data() const {
429  return begin();
430  }
432  inline dim_t *data() {
433  return begin();
434  }
435 #ifdef MSHADOW_XINLINE
436  template<int dim>
437  inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
438  this->assign(s.shape_, s.shape_ + dim);
439  }
440 
441  template<int dim>
442  inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
443  this->assign(s.shape_, s.shape_ + dim);
444  }
451  template<int dim>
452  inline TShape &operator=(const mshadow::Shape<dim> &shape) {
453  this->assign(shape.shape_, shape.shape_ + dim);
454  return *this;
455  }
461  template<int dim>
462  inline mshadow::Shape<dim> get() const {
463  CHECK_EQ(dim, static_cast<int>(ndim()))
464  << "dimension do not match target dimension " << dim << " vs " << ndim();
465  const dim_t *d = this->data();
467  for (int i = 0; i < dim; ++i) {
468  s[i] = d[i];
469  }
470  return s;
471  }
476  inline mshadow::Shape<2> FlatTo2D(void) const {
478  if (ndim() == 0) return mshadow::Shape2(0, 0);
479  const dim_t *d = this->data();
480  s.shape_[1] = d[ndim() - 1];
481  dim_t ymax = 1;
482  for (size_t i = 1; i < ndim(); ++i) {
483  ymax *= d[i - 1];
484  }
485  s.shape_[0] = ymax;
486  return s;
487  }
494  inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const {
495  CHECK(axis_end >= axis_begin);
497  if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
498  const dim_t *d = this->data();
499  s.shape_[0] = 1;
500  s.shape_[1] = 1;
501  s.shape_[2] = 1;
502 
503  for (size_t i = 0; i < axis_begin; ++i) {
504  s.shape_[0] *= d[i];
505  }
506  for (size_t i = axis_begin; i <= axis_end; ++i) {
507  s.shape_[1] *= d[i];
508  }
509  for (size_t i = axis_end + 1; i < ndim(); ++i) {
510  s.shape_[2] *= d[i];
511  }
512  return s;
513  }
519  inline mshadow::Shape<3> FlatTo3D(size_t axis) const {
520  return FlatTo3D(axis, axis);
521  }
522  inline bool operator==(const TShape &s) const {
523  if (ndim() != s.ndim()) return false;
524  return std::equal(begin(), end(), s.begin());
525  }
526  inline bool operator!=(const TShape &s) const {
527  return !(*this == s);
528  }
534  template<int dim>
535  inline bool operator==(const mshadow::Shape<dim> &s) const {
536  if (ndim_ != dim) return false;
537  const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
538  for (size_t i = 0; i < dim; ++i) {
539  if (d[i] != s.shape_[i]) return false;
540  }
541  return true;
542  }
548  template<int dim>
549  inline bool operator!=(const mshadow::Shape<dim> &s) const {
550  return !(*this == s);
551  }
552 #endif
553 };
554 
556 template<typename SrcIter, typename DstIter>
557 inline DstIter ShapeTypeCast(const SrcIter begin,
558  const SrcIter end,
559  DstIter dst_begin) {
560  typedef typename std::iterator_traits<SrcIter>::value_type SrcDType;
561  typedef typename std::iterator_traits<DstIter>::value_type DstDType;
562  auto cast = [](const SrcDType& dim) { return static_cast<DstDType>(dim); };
563  return std::transform(begin, end, dst_begin, cast);
564 }
565 
567 template<typename SrcIter>
568 inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) {
569  size_t ndim = std::distance(begin, end);
570  TShape res(ndim);
571  ShapeTypeCast(begin, end, res.begin());
572  return res;
573 }
574 
576 template<typename ValueType>
577 template<typename DType, typename TStream>
578 inline void Tuple<ValueType>::Save(TStream *strm) const {
579  strm->Write(&ndim_, sizeof(ndim_));
580  if (typeid(DType) == typeid(ValueType)) {
581  strm->Write(begin(), sizeof(ValueType) * ndim_);
582  } else {
583  std::vector<DType> buffer(ndim_);
584  ShapeTypeCast(begin(), end(), buffer.data());
585  strm->Write(buffer.data(), sizeof(DType) * ndim_);
586  }
587 }
588 
590 template<typename ValueType>
591 template<typename DType, typename TStream>
592 inline bool Tuple<ValueType>::Load(TStream *strm) {
593  if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
594  this->SetDim(ndim_);
595  size_t nread = sizeof(DType) * ndim_;
596  if (typeid(DType) == typeid(ValueType)) {
597  if (strm->Read(begin(), nread) != nread) return false;
598  } else {
599  std::vector<DType> buffer(ndim_);
600  if (strm->Read(buffer.data(), nread) != nread) return false;
601  ShapeTypeCast(buffer.begin(), buffer.end(), begin());
602  }
603  return true;
604 }
605 
606 } // namespace nnvm
607 
608 namespace std {
610 template<typename T>
611 struct hash<nnvm::Tuple<T> > {
613  size_t operator()(const nnvm::Tuple<T>& val) const {
614  std::hash<uint32_t> hash_uint;
615  size_t res = hash_uint(val.ndim());
616  for (uint32_t i = 0; i < val.ndim(); ++i) {
617  res = dmlc::HashCombine(res, val[i]);
618  }
619  return res;
620  }
621 };
622 
624 template<>
625 struct hash<nnvm::TShape> {
627  size_t operator()(const nnvm::TShape& val) const {
628  std::hash<uint32_t> hash_uint;
629  size_t res = hash_uint(val.ndim());
630  for (uint32_t i = 0; i < val.ndim(); ++i) {
631  res = dmlc::HashCombine(res, val[i]);
632  }
633  return res;
634  }
635 };
636 } // namespace std
637 
638 namespace dmlc {
640 DMLC_DECLARE_TYPE_NAME(optional<nnvm::TShape>, "Shape or None");
641 // avoid low version of MSVC
642 #if !defined(_MSC_VER)
643 template<typename T>
645  static inline std::string value() {
646  return "tuple of <" + type_name<T>() + ">";
647  }
648 };
649 #endif
650 } // namespace dmlc
651 #endif // NNVM_TUPLE_H_
#define DMLC_DECLARE_TYPE_NAME(Type, Name)
macro to quickly declare traits information
Definition: type_traits.h:133
Definition: base.h:35
helper class to construct a string that represents type name
Definition: type_traits.h:86
Tuple< ValueType > & operator=(Tuple< ValueType > &&src)
assignment from rvalue of another tuple.
Definition: tuple.h:135
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:51
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:327
uint32_t ndim_
number of dimension of the tuple
Definition: tuple.h:321
const ValueType & operator[](size_t i) const
get corresponding index
Definition: tuple.h:196
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:392
dim_t * data()
Definition: tuple.h:432
bool operator==(const Tuple< ValueType > &s) const
Definition: tuple.h:152
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:325
int64_t dim_t
data type to store dim size
Definition: tuple.h:38
size_t operator()(const nnvm::TShape &val) const
hash a TShape into unsigned int
Definition: tuple.h:627
Definition: optional.h:241
Tuple()=default
default constructor
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:343
TShape(uint32_t ndim)
Definition: tuple.h:351
TShape(std::initializer_list< dim_t > init)
constructor from initializer list
Definition: tuple.h:366
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:138
ValueType & operator[](size_t i)
get corresponding index
Definition: tuple.h:188
size_t Size() const
Definition: tuple.h:406
void SetDim(uint32_t ndim)
Definition: tuple.h:329
ValueType * end()
Definition: tuple.h:176
const dim_t * data() const
Definition: tuple.h:428
size_t HashCombine(size_t key, const T &value)
hash an object and combines the key with previous keys
Definition: common.h:37
bool operator!=(const Tuple< ValueType > &s) const
Definition: tuple.h:160
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
TShape & operator=(Tuple< dim_t > &&src)
move assignment function from tshape
Definition: tuple.h:401
static const uint32_t kStackCache
Definition: tuple.h:319
DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin)
helper function to cast type of container elements
Definition: tuple.h:557
namespace for dmlc
Definition: array_view.h:12
size_t operator()(const nnvm::Tuple< T > &val) const
hash a Tuple into unsigned int
Definition: tuple.h:613
TShape(Tuple< dim_t > &&s)
move constructor.
Definition: tuple.h:373
Tuple< ValueType > & operator=(std::initializer_list< ValueType > init)
assignment from initializer list
Definition: tuple.h:144
void Write(const ValueType &value)
Write value to json.
ValueType * begin()
Definition: tuple.h:168
friend std::istream & operator>>(std::istream &is, Tuple< ValueType > &t)
read tuple from the istream
Definition: tuple.h:239
void Save(dmlc::JSONWriter *writer) const
Save Tuple to JSON.
Definition: tuple.h:203
int num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:390
TShape & operator=(const Tuple< dim_t > &src)
assignment function from tshape
Definition: tuple.h:392
TShape(const Tuple< dim_t > &s)
copy constructor of TShape
Definition: tuple.h:359
const ValueType * end() const
Definition: tuple.h:210
const ValueType * begin() const
Definition: tuple.h:164
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:76
Tuple(const Tuple< ValueType > &s)
copy constructor from another tuple
Definition: tuple.h:63
void Read(ValueType *out_value)
Read next ValueType.
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:217
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:58
Tuple(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator
Definition: tuple.h:95
Tuple(std::vector< ValueType > init)
constructor from vector
Definition: tuple.h:77
const ValueType * end() const
Definition: tuple.h:172
const ValueType * begin() const
Definition: tuple.h:202
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:106
int ndim_
number of dimension of the tuple
Definition: tuple.h:388
Tuple< ValueType > & operator=(const Tuple< ValueType > &src)
assignment from another tuple.
Definition: tuple.h:126
Tuple(std::initializer_list< ValueType > init)
constructor from initializer list
Definition: tuple.h:70
~Tuple()
destructor
Definition: tuple.h:56
static std::string value()
Definition: tuple.h:645
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:228
size_t ProdShape(int dimstart, int dimend) const
Definition: tuple.h:419
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:394
void Load(dmlc::JSONReader *reader)
Load Tuple from JSON.
Definition: tuple.h:211
Tuple(Tuple< ValueType > &&src)
move constructor from Tuple
Definition: tuple.h:85
uint32_t ndim() const
Definition: tuple.h:180
uint32_t num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:323
bool isdigit(char c)
Inline implementation of isdigit(). Tests whether the given character is a decimal digit...
Definition: strtonum.h:46
void swap(Tuple< ValueType > &other)
Swap current object with other.
Definition: tuple.h:115
Configuration of nnvm as well as basic data structure.
TShape(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator
Definition: tuple.h:383
Lightweight json to write any STL compositions.
Definition: json.h:189