Go to the documentation of this file.
25 #ifndef MSHADOW_EXTENSION_COMPLEX_H_
26 #define MSHADOW_EXTENSION_COMPLEX_H_
28 #include "../extension.h"
37 template<
typename DType>
39 DType b_real, DType b_imag) {
40 return a_real * b_real - a_imag * b_imag;
42 template<
typename DType>
44 DType b_real, DType b_imag) {
45 return a_real * b_imag + b_real * a_imag;
51 template<
typename DType>
53 DType b_real, DType b_imag) {
54 return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
56 template<
typename DType>
58 DType b_real, DType b_imag) {
59 return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
64 template<
typename TA,
typename DType>
67 return src_.
Eval(real_i, real_j);
69 template<
typename TA,
typename DType>
72 return -src_.
Eval(imag_i, imag_j);
77 template<
typename TA,
typename DType>
80 return src_.
Eval(imag_i, imag_j);
82 template<
typename TA,
typename DType>
85 return src_.
Eval(real_i, real_j);
91 template<
typename TA,
typename DType>
94 return src_.
Eval(real_i, real_j);
96 template<
typename TA,
typename DType>
105 template<
typename TA,
typename DType>
108 DType real_val = src_.
Eval(real_i, real_j);
114 template<
typename TA,
typename DType>
117 DType real_val = src_.
Eval(real_i, real_j);
118 DType image_val = src_.
Eval(imag_i, imag_j);
119 return real_val * real_val + image_val * image_val;
124 template<
typename TA,
typename DType>
127 DType real_val = src_.
Eval(real_i, real_j);
128 DType image_val = src_.
Eval(imag_i, imag_j);
129 return real_val + image_val;
147 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
167 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
178 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int ta,
int tb>
179 inline ComplexBinaryMapExp<calctype, OP, TA, TB, DType, (ta | tb |
type::kMapper)>
190 template<
int calctype,
typename OP,
typename SrcExp,
typename DType,
int e1>
191 inline ComplexUnitaryExp<calctype, OP, SrcExp, DType, (e1 |
type::kMapper)>
199 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
203 return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
209 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
213 return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
219 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
223 return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
229 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
233 return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
239 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
243 return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
249 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
253 return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
261 template<
typename SrcExp,
typename DType,
int e1>
265 return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
273 template<
typename SrcExp,
typename DType,
int e1>
277 return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
285 template<
typename SrcExp,
typename DType,
int e1>
289 return ComplexF<op::complex::kUnitaryR2C, op::complex::pad_imag>(src);
297 template<
typename SrcExp,
typename DType,
int e1>
301 return ComplexF<op::complex::kUnitaryC2R, op::complex::toreal>(src);
309 template<
typename SrcExp,
typename DType,
int e1>
313 return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
316 template<
typename SrcExp,
typename DType,
int e1>
320 return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
323 template<
int dim,
int calctype,
typename OP,
typename TA,
typename TB,
324 typename DType,
int etype>
330 if (shape1[0] == 0)
return shape2;
331 if (shape2[0] == 0)
return shape1;
333 CHECK_EQ(shape1, shape2) <<
"ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
334 CHECK_EQ(shape1[dim - 1] % 2, 0) <<
335 "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. "
336 "We must have real part + imaginary part.";
339 for (
int i = 0; i < dim - 1; ++i) {
341 "ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
343 CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
344 "ComplexBinaryMapExp (CR): Shapes of operands do not match.";
347 for (
int i = 0; i < dim - 1; ++i) {
349 "ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
351 CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
352 "ComplexBinaryMapExp (RC): Shapes of operands do not match.";
355 LOG(FATAL) <<
"ComplexBinaryMapExp: Unexpected Calculation Type!";
361 template<
int dim,
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
365 CHECK_EQ(s[dim - 1] % 2, 0) <<
"ComplexUnitaryExp: Shape of the last dimension is not even. "
366 "We must have real + imaginary.";
378 LOG(FATAL) <<
"ComplexUnitaryExp: Unexpected Calculation Type!";
387 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
391 : lhs_(lhs), rhs_(rhs) {}
395 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
396 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
398 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
399 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
409 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
413 : lhs_(lhs), rhs_(rhs) {}
417 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
418 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
420 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
421 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
432 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
436 : lhs_(lhs), rhs_(rhs) {}
440 return OP::RealMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
441 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
443 return OP::ImagMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
444 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
455 template<
typename OP,
typename TA,
int etype,
typename DType>
462 return OP::RealMap(src_, y, base_x, y, base_x + 1);
464 return OP::ImagMap(src_, y, base_x, y, base_x + 1);
473 template<
typename OP,
typename TA,
int etype,
typename DType>
483 return OP::RealMap(src_, y, real_x);
485 return OP::ImagMap(src_, y, real_x);
494 template<
typename OP,
typename TA,
int etype,
typename DType>
499 return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
508 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
509 inline Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>, DType>
515 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
516 inline Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>, DType>
524 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
528 static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
531 ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
535 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
543 #endif // MSHADOW_EXTENSION_COMPLEX_H_
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:115
Plan(const Plan< TA, DType > &src)
Definition: complex.h:458
static const int kDevMask
Definition: expr_engine-inl.h:264
@ kBinaryCC
Definition: complex.h:33
const SubType & self(void) const
Definition: expression.h:82
@ kUnitaryC2R
Definition: complex.h:34
Definition: complex.h:113
@ kUnitaryC2C
Definition: complex.h:34
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:92
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:212
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
Definition: complex.h:252
binary map expression lhs [op] rhs where lhs and rhs are complex tensors
Definition: complex.h:148
ComplexBinaryMapExp< calctype, OP, TA, TB, DType,(ta|tb|type::kMapper)> ComplexF(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
Definition: complex.h:180
const TA & lhs_
left operand
Definition: complex.h:151
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:232
Plan(const Plan< TA, DType > &src)
Definition: complex.h:497
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:498
#define MSHADOW_XINLINE
Definition: base.h:228
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:412
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:106
compute conj(src) where src is a complex tensor
Definition: complex.h:168
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::toreal, SrcExp, DType,(e1|type::kMapper)> complex_toreal(const Exp< SrcExp, DType, e1 > &src)
complex_toreal convert complex matrix to real matrix, keep only real part
Definition: complex.h:300
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:65
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:125
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape
Definition: expression.h:50
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:477
static Shape< dim > Check(const E &t)
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
Definition: complex.h:104
static Shape< dim > Check(const ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > &t)
Definition: complex.h:327
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:70
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:86
@ kBinaryRC
Definition: complex.h:33
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:437
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::conjugate, SrcExp, DType,(e1|type::kMapper)> conj(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:264
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:459
static const int kDim
Definition: expr_engine-inl.h:263
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:242
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:435
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:202
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:38
int32_t index_t
type that will be used for index
Definition: base.h:328
const TB & rhs_
right operand
Definition: complex.h:153
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::sum_real_imag, SrcExp, DType,(e1|type::kMapper)> complex_sum_real_imag(const Exp< SrcExp, DType, e1 > &src)
Definition: complex.h:319
static Shape< dim > Check(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &t)
Definition: complex.h:363
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:392
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:57
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:52
@ kBinaryCR
Definition: complex.h:33
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
Definition: complex.h:222
overloaded + operator between half_t and bf16_t
Definition: base.h:319
UnitaryCalculationType
Definition: complex.h:34
ComplexUnitaryExp< op::complex::kUnitaryR2C, op::complex::pad_imag, SrcExp, DType,(e1|type::kMapper)> complex_pad_imag(const Exp< SrcExp, DType, e1 > &src)
complex_pad_imag Transform real matrix into complex matrix
Definition: complex.h:288
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:78
Plan(const Plan< TA, DType > &src)
Definition: complex.h:476
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:390
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:97
Definition: complex.h:123
BinaryCalculationType
Definition: complex.h:33
ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: complex.h:155
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::exchange, SrcExp, DType,(e1|type::kMapper)> complex_exchange(const Exp< SrcExp, DType, e1 > &src)
complex_exchange Exchange the real and imaginary part of A where A is a complex tensor
Definition: complex.h:276
const TA & src_
source expression
Definition: complex.h:171
ComplexUnitaryExp(const TA &src)
constructor
Definition: complex.h:173
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:43
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:83
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:414
@ kUnitaryR2C
Definition: complex.h:34
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::abs_square, SrcExp, DType,(e1|type::kMapper)> complex_abs_square(const Exp< SrcExp, DType, e1 > &src)
complex_abs_square calculate the square of the modulus of A where A is a complex tensor
Definition: complex.h:312