Go to the documentation of this file.
20 #ifndef MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_
32 #define INT_MAX (2147483647)
35 using type_util::mixed_type;
37 template <typename DType>
43 struct LoadType<half> {
47 template <typename DType>
48 __device__ inline typename LoadType<DType>::Type load(const DType input) {
53 __device__ inline float load(const half input) {
54 return __half2float(input);
57 template <typename DType1, typename DType2>
58 __device__ inline DType1 store(const DType2 input, DType1* ref) {
62 template <typename DType>
63 __device__ inline half store(const DType input, half* ref) {
64 return __float2half(input);
71 __device__ inline const int& operator [](const int i) const {
74 __device__ inline int& operator [](const int i) {
77 __device__ inline void set(const int def) {
79 for (int i = 0; i < ndim; i++) {
90 template <int nvec, typename DType, int ndim>
91 __device__ inline vector::VectorizedStorage<DType, nvec> load_index(const DType * input, int i,
92 const Shape<ndim> &shape) {
93 using V = vector::VectorizedStorage<DType, nvec>;
95 const auto* vector_input = reinterpret_cast<const typename V::LType *>(input + i);
96 return V(*vector_input);
102 template <int nvec, typename DType, int ndim>
103 __device__ inline vector::VectorizedStorage<DType, nvec> global_load_index(const DType * input,
104 int i, const Shape<ndim> &shape) {
105 using V = vector::VectorizedStorage<DType, nvec>;
106 if (i < shape.size) {
107 const auto* vector_input = reinterpret_cast<const typename V::LType *>(input + i);
108 return V(__ldg(vector_input));
114 template <int nvec, typename DType, int ndim>
115 __device__ inline vector::VectorizedStorage<DType, nvec> load_slice(const DType * input,
116 const Shape<ndim>& shape,
122 Shape<ndim> ref_strides;
124 ref_strides[ndim-1] = 1;
127 for (int dim = ndim-1; dim >=0; dim--) {
128 if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
129 if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
130 if (end[dim] == INT_MAX) end[dim] = shape[dim];
132 ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
133 strides[dim-1] = strides[dim] * shape[dim];
137 for (int j = 0; j < nvec; j++) {
139 int ref_idx = offset + j;
141 for (int dim = 0; dim < ndim; dim++) {
142 int stride = ref_strides[dim];
143 if (shape[dim] > 1) {
144 idx[j] += (ref_idx / stride + begin[dim]) * strides[dim];
146 ref_idx = ref_idx % stride;
149 vector::VectorizedStorage<DType, nvec> ret;
151 for (int j = 0; j < nvec; j++) {
152 ret.scratch_.separate[j] = idx[j] < shape.size ? *(input + idx[j]) : DType {};
157 template <int nvec, typename DType, int ndim>
158 __device__ inline vector::VectorizedStorage<DType, nvec> fast_load_slice(const DType * input,
159 const Shape<ndim>& shape,
165 Shape<ndim> ref_strides;
167 ref_strides[ndim-1] = 1;
170 for (int dim = ndim-1; dim >=0; dim--) {
171 if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
172 if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
173 if (end[dim] == INT_MAX) end[dim] = shape[dim];
175 ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
176 strides[dim-1] = strides[dim] * shape[dim];
179 int ref_idx = offset;
181 for (int dim = 0; dim < ndim; dim++) {
182 int stride = ref_strides[dim];
183 if (shape[dim] > 1) {
184 idx += (ref_idx / stride + begin[dim]) * strides[dim];
186 ref_idx = ref_idx % stride;
188 return global_load_index<nvec>(input, idx, shape);
191 template <int nvec, typename DType, int ndim>
192 __device__ inline void store_index(const vector::VectorizedStorage<DType, nvec> value, int i,
193 DType * output, const Shape<ndim>& shape) {
194 if (i < (shape.size + nvec - 1) / nvec) {
195 auto vector_output = reinterpret_cast<
196 typename vector::VectorizedStorage<DType, nvec>::LType *>(output);
197 vector_output[i] = value.scratch_.aligned;
201 template <int nvec, typename DType, int ndim>
202 __device__ inline void store_add_index(const vector::VectorizedStorage<DType, nvec> value, int i,
203 DType * output, const Shape<ndim>& shape) {
204 if (i < (shape.size + nvec - 1) / nvec) {
205 auto vector_output = reinterpret_cast<
206 typename vector::VectorizedStorage<DType, nvec>::LType *>(output);
207 vector::VectorizedStorage<DType, nvec> ret(vector_output[i]);
209 vector_output[i] = ret.scratch_.aligned;
219 template <typename DType>
220 __device__ inline bool isnan(const DType val) {
221 return util::isnan(val);
224 template <typename DType>
225 __device__ inline bool_t isinf(const DType val) {
226 return util::isinf(val);
229 template <typename DType>
230 __device__ inline bool_t isposinf(const DType val) {
231 return util::isinf(val) && (val > 0);
234 template <typename DType>
235 __device__ inline bool_t isneginf(const DType val) {
236 return util::isinf(val) && (val < 0);
239 template <typename DType>
240 __device__ inline bool_t isfinite(const DType val) {
241 return !op::isnan(val) && !op::isinf(val);
244 template <typename DType, typename DType2>
245 __device__ inline mixed_type<DType, DType2>
246 add(const DType a, const DType2 b) {
250 template <typename DType, typename DType2>
251 __device__ inline mixed_type<DType, DType2>
252 sub(const DType a, const DType2 b) {
256 template <typename DType, typename DType2>
257 __device__ inline mixed_type<DType, DType2>
258 rsub(const DType a, const DType2 b) {
262 template <typename DType, typename DType2>
263 __device__ inline mixed_type<DType, DType2>
264 floor_divide(const DType a, const DType2 b) {
265 if (type_util::has_double_or_integral<DType, DType2>::value) {
266 return ::floor((double)a / (double)b);
268 return ::floorf((float)a / (float)b);
272 template <typename DType, typename DType2>
273 __device__ inline mixed_type<DType, DType2>
274 rfloor_divide(const DType a, const DType2 b) {
275 if (type_util::has_double_or_integral<DType, DType2>::value) {
276 return ::floor((double)b / (double)a);
278 return ::floorf((float)b / (float)a);
282 template <typename DType, typename DType2>
283 __device__ inline mixed_type<DType, DType2>
284 mul(const DType a, const DType2 b) {
288 template <typename DType, typename DType2>
289 __device__ inline mixed_type<DType, DType2>
290 div(const DType a, const DType2 b) {
294 template <typename DType, typename DType2>
295 __device__ inline mixed_type<DType, DType2>
296 rdiv(const DType a, const DType2 b) {
300 #define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \
301 template <typename DType, typename DType2> \
302 __device__ inline mixed_type<DType, DType2> \
303 name (const DType a, const DType2 b) { \
304 if (type_util::has_double_or_integral<DType, DType2>::value) { \
305 return double_version ((double)a, (double)b); \
307 return float_version ((float)a, (float)b); \
311 template <typename DType, typename DType2>
312 __device__ inline mixed_type<DType, DType2>
313 power (const DType a, const DType2 b) {
314 if (type_util::has_double<DType, DType2>::value) {
315 return ::pow ((double)a, (double)b); \
317 return ::powf ((float)a, (float)b);
321 template <typename DType, typename DType2>
322 __device__ inline mixed_type<DType, DType2>
323 rpow(const DType a, const DType2 b) {
327 template <typename DType, typename DType2>
328 __device__ inline mixed_type<DType, DType2>
329 max(const DType a, const DType2 b) {
330 if (isnan(a)) return a;
331 return a > b ? a : b;
334 template <typename DType, typename DType2>
335 __device__ inline mixed_type<DType, DType2>
336 fmax(const DType a, const DType2 b) {
337 if (isnan(b)) return a;
338 return a > b ? a : b;
341 template <typename DType, typename DType2>
342 __device__ inline mixed_type<DType, DType2>
343 min(const DType a, const DType2 b) {
344 if (isnan(a)) return a;
345 return a < b ? a : b;
348 template <typename DType, typename DType2>
349 __device__ inline mixed_type<DType, DType2>
350 fmin(const DType a, const DType2 b) {
351 if (isnan(b)) return a;
352 return a < b ? a : b;
355 DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf)
357 template <typename DType, typename DType2>
358 __device__ inline mixed_type<DType, DType2>
359 mod(const DType a, const DType2 b) {
363 const double ad = static_cast<double>(a);
364 const double bd = static_cast<double>(b);
367 return -::fmod(-ad, -bd);
369 return ::fmod(ad, -bd) +
370 (::fmod(ad, -bd) != 0 ? bd : 0);
374 return -::fmod(-ad, bd) +
375 (::fmod(-ad, bd) != 0 ? bd : 0);
377 return ::fmod(ad, bd);
382 template <typename DType, typename DType2>
383 __device__ inline mixed_type<DType, DType2>
384 fmod(const DType a, const DType2 b) {
388 return ::fmod(static_cast<double>(a), static_cast<double>(b));
391 template <typename DType, typename DType2>
392 __device__ inline mixed_type<DType, DType2>
393 rmod(const DType a, const DType2 b) {
394 return op::mod(b, a);
397 template <typename DType, typename DType2>
398 __device__ inline mixed_type<DType, DType2>
399 rfmod(const DType a, const DType2 b) {
400 return op::fmod(b, a);
403 template <typename DType, typename DType2>
404 __device__ inline DType equal(const DType a, const DType2 b) {
405 const mixed_type<DType, DType2> real_a = a;
406 const mixed_type<DType, DType2> real_b = b;
407 return real_a == real_b ? 1 : 0;
410 template <typename DType, typename DType2>
411 __device__ inline DType not_equal(const DType a, const DType2 b) {
412 const mixed_type<DType, DType2> real_a = a;
413 const mixed_type<DType, DType2> real_b = b;
414 return real_a != real_b ? 1 : 0;
417 template <typename DType, typename DType2>
418 __device__ inline DType greater(const DType a, const DType2 b) {
419 const mixed_type<DType, DType2> real_a = a;
420 const mixed_type<DType, DType2> real_b = b;
421 return real_a > real_b ? 1 : 0;
424 template <typename DType, typename DType2>
425 __device__ inline DType greater_equal(const DType a, const DType2 b) {
426 const mixed_type<DType, DType2> real_a = a;
427 const mixed_type<DType, DType2> real_b = b;
428 return real_a >= real_b ? 1 : 0;
431 template <typename DType, typename DType2>
432 __device__ inline DType less(const DType a, const DType2 b) {
433 const mixed_type<DType, DType2> real_a = a;
434 const mixed_type<DType, DType2> real_b = b;
435 return real_a < real_b ? 1 : 0;
438 template <typename DType, typename DType2>
439 __device__ inline DType less_equal(const DType a, const DType2 b) {
440 const mixed_type<DType, DType2> real_a = a;
441 const mixed_type<DType, DType2> real_b = b;
442 return real_a <= real_b ? 1 : 0;
445 template <typename DType, typename DType2>
446 __device__ inline bool_t np_equal(const DType a, const DType2 b) {
447 const mixed_type<DType, DType2> real_a = a;
448 const mixed_type<DType, DType2> real_b = b;
449 return real_a == real_b ? true : false;
452 template <typename DType, typename DType2>
453 __device__ inline bool_t np_not_equal(const DType a, const DType2 b) {
454 const mixed_type<DType, DType2> real_a = a;
455 const mixed_type<DType, DType2> real_b = b;
456 return real_a != real_b ? true : false;
459 template <typename DType, typename DType2>
460 __device__ inline bool_t np_greater(const DType a, const DType2 b) {
461 const mixed_type<DType, DType2> real_a = a;
462 const mixed_type<DType, DType2> real_b = b;
463 return real_a > real_b ? true : false;
466 template <typename DType, typename DType2>
467 __device__ inline bool_t np_greater_equal(const DType a, const DType2 b) {
468 const mixed_type<DType, DType2> real_a = a;
469 const mixed_type<DType, DType2> real_b = b;
470 return real_a >= real_b ? true : false;
473 template <typename DType, typename DType2>
474 __device__ inline bool_t np_less(const DType a, const DType2 b) {
475 const mixed_type<DType, DType2> real_a = a;
476 const mixed_type<DType, DType2> real_b = b;
477 return real_a < real_b ? true : false;
480 template <typename DType, typename DType2>
481 __device__ inline bool_t np_less_equal(const DType a, const DType2 b) {
482 const mixed_type<DType, DType2> real_a = a;
483 const mixed_type<DType, DType2> real_b = b;
484 return real_a <= real_b ? true : false;
487 template <typename DType, typename DType2>
488 __device__ inline DType logical_and(const DType a, const DType2 b) {
489 return a && b ? 1 : 0;
492 template <typename DType, typename DType2>
493 __device__ inline DType logical_or(const DType a, const DType2 b) {
494 return a || b ? 1 : 0;
497 template <typename DType, typename DType2>
498 __device__ inline DType logical_xor(const DType a, const DType2 b) {
499 return ((a || b) && !(a && b)) ? 1 : 0;
502 template <typename DType, typename DType2>
503 __device__ inline DType copysign(const DType a, const DType2 b) {
504 return (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a;
507 template <typename DType, typename DType2>
508 __device__ inline DType2 rcopysign(const DType a, const DType2 b) {
509 return copysign(b, a);
512 template <typename DType, typename DType2>
513 __device__ inline mixed_type<DType, DType2>
514 lcm(const DType a, const DType2 b) {
515 if (type_util::is_integral<DType>::value &&
516 type_util::is_integral<DType2>::value) {
526 // handle zero-valued cases.
528 if (a == 0 || b == 0) {
545 c = tmp_a / B * tmp_b;
553 template <typename DType, typename DType2>
554 __device__ inline mixed_type<DType, DType2>
555 gcd(const DType a, const DType2 b) {
556 if (type_util::is_integral<DType>::value &&
557 type_util::is_integral<DType2>::value) {
567 // handle zero-valued cases.
569 if (a == 0 && b != 0) {
571 } else if (b == 0 && a != 0) {
573 } else if (a == 0 && b == 0) {
596 template <typename DType, typename DType2>
597 __device__ inline mixed_type<DType, DType2> bitwise_xor(const DType a,
599 const mixed_type<DType, DType2> real_a = a;
600 const mixed_type<DType, DType2> real_b = b;
601 return real_a ^ real_b;
604 template <typename DType, typename DType2>
605 __device__ inline mixed_type<DType, DType2> bitwise_or(const DType a,
607 const mixed_type<DType, DType2> real_a = a;
608 const mixed_type<DType, DType2> real_b = b;
609 return real_a | real_b;
612 template <typename DType, typename DType2>
613 __device__ inline mixed_type<DType, DType2> bitwise_and(const DType a,
615 const mixed_type<DType, DType2> real_a = a;
616 const mixed_type<DType, DType2> real_b = b;
617 return real_a & real_b;
620 template <typename DType, typename DType2>
621 __device__ inline mixed_type<DType, DType2> bitwise_left_shift(const DType a,
623 const mixed_type<DType, DType2> real_a = a;
624 const mixed_type<DType, DType2> real_b = b;
625 return real_a << real_b;
628 template <typename DType, typename DType2>
629 __device__ inline mixed_type<DType, DType2> rbitwise_left_shift(const DType a,
631 const mixed_type<DType, DType2> real_a = a;
632 const mixed_type<DType, DType2> real_b = b;
633 return real_b << real_a;
636 template <typename DType, typename DType2>
637 __device__ inline mixed_type<DType, DType2> bitwise_right_shift(const DType a,
639 const mixed_type<DType, DType2> real_a = a;
640 const mixed_type<DType, DType2> real_b = b;
641 return real_a >> real_b;
644 template <typename DType, typename DType2>
645 __device__ inline mixed_type<DType, DType2> rbitwise_right_shift(const DType a,
647 const mixed_type<DType, DType2> real_a = a;
648 const mixed_type<DType, DType2> real_b = b;
649 return real_b >> real_a;
652 DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)
654 template <typename DType, typename DType2>
655 __device__ inline mixed_type<DType, DType2>
656 rarctan2(const DType a, const DType2 b) {
657 return arctan2(b, a);
660 template <typename DType, typename DType2>
661 __device__ inline mixed_type<DType, DType2>
662 ldexp(const DType a, const DType2 b) {
663 if (type_util::has_double_or_integral<DType, DType2>::value) {
664 return a * ::pow(2.0, static_cast<double>(b));
666 return a * ::powf(2.0f, static_cast<float>(b));
670 template <typename DType, typename DType2>
671 __device__ inline mixed_type<DType, DType2>
672 rldexp(const DType a, const DType2 b) {
676 template <typename DType, typename DType2>
677 __device__ inline mixed_type<DType, DType2>
678 logaddexp(const DType a, const DType2 b) {
679 if (type_util::has_double_or_integral<DType, DType2>::value) {
680 return ::log(::exp(static_cast<double>(a)) + ::exp(static_cast<double>(b)));
682 return ::log(::expf(static_cast<float>(a)) + ::expf(static_cast<float>(b)));
686 #undef DEFINE_BINARY_MATH_FUNC
688 template <typename DType, typename DType2>
689 __device__ inline bool np_logical_and(const DType val, const DType2 val2) {
690 return (val && val2) ? true : false;
693 template <typename DType, typename DType2>
694 __device__ inline bool np_logical_or(const DType val, const DType2 val2) {
695 return (val || val2) ? true : false;
698 template <typename DType, typename DType2>
699 __device__ inline bool np_logical_xor(const DType val, const DType2 val2) {
700 return ((val || val2) && !(val && val2)) ? true : false;
703 template <typename DType, typename DType2>
704 __device__ inline DType left(const DType left_val, const DType2 right_val) {
708 template <typename DType, typename DType2>
709 __device__ inline DType2 right(const DType left_val, const DType2 right_val) {
719 template <typename DType>
720 __device__ inline DType identity(const DType val) {
724 template <typename DType>
725 __device__ inline DType negation(const DType val) {
729 template <typename OutType, typename DType>
730 __device__ inline typename LoadType<OutType>::Type cast(const DType val) {
731 return static_cast<typename LoadType<OutType>::Type>(val);
736 template <typename DType>
737 __device__ inline DType relu(const DType val) {
738 return (isnan(val) || val > 0) ? val : 0;
741 template <typename DType>
742 __device__ inline DType sigmoid(const DType val) {
743 if (type_util::has_double_or_integral<DType>::value) {
744 return 1./(1 + ::exp(-val));
746 return 1.f/(1 + expf(-val));
750 template <typename DType>
751 __device__ inline DType log_sigmoid(const DType val) {
752 if (type_util::has_double_or_integral<DType>::value) {
753 return ::log(1./(1 + ::exp(-val)));
755 return ::logf(1.f/(1 + expf(-val)));
759 template <typename DType>
760 __device__ inline DType softrelu(const DType val) {
761 // Avoid overflow of exp for large inputs.
762 // The threshold 20 is chosen such that softrelu(a) = a
763 // for a > 20 using floating precision.
764 if (val > 20) return val;
765 if (type_util::has_double_or_integral<DType>::value) {
766 return ::log(1 + ::exp(val));
768 return logf(1 + expf(val));
772 template <typename DType>
773 __device__ inline DType softsign(const DType val) {
774 if (type_util::has_double_or_integral<DType>::value) {
775 return val / (1 + fabs(val));
777 return val / (1 + fabsf(val));
783 #define DEFINE_UNARY_MATH_FUNC(name, double_version, float_version) \
784 template <typename DType> \
785 __device__ inline DType name (const DType a) { \
786 if (type_util::has_double_or_integral<DType>::value) { \
787 return double_version ((double)a); \
789 return float_version (a); \
793 DEFINE_UNARY_MATH_FUNC(exp, ::exp, ::expf)
794 DEFINE_UNARY_MATH_FUNC(expm1, ::expm1, ::expm1f)
795 DEFINE_UNARY_MATH_FUNC(log, ::log, ::logf)
796 DEFINE_UNARY_MATH_FUNC(log10, ::log10, ::log10f)
797 DEFINE_UNARY_MATH_FUNC(log2, ::log2, ::log2f)
798 DEFINE_UNARY_MATH_FUNC(log1p, ::log1p, ::log1pf)
802 constexpr double pi = 3.14159265358979323846;
804 template <typename DType>
805 __device__ inline DType degrees(const DType val) {
806 if (type_util::has_double_or_integral<DType>::value) {
807 return (val / pi) * 180;
809 return (val / static_cast<float>(pi)) * 180.f;
813 template <typename DType>
814 __device__ inline DType radians(const DType val) {
815 if (type_util::has_double_or_integral<DType>::value) {
816 return (val / 180.0) * pi;
818 return (val / 180.0f) * static_cast<float>(pi);
822 DEFINE_UNARY_MATH_FUNC(sin, ::sin, ::sinf)
823 DEFINE_UNARY_MATH_FUNC(cos, ::cos, ::cosf)
824 DEFINE_UNARY_MATH_FUNC(tan, ::tan, ::tanf)
825 DEFINE_UNARY_MATH_FUNC(arcsin, ::asin, ::asinf)
826 DEFINE_UNARY_MATH_FUNC(arccos, ::acos, ::acosf)
827 DEFINE_UNARY_MATH_FUNC(arctan, ::atan, ::atanf)
829 DEFINE_UNARY_MATH_FUNC(sinh, ::sinh, ::sinhf)
830 DEFINE_UNARY_MATH_FUNC(cosh, ::cosh, ::coshf)
831 DEFINE_UNARY_MATH_FUNC(tanh, ::tanh, ::tanhf)
832 DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf)
833 DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf)
834 DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf)
836 template <typename DType>
837 __device__ inline DType mish(const DType val) {
838 return val * op::tanh(op::softrelu(val));
843 DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf)
844 DEFINE_UNARY_MATH_FUNC(rsqrt, ::rsqrt, ::rsqrtf)
845 DEFINE_UNARY_MATH_FUNC(cbrt, ::cbrt, ::cbrtf)
846 DEFINE_UNARY_MATH_FUNC(rcbrt, ::rcbrt, ::rcbrtf)
848 template <typename DType>
849 __device__ inline DType square(const DType val) {
853 template <typename DType, typename... DTypes>
854 __device__ inline typename LoadType<DType>::Type zero(const DType val, const DTypes... args) {
858 template <typename DType>
859 __device__ inline typename LoadType<DType>::Type zero() {
863 template <typename DType, typename... DTypes>
864 __device__ inline typename LoadType<DType>::Type one(const DType val, const DTypes... args) {
868 template <typename DType>
869 __device__ inline typename LoadType<DType>::Type one() {
873 template <typename DType, typename... DTypes>
874 __device__ inline typename LoadType<DType>::Type negone(const DType val, const DTypes... args) {
878 template <typename DType>
879 __device__ inline typename LoadType<DType>::Type negone() {
883 template <typename DType>
884 __device__ inline DType round(const DType val) {
885 if (type_util::has_double<DType>::value) {
886 return ::round((double)val);
887 } else if (type_util::is_integral<DType>::value) {
890 return ::roundf(val);
894 template <typename DType>
895 __device__ inline DType floor(const DType val) {
896 if (type_util::has_double<DType>::value) {
897 return ::floor((double)val);
898 } else if (type_util::is_integral<DType>::value) {
901 return ::floorf(val);
905 template <typename DType>
906 __device__ inline DType ceil(const DType val) {
907 if (type_util::has_double<DType>::value) {
908 return ::ceil((double)val);
909 } else if (type_util::is_integral<DType>::value) {
916 template <typename DType>
917 __device__ inline DType rint(const DType val) {
918 if (type_util::has_double<DType>::value) {
919 return ::rint((double)val);
920 } else if (type_util::is_integral<DType>::value) {
927 template <typename DType>
928 __device__ inline DType fix(const DType val) {
929 const auto f = floor(val);
930 const auto c = ceil(val);
931 return (f > 0 ? f : -f) < (c > 0 ? c : -c) ? f : c;
934 template <typename DType>
935 __device__ inline DType trunc(const DType val) {
936 if (type_util::has_double<DType>::value) {
937 return ::trunc((double)val);
938 } else if (type_util::is_integral<DType>::value) {
941 return ::truncf(val);
945 template <typename DType>
946 __device__ inline DType clip(const DType val, const float a_min, const float a_max) {
949 } else if (val < a_min) {
956 template <typename DType>
957 __device__ inline DType sign(const DType val) {
958 if (val < 0) return -1;
959 return val > 0 ? 1 : 0;
962 template <typename DType>
963 __device__ inline DType reciprocal(const DType val) {
967 DEFINE_UNARY_MATH_FUNC(abs, ::fabs, ::fabsf)
968 DEFINE_UNARY_MATH_FUNC(gamma, ::tgamma, ::tgammaf)
969 DEFINE_UNARY_MATH_FUNC(gammaln, ::lgamma, ::lgammaf)
970 DEFINE_UNARY_MATH_FUNC(erf, ::erf, ::erff)
971 DEFINE_UNARY_MATH_FUNC(erfinv, ::erfinv, ::erfinvf)
973 template <typename DType>
974 __device__ inline DType gelu_erf(const DType val) {
975 return 0.5f * val * (1.0f + op::erf(val / op::sqrt(2.0f)));
978 template <typename DType1, typename DType2>
979 __device__ inline DType1 smooth_l1(const DType1 val, const DType2 scalar) {
980 const auto bsq = scalar * scalar;
981 const auto ibsq = 1.0f / bsq;
983 return val - 0.5f * ibsq;
984 } else if (val < -ibsq) {
985 return -val - 0.5f * ibsq;
987 return 0.5f * val * val * bsq;
991 template <typename DType>
992 __device__ inline DType digamma(const DType val) {
993 if (type_util::has_double_or_integral<DType>::value) {
994 return special_functions::cephes::psi<double>(val);
996 return special_functions::cephes::psi<float>(val);
1000 template <typename DType>
1001 __device__ inline DType logical_not(const DType val) {
1002 return val != DType(0) ? DType(0) : DType(1);
1005 template <typename DType>
1006 __device__ inline bool_t np_logical_not(const DType val) {
1007 return !static_cast<bool>(val);
1010 template <typename DType>
1011 __device__ inline bool_t NonZero(const DType val) {
1015 #undef DEFINE_UNARY_MATH_FUNC
1017 template <typename DType>
1018 __device__ inline DType bitwise_not(const DType a) {
1019 if (type_util::is_same<DType, bool_t>::value) {
1022 return ~static_cast<int64>(a);
1035 #endif // MXNET_USE_CUDA
1037 #endif // MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_
namespace of mxnet
Definition: api_registry.h:33
const char function_definitions_util[]
Definition: forward_functions-inl.h:30
const char function_definitions_binary[]
Definition: forward_functions-inl.h:216
const char function_definitions_unary[]
Definition: forward_functions-inl.h:716