mxnet
tuple.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef NNVM_TUPLE_H_
26 #define NNVM_TUPLE_H_
27 
28 #include <vector>
29 #include <type_traits>
30 #include <algorithm>
31 #include <utility>
32 #include <iostream>
33 #include <string>
34 #include "base.h"
35 
36 namespace nnvm {
37 
39 typedef int64_t dim_t;
40 
51 template<typename ValueType>
52 class Tuple {
53  public:
55  Tuple() = default;
57  inline ~Tuple() {
58  delete [] data_heap_;
59  }
64  inline Tuple(const Tuple<ValueType>& s) {
65  this->assign(s.begin(), s.end());
66  }
71  inline Tuple(std::initializer_list<ValueType> init) {
72  this->assign(init.begin(), init.end());
73  }
78  inline Tuple(std::vector<ValueType> init) { // NOLINT(runtime/explicit)
79  this->assign(init.begin(), init.end());
80  }
86  inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
87  this->swap(src);
88  }
95  template<typename RandomAccessIterator>
96  inline Tuple(RandomAccessIterator begin,
97  RandomAccessIterator end) {
98  this->assign(begin, end);
99  }
106  template<typename RandomAccessIterator>
107  inline void assign(RandomAccessIterator begin,
108  RandomAccessIterator end) {
109  this->SetDim(end - begin);
110  std::copy(begin, end, this->begin());
111  }
116  inline void swap(Tuple<ValueType>& other) { // NOLINT(*)
117  std::swap(ndim_, other.ndim_);
118  std::swap(num_heap_allocated_, other.num_heap_allocated_);
119  std::swap(data_stack_, other.data_stack_);
120  std::swap(data_heap_, other.data_heap_);
121  }
128  this->assign(src.begin(), src.end());
129  return *this;
130  }
137  Tuple<ValueType>(std::move(src)).swap(*this);
138  return *this;
139  }
145  inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) {
146  this->assign(init.begin(), init.end());
147  return *this;
148  }
153  inline bool operator==(const Tuple<ValueType> &s) const {
154  if (ndim_ != s.ndim_) return false;
155  return std::equal(begin(), end(), s.begin());
156  }
161  inline bool operator!=(const Tuple<ValueType> &s) const {
162  return !(*this == s);
163  }
165  inline const ValueType *begin() const {
166  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
167  }
169  inline ValueType *begin() {
170  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
171  }
173  inline const ValueType* end() const {
174  return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
175  }
177  inline ValueType* end() {
178  return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
179  }
181  inline uint32_t ndim() const {
182  return ndim_;
183  }
189  inline ValueType& operator[](size_t i) {
190  return begin()[i];
191  }
197  inline const ValueType& operator[](size_t i) const {
198  return begin()[i];
199  }
204  inline void Save(dmlc::JSONWriter* writer) const {
205  std::vector<ValueType> tmp(begin(), end());
206  writer->Write(tmp);
207  }
212  inline void Load(dmlc::JSONReader* reader) {
213  std::vector<ValueType> tmp;
214  reader->Read(&tmp);
215  this->assign(tmp.begin(), tmp.end());
216  }
223  friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
224  os << '[';
225  const ValueType* begin = t.begin();
226  const ValueType* end = t.end();
227  for (const ValueType* it = begin; it != end; ++it) {
228  if (it != begin) os << ',';
229  os << *it;
230  }
231  os << ']';
232  return os;
233  }
240  friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) {
241  // get (
242  while (true) {
243  char ch = is.peek();
244  if (isdigit(ch) || ch == '-') {
245  ValueType idx;
246  if (is >> idx) {
247  t.assign(&idx, &idx + 1);
248  }
249  return is;
250  }
251  is.get();
252  if (ch == '(' || ch == '[') break;
253  if (!isspace(ch)) {
254  is.setstate(std::ios::failbit);
255  return is;
256  }
257  }
258  // Handle empty tuple
259  while (isspace(is.peek())) {
260  is.get();
261  }
262  if (is.peek() == ')' || is.peek() == ']') {
263  is.get();
264  return is;
265  }
266  // Handle non-empty tuple
267  ValueType idx;
268  std::vector<ValueType> tmp;
269  while (is >> idx) {
270  tmp.push_back(idx);
271  char ch;
272  do {
273  ch = is.get();
274  } while (isspace(ch));
275  if (std::is_integral<ValueType>::value && ch == 'L') {
276  ch = is.get();
277  }
278  if (ch == ',') {
279  while (true) {
280  ch = is.peek();
281  if (isspace(ch)) {
282  is.get(); continue;
283  }
284  if (ch == ')' || ch == ']') {
285  is.get(); break;
286  }
287  break;
288  }
289  if (ch == ')' || ch == ']') break;
290  } else if (ch == ')' || ch == ']') {
291  break;
292  } else {
293  is.setstate(std::ios::failbit);
294  return is;
295  }
296  }
297  t.assign(tmp.begin(), tmp.end());
298  return is;
299  }
306  template<typename DType = ValueType, typename TStream>
307  inline void Save(TStream *strm) const;
315  template<typename DType = ValueType, typename TStream>
316  inline bool Load(TStream *strm);
317 
318  protected:
319  // stack cache size
320  static const uint32_t kStackCache = 4;
322  uint32_t ndim_{0};
324  uint32_t num_heap_allocated_{0};
328  ValueType* data_heap_{nullptr};
329  // internal function to change the dimension
330  inline void SetDim(uint32_t ndim) {
331  if (ndim > kStackCache &&
332  ndim > num_heap_allocated_) {
333  delete [] data_heap_;
334  data_heap_ = new ValueType[ndim];
336  }
337  ndim_ = ndim;
338  }
339 };
340 
344 class TShape : public Tuple<dim_t> {
345  public:
347  TShape() = default;
352  inline TShape(uint32_t ndim) { // NOLINT(*)
353  this->SetDim(ndim);
354  std::fill_n(begin(), ndim, 1);
355  }
360  inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
361  this->assign(s.begin(), s.end());
362  }
367  inline TShape(std::initializer_list<dim_t> init) {
368  this->assign(init.begin(), init.end());
369  }
374  inline TShape(Tuple<dim_t>&& s) { // NOLINT(*)
375  this->swap(s);
376  }
383  template<typename RandomAccessIterator>
384  inline TShape(RandomAccessIterator begin,
385  RandomAccessIterator end) {
386  this->assign(begin, end);
387  }
393  inline TShape& operator=(const Tuple<dim_t>& src) {
394  this->assign(src.begin(), src.end());
395  return *this;
396  }
402  inline TShape& operator=(Tuple<dim_t>&& src) { // NOLINT(*)
403  TShape(std::move(src)).swap(*this); // NOLINT(*)
404  return *this;
405  }
407  inline size_t Size() const {
408  dim_t size = 1;
409  const dim_t* start = begin(), *fin = end();
410  for (const dim_t* it = start; it != fin; ++it) {
411  size *= *it;
412  }
413  return size;
414  }
420  inline size_t ProdShape(int dimstart, int dimend) const {
421  dim_t num = 1;
422  const dim_t *d = this->data();
423  for (int i = dimstart; i < dimend; ++i) {
424  num *= d[i];
425  }
426  return num;
427  }
429  inline const dim_t *data() const {
430  return begin();
431  }
433  inline dim_t *data() {
434  return begin();
435  }
436 #ifdef MSHADOW_XINLINE
437  template<int dim>
438  inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
439  this->assign(s.shape_, s.shape_ + dim);
440  }
441 
442  template<int dim>
443  inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
444  this->assign(s.shape_, s.shape_ + dim);
445  }
452  template<int dim>
453  inline TShape &operator=(const mshadow::Shape<dim> &shape) {
454  this->assign(shape.shape_, shape.shape_ + dim);
455  return *this;
456  }
462  template<int dim>
463  inline mshadow::Shape<dim> get() const {
464  CHECK_EQ(dim, static_cast<int>(ndim()))
465  << "dimension do not match target dimension " << dim << " vs " << ndim();
466  const dim_t *d = this->data();
468  for (int i = 0; i < dim; ++i) {
469  s[i] = d[i];
470  }
471  return s;
472  }
477  inline mshadow::Shape<2> FlatTo2D(void) const {
479  if (ndim() == 0) return mshadow::Shape2(0, 0);
480  const dim_t *d = this->data();
481  s.shape_[1] = d[ndim() - 1];
482  dim_t ymax = 1;
483  for (size_t i = 1; i < ndim(); ++i) {
484  ymax *= d[i - 1];
485  }
486  s.shape_[0] = ymax;
487  return s;
488  }
495  inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const {
496  CHECK(axis_end >= axis_begin);
498  if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
499  const dim_t *d = this->data();
500  s.shape_[0] = 1;
501  s.shape_[1] = 1;
502  s.shape_[2] = 1;
503 
504  for (size_t i = 0; i < axis_begin; ++i) {
505  s.shape_[0] *= d[i];
506  }
507  for (size_t i = axis_begin; i <= axis_end; ++i) {
508  s.shape_[1] *= d[i];
509  }
510  for (size_t i = axis_end + 1; i < ndim(); ++i) {
511  s.shape_[2] *= d[i];
512  }
513  return s;
514  }
520  inline mshadow::Shape<3> FlatTo3D(size_t axis) const {
521  return FlatTo3D(axis, axis);
522  }
523  inline bool operator==(const TShape &s) const {
524  if (ndim() != s.ndim()) return false;
525  return std::equal(begin(), end(), s.begin());
526  }
527  inline bool operator!=(const TShape &s) const {
528  return !(*this == s);
529  }
535  template<int dim>
536  inline bool operator==(const mshadow::Shape<dim> &s) const {
537  if (ndim_ != dim) return false;
538  const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
539  for (size_t i = 0; i < dim; ++i) {
540  if (d[i] != s.shape_[i]) return false;
541  }
542  return true;
543  }
549  template<int dim>
550  inline bool operator!=(const mshadow::Shape<dim> &s) const {
551  return !(*this == s);
552  }
553 #endif
554 };
555 
557 template<typename SrcIter, typename DstIter>
558 inline DstIter ShapeTypeCast(const SrcIter begin,
559  const SrcIter end,
560  DstIter dst_begin) {
561  typedef typename std::iterator_traits<SrcIter>::value_type SrcDType;
562  typedef typename std::iterator_traits<DstIter>::value_type DstDType;
563  auto cast = [](const SrcDType& dim) { return static_cast<DstDType>(dim); };
564  return std::transform(begin, end, dst_begin, cast);
565 }
566 
568 template<typename SrcIter>
569 inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) {
570  size_t ndim = std::distance(begin, end);
571  TShape res(ndim);
572  ShapeTypeCast(begin, end, res.begin());
573  return res;
574 }
575 
577 template<typename ValueType>
578 template<typename DType, typename TStream>
579 inline void Tuple<ValueType>::Save(TStream *strm) const {
580  strm->Write(&ndim_, sizeof(ndim_));
581  if (typeid(DType) == typeid(ValueType)) {
582  strm->Write(begin(), sizeof(ValueType) * ndim_);
583  } else {
584  std::vector<DType> buffer(ndim_);
585  ShapeTypeCast(begin(), end(), buffer.data());
586  strm->Write(buffer.data(), sizeof(DType) * ndim_);
587  }
588 }
589 
591 template<typename ValueType>
592 template<typename DType, typename TStream>
593 inline bool Tuple<ValueType>::Load(TStream *strm) {
594  if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
595  this->SetDim(ndim_);
596  size_t nread = sizeof(DType) * ndim_;
597  if (typeid(DType) == typeid(ValueType)) {
598  if (strm->Read(begin(), nread) != nread) return false;
599  } else {
600  std::vector<DType> buffer(ndim_);
601  if (strm->Read(buffer.data(), nread) != nread) return false;
602  ShapeTypeCast(buffer.begin(), buffer.end(), begin());
603  }
604  return true;
605 }
606 
607 } // namespace nnvm
608 
609 namespace std {
611 template<typename T>
612 struct hash<nnvm::Tuple<T> > {
614  size_t operator()(const nnvm::Tuple<T>& val) const {
615  std::hash<uint32_t> hash_uint;
616  size_t res = hash_uint(val.ndim());
617  for (uint32_t i = 0; i < val.ndim(); ++i) {
618  res = dmlc::HashCombine(res, val[i]);
619  }
620  return res;
621  }
622 };
623 
625 template<>
626 struct hash<nnvm::TShape> {
628  size_t operator()(const nnvm::TShape& val) const {
629  std::hash<uint32_t> hash_uint;
630  size_t res = hash_uint(val.ndim());
631  for (uint32_t i = 0; i < val.ndim(); ++i) {
632  res = dmlc::HashCombine(res, val[i]);
633  }
634  return res;
635  }
636 };
637 } // namespace std
638 
639 namespace dmlc {
641 DMLC_DECLARE_TYPE_NAME(optional<nnvm::TShape>, "Shape or None");
642 // avoid low version of MSVC
643 #if !defined(_MSC_VER)
644 template<typename T>
646  static inline std::string value() {
647  return "tuple of <" + type_name<T>() + ">";
648  }
649 };
650 #endif
651 } // namespace dmlc
652 #endif // NNVM_TUPLE_H_
#define DMLC_DECLARE_TYPE_NAME(Type, Name)
macro to quickly declare traits information
Definition: type_traits.h:133
Definition: base.h:36
helper class to construct a string that represents type name
Definition: type_traits.h:86
Tuple< ValueType > & operator=(Tuple< ValueType > &&src)
assignment from rvalue of another tuple.
Definition: tuple.h:136
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:52
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:328
uint32_t ndim_
number of dimension of the tuple
Definition: tuple.h:322
const ValueType & operator[](size_t i) const
get corresponding index
Definition: tuple.h:197
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:367
dim_t * data()
Definition: tuple.h:433
bool operator==(const Tuple< ValueType > &s) const
Definition: tuple.h:153
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:326
int64_t dim_t
data type to store dim size
Definition: tuple.h:39
size_t operator()(const nnvm::TShape &val) const
hash a TShape into unsigned int
Definition: tuple.h:628
Definition: optional.h:241
Tuple()=default
default constructor
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:344
TShape(uint32_t ndim)
Definition: tuple.h:352
TShape(std::initializer_list< dim_t > init)
constructor from initializer list
Definition: tuple.h:367
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:113
ValueType & operator[](size_t i)
get corresponding index
Definition: tuple.h:189
size_t Size() const
Definition: tuple.h:407
void SetDim(uint32_t ndim)
Definition: tuple.h:330
ValueType * end()
Definition: tuple.h:177
const dim_t * data() const
Definition: tuple.h:429
size_t HashCombine(size_t key, const T &value)
hash an object and combines the key with previous keys
Definition: common.h:37
bool operator!=(const Tuple< ValueType > &s) const
Definition: tuple.h:161
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
TShape & operator=(Tuple< dim_t > &&src)
move assignment function from tshape
Definition: tuple.h:402
static const uint32_t kStackCache
Definition: tuple.h:320
DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin)
helper function to cast type of container elements
Definition: tuple.h:558
namespace for dmlc
Definition: array_view.h:12
size_t operator()(const nnvm::Tuple< T > &val) const
hash a Tuple into unsigned int
Definition: tuple.h:614
TShape(Tuple< dim_t > &&s)
move constructor.
Definition: tuple.h:374
Tuple< ValueType > & operator=(std::initializer_list< ValueType > init)
assignment from initializer list
Definition: tuple.h:145
void Write(const ValueType &value)
Write value to json.
ValueType * begin()
Definition: tuple.h:169
friend std::istream & operator>>(std::istream &is, Tuple< ValueType > &t)
read tuple from the istream
Definition: tuple.h:240
void Save(dmlc::JSONWriter *writer) const
Save Tuple to JSON.
Definition: tuple.h:204
int num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:365
TShape & operator=(const Tuple< dim_t > &src)
assignment function from tshape
Definition: tuple.h:393
TShape(const Tuple< dim_t > &s)
copy constructor of TShape
Definition: tuple.h:360
const ValueType * end() const
Definition: tuple.h:185
const ValueType * begin() const
Definition: tuple.h:165
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:57
Tuple(const Tuple< ValueType > &s)
copy constructor from another tuple
Definition: tuple.h:64
void Read(ValueType *out_value)
Read next ValueType.
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:198
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:54
Tuple(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator
Definition: tuple.h:96
Tuple(std::vector< ValueType > init)
constructor from vector
Definition: tuple.h:78
const ValueType * end() const
Definition: tuple.h:173
const ValueType * begin() const
Definition: tuple.h:177
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:107
int ndim_
number of dimension of the tuple
Definition: tuple.h:363
Tuple< ValueType > & operator=(const Tuple< ValueType > &src)
assignment from another tuple.
Definition: tuple.h:127
Tuple(std::initializer_list< ValueType > init)
constructor from initializer list
Definition: tuple.h:71
~Tuple()
destructor
Definition: tuple.h:57
static std::string value()
Definition: tuple.h:646
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:209
size_t ProdShape(int dimstart, int dimend) const
Definition: tuple.h:420
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:369
void Load(dmlc::JSONReader *reader)
Load Tuple from JSON.
Definition: tuple.h:212
Tuple(Tuple< ValueType > &&src)
move constructor from Tuple
Definition: tuple.h:86
uint32_t ndim() const
Definition: tuple.h:181
uint32_t num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:324
bool isdigit(char c)
Inline implementation of isdigit(). Tests whether the given character is a decimal digit...
Definition: strtonum.h:46
void swap(Tuple< ValueType > &other)
Swap current object with other.
Definition: tuple.h:116
Configuration of nnvm as well as basic data structure.
TShape(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator
Definition: tuple.h:384
Lightweight json to write any STL compositions.
Definition: json.h:189