1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda_half_t.h
22 * \brief half_t (fp16) definition for cuda codegen.
23 */
24#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
25#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
26
27static constexpr const char* _cuda_half_t_def = R"(
28typedef unsigned short uint16_t;
29typedef unsigned char uint8_t;
30typedef signed char int8_t;
31typedef int int32_t;
32typedef unsigned long long uint64_t;
33typedef unsigned int uint32_t;
34
35#define TVM_FORCE_INLINE inline __attribute__((always_inline))
36#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
37#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
38#define TVM_HALF_OPERATOR(RTYPE, OP) \
39 TVM_XINLINE RTYPE operator OP (half a, half b) { \
40 return RTYPE(float(a) OP float(b)); \
41 } \
42 template<typename T> \
43 TVM_XINLINE RTYPE operator OP (half a, T b) { \
44 return RTYPE(float(a) OP float(b)); \
45 } \
46 template<typename T> \
47 TVM_XINLINE RTYPE operator OP (T a, half b) { \
48 return RTYPE(float(a) OP float(b)); \
49 }
50
51#define TVM_HALF_ASSIGNOP(AOP, OP) \
52 template<typename T> \
53 TVM_XINLINE half operator AOP (const T& a) { \
54 return *this = half(float(*this) OP float(a)); \
55 } \
56 template<typename T> \
57 TVM_XINLINE half operator AOP (const volatile T& a) volatile { \
58 return *this = half(float(*this) OP float(a)); \
59 }
60
61class TVM_ALIGNED(2) half {
62 public:
63 uint16_t half_;
64
65 static TVM_XINLINE half Binary(uint16_t value) {
66 half res;
67 res.half_ = value;
68 return res;
69 }
70
71 TVM_XINLINE half() {}
72
73 TVM_XINLINE half(const float& value) { constructor(value); }
74 TVM_XINLINE explicit half(const double& value) { constructor(value); }
75 TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
76 TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
77 TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
78 TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
79 TVM_XINLINE explicit half(const long long& value) { constructor(value); }
80 TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
81
82 TVM_XINLINE operator float() const { \
83 return float(half2float(half_)); \
84 } \
85 TVM_XINLINE operator float() const volatile { \
86 return float(half2float(half_)); \
87 }
88
89
90 TVM_HALF_ASSIGNOP(+=, +)
91 TVM_HALF_ASSIGNOP(-=, -)
92 TVM_HALF_ASSIGNOP(*=, *)
93 TVM_HALF_ASSIGNOP(/=, /)
94
95 TVM_XINLINE half operator+() {
96 return *this;
97 }
98
99 TVM_XINLINE half operator-() {
100 return half(-float(*this));
101 }
102
103 TVM_XINLINE half operator=(const half& a) {
104 half_ = a.half_;
105 return a;
106 }
107
108 template<typename T>
109 TVM_XINLINE half operator=(const T& a) {
110 return *this = half(a);
111 }
112
113 TVM_XINLINE half operator=(const half& a) volatile {
114 half_ = a.half_;
115 return a;
116 }
117
118 template<typename T>
119 TVM_XINLINE half operator=(const T& a) volatile {
120 return *this = half(a);
121 }
122
123 private:
124 union Bits {
125 float f;
126 int32_t si;
127 uint32_t ui;
128 };
129
130 static int const fp16FractionBits = 10;
131 static int const fp32FractionBits = 23;
132 static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
133 static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
134 static int const shift = fp32FractionBits - fp16FractionBits; // == 13
135 static int const shiftSign = 16;
136 static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
137
138 static int32_t const infN = 0x7F800000; // flt32 infinity
139 static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
140 static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
141 static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
142 static int32_t const signN = 0x80000000; // flt32 sign bit
143
144 static int32_t const infC = infN >> shift;
145 static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
146 static int32_t const maxC = maxN >> shift;
147 static int32_t const minC = minN >> shift;
148 static int32_t const signC = signN >> shiftSign; // flt16 sign bit
149
150 static int32_t const mulN = 0x52000000; // (1 << 23) / minN
151 static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
152
153 static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
154 static int32_t const norC = 0x00400; // min flt32 normal down shifted
155
156 static int32_t const maxD = infC - maxC - 1;
157 static int32_t const minD = minC - subC - 1;
158
159 TVM_XINLINE uint16_t float2half(const float& value) const {
160 Bits v;
161 v.f = value;
162 uint32_t sign = v.si & signN; // grab sign bit
163 v.si ^= sign; // clear sign bit from v
164 sign >>= shiftSign; // logical shift sign to fp16 position
165
166 if (v.si <= maxZ) {
167 // Handle eventual zeros here to ensure
168 // vshift will not exceed 32 below.
169 v.ui = 0;
170 } else if (v.si < minN) {
171 // Handle denorms
172 uint32_t exp32 = v.ui >> fp32FractionBits;
173 int32_t exp16 = exp32 - expAdjust;
174 // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
175 // Smaller (so negative) exp16 values should result in greater right shifts.
176 uint32_t vshift = 1 - exp16;
177 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
178 v.ui = significand >> vshift;
179 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
180 } else if (v.si <= maxN) {
181 // Handle norms
182 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
183 v.ui -= expAdjust << fp32FractionBits;
184 } else if (v.si <= infN) {
185 v.si = infN;
186 } else if (v.si < nanN) {
187 v.si = nanN;
188 }
189
190 v.ui >>= shift;
191 return sign | (v.ui & 0x7fff);
192 }
193
194 // Same as above routine, except for addition of volatile keyword
195 TVM_XINLINE uint16_t float2half(
196 const volatile float& value) const volatile {
197 Bits v;
198 v.f = value;
199 uint32_t sign = v.si & signN; // grab sign bit
200 v.si ^= sign; // clear sign bit from v
201 sign >>= shiftSign; // logical shift sign to fp16 position
202
203 if (v.si <= maxZ) {
204 // Handle eventual zeros here to ensure
205 // vshift will not exceed 32 below.
206 v.ui = 0;
207 } else if (v.si < minN) {
208 // Handle denorms
209 uint32_t exp32 = v.ui >> fp32FractionBits;
210 int32_t exp16 = exp32 - expAdjust;
211 // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
212 // Smaller (so negative) exp16 values should result in greater right shifts.
213 uint32_t vshift = 1 - exp16;
214 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
215 v.ui = significand >> vshift;
216 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
217 } else if (v.si <= maxN) {
218 // Handle norms
219 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
220 v.ui -= expAdjust << fp32FractionBits;
221 } else if (v.si <= infN) {
222 v.si = infN;
223 } else if (v.si < nanN) {
224 v.si = nanN;
225 }
226
227 v.ui >>= shift;
228 return sign | (v.ui & 0x7fff);
229 }
230
231 TVM_XINLINE float half2float(const uint16_t& value) const {
232 Bits v;
233 v.ui = value;
234 int32_t sign = v.si & signC;
235 v.si ^= sign;
236 sign <<= shiftSign;
237 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
238 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
239 Bits s;
240 s.si = mulC;
241 s.f *= v.si;
242 int32_t mask = -(norC > v.si);
243 v.si <<= shift;
244 v.si ^= (s.si ^ v.si) & mask;
245 v.si |= sign;
246 return v.f;
247 }
248
249 TVM_XINLINE float half2float(
250 const volatile uint16_t& value) const volatile {
251 Bits v;
252 v.ui = value;
253 int32_t sign = v.si & signC;
254 v.si ^= sign;
255 sign <<= shiftSign;
256 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
257 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
258 Bits s;
259 s.si = mulC;
260 s.f *= v.si;
261 int32_t mask = -(norC > v.si);
262 v.si <<= shift;
263 v.si ^= (s.si ^ v.si) & mask;
264 v.si |= sign;
265 return v.f;
266 }
267
268 template<typename T>
269 TVM_XINLINE void constructor(const T& value) {
270 half_ = float2half(float(value));
271 }
272};
273
274TVM_HALF_OPERATOR(half, +)
275TVM_HALF_OPERATOR(half, -)
276TVM_HALF_OPERATOR(half, *)
277TVM_HALF_OPERATOR(half, /)
278TVM_HALF_OPERATOR(bool, >)
279TVM_HALF_OPERATOR(bool, <)
280TVM_HALF_OPERATOR(bool, >=)
281TVM_HALF_OPERATOR(bool, <=)
282
283TVM_XINLINE half __float2half_rn(const float a) {
284 return half(a);
285}
286)";
287
288static constexpr const char* _cuda_half_util = R"(
289// Pack two half values.
290static inline __device__ __host__ unsigned
291__pack_half2(const half x, const half y) {
292 unsigned v0 = *((unsigned short *)&x);
293 unsigned v1 = *((unsigned short *)&y);
294 return (v1 << 16) | v0;
295}
296
297// Some fp16 math functions are not supported in cuda_fp16.h,
298// so we define them here to make sure the generated CUDA code
299// is valid.
300#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
301#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
302static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \
303 float tmp_x = __half2float(x); \
304 float tmp_y = __half2float(y); \
305 float result = FP32_MATH_NAME(tmp_x, tmp_y); \
306 return __float2half(result); \
307}
308
309#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
310static inline __device__ __host__ half HALF_MATH_NAME(half x) { \
311 float tmp_x = __half2float(x); \
312 float result = FP32_MATH_NAME(tmp_x); \
313 return __float2half(result); \
314}
315
316CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
317CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
318CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
319CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
320CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
321
322#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
323#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
324
325#endif
326)";
327
328static constexpr const char* _cuda_bfloat16_util = R"(
329// Pack two bfloat16 values.
330static inline __device__ __host__ unsigned
331__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
332 unsigned v0 = *((unsigned short *)&x);
333 unsigned v1 = *((unsigned short *)&y);
334 return (v1 << 16) | v0;
335}
336
337// Some bfp16 math functions are not supported in cuda_bfp16.h,
338// so we define them here to make sure the generated CUDA code
339// is valid.
340#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
341static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x, nv_bfloat16 y) { \
342 float tmp_x = __bfloat162float(x); \
343 float tmp_y = __bfloat162float(y); \
344 float result = FP32_MATH_NAME(tmp_x, tmp_y); \
345 return __float2bfloat16(result); \
346}
347
348#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
349static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) { \
350 float tmp_x = __bfloat162float(x); \
351 float result = FP32_MATH_NAME(tmp_x); \
352 return __float2bfloat16(result); \
353}
354
355CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
356CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
357CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
358CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
359CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
360
361#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
362#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
363)";
364
365static constexpr const char* _cuda_warp_intrinsic_util = R"(
366#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
367#define __shfl_sync(mask, var, lane, width) \
368 __shfl((var), (lane), (width))
369
370#define __shfl_down_sync(mask, var, offset, width) \
371 __shfl_down((var), (offset), (width))
372
373#define __shfl_up_sync(mask, var, offset, width) \
374 __shfl_up((var), (offset), (width))
375#endif
376
377)";
378
379#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
380