mxnet
src
common
cuda
rtc
backward_functions-inl.h
Go to the documentation of this file.
1
/*
2
* Licensed to the Apache Software Foundation (ASF) under one
3
* or more contributor license agreements. See the NOTICE file
4
* distributed with this work for additional information
5
* regarding copyright ownership. The ASF licenses this file
6
* to you under the Apache License, Version 2.0 (the
7
* "License"); you may not use this file except in compliance
8
* with the License. You may obtain a copy of the License at
9
*
10
* http://www.apache.org/licenses/LICENSE-2.0
11
*
12
* Unless required by applicable law or agreed to in writing,
13
* software distributed under the License is distributed on an
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
* KIND, either express or implied. See the License for the
16
* specific language governing permissions and limitations
17
* under the License.
18
*/
19
20
#ifndef MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_
21
#define MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_
22
23
#if MXNET_USE_CUDA
24
25
namespace
mxnet
{
26
namespace
common {
27
namespace
cuda {
28
namespace
rtc {
29
30
const
char
backward_function_definitions
[] = R
"code(
31
32
namespace op {
33
34
template <typename DType, typename DTypeGrad>
35
__device__ inline mixed_type<DTypeGrad, DType>
36
backward_relu(const DTypeGrad grad, const DType val) {
37
if (isnan(val)) return val;
38
return val > 0 ? grad : 0;
39
}
40
41
template <typename DType, typename DTypeGrad>
42
__device__ inline mixed_type<DTypeGrad, DType>
43
backward_sigmoid(const DTypeGrad grad, const DType val) {
44
return grad * val * (1 - val);
45
}
46
47
template <typename DType, typename DTypeGrad>
48
__device__ inline mixed_type<DTypeGrad, DType>
49
backward_log_sigmoid(const DTypeGrad grad, const DType val) {
50
return grad * (1 - op::exp(val));
51
}
52
53
template <typename DType, typename DTypeGrad>
54
__device__ inline mixed_type<DTypeGrad, DType>
55
backward_softrelu(const DTypeGrad grad, const DType val) {
56
const mixed_type<DTypeGrad, DType> v = val;
57
return grad * sigmoid(v);
58
}
59
60
template <typename DType, typename DTypeGrad>
61
__device__ inline mixed_type<DTypeGrad, DType>
62
backward_softsign(const DTypeGrad grad, const DType val) {
63
const mixed_type<DTypeGrad, DType> v = val;
64
const auto ap1 = 1 + op::abs(v);
65
return grad / (ap1 * ap1);
66
}
67
68
template <typename DType, typename DTypeGrad>
69
__device__ inline mixed_type<DTypeGrad, DType>
70
backward_abs(const DTypeGrad grad, const DType val) {
71
const mixed_type<DTypeGrad, DType> v = val;
72
return grad * op::sign(v);
73
}
74
75
template <typename DType, typename DTypeGrad>
76
__device__ inline mixed_type<DTypeGrad, DType>
77
backward_exp(const DTypeGrad grad, const DType val) {
78
const mixed_type<DTypeGrad, DType> v = val;
79
return grad * op::exp(v);
80
}
81
82
template <typename DType, typename DTypeGrad>
83
__device__ inline mixed_type<DTypeGrad, DType>
84
backward_expm1(const DTypeGrad grad, const DType val) {
85
return backward_exp(grad, val);
86
}
87
88
template <typename DType, typename DTypeGrad>
89
__device__ inline mixed_type<DTypeGrad, DType>
90
backward_log(const DTypeGrad grad, const DType val) {
91
return grad / val;
92
}
93
94
template <typename DType, typename DTypeGrad>
95
__device__ inline mixed_type<DTypeGrad, DType>
96
backward_log10(const DTypeGrad grad, const DType val) {
97
return grad / (val * op::log(static_cast<DTypeGrad>(10)));
98
}
99
100
template <typename DType, typename DTypeGrad>
101
__device__ inline mixed_type<DTypeGrad, DType>
102
backward_log2(const DTypeGrad grad, const DType val) {
103
return grad / (val * op::log(static_cast<DTypeGrad>(2)));
104
}
105
106
template <typename DType, typename DTypeGrad>
107
__device__ inline mixed_type<DTypeGrad, DType>
108
backward_log1p(const DTypeGrad grad, const DType val) {
109
return grad / (1 + val);
110
}
111
112
template <typename DType, typename DTypeGrad>
113
__device__ inline mixed_type<DTypeGrad, DType>
114
backward_sin(const DTypeGrad grad, const DType val) {
115
const mixed_type<DTypeGrad, DType> v = val;
116
return grad * op::cos(v);
117
}
118
119
template <typename DType, typename DTypeGrad>
120
__device__ inline mixed_type<DTypeGrad, DType>
121
backward_cos(const DTypeGrad grad, const DType val) {
122
const mixed_type<DTypeGrad, DType> v = val;
123
return -grad * op::sin(v);
124
}
125
126
// Uses output from tan
127
template <typename DType, typename DTypeGrad>
128
__device__ inline mixed_type<DTypeGrad, DType>
129
backward_tan(const DTypeGrad grad, const DType out) {
130
return grad * (out * out + 1);
131
}
132
133
template <typename DType, typename DTypeGrad>
134
__device__ inline mixed_type<DTypeGrad, DType>
135
backward_arcsin(const DTypeGrad grad, const DType val) {
136
const mixed_type<DTypeGrad, DType> v = val;
137
return grad / op::sqrt(1 - v*v);
138
}
139
140
template <typename DType, typename DTypeGrad>
141
__device__ inline mixed_type<DTypeGrad, DType>
142
backward_arccos(const DTypeGrad grad, const DType val) {
143
const mixed_type<DTypeGrad, DType> v = val;
144
return -grad / op::sqrt(1 - v*v);
145
}
146
147
template <typename DType, typename DTypeGrad>
148
__device__ inline mixed_type<DTypeGrad, DType>
149
backward_arctan(const DTypeGrad grad, const DType val) {
150
return grad / (1 + val*val);
151
}
152
153
template <typename DType, typename DTypeGrad>
154
__device__ inline mixed_type<DTypeGrad, DType>
155
backward_degrees(const DTypeGrad grad, const DType /* val */) {
156
return op::degrees(grad);
157
}
158
159
template <typename DType, typename DTypeGrad>
160
__device__ inline mixed_type<DTypeGrad, DType>
161
backward_radians(const DTypeGrad grad, const DType /* val */) {
162
return op::radians(grad);
163
}
164
165
template <typename DType, typename DTypeGrad>
166
__device__ inline mixed_type<DTypeGrad, DType>
167
backward_sinh(const DTypeGrad grad, const DType val) {
168
const mixed_type<DTypeGrad, DType> v = val;
169
return grad * op::cosh(v);
170
}
171
172
template <typename DType, typename DTypeGrad>
173
__device__ inline mixed_type<DTypeGrad, DType>
174
backward_cosh(const DTypeGrad grad, const DType val) {
175
const mixed_type<DTypeGrad, DType> v = val;
176
return grad * op::sinh(v);
177
}
178
179
// Uses tanh output
180
template <typename DType, typename DTypeGrad>
181
__device__ inline mixed_type<DTypeGrad, DType>
182
backward_tanh(const DTypeGrad grad, const DType out) {
183
return grad * (1 - out * out);
184
}
185
186
template <typename DType, typename DTypeGrad>
187
__device__ inline mixed_type<DTypeGrad, DType>
188
backward_arcsinh(const DTypeGrad grad, const DType val) {
189
const mixed_type<DTypeGrad, DType> v = val;
190
return grad / op::sqrt(v * v + 1);
191
}
192
193
template <typename DType, typename DTypeGrad>
194
__device__ inline mixed_type<DTypeGrad, DType>
195
backward_arccosh(const DTypeGrad grad, const DType val) {
196
const mixed_type<DTypeGrad, DType> v = val;
197
return grad / op::sqrt(v * v - 1);
198
}
199
200
template <typename DType, typename DTypeGrad>
201
__device__ inline mixed_type<DTypeGrad, DType>
202
backward_arctanh(const DTypeGrad grad, const DType val) {
203
return grad / (1 - val * val);
204
}
205
206
template <typename DType, typename DTypeGrad>
207
__device__ inline mixed_type<DTypeGrad, DType>
208
backward_mish(const DTypeGrad grad, const DType val) {
209
const auto softrelu = op::softrelu(val);
210
const auto tanh_sr = op::tanh(softrelu);
211
return grad * (tanh_sr + val * sigmoid(val) * (1 - tanh_sr * tanh_sr));
212
}
213
214
template <typename DType, typename DTypeGrad>
215
__device__ inline mixed_type<DTypeGrad, DType>
216
backward_sqrt(const DTypeGrad grad, const DType out) {
217
return 0.5 * grad / out;
218
}
219
220
template <typename DType, typename DTypeGrad>
221
__device__ inline mixed_type<DTypeGrad, DType>
222
backward_rsqrt(const DTypeGrad grad, const DType val) {
223
const mixed_type<DTypeGrad, DType> v = val;
224
const auto inv = 1 / v;
225
return -0.5 * grad * op::sqrt(inv) * inv;
226
}
227
228
template <typename DType, typename DTypeGrad>
229
__device__ inline mixed_type<DTypeGrad, DType>
230
backward_cbrt(const DTypeGrad grad, const DType out) {
231
return grad / (3.0f * out * out);
232
}
233
234
template <typename DType, typename DTypeGrad>
235
__device__ inline mixed_type<DTypeGrad, DType>
236
backward_rcbrt(const DTypeGrad grad, const DType val) {
237
const mixed_type<DTypeGrad, DType> v = val;
238
const auto inv = 1 / v;
239
return -1.f/3.f * grad * op::cbrt(inv) * inv;
240
}
241
242
template <typename DType, typename DTypeGrad>
243
__device__ inline mixed_type<DTypeGrad, DType>
244
backward_square(const DTypeGrad grad, const DType val) {
245
return 2 * val * grad;
246
}
247
248
template <typename DType, typename DType2>
249
__device__ inline DType div_rgrad(const DType val,
250
const DType2 val2) {
251
return -val / (val2 * val2);
252
}
253
254
template <typename DType, typename DTypeGrad>
255
__device__ inline mixed_type<DTypeGrad, DType>
256
backward_clip(const DTypeGrad grad, const DType val,
257
const float a_min, const float a_max) {
258
if (val > a_max || val < a_min) {
259
return 0;
260
} else {
261
return grad;
262
}
263
}
264
265
template <typename DType, typename DTypeGrad>
266
__device__ inline mixed_type<DTypeGrad, DType>
267
backward_reciprocal(const DTypeGrad grad, const DType val) {
268
return -grad / (val * val);
269
}
270
271
template <typename DType, typename DTypeGrad>
272
__device__ inline mixed_type<DTypeGrad, DType>
273
backward_erf(const DTypeGrad grad, const DType val) {
274
using type = mixed_type<DTypeGrad, DType>;
275
const type v = val;
276
constexpr type my_pi = pi;
277
return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
278
}
279
280
template <typename DType, typename DTypeGrad>
281
__device__ inline mixed_type<DTypeGrad, DType>
282
backward_erfinv(const DTypeGrad grad, const DType val) {
283
using type = mixed_type<DTypeGrad, DType>;
284
constexpr type my_pi = pi;
285
const type g = grad;
286
const type v = val;
287
return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
288
}
289
290
template <typename DType, typename DTypeGrad>
291
__device__ inline mixed_type<DTypeGrad, DType>
292
backward_gamma(const DTypeGrad grad, const DType val) {
293
using type = mixed_type<DTypeGrad, DType>;
294
const type v = val;
295
if (type_util::is_same<DTypeGrad, double>::value) {
296
return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
297
} else {
298
return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
299
}
300
}
301
302
template <typename DType, typename DTypeGrad>
303
__device__ inline mixed_type<DTypeGrad, DType>
304
backward_gammaln(const DTypeGrad grad, const DType val) {
305
using type = mixed_type<DTypeGrad, DType>;
306
const type v = val;
307
if (type_util::is_same<DTypeGrad, double>::value) {
308
return grad * op::special_functions::cephes::psi<double>(v);
309
} else {
310
return grad * op::special_functions::cephes::psi<float>(v);
311
}
312
}
313
314
template <typename DType, typename DTypeGrad>
315
__device__ inline mixed_type<DTypeGrad, DType>
316
backward_digamma(const DTypeGrad grad, const DType val) {
317
using type = mixed_type<DTypeGrad, DType>;
318
const type v = val;
319
if (type_util::is_same<DTypeGrad, double>::value) {
320
return grad * op::special_functions::trigamma<double>(v);
321
} else {
322
return grad * op::special_functions::trigamma<float>(v);
323
}
324
}
325
326
template <typename DType, typename DTypeGrad>
327
__device__ inline mixed_type<DTypeGrad, DType>
328
backward_gelu_erf(const DTypeGrad grad, const DType val) {
329
return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
330
val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
331
}
332
333
} // namespace op
334
335
)code";
336
337
const
char
grad_function_definitions
[] = R
"code(
338
namespace op {
339
340
template <typename DType, typename DType2>
341
__device__ inline mixed_type<DType, DType2>
342
rdiv_grad(const DType val,
343
const DType2 val2) {
344
return -val2 / (val * val);
345
}
346
347
template <typename DType, typename DType2>
348
__device__ inline mixed_type<DType, DType2>
349
div_grad(const DType val,
350
const DType2 val2) {
351
const mixed_type<DType, DType2> temp = val2;
352
return op::reciprocal(temp);
353
}
354
355
template <typename DType, typename DType2>
356
__device__ inline DType mod_grad(const DType val,
357
const DType2 val2) {
358
if (type_util::is_integral<DType>::value) {
359
return 0;
360
} else {
361
return 1;
362
}
363
}
364
365
template <typename DType, typename DType2>
366
__device__ inline DType mod_rgrad(const DType val,
367
const DType2 val2) {
368
if (type_util::is_integral<DType>::value) {
369
return 0;
370
} else {
371
return -op::floor(val / val2);
372
}
373
}
374
375
template <typename DType, typename DType2>
376
__device__ inline DType rmod_grad(const DType val,
377
const DType2 val2) {
378
if (type_util::is_integral<DType>::value) {
379
return 0;
380
} else {
381
return -op::floor(val2 / val);
382
}
383
}
384
385
template <typename DType, typename DType2>
386
__device__ inline mixed_type<DType, DType2>
387
power_grad(const DType val,
388
const DType2 val2) {
389
return op::power(val, val2 - 1.f) * val2;
390
}
391
392
template <typename DType, typename DType2>
393
__device__ inline mixed_type<DType, DType2>
394
power_rgrad(const DType val,
395
const DType2 val2) {
396
const mixed_type<DType, DType2> temp = val;
397
return op::power(val, val2) * op::log(temp);
398
}
399
400
template <typename DType, typename DType2>
401
__device__ inline mixed_type<DType, DType2>
402
rpower_grad(const DType val,
403
const DType2 val2) {
404
const mixed_type<DType, DType2> temp = val2;
405
return val * op::log(temp);
406
}
407
408
template <typename DType, typename DType2>
409
__device__ inline mixed_type<DType, DType2>
410
hypot_grad_left(const DType val,
411
const DType2 val2) {
412
return val / op::hypot(val, val2);
413
}
414
415
template <typename DType, typename DType2>
416
__device__ inline mixed_type<DType, DType2>
417
hypot_grad_right(const DType val,
418
const DType2 val2) {
419
return val2 / op::hypot(val, val2);
420
}
421
422
template <typename DType, typename DType2>
423
__device__ inline mixed_type<DType, DType2>
424
copysign_grad(const DType val,
425
const DType2 val2) {
426
return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
427
}
428
429
template <typename DType, typename DType2>
430
__device__ inline mixed_type<DType, DType2>
431
bitwise_left_shift_grad(const DType val,
432
const DType2 val2) {
433
return op::power(static_cast<DType>(2), val2);
434
}
435
436
template <typename DType, typename DType2>
437
__device__ inline mixed_type<DType, DType2>
438
bitwise_left_shift_rgrad(const DType val,
439
const DType2 val2) {
440
using type = mixed_type<DType, DType2>;
441
return val * op::power(static_cast<DType>(2), val2) * op::log(static_cast<type>(2));
442
}
443
444
template <typename DType, typename DType2>
445
__device__ inline mixed_type<DType, DType2>
446
rbitwise_left_shift_grad(const DType val,
447
const DType2 val2) {
448
using type = mixed_type<DType, DType2>;
449
return val2 * op::power(static_cast<DType>(2), val) * op::log(static_cast<type>(2));
450
}
451
452
template <typename DType, typename DType2>
453
__device__ inline mixed_type<DType, DType2>
454
bitwise_right_shift_grad(const DType val,
455
const DType2 val2) {
456
return op::power(0.5f, val2);
457
}
458
459
template <typename DType, typename DType2>
460
__device__ inline mixed_type<DType, DType2>
461
bitwise_right_shift_rgrad(const DType val,
462
const DType2 val2) {
463
return val * op::power(0.5f, val2) * op::log(0.5f);
464
}
465
466
template <typename DType, typename DType2>
467
__device__ inline mixed_type<DType, DType2>
468
rbitwise_right_shift_grad(const DType val,
469
const DType2 val2) {
470
return val2 * op::power(0.5f, val) * op::log(0.5f);
471
}
472
473
template <typename DType, typename DType2>
474
__device__ inline mixed_type<DType, DType2>
475
arctan2_grad(const DType val,
476
const DType2 val2) {
477
return val2 / (val * val + val2 * val2);
478
}
479
480
template <typename DType, typename DType2>
481
__device__ inline mixed_type<DType, DType2>
482
rarctan2_grad(const DType val,
483
const DType2 val2) {
484
return val / (val * val + val2 * val2);
485
}
486
487
template <typename DType, typename DType2>
488
__device__ inline mixed_type<DType, DType2>
489
arctan2_rgrad(const DType val,
490
const DType2 val2) {
491
return -rarctan2_grad(val, val2);
492
}
493
494
template <typename DType, typename DType2>
495
__device__ inline mixed_type<DType, DType2>
496
ldexp_grad(const DType val,
497
const DType2 val2) {
498
return op::power(static_cast<DType>(2), val2);
499
}
500
501
template <typename DType, typename DType2>
502
__device__ inline mixed_type<DType, DType2>
503
rldexp_grad(const DType val,
504
const DType2 val2) {
505
using type = mixed_type<DType, DType2>;
506
return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
507
}
508
509
template <typename DType, typename DType2>
510
__device__ inline mixed_type<DType, DType2>
511
logaddexp_grad(const DType val,
512
const DType2 val2) {
513
return op::exp(val) / (op::exp(val) + op::exp(val2));
514
}
515
516
template <typename DType, typename DType2>
517
__device__ inline mixed_type<DType, DType2>
518
logaddexp_rgrad(const DType val,
519
const DType2 val2) {
520
return op::exp(val2) / (op::exp(val) + op::exp(val2));
521
}
522
523
template <typename DType, typename DType2>
524
__device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
525
auto bsq = scalar * scalar;
526
auto ibsq = 1.0f / bsq;
527
if (val > ibsq) {
528
return 1;
529
} else if (val < -ibsq) {
530
return -1;
531
} else {
532
return bsq * val;
533
}
534
}
535
536
template <typename DType, typename DType2>
537
__device__ inline DType2 xelu_grad(const DType val,
538
const DType2 val2) {
539
return (val > 0) ? 1 : val2;
540
}
541
542
template <typename DType, typename DType2>
543
__device__ inline DType prelu_grad(const DType val,
544
const DType2 val2) {
545
return (val > 0) ? 0 : val;
546
}
547
548
template <typename DType, typename DType2>
549
__device__ inline mixed_type<DType2, DType>
550
gamma_implicit_grad(const DType a_in, const DType2 x_in) {
551
using OType = mixed_type<DType2, DType>;
552
const OType a = a_in;
553
const OType x = x_in;
554
if (x < 0.8f) {
555
OType numer = 1;
556
OType denom = a;
557
OType series1 = numer / denom;
558
OType series2 = numer / (denom * denom);
559
#pragma unroll
560
for (int i = 1; i <= 5; i++) {
561
numer *= -x / static_cast<DType>(i);
562
denom += 1;
563
series1 += numer / denom;
564
series2 += numer / (denom * denom);
565
}
566
OType pow_x_alpha = op::power(x, a);
567
OType gamma_pdf = op::power(x, a - 1) * op::exp(-x);
568
OType gamma_cdf = pow_x_alpha * series1;
569
OType gamma_cdf_alpha =
570
(op::log(x) - OType(special_functions::cephes::psi<float>(a))) *
571
gamma_cdf -
572
pow_x_alpha * series2;
573
OType result = -gamma_cdf_alpha / gamma_pdf;
574
return op::isnan(result) ? 0.f : result;
575
}
576
if (a > 8.0f) {
577
if (0.9f * a <= x && x <= 1.1f * a) {
578
OType numer_1 = 1 + 24 * a * (1 + 12 * a);
579
OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
580
65 * x * x / a + a * (107 + 3600 * x);
581
OType denom = 1244160 * (a * a) * (a * a);
582
return numer_1 * numer_2 / denom;
583
}
584
OType denom = op::sqrt(8 * a);
585
OType term2 = denom / (a - x);
586
OType term3 =
587
op::power(x - a - a * op::log(x / a), static_cast<OType>(-1.5));
588
OType term23 = (x < a) ? term2 - term3 : term2 + term3;
589
OType term1 = op::log(x / a) * term23 -
590
op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
591
OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a));
592
OType numer = x * term1;
593
return -stirling * numer / denom;
594
}
595
OType u = op::log(x / a);
596
OType v = op::log(a);
597
OType coef_uv[3][8] = {
598
{0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
599
0.10406089, 0.0014179084},
600
{0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
601
0.020070113, -0.0035938915, -0.00058392623},
602
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
603
-0.0021309326, 0.00085092367, -1.5247877e-07},
604
};
605
OType coef_v[8];
606
#pragma unroll
607
for (int i = 0; i < 8; i++) {
608
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
609
}
610
OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
611
OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
612
return op::exp(p / q);
613
}
614
615
} // namespace op
616
)code";
617
618
}
// namespace rtc
619
}
// namespace cuda
620
}
// namespace common
621
}
// namespace mxnet
622
623
#endif // MXNET_USE_CUDA
624
625
#endif // MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_
mxnet
namespace of mxnet
Definition:
api_registry.h:33
mxnet::common::cuda::rtc::grad_function_definitions
const char grad_function_definitions[]
Definition:
backward_functions-inl.h:337
mxnet::common::cuda::rtc::backward_function_definitions
const char backward_function_definitions[]
Definition:
backward_functions-inl.h:30
Generated on Thu Jan 5 2023 03:47:40 for mxnet by
1.8.17