mxnet
layout.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
30 #ifndef NNVM_LAYOUT_H_
31 #define NNVM_LAYOUT_H_
32 
33 #include <dmlc/parameter.h>
34 
35 #include <algorithm>
36 #include <sstream>
37 #include <string>
38 #include <utility>
39 #include <vector>
40 
41 namespace nnvm {
42 
43 class Layout {
44  public:
45  using LayoutDim = char;
46 
48  Layout() : name_("__undef__") {} // NOLINT(*)
49 
58  inline Layout(const std::string& layout) { // NOLINT(*)
59  parse(layout);
60  }
65  inline Layout(const Layout& s) { // NOLINT(*)
66  this->parse(s.name_);
67  }
72  inline Layout(Layout&& src) { // NOLINT(*)
73  this->swap(src);
74  }
80  inline Layout& operator=(const Layout& src) {
81  this->parse(src.name_);
82  return *this;
83  }
89  inline Layout& operator=(Layout&& src) {
90  Layout(std::move(src)).swap(*this); // NOLINT(*)
91  return *this;
92  }
98  inline Layout& operator=(const std::string& src) {
99  this->parse(src);
100  return *this;
101  }
106  inline bool operator==(const Layout& s) const { return name_ == s.name_; }
111  inline bool operator!=(const Layout& s) const { return !(*this == s); }
112 
118  inline Layout operator+(const Layout& other) const {
119  if (!this->defined() && !other.defined()) {
120  return Layout::Undef();
121  } else if (!this->defined()) {
122  return other;
123  } else if (!other.defined()) {
124  return *this;
125  }
126  return Layout(this->name_ + other.name_);
127  }
128 
134  static inline bool is_superdim(LayoutDim dim) { return dim >= 'A' && dim <= 'Z'; }
135 
141  static inline bool is_subdim(LayoutDim dim) { return dim >= 'a' && dim <= 'z'; }
142 
148  static inline LayoutDim to_superdim(LayoutDim dim) {
149  if (is_subdim(dim)) {
150  return dim - 'a' + 'A';
151  }
152  return dim;
153  }
154 
160  static inline LayoutDim to_subdim(LayoutDim dim) {
161  if (is_superdim(dim)) {
162  return dim - 'A' + 'a';
163  }
164  return dim;
165  }
166 
171  static inline const Layout& Undef() {
172  static Layout undef;
173  return undef;
174  }
175 
180  inline void swap(Layout& other) { // NOLINT(*)
181  std::swap(name_, other.name_);
182  std::swap(superdim_pos_, other.superdim_pos_);
183  std::swap(subdim_pos_, other.subdim_pos_);
184  std::swap(subdim_size_, other.subdim_size_);
185  std::swap(layout_simplified_, other.layout_simplified_);
186  }
187 
196  inline bool convertible(const Layout& dst) const {
197  if (!this->defined() || !dst.defined()) return false;
198  for (size_t i = 0; i < kUniqueDim; ++i) {
199  if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
200  (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
201  return false;
202  }
203  }
204  return true;
205  }
206 
215  inline Layout sublayout(size_t pos, size_t len) const {
216  if (pos > ndim()) return Layout::Undef();
217  if (pos + len > ndim()) len = ndim() - pos;
218  if (len == 0) return Layout::Undef();
219  std::ostringstream new_layout;
220  for (size_t i = pos; i < pos + len; ++i) {
221  if (is_subdim(layout_simplified_[i])) {
222  auto block_size = this->subsizeof(layout_simplified_[i]);
223  CHECK_GT(block_size, 0);
224  new_layout << block_size;
225  }
226  new_layout << layout_simplified_[i];
227  }
228  return Layout(new_layout.str());
229  }
230 
232  inline Layout reverse() const {
233  if (!this->defined()) return Layout::Undef();
234  std::ostringstream new_layout;
235  for (int64_t i = this->ndim() - 1; i >= 0; --i) {
236  if (is_subdim(layout_simplified_[i])) {
237  auto block_size = this->subsizeof(layout_simplified_[i]);
238  CHECK_GT(block_size, 0);
239  new_layout << block_size;
240  }
241  new_layout << layout_simplified_[i];
242  }
243  return Layout(new_layout.str());
244  }
245 
253  inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
254  CHECK(target_pos <= this->ndim())
255  << "Invalid split position " << target_pos << " for layout " << name_;
256  CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
257  CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
258  CHECK(!this->contains(to_subdim(dim)))
259  << "Dimension " << dim << " has already been split in " << name_;
260  CHECK(size > 0) << "Invalid split size " << size;
261  std::ostringstream new_layout;
262  for (size_t i = 0; i <= this->ndim(); ++i) {
263  if (i == target_pos) {
264  new_layout << size << Layout::to_subdim(dim);
265  }
266  if (i == this->ndim()) break;
267  new_layout << this->at(i);
268  }
269  Layout x(new_layout.str());
270  return x;
271  }
272 
273  using iterator = std::vector<LayoutDim>::const_iterator;
274  using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
275 
277  inline iterator begin() const { return layout_simplified_.begin(); }
279  inline iterator end() const { return layout_simplified_.end(); }
281  inline reverse_iterator rbegin() const { return layout_simplified_.rbegin(); }
283  inline reverse_iterator rend() const { return layout_simplified_.rend(); }
284 
286  inline size_t ndim() const { return layout_simplified_.size(); }
287 
295  inline std::string at(size_t i) const {
296  CHECK_LT(i, this->ndim()) << "position " << i << " exceeds ndim=" << this->ndim();
297  std::ostringstream repr;
298  if (is_subdim(layout_simplified_[i])) {
299  auto factor = subsizeof(layout_simplified_[i]);
300  CHECK_GT(factor, 0);
301  repr << factor;
302  }
303  repr << layout_simplified_[i];
304  return repr.str();
305  }
306 
314  inline int32_t indexof(LayoutDim dim) const {
315  if (!this->defined())
316  return -1;
317  else if (is_superdim(dim))
318  return superdim_pos_[dim - 'A'];
319  else if (is_subdim(dim))
320  return subdim_pos_[dim - 'a'];
321  return -1;
322  }
323 
330  inline int64_t subsizeof(LayoutDim dim) const {
331  CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
332  if (!this->defined() || !this->contains(to_subdim(dim))) {
333  return -1;
334  }
335  int idx = to_subdim(dim) - 'a';
336  return subdim_size_[idx];
337  }
338 
344  inline bool contains(LayoutDim dim) const {
345  if (is_superdim(dim)) {
346  return superdim_pos_[dim - 'A'] >= 0;
347  } else if (is_subdim(dim)) {
348  return subdim_pos_[dim - 'a'] >= 0;
349  }
350  return false;
351  }
352 
353  inline LayoutDim operator[](size_t i) const { return layout_simplified_[i]; }
354 
356  inline bool defined() const { return name_ != "__undef__"; }
357 
359  inline const std::string& name() const { return name_; }
360 
365  inline void Save(dmlc::JSONWriter* writer) const { writer->Write(name_); }
366 
371  inline void Load(dmlc::JSONReader* reader) {
372  std::string tmp;
373  reader->Read(&tmp);
374  this->parse(tmp);
375  }
376 
383  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
384  os << l.name_;
385  return os;
386  }
387 
388  private:
389  static const uint32_t kUniqueDim = 26;
390 
391  std::string name_;
392  int32_t superdim_pos_[kUniqueDim];
393  int32_t subdim_pos_[kUniqueDim];
394  int64_t subdim_size_[kUniqueDim];
395  std::vector<LayoutDim> layout_simplified_;
396 
397  void parse(const std::string& layout) {
398  name_ = layout;
399  std::fill_n(superdim_pos_, kUniqueDim, -1);
400  std::fill_n(subdim_pos_, kUniqueDim, -1);
401  std::fill_n(subdim_size_, kUniqueDim, -1);
402  layout_simplified_.clear();
403 
404  if (layout == "__undef__") return;
405 
406  int32_t factor = 0;
407  uint32_t curr = 0;
408  for (size_t i = 0; i < layout.size(); ++i) {
409  const LayoutDim c = layout.at(i);
410  if (is_superdim(c)) {
411  int pos = c - 'A';
412  CHECK_EQ(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor
413  << " before dimension " << c;
414  CHECK_EQ(superdim_pos_[pos], -1)
415  << "Invalid layout " << layout << ": duplicate dimension " << c;
416  superdim_pos_[pos] = curr++;
417  layout_simplified_.push_back(c);
418  } else if (is_subdim(c)) {
419  int pos = c - 'a';
420  CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor
421  << " for dimension " << c;
422  CHECK_EQ(subdim_pos_[pos], -1)
423  << "Invalid layout " << layout << ": duplicate dimension " << c;
424  CHECK_EQ(subdim_size_[pos], -1)
425  << "Invalid layout " << layout << ": duplicate dimension " << c;
426  subdim_pos_[pos] = curr++;
427  subdim_size_[pos] = factor;
428  layout_simplified_.push_back(c);
429  factor = 0;
430  } else if (c >= '0' && c <= '9') {
431  CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
432  factor = factor * 10 + c - '0';
433  } else {
434  LOG(FATAL) << "Invalid layout " << layout;
435  }
436  }
437  CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
438  for (LayoutDim dim : layout_simplified_) {
439  CHECK(is_superdim(dim) || superdim_pos_[dim - 'a'] >= 0)
440  << "Invalid layout " << layout << ": missing axis " << static_cast<char>(dim - 'a' + 'A');
441  }
442  }
443 };
444 
445 } // namespace nnvm
446 
447 #endif // NNVM_LAYOUT_H_
dmlc::JSONReader::Read
void Read(ValueType *out_value)
Read next ValueType.
nnvm::Layout::is_subdim
static bool is_subdim(LayoutDim dim)
Check whether a given dimension is a sub-dimension.
Definition: layout.h:141
nnvm::Layout::operator<<
friend std::ostream & operator<<(std::ostream &os, const Layout &l)
allow output string of layout to ostream
Definition: layout.h:383
nnvm::Layout::operator!=
bool operator!=(const Layout &s) const
Definition: layout.h:111
nnvm::Layout::convertible
bool convertible(const Layout &dst) const
Two layouts are convertible only if they have same set of super-dimensions. e.g., NCHW,...
Definition: layout.h:196
nnvm::Layout::LayoutDim
char LayoutDim
Definition: layout.h:45
nnvm::Layout::end
iterator end() const
Definition: layout.h:279
nnvm::Layout::operator=
Layout & operator=(Layout &&src)
assignment from rvalue of another layout.
Definition: layout.h:89
dmlc::JSONWriter::Write
void Write(const ValueType &value)
Write value to json.
nnvm::Layout::at
std::string at(size_t i) const
The description of the i-th dimension. If it is a sub-dimension, the size will be returned as well,...
Definition: layout.h:295
nnvm::Layout::operator=
Layout & operator=(const std::string &src)
assignment from string.
Definition: layout.h:98
nnvm::Layout::ndim
size_t ndim() const
Definition: layout.h:286
parameter.h
Provide lightweight util to do parameter setup and checking.
nnvm::Layout::sublayout
Layout sublayout(size_t pos, size_t len) const
Returns a sublayout which is the portion of the object that starts at dimension pos and spans len dim...
Definition: layout.h:215
nnvm::Layout::split
Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const
Split dim by size and put the sub-dimension to position target_pos.
Definition: layout.h:253
nnvm::Layout::defined
bool defined() const
Definition: layout.h:356
nnvm::Layout::rend
reverse_iterator rend() const
Definition: layout.h:283
nnvm::Layout::name
const std::string & name() const
Definition: layout.h:359
nnvm::Layout
Definition: layout.h:43
nnvm::Layout::reverse
Layout reverse() const
Definition: layout.h:232
nnvm::Layout::to_superdim
static LayoutDim to_superdim(LayoutDim dim)
Convert a given dimension to super-dimension.
Definition: layout.h:148
nnvm::Layout::operator[]
LayoutDim operator[](size_t i) const
Definition: layout.h:353
nnvm::Layout::Save
void Save(dmlc::JSONWriter *writer) const
Write layout in JSON format.
Definition: layout.h:365
nnvm::Layout::swap
void swap(Layout &other)
Swap current object with other.
Definition: layout.h:180
dmlc::JSONWriter
Lightweight json to write any STL compositions.
Definition: json.h:190
nnvm::Layout::Layout
Layout(Layout &&src)
move constructor from Layout
Definition: layout.h:72
nnvm::Layout::Load
void Load(dmlc::JSONReader *reader)
Load layout from JSON.
Definition: layout.h:371
nnvm::Layout::Layout
Layout()
default constructor
Definition: layout.h:48
nnvm::Layout::Layout
Layout(const Layout &s)
copy constructor from another layout
Definition: layout.h:65
nnvm::Layout::contains
bool contains(LayoutDim dim) const
Whether the layout contains a dimension.
Definition: layout.h:344
nnvm::Layout::to_subdim
static LayoutDim to_subdim(LayoutDim dim)
Convert a given dimension to sub-dimension.
Definition: layout.h:160
nnvm::Layout::rbegin
reverse_iterator rbegin() const
Definition: layout.h:281
nnvm::Layout::subsizeof
int64_t subsizeof(LayoutDim dim) const
Definition: layout.h:330
nnvm::Layout::Layout
Layout(const std::string &layout)
construct from a string.
Definition: layout.h:58
nnvm::Layout::reverse_iterator
std::vector< LayoutDim >::const_reverse_iterator reverse_iterator
Definition: layout.h:274
nnvm::Layout::operator=
Layout & operator=(const Layout &src)
assignment from another layout.
Definition: layout.h:80
nnvm::Layout::operator==
bool operator==(const Layout &s) const
Definition: layout.h:106
nnvm::Layout::Undef
static const Layout & Undef()
Return an undefined layout.
Definition: layout.h:171
nnvm::Layout::begin
iterator begin() const
Definition: layout.h:277
dmlc::JSONReader
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
nnvm::Layout::iterator
std::vector< LayoutDim >::const_iterator iterator
Definition: layout.h:273
nnvm::Layout::operator+
Layout operator+(const Layout &other) const
Append the current layout by another.
Definition: layout.h:118
nnvm::Layout::indexof
int32_t indexof(LayoutDim dim) const
return the index of the input dimension. If it is not found in the layout or the layout is undefined,...
Definition: layout.h:314
nnvm::Layout::is_superdim
static bool is_superdim(LayoutDim dim)
Check whether a given dimension is a super-dimension.
Definition: layout.h:134
nnvm
Definition: base.h:35