mxnet
backward_functions-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_
22 
23 #if MXNET_USE_CUDA
24 
25 namespace mxnet {
26 namespace common {
27 namespace cuda {
28 namespace rtc {
29 
30 const char backward_function_definitions[] = R"code(
31 
32 namespace op {
33 
34 template <typename DType, typename DTypeGrad>
35 __device__ inline mixed_type<DTypeGrad, DType>
36 backward_relu(const DTypeGrad grad, const DType val) {
37  if (isnan(val)) return val;
38  return val > 0 ? grad : 0;
39 }
40 
41 template <typename DType, typename DTypeGrad>
42 __device__ inline mixed_type<DTypeGrad, DType>
43 backward_sigmoid(const DTypeGrad grad, const DType val) {
44  return grad * val * (1 - val);
45 }
46 
47 template <typename DType, typename DTypeGrad>
48 __device__ inline mixed_type<DTypeGrad, DType>
49 backward_log_sigmoid(const DTypeGrad grad, const DType val) {
50  return grad * (1 - op::exp(val));
51 }
52 
53 template <typename DType, typename DTypeGrad>
54 __device__ inline mixed_type<DTypeGrad, DType>
55 backward_softrelu(const DTypeGrad grad, const DType val) {
56  const mixed_type<DTypeGrad, DType> v = val;
57  return grad * sigmoid(v);
58 }
59 
60 template <typename DType, typename DTypeGrad>
61 __device__ inline mixed_type<DTypeGrad, DType>
62 backward_softsign(const DTypeGrad grad, const DType val) {
63  const mixed_type<DTypeGrad, DType> v = val;
64  const auto ap1 = 1 + op::abs(v);
65  return grad / (ap1 * ap1);
66 }
67 
68 template <typename DType, typename DTypeGrad>
69 __device__ inline mixed_type<DTypeGrad, DType>
70 backward_abs(const DTypeGrad grad, const DType val) {
71  const mixed_type<DTypeGrad, DType> v = val;
72  return grad * op::sign(v);
73 }
74 
75 template <typename DType, typename DTypeGrad>
76 __device__ inline mixed_type<DTypeGrad, DType>
77 backward_exp(const DTypeGrad grad, const DType val) {
78  const mixed_type<DTypeGrad, DType> v = val;
79  return grad * op::exp(v);
80 }
81 
82 template <typename DType, typename DTypeGrad>
83 __device__ inline mixed_type<DTypeGrad, DType>
84 backward_expm1(const DTypeGrad grad, const DType val) {
85  return backward_exp(grad, val);
86 }
87 
88 template <typename DType, typename DTypeGrad>
89 __device__ inline mixed_type<DTypeGrad, DType>
90 backward_log(const DTypeGrad grad, const DType val) {
91  return grad / val;
92 }
93 
94 template <typename DType, typename DTypeGrad>
95 __device__ inline mixed_type<DTypeGrad, DType>
96 backward_log10(const DTypeGrad grad, const DType val) {
97  return grad / (val * op::log(static_cast<DTypeGrad>(10)));
98 }
99 
100 template <typename DType, typename DTypeGrad>
101 __device__ inline mixed_type<DTypeGrad, DType>
102 backward_log2(const DTypeGrad grad, const DType val) {
103  return grad / (val * op::log(static_cast<DTypeGrad>(2)));
104 }
105 
106 template <typename DType, typename DTypeGrad>
107 __device__ inline mixed_type<DTypeGrad, DType>
108 backward_log1p(const DTypeGrad grad, const DType val) {
109  return grad / (1 + val);
110 }
111 
112 template <typename DType, typename DTypeGrad>
113 __device__ inline mixed_type<DTypeGrad, DType>
114 backward_sin(const DTypeGrad grad, const DType val) {
115  const mixed_type<DTypeGrad, DType> v = val;
116  return grad * op::cos(v);
117 }
118 
119 template <typename DType, typename DTypeGrad>
120 __device__ inline mixed_type<DTypeGrad, DType>
121 backward_cos(const DTypeGrad grad, const DType val) {
122  const mixed_type<DTypeGrad, DType> v = val;
123  return -grad * op::sin(v);
124 }
125 
126 // Uses output from tan
127 template <typename DType, typename DTypeGrad>
128 __device__ inline mixed_type<DTypeGrad, DType>
129 backward_tan(const DTypeGrad grad, const DType out) {
130  return grad * (out * out + 1);
131 }
132 
133 template <typename DType, typename DTypeGrad>
134 __device__ inline mixed_type<DTypeGrad, DType>
135 backward_arcsin(const DTypeGrad grad, const DType val) {
136  const mixed_type<DTypeGrad, DType> v = val;
137  return grad / op::sqrt(1 - v*v);
138 }
139 
140 template <typename DType, typename DTypeGrad>
141 __device__ inline mixed_type<DTypeGrad, DType>
142 backward_arccos(const DTypeGrad grad, const DType val) {
143  const mixed_type<DTypeGrad, DType> v = val;
144  return -grad / op::sqrt(1 - v*v);
145 }
146 
147 template <typename DType, typename DTypeGrad>
148 __device__ inline mixed_type<DTypeGrad, DType>
149 backward_arctan(const DTypeGrad grad, const DType val) {
150  return grad / (1 + val*val);
151 }
152 
153 template <typename DType, typename DTypeGrad>
154 __device__ inline mixed_type<DTypeGrad, DType>
155 backward_degrees(const DTypeGrad grad, const DType /* val */) {
156  return op::degrees(grad);
157 }
158 
159 template <typename DType, typename DTypeGrad>
160 __device__ inline mixed_type<DTypeGrad, DType>
161 backward_radians(const DTypeGrad grad, const DType /* val */) {
162  return op::radians(grad);
163 }
164 
165 template <typename DType, typename DTypeGrad>
166 __device__ inline mixed_type<DTypeGrad, DType>
167 backward_sinh(const DTypeGrad grad, const DType val) {
168  const mixed_type<DTypeGrad, DType> v = val;
169  return grad * op::cosh(v);
170 }
171 
172 template <typename DType, typename DTypeGrad>
173 __device__ inline mixed_type<DTypeGrad, DType>
174 backward_cosh(const DTypeGrad grad, const DType val) {
175  const mixed_type<DTypeGrad, DType> v = val;
176  return grad * op::sinh(v);
177 }
178 
179 // Uses tanh output
180 template <typename DType, typename DTypeGrad>
181 __device__ inline mixed_type<DTypeGrad, DType>
182 backward_tanh(const DTypeGrad grad, const DType out) {
183  return grad * (1 - out * out);
184 }
185 
186 template <typename DType, typename DTypeGrad>
187 __device__ inline mixed_type<DTypeGrad, DType>
188 backward_arcsinh(const DTypeGrad grad, const DType val) {
189  const mixed_type<DTypeGrad, DType> v = val;
190  return grad / op::sqrt(v * v + 1);
191 }
192 
193 template <typename DType, typename DTypeGrad>
194 __device__ inline mixed_type<DTypeGrad, DType>
195 backward_arccosh(const DTypeGrad grad, const DType val) {
196  const mixed_type<DTypeGrad, DType> v = val;
197  return grad / op::sqrt(v * v - 1);
198 }
199 
200 template <typename DType, typename DTypeGrad>
201 __device__ inline mixed_type<DTypeGrad, DType>
202 backward_arctanh(const DTypeGrad grad, const DType val) {
203  return grad / (1 - val * val);
204 }
205 
206 template <typename DType, typename DTypeGrad>
207 __device__ inline mixed_type<DTypeGrad, DType>
208 backward_mish(const DTypeGrad grad, const DType val) {
209  const auto softrelu = op::softrelu(val);
210  const auto tanh_sr = op::tanh(softrelu);
211  return grad * (tanh_sr + val * sigmoid(val) * (1 - tanh_sr * tanh_sr));
212 }
213 
214 template <typename DType, typename DTypeGrad>
215 __device__ inline mixed_type<DTypeGrad, DType>
216 backward_sqrt(const DTypeGrad grad, const DType out) {
217  return 0.5 * grad / out;
218 }
219 
220 template <typename DType, typename DTypeGrad>
221 __device__ inline mixed_type<DTypeGrad, DType>
222 backward_rsqrt(const DTypeGrad grad, const DType val) {
223  const mixed_type<DTypeGrad, DType> v = val;
224  const auto inv = 1 / v;
225  return -0.5 * grad * op::sqrt(inv) * inv;
226 }
227 
228 template <typename DType, typename DTypeGrad>
229 __device__ inline mixed_type<DTypeGrad, DType>
230 backward_cbrt(const DTypeGrad grad, const DType out) {
231  return grad / (3.0f * out * out);
232 }
233 
234 template <typename DType, typename DTypeGrad>
235 __device__ inline mixed_type<DTypeGrad, DType>
236 backward_rcbrt(const DTypeGrad grad, const DType val) {
237  const mixed_type<DTypeGrad, DType> v = val;
238  const auto inv = 1 / v;
239  return -1.f/3.f * grad * op::cbrt(inv) * inv;
240 }
241 
242 template <typename DType, typename DTypeGrad>
243 __device__ inline mixed_type<DTypeGrad, DType>
244 backward_square(const DTypeGrad grad, const DType val) {
245  return 2 * val * grad;
246 }
247 
248 template <typename DType, typename DType2>
249 __device__ inline DType div_rgrad(const DType val,
250  const DType2 val2) {
251  return -val / (val2 * val2);
252 }
253 
254 template <typename DType, typename DTypeGrad>
255 __device__ inline mixed_type<DTypeGrad, DType>
256 backward_clip(const DTypeGrad grad, const DType val,
257  const float a_min, const float a_max) {
258  if (val > a_max || val < a_min) {
259  return 0;
260  } else {
261  return grad;
262  }
263 }
264 
265 template <typename DType, typename DTypeGrad>
266 __device__ inline mixed_type<DTypeGrad, DType>
267 backward_reciprocal(const DTypeGrad grad, const DType val) {
268  return -grad / (val * val);
269 }
270 
271 template <typename DType, typename DTypeGrad>
272 __device__ inline mixed_type<DTypeGrad, DType>
273 backward_erf(const DTypeGrad grad, const DType val) {
274  using type = mixed_type<DTypeGrad, DType>;
275  const type v = val;
276  constexpr type my_pi = pi;
277  return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
278 }
279 
280 template <typename DType, typename DTypeGrad>
281 __device__ inline mixed_type<DTypeGrad, DType>
282 backward_erfinv(const DTypeGrad grad, const DType val) {
283  using type = mixed_type<DTypeGrad, DType>;
284  constexpr type my_pi = pi;
285  const type g = grad;
286  const type v = val;
287  return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
288 }
289 
290 template <typename DType, typename DTypeGrad>
291 __device__ inline mixed_type<DTypeGrad, DType>
292 backward_gamma(const DTypeGrad grad, const DType val) {
293  using type = mixed_type<DTypeGrad, DType>;
294  const type v = val;
295  if (type_util::is_same<DTypeGrad, double>::value) {
296  return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
297  } else {
298  return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
299  }
300 }
301 
302 template <typename DType, typename DTypeGrad>
303 __device__ inline mixed_type<DTypeGrad, DType>
304 backward_gammaln(const DTypeGrad grad, const DType val) {
305  using type = mixed_type<DTypeGrad, DType>;
306  const type v = val;
307  if (type_util::is_same<DTypeGrad, double>::value) {
308  return grad * op::special_functions::cephes::psi<double>(v);
309  } else {
310  return grad * op::special_functions::cephes::psi<float>(v);
311  }
312 }
313 
314 template <typename DType, typename DTypeGrad>
315 __device__ inline mixed_type<DTypeGrad, DType>
316 backward_digamma(const DTypeGrad grad, const DType val) {
317  using type = mixed_type<DTypeGrad, DType>;
318  const type v = val;
319  if (type_util::is_same<DTypeGrad, double>::value) {
320  return grad * op::special_functions::trigamma<double>(v);
321  } else {
322  return grad * op::special_functions::trigamma<float>(v);
323  }
324 }
325 
326 template <typename DType, typename DTypeGrad>
327 __device__ inline mixed_type<DTypeGrad, DType>
328 backward_gelu_erf(const DTypeGrad grad, const DType val) {
329  return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
330  val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
331 }
332 
333 } // namespace op
334 
335 )code";
336 
337 const char grad_function_definitions[] = R"code(
338 namespace op {
339 
340 template <typename DType, typename DType2>
341 __device__ inline mixed_type<DType, DType2>
342 rdiv_grad(const DType val,
343  const DType2 val2) {
344  return -val2 / (val * val);
345 }
346 
347 template <typename DType, typename DType2>
348 __device__ inline mixed_type<DType, DType2>
349 div_grad(const DType val,
350  const DType2 val2) {
351  const mixed_type<DType, DType2> temp = val2;
352  return op::reciprocal(temp);
353 }
354 
355 template <typename DType, typename DType2>
356 __device__ inline DType mod_grad(const DType val,
357  const DType2 val2) {
358  if (type_util::is_integral<DType>::value) {
359  return 0;
360  } else {
361  return 1;
362  }
363 }
364 
365 template <typename DType, typename DType2>
366 __device__ inline DType mod_rgrad(const DType val,
367  const DType2 val2) {
368  if (type_util::is_integral<DType>::value) {
369  return 0;
370  } else {
371  return -op::floor(val / val2);
372  }
373 }
374 
375 template <typename DType, typename DType2>
376 __device__ inline DType rmod_grad(const DType val,
377  const DType2 val2) {
378  if (type_util::is_integral<DType>::value) {
379  return 0;
380  } else {
381  return -op::floor(val2 / val);
382  }
383 }
384 
385 template <typename DType, typename DType2>
386 __device__ inline mixed_type<DType, DType2>
387 power_grad(const DType val,
388  const DType2 val2) {
389  return op::power(val, val2 - 1.f) * val2;
390 }
391 
392 template <typename DType, typename DType2>
393 __device__ inline mixed_type<DType, DType2>
394 power_rgrad(const DType val,
395  const DType2 val2) {
396  const mixed_type<DType, DType2> temp = val;
397  return op::power(val, val2) * op::log(temp);
398 }
399 
400 template <typename DType, typename DType2>
401 __device__ inline mixed_type<DType, DType2>
402 rpower_grad(const DType val,
403  const DType2 val2) {
404  const mixed_type<DType, DType2> temp = val2;
405  return val * op::log(temp);
406 }
407 
408 template <typename DType, typename DType2>
409 __device__ inline mixed_type<DType, DType2>
410 hypot_grad_left(const DType val,
411  const DType2 val2) {
412  return val / op::hypot(val, val2);
413 }
414 
415 template <typename DType, typename DType2>
416 __device__ inline mixed_type<DType, DType2>
417 hypot_grad_right(const DType val,
418  const DType2 val2) {
419  return val2 / op::hypot(val, val2);
420 }
421 
422 template <typename DType, typename DType2>
423 __device__ inline mixed_type<DType, DType2>
424 copysign_grad(const DType val,
425  const DType2 val2) {
426  return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
427 }
428 
429 template <typename DType, typename DType2>
430 __device__ inline mixed_type<DType, DType2>
431 bitwise_left_shift_grad(const DType val,
432  const DType2 val2) {
433  return op::power(static_cast<DType>(2), val2);
434 }
435 
436 template <typename DType, typename DType2>
437 __device__ inline mixed_type<DType, DType2>
438 bitwise_left_shift_rgrad(const DType val,
439  const DType2 val2) {
440  using type = mixed_type<DType, DType2>;
441  return val * op::power(static_cast<DType>(2), val2) * op::log(static_cast<type>(2));
442 }
443 
444 template <typename DType, typename DType2>
445 __device__ inline mixed_type<DType, DType2>
446 rbitwise_left_shift_grad(const DType val,
447  const DType2 val2) {
448  using type = mixed_type<DType, DType2>;
449  return val2 * op::power(static_cast<DType>(2), val) * op::log(static_cast<type>(2));
450 }
451 
452 template <typename DType, typename DType2>
453 __device__ inline mixed_type<DType, DType2>
454 bitwise_right_shift_grad(const DType val,
455  const DType2 val2) {
456  return op::power(0.5f, val2);
457 }
458 
459 template <typename DType, typename DType2>
460 __device__ inline mixed_type<DType, DType2>
461 bitwise_right_shift_rgrad(const DType val,
462  const DType2 val2) {
463  return val * op::power(0.5f, val2) * op::log(0.5f);
464 }
465 
466 template <typename DType, typename DType2>
467 __device__ inline mixed_type<DType, DType2>
468 rbitwise_right_shift_grad(const DType val,
469  const DType2 val2) {
470  return val2 * op::power(0.5f, val) * op::log(0.5f);
471 }
472 
473 template <typename DType, typename DType2>
474 __device__ inline mixed_type<DType, DType2>
475 arctan2_grad(const DType val,
476  const DType2 val2) {
477  return val2 / (val * val + val2 * val2);
478 }
479 
480 template <typename DType, typename DType2>
481 __device__ inline mixed_type<DType, DType2>
482 rarctan2_grad(const DType val,
483  const DType2 val2) {
484  return val / (val * val + val2 * val2);
485 }
486 
487 template <typename DType, typename DType2>
488 __device__ inline mixed_type<DType, DType2>
489 arctan2_rgrad(const DType val,
490  const DType2 val2) {
491  return -rarctan2_grad(val, val2);
492 }
493 
494 template <typename DType, typename DType2>
495 __device__ inline mixed_type<DType, DType2>
496 ldexp_grad(const DType val,
497  const DType2 val2) {
498  return op::power(static_cast<DType>(2), val2);
499 }
500 
501 template <typename DType, typename DType2>
502 __device__ inline mixed_type<DType, DType2>
503 rldexp_grad(const DType val,
504  const DType2 val2) {
505  using type = mixed_type<DType, DType2>;
506  return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
507 }
508 
509 template <typename DType, typename DType2>
510 __device__ inline mixed_type<DType, DType2>
511 logaddexp_grad(const DType val,
512  const DType2 val2) {
513  return op::exp(val) / (op::exp(val) + op::exp(val2));
514 }
515 
516 template <typename DType, typename DType2>
517 __device__ inline mixed_type<DType, DType2>
518 logaddexp_rgrad(const DType val,
519  const DType2 val2) {
520  return op::exp(val2) / (op::exp(val) + op::exp(val2));
521 }
522 
523 template <typename DType, typename DType2>
524 __device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
525  auto bsq = scalar * scalar;
526  auto ibsq = 1.0f / bsq;
527  if (val > ibsq) {
528  return 1;
529  } else if (val < -ibsq) {
530  return -1;
531  } else {
532  return bsq * val;
533  }
534 }
535 
536 template <typename DType, typename DType2>
537 __device__ inline DType2 xelu_grad(const DType val,
538  const DType2 val2) {
539  return (val > 0) ? 1 : val2;
540 }
541 
542 template <typename DType, typename DType2>
543 __device__ inline DType prelu_grad(const DType val,
544  const DType2 val2) {
545  return (val > 0) ? 0 : val;
546 }
547 
548 template <typename DType, typename DType2>
549 __device__ inline mixed_type<DType2, DType>
550 gamma_implicit_grad(const DType a_in, const DType2 x_in) {
551  using OType = mixed_type<DType2, DType>;
552  const OType a = a_in;
553  const OType x = x_in;
554  if (x < 0.8f) {
555  OType numer = 1;
556  OType denom = a;
557  OType series1 = numer / denom;
558  OType series2 = numer / (denom * denom);
559 #pragma unroll
560  for (int i = 1; i <= 5; i++) {
561  numer *= -x / static_cast<DType>(i);
562  denom += 1;
563  series1 += numer / denom;
564  series2 += numer / (denom * denom);
565  }
566  OType pow_x_alpha = op::power(x, a);
567  OType gamma_pdf = op::power(x, a - 1) * op::exp(-x);
568  OType gamma_cdf = pow_x_alpha * series1;
569  OType gamma_cdf_alpha =
570  (op::log(x) - OType(special_functions::cephes::psi<float>(a))) *
571  gamma_cdf -
572  pow_x_alpha * series2;
573  OType result = -gamma_cdf_alpha / gamma_pdf;
574  return op::isnan(result) ? 0.f : result;
575  }
576  if (a > 8.0f) {
577  if (0.9f * a <= x && x <= 1.1f * a) {
578  OType numer_1 = 1 + 24 * a * (1 + 12 * a);
579  OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
580  65 * x * x / a + a * (107 + 3600 * x);
581  OType denom = 1244160 * (a * a) * (a * a);
582  return numer_1 * numer_2 / denom;
583  }
584  OType denom = op::sqrt(8 * a);
585  OType term2 = denom / (a - x);
586  OType term3 =
587  op::power(x - a - a * op::log(x / a), static_cast<OType>(-1.5));
588  OType term23 = (x < a) ? term2 - term3 : term2 + term3;
589  OType term1 = op::log(x / a) * term23 -
590  op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
591  OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a));
592  OType numer = x * term1;
593  return -stirling * numer / denom;
594  }
595  OType u = op::log(x / a);
596  OType v = op::log(a);
597  OType coef_uv[3][8] = {
598  {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
599  0.10406089, 0.0014179084},
600  {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
601  0.020070113, -0.0035938915, -0.00058392623},
602  {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
603  -0.0021309326, 0.00085092367, -1.5247877e-07},
604  };
605  OType coef_v[8];
606 #pragma unroll
607  for (int i = 0; i < 8; i++) {
608  coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
609  }
610  OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
611  OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
612  return op::exp(p / q);
613 }
614 
615 } // namespace op
616 )code";
617 
618 } // namespace rtc
619 } // namespace cuda
620 } // namespace common
621 } // namespace mxnet
622 
623 #endif // MXNET_USE_CUDA
624 
625 #endif // MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::grad_function_definitions
const char grad_function_definitions[]
Definition: backward_functions-inl.h:337
mxnet::common::cuda::rtc::backward_function_definitions
const char backward_function_definitions[]
Definition: backward_functions-inl.h:30