mxnet
implicit_gemm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
27 #define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
28 
29 #include "../extension.h"
30 #include "../packet-inl.h"
31 
32 namespace mshadow {
33 namespace expr {
40 template<typename LhsExp, typename RhsExp, typename DType>
42  public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>,
43  DType, type::kChainer> {
45  const LhsExp &lhs_;
47  const RhsExp &rhs_;
53  ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
54  : lhs_(lhs), rhs_(rhs) {
57  this->shape_ = mshadow::Shape2(slhs[0], srhs[1]);
58  prod_size_ = slhs[1];
59  }
60 };
61 
62 
63 template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2>
66  const Exp<RhsExp, DType, e2> &rhs) {
68  ::Error_Expression_Does_Not_Meet_Dimension_Req();
70 }
71 
72 //----------------------
73 // Execution plan
74 //----------------------
75 template<typename LhsExp, typename RhsExp, typename DType>
76 struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> {
77  public:
79  : lhs_(MakePlan(e.lhs_)),
80  rhs_(MakePlan(e.rhs_)),
82  prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) {
83  }
84 
85  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
86  typedef packet::Packet<DType> Packet;
87  Packet sum = Packet::Fill(0);
88 
89  const size_t packetSize = Packet::size;
90  DType lhs_temp[packetSize], rhs_temp[packetSize];
91 
92  for (index_t i = 0; i < prod_size_lower_align_; i += packetSize) {
93  // unroll
94  for (index_t j = 0; j < packetSize; ++j) {
95  lhs_temp[j] = lhs_.Eval(y, i + j);
96  }
97  for (index_t j = 0; j < packetSize; ++j) {
98  rhs_temp[j] = rhs_.Eval(i + j, x);
99  }
100  sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp);
101  }
102  DType ret_result = sum.Sum();
103 
104  for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) {
105  ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x);
106  }
107  return ret_result;
108  }
109 
110  private:
113  const index_t prod_size_;
114  const index_t prod_size_lower_align_;
115 };
116 
117 template<typename LhsExp, typename RhsExp, typename DType>
121 }
122 
123 
124 template<int dim, typename LhsExp, typename RhsExp, typename DType>
125 struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
126  inline static Shape<dim>
128  CHECK(dim == 2)
129  << "ImplicitGEMMExp only support 2 dimension";
132  CHECK_EQ(shape1[1], shape2[0])
133  << "implicit_dot The matrix shape do not match";
134  return t.shape_;
135  }
136 };
137 
138 template<typename LhsExp, typename RhsExp, typename DType>
139 struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
140  static const int kDim = 2;
142 };
143 
144 } // namespace expr
145 } // namespace mshadow
146 #endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
147 
ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
constructor
Definition: implicit_gemm.h:53
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:65
static Shape< dim > Check(const ImplicitGEMMExp< LhsExp, RhsExp, DType > &t)
Definition: implicit_gemm.h:127
index_t prod_size_
internal production size
Definition: implicit_gemm.h:49
Definition: expr_engine-inl.h:59
used to help static type check
Definition: expr_engine-inl.h:331
const RhsExp & rhs_
rhs operand
Definition: implicit_gemm.h:47
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
Shape< 2 > shape_
the shape of this expression
Definition: implicit_gemm.h:51
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
int32_t index_t
type that will be used for index
Definition: base.h:336
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:217
index_t LowerAlign(index_t size)
get lower bound of aligned index of size
Definition: packet-inl.h:143
const LhsExp & lhs_
lhs operand
Definition: implicit_gemm.h:45
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Plan(const ImplicitGEMMExp< LhsExp, RhsExp, DType > &e)
Definition: implicit_gemm.h:78
Matrix multiplication.
Definition: implicit_gemm.h:41
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: implicit_gemm.h:85
#define MSHADOW_DEFAULT_PACKET
Definition: packet-inl.h:48
Generic packet type.
Definition: packet-inl.h:60