mxnet
reducer-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_
22 
23 #if MXNET_USE_CUDA
24 
25 namespace mxnet {
26 namespace common {
27 namespace cuda {
28 namespace rtc {
29 
30 const char reducer[] = R"code(
31 namespace red {
32 
33 struct sum {
35  template<typename DType, typename DType2>
36  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
37  dst = op::add(dst, src);
38  }
39 
41  template<typename DType, typename DType2>
42  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
43  volatile DType& residual) {
44  DType y = op::sub(src, residual);
45  DType t = dst + y;
46  if (util::isinf(t)) {
47  residual = 0;
48  } else {
49  residual = (t - dst) - y;
50  }
51  dst = t;
52  }
54  template<typename DType>
55  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
56  Reduce(dst_val, src_val);
57  }
59  template<typename DType>
60  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
61  volatile DType& src_val, volatile DType& src_residual) {
62  DType t1 = dst_val + src_val;
63  if (util::isinf(t1)) {
64  dst_val = t1;
65  dst_residual = 0;
66  } else {
67  DType e = t1 - dst_val;
68  DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
69  dst_val = t1 + t2;
70  dst_residual = t2 - (dst_val - t1);
71  }
72  }
74  template<typename DType>
75  __device__ inline static void Finalize(volatile DType& dst) {}
77  template<typename DType>
78  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
82  template<typename DType>
83  __device__ inline static void SetInitValue(DType &initv) {
84  initv = 0;
85  }
89  template<typename DType>
90  __device__ inline static void SetInitValue(DType &initv, DType &residual) {
91  SetInitValue(initv);
92  residual = 0;
93  }
94 };
95 
96 struct product {
98  template<typename DType, typename DType2>
99  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
100  dst = op::mul(dst, src);
101  }
103  template<typename DType, typename DType2>
104  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
105  volatile DType& none) {
106  Reduce(dst, src);
107  }
109  template<typename DType>
110  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
111  Reduce(dst_val, src_val);
112  }
114  template<typename DType>
115  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
116  volatile DType& src_val, volatile DType& src_residual) {
117  Reduce(dst_val, src_val);
118  }
120  template<typename DType>
121  __device__ inline static void Finalize(volatile DType& dst) {}
123  template<typename DType>
124  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
128  template<typename DType>
129  __device__ inline static void SetInitValue(DType &initv) {
130  initv = 1;
131  }
135  template<typename DType>
136  __device__ inline static void SetInitValue(DType &initv, DType &none) {
137  SetInitValue(initv);
138  }
139 };
140 
141 struct nansum {
143  template<typename DType, typename DType2>
144  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
145  if (util::isnan(src)) return;
146  dst = op::add(dst, src);
147  }
149  template<typename DType>
150  __device__ inline static void Reduce(volatile DType& dst, volatile DType src,
151  volatile DType& residual) {
152  if (util::isnan(src)) return;
153  DType y = src - residual;
154  DType t = dst + y;
155  residual = (t - dst) - y;
156  dst = t;
157  }
159  template<typename DType>
160  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
161  Reduce(dst_val, src_val);
162  }
164  template<typename DType>
165  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
166  volatile DType& src_val, volatile DType& src_residual) {
167  DType t1 = dst_val + src_val;
168  DType e = t1 - src_val;
169  DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
170  dst_val = t1 + t2;
171  dst_residual = t2 - (dst_val - t1);
172  }
174  template<typename DType>
175  __device__ inline static void Finalize(volatile DType& dst) {}
177  template<typename DType>
178  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
182  template<typename DType>
183  __device__ inline static void SetInitValue(DType & initv) {
184  initv = 0;
185  }
189  template<typename DType>
190  __device__ inline static void SetInitValue(DType &initv, DType &residual) {
191  SetInitValue(initv);
192  residual = 0;
193  }
194 };
195 
196 struct nanprod {
198  template<typename DType, typename DType2>
199  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
200  if (util::isnan(src)) return;
201  dst = op::mul(dst, src);
202  }
204  template<typename DType>
205  __device__ inline static void Reduce(volatile DType& dst, volatile DType src,
206  volatile DType& none) {
207  Reduce(dst, src);
208  }
210  template<typename DType>
211  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
212  Reduce(dst_val, src_val);
213  }
215  template<typename DType>
216  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
217  volatile DType& src_val, volatile DType& src_residual) {
218  Reduce(dst_val, src_val);
219  }
221  template<typename DType>
222  __device__ inline static void Finalize(volatile DType& dst) {}
224  template<typename DType>
225  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
229  template<typename DType>
230  __device__ inline static void SetInitValue(DType & initv) {
231  initv = 1;
232  }
236  template<typename DType>
237  __device__ inline static void SetInitValue(DType &initv, DType &none) {
238  SetInitValue(initv);
239  }
240 };
241 
242 struct nrm2 {
244  template<typename AType, typename DType>
245  __device__ inline static void Reduce(volatile AType& sum_of_squares, volatile DType src) {
246  sum_of_squares = op::add(sum_of_square, src * src);
247  }
249  template<typename AType, typename DType>
250  __device__ inline static void Reduce(volatile AType& sum_of_squares,
251  volatile DType src, volatile DType& scale) {
252  if (src != 0) {
253  DType abs = op::abs(src);
254  if (scale < abs) {
255  sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
256  scale = abs;
257  } else {
258  sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
259  }
260  }
261  }
263  template<typename DType>
264  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
265  dst_val = op::add(dst_val, src_val);
266  }
268  template<typename DType>
269  __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale,
270  volatile DType& src_ssq, volatile DType& src_scale) {
271  if (dst_scale != 0 && dst_scale >= src_scale) {
272  dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
273  } else if (src_scale != 0 && dst_scale < src_scale) {
274  dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale);
275  dst_scale = src_scale;
276  }
277  }
279  template<typename DType>
280  __device__ inline static void Finalize(volatile DType& sum_of_squares) {
281  sum_of_squares = op::sqrt(sum_of_squares);
282  }
284  template<typename DType>
285  __device__ inline static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) {
286  sum_of_squares = scale * op::sqrt(sum_of_squares);
287  }
291  template<typename DType>
292  __device__ inline static void SetInitValue(DType &sum_of_squares) {
293  sum_of_squares = 0;
294  }
298  template<typename DType>
299  __device__ inline static void SetInitValue(DType &sum_of_squares, DType &scale) {
300  SetInitValue(sum_of_squares);
301  scale = 0;
302  }
303 };
304 
305 struct nrmlp {
306  double lp;
307  /* \brief power for Lp norm */
308  __device__ inline static double lp_power(volatile double src, volatile double p) {
309  if (p != 0.0) {
310  if (src == 0.0) {
311  return src;
312  } else {
313  return op::power(src, p);
314  }
315  } else { // 0-norm, sparsity
316  return static_cast<double>(src != 0);
317  }
318  }
319 
321  template<typename AType, typename DType>
322  __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src) {
323  if (src != 0) {
324  sum_of_powers += AType(lp_power(static_cast<double>(src), lp));
325  }
326  }
327 
329  template<typename AType, typename DType>
330  __device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src,
331  volatile DType& scale) {
332  if (src != 0) {
333  DType src_abs = op::abs(src);
334  if (scale < src_abs) {
335  sum_of_powers = sum_of_powers * AType(lp_power(static_cast<double>(scale / src_abs), lp));
336  sum_of_powers = sum_of_powers + 1;
337  scale = src_abs;
338  } else {
339  sum_of_powers = sum_of_powers + AType(lp_power(static_cast<double>(src_abs / scale), lp));
340  }
341  }
342  }
343 
345  template<typename DType>
346  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
347  dst_val = dst_val + src_val;
348  }
349 
351  template<typename DType>
352  __device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale,
353  volatile DType& src_ssq, volatile DType& src_scale) {
354  if (dst_scale != 0 && dst_scale >= src_scale) {
355  dst_ssq = dst_ssq + src_ssq * DType(lp_power(static_cast<double>(src_scale / dst_scale), 2));
356  } else if (src_scale != 0 && dst_scale < src_scale) {
357  dst_ssq = src_ssq + dst_ssq * DType(lp_power(static_cast<double>(dst_scale / src_scale), 2));
358  dst_scale = src_scale;
359  }
360  }
361 
363  template<typename DType>
364  __device__ inline void Finalize(volatile DType& sum_of_powers) {
365  if (lp != 0.0) {
366  sum_of_powers = DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / lp));
367  }
368  }
369 
371  template<typename DType>
372  __device__ inline void Finalize(volatile DType& sum_of_powers, volatile DType& scale) {
373  if (lp != 0.0) {
374  sum_of_powers = scale * DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / lp));
375  }
376  }
377 
381  template<typename DType>
382  __device__ inline static void SetInitValue(DType &sum_of_powers) {
383  sum_of_powers = 0;
384  }
385 
389  template<typename DType>
390  __device__ inline static void SetInitValue(DType &sum_of_powers, DType &scale) {
391  SetInitValue(sum_of_powers);
392  scale = 0;
393  }
394 };
395 
396 } // namespace red
397 )code";
398 
399 const char logic_reducer[] = R"code(
400 namespace red {
401 
402 struct maximum {
404  template<typename DType, typename DType2>
405  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*)
406  if (!util::isnan(dst)) {
407  if (!(dst >= src)) dst = src;
408  }
409  }
411  template<typename DType, typename DType2>
412  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
413  volatile DType& none) {
414  Reduce(dst, src);
415  }
417  template<typename DType>
418  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
419  Reduce(dst_val, src_val);
420  }
422  template<typename DType>
423  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
424  volatile DType& src_val, volatile DType& src_residual) {
425  Reduce(dst_val, src_val);
426  }
428  template<typename DType>
429  __device__ inline static void Finalize(volatile DType& dst) {}
431  template<typename DType>
432  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
436  template<typename DType>
437  __device__ inline static void SetInitValue(DType &initv) {
438  initv = limits::NegInfValue<DType>();
439  }
443  template<typename DType>
444  __device__ inline static void SetInitValue(DType &initv, DType &none) {
445  SetInitValue(initv);
446  }
447 };
448 
449 struct minimum {
451  template<typename DType, typename DType2>
452  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
453  if (!util::isnan(dst)) {
454  if (!(dst <= src)) dst = src;
455  }
456  }
458  template<typename DType, typename DType2>
459  __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
460  volatile DType& none) {
461  Reduce(dst, src);
462  }
464  template<typename DType>
465  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
466  Reduce(dst_val, src_val);
467  }
469  template<typename DType>
470  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
471  volatile DType& src_val, volatile DType& src_residual) {
472  Reduce(dst_val, src_val);
473  }
475  template<typename DType>
476  __device__ inline static void Finalize(volatile DType& dst) {}
478  template<typename DType>
479  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
483  template<typename DType>
484  __device__ inline static void SetInitValue(DType &initv) {
485  initv = limits::PosInfValue<DType>();
486  }
490  template<typename DType>
491  __device__ inline static void SetInitValue(DType &initv, DType &none) {
492  SetInitValue(initv);
493  }
494 };
495 
496 struct argmax {
498  template<typename AType, typename DType>
499  __device__ inline static void Reduce(volatile AType& dst, volatile DType src) {
500  if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
501  dst.num = src.num;
502  dst.idx = src.idx;
503  }
504  }
506  template<typename AType, typename DType>
507  __device__ inline static void Reduce(volatile AType& dst, volatile DType src,
508  volatile DType&) {
509  if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
510  dst.num = src.num;
511  dst.idx = src.idx;
512  }
513  }
515  template<typename DType>
516  __device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
517  if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
518  dst.num = src.num;
519  dst.idx = src.idx;
520  }
521  }
523  template<typename DType>
524  __device__ inline static void Merge(volatile DType& dst, volatile DType&,
525  volatile DType& src, volatile DType&) {
526  if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
527  dst.num = src.num;
528  dst.idx = src.idx;
529  }
530  }
532  template<typename DType>
533  __device__ inline static void Finalize(volatile DType& dst) {}
535  template<typename DType>
536  __device__ inline static void Finalize(volatile DType& dst, volatile DType&) {}
540  template<typename DType>
541  __device__ inline static void SetInitValue(DType &initv) {
542  initv.num = limits::NegInfValue<decltype(initv.num)>();
543  }
547  template<typename DType>
548  __device__ inline static void SetInitValue(DType &initv, DType &) {
549  initv.num = limits::NegInfValue<decltype(initv.num)>();
550  }
551 };
552 
553 struct argmin {
555  template<typename AType, typename DType>
556  __device__ inline static void Reduce(volatile AType& dst, volatile DType src) {
557  if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
558  dst.num = src.num;
559  dst.idx = src.idx;
560  }
561  }
563  template<typename AType, typename DType>
564  __device__ inline static void Reduce(volatile AType& dst, volatile DType src,
565  volatile DType& residual) {
566  if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
567  dst.num = src.num;
568  dst.idx = src.idx;
569  }
570  }
572  template<typename DType>
573  __device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
574  if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
575  dst.num = src.num;
576  dst.idx = src.idx;
577  }
578  }
580  template<typename DType>
581  __device__ inline static void Merge(volatile DType& dst, volatile DType&,
582  volatile DType& src, volatile DType&) {
583  if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
584  dst.num = src.num;
585  dst.idx = src.idx;
586  }
587  }
589  template<typename DType>
590  __device__ inline static void Finalize(volatile DType& dst) {}
592  template<typename DType>
593  __device__ inline static void Finalize(volatile DType& dst, volatile DType& residual) {}
597  template<typename DType>
598  __device__ inline static void SetInitValue(DType &initv) {
599  initv.num = limits::PosInfValue<decltype(initv.num)>();
600  }
604  template<typename DType>
605  __device__ inline static void SetInitValue(DType &initv, DType &residual) {
606  initv.num = limits::PosInfValue<decltype(initv.num)>();
607  }
608 };
609 } // namespace red
610 )code";
611 } // namespace rtc
612 } // namespace cuda
613 } // namespace common
614 } // namespace mxnet
615 
616 #endif // MXNET_USE_CUDA
617 
618 #endif // MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::logic_reducer
const char logic_reducer[]
Definition: reducer-inl.h:399
mxnet::common::cuda::rtc::reducer
const char reducer[]
Definition: reducer-inl.h:30