30 #ifndef NNVM_LAYOUT_H_
31 #define NNVM_LAYOUT_H_
58 inline Layout(
const std::string& layout) {
81 this->parse(src.name_);
90 Layout(std::move(src)).swap(*
this);
126 return Layout(this->name_ + other.name_);
150 return dim -
'a' +
'A';
162 return dim -
'A' +
'a';
181 std::swap(name_, other.name_);
182 std::swap(superdim_pos_, other.superdim_pos_);
183 std::swap(subdim_pos_, other.subdim_pos_);
184 std::swap(subdim_size_, other.subdim_size_);
185 std::swap(layout_simplified_, other.layout_simplified_);
198 for (
size_t i = 0; i < kUniqueDim; ++i) {
199 if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
200 (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
217 if (pos + len >
ndim()) len =
ndim() - pos;
219 std::ostringstream new_layout;
220 for (
size_t i = pos; i < pos + len; ++i) {
222 auto block_size = this->
subsizeof(layout_simplified_[i]);
223 CHECK_GT(block_size, 0);
224 new_layout << block_size;
226 new_layout << layout_simplified_[i];
228 return Layout(new_layout.str());
234 std::ostringstream new_layout;
235 for (int64_t i = this->
ndim() - 1; i >= 0; --i) {
237 auto block_size = this->
subsizeof(layout_simplified_[i]);
238 CHECK_GT(block_size, 0);
239 new_layout << block_size;
241 new_layout << layout_simplified_[i];
243 return Layout(new_layout.str());
254 CHECK(target_pos <= this->
ndim())
255 <<
"Invalid split position " << target_pos <<
" for layout " << name_;
256 CHECK(
is_superdim(dim)) <<
"Cannot split a sub-dimension " << dim;
257 CHECK(this->
contains(dim)) <<
"Axis " << dim <<
" does not exist in " << name_;
259 <<
"Dimension " << dim <<
" has already been split in " << name_;
260 CHECK(size > 0) <<
"Invalid split size " << size;
261 std::ostringstream new_layout;
262 for (
size_t i = 0; i <= this->
ndim(); ++i) {
263 if (i == target_pos) {
266 if (i == this->
ndim())
break;
267 new_layout << this->
at(i);
269 Layout x(new_layout.str());
273 using iterator = std::vector<LayoutDim>::const_iterator;
286 inline size_t ndim()
const {
return layout_simplified_.size(); }
295 inline std::string
at(
size_t i)
const {
296 CHECK_LT(i, this->
ndim()) <<
"position " << i <<
" exceeds ndim=" << this->
ndim();
297 std::ostringstream repr;
299 auto factor =
subsizeof(layout_simplified_[i]);
303 repr << layout_simplified_[i];
318 return superdim_pos_[dim -
'A'];
320 return subdim_pos_[dim -
'a'];
336 return subdim_size_[idx];
346 return superdim_pos_[dim -
'A'] >= 0;
348 return subdim_pos_[dim -
'a'] >= 0;
356 inline bool defined()
const {
return name_ !=
"__undef__"; }
359 inline const std::string&
name()
const {
return name_; }
389 static const uint32_t kUniqueDim = 26;
392 int32_t superdim_pos_[kUniqueDim];
393 int32_t subdim_pos_[kUniqueDim];
394 int64_t subdim_size_[kUniqueDim];
395 std::vector<LayoutDim> layout_simplified_;
397 void parse(
const std::string& layout) {
399 std::fill_n(superdim_pos_, kUniqueDim, -1);
400 std::fill_n(subdim_pos_, kUniqueDim, -1);
401 std::fill_n(subdim_size_, kUniqueDim, -1);
402 layout_simplified_.clear();
404 if (layout ==
"__undef__")
return;
408 for (
size_t i = 0; i < layout.size(); ++i) {
412 CHECK_EQ(factor, 0) <<
"Invalid layout " << layout <<
": invalid factor size " << factor
413 <<
" before dimension " << c;
414 CHECK_EQ(superdim_pos_[pos], -1)
415 <<
"Invalid layout " << layout <<
": duplicate dimension " << c;
416 superdim_pos_[pos] = curr++;
417 layout_simplified_.push_back(c);
420 CHECK_GT(factor, 0) <<
"Invalid layout " << layout <<
": invalid factor size " << factor
421 <<
" for dimension " << c;
422 CHECK_EQ(subdim_pos_[pos], -1)
423 <<
"Invalid layout " << layout <<
": duplicate dimension " << c;
424 CHECK_EQ(subdim_size_[pos], -1)
425 <<
"Invalid layout " << layout <<
": duplicate dimension " << c;
426 subdim_pos_[pos] = curr++;
427 subdim_size_[pos] = factor;
428 layout_simplified_.push_back(c);
430 }
else if (c >=
'0' && c <=
'9') {
431 CHECK(factor >= 0) <<
"Invalid layout " << layout <<
": _ is adjacent to a number.";
432 factor = factor * 10 + c -
'0';
434 LOG(FATAL) <<
"Invalid layout " << layout;
437 CHECK(!layout_simplified_.empty()) <<
"Invalid layout " << layout;
438 for (
LayoutDim dim : layout_simplified_) {
439 CHECK(
is_superdim(dim) || superdim_pos_[dim -
'a'] >= 0)
440 <<
"Invalid layout " << layout <<
": missing axis " <<
static_cast<char>(dim -
'a' +
'A');
447 #endif // NNVM_LAYOUT_H_