mxnet
layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
31 #ifndef NNVM_LAYOUT_H_
32 #define NNVM_LAYOUT_H_
33 
34 #include <dmlc/parameter.h>
35 #include <string>
36 #include <sstream>
37 #include <vector>
38 #include <utility>
39 #include <algorithm>
40 
41 namespace nnvm {
42 
43 class Layout {
44  public:
45  using LayoutDim = char;
46 
48  Layout() : name_("__undef__") {} // NOLINT(*)
49 
58  inline Layout(const std::string& layout) { // NOLINT(*)
59  parse(layout);
60  }
65  inline Layout(const Layout& s) { // NOLINT(*)
66  this->parse(s.name_);
67  }
72  inline Layout(Layout&& src) { // NOLINT(*)
73  this->swap(src);
74  }
80  inline Layout& operator=(const Layout& src) {
81  this->parse(src.name_);
82  return *this;
83  }
89  inline Layout& operator=(Layout&& src) {
90  Layout(std::move(src)).swap(*this); // NOLINT(*)
91  return *this;
92  }
98  inline Layout& operator=(const std::string& src) {
99  this->parse(src);
100  return *this;
101  }
106  inline bool operator==(const Layout& s) const {
107  return name_ == s.name_;
108  }
113  inline bool operator!=(const Layout& s) const {
114  return !(*this == s);
115  }
116 
122  inline Layout operator+(const Layout& other) const {
123  if (!this->defined() && !other.defined()) {
124  return Layout::Undef();
125  } else if (!this->defined()) {
126  return other;
127  } else if (!other.defined()) {
128  return *this;
129  }
130  return Layout(this->name_ + other.name_);
131  }
132 
138  static inline bool is_superdim(LayoutDim dim) {
139  return dim >= 'A' && dim <= 'Z';
140  }
141 
147  static inline bool is_subdim(LayoutDim dim) {
148  return dim >= 'a' && dim <= 'z';
149  }
150 
156  static inline LayoutDim to_superdim(LayoutDim dim) {
157  if (is_subdim(dim)) {
158  return dim - 'a' + 'A';
159  }
160  return dim;
161  }
162 
168  static inline LayoutDim to_subdim(LayoutDim dim) {
169  if (is_superdim(dim)) {
170  return dim - 'A' + 'a';
171  }
172  return dim;
173  }
174 
179  static inline const Layout& Undef() {
180  static Layout undef;
181  return undef;
182  }
183 
188  inline void swap(Layout& other) { // NOLINT(*)
189  std::swap(name_, other.name_);
190  std::swap(superdim_pos_, other.superdim_pos_);
191  std::swap(subdim_pos_, other.subdim_pos_);
192  std::swap(subdim_size_, other.subdim_size_);
193  std::swap(layout_simplified_, other.layout_simplified_);
194  }
195 
204  inline bool convertible(const Layout &dst) const {
205  if (!this->defined() || !dst.defined()) return false;
206  for (size_t i = 0; i < kUniqueDim; ++i) {
207  if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
208  (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
209  return false;
210  }
211  }
212  return true;
213  }
214 
223  inline Layout sublayout(size_t pos, size_t len) const {
224  if (pos > ndim()) return Layout::Undef();
225  if (pos + len > ndim()) len = ndim() - pos;
226  if (len == 0) return Layout::Undef();
227  std::ostringstream new_layout;
228  for (size_t i = pos; i < pos + len; ++i) {
229  if (is_subdim(layout_simplified_[i])) {
230  auto block_size = this->subsizeof(layout_simplified_[i]);
231  CHECK_GT(block_size, 0);
232  new_layout << block_size;
233  }
234  new_layout << layout_simplified_[i];
235  }
236  return Layout(new_layout.str());
237  }
238 
240  inline Layout reverse() const {
241  if (!this->defined()) return Layout::Undef();
242  std::ostringstream new_layout;
243  for (int64_t i = this->ndim() - 1; i >= 0; --i) {
244  if (is_subdim(layout_simplified_[i])) {
245  auto block_size = this->subsizeof(layout_simplified_[i]);
246  CHECK_GT(block_size, 0);
247  new_layout << block_size;
248  }
249  new_layout << layout_simplified_[i];
250  }
251  return Layout(new_layout.str());
252  }
253 
261  inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
262  CHECK(target_pos <= this->ndim()) << "Invalid split position "
263  << target_pos << " for layout " << name_;
264  CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
265  CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
266  CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
267  << " has already been split in "
268  << name_;
269  CHECK(size > 0) << "Invalid split size " << size;
270  std::ostringstream new_layout;
271  for (size_t i = 0; i <= this->ndim(); ++i) {
272  if (i == target_pos) {
273  new_layout << size << Layout::to_subdim(dim);
274  }
275  if (i == this->ndim()) break;
276  new_layout << this->at(i);
277  }
278  Layout x(new_layout.str());
279  return x;
280  }
281 
282  using iterator = std::vector<LayoutDim>::const_iterator;
283  using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
284 
286  inline iterator begin() const {
287  return layout_simplified_.begin();
288  }
290  inline iterator end() const {
291  return layout_simplified_.end();
292  }
294  inline reverse_iterator rbegin() const {
295  return layout_simplified_.rbegin();
296  }
298  inline reverse_iterator rend() const {
299  return layout_simplified_.rend();
300  }
301 
303  inline size_t ndim() const {
304  return layout_simplified_.size();
305  }
306 
314  inline std::string at(size_t i) const {
315  CHECK_LT(i, this->ndim()) << "position " << i
316  << " exceeds ndim=" << this->ndim();
317  std::ostringstream repr;
318  if (is_subdim(layout_simplified_[i])) {
319  auto factor = subsizeof(layout_simplified_[i]);
320  CHECK_GT(factor, 0);
321  repr << factor;
322  }
323  repr << layout_simplified_[i];
324  return repr.str();
325  }
326 
334  inline int32_t indexof(LayoutDim dim) const {
335  if (!this->defined()) return -1;
336  else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
337  else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
338  return -1;
339  }
340 
347  inline int64_t subsizeof(LayoutDim dim) const {
348  CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
349  if (!this->defined() || !this->contains(to_subdim(dim))) {
350  return -1;
351  }
352  int idx = to_subdim(dim) - 'a';
353  return subdim_size_[idx];
354  }
355 
361  inline bool contains(LayoutDim dim) const {
362  if (is_superdim(dim)) {
363  return superdim_pos_[dim-'A'] >= 0;
364  } else if (is_subdim(dim)) {
365  return subdim_pos_[dim-'a'] >= 0;
366  }
367  return false;
368  }
369 
370  inline LayoutDim operator[](size_t i) const {
371  return layout_simplified_[i];
372  }
373 
375  inline bool defined() const {
376  return name_ != "__undef__";
377  }
378 
380  inline const std::string& name() const {
381  return name_;
382  }
383 
388  inline void Save(dmlc::JSONWriter* writer) const {
389  writer->Write(name_);
390  }
391 
396  inline void Load(dmlc::JSONReader* reader) {
397  std::string tmp;
398  reader->Read(&tmp);
399  this->parse(tmp);
400  }
401 
408  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
409  os << l.name_;
410  return os;
411  }
412 
413  private:
414  static const uint32_t kUniqueDim = 26;
415 
416  std::string name_;
417  int32_t superdim_pos_[kUniqueDim];
418  int32_t subdim_pos_[kUniqueDim];
419  int64_t subdim_size_[kUniqueDim];
420  std::vector<LayoutDim> layout_simplified_;
421 
422  void parse(const std::string& layout) {
423  name_ = layout;
424  std::fill_n(superdim_pos_, kUniqueDim, -1);
425  std::fill_n(subdim_pos_, kUniqueDim, -1);
426  std::fill_n(subdim_size_, kUniqueDim, -1);
427  layout_simplified_.clear();
428 
429  if (layout == "__undef__") return;
430 
431  int32_t factor = 0;
432  uint32_t curr = 0;
433  for (size_t i = 0; i < layout.size(); ++i) {
434  const LayoutDim c = layout.at(i);
435  if (is_superdim(c)) {
436  int pos = c - 'A';
437  CHECK_EQ(factor, 0) << "Invalid layout " << layout
438  << ": invalid factor size " << factor
439  << " before dimension " << c;
440  CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
441  << ": duplicate dimension " << c;
442  superdim_pos_[pos] = curr++;
443  layout_simplified_.push_back(c);
444  } else if (is_subdim(c)) {
445  int pos = c - 'a';
446  CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
447  << factor << " for dimension " << c;
448  CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
449  << ": duplicate dimension " << c;
450  CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
451  << ": duplicate dimension " << c;
452  subdim_pos_[pos] = curr++;
453  subdim_size_[pos] = factor;
454  layout_simplified_.push_back(c);
455  factor = 0;
456  } else if (c >= '0' && c <= '9') {
457  CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
458  factor = factor * 10 + c - '0';
459  } else {
460  LOG(FATAL) << "Invalid layout " << layout;
461  }
462  }
463  CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
464  for (LayoutDim dim : layout_simplified_) {
465  CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
466  << "Invalid layout " << layout << ": missing axis "
467  << static_cast<char>(dim - 'a' + 'A');
468  }
469  }
470 };
471 
472 } // namespace nnvm
473 
474 #endif // NNVM_LAYOUT_H_
void swap(Layout &other)
Swap current object with other.
Definition: layout.h:188
Definition: base.h:36
Layout reverse() const
Definition: layout.h:240
Layout(Layout &&src)
move constructor from Layout
Definition: layout.h:72
Layout(const std::string &layout)
construct from a string.
Definition: layout.h:58
Layout()
default constructor
Definition: layout.h:48
bool operator==(const Layout &s) const
Definition: layout.h:106
static LayoutDim to_subdim(LayoutDim dim)
Convert a given dimension to sub-dimension.
Definition: layout.h:168
Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const
Split dim by size and put the sub-dimension to position target_pos.
Definition: layout.h:261
iterator begin() const
Definition: layout.h:286
void Load(dmlc::JSONReader *reader)
Load layout from JSON.
Definition: layout.h:396
static bool is_superdim(LayoutDim dim)
Check whether a given dimension is a super-dimension.
Definition: layout.h:138
void Save(dmlc::JSONWriter *writer) const
Write layout in JSON format.
Definition: layout.h:388
int32_t indexof(LayoutDim dim) const
return the index of the input dimension. If it is not found in the layout or the layout is undefined...
Definition: layout.h:334
Layout operator+(const Layout &other) const
Append the current layout by another.
Definition: layout.h:122
Layout(const Layout &s)
copy constructor from another layout
Definition: layout.h:65
static const Layout & Undef()
Return an undefined layout.
Definition: layout.h:179
std::vector< LayoutDim >::const_iterator iterator
Definition: layout.h:282
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
reverse_iterator rend() const
Definition: layout.h:298
std::vector< LayoutDim >::const_reverse_iterator reverse_iterator
Definition: layout.h:283
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: layout.h:408
void Write(const ValueType &value)
Write value to json.
Layout & operator=(Layout &&src)
assignment from rvalue of another layout.
Definition: layout.h:89
LayoutDim operator[](size_t i) const
Definition: layout.h:370
Layout & operator=(const std::string &src)
assignment from string.
Definition: layout.h:98
bool defined() const
Definition: layout.h:375
std::string at(size_t i) const
The description of the i-th dimension. If it is a sub-dimension, the size will be returned as well...
Definition: layout.h:314
Layout & operator=(const Layout &src)
assignment from another layout.
Definition: layout.h:80
void Read(ValueType *out_value)
Read next ValueType.
iterator end() const
Definition: layout.h:290
Layout sublayout(size_t pos, size_t len) const
Returns a sublayout which is the portion of the object that starts at dimension pos and spans len dim...
Definition: layout.h:223
static bool is_subdim(LayoutDim dim)
Check whether a given dimension is a sub-dimension.
Definition: layout.h:147
size_t ndim() const
Definition: layout.h:303
bool contains(LayoutDim dim) const
Whether the layout contains a dimension.
Definition: layout.h:361
char LayoutDim
Definition: layout.h:45
bool convertible(const Layout &dst) const
Two layouts are convertible only if they have same set of super-dimensions. e.g., NCHW...
Definition: layout.h:204
Definition: layout.h:43
bool operator!=(const Layout &s) const
Definition: layout.h:113
reverse_iterator rbegin() const
Definition: layout.h:294
int64_t subsizeof(LayoutDim dim) const
Definition: layout.h:347
static LayoutDim to_superdim(LayoutDim dim)
Convert a given dimension to super-dimension.
Definition: layout.h:156
Provide lightweight util to do parameter setup and checking.
const std::string & name() const
Definition: layout.h:380
Lightweight json to write any STL compositions.
Definition: json.h:189