7 #ifndef MSHADOW_EXTENSION_COMPLEX_H_ 8 #define MSHADOW_EXTENSION_COMPLEX_H_ 10 #include "../extension.h" 19 template<
typename DType>
21 DType b_real, DType b_imag) {
22 return a_real * b_real - a_imag * b_imag;
24 template<
typename DType>
26 DType b_real, DType b_imag) {
27 return a_real * b_imag + b_real * a_imag;
33 template<
typename DType>
35 DType b_real, DType b_imag) {
36 return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
38 template<
typename DType>
40 DType b_real, DType b_imag) {
41 return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
46 template<
typename TA,
typename DType>
49 return src_.
Eval(real_i, real_j);
51 template<
typename TA,
typename DType>
54 return -src_.
Eval(imag_i, imag_j);
59 template<
typename TA,
typename DType>
62 return src_.
Eval(imag_i, imag_j);
64 template<
typename TA,
typename DType>
67 return src_.
Eval(real_i, real_j);
73 template<
typename TA,
typename DType>
76 return src_.
Eval(real_i, real_j);
78 template<
typename TA,
typename DType>
87 template<
typename TA,
typename DType>
90 DType real_val = src_.
Eval(real_i, real_j);
96 template<
typename TA,
typename DType>
99 DType real_val = src_.
Eval(real_i, real_j);
100 DType image_val = src_.
Eval(imag_i, imag_j);
101 return real_val * real_val + image_val * image_val;
106 template<
typename TA,
typename DType>
109 DType real_val = src_.
Eval(real_i, real_j);
110 DType image_val = src_.
Eval(imag_i, imag_j);
111 return real_val + image_val;
129 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
138 :lhs_(lhs), rhs_(rhs) {}
149 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
160 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int ta,
int tb>
172 template<
int calctype,
typename OP,
typename SrcExp,
typename DType,
int e1>
181 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
185 return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
191 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
195 return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
201 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
205 return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
211 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
215 return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
221 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
225 return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
231 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
235 return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
243 template<
typename SrcExp,
typename DType,
int e1>
247 return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
255 template<
typename SrcExp,
typename DType,
int e1>
259 return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
267 template<
typename SrcExp,
typename DType,
int e1>
271 return ComplexF<op::complex::kUnitaryR2C, op::complex::pad_imag>(src);
279 template<
typename SrcExp,
typename DType,
int e1>
283 return ComplexF<op::complex::kUnitaryC2R, op::complex::toreal>(src);
291 template<
typename SrcExp,
typename DType,
int e1>
295 return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
298 template<
typename SrcExp,
typename DType,
int e1>
302 return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
305 template<
int dim,
int calctype,
typename OP,
typename TA,
typename TB,
306 typename DType,
int etype>
312 if (shape1[0] == 0)
return shape2;
313 if (shape2[0] == 0)
return shape1;
314 if (calctype == op::complex::kBinaryCC) {
315 CHECK_EQ(shape1, shape2) <<
"ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
316 CHECK_EQ(shape1[dim - 1] % 2, 0) <<
317 "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. " 318 "We must have real part + imaginary part.";
320 }
else if (calctype == op::complex::kBinaryCR) {
321 for (
int i = 0; i < dim - 1; ++i) {
323 "ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
325 CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
326 "ComplexBinaryMapExp (CR): Shapes of operands do not match.";
328 }
else if (calctype == op::complex::kBinaryRC) {
329 for (
int i = 0; i < dim - 1; ++i) {
331 "ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
333 CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
334 "ComplexBinaryMapExp (RC): Shapes of operands do not match.";
337 LOG(FATAL) <<
"ComplexBinaryMapExp: Unexpected Calculation Type!";
343 template<
int dim,
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
347 CHECK_EQ(s[dim - 1] % 2, 0) <<
"ComplexUnitaryExp: Shape of the last dimension is not even. " 348 "We must have real + imaginary.";
349 if (calctype == op::complex::kUnitaryC2C) {
351 }
else if (calctype == op::complex::kUnitaryC2R) {
355 }
else if (calctype == op::complex::kUnitaryR2C) {
360 LOG(FATAL) <<
"ComplexUnitaryExp: Unexpected Calculation Type!";
369 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
373 : lhs_(lhs), rhs_(rhs) {}
377 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
378 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
380 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
381 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
391 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
395 : lhs_(lhs), rhs_(rhs) {}
399 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
400 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
402 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
403 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
414 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
418 : lhs_(lhs), rhs_(rhs) {}
422 return OP::RealMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
423 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
425 return OP::ImagMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
426 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
437 template<
typename OP,
typename TA,
int etype,
typename DType>
444 return OP::RealMap(src_, y, base_x, y, base_x + 1);
446 return OP::ImagMap(src_, y, base_x, y, base_x + 1);
455 template<
typename OP,
typename TA,
int etype,
typename DType>
465 return OP::RealMap(src_, y, real_x);
467 return OP::ImagMap(src_, y, real_x);
476 template<
typename OP,
typename TA,
int etype,
typename DType>
481 return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
490 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
493 return Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
497 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
500 return Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
506 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
510 static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
513 ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
517 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
525 #endif // MSHADOW_EXTENSION_COMPLEX_H_ ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
Definition: complex.h:204
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:34
Plan< ComplexUnitaryExp< calctype, OP, TA, DType, etype >, DType > MakePlan(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &e)
Definition: complex.h:499
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:65
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:52
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:374
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:32
ComplexUnitaryExp(const TA &src)
constructor
Definition: complex.h:155
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::abs_square, SrcExp, DType,(e1|type::kMapper)> complex_abs_square(const Exp< SrcExp, DType, e1 > &src)
complex_abs_square calculate the square of the modulus of A where A is a complex tensor ...
Definition: complex.h:294
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:107
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
UnitaryCalculationType
Definition: complex.h:16
const TB & rhs_
right operand
Definition: complex.h:135
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:480
ComplexUnitaryExp< calctype, OP, SrcExp, DType,(e1|type::kMapper)> ComplexF(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:174
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:88
ComplexUnitaryExp< op::complex::kUnitaryR2C, op::complex::pad_imag, SrcExp, DType,(e1|type::kMapper)> complex_pad_imag(const Exp< SrcExp, DType, e1 > &src)
complex_pad_imag Transform real matrix into complex matrix
Definition: complex.h:270
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:79
static Shape< dim > Check(const ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > &t)
Definition: complex.h:309
BinaryCalculationType
Definition: complex.h:15
Definition: complex.h:105
const TA & lhs_
left operand
Definition: complex.h:133
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:184
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::conjugate, SrcExp, DType,(e1|type::kMapper)> conj(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:246
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:419
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:441
const TA & src_
source expression
Definition: complex.h:153
int32_t index_t
type that will be used for index
Definition: base.h:291
static Shape< dim > Check(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &t)
Definition: complex.h:345
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:74
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:214
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:20
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:459
Plan(const Plan< TA, DType > &src)
Definition: complex.h:479
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:417
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:57
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:47
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:97
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::exchange, SrcExp, DType,(e1|type::kMapper)> complex_exchange(const Exp< SrcExp, DType, e1 > &src)
complex_exchange Exchange the real and imaginary part of A where A is a complex tensor ...
Definition: complex.h:258
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:39
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:372
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:194
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:394
ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: complex.h:137
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::sum_real_imag, SrcExp, DType,(e1|type::kMapper)> complex_sum_real_imag(const Exp< SrcExp, DType, e1 > &src)
Definition: complex.h:301
const SubType & self(void) const
Definition: expression.h:64
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:25
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::toreal, SrcExp, DType,(e1|type::kMapper)> complex_toreal(const Exp< SrcExp, DType, e1 > &src)
complex_toreal convert complex matrix to real matrix, keep only real part
Definition: complex.h:282
namespace for mshadow
Definition: base.h:282
Plan(const Plan< TA, DType > &src)
Definition: complex.h:458
Plan(const Plan< TA, DType > &src)
Definition: complex.h:440
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:396
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:60
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
Definition: complex.h:234
compute conj(src) where src is a complex tensor
Definition: complex.h:150
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:224
binary map expression lhs [op] rhs where lhs and rhs are complex tensors
Definition: complex.h:130