mxnet
src
common
cuda
rtc
half-inl.h
Go to the documentation of this file.
1
/*
2
* Licensed to the Apache Software Foundation (ASF) under one
3
* or more contributor license agreements. See the NOTICE file
4
* distributed with this work for additional information
5
* regarding copyright ownership. The ASF licenses this file
6
* to you under the Apache License, Version 2.0 (the
7
* "License"); you may not use this file except in compliance
8
* with the License. You may obtain a copy of the License at
9
*
10
* http://www.apache.org/licenses/LICENSE-2.0
11
*
12
* Unless required by applicable law or agreed to in writing,
13
* software distributed under the License is distributed on an
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
* KIND, either express or implied. See the License for the
16
* specific language governing permissions and limitations
17
* under the License.
18
*/
19
20
#ifndef MXNET_COMMON_CUDA_RTC_HALF_INL_H_
21
#define MXNET_COMMON_CUDA_RTC_HALF_INL_H_
22
23
#if MXNET_USE_CUDA
24
25
namespace
mxnet
{
26
namespace
common {
27
namespace
cuda {
28
namespace
rtc {
29
30
const
char
fp16_support_string
[] = R
"code(
31
struct __align__(2) __half {
32
__host__ __device__ __half() : __x(0) { }
33
unsigned short __x;
34
};
35
/* Definitions of intrinsics */
36
__device__ inline __half __float2half(const float f) {
37
__half val;
38
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f));
39
return val;
40
}
41
__device__ inline float __half2float(const __half h) {
42
float val;
43
asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x));
44
return val;
45
}
46
47
typedef __half half;
48
49
template <typename DType>
50
struct AccType {
51
using type = DType;
52
53
__device__ static inline type from(const DType& val) {
54
return val;
55
}
56
57
__device__ static inline DType to(type val) {
58
return val;
59
}
60
61
};
62
63
template<>
64
struct AccType<half> {
65
using type = float;
66
67
__device__ static inline type from(const half& val) {
68
return __half2float(val);
69
}
70
71
__device__ static inline half to(type val) {
72
return __float2half(val);
73
}
74
};
75
)code";
76
77
}
// namespace rtc
78
}
// namespace cuda
79
}
// namespace common
80
}
// namespace mxnet
81
82
#endif // MXNET_USE_CUDA
83
84
#endif // MXNET_COMMON_CUDA_RTC_HALF_INL_H_
mxnet
namespace of mxnet
Definition:
api_registry.h:33
mxnet::common::cuda::rtc::fp16_support_string
const char fp16_support_string[]
Definition:
half-inl.h:30
Generated on Thu Jan 5 2023 03:47:40 for mxnet by
1.8.17