mxnet
implicit_gemm.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
8 #define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
9 
10 #include "../extension.h"
11 #include "../packet-inl.h"
12 
13 namespace mshadow {
14 namespace expr {
21 template<typename LhsExp, typename RhsExp, typename DType>
23  public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>,
24  DType, type::kChainer> {
26  const LhsExp &lhs_;
28  const RhsExp &rhs_;
34  ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
35  : lhs_(lhs), rhs_(rhs) {
38  this->shape_ = mshadow::Shape2(slhs[0], srhs[1]);
39  prod_size_ = slhs[1];
40  }
41 };
42 
43 
44 template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2>
47  const Exp<RhsExp, DType, e2> &rhs) {
49  ::Error_Expression_Does_Not_Meet_Dimension_Req();
51 }
52 
53 //----------------------
54 // Execution plan
55 //----------------------
56 template<typename LhsExp, typename RhsExp, typename DType>
57 struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> {
58  public:
60  : lhs_(MakePlan(e.lhs_)),
61  rhs_(MakePlan(e.rhs_)),
63  prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) {
64  }
65 
66  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
67  typedef packet::Packet<DType> Packet;
68  Packet sum = Packet::Fill(0);
69 
70  const size_t packetSize = Packet::size;
71  DType lhs_temp[packetSize], rhs_temp[packetSize];
72 
73  for (index_t i = 0; i < prod_size_lower_align_; i += packetSize) {
74  // unroll
75  for (index_t j = 0; j < packetSize; ++j) {
76  lhs_temp[j] = lhs_.Eval(y, i + j);
77  }
78  for (index_t j = 0; j < packetSize; ++j) {
79  rhs_temp[j] = rhs_.Eval(i + j, x);
80  }
81  sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp);
82  }
83  DType ret_result = sum.Sum();
84 
85  for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) {
86  ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x);
87  }
88  return ret_result;
89  }
90 
91  private:
94  const index_t prod_size_;
95  const index_t prod_size_lower_align_;
96 };
97 
98 template<typename LhsExp, typename RhsExp, typename DType>
102 }
103 
104 
105 template<int dim, typename LhsExp, typename RhsExp, typename DType>
106 struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
107  inline static Shape<dim>
109  CHECK(dim == 2)
110  << "ImplicitGEMMExp only support 2 dimension";
113  CHECK_EQ(shape1[1], shape2[0])
114  << "implicit_dot The matrix shape do not match";
115  return t.shape_;
116  }
117 };
118 
119 template<typename LhsExp, typename RhsExp, typename DType>
120 struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
121  static const int kDim = 2;
123 };
124 
125 } // namespace expr
126 } // namespace mshadow
127 #endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
128 
ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
constructor
Definition: implicit_gemm.h:34
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:46
static Shape< dim > Check(const ImplicitGEMMExp< LhsExp, RhsExp, DType > &t)
Definition: implicit_gemm.h:108
index_t prod_size_
internal production size
Definition: implicit_gemm.h:30
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
const RhsExp & rhs_
rhs operand
Definition: implicit_gemm.h:28
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
Shape< 2 > shape_
the shape of this expression
Definition: implicit_gemm.h:32
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
int32_t index_t
type that will be used for index
Definition: base.h:291
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:198
index_t LowerAlign(index_t size)
get lower bound of aligned index of size
Definition: packet-inl.h:124
const LhsExp & lhs_
lhs operand
Definition: implicit_gemm.h:26
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
namespace for mshadow
Definition: base.h:282
Plan(const ImplicitGEMMExp< LhsExp, RhsExp, DType > &e)
Definition: implicit_gemm.h:59
Matrix multiplication.
Definition: implicit_gemm.h:22
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: implicit_gemm.h:66
#define MSHADOW_DEFAULT_PACKET
Definition: packet-inl.h:29
Generic packet type.
Definition: packet-inl.h:41