mxnet
forward_functions-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_
22 
23 #if MXNET_USE_CUDA
24 
25 namespace mxnet {
26 namespace common {
27 namespace cuda {
28 namespace rtc {
29 
30 const char function_definitions_util[] = R"code(
31 
32 #define INT_MAX (2147483647)
33 
34 namespace op {
35 using type_util::mixed_type;
36 
37 template <typename DType>
38 struct LoadType {
39  using Type = DType;
40 };
41 
42 template <>
43 struct LoadType<half> {
44  using Type = float;
45 };
46 
47 template <typename DType>
48 __device__ inline typename LoadType<DType>::Type load(const DType input) {
49  return input;
50 }
51 
52 template <>
53 __device__ inline float load(const half input) {
54  return __half2float(input);
55 }
56 
57 template <typename DType1, typename DType2>
58 __device__ inline DType1 store(const DType2 input, DType1* ref) {
59  return input;
60 }
61 
62 template <typename DType>
63 __device__ inline half store(const DType input, half* ref) {
64  return __float2half(input);
65 }
66 
67 template <int ndim>
68 struct Shape {
69  int x[ndim];
70  size_t size;
71  __device__ inline const int& operator [](const int i) const {
72  return x[i];
73  }
74  __device__ inline int& operator [](const int i) {
75  return x[i];
76  }
77  __device__ inline void set(const int def) {
78  #pragma unroll
79  for (int i = 0; i < ndim; i++) {
80  x[i] = def;
81  }
82  }
83 };
84 
85 template <>
86 struct Shape<0> {
87  size_t size;
88 };
89 
90 template <int nvec, typename DType, int ndim>
91 __device__ inline vector::VectorizedStorage<DType, nvec> load_index(const DType * input, int i,
92  const Shape<ndim> &shape) {
93  using V = vector::VectorizedStorage<DType, nvec>;
94  if (i < shape.size) {
95  const auto* vector_input = reinterpret_cast<const typename V::LType *>(input + i);
96  return V(*vector_input);
97  } else {
98  return V({0});
99  }
100 }
101 
102 template <int nvec, typename DType, int ndim>
103 __device__ inline vector::VectorizedStorage<DType, nvec> global_load_index(const DType * input,
104  int i, const Shape<ndim> &shape) {
105  using V = vector::VectorizedStorage<DType, nvec>;
106  if (i < shape.size) {
107  const auto* vector_input = reinterpret_cast<const typename V::LType *>(input + i);
108  return V(__ldg(vector_input));
109  } else {
110  return V({0});
111  }
112 }
113 
114 template <int nvec, typename DType, int ndim>
115 __device__ inline vector::VectorizedStorage<DType, nvec> load_slice(const DType * input,
116  const Shape<ndim>& shape,
117  Shape<ndim> begin,
118  Shape<ndim> end,
119  int offset) {
120  int idx[nvec];
121 
122  Shape<ndim> ref_strides;
123  Shape<ndim> strides;
124  ref_strides[ndim-1] = 1;
125  strides[ndim-1] = 1;
126  #pragma unroll
127  for (int dim = ndim-1; dim >=0; dim--) {
128  if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
129  if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
130  if (end[dim] == INT_MAX) end[dim] = shape[dim];
131  if (dim > 0) {
132  ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
133  strides[dim-1] = strides[dim] * shape[dim];
134  }
135  }
136  #pragma unroll
137  for (int j = 0; j < nvec; j++) {
138  idx[j] = 0;
139  int ref_idx = offset + j;
140  #pragma unroll
141  for (int dim = 0; dim < ndim; dim++) {
142  int stride = ref_strides[dim];
143  if (shape[dim] > 1) {
144  idx[j] += (ref_idx / stride + begin[dim]) * strides[dim];
145  }
146  ref_idx = ref_idx % stride;
147  }
148  }
149  vector::VectorizedStorage<DType, nvec> ret;
150  #pragma unroll
151  for (int j = 0; j < nvec; j++) {
152  ret.scratch_.separate[j] = idx[j] < shape.size ? *(input + idx[j]) : DType {};
153  }
154  return ret;
155 }
156 
157 template <int nvec, typename DType, int ndim>
158 __device__ inline vector::VectorizedStorage<DType, nvec> fast_load_slice(const DType * input,
159  const Shape<ndim>& shape,
160  Shape<ndim> begin,
161  Shape<ndim> end,
162  int offset) {
163  int idx = 0;
164 
165  Shape<ndim> ref_strides;
166  Shape<ndim> strides;
167  ref_strides[ndim-1] = 1;
168  strides[ndim-1] = 1;
169  #pragma unroll
170  for (int dim = ndim-1; dim >=0; dim--) {
171  if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
172  if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
173  if (end[dim] == INT_MAX) end[dim] = shape[dim];
174  if (dim > 0) {
175  ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
176  strides[dim-1] = strides[dim] * shape[dim];
177  }
178  }
179  int ref_idx = offset;
180  #pragma unroll
181  for (int dim = 0; dim < ndim; dim++) {
182  int stride = ref_strides[dim];
183  if (shape[dim] > 1) {
184  idx += (ref_idx / stride + begin[dim]) * strides[dim];
185  }
186  ref_idx = ref_idx % stride;
187  }
188  return global_load_index<nvec>(input, idx, shape);
189 }
190 
191 template <int nvec, typename DType, int ndim>
192 __device__ inline void store_index(const vector::VectorizedStorage<DType, nvec> value, int i,
193  DType * output, const Shape<ndim>& shape) {
194  if (i < (shape.size + nvec - 1) / nvec) {
195  auto vector_output = reinterpret_cast<
196  typename vector::VectorizedStorage<DType, nvec>::LType *>(output);
197  vector_output[i] = value.scratch_.aligned;
198  }
199 }
200 
201 template <int nvec, typename DType, int ndim>
202 __device__ inline void store_add_index(const vector::VectorizedStorage<DType, nvec> value, int i,
203  DType * output, const Shape<ndim>& shape) {
204  if (i < (shape.size + nvec - 1) / nvec) {
205  auto vector_output = reinterpret_cast<
206  typename vector::VectorizedStorage<DType, nvec>::LType *>(output);
207  vector::VectorizedStorage<DType, nvec> ret(vector_output[i]);
208  ret += value;
209  vector_output[i] = ret.scratch_.aligned;
210  }
211 }
212 
213 } // namespace op
214 )code";
215 
216 const char function_definitions_binary[] = R"code(
217 namespace op {
218 
219 template <typename DType>
220 __device__ inline bool isnan(const DType val) {
221  return util::isnan(val);
222 }
223 
224 template <typename DType>
225 __device__ inline bool_t isinf(const DType val) {
226  return util::isinf(val);
227 }
228 
229 template <typename DType>
230 __device__ inline bool_t isposinf(const DType val) {
231  return util::isinf(val) && (val > 0);
232 }
233 
234 template <typename DType>
235 __device__ inline bool_t isneginf(const DType val) {
236  return util::isinf(val) && (val < 0);
237 }
238 
239 template <typename DType>
240 __device__ inline bool_t isfinite(const DType val) {
241  return !op::isnan(val) && !op::isinf(val);
242 }
243 
244 template <typename DType, typename DType2>
245 __device__ inline mixed_type<DType, DType2>
246 add(const DType a, const DType2 b) {
247  return a + b;
248 }
249 
250 template <typename DType, typename DType2>
251 __device__ inline mixed_type<DType, DType2>
252 sub(const DType a, const DType2 b) {
253  return a - b;
254 }
255 
256 template <typename DType, typename DType2>
257 __device__ inline mixed_type<DType, DType2>
258 rsub(const DType a, const DType2 b) {
259  return b - a;
260 }
261 
262 template <typename DType, typename DType2>
263 __device__ inline mixed_type<DType, DType2>
264 floor_divide(const DType a, const DType2 b) {
265  if (type_util::has_double_or_integral<DType, DType2>::value) {
266  return ::floor((double)a / (double)b);
267  } else {
268  return ::floorf((float)a / (float)b);
269  }
270 }
271 
272 template <typename DType, typename DType2>
273 __device__ inline mixed_type<DType, DType2>
274 rfloor_divide(const DType a, const DType2 b) {
275  if (type_util::has_double_or_integral<DType, DType2>::value) {
276  return ::floor((double)b / (double)a);
277  } else {
278  return ::floorf((float)b / (float)a);
279  }
280 }
281 
282 template <typename DType, typename DType2>
283 __device__ inline mixed_type<DType, DType2>
284 mul(const DType a, const DType2 b) {
285  return a * b;
286 }
287 
288 template <typename DType, typename DType2>
289 __device__ inline mixed_type<DType, DType2>
290 div(const DType a, const DType2 b) {
291  return a / b;
292 }
293 
294 template <typename DType, typename DType2>
295 __device__ inline mixed_type<DType, DType2>
296 rdiv(const DType a, const DType2 b) {
297  return b / a;
298 }
299 
300 #define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \
301 template <typename DType, typename DType2> \
302 __device__ inline mixed_type<DType, DType2> \
303 name (const DType a, const DType2 b) { \
304  if (type_util::has_double_or_integral<DType, DType2>::value) { \
305  return double_version ((double)a, (double)b); \
306  } else { \
307  return float_version ((float)a, (float)b); \
308  } \
309 }
310 
311 template <typename DType, typename DType2>
312 __device__ inline mixed_type<DType, DType2>
313 power (const DType a, const DType2 b) {
314  if (type_util::has_double<DType, DType2>::value) {
315  return ::pow ((double)a, (double)b); \
316  } else {
317  return ::powf ((float)a, (float)b);
318  }
319 }
320 
321 template <typename DType, typename DType2>
322 __device__ inline mixed_type<DType, DType2>
323 rpow(const DType a, const DType2 b) {
324  return power(b, a);
325 }
326 
327 template <typename DType, typename DType2>
328 __device__ inline mixed_type<DType, DType2>
329 max(const DType a, const DType2 b) {
330  if (isnan(a)) return a;
331  return a > b ? a : b;
332 }
333 
334 template <typename DType, typename DType2>
335 __device__ inline mixed_type<DType, DType2>
336 fmax(const DType a, const DType2 b) {
337  if (isnan(b)) return a;
338  return a > b ? a : b;
339 }
340 
341 template <typename DType, typename DType2>
342 __device__ inline mixed_type<DType, DType2>
343 min(const DType a, const DType2 b) {
344  if (isnan(a)) return a;
345  return a < b ? a : b;
346 }
347 
348 template <typename DType, typename DType2>
349 __device__ inline mixed_type<DType, DType2>
350 fmin(const DType a, const DType2 b) {
351  if (isnan(b)) return a;
352  return a < b ? a : b;
353 }
354 
355 DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf)
356 
357 template <typename DType, typename DType2>
358 __device__ inline mixed_type<DType, DType2>
359 mod(const DType a, const DType2 b) {
360  if (b == 0) {
361  return 0;
362  }
363  const double ad = static_cast<double>(a);
364  const double bd = static_cast<double>(b);
365  if (bd < 0) {
366  if (ad < 0) {
367  return -::fmod(-ad, -bd);
368  } else {
369  return ::fmod(ad, -bd) +
370  (::fmod(ad, -bd) != 0 ? bd : 0);
371  }
372  } else {
373  if (ad < 0) {
374  return -::fmod(-ad, bd) +
375  (::fmod(-ad, bd) != 0 ? bd : 0);
376  } else {
377  return ::fmod(ad, bd);
378  }
379  }
380 }
381 
382 template <typename DType, typename DType2>
383 __device__ inline mixed_type<DType, DType2>
384 fmod(const DType a, const DType2 b) {
385  if (b == 0) {
386  return 0;
387  }
388  return ::fmod(static_cast<double>(a), static_cast<double>(b));
389 }
390 
391 template <typename DType, typename DType2>
392 __device__ inline mixed_type<DType, DType2>
393 rmod(const DType a, const DType2 b) {
394  return op::mod(b, a);
395 }
396 
397 template <typename DType, typename DType2>
398 __device__ inline mixed_type<DType, DType2>
399 rfmod(const DType a, const DType2 b) {
400  return op::fmod(b, a);
401 }
402 
403 template <typename DType, typename DType2>
404 __device__ inline DType equal(const DType a, const DType2 b) {
405  const mixed_type<DType, DType2> real_a = a;
406  const mixed_type<DType, DType2> real_b = b;
407  return real_a == real_b ? 1 : 0;
408 }
409 
410 template <typename DType, typename DType2>
411 __device__ inline DType not_equal(const DType a, const DType2 b) {
412  const mixed_type<DType, DType2> real_a = a;
413  const mixed_type<DType, DType2> real_b = b;
414  return real_a != real_b ? 1 : 0;
415 }
416 
417 template <typename DType, typename DType2>
418 __device__ inline DType greater(const DType a, const DType2 b) {
419  const mixed_type<DType, DType2> real_a = a;
420  const mixed_type<DType, DType2> real_b = b;
421  return real_a > real_b ? 1 : 0;
422 }
423 
424 template <typename DType, typename DType2>
425 __device__ inline DType greater_equal(const DType a, const DType2 b) {
426  const mixed_type<DType, DType2> real_a = a;
427  const mixed_type<DType, DType2> real_b = b;
428  return real_a >= real_b ? 1 : 0;
429 }
430 
431 template <typename DType, typename DType2>
432 __device__ inline DType less(const DType a, const DType2 b) {
433  const mixed_type<DType, DType2> real_a = a;
434  const mixed_type<DType, DType2> real_b = b;
435  return real_a < real_b ? 1 : 0;
436 }
437 
438 template <typename DType, typename DType2>
439 __device__ inline DType less_equal(const DType a, const DType2 b) {
440  const mixed_type<DType, DType2> real_a = a;
441  const mixed_type<DType, DType2> real_b = b;
442  return real_a <= real_b ? 1 : 0;
443 }
444 
445 template <typename DType, typename DType2>
446 __device__ inline bool_t np_equal(const DType a, const DType2 b) {
447  const mixed_type<DType, DType2> real_a = a;
448  const mixed_type<DType, DType2> real_b = b;
449  return real_a == real_b ? true : false;
450 }
451 
452 template <typename DType, typename DType2>
453 __device__ inline bool_t np_not_equal(const DType a, const DType2 b) {
454  const mixed_type<DType, DType2> real_a = a;
455  const mixed_type<DType, DType2> real_b = b;
456  return real_a != real_b ? true : false;
457 }
458 
459 template <typename DType, typename DType2>
460 __device__ inline bool_t np_greater(const DType a, const DType2 b) {
461  const mixed_type<DType, DType2> real_a = a;
462  const mixed_type<DType, DType2> real_b = b;
463  return real_a > real_b ? true : false;
464 }
465 
466 template <typename DType, typename DType2>
467 __device__ inline bool_t np_greater_equal(const DType a, const DType2 b) {
468  const mixed_type<DType, DType2> real_a = a;
469  const mixed_type<DType, DType2> real_b = b;
470  return real_a >= real_b ? true : false;
471 }
472 
473 template <typename DType, typename DType2>
474 __device__ inline bool_t np_less(const DType a, const DType2 b) {
475  const mixed_type<DType, DType2> real_a = a;
476  const mixed_type<DType, DType2> real_b = b;
477  return real_a < real_b ? true : false;
478 }
479 
480 template <typename DType, typename DType2>
481 __device__ inline bool_t np_less_equal(const DType a, const DType2 b) {
482  const mixed_type<DType, DType2> real_a = a;
483  const mixed_type<DType, DType2> real_b = b;
484  return real_a <= real_b ? true : false;
485 }
486 
487 template <typename DType, typename DType2>
488 __device__ inline DType logical_and(const DType a, const DType2 b) {
489  return a && b ? 1 : 0;
490 }
491 
492 template <typename DType, typename DType2>
493 __device__ inline DType logical_or(const DType a, const DType2 b) {
494  return a || b ? 1 : 0;
495 }
496 
497 template <typename DType, typename DType2>
498 __device__ inline DType logical_xor(const DType a, const DType2 b) {
499  return ((a || b) && !(a && b)) ? 1 : 0;
500 }
501 
502 template <typename DType, typename DType2>
503 __device__ inline DType copysign(const DType a, const DType2 b) {
504  return (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a;
505 }
506 
507 template <typename DType, typename DType2>
508 __device__ inline DType2 rcopysign(const DType a, const DType2 b) {
509  return copysign(b, a);
510 }
511 
512 template <typename DType, typename DType2>
513 __device__ inline mixed_type<DType, DType2>
514 lcm(const DType a, const DType2 b) {
515  if (type_util::is_integral<DType>::value &&
516  type_util::is_integral<DType2>::value) {
517  DType A = a;
518  DType2 B = b;
519  // minus cases.
520  if (a < 0) {
521  A = -a;
522  }
523  if (b < 0) {
524  B = -b;
525  }
526  // handle zero-valued cases.
527  DType c;
528  if (a == 0 || b == 0) {
529  c = 0;
530  } else {
531  DType tmp;
532  DType tmp_a = A;
533  DType tmp_b = B;
534  if (A < B) {
535  tmp = A;
536  A = B;
537  B = tmp;
538  }
539  while (A % B != 0) {
540  A = A % B;
541  tmp = A;
542  A = B;
543  B = tmp;
544  }
545  c = tmp_a / B * tmp_b;
546  }
547  return c;
548  } else {
549  return 0;
550  }
551 }
552 
553 template <typename DType, typename DType2>
554 __device__ inline mixed_type<DType, DType2>
555 gcd(const DType a, const DType2 b) {
556  if (type_util::is_integral<DType>::value &&
557  type_util::is_integral<DType2>::value) {
558  DType A = a;
559  DType2 B = b;
560  // minus cases.
561  if (a < 0) {
562  A = -a;
563  }
564  if (b < 0) {
565  B = -b;
566  }
567  // handle zero-valued cases.
568  DType c;
569  if (a == 0 && b != 0) {
570  c = B;
571  } else if (b == 0 && a != 0) {
572  c = A;
573  } else if (a == 0 && b == 0) {
574  c = 0;
575  } else {
576  DType tmp;
577  if (A < B) {
578  tmp = A;
579  A = B;
580  B = tmp;
581  }
582  while (A % B != 0) {
583  A = A % B;
584  tmp = A;
585  A = B;
586  B = tmp;
587  }
588  c = B;
589  }
590  return c;
591  } else {
592  return 0;
593  }
594 }
595 
596 template <typename DType, typename DType2>
597 __device__ inline mixed_type<DType, DType2> bitwise_xor(const DType a,
598  const DType2 b) {
599  const mixed_type<DType, DType2> real_a = a;
600  const mixed_type<DType, DType2> real_b = b;
601  return real_a ^ real_b;
602 }
603 
604 template <typename DType, typename DType2>
605 __device__ inline mixed_type<DType, DType2> bitwise_or(const DType a,
606  const DType2 b) {
607  const mixed_type<DType, DType2> real_a = a;
608  const mixed_type<DType, DType2> real_b = b;
609  return real_a | real_b;
610 }
611 
612 template <typename DType, typename DType2>
613 __device__ inline mixed_type<DType, DType2> bitwise_and(const DType a,
614  const DType2 b) {
615  const mixed_type<DType, DType2> real_a = a;
616  const mixed_type<DType, DType2> real_b = b;
617  return real_a & real_b;
618 }
619 
620 template <typename DType, typename DType2>
621 __device__ inline mixed_type<DType, DType2> bitwise_left_shift(const DType a,
622  const DType2 b) {
623  const mixed_type<DType, DType2> real_a = a;
624  const mixed_type<DType, DType2> real_b = b;
625  return real_a << real_b;
626 }
627 
628 template <typename DType, typename DType2>
629 __device__ inline mixed_type<DType, DType2> rbitwise_left_shift(const DType a,
630  const DType2 b) {
631  const mixed_type<DType, DType2> real_a = a;
632  const mixed_type<DType, DType2> real_b = b;
633  return real_b << real_a;
634 }
635 
636 template <typename DType, typename DType2>
637 __device__ inline mixed_type<DType, DType2> bitwise_right_shift(const DType a,
638  const DType2 b) {
639  const mixed_type<DType, DType2> real_a = a;
640  const mixed_type<DType, DType2> real_b = b;
641  return real_a >> real_b;
642 }
643 
644 template <typename DType, typename DType2>
645 __device__ inline mixed_type<DType, DType2> rbitwise_right_shift(const DType a,
646  const DType2 b) {
647  const mixed_type<DType, DType2> real_a = a;
648  const mixed_type<DType, DType2> real_b = b;
649  return real_b >> real_a;
650 }
651 
652 DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)
653 
654 template <typename DType, typename DType2>
655 __device__ inline mixed_type<DType, DType2>
656 rarctan2(const DType a, const DType2 b) {
657  return arctan2(b, a);
658 }
659 
660 template <typename DType, typename DType2>
661 __device__ inline mixed_type<DType, DType2>
662 ldexp(const DType a, const DType2 b) {
663  if (type_util::has_double_or_integral<DType, DType2>::value) {
664  return a * ::pow(2.0, static_cast<double>(b));
665  } else {
666  return a * ::powf(2.0f, static_cast<float>(b));
667  }
668 }
669 
670 template <typename DType, typename DType2>
671 __device__ inline mixed_type<DType, DType2>
672 rldexp(const DType a, const DType2 b) {
673  return ldexp(b, a);
674 }
675 
676 template <typename DType, typename DType2>
677 __device__ inline mixed_type<DType, DType2>
678 logaddexp(const DType a, const DType2 b) {
679  if (type_util::has_double_or_integral<DType, DType2>::value) {
680  return ::log(::exp(static_cast<double>(a)) + ::exp(static_cast<double>(b)));
681  } else {
682  return ::log(::expf(static_cast<float>(a)) + ::expf(static_cast<float>(b)));
683  }
684 }
685 
686 #undef DEFINE_BINARY_MATH_FUNC
687 
688 template <typename DType, typename DType2>
689 __device__ inline bool np_logical_and(const DType val, const DType2 val2) {
690  return (val && val2) ? true : false;
691 }
692 
693 template <typename DType, typename DType2>
694 __device__ inline bool np_logical_or(const DType val, const DType2 val2) {
695  return (val || val2) ? true : false;
696 }
697 
698 template <typename DType, typename DType2>
699 __device__ inline bool np_logical_xor(const DType val, const DType2 val2) {
700  return ((val || val2) && !(val && val2)) ? true : false;
701 }
702 
703 template <typename DType, typename DType2>
704 __device__ inline DType left(const DType left_val, const DType2 right_val) {
705  return left_val;
706 }
707 
708 template <typename DType, typename DType2>
709 __device__ inline DType2 right(const DType left_val, const DType2 right_val) {
710  return right_val;
711 }
712 
713 } // namespace op
714 )code";
715 
716 const char function_definitions_unary[] = R"code(
717 namespace op {
718 
719 template <typename DType>
720 __device__ inline DType identity(const DType val) {
721  return val;
722 }
723 
724 template <typename DType>
725 __device__ inline DType negation(const DType val) {
726  return -val;
727 }
728 
729 template <typename OutType, typename DType>
730 __device__ inline typename LoadType<OutType>::Type cast(const DType val) {
731  return static_cast<typename LoadType<OutType>::Type>(val);
732 }
733 
734 // activations
735 
736 template <typename DType>
737 __device__ inline DType relu(const DType val) {
738  return (isnan(val) || val > 0) ? val : 0;
739 }
740 
741 template <typename DType>
742 __device__ inline DType sigmoid(const DType val) {
743  if (type_util::has_double_or_integral<DType>::value) {
744  return 1./(1 + ::exp(-val));
745  } else {
746  return 1.f/(1 + expf(-val));
747  }
748 }
749 
750 template <typename DType>
751 __device__ inline DType log_sigmoid(const DType val) {
752  if (type_util::has_double_or_integral<DType>::value) {
753  return ::log(1./(1 + ::exp(-val)));
754  } else {
755  return ::logf(1.f/(1 + expf(-val)));
756  }
757 }
758 
759 template <typename DType>
760 __device__ inline DType softrelu(const DType val) {
761  // Avoid overflow of exp for large inputs.
762  // The threshold 20 is chosen such that softrelu(a) = a
763  // for a > 20 using floating precision.
764  if (val > 20) return val;
765  if (type_util::has_double_or_integral<DType>::value) {
766  return ::log(1 + ::exp(val));
767  } else {
768  return logf(1 + expf(val));
769  }
770 }
771 
772 template <typename DType>
773 __device__ inline DType softsign(const DType val) {
774  if (type_util::has_double_or_integral<DType>::value) {
775  return val / (1 + fabs(val));
776  } else {
777  return val / (1 + fabsf(val));
778  }
779 }
780 
781 // exp and log
782 
783 #define DEFINE_UNARY_MATH_FUNC(name, double_version, float_version) \
784 template <typename DType> \
785 __device__ inline DType name (const DType a) { \
786  if (type_util::has_double_or_integral<DType>::value) { \
787  return double_version ((double)a); \
788  } else { \
789  return float_version (a); \
790  } \
791 }
792 
793 DEFINE_UNARY_MATH_FUNC(exp, ::exp, ::expf)
794 DEFINE_UNARY_MATH_FUNC(expm1, ::expm1, ::expm1f)
795 DEFINE_UNARY_MATH_FUNC(log, ::log, ::logf)
796 DEFINE_UNARY_MATH_FUNC(log10, ::log10, ::log10f)
797 DEFINE_UNARY_MATH_FUNC(log2, ::log2, ::log2f)
798 DEFINE_UNARY_MATH_FUNC(log1p, ::log1p, ::log1pf)
799 
800 // trigonometric
801 
802 constexpr double pi = 3.14159265358979323846;
803 
804 template <typename DType>
805 __device__ inline DType degrees(const DType val) {
806  if (type_util::has_double_or_integral<DType>::value) {
807  return (val / pi) * 180;
808  } else {
809  return (val / static_cast<float>(pi)) * 180.f;
810  }
811 }
812 
813 template <typename DType>
814 __device__ inline DType radians(const DType val) {
815  if (type_util::has_double_or_integral<DType>::value) {
816  return (val / 180.0) * pi;
817  } else {
818  return (val / 180.0f) * static_cast<float>(pi);
819  }
820 }
821 
822 DEFINE_UNARY_MATH_FUNC(sin, ::sin, ::sinf)
823 DEFINE_UNARY_MATH_FUNC(cos, ::cos, ::cosf)
824 DEFINE_UNARY_MATH_FUNC(tan, ::tan, ::tanf)
825 DEFINE_UNARY_MATH_FUNC(arcsin, ::asin, ::asinf)
826 DEFINE_UNARY_MATH_FUNC(arccos, ::acos, ::acosf)
827 DEFINE_UNARY_MATH_FUNC(arctan, ::atan, ::atanf)
828 
829 DEFINE_UNARY_MATH_FUNC(sinh, ::sinh, ::sinhf)
830 DEFINE_UNARY_MATH_FUNC(cosh, ::cosh, ::coshf)
831 DEFINE_UNARY_MATH_FUNC(tanh, ::tanh, ::tanhf)
832 DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf)
833 DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf)
834 DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf)
835 
836 template <typename DType>
837 __device__ inline DType mish(const DType val) {
838  return val * op::tanh(op::softrelu(val));
839 }
840 
841 // sqrt
842 
843 DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf)
844 DEFINE_UNARY_MATH_FUNC(rsqrt, ::rsqrt, ::rsqrtf)
845 DEFINE_UNARY_MATH_FUNC(cbrt, ::cbrt, ::cbrtf)
846 DEFINE_UNARY_MATH_FUNC(rcbrt, ::rcbrt, ::rcbrtf)
847 
848 template <typename DType>
849 __device__ inline DType square(const DType val) {
850  return val * val;
851 }
852 
853 template <typename DType, typename... DTypes>
854 __device__ inline typename LoadType<DType>::Type zero(const DType val, const DTypes... args) {
855  return 0;
856 }
857 
858 template <typename DType>
859 __device__ inline typename LoadType<DType>::Type zero() {
860  return 0;
861 }
862 
863 template <typename DType, typename... DTypes>
864 __device__ inline typename LoadType<DType>::Type one(const DType val, const DTypes... args) {
865  return 1;
866 }
867 
868 template <typename DType>
869 __device__ inline typename LoadType<DType>::Type one() {
870  return 1;
871 }
872 
873 template <typename DType, typename... DTypes>
874 __device__ inline typename LoadType<DType>::Type negone(const DType val, const DTypes... args) {
875  return -1;
876 }
877 
878 template <typename DType>
879 __device__ inline typename LoadType<DType>::Type negone() {
880  return -1;
881 }
882 
883 template <typename DType>
884 __device__ inline DType round(const DType val) {
885  if (type_util::has_double<DType>::value) {
886  return ::round((double)val);
887  } else if (type_util::is_integral<DType>::value) {
888  return val;
889  } else {
890  return ::roundf(val);
891  }
892 }
893 
894 template <typename DType>
895 __device__ inline DType floor(const DType val) {
896  if (type_util::has_double<DType>::value) {
897  return ::floor((double)val);
898  } else if (type_util::is_integral<DType>::value) {
899  return val;
900  } else {
901  return ::floorf(val);
902  }
903 }
904 
905 template <typename DType>
906 __device__ inline DType ceil(const DType val) {
907  if (type_util::has_double<DType>::value) {
908  return ::ceil((double)val);
909  } else if (type_util::is_integral<DType>::value) {
910  return val;
911  } else {
912  return ::ceilf(val);
913  }
914 }
915 
916 template <typename DType>
917 __device__ inline DType rint(const DType val) {
918  if (type_util::has_double<DType>::value) {
919  return ::rint((double)val);
920  } else if (type_util::is_integral<DType>::value) {
921  return val;
922  } else {
923  return ::rintf(val);
924  }
925 }
926 
927 template <typename DType>
928 __device__ inline DType fix(const DType val) {
929  const auto f = floor(val);
930  const auto c = ceil(val);
931  return (f > 0 ? f : -f) < (c > 0 ? c : -c) ? f : c;
932 }
933 
934 template <typename DType>
935 __device__ inline DType trunc(const DType val) {
936  if (type_util::has_double<DType>::value) {
937  return ::trunc((double)val);
938  } else if (type_util::is_integral<DType>::value) {
939  return val;
940  } else {
941  return ::truncf(val);
942  }
943 }
944 
945 template <typename DType>
946 __device__ inline DType clip(const DType val, const float a_min, const float a_max) {
947  if (val > a_max) {
948  return a_max;
949  } else if (val < a_min) {
950  return a_min;
951  } else {
952  return val;
953  }
954 }
955 
956 template <typename DType>
957 __device__ inline DType sign(const DType val) {
958  if (val < 0) return -1;
959  return val > 0 ? 1 : 0;
960 }
961 
962 template <typename DType>
963 __device__ inline DType reciprocal(const DType val) {
964  return 1.0f / val;
965 }
966 
967 DEFINE_UNARY_MATH_FUNC(abs, ::fabs, ::fabsf)
968 DEFINE_UNARY_MATH_FUNC(gamma, ::tgamma, ::tgammaf)
969 DEFINE_UNARY_MATH_FUNC(gammaln, ::lgamma, ::lgammaf)
970 DEFINE_UNARY_MATH_FUNC(erf, ::erf, ::erff)
971 DEFINE_UNARY_MATH_FUNC(erfinv, ::erfinv, ::erfinvf)
972 
973 template <typename DType>
974 __device__ inline DType gelu_erf(const DType val) {
975  return 0.5f * val * (1.0f + op::erf(val / op::sqrt(2.0f)));
976 }
977 
978 template <typename DType1, typename DType2>
979 __device__ inline DType1 smooth_l1(const DType1 val, const DType2 scalar) {
980  const auto bsq = scalar * scalar;
981  const auto ibsq = 1.0f / bsq;
982  if (val > ibsq) {
983  return val - 0.5f * ibsq;
984  } else if (val < -ibsq) {
985  return -val - 0.5f * ibsq;
986  } else {
987  return 0.5f * val * val * bsq;
988  }
989 }
990 
991 template <typename DType>
992 __device__ inline DType digamma(const DType val) {
993  if (type_util::has_double_or_integral<DType>::value) {
994  return special_functions::cephes::psi<double>(val);
995  } else {
996  return special_functions::cephes::psi<float>(val);
997  }
998 }
999 
1000 template <typename DType>
1001 __device__ inline DType logical_not(const DType val) {
1002  return val != DType(0) ? DType(0) : DType(1);
1003 }
1004 
1005 template <typename DType>
1006 __device__ inline bool_t np_logical_not(const DType val) {
1007  return !static_cast<bool>(val);
1008 }
1009 
1010 template <typename DType>
1011 __device__ inline bool_t NonZero(const DType val) {
1012  return val != 0;
1013 }
1014 
1015 #undef DEFINE_UNARY_MATH_FUNC
1016 
1017 template <typename DType>
1018 __device__ inline DType bitwise_not(const DType a) {
1019  if (type_util::is_same<DType, bool_t>::value) {
1020  return !a;
1021  } else {
1022  return ~static_cast<int64>(a);
1023  }
1024 }
1025 
1026 } // namespace op
1027 
1028 )code";
1029 
1030 } // namespace rtc
1031 } // namespace cuda
1032 } // namespace common
1033 } // namespace mxnet
1034 
1035 #endif // MXNET_USE_CUDA
1036 
1037 #endif // MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::function_definitions_util
const char function_definitions_util[]
Definition: forward_functions-inl.h:30
mxnet::common::cuda::rtc::function_definitions_binary
const char function_definitions_binary[]
Definition: forward_functions-inl.h:216
mxnet::common::cuda::rtc::function_definitions_unary
const char function_definitions_unary[]
Definition: forward_functions-inl.h:716