mxnet
src
common
cuda
rtc
reducer-inl.h
Go to the documentation of this file.
1
/*
2
* Licensed to the Apache Software Foundation (ASF) under one
3
* or more contributor license agreements. See the NOTICE file
4
* distributed with this work for additional information
5
* regarding copyright ownership. The ASF licenses this file
6
* to you under the Apache License, Version 2.0 (the
7
* "License"); you may not use this file except in compliance
8
* with the License. You may obtain a copy of the License at
9
*
10
* http://www.apache.org/licenses/LICENSE-2.0
11
*
12
* Unless required by applicable law or agreed to in writing,
13
* software distributed under the License is distributed on an
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
* KIND, either express or implied. See the License for the
16
* specific language governing permissions and limitations
17
* under the License.
18
*/
19
20
#ifndef MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_
21
#define MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_
22
23
#if MXNET_USE_CUDA
24
25
namespace
mxnet
{
26
namespace
common {
27
namespace
cuda {
28
namespace
rtc {
29
30
const
char
reducer
[] = R
"code(
31
namespace red {
32
33
struct sum {
35
template<typename DType, typename DType2>
36
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
37
dst = op::add(dst, src);
38
}
39
41
template<typename DType, typename DType2>
42
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
43
volatile DType& residual) {
44
DType y = op::sub(src, residual);
45
DType t = dst + y;
46
if (util::isinf(t)) {
47
residual = 0;
48
} else {
49
residual = (t - dst) - y;
50
}
51
dst = t;
52
}
54
template<typename DType>
55
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
56
Reduce(dst_val, src_val);
57
}
59
template<typename DType>
60
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
61
volatile DType& src_val, volatile DType& src_residual) {
62
DType t1 = dst_val + src_val;
63
if (util::isinf(t1)) {
64
dst_val = t1;
65
dst_residual = 0;
66
} else {
67
DType e = t1 - dst_val;
68
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
69
dst_val = t1 + t2;
70
dst_residual = t2 - (dst_val - t1);
71
}
72
}
74
template<typename DType>
75
__device__ inline static void Finalize(volatile DType& dst) {}
77
template<typename DType>
78
__device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
82
template<typename DType>
83
__device__ inline static void SetInitValue(DType &initv) {
84
initv = 0;
85
}
89
template<typename DType>
90
__device__ inline static void SetInitValue(DType &initv, DType &residual) {
91
SetInitValue(initv);
92
residual = 0;
93
}
94
};
95
96
struct product {
98
template<typename DType, typename DType2>
99
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
100
dst = op::mul(dst, src);
101
}
103
template<typename DType, typename DType2>
104
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
105
volatile DType& none) {
106
Reduce(dst, src);
107
}
109
template<typename DType>
110
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
111
Reduce(dst_val, src_val);
112
}
114
template<typename DType>
115
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
116
volatile DType& src_val, volatile DType& src_residual) {
117
Reduce(dst_val, src_val);
118
}
120
template<typename DType>
121
__device__ inline static void Finalize(volatile DType& dst) {}
123
template<typename DType>
124
__device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
128
template<typename DType>
129
__device__ inline static void SetInitValue(DType &initv) {
130
initv = 1;
131
}
135
template<typename DType>
136
__device__ inline static void SetInitValue(DType &initv, DType &none) {
137
SetInitValue(initv);
138
}
139
};
140
141
struct nansum {
143
template<typename DType, typename DType2>
144
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
145
if (util::isnan(src)) return;
146
dst = op::add(dst, src);
147
}
149
template<typename DType>
150
__device__ inline static void Reduce(volatile DType& dst, volatile DType src,
151
volatile DType& residual) {
152
if (util::isnan(src)) return;
153
DType y = src - residual;
154
DType t = dst + y;
155
residual = (t - dst) - y;
156
dst = t;
157
}
159
template<typename DType>
160
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
161
Reduce(dst_val, src_val);
162
}
164
template<typename DType>
165
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
166
volatile DType& src_val, volatile DType& src_residual) {
167
DType t1 = dst_val + src_val;
168
DType e = t1 - src_val;
169
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
170
dst_val = t1 + t2;
171
dst_residual = t2 - (dst_val - t1);
172
}
174
template<typename DType>
175
__device__ inline static void Finalize(volatile DType& dst) {}
177
template<typename DType>
178
__device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
182
template<typename DType>
183
__device__ inline static void SetInitValue(DType & initv) {
184
initv = 0;
185
}
189
template<typename DType>
190
__device__ inline static void SetInitValue(DType &initv, DType &residual) {
191
SetInitValue(initv);
192
residual = 0;
193
}
194
};
195
196
struct nanprod {
198
template<typename DType, typename DType2>
199
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
200
if (util::isnan(src)) return;
201
dst = op::mul(dst, src);
202
}
204
template<typename DType>
205
__device__ inline static void Reduce(volatile DType& dst, volatile DType src,
206
volatile DType& none) {
207
Reduce(dst, src);
208
}
210
template<typename DType>
211
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
212
Reduce(dst_val, src_val);
213
}
215
template<typename DType>
216
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
217
volatile DType& src_val, volatile DType& src_residual) {
218
Reduce(dst_val, src_val);
219
}
221
template<typename DType>
222
__device__ inline static void Finalize(volatile DType& dst) {}
224
template<typename DType>
225
__device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
229
template<typename DType>
230
__device__ inline static void SetInitValue(DType & initv) {
231
initv = 1;
232
}
236
template<typename DType>
237
__device__ inline static void SetInitValue(DType &initv, DType &none) {
238
SetInitValue(initv);
239
}
240
};
241
242
struct nrm2 {
244
template<typename AType, typename DType>
245
__device__ inline static void Reduce(volatile AType& sum_of_squares, volatile DType src) {
246
sum_of_squares = op::add(sum_of_square, src * src);
247
}
249
template<typename AType, typename DType>
250
__device__ inline static void Reduce(volatile AType& sum_of_squares,
251
volatile DType src, volatile DType& scale) {
252
if (src != 0) {
253
DType abs = op::abs(src);
254
if (scale < abs) {
255
sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
256
scale = abs;
257
} else {
258
sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
259
}
260
}
261
}
263
template<typename DType>
264
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
265
dst_val = op::add(dst_val, src_val);
266
}
268
template<typename DType>
269
__device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale,
270
volatile DType& src_ssq, volatile DType& src_scale) {
271
if (dst_scale != 0 && dst_scale >= src_scale) {
272
dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
273
} else if (src_scale != 0 && dst_scale < src_scale) {
274
dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale);
275
dst_scale = src_scale;
276
}
277
}
279
template<typename DType>
280
__device__ inline static void Finalize(volatile DType& sum_of_squares) {
281
sum_of_squares = op::sqrt(sum_of_squares);
282
}
284
template<typename DType>
285
__device__ inline static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) {
286
sum_of_squares = scale * op::sqrt(sum_of_squares);
287
}
291
template<typename DType>
292
__device__ inline static void SetInitValue(DType &sum_of_squares) {
293
sum_of_squares = 0;
294
}
298
template<typename DType>
299
__device__ inline static void SetInitValue(DType &sum_of_squares, DType &scale) {
300
SetInitValue(sum_of_squares);
301
scale = 0;
302
}
303
};
304
305
struct nrmlp {
306
double lp;
307
/* \brief power for Lp norm */
308
__device__ inline static double lp_power(volatile double src, volatile double p) {
309
if (p != 0.0) {
310
if (src == 0.0) {
311
return src;
312
} else {
313
return op::power(src, p);
314
}
315
} else { // 0-norm, sparsity
316
return static_cast<double>(src != 0);
317
}
318
}
319
321
template<typename AType, typename DType>
322
__device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src) {
323
if (src != 0) {
324
sum_of_powers += AType(lp_power(static_cast<double>(src), lp));
325
}
326
}
327
329
template<typename AType, typename DType>
330
__device__ inline void Reduce(volatile AType& sum_of_powers, volatile DType src,
331
volatile DType& scale) {
332
if (src != 0) {
333
DType src_abs = op::abs(src);
334
if (scale < src_abs) {
335
sum_of_powers = sum_of_powers * AType(lp_power(static_cast<double>(scale / src_abs), lp));
336
sum_of_powers = sum_of_powers + 1;
337
scale = src_abs;
338
} else {
339
sum_of_powers = sum_of_powers + AType(lp_power(static_cast<double>(src_abs / scale), lp));
340
}
341
}
342
}
343
345
template<typename DType>
346
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
347
dst_val = dst_val + src_val;
348
}
349
351
template<typename DType>
352
__device__ inline static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale,
353
volatile DType& src_ssq, volatile DType& src_scale) {
354
if (dst_scale != 0 && dst_scale >= src_scale) {
355
dst_ssq = dst_ssq + src_ssq * DType(lp_power(static_cast<double>(src_scale / dst_scale), 2));
356
} else if (src_scale != 0 && dst_scale < src_scale) {
357
dst_ssq = src_ssq + dst_ssq * DType(lp_power(static_cast<double>(dst_scale / src_scale), 2));
358
dst_scale = src_scale;
359
}
360
}
361
363
template<typename DType>
364
__device__ inline void Finalize(volatile DType& sum_of_powers) {
365
if (lp != 0.0) {
366
sum_of_powers = DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / lp));
367
}
368
}
369
371
template<typename DType>
372
__device__ inline void Finalize(volatile DType& sum_of_powers, volatile DType& scale) {
373
if (lp != 0.0) {
374
sum_of_powers = scale * DType(lp_power(static_cast<double>(sum_of_powers), 1.0 / lp));
375
}
376
}
377
381
template<typename DType>
382
__device__ inline static void SetInitValue(DType &sum_of_powers) {
383
sum_of_powers = 0;
384
}
385
389
template<typename DType>
390
__device__ inline static void SetInitValue(DType &sum_of_powers, DType &scale) {
391
SetInitValue(sum_of_powers);
392
scale = 0;
393
}
394
};
395
396
} // namespace red
397
)code";
398
399
const
char
logic_reducer
[] = R
"code(
400
namespace red {
401
402
struct maximum {
404
template<typename DType, typename DType2>
405
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*)
406
if (!util::isnan(dst)) {
407
if (!(dst >= src)) dst = src;
408
}
409
}
411
template<typename DType, typename DType2>
412
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
413
volatile DType& none) {
414
Reduce(dst, src);
415
}
417
template<typename DType>
418
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
419
Reduce(dst_val, src_val);
420
}
422
template<typename DType>
423
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
424
volatile DType& src_val, volatile DType& src_residual) {
425
Reduce(dst_val, src_val);
426
}
428
template<typename DType>
429
__device__ inline static void Finalize(volatile DType& dst) {}
431
template<typename DType>
432
__device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
436
template<typename DType>
437
__device__ inline static void SetInitValue(DType &initv) {
438
initv = limits::NegInfValue<DType>();
439
}
443
template<typename DType>
444
__device__ inline static void SetInitValue(DType &initv, DType &none) {
445
SetInitValue(initv);
446
}
447
};
448
449
struct minimum {
451
template<typename DType, typename DType2>
452
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
453
if (!util::isnan(dst)) {
454
if (!(dst <= src)) dst = src;
455
}
456
}
458
template<typename DType, typename DType2>
459
__device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
460
volatile DType& none) {
461
Reduce(dst, src);
462
}
464
template<typename DType>
465
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
466
Reduce(dst_val, src_val);
467
}
469
template<typename DType>
470
__device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
471
volatile DType& src_val, volatile DType& src_residual) {
472
Reduce(dst_val, src_val);
473
}
475
template<typename DType>
476
__device__ inline static void Finalize(volatile DType& dst) {}
478
template<typename DType>
479
__device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
483
template<typename DType>
484
__device__ inline static void SetInitValue(DType &initv) {
485
initv = limits::PosInfValue<DType>();
486
}
490
template<typename DType>
491
__device__ inline static void SetInitValue(DType &initv, DType &none) {
492
SetInitValue(initv);
493
}
494
};
495
496
struct argmax {
498
template<typename AType, typename DType>
499
__device__ inline static void Reduce(volatile AType& dst, volatile DType src) {
500
if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
501
dst.num = src.num;
502
dst.idx = src.idx;
503
}
504
}
506
template<typename AType, typename DType>
507
__device__ inline static void Reduce(volatile AType& dst, volatile DType src,
508
volatile DType&) {
509
if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
510
dst.num = src.num;
511
dst.idx = src.idx;
512
}
513
}
515
template<typename DType>
516
__device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
517
if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
518
dst.num = src.num;
519
dst.idx = src.idx;
520
}
521
}
523
template<typename DType>
524
__device__ inline static void Merge(volatile DType& dst, volatile DType&,
525
volatile DType& src, volatile DType&) {
526
if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
527
dst.num = src.num;
528
dst.idx = src.idx;
529
}
530
}
532
template<typename DType>
533
__device__ inline static void Finalize(volatile DType& dst) {}
535
template<typename DType>
536
__device__ inline static void Finalize(volatile DType& dst, volatile DType&) {}
540
template<typename DType>
541
__device__ inline static void SetInitValue(DType &initv) {
542
initv.num = limits::NegInfValue<decltype(initv.num)>();
543
}
547
template<typename DType>
548
__device__ inline static void SetInitValue(DType &initv, DType &) {
549
initv.num = limits::NegInfValue<decltype(initv.num)>();
550
}
551
};
552
553
struct argmin {
555
template<typename AType, typename DType>
556
__device__ inline static void Reduce(volatile AType& dst, volatile DType src) {
557
if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
558
dst.num = src.num;
559
dst.idx = src.idx;
560
}
561
}
563
template<typename AType, typename DType>
564
__device__ inline static void Reduce(volatile AType& dst, volatile DType src,
565
volatile DType& residual) {
566
if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
567
dst.num = src.num;
568
dst.idx = src.idx;
569
}
570
}
572
template<typename DType>
573
__device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
574
if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
575
dst.num = src.num;
576
dst.idx = src.idx;
577
}
578
}
580
template<typename DType>
581
__device__ inline static void Merge(volatile DType& dst, volatile DType&,
582
volatile DType& src, volatile DType&) {
583
if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
584
dst.num = src.num;
585
dst.idx = src.idx;
586
}
587
}
589
template<typename DType>
590
__device__ inline static void Finalize(volatile DType& dst) {}
592
template<typename DType>
593
__device__ inline static void Finalize(volatile DType& dst, volatile DType& residual) {}
597
template<typename DType>
598
__device__ inline static void SetInitValue(DType &initv) {
599
initv.num = limits::PosInfValue<decltype(initv.num)>();
600
}
604
template<typename DType>
605
__device__ inline static void SetInitValue(DType &initv, DType &residual) {
606
initv.num = limits::PosInfValue<decltype(initv.num)>();
607
}
608
};
609
} // namespace red
610
)code";
611
}
// namespace rtc
612
}
// namespace cuda
613
}
// namespace common
614
}
// namespace mxnet
615
616
#endif // MXNET_USE_CUDA
617
618
#endif // MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_
mxnet
namespace of mxnet
Definition:
api_registry.h:33
mxnet::common::cuda::rtc::logic_reducer
const char logic_reducer[]
Definition:
reducer-inl.h:399
mxnet::common::cuda::rtc::reducer
const char reducer[]
Definition:
reducer-inl.h:30
Generated on Thu Jan 5 2023 03:47:40 for mxnet by
1.8.17