mxnet
complex.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_COMPLEX_H_
27 #define MSHADOW_EXTENSION_COMPLEX_H_
28 #include <algorithm>
29 #include "../extension.h"
30 
31 namespace mshadow {
32 namespace op {
33 namespace complex {
36 struct mul {
38  template<typename DType>
39  MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag,
40  DType b_real, DType b_imag) {
41  return a_real * b_real - a_imag * b_imag;
42  }
43  template<typename DType>
44  MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag,
45  DType b_real, DType b_imag) {
46  return a_real * b_imag + b_real * a_imag;
47  }
48 };
49 
50 struct div {
52  template<typename DType>
53  MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag,
54  DType b_real, DType b_imag) {
55  return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
56  }
57  template<typename DType>
58  MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag,
59  DType b_real, DType b_imag) {
60  return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
61  }
62 };
63 
64 struct conjugate {
65  template<typename TA, typename DType>
66  MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
67  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
68  return src_.Eval(real_i, real_j);
69  }
70  template<typename TA, typename DType>
71  MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
72  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
73  return -src_.Eval(imag_i, imag_j);
74  }
75 };
76 
77 struct exchange {
78  template<typename TA, typename DType>
79  MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
80  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
81  return src_.Eval(imag_i, imag_j);
82  }
83  template<typename TA, typename DType>
84  MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
85  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
86  return src_.Eval(real_i, real_j);
87  }
88 };
89 
90 // r2c operator
91 struct pad_imag {
92  template<typename TA, typename DType>
93  MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
94  index_t real_i, index_t real_j) {
95  return src_.Eval(real_i, real_j);
96  }
97  template<typename TA, typename DType>
98  MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
99  index_t real_i, index_t real_j) {
100  return 0;
101  }
102 };
103 
104 // c2r operator
105 struct toreal {
106  template<typename TA, typename DType>
108  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
109  DType real_val = src_.Eval(real_i, real_j);
110  return real_val;
111  }
112 };
113 
114 struct abs_square {
115  template<typename TA, typename DType>
117  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
118  DType real_val = src_.Eval(real_i, real_j);
119  DType image_val = src_.Eval(imag_i, imag_j);
120  return real_val * real_val + image_val * image_val;
121  }
122 };
123 
125  template<typename TA, typename DType>
127  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
128  DType real_val = src_.Eval(real_i, real_j);
129  DType image_val = src_.Eval(imag_i, imag_j);
130  return real_val + image_val;
131  }
132 };
133 } // namespace complex
134 } // namespace op
135 
136 namespace expr {
137 //--------------------
138 // ComplexBinaryMapExp
139 //--------------------
148 template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
149 struct ComplexBinaryMapExp : public Exp<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
150  DType, etype> {
152  const TA &lhs_;
154  const TB &rhs_;
156  explicit ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
157  :lhs_(lhs), rhs_(rhs) {}
158 };
159 
160 //-------------------
161 // ComplexConjExp
162 //-------------------
168 template<int calctype, typename OP, typename TA, typename DType, int etype>
169 struct ComplexUnitaryExp : public Exp<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
170  DType, etype> {
172  const TA &src_;
174  explicit ComplexUnitaryExp(const TA &src) : src_(src) {}
175 };
176 
177 
178 
179 template<int calctype, typename OP, typename TA, typename TB, typename DType, int ta, int tb>
182  return ComplexBinaryMapExp<calctype, OP, TA, TB, DType,
183  (ta | tb | type::kMapper)>(lhs.self(), rhs.self());
184 }
185 
191 template<int calctype, typename OP, typename SrcExp, typename DType, int e1>
195 }
196 
200 template<typename TA, typename TB, typename DType, int ta, int tb>
202  TA, TB, DType, (ta | tb | type::kMapper)>
204  return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
205 }
206 
210 template<typename TA, typename TB, typename DType, int ta, int tb>
211 inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::mul,
212  TA, TB, DType, (ta | tb | type::kMapper)>
214  return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
215 }
216 
220 template<typename TA, typename TB, typename DType, int ta, int tb>
221 inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::mul,
222  TA, TB, DType, (ta | tb | type::kMapper)>
224  return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
225 }
226 
230 template<typename TA, typename TB, typename DType, int ta, int tb>
232  TA, TB, DType, (ta | tb | type::kMapper)>
234  return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
235 }
236 
240 template<typename TA, typename TB, typename DType, int ta, int tb>
241 inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::div,
242  TA, TB, DType, (ta | tb | type::kMapper)>
244  return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
245 }
246 
250 template<typename TA, typename TB, typename DType, int ta, int tb>
251 inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::div,
252  TA, TB, DType, (ta | tb | type::kMapper)>
254  return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
255 }
256 
262 template<typename SrcExp, typename DType, int e1>
264  SrcExp, DType, (e1|type::kMapper)>
266  return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
267 }
268 
274 template<typename SrcExp, typename DType, int e1>
276  SrcExp, DType, (e1|type::kMapper)>
278  return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
279 }
280 
286 template<typename SrcExp, typename DType, int e1>
288  SrcExp, DType, (e1|type::kMapper)>
290  return ComplexF<op::complex::kUnitaryR2C, op::complex::pad_imag>(src);
291 }
292 
298 template<typename SrcExp, typename DType, int e1>
300  SrcExp, DType, (e1 | type::kMapper)>
302  return ComplexF<op::complex::kUnitaryC2R, op::complex::toreal>(src);
303 }
304 
310 template<typename SrcExp, typename DType, int e1>
312  SrcExp, DType, (e1 | type::kMapper)>
314  return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
315 }
316 
317 template<typename SrcExp, typename DType, int e1>
319  SrcExp, DType, (e1 | type::kMapper)>
321  return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
322 }
323 
324 template<int dim, int calctype, typename OP, typename TA, typename TB,
325  typename DType, int etype>
326 struct ShapeCheck<dim, ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > {
327  inline static Shape<dim>
331  if (shape1[0] == 0) return shape2;
332  if (shape2[0] == 0) return shape1;
333  if (calctype == op::complex::kBinaryCC) {
334  CHECK_EQ(shape1, shape2) << "ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
335  CHECK_EQ(shape1[dim - 1] % 2, 0) <<
336  "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. "
337  "We must have real part + imaginary part.";
338  return shape1;
339  } else if (calctype == op::complex::kBinaryCR) {
340  for (int i = 0; i < dim - 1; ++i) {
341  CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) <<
342  "ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
343  }
344  CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
345  "ComplexBinaryMapExp (CR): Shapes of operands do not match.";
346  return shape1;
347  } else if (calctype == op::complex::kBinaryRC) {
348  for (int i = 0; i < dim - 1; ++i) {
349  CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) <<
350  "ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
351  }
352  CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
353  "ComplexBinaryMapExp (RC): Shapes of operands do not match.";
354  return shape2;
355  } else {
356  LOG(FATAL) << "ComplexBinaryMapExp: Unexpected Calculation Type!";
357  return shape1;
358  }
359  }
360 };
361 
362 template<int dim, int calctype, typename OP, typename TA, typename DType, int etype>
363 struct ShapeCheck<dim, ComplexUnitaryExp<calctype, OP, TA, DType, etype> > {
366  CHECK_EQ(s[dim - 1] % 2, 0) << "ComplexUnitaryExp: Shape of the last dimension is not even. "
367  "We must have real + imaginary.";
368  if (calctype == op::complex::kUnitaryC2C) {
369  return s;
370  } else if (calctype == op::complex::kUnitaryC2R) {
371  Shape<dim> s_ret = s;
372  s_ret[dim - 1] /= 2;
373  return s_ret;
374  } else if (calctype == op::complex::kUnitaryR2C) {
375  Shape<dim> s_ret = s;
376  s_ret[dim-1] *= 2;
377  return s_ret;
378  } else {
379  LOG(FATAL) << "ComplexUnitaryExp: Unexpected Calculation Type!";
380  return s;
381  }
382  }
383 };
384 
385 
386 
387 // complex binary expression (cc)
388 template<typename OP, typename TA, typename TB, int etype, typename DType>
389 class Plan<ComplexBinaryMapExp<op::complex::kBinaryCC, OP, TA, TB, DType, etype>, DType> {
390  public:
391  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
392  : lhs_(lhs), rhs_(rhs) {}
393  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
394  const index_t base_x = static_cast<index_t>(x / 2) * 2;
395  if (x % 2 == 0) {
396  return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
397  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
398  } else {
399  return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
400  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
401  }
402  }
403 
404  private:
405  Plan<TA, DType> lhs_;
406  Plan<TB, DType> rhs_;
407 };
408 
409 // complex binary expression (cr)
410 template<typename OP, typename TA, typename TB, int etype, typename DType>
411 class Plan<ComplexBinaryMapExp<op::complex::kBinaryCR, OP, TA, TB, DType, etype>, DType> {
412  public:
413  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
414  : lhs_(lhs), rhs_(rhs) {}
415  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
416  const index_t base_x = static_cast<index_t>(x / 2) * 2;
417  if (x % 2 == 0) {
418  return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
419  rhs_.Eval(y, base_x / 2), static_cast<DType>(0));
420  } else {
421  return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
422  rhs_.Eval(y, base_x / 2), static_cast<DType>(0));
423  }
424  }
425 
426  private:
427  Plan<TA, DType> lhs_;
428  Plan<TB, DType> rhs_;
429 };
430 
431 
432 // complex binary expression (rc)
433 template<typename OP, typename TA, typename TB, int etype, typename DType>
434 class Plan<ComplexBinaryMapExp<op::complex::kBinaryRC, OP, TA, TB, DType, etype>, DType> {
435  public:
436  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
437  : lhs_(lhs), rhs_(rhs) {}
438  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
439  const index_t base_x = static_cast<index_t>(x / 2) * 2;
440  if (x % 2 == 0) {
441  return OP::RealMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0),
442  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
443  } else {
444  return OP::ImagMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0),
445  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
446  }
447  }
448 
449  private:
450  Plan<TA, DType> lhs_;
451  Plan<TB, DType> rhs_;
452 };
453 
454 
455 // complex unitary expression (c2c)
456 template<typename OP, typename TA, int etype, typename DType>
457 class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2C, OP, TA, DType, etype>, DType> {
458  public:
459  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
460  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
461  const index_t base_x = static_cast<index_t>(x / 2) * 2;
462  if (0 == x % 2) {
463  return OP::RealMap(src_, y, base_x, y, base_x + 1);
464  } else {
465  return OP::ImagMap(src_, y, base_x, y, base_x + 1);
466  }
467  }
468 
469  private:
470  Plan<TA, DType> src_;
471 };
472 
473 // complex unitary expression (r2c)
474 template<typename OP, typename TA, int etype, typename DType>
475 class Plan<ComplexUnitaryExp<op::complex::kUnitaryR2C, OP, TA, DType, etype>, DType> {
476  public:
477  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
478  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
479  const index_t real_x = static_cast<index_t>(x / 2);
480  if (0 == x%2) {
481  // x,y should be coordinates in the complex matrix
482  // this defines how we will give value to the real part from the real matrix src_,
483  // thus the index has only 2 dimensions
484  return OP::RealMap(src_, y, real_x);
485  } else {
486  return OP::ImagMap(src_, y, real_x);
487  }
488  }
489 
490  private:
491  Plan<TA, DType> src_;
492 };
493 
494 // complex unitary expression (c2r)
495 template<typename OP, typename TA, int etype, typename DType>
496 class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2R, OP, TA, DType, etype>, DType> {
497  public:
498  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
499  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
500  return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
501  }
502 
503  private:
504  Plan<TA, DType> src_;
505 };
506 
507 
508 
509 template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
512  return Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
513  DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
514 }
515 
516 template<int calctype, typename OP, typename TA, typename DType, int etype>
519  return Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
520  DType>(MakePlan(e.src_));
521 }
522 
523 
524 
525 template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
526 struct ExpInfo<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > {
527  static const int kDimLhs = ExpInfo<TA>::kDim;
528  static const int kDimRhs = ExpInfo<TB>::kDim;
529  static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
530  (kDimLhs == 0 ? \
531  kDimRhs : \
532  ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
533  static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask;
534 };
535 
536 template<int calctype, typename OP, typename TA, typename DType, int etype>
537 struct ExpInfo<ComplexUnitaryExp<calctype, OP, TA, DType, etype> > {
538  static const int kDim = ExpInfo<TA>::kDim;
539  static const int kDevMask = ExpInfo<TA>::kDevMask;
540 };
541 
542 } // namespace expr
543 } // namespace mshadow
544 #endif // MSHADOW_EXTENSION_COMPLEX_H_
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
Definition: complex.h:223
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:53
Plan< ComplexUnitaryExp< calctype, OP, TA, DType, etype >, DType > MakePlan(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &e)
Definition: complex.h:518
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:84
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:71
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:393
Definition: complex.h:34
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:51
ComplexUnitaryExp(const TA &src)
constructor
Definition: complex.h:174
Definition: complex.h:34
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::abs_square, SrcExp, DType,(e1|type::kMapper)> complex_abs_square(const Exp< SrcExp, DType, e1 > &src)
complex_abs_square calculate the square of the modulus of A where A is a complex tensor ...
Definition: complex.h:313
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:126
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
UnitaryCalculationType
Definition: complex.h:35
const TB & rhs_
right operand
Definition: complex.h:154
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:499
ComplexUnitaryExp< calctype, OP, SrcExp, DType,(e1|type::kMapper)> ComplexF(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:193
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:107
ComplexUnitaryExp< op::complex::kUnitaryR2C, op::complex::pad_imag, SrcExp, DType,(e1|type::kMapper)> complex_pad_imag(const Exp< SrcExp, DType, e1 > &src)
complex_pad_imag Transform real matrix into complex matrix
Definition: complex.h:289
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:98
Definition: complex.h:36
static Shape< dim > Check(const ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > &t)
Definition: complex.h:328
Definition: complex.h:34
BinaryCalculationType
Definition: complex.h:34
Definition: complex.h:124
const TA & lhs_
left operand
Definition: complex.h:152
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:203
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::conjugate, SrcExp, DType,(e1|type::kMapper)> conj(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:265
#define MSHADOW_XINLINE
Definition: base.h:223
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:438
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:460
const TA & src_
source expression
Definition: complex.h:172
int32_t index_t
type that will be used for index
Definition: base.h:336
static Shape< dim > Check(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &t)
Definition: complex.h:364
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:93
Definition: complex.h:35
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:233
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:39
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:478
Definition: complex.h:35
Definition: complex.h:91
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:436
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:76
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:66
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:116
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::exchange, SrcExp, DType,(e1|type::kMapper)> complex_exchange(const Exp< SrcExp, DType, e1 > &src)
complex_exchange Exchange the real and imaginary part of A where A is a complex tensor ...
Definition: complex.h:277
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:58
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:391
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:213
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:413
Definition: complex.h:105
ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: complex.h:156
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::sum_real_imag, SrcExp, DType,(e1|type::kMapper)> complex_sum_real_imag(const Exp< SrcExp, DType, e1 > &src)
Definition: complex.h:320
const SubType & self(void) const
Definition: expression.h:83
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:44
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::toreal, SrcExp, DType,(e1|type::kMapper)> complex_toreal(const Exp< SrcExp, DType, e1 > &src)
complex_toreal convert complex matrix to real matrix, keep only real part
Definition: complex.h:301
Definition: complex.h:64
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:415
Definition: complex.h:77
Definition: complex.h:35
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:79
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
Definition: complex.h:253
Definition: complex.h:114
compute conj(src) where src is a complex tensor
Definition: complex.h:169
Definition: complex.h:50
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:243
binary map expression lhs [op] rhs where lhs and rhs are complex tensors
Definition: complex.h:149