mxnet
layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
30 #ifndef NNVM_LAYOUT_H_
31 #define NNVM_LAYOUT_H_
32 
33 #include <dmlc/parameter.h>
34 #include <string>
35 #include <sstream>
36 #include <vector>
37 #include <utility>
38 #include <algorithm>
39 
40 namespace nnvm {
41 
42 class Layout {
43  public:
44  using LayoutDim = char;
45 
47  Layout() : name_("__undef__") {} // NOLINT(*)
48 
57  inline Layout(const std::string& layout) { // NOLINT(*)
58  parse(layout);
59  }
64  inline Layout(const Layout& s) { // NOLINT(*)
65  this->parse(s.name_);
66  }
71  inline Layout(Layout&& src) { // NOLINT(*)
72  this->swap(src);
73  }
79  inline Layout& operator=(const Layout& src) {
80  this->parse(src.name_);
81  return *this;
82  }
88  inline Layout& operator=(Layout&& src) {
89  Layout(std::move(src)).swap(*this); // NOLINT(*)
90  return *this;
91  }
97  inline Layout& operator=(const std::string& src) {
98  this->parse(src);
99  return *this;
100  }
105  inline bool operator==(const Layout& s) const {
106  return name_ == s.name_;
107  }
112  inline bool operator!=(const Layout& s) const {
113  return !(*this == s);
114  }
115 
121  inline Layout operator+(const Layout& other) const {
122  if (!this->defined() && !other.defined()) {
123  return Layout::Undef();
124  } else if (!this->defined()) {
125  return other;
126  } else if (!other.defined()) {
127  return *this;
128  }
129  return Layout(this->name_ + other.name_);
130  }
131 
137  static inline bool is_superdim(LayoutDim dim) {
138  return dim >= 'A' && dim <= 'Z';
139  }
140 
146  static inline bool is_subdim(LayoutDim dim) {
147  return dim >= 'a' && dim <= 'z';
148  }
149 
155  static inline LayoutDim to_superdim(LayoutDim dim) {
156  if (is_subdim(dim)) {
157  return dim - 'a' + 'A';
158  }
159  return dim;
160  }
161 
167  static inline LayoutDim to_subdim(LayoutDim dim) {
168  if (is_superdim(dim)) {
169  return dim - 'A' + 'a';
170  }
171  return dim;
172  }
173 
178  static inline const Layout& Undef() {
179  static Layout undef;
180  return undef;
181  }
182 
187  inline void swap(Layout& other) { // NOLINT(*)
188  std::swap(name_, other.name_);
189  std::swap(superdim_pos_, other.superdim_pos_);
190  std::swap(subdim_pos_, other.subdim_pos_);
191  std::swap(subdim_size_, other.subdim_size_);
192  std::swap(layout_simplified_, other.layout_simplified_);
193  }
194 
203  inline bool convertible(const Layout &dst) const {
204  if (!this->defined() || !dst.defined()) return false;
205  for (size_t i = 0; i < kUniqueDim; ++i) {
206  if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
207  (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
208  return false;
209  }
210  }
211  return true;
212  }
213 
222  inline Layout sublayout(size_t pos, size_t len) const {
223  if (pos > ndim()) return Layout::Undef();
224  if (pos + len > ndim()) len = ndim() - pos;
225  if (len == 0) return Layout::Undef();
226  std::ostringstream new_layout;
227  for (size_t i = pos; i < pos + len; ++i) {
228  if (is_subdim(layout_simplified_[i])) {
229  auto block_size = this->subsizeof(layout_simplified_[i]);
230  CHECK_GT(block_size, 0);
231  new_layout << block_size;
232  }
233  new_layout << layout_simplified_[i];
234  }
235  return Layout(new_layout.str());
236  }
237 
239  inline Layout reverse() const {
240  if (!this->defined()) return Layout::Undef();
241  std::ostringstream new_layout;
242  for (int64_t i = this->ndim() - 1; i >= 0; --i) {
243  if (is_subdim(layout_simplified_[i])) {
244  auto block_size = this->subsizeof(layout_simplified_[i]);
245  CHECK_GT(block_size, 0);
246  new_layout << block_size;
247  }
248  new_layout << layout_simplified_[i];
249  }
250  return Layout(new_layout.str());
251  }
252 
260  inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
261  CHECK(target_pos <= this->ndim()) << "Invalid split position "
262  << target_pos << " for layout " << name_;
263  CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
264  CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
265  CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
266  << " has already been split in "
267  << name_;
268  CHECK(size > 0) << "Invalid split size " << size;
269  std::ostringstream new_layout;
270  for (size_t i = 0; i <= this->ndim(); ++i) {
271  if (i == target_pos) {
272  new_layout << size << Layout::to_subdim(dim);
273  }
274  if (i == this->ndim()) break;
275  new_layout << this->at(i);
276  }
277  Layout x(new_layout.str());
278  return x;
279  }
280 
281  using iterator = std::vector<LayoutDim>::const_iterator;
282  using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
283 
285  inline iterator begin() const {
286  return layout_simplified_.begin();
287  }
289  inline iterator end() const {
290  return layout_simplified_.end();
291  }
293  inline reverse_iterator rbegin() const {
294  return layout_simplified_.rbegin();
295  }
297  inline reverse_iterator rend() const {
298  return layout_simplified_.rend();
299  }
300 
302  inline size_t ndim() const {
303  return layout_simplified_.size();
304  }
305 
313  inline std::string at(size_t i) const {
314  CHECK_LT(i, this->ndim()) << "position " << i
315  << " exceeds ndim=" << this->ndim();
316  std::ostringstream repr;
317  if (is_subdim(layout_simplified_[i])) {
318  auto factor = subsizeof(layout_simplified_[i]);
319  CHECK_GT(factor, 0);
320  repr << factor;
321  }
322  repr << layout_simplified_[i];
323  return repr.str();
324  }
325 
333  inline int32_t indexof(LayoutDim dim) const {
334  if (!this->defined()) return -1;
335  else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
336  else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
337  return -1;
338  }
339 
346  inline int64_t subsizeof(LayoutDim dim) const {
347  CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
348  if (!this->defined() || !this->contains(to_subdim(dim))) {
349  return -1;
350  }
351  int idx = to_subdim(dim) - 'a';
352  return subdim_size_[idx];
353  }
354 
360  inline bool contains(LayoutDim dim) const {
361  if (is_superdim(dim)) {
362  return superdim_pos_[dim-'A'] >= 0;
363  } else if (is_subdim(dim)) {
364  return subdim_pos_[dim-'a'] >= 0;
365  }
366  return false;
367  }
368 
369  inline LayoutDim operator[](size_t i) const {
370  return layout_simplified_[i];
371  }
372 
374  inline bool defined() const {
375  return name_ != "__undef__";
376  }
377 
379  inline const std::string& name() const {
380  return name_;
381  }
382 
387  inline void Save(dmlc::JSONWriter* writer) const {
388  writer->Write(name_);
389  }
390 
395  inline void Load(dmlc::JSONReader* reader) {
396  std::string tmp;
397  reader->Read(&tmp);
398  this->parse(tmp);
399  }
400 
407  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
408  os << l.name_;
409  return os;
410  }
411 
412  private:
413  static const uint32_t kUniqueDim = 26;
414 
415  std::string name_;
416  int32_t superdim_pos_[kUniqueDim];
417  int32_t subdim_pos_[kUniqueDim];
418  int64_t subdim_size_[kUniqueDim];
419  std::vector<LayoutDim> layout_simplified_;
420 
421  void parse(const std::string& layout) {
422  name_ = layout;
423  std::fill_n(superdim_pos_, kUniqueDim, -1);
424  std::fill_n(subdim_pos_, kUniqueDim, -1);
425  std::fill_n(subdim_size_, kUniqueDim, -1);
426  layout_simplified_.clear();
427 
428  if (layout == "__undef__") return;
429 
430  int32_t factor = 0;
431  uint32_t curr = 0;
432  for (size_t i = 0; i < layout.size(); ++i) {
433  const LayoutDim c = layout.at(i);
434  if (is_superdim(c)) {
435  int pos = c - 'A';
436  CHECK_EQ(factor, 0) << "Invalid layout " << layout
437  << ": invalid factor size " << factor
438  << " before dimension " << c;
439  CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
440  << ": duplicate dimension " << c;
441  superdim_pos_[pos] = curr++;
442  layout_simplified_.push_back(c);
443  } else if (is_subdim(c)) {
444  int pos = c - 'a';
445  CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
446  << factor << " for dimension " << c;
447  CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
448  << ": duplicate dimension " << c;
449  CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
450  << ": duplicate dimension " << c;
451  subdim_pos_[pos] = curr++;
452  subdim_size_[pos] = factor;
453  layout_simplified_.push_back(c);
454  factor = 0;
455  } else if (c >= '0' && c <= '9') {
456  CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
457  factor = factor * 10 + c - '0';
458  } else {
459  LOG(FATAL) << "Invalid layout " << layout;
460  }
461  }
462  CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
463  for (LayoutDim dim : layout_simplified_) {
464  CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
465  << "Invalid layout " << layout << ": missing axis "
466  << static_cast<char>(dim - 'a' + 'A');
467  }
468  }
469 };
470 
471 } // namespace nnvm
472 
473 #endif // NNVM_LAYOUT_H_
void swap(Layout &other)
Swap current object with other.
Definition: layout.h:187
Definition: base.h:35
Layout reverse() const
Definition: layout.h:239
Layout(Layout &&src)
move constructor from Layout
Definition: layout.h:71
Layout(const std::string &layout)
construct from a string.
Definition: layout.h:57
Layout()
default constructor
Definition: layout.h:47
bool operator==(const Layout &s) const
Definition: layout.h:105
static LayoutDim to_subdim(LayoutDim dim)
Convert a given dimension to sub-dimension.
Definition: layout.h:167
Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const
Split dim by size and put the sub-dimension to position target_pos.
Definition: layout.h:260
iterator begin() const
Definition: layout.h:285
void Load(dmlc::JSONReader *reader)
Load layout from JSON.
Definition: layout.h:395
static bool is_superdim(LayoutDim dim)
Check whether a given dimension is a super-dimension.
Definition: layout.h:137
void Save(dmlc::JSONWriter *writer) const
Write layout in JSON format.
Definition: layout.h:387
int32_t indexof(LayoutDim dim) const
return the index of the input dimension. If it is not found in the layout or the layout is undefined...
Definition: layout.h:333
Layout operator+(const Layout &other) const
Append the current layout by another.
Definition: layout.h:121
Layout(const Layout &s)
copy constructor from another layout
Definition: layout.h:64
static const Layout & Undef()
Return an undefined layout.
Definition: layout.h:178
std::vector< LayoutDim >::const_iterator iterator
Definition: layout.h:281
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
reverse_iterator rend() const
Definition: layout.h:297
std::vector< LayoutDim >::const_reverse_iterator reverse_iterator
Definition: layout.h:282
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: layout.h:407
void Write(const ValueType &value)
Write value to json.
Layout & operator=(Layout &&src)
assignment from rvalue of another layout.
Definition: layout.h:88
LayoutDim operator[](size_t i) const
Definition: layout.h:369
Layout & operator=(const std::string &src)
assignment from string.
Definition: layout.h:97
bool defined() const
Definition: layout.h:374
std::string at(size_t i) const
The description of the i-th dimension. If it is a sub-dimension, the size will be returned as well...
Definition: layout.h:313
Layout & operator=(const Layout &src)
assignment from another layout.
Definition: layout.h:79
void Read(ValueType *out_value)
Read next ValueType.
iterator end() const
Definition: layout.h:289
Layout sublayout(size_t pos, size_t len) const
Returns a sublayout which is the portion of the object that starts at dimension pos and spans len dim...
Definition: layout.h:222
static bool is_subdim(LayoutDim dim)
Check whether a given dimension is a sub-dimension.
Definition: layout.h:146
size_t ndim() const
Definition: layout.h:302
bool contains(LayoutDim dim) const
Whether the layout contains a dimension.
Definition: layout.h:360
char LayoutDim
Definition: layout.h:44
bool convertible(const Layout &dst) const
Two layouts are convertible only if they have same set of super-dimensions. e.g., NCHW...
Definition: layout.h:203
Definition: layout.h:42
bool operator!=(const Layout &s) const
Definition: layout.h:112
reverse_iterator rbegin() const
Definition: layout.h:293
int64_t subsizeof(LayoutDim dim) const
Definition: layout.h:346
static LayoutDim to_superdim(LayoutDim dim)
Convert a given dimension to super-dimension.
Definition: layout.h:155
Provide lightweight util to do parameter setup and checking.
const std::string & name() const
Definition: layout.h:379
Lightweight json to write any STL compositions.
Definition: json.h:189