1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/bf16_support.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* bf16_support_cu = R"( |
7 | |
8 | #define __NVFUSER_BFLOAT_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var))) |
9 | #define __NVFUSER_BFLOAT_TO_CUS(var) \ |
10 | *(reinterpret_cast<const unsigned short*>(&(var))) |
11 | |
12 | struct __bfloat; |
13 | __device__ __bfloat __float2bfloat(const float); |
14 | |
15 | struct __align__(2) __bfloat { |
16 | __bfloat() = default; |
17 | |
18 | __device__ __bfloat(const float f) { |
19 | __x = __float2bfloat(f).__x; |
20 | } |
21 | |
22 | protected: |
23 | unsigned short __x; |
24 | }; |
25 | |
26 | __device__ __bfloat __float2bfloat(const float f) { |
27 | __bfloat val; |
28 | asm("{ cvt.rn.bf16.f32 %0, %1;}\n" |
29 | : "=h"(__NVFUSER_BFLOAT_TO_US(val)) |
30 | : "f"(f)); |
31 | return val; |
32 | } |
33 | |
34 | __device__ float __bfloat2float(const __bfloat h) { |
35 | float val; |
36 | asm("{ mov.b32 %0, {0,%1};}\n" |
37 | : "=f"(val) |
38 | : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); |
39 | return val; |
40 | } |
41 | )" ; |
42 | |
43 | } // namespace nvfuser_resources |
44 | |