1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file cuda_half_t.h |
22 | * \brief half_t (fp16) definition for cuda codegen. |
23 | */ |
24 | #ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ |
25 | #define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ |
26 | |
27 | static constexpr const char* _cuda_half_t_def = R"( |
28 | typedef unsigned short uint16_t; |
29 | typedef unsigned char uint8_t; |
30 | typedef signed char int8_t; |
31 | typedef int int32_t; |
32 | typedef unsigned long long uint64_t; |
33 | typedef unsigned int uint32_t; |
34 | |
35 | #define TVM_FORCE_INLINE inline __attribute__((always_inline)) |
36 | #define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__ |
37 | #define TVM_ALIGNED(x) __attribute__ ((aligned(x))) |
38 | #define TVM_HALF_OPERATOR(RTYPE, OP) \ |
39 | TVM_XINLINE RTYPE operator OP (half a, half b) { \ |
40 | return RTYPE(float(a) OP float(b)); \ |
41 | } \ |
42 | template<typename T> \ |
43 | TVM_XINLINE RTYPE operator OP (half a, T b) { \ |
44 | return RTYPE(float(a) OP float(b)); \ |
45 | } \ |
46 | template<typename T> \ |
47 | TVM_XINLINE RTYPE operator OP (T a, half b) { \ |
48 | return RTYPE(float(a) OP float(b)); \ |
49 | } |
50 | |
51 | #define TVM_HALF_ASSIGNOP(AOP, OP) \ |
52 | template<typename T> \ |
53 | TVM_XINLINE half operator AOP (const T& a) { \ |
54 | return *this = half(float(*this) OP float(a)); \ |
55 | } \ |
56 | template<typename T> \ |
57 | TVM_XINLINE half operator AOP (const volatile T& a) volatile { \ |
58 | return *this = half(float(*this) OP float(a)); \ |
59 | } |
60 | |
61 | class TVM_ALIGNED(2) half { |
62 | public: |
63 | uint16_t half_; |
64 | |
65 | static TVM_XINLINE half Binary(uint16_t value) { |
66 | half res; |
67 | res.half_ = value; |
68 | return res; |
69 | } |
70 | |
71 | TVM_XINLINE half() {} |
72 | |
73 | TVM_XINLINE half(const float& value) { constructor(value); } |
74 | TVM_XINLINE explicit half(const double& value) { constructor(value); } |
75 | TVM_XINLINE explicit half(const int8_t& value) { constructor(value); } |
76 | TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } |
77 | TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } |
78 | TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } |
79 | TVM_XINLINE explicit half(const long long& value) { constructor(value); } |
80 | TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } |
81 | |
82 | TVM_XINLINE operator float() const { \ |
83 | return float(half2float(half_)); \ |
84 | } \ |
85 | TVM_XINLINE operator float() const volatile { \ |
86 | return float(half2float(half_)); \ |
87 | } |
88 | |
89 | |
90 | TVM_HALF_ASSIGNOP(+=, +) |
91 | TVM_HALF_ASSIGNOP(-=, -) |
92 | TVM_HALF_ASSIGNOP(*=, *) |
93 | TVM_HALF_ASSIGNOP(/=, /) |
94 | |
95 | TVM_XINLINE half operator+() { |
96 | return *this; |
97 | } |
98 | |
99 | TVM_XINLINE half operator-() { |
100 | return half(-float(*this)); |
101 | } |
102 | |
103 | TVM_XINLINE half operator=(const half& a) { |
104 | half_ = a.half_; |
105 | return a; |
106 | } |
107 | |
108 | template<typename T> |
109 | TVM_XINLINE half operator=(const T& a) { |
110 | return *this = half(a); |
111 | } |
112 | |
113 | TVM_XINLINE half operator=(const half& a) volatile { |
114 | half_ = a.half_; |
115 | return a; |
116 | } |
117 | |
118 | template<typename T> |
119 | TVM_XINLINE half operator=(const T& a) volatile { |
120 | return *this = half(a); |
121 | } |
122 | |
123 | private: |
124 | union Bits { |
125 | float f; |
126 | int32_t si; |
127 | uint32_t ui; |
128 | }; |
129 | |
130 | static int const fp16FractionBits = 10; |
131 | static int const fp32FractionBits = 23; |
132 | static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff |
133 | static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 |
134 | static int const shift = fp32FractionBits - fp16FractionBits; // == 13 |
135 | static int const shiftSign = 16; |
136 | static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) |
137 | |
138 | static int32_t const infN = 0x7F800000; // flt32 infinity |
139 | static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift |
140 | static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 |
141 | static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 |
142 | static int32_t const signN = 0x80000000; // flt32 sign bit |
143 | |
144 | static int32_t const infC = infN >> shift; |
145 | static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 |
146 | static int32_t const maxC = maxN >> shift; |
147 | static int32_t const minC = minN >> shift; |
148 | static int32_t const signC = signN >> shiftSign; // flt16 sign bit |
149 | |
150 | static int32_t const mulN = 0x52000000; // (1 << 23) / minN |
151 | static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) |
152 | |
153 | static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted |
154 | static int32_t const norC = 0x00400; // min flt32 normal down shifted |
155 | |
156 | static int32_t const maxD = infC - maxC - 1; |
157 | static int32_t const minD = minC - subC - 1; |
158 | |
159 | TVM_XINLINE uint16_t float2half(const float& value) const { |
160 | Bits v; |
161 | v.f = value; |
162 | uint32_t sign = v.si & signN; // grab sign bit |
163 | v.si ^= sign; // clear sign bit from v |
164 | sign >>= shiftSign; // logical shift sign to fp16 position |
165 | |
166 | if (v.si <= maxZ) { |
167 | // Handle eventual zeros here to ensure |
168 | // vshift will not exceed 32 below. |
169 | v.ui = 0; |
170 | } else if (v.si < minN) { |
171 | // Handle denorms |
172 | uint32_t exp32 = v.ui >> fp32FractionBits; |
173 | int32_t exp16 = exp32 - expAdjust; |
174 | // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. |
175 | // Smaller (so negative) exp16 values should result in greater right shifts. |
176 | uint32_t vshift = 1 - exp16; |
177 | uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); |
178 | v.ui = significand >> vshift; |
179 | v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; |
180 | } else if (v.si <= maxN) { |
181 | // Handle norms |
182 | v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; |
183 | v.ui -= expAdjust << fp32FractionBits; |
184 | } else if (v.si <= infN) { |
185 | v.si = infN; |
186 | } else if (v.si < nanN) { |
187 | v.si = nanN; |
188 | } |
189 | |
190 | v.ui >>= shift; |
191 | return sign | (v.ui & 0x7fff); |
192 | } |
193 | |
194 | // Same as above routine, except for addition of volatile keyword |
195 | TVM_XINLINE uint16_t float2half( |
196 | const volatile float& value) const volatile { |
197 | Bits v; |
198 | v.f = value; |
199 | uint32_t sign = v.si & signN; // grab sign bit |
200 | v.si ^= sign; // clear sign bit from v |
201 | sign >>= shiftSign; // logical shift sign to fp16 position |
202 | |
203 | if (v.si <= maxZ) { |
204 | // Handle eventual zeros here to ensure |
205 | // vshift will not exceed 32 below. |
206 | v.ui = 0; |
207 | } else if (v.si < minN) { |
208 | // Handle denorms |
209 | uint32_t exp32 = v.ui >> fp32FractionBits; |
210 | int32_t exp16 = exp32 - expAdjust; |
211 | // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. |
212 | // Smaller (so negative) exp16 values should result in greater right shifts. |
213 | uint32_t vshift = 1 - exp16; |
214 | uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); |
215 | v.ui = significand >> vshift; |
216 | v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; |
217 | } else if (v.si <= maxN) { |
218 | // Handle norms |
219 | v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; |
220 | v.ui -= expAdjust << fp32FractionBits; |
221 | } else if (v.si <= infN) { |
222 | v.si = infN; |
223 | } else if (v.si < nanN) { |
224 | v.si = nanN; |
225 | } |
226 | |
227 | v.ui >>= shift; |
228 | return sign | (v.ui & 0x7fff); |
229 | } |
230 | |
231 | TVM_XINLINE float half2float(const uint16_t& value) const { |
232 | Bits v; |
233 | v.ui = value; |
234 | int32_t sign = v.si & signC; |
235 | v.si ^= sign; |
236 | sign <<= shiftSign; |
237 | v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); |
238 | v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); |
239 | Bits s; |
240 | s.si = mulC; |
241 | s.f *= v.si; |
242 | int32_t mask = -(norC > v.si); |
243 | v.si <<= shift; |
244 | v.si ^= (s.si ^ v.si) & mask; |
245 | v.si |= sign; |
246 | return v.f; |
247 | } |
248 | |
249 | TVM_XINLINE float half2float( |
250 | const volatile uint16_t& value) const volatile { |
251 | Bits v; |
252 | v.ui = value; |
253 | int32_t sign = v.si & signC; |
254 | v.si ^= sign; |
255 | sign <<= shiftSign; |
256 | v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); |
257 | v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); |
258 | Bits s; |
259 | s.si = mulC; |
260 | s.f *= v.si; |
261 | int32_t mask = -(norC > v.si); |
262 | v.si <<= shift; |
263 | v.si ^= (s.si ^ v.si) & mask; |
264 | v.si |= sign; |
265 | return v.f; |
266 | } |
267 | |
268 | template<typename T> |
269 | TVM_XINLINE void constructor(const T& value) { |
270 | half_ = float2half(float(value)); |
271 | } |
272 | }; |
273 | |
274 | TVM_HALF_OPERATOR(half, +) |
275 | TVM_HALF_OPERATOR(half, -) |
276 | TVM_HALF_OPERATOR(half, *) |
277 | TVM_HALF_OPERATOR(half, /) |
278 | TVM_HALF_OPERATOR(bool, >) |
279 | TVM_HALF_OPERATOR(bool, <) |
280 | TVM_HALF_OPERATOR(bool, >=) |
281 | TVM_HALF_OPERATOR(bool, <=) |
282 | |
283 | TVM_XINLINE half __float2half_rn(const float a) { |
284 | return half(a); |
285 | } |
286 | )" ; |
287 | |
288 | static constexpr const char* _cuda_half_util = R"( |
289 | // Pack two half values. |
290 | static inline __device__ __host__ unsigned |
291 | __pack_half2(const half x, const half y) { |
292 | unsigned v0 = *((unsigned short *)&x); |
293 | unsigned v1 = *((unsigned short *)&y); |
294 | return (v1 << 16) | v0; |
295 | } |
296 | |
297 | // Some fp16 math functions are not supported in cuda_fp16.h, |
298 | // so we define them here to make sure the generated CUDA code |
299 | // is valid. |
300 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) |
301 | #define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \ |
302 | static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \ |
303 | float tmp_x = __half2float(x); \ |
304 | float tmp_y = __half2float(y); \ |
305 | float result = FP32_MATH_NAME(tmp_x, tmp_y); \ |
306 | return __float2half(result); \ |
307 | } |
308 | |
309 | #define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \ |
310 | static inline __device__ __host__ half HALF_MATH_NAME(half x) { \ |
311 | float tmp_x = __half2float(x); \ |
312 | float result = FP32_MATH_NAME(tmp_x); \ |
313 | return __float2half(result); \ |
314 | } |
315 | |
316 | CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) |
317 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) |
318 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) |
319 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) |
320 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) |
321 | |
322 | #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY |
323 | #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY |
324 | |
325 | #endif |
326 | )" ; |
327 | |
328 | static constexpr const char* _cuda_bfloat16_util = R"( |
329 | // Pack two bfloat16 values. |
330 | static inline __device__ __host__ unsigned |
331 | __pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) { |
332 | unsigned v0 = *((unsigned short *)&x); |
333 | unsigned v1 = *((unsigned short *)&y); |
334 | return (v1 << 16) | v0; |
335 | } |
336 | |
337 | // Some bfp16 math functions are not supported in cuda_bfp16.h, |
338 | // so we define them here to make sure the generated CUDA code |
339 | // is valid. |
340 | #define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \ |
341 | static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x, nv_bfloat16 y) { \ |
342 | float tmp_x = __bfloat162float(x); \ |
343 | float tmp_y = __bfloat162float(y); \ |
344 | float result = FP32_MATH_NAME(tmp_x, tmp_y); \ |
345 | return __float2bfloat16(result); \ |
346 | } |
347 | |
348 | #define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \ |
349 | static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) { \ |
350 | float tmp_x = __bfloat162float(x); \ |
351 | float result = FP32_MATH_NAME(tmp_x); \ |
352 | return __float2bfloat16(result); \ |
353 | } |
354 | |
355 | CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) |
356 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) |
357 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) |
358 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) |
359 | CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) |
360 | |
361 | #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY |
362 | #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY |
363 | )" ; |
364 | |
365 | static constexpr const char* _cuda_warp_intrinsic_util = R"( |
366 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) |
367 | #define __shfl_sync(mask, var, lane, width) \ |
368 | __shfl((var), (lane), (width)) |
369 | |
370 | #define __shfl_down_sync(mask, var, offset, width) \ |
371 | __shfl_down((var), (offset), (width)) |
372 | |
373 | #define __shfl_up_sync(mask, var, offset, width) \ |
374 | __shfl_up((var), (offset), (width)) |
375 | #endif |
376 | |
377 | )" ; |
378 | |
379 | #endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ |
380 | |