26 #ifndef MSHADOW_EXTENSION_COMPLEX_H_ 27 #define MSHADOW_EXTENSION_COMPLEX_H_ 29 #include "../extension.h" 38 template<
typename DType>
40 DType b_real, DType b_imag) {
41 return a_real * b_real - a_imag * b_imag;
43 template<
typename DType>
45 DType b_real, DType b_imag) {
46 return a_real * b_imag + b_real * a_imag;
52 template<
typename DType>
54 DType b_real, DType b_imag) {
55 return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
57 template<
typename DType>
59 DType b_real, DType b_imag) {
60 return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
65 template<
typename TA,
typename DType>
68 return src_.
Eval(real_i, real_j);
70 template<
typename TA,
typename DType>
73 return -src_.
Eval(imag_i, imag_j);
78 template<
typename TA,
typename DType>
81 return src_.
Eval(imag_i, imag_j);
83 template<
typename TA,
typename DType>
86 return src_.
Eval(real_i, real_j);
92 template<
typename TA,
typename DType>
95 return src_.
Eval(real_i, real_j);
97 template<
typename TA,
typename DType>
106 template<
typename TA,
typename DType>
109 DType real_val = src_.
Eval(real_i, real_j);
115 template<
typename TA,
typename DType>
118 DType real_val = src_.
Eval(real_i, real_j);
119 DType image_val = src_.
Eval(imag_i, imag_j);
120 return real_val * real_val + image_val * image_val;
125 template<
typename TA,
typename DType>
128 DType real_val = src_.
Eval(real_i, real_j);
129 DType image_val = src_.
Eval(imag_i, imag_j);
130 return real_val + image_val;
148 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
157 :lhs_(lhs), rhs_(rhs) {}
168 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
179 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int ta,
int tb>
191 template<
int calctype,
typename OP,
typename SrcExp,
typename DType,
int e1>
200 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
204 return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
210 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
214 return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
220 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
224 return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
230 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
234 return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
240 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
244 return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
250 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
254 return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
262 template<
typename SrcExp,
typename DType,
int e1>
266 return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
274 template<
typename SrcExp,
typename DType,
int e1>
278 return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
286 template<
typename SrcExp,
typename DType,
int e1>
290 return ComplexF<op::complex::kUnitaryR2C, op::complex::pad_imag>(src);
298 template<
typename SrcExp,
typename DType,
int e1>
302 return ComplexF<op::complex::kUnitaryC2R, op::complex::toreal>(src);
310 template<
typename SrcExp,
typename DType,
int e1>
314 return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
317 template<
typename SrcExp,
typename DType,
int e1>
321 return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
324 template<
int dim,
int calctype,
typename OP,
typename TA,
typename TB,
325 typename DType,
int etype>
331 if (shape1[0] == 0)
return shape2;
332 if (shape2[0] == 0)
return shape1;
333 if (calctype == op::complex::kBinaryCC) {
334 CHECK_EQ(shape1, shape2) <<
"ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
335 CHECK_EQ(shape1[dim - 1] % 2, 0) <<
336 "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. " 337 "We must have real part + imaginary part.";
339 }
else if (calctype == op::complex::kBinaryCR) {
340 for (
int i = 0; i < dim - 1; ++i) {
342 "ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
344 CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
345 "ComplexBinaryMapExp (CR): Shapes of operands do not match.";
347 }
else if (calctype == op::complex::kBinaryRC) {
348 for (
int i = 0; i < dim - 1; ++i) {
350 "ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
352 CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
353 "ComplexBinaryMapExp (RC): Shapes of operands do not match.";
356 LOG(FATAL) <<
"ComplexBinaryMapExp: Unexpected Calculation Type!";
362 template<
int dim,
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
366 CHECK_EQ(s[dim - 1] % 2, 0) <<
"ComplexUnitaryExp: Shape of the last dimension is not even. " 367 "We must have real + imaginary.";
368 if (calctype == op::complex::kUnitaryC2C) {
370 }
else if (calctype == op::complex::kUnitaryC2R) {
374 }
else if (calctype == op::complex::kUnitaryR2C) {
379 LOG(FATAL) <<
"ComplexUnitaryExp: Unexpected Calculation Type!";
388 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
392 : lhs_(lhs), rhs_(rhs) {}
396 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
397 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
399 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
400 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
410 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
414 : lhs_(lhs), rhs_(rhs) {}
418 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
419 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
421 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
422 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
433 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
437 : lhs_(lhs), rhs_(rhs) {}
441 return OP::RealMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
442 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
444 return OP::ImagMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
445 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
456 template<
typename OP,
typename TA,
int etype,
typename DType>
463 return OP::RealMap(src_, y, base_x, y, base_x + 1);
465 return OP::ImagMap(src_, y, base_x, y, base_x + 1);
474 template<
typename OP,
typename TA,
int etype,
typename DType>
484 return OP::RealMap(src_, y, real_x);
486 return OP::ImagMap(src_, y, real_x);
495 template<
typename OP,
typename TA,
int etype,
typename DType>
500 return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
509 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
512 return Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
516 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
519 return Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
525 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
529 static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
532 ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
536 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
544 #endif // MSHADOW_EXTENSION_COMPLEX_H_ ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
Definition: complex.h:223
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:53
Plan< ComplexUnitaryExp< calctype, OP, TA, DType, etype >, DType > MakePlan(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &e)
Definition: complex.h:518
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:84
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:71
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:393
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:51
ComplexUnitaryExp(const TA &src)
constructor
Definition: complex.h:174
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::abs_square, SrcExp, DType,(e1|type::kMapper)> complex_abs_square(const Exp< SrcExp, DType, e1 > &src)
complex_abs_square calculate the square of the modulus of A where A is a complex tensor ...
Definition: complex.h:313
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:126
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
UnitaryCalculationType
Definition: complex.h:35
const TB & rhs_
right operand
Definition: complex.h:154
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:499
ComplexUnitaryExp< calctype, OP, SrcExp, DType,(e1|type::kMapper)> ComplexF(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:193
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:107
ComplexUnitaryExp< op::complex::kUnitaryR2C, op::complex::pad_imag, SrcExp, DType,(e1|type::kMapper)> complex_pad_imag(const Exp< SrcExp, DType, e1 > &src)
complex_pad_imag Transform real matrix into complex matrix
Definition: complex.h:289
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:98
static Shape< dim > Check(const ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > &t)
Definition: complex.h:328
BinaryCalculationType
Definition: complex.h:34
Definition: complex.h:124
const TA & lhs_
left operand
Definition: complex.h:152
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:203
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::conjugate, SrcExp, DType,(e1|type::kMapper)> conj(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:265
#define MSHADOW_XINLINE
Definition: base.h:223
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:438
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:460
const TA & src_
source expression
Definition: complex.h:172
int32_t index_t
type that will be used for index
Definition: base.h:336
static Shape< dim > Check(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &t)
Definition: complex.h:364
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:93
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:233
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:39
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:478
Plan(const Plan< TA, DType > &src)
Definition: complex.h:498
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:436
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:76
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:66
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:116
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::exchange, SrcExp, DType,(e1|type::kMapper)> complex_exchange(const Exp< SrcExp, DType, e1 > &src)
complex_exchange Exchange the real and imaginary part of A where A is a complex tensor ...
Definition: complex.h:277
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:58
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:391
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:213
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:413
Definition: complex.h:105
ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: complex.h:156
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::sum_real_imag, SrcExp, DType,(e1|type::kMapper)> complex_sum_real_imag(const Exp< SrcExp, DType, e1 > &src)
Definition: complex.h:320
const SubType & self(void) const
Definition: expression.h:83
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:44
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::toreal, SrcExp, DType,(e1|type::kMapper)> complex_toreal(const Exp< SrcExp, DType, e1 > &src)
complex_toreal convert complex matrix to real matrix, keep only real part
Definition: complex.h:301
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Plan(const Plan< TA, DType > &src)
Definition: complex.h:477
Plan(const Plan< TA, DType > &src)
Definition: complex.h:459
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:415
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:79
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
Definition: complex.h:253
Definition: complex.h:114
compute conj(src) where src is a complex tensor
Definition: complex.h:169
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:243
binary map expression lhs [op] rhs where lhs and rhs are complex tensors
Definition: complex.h:149