mxnet
half-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_HALF_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_HALF_INL_H_
22 
23 #if MXNET_USE_CUDA
24 
25 namespace mxnet {
26 namespace common {
27 namespace cuda {
28 namespace rtc {
29 
30 const char fp16_support_string[] = R"code(
31 struct __align__(2) __half {
32  __host__ __device__ __half() : __x(0) { }
33  unsigned short __x;
34 };
35 /* Definitions of intrinsics */
36 __device__ inline __half __float2half(const float f) {
37  __half val;
38  asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f));
39  return val;
40 }
41 __device__ inline float __half2float(const __half h) {
42  float val;
43  asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x));
44  return val;
45 }
46 
47 typedef __half half;
48 
49 template <typename DType>
50 struct AccType {
51  using type = DType;
52 
53  __device__ static inline type from(const DType& val) {
54  return val;
55  }
56 
57  __device__ static inline DType to(type val) {
58  return val;
59  }
60 
61 };
62 
63 template<>
64 struct AccType<half> {
65  using type = float;
66 
67  __device__ static inline type from(const half& val) {
68  return __half2float(val);
69  }
70 
71  __device__ static inline half to(type val) {
72  return __float2half(val);
73  }
74 };
75 )code";
76 
77 } // namespace rtc
78 } // namespace cuda
79 } // namespace common
80 } // namespace mxnet
81 
82 #endif // MXNET_USE_CUDA
83 
84 #endif // MXNET_COMMON_CUDA_RTC_HALF_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::fp16_support_string
const char fp16_support_string[]
Definition: half-inl.h:30