mxnet
complex.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_COMPLEX_H_
26 #define MSHADOW_EXTENSION_COMPLEX_H_
27 #include <algorithm>
28 #include "../extension.h"
29 
30 namespace mshadow {
31 namespace op {
32 namespace complex {
35 struct mul {
37  template<typename DType>
38  MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag,
39  DType b_real, DType b_imag) {
40  return a_real * b_real - a_imag * b_imag;
41  }
42  template<typename DType>
43  MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag,
44  DType b_real, DType b_imag) {
45  return a_real * b_imag + b_real * a_imag;
46  }
47 };
48 
49 struct div {
51  template<typename DType>
52  MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag,
53  DType b_real, DType b_imag) {
54  return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
55  }
56  template<typename DType>
57  MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag,
58  DType b_real, DType b_imag) {
59  return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
60  }
61 };
62 
63 struct conjugate {
64  template<typename TA, typename DType>
65  MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
66  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
67  return src_.Eval(real_i, real_j);
68  }
69  template<typename TA, typename DType>
70  MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
71  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
72  return -src_.Eval(imag_i, imag_j);
73  }
74 };
75 
76 struct exchange {
77  template<typename TA, typename DType>
78  MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
79  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
80  return src_.Eval(imag_i, imag_j);
81  }
82  template<typename TA, typename DType>
83  MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
84  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
85  return src_.Eval(real_i, real_j);
86  }
87 };
88 
89 // r2c operator
90 struct pad_imag {
91  template<typename TA, typename DType>
92  MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
93  index_t real_i, index_t real_j) {
94  return src_.Eval(real_i, real_j);
95  }
96  template<typename TA, typename DType>
97  MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
98  index_t real_i, index_t real_j) {
99  return 0;
100  }
101 };
102 
103 // c2r operator
104 struct toreal {
105  template<typename TA, typename DType>
107  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
108  DType real_val = src_.Eval(real_i, real_j);
109  return real_val;
110  }
111 };
112 
113 struct abs_square {
114  template<typename TA, typename DType>
116  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
117  DType real_val = src_.Eval(real_i, real_j);
118  DType image_val = src_.Eval(imag_i, imag_j);
119  return real_val * real_val + image_val * image_val;
120  }
121 };
122 
124  template<typename TA, typename DType>
126  index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
127  DType real_val = src_.Eval(real_i, real_j);
128  DType image_val = src_.Eval(imag_i, imag_j);
129  return real_val + image_val;
130  }
131 };
132 } // namespace complex
133 } // namespace op
134 
135 namespace expr {
136 //--------------------
137 // ComplexBinaryMapExp
138 //--------------------
147 template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
148 struct ComplexBinaryMapExp : public Exp<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
149  DType, etype> {
151  const TA &lhs_;
153  const TB &rhs_;
155  explicit ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
156  :lhs_(lhs), rhs_(rhs) {}
157 };
158 
159 //-------------------
160 // ComplexConjExp
161 //-------------------
167 template<int calctype, typename OP, typename TA, typename DType, int etype>
168 struct ComplexUnitaryExp : public Exp<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
169  DType, etype> {
171  const TA &src_;
173  explicit ComplexUnitaryExp(const TA &src) : src_(src) {}
174 };
175 
176 
177 
178 template<int calctype, typename OP, typename TA, typename TB, typename DType, int ta, int tb>
179 inline ComplexBinaryMapExp<calctype, OP, TA, TB, DType, (ta | tb | type::kMapper)>
181  return ComplexBinaryMapExp<calctype, OP, TA, TB, DType,
182  (ta | tb | type::kMapper)>(lhs.self(), rhs.self());
183 }
184 
190 template<int calctype, typename OP, typename SrcExp, typename DType, int e1>
191 inline ComplexUnitaryExp<calctype, OP, SrcExp, DType, (e1 | type::kMapper)>
194 }
195 
199 template<typename TA, typename TB, typename DType, int ta, int tb>
200 inline ComplexBinaryMapExp<op::complex::kBinaryCC, op::complex::mul,
201  TA, TB, DType, (ta | tb | type::kMapper)>
203  return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
204 }
205 
209 template<typename TA, typename TB, typename DType, int ta, int tb>
210 inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::mul,
211  TA, TB, DType, (ta | tb | type::kMapper)>
213  return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
214 }
215 
219 template<typename TA, typename TB, typename DType, int ta, int tb>
220 inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::mul,
221  TA, TB, DType, (ta | tb | type::kMapper)>
223  return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
224 }
225 
229 template<typename TA, typename TB, typename DType, int ta, int tb>
230 inline ComplexBinaryMapExp<op::complex::kBinaryCC, op::complex::div,
231  TA, TB, DType, (ta | tb | type::kMapper)>
233  return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
234 }
235 
239 template<typename TA, typename TB, typename DType, int ta, int tb>
240 inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::div,
241  TA, TB, DType, (ta | tb | type::kMapper)>
243  return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
244 }
245 
249 template<typename TA, typename TB, typename DType, int ta, int tb>
250 inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::div,
251  TA, TB, DType, (ta | tb | type::kMapper)>
253  return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
254 }
255 
261 template<typename SrcExp, typename DType, int e1>
262 inline ComplexUnitaryExp<op::complex::kUnitaryC2C, op::complex::conjugate,
263  SrcExp, DType, (e1|type::kMapper)>
265  return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
266 }
267 
273 template<typename SrcExp, typename DType, int e1>
274 inline ComplexUnitaryExp<op::complex::kUnitaryC2C, op::complex::exchange,
275  SrcExp, DType, (e1|type::kMapper)>
277  return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
278 }
279 
285 template<typename SrcExp, typename DType, int e1>
286 inline ComplexUnitaryExp<op::complex::kUnitaryR2C, op::complex::pad_imag,
287  SrcExp, DType, (e1|type::kMapper)>
289  return ComplexF<op::complex::kUnitaryR2C, op::complex::pad_imag>(src);
290 }
291 
297 template<typename SrcExp, typename DType, int e1>
298 inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::toreal,
299  SrcExp, DType, (e1 | type::kMapper)>
301  return ComplexF<op::complex::kUnitaryC2R, op::complex::toreal>(src);
302 }
303 
309 template<typename SrcExp, typename DType, int e1>
310 inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::abs_square,
311  SrcExp, DType, (e1 | type::kMapper)>
313  return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
314 }
315 
316 template<typename SrcExp, typename DType, int e1>
317 inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::sum_real_imag,
318  SrcExp, DType, (e1 | type::kMapper)>
320  return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
321 }
322 
323 template<int dim, int calctype, typename OP, typename TA, typename TB,
324  typename DType, int etype>
325 struct ShapeCheck<dim, ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > {
326  inline static Shape<dim>
330  if (shape1[0] == 0) return shape2;
331  if (shape2[0] == 0) return shape1;
332  if (calctype == op::complex::kBinaryCC) {
333  CHECK_EQ(shape1, shape2) << "ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
334  CHECK_EQ(shape1[dim - 1] % 2, 0) <<
335  "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. "
336  "We must have real part + imaginary part.";
337  return shape1;
338  } else if (calctype == op::complex::kBinaryCR) {
339  for (int i = 0; i < dim - 1; ++i) {
340  CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) <<
341  "ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
342  }
343  CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
344  "ComplexBinaryMapExp (CR): Shapes of operands do not match.";
345  return shape1;
346  } else if (calctype == op::complex::kBinaryRC) {
347  for (int i = 0; i < dim - 1; ++i) {
348  CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) <<
349  "ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
350  }
351  CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
352  "ComplexBinaryMapExp (RC): Shapes of operands do not match.";
353  return shape2;
354  } else {
355  LOG(FATAL) << "ComplexBinaryMapExp: Unexpected Calculation Type!";
356  return shape1;
357  }
358  }
359 };
360 
361 template<int dim, int calctype, typename OP, typename TA, typename DType, int etype>
362 struct ShapeCheck<dim, ComplexUnitaryExp<calctype, OP, TA, DType, etype> > {
365  CHECK_EQ(s[dim - 1] % 2, 0) << "ComplexUnitaryExp: Shape of the last dimension is not even. "
366  "We must have real + imaginary.";
367  if (calctype == op::complex::kUnitaryC2C) {
368  return s;
369  } else if (calctype == op::complex::kUnitaryC2R) {
370  Shape<dim> s_ret = s;
371  s_ret[dim - 1] /= 2;
372  return s_ret;
373  } else if (calctype == op::complex::kUnitaryR2C) {
374  Shape<dim> s_ret = s;
375  s_ret[dim-1] *= 2;
376  return s_ret;
377  } else {
378  LOG(FATAL) << "ComplexUnitaryExp: Unexpected Calculation Type!";
379  return s;
380  }
381  }
382 };
383 
384 
385 
386 // complex binary expression (cc)
387 template<typename OP, typename TA, typename TB, int etype, typename DType>
388 class Plan<ComplexBinaryMapExp<op::complex::kBinaryCC, OP, TA, TB, DType, etype>, DType> {
389  public:
390  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
391  : lhs_(lhs), rhs_(rhs) {}
392  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
393  const index_t base_x = static_cast<index_t>(x / 2) * 2;
394  if (x % 2 == 0) {
395  return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
396  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
397  } else {
398  return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
399  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
400  }
401  }
402 
403  private:
404  Plan<TA, DType> lhs_;
405  Plan<TB, DType> rhs_;
406 };
407 
408 // complex binary expression (cr)
409 template<typename OP, typename TA, typename TB, int etype, typename DType>
410 class Plan<ComplexBinaryMapExp<op::complex::kBinaryCR, OP, TA, TB, DType, etype>, DType> {
411  public:
412  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
413  : lhs_(lhs), rhs_(rhs) {}
414  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
415  const index_t base_x = static_cast<index_t>(x / 2) * 2;
416  if (x % 2 == 0) {
417  return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
418  rhs_.Eval(y, base_x / 2), static_cast<DType>(0));
419  } else {
420  return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
421  rhs_.Eval(y, base_x / 2), static_cast<DType>(0));
422  }
423  }
424 
425  private:
426  Plan<TA, DType> lhs_;
427  Plan<TB, DType> rhs_;
428 };
429 
430 
431 // complex binary expression (rc)
432 template<typename OP, typename TA, typename TB, int etype, typename DType>
433 class Plan<ComplexBinaryMapExp<op::complex::kBinaryRC, OP, TA, TB, DType, etype>, DType> {
434  public:
435  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
436  : lhs_(lhs), rhs_(rhs) {}
437  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
438  const index_t base_x = static_cast<index_t>(x / 2) * 2;
439  if (x % 2 == 0) {
440  return OP::RealMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0),
441  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
442  } else {
443  return OP::ImagMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0),
444  rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
445  }
446  }
447 
448  private:
449  Plan<TA, DType> lhs_;
450  Plan<TB, DType> rhs_;
451 };
452 
453 
454 // complex unitary expression (c2c)
455 template<typename OP, typename TA, int etype, typename DType>
456 class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2C, OP, TA, DType, etype>, DType> {
457  public:
458  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
459  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
460  const index_t base_x = static_cast<index_t>(x / 2) * 2;
461  if (0 == x % 2) {
462  return OP::RealMap(src_, y, base_x, y, base_x + 1);
463  } else {
464  return OP::ImagMap(src_, y, base_x, y, base_x + 1);
465  }
466  }
467 
468  private:
469  Plan<TA, DType> src_;
470 };
471 
472 // complex unitary expression (r2c)
473 template<typename OP, typename TA, int etype, typename DType>
474 class Plan<ComplexUnitaryExp<op::complex::kUnitaryR2C, OP, TA, DType, etype>, DType> {
475  public:
476  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
477  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
478  const index_t real_x = static_cast<index_t>(x / 2);
479  if (0 == x%2) {
480  // x,y should be coordinates in the complex matrix
481  // this defines how we will give value to the real part from the real matrix src_,
482  // thus the index has only 2 dimensions
483  return OP::RealMap(src_, y, real_x);
484  } else {
485  return OP::ImagMap(src_, y, real_x);
486  }
487  }
488 
489  private:
490  Plan<TA, DType> src_;
491 };
492 
493 // complex unitary expression (c2r)
494 template<typename OP, typename TA, int etype, typename DType>
495 class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2R, OP, TA, DType, etype>, DType> {
496  public:
497  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
498  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
499  return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
500  }
501 
502  private:
503  Plan<TA, DType> src_;
504 };
505 
506 
507 
508 template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
509 inline Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>, DType>
512  DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
513 }
514 
515 template<int calctype, typename OP, typename TA, typename DType, int etype>
516 inline Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>, DType>
519  DType>(MakePlan(e.src_));
520 }
521 
522 
523 
524 template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
525 struct ExpInfo<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > {
526  static const int kDimLhs = ExpInfo<TA>::kDim;
527  static const int kDimRhs = ExpInfo<TB>::kDim;
528  static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
529  (kDimLhs == 0 ? \
530  kDimRhs : \
531  ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
533 };
534 
535 template<int calctype, typename OP, typename TA, typename DType, int etype>
536 struct ExpInfo<ComplexUnitaryExp<calctype, OP, TA, DType, etype> > {
537  static const int kDim = ExpInfo<TA>::kDim;
538  static const int kDevMask = ExpInfo<TA>::kDevMask;
539 };
540 
541 } // namespace expr
542 } // namespace mshadow
543 #endif // MSHADOW_EXTENSION_COMPLEX_H_
mshadow::op::complex::abs_square::RealMap
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:115
mshadow::expr::Plan< ComplexUnitaryExp< op::complex::kUnitaryC2C, OP, TA, DType, etype >, DType >::Plan
Plan(const Plan< TA, DType > &src)
Definition: complex.h:458
mshadow::expr::ExpInfo::kDevMask
static const int kDevMask
Definition: expr_engine-inl.h:264
mshadow::op::complex::kBinaryCC
@ kBinaryCC
Definition: complex.h:33
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::op::complex::kUnitaryC2R
@ kUnitaryC2R
Definition: complex.h:34
mshadow::op::complex::abs_square
Definition: complex.h:113
mshadow::op::complex::conjugate
Definition: complex.h:63
mshadow::op::complex::kUnitaryC2C
@ kUnitaryC2C
Definition: complex.h:34
mshadow::op::complex::pad_imag::RealMap
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:92
mshadow::op::complex::pad_imag
Definition: complex.h:90
mshadow::expr::complex_mul_cr
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:212
mshadow::expr::complex_div_rc
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
Definition: complex.h:252
mshadow::expr::ComplexBinaryMapExp
binary map expression lhs [op] rhs where lhs and rhs are complex tensors
Definition: complex.h:148
mshadow::expr::ComplexF
ComplexBinaryMapExp< calctype, OP, TA, TB, DType,(ta|tb|type::kMapper)> ComplexF(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
Definition: complex.h:180
mshadow::expr::ComplexBinaryMapExp::lhs_
const TA & lhs_
left operand
Definition: complex.h:151
mshadow::expr::complex_div_cc
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:232
mshadow::expr::Plan< ComplexUnitaryExp< op::complex::kUnitaryC2R, OP, TA, DType, etype >, DType >::Plan
Plan(const Plan< TA, DType > &src)
Definition: complex.h:497
mshadow::expr::Plan< ComplexUnitaryExp< op::complex::kUnitaryC2R, OP, TA, DType, etype >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:498
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::Plan< ComplexBinaryMapExp< op::complex::kBinaryCR, OP, TA, TB, DType, etype >, DType >::Plan
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:412
mshadow::op::complex::toreal::RealMap
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:106
mshadow::expr::ComplexUnitaryExp
compute conj(src) where src is a complex tensor
Definition: complex.h:168
mshadow::expr::complex_toreal
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::toreal, SrcExp, DType,(e1|type::kMapper)> complex_toreal(const Exp< SrcExp, DType, e1 > &src)
complex_toreal convert complex matrix to real matrix, keep only real part
Definition: complex.h:300
mshadow::op::complex::conjugate::RealMap
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:65
mshadow::expr::ShapeCheck
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
mshadow::op::complex::sum_real_imag::RealMap
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:125
mshadow::expr::type::kMapper
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape
Definition: expression.h:50
mshadow::expr::Plan< ComplexUnitaryExp< op::complex::kUnitaryR2C, OP, TA, DType, etype >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:477
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::ExpInfo
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
mshadow::op::complex::toreal
Definition: complex.h:104
mshadow::expr::ShapeCheck< dim, ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > >::Check
static Shape< dim > Check(const ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > &t)
Definition: complex.h:327
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::op::complex::div
Definition: complex.h:49
mshadow::op::complex::conjugate::ImagMap
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:70
mshadow::Shape::shape_
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:86
mshadow::op::complex::kBinaryRC
@ kBinaryRC
Definition: complex.h:33
mshadow::expr::Plan< ComplexBinaryMapExp< op::complex::kBinaryRC, OP, TA, TB, DType, etype >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:437
mshadow::expr::conj
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::conjugate, SrcExp, DType,(e1|type::kMapper)> conj(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:264
mshadow::expr::Plan< ComplexUnitaryExp< op::complex::kUnitaryC2C, OP, TA, DType, etype >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:459
mshadow::expr::ExpInfo::kDim
static const int kDim
Definition: expr_engine-inl.h:263
mshadow::expr::complex_div_cr
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:242
mshadow::expr::Plan::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
mshadow::expr::Plan< ComplexBinaryMapExp< op::complex::kBinaryRC, OP, TA, TB, DType, etype >, DType >::Plan
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:435
mshadow::op::complex::exchange
Definition: complex.h:76
mshadow::expr::complex_mul_cc
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:202
mshadow::op::complex::mul::RealMap
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:38
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::ComplexBinaryMapExp::rhs_
const TB & rhs_
right operand
Definition: complex.h:153
mshadow::expr::Plan< TA, DType >
mshadow::expr::complex_sum_real_imag
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::sum_real_imag, SrcExp, DType,(e1|type::kMapper)> complex_sum_real_imag(const Exp< SrcExp, DType, e1 > &src)
Definition: complex.h:319
mshadow::expr::ShapeCheck< dim, ComplexUnitaryExp< calctype, OP, TA, DType, etype > >::Check
static Shape< dim > Check(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &t)
Definition: complex.h:363
mshadow::expr::Plan< ComplexBinaryMapExp< op::complex::kBinaryCC, OP, TA, TB, DType, etype >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:392
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::op::complex::div::ImagMap
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:57
mshadow::op::complex::div::RealMap
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:52
mshadow::op::complex::kBinaryCR
@ kBinaryCR
Definition: complex.h:33
mshadow::expr::complex_mul_rc
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
Definition: complex.h:222
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::op::complex::UnitaryCalculationType
UnitaryCalculationType
Definition: complex.h:34
mshadow::expr::complex_pad_imag
ComplexUnitaryExp< op::complex::kUnitaryR2C, op::complex::pad_imag, SrcExp, DType,(e1|type::kMapper)> complex_pad_imag(const Exp< SrcExp, DType, e1 > &src)
complex_pad_imag Transform real matrix into complex matrix
Definition: complex.h:288
mshadow::op::complex::exchange::RealMap
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:78
mshadow::Shape< dim >
mshadow::expr::Plan< ComplexUnitaryExp< op::complex::kUnitaryR2C, OP, TA, DType, etype >, DType >::Plan
Plan(const Plan< TA, DType > &src)
Definition: complex.h:476
mshadow::expr::Plan< ComplexBinaryMapExp< op::complex::kBinaryCC, OP, TA, TB, DType, etype >, DType >::Plan
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:390
mshadow::op::complex::pad_imag::ImagMap
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:97
mshadow::op::complex::mul
Definition: complex.h:35
mshadow::op::complex::sum_real_imag
Definition: complex.h:123
mshadow::op::complex::BinaryCalculationType
BinaryCalculationType
Definition: complex.h:33
mshadow::expr::ComplexBinaryMapExp::ComplexBinaryMapExp
ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: complex.h:155
mshadow::expr::complex_exchange
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::exchange, SrcExp, DType,(e1|type::kMapper)> complex_exchange(const Exp< SrcExp, DType, e1 > &src)
complex_exchange Exchange the real and imaginary part of A where A is a complex tensor
Definition: complex.h:276
mshadow::expr::ComplexUnitaryExp::src_
const TA & src_
source expression
Definition: complex.h:171
mshadow::expr::ComplexUnitaryExp::ComplexUnitaryExp
ComplexUnitaryExp(const TA &src)
constructor
Definition: complex.h:173
mshadow::op::complex::mul::ImagMap
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:43
mshadow::op::complex::exchange::ImagMap
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:83
mshadow::expr::Plan< ComplexBinaryMapExp< op::complex::kBinaryCR, OP, TA, TB, DType, etype >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:414
mshadow::op::complex::kUnitaryR2C
@ kUnitaryR2C
Definition: complex.h:34
mshadow::expr::complex_abs_square
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::abs_square, SrcExp, DType,(e1|type::kMapper)> complex_abs_square(const Exp< SrcExp, DType, e1 > &src)
complex_abs_square calculate the square of the modulus of A where A is a complex tensor
Definition: complex.h:312