mxnet
expr_scalar-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 // macro guard is harmful, used to pass the cpplint
29 #ifndef MSHADOW_EXPR_SCALAR_INL_H_
30 #define MSHADOW_EXPR_SCALAR_INL_H_
31 // undef the guard so it can be included multiple times
32 #undef MSHADOW_EXPR_SCALAR_INL_H_
33 
34 namespace mshadow {
35 namespace expr {
36 // DotExp
38 template<typename TA, typename TB, bool ltrans, bool rtrans>
39 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
41  MSHADOW_SCALAR_ rhs) {
42  return DotExp<TA, TB, ltrans, rtrans,
43  MSHADOW_SCALAR_>(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs);
44 }
46 template<typename TA, typename TB, bool ltrans, bool rtrans>
47 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
50  return DotExp<TA, TB, ltrans, rtrans,
51  MSHADOW_SCALAR_>(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs);
52 }
53 
55 template<typename E, typename DType, typename R, int d>
56 inline ReduceTo1DExp<E, DType, R, d>
58  return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
59 }
61 template<typename E, typename DType, typename R, int d>
62 inline ReduceTo1DExp<E, DType, R, d>
64  return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
65 }
66 
68 template<typename OP, typename TA, int ta>
69 inline BinaryMapExp<OP, TA, ScalarExp<MSHADOW_SCALAR_>,
72  return MakeExp<OP>(lhs, rhs);
73 }
75 template<typename OP, typename TB, int tb>
76 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, TB,
79  return MakeExp<OP>(lhs, rhs);
80 }
82 template<typename OP>
83 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
86  return MakeExp<OP>(lhs, rhs);
87 }
88 // constant operators
90 template<typename TA, int ta>
91 inline BinaryMapExp<op::plus, TA, ScalarExp<MSHADOW_SCALAR_>,
93 operator+(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
94  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
95  return MakeExp<op::plus>(lhs, rhs);
96 }
98 template<typename TA, int ta>
99 inline BinaryMapExp<op::minus, TA, ScalarExp<MSHADOW_SCALAR_>,
101 operator-(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
102  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
103  return MakeExp<op::minus>(lhs, rhs);
104 }
106 template<typename TA, int ta>
107 inline BinaryMapExp<op::mul, TA, ScalarExp<MSHADOW_SCALAR_>,
109 operator*(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
110  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
111  return MakeExp<op::mul>(lhs, rhs);
112 }
114 template<typename TA, int ta>
115 inline BinaryMapExp<op::div, TA, ScalarExp<MSHADOW_SCALAR_>,
117 operator/(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
118  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
119  return MakeExp<op::div>(lhs, rhs);
120 }
121 // constant operators 2
123 template<typename TB, int tb>
124 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, TB,
126 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
127  const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
128  return MakeExp<op::plus>(lhs, rhs);
129 }
131 template<typename TB, int tb>
132 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, TB,
134 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
135  const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
136  return MakeExp<op::minus>(lhs, rhs);
137 }
139 template<typename TB, int tb>
140 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, TB,
142 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
143  const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
144  return MakeExp<op::mul>(lhs, rhs);
145 }
147 template<typename TB, int tb>
148 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, TB,
150 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
151  return MakeExp<op::div>(lhs, rhs);
152 }
153 // constant operators 3
155 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
157 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
158  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
159  return MakeExp<op::plus>(lhs, rhs);
160 }
162 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
164 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
165  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
166  return MakeExp<op::minus>(lhs, rhs);
167 }
169 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
171 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
172  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
173  return MakeExp<op::mul>(lhs, rhs);
174 }
176 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
178 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
179  return MakeExp<op::div>(lhs, rhs);
180 }
181 } // namespace expr
182 } // namespace mshadow
183 #endif // MSHADOW_EXPR_SCALAR_INL_H_
mshadow::expr::ReduceTo1DExp::src_
const SrcExp & src_
source operand
Definition: reduceto1d.h:45
mshadow::expr::ReduceTo1DExp::scale_
DType scale_
source operand, scale of the
Definition: reduceto1d.h:47
mshadow::expr::operator*
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:40
mshadow::expr::DotExp::lhs_
const TA & lhs_
left operand
Definition: expression.h:227
mshadow::expr::F
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
mshadow::expr::type::kMapper
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape
Definition: expression.h:50
mshadow::expr::DotExp::scale_
DType scale_
scale over result
Definition: expression.h:231
mshadow::expr::DotExp::rhs_
const TB & rhs_
right operand
Definition: expression.h:229
mshadow::expr::ReduceTo1DExp
reduction to 1 dimension tensor input: Tensor<Device,k>: ishape output: Tensor<Device,...
Definition: reduceto1d.h:41
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
MSHADOW_SCALAR_
#define MSHADOW_SCALAR_
Definition: tensor.h:1227
mshadow::expr::ScalarExp
scalar expression
Definition: expression.h:95
mshadow::expr::DotExp
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:224