1 | /** |
2 | * Copyright (c) Facebook, Inc. and its affiliates. |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include <iostream> |
8 | |
9 | #ifdef __CUDA_ARCH__ |
10 | #include <cuda.h> |
11 | // Disable strict aliasing errors for CUDA 9. |
12 | #if CUDA_VERSION >= 9000 |
13 | #ifdef __GNUC__ |
14 | #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) |
15 | #pragma GCC diagnostic push |
16 | #endif |
17 | #pragma GCC diagnostic ignored "-Wstrict-aliasing" |
18 | #endif // __GNUC__ |
19 | #endif // CUDA_VERSION >= 9000 |
20 | #include <cuda_fp16.h> |
21 | #if CUDA_VERSION >= 9000 |
22 | #ifdef __GNUC__ |
23 | #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) |
24 | #pragma GCC diagnostic pop |
25 | #endif |
26 | #endif // __GNUC__ |
27 | #endif // CUDA_VERSION >= 9000 |
28 | #endif |
29 | |
30 | #include "gloo/common/common.h" |
31 | |
32 | #ifdef _WIN32 |
33 | #include <BaseTsd.h> |
34 | typedef SSIZE_T ssize_t; |
35 | #endif |
36 | |
37 | namespace gloo { |
38 | |
39 | // Unlike old style collectives that are class instances that hold |
40 | // some state, the new style collectives do not need initialization |
41 | // before they can run. Instead of asking the context for a series of |
42 | // slots and storing them for later use and reuse, the new style |
43 | // collectives take a slot (or tag) argument that allows for |
44 | // concurrent execution of multiple collectives on the same context. |
45 | // |
46 | // This tag is what determines the slot numbers for the send and recv |
47 | // operations that the collectives end up executing. A single |
48 | // collective may have many send and recv operations running in |
49 | // parallel, so instead of using the specified tag verbatim, we use it |
50 | // as a prefix. Also, to avoid conflicts between collectives with the |
51 | // same tag, we have another tag prefix per collective type. Out of |
52 | // the 64 bits we can use for a slot, we use 8 of them to identify a |
53 | // collective, 32 to identify the collective tag, another 8 for use by |
54 | // the collective operation itself (allowing for 256 independent send |
55 | // and recv operations against the same point to point pair), and |
56 | // leave 16 bits unused. |
57 | // |
58 | // Below, you find constexprs for the prefix per collective type, as |
59 | // well as a way to compute slots when executing a collective. The |
60 | // slot class below captures both a prefix and a delta on that prefix |
61 | // to support addition with bounds checking. It is usable as an |
62 | // uint64_t, but one that cannot overflow beyond the bits allocated |
63 | // for use within a collective. |
64 | // |
65 | |
66 | constexpr uint8_t kGatherSlotPrefix = 0x01; |
67 | constexpr uint8_t kAllgatherSlotPrefix = 0x02; |
68 | constexpr uint8_t kReduceSlotPrefix = 0x03; |
69 | constexpr uint8_t kAllreduceSlotPrefix = 0x04; |
70 | constexpr uint8_t kScatterSlotPrefix = 0x05; |
71 | constexpr uint8_t kBroadcastSlotPrefix = 0x06; |
72 | constexpr uint8_t kBarrierSlotPrefix = 0x07; |
73 | constexpr uint8_t kAlltoallSlotPrefix = 0x08; |
74 | |
75 | class Slot { |
76 | public: |
77 | static Slot build(uint8_t prefix, uint32_t tag); |
78 | |
79 | operator uint64_t() const { |
80 | return base_ + delta_; |
81 | } |
82 | |
83 | Slot operator+(uint8_t i) const; |
84 | |
85 | protected: |
86 | explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {} |
87 | |
88 | const uint64_t base_; |
89 | const uint64_t delta_; |
90 | }; |
91 | |
92 | struct float16; |
93 | float16 cpu_float2half_rn(float f); |
94 | float cpu_half2float(float16 h); |
95 | |
96 | struct alignas(2) float16 { |
97 | uint16_t x; |
98 | |
99 | float16() : x(0) {} |
100 | |
101 | float16(const float16 &) = default; |
102 | |
103 | explicit float16(int val) { |
104 | float16 res = cpu_float2half_rn(static_cast<float>(val)); |
105 | x = res.x; |
106 | } |
107 | |
108 | explicit float16(unsigned long val) { |
109 | float16 res = cpu_float2half_rn(static_cast<float>(val)); |
110 | x = res.x; |
111 | } |
112 | |
113 | explicit float16(unsigned long long val) { |
114 | float16 res = cpu_float2half_rn(static_cast<float>(val)); |
115 | x = res.x; |
116 | } |
117 | |
118 | explicit float16(double val) { |
119 | float16 res = cpu_float2half_rn(static_cast<float>(val)); |
120 | x = res.x; |
121 | } |
122 | |
123 | float16& operator=(const int& rhs) { |
124 | float16 res = cpu_float2half_rn(static_cast<float>(rhs)); |
125 | x = res.x; |
126 | return *this; |
127 | } |
128 | |
129 | float16& operator=(const float16& rhs) { |
130 | if (rhs != *this) { |
131 | x = rhs.x; |
132 | } |
133 | return *this; |
134 | } |
135 | |
136 | bool operator==(const float16& rhs) const { |
137 | return x == rhs.x; |
138 | } |
139 | |
140 | bool operator!=(const float16& rhs) const { |
141 | return !(*this == rhs.x); |
142 | } |
143 | |
144 | bool operator==(const int& rhs) const { |
145 | float16 res = cpu_float2half_rn(static_cast<float>(rhs)); |
146 | return x == res.x; |
147 | } |
148 | |
149 | bool operator==(const unsigned long& rhs) const { |
150 | float16 res = cpu_float2half_rn(static_cast<float>(rhs)); |
151 | return x == res.x; |
152 | } |
153 | |
154 | bool operator==(const double& rhs) const { |
155 | float16 res = cpu_float2half_rn(static_cast<float>(rhs)); |
156 | return x == res.x; |
157 | } |
158 | #ifdef __CUDA_ARCH__ |
159 | float16(half h) { |
160 | #if CUDA_VERSION >= 9000 |
161 | x = reinterpret_cast<__half_raw*>(&h)->x; |
162 | #else |
163 | x = h.x; |
164 | #endif // CUDA_VERSION |
165 | } |
166 | |
167 | // half and float16 are supposed to have identical representation so implicit |
168 | // conversion should be fine |
169 | /* implicit */ |
170 | operator half() const { |
171 | #if CUDA_VERSION >= 9000 |
172 | __half_raw hr; |
173 | hr.x = this->x; |
174 | return half(hr); |
175 | #else |
176 | return (half) * this; |
177 | #endif // CUDA_VERSION |
178 | } |
179 | #endif // __CUDA_ARCH |
180 | |
181 | float16& operator+=(const float16& rhs) { |
182 | float r = cpu_half2float(*this) + cpu_half2float(rhs); |
183 | *this = cpu_float2half_rn(r); |
184 | return *this; |
185 | } |
186 | |
187 | float16& operator-=(const float16& rhs) { |
188 | float r = cpu_half2float(*this) - cpu_half2float(rhs); |
189 | *this = cpu_float2half_rn(r); |
190 | return *this; |
191 | } |
192 | |
193 | float16& operator*=(const float16& rhs) { |
194 | float r = cpu_half2float(*this) * cpu_half2float(rhs); |
195 | *this = cpu_float2half_rn(r); |
196 | return *this; |
197 | } |
198 | |
199 | float16& operator/=(const float16& rhs) { |
200 | float r = cpu_half2float(*this) / cpu_half2float(rhs); |
201 | *this = cpu_float2half_rn(r); |
202 | return *this; |
203 | } |
204 | }; |
205 | |
206 | inline std::ostream& operator<<(std::ostream& stream, const float16& val) { |
207 | stream << cpu_half2float(val); |
208 | return stream; |
209 | } |
210 | |
211 | inline float16 operator+(const float16& lhs, const float16& rhs) { |
212 | float16 result = lhs; |
213 | result += rhs; |
214 | return result; |
215 | } |
216 | |
217 | inline float16 operator-(const float16& lhs, const float16& rhs) { |
218 | float16 result = lhs; |
219 | result -= rhs; |
220 | return result; |
221 | } |
222 | |
223 | inline float16 operator*(const float16& lhs, const float16& rhs) { |
224 | float16 result = lhs; |
225 | result *= rhs; |
226 | return result; |
227 | } |
228 | |
229 | inline float16 operator/(const float16& lhs, const float16& rhs) { |
230 | float16 result = lhs; |
231 | result /= rhs; |
232 | return result; |
233 | } |
234 | |
235 | inline bool operator<(const float16& lhs, const float16& rhs) { |
236 | return cpu_half2float(lhs) < cpu_half2float(rhs); |
237 | } |
238 | |
239 | inline bool operator<=(const float16& lhs, const float16& rhs) { |
240 | return cpu_half2float(lhs) <= cpu_half2float(rhs); |
241 | } |
242 | |
243 | inline bool operator>(const float16& lhs, const float16& rhs) { |
244 | return cpu_half2float(lhs) > cpu_half2float(rhs); |
245 | } |
246 | |
247 | inline bool operator>=(const float16& lhs, const float16& rhs) { |
248 | return cpu_half2float(lhs) >= cpu_half2float(rhs); |
249 | } |
250 | |
251 | inline float16 cpu_float2half_rn(float f) { |
252 | float16 ret; |
253 | |
254 | static_assert( |
255 | sizeof(unsigned int) == sizeof(float), |
256 | "Programming error sizeof(unsigned int) != sizeof(float)" ); |
257 | |
258 | unsigned* xp = reinterpret_cast<unsigned int*>(&f); |
259 | unsigned x = *xp; |
260 | unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; |
261 | unsigned sign, exponent, mantissa; |
262 | |
263 | // Get rid of +NaN/-NaN case first. |
264 | if (u > 0x7f800000) { |
265 | ret.x = 0x7fffU; |
266 | return ret; |
267 | } |
268 | |
269 | sign = ((x >> 16) & 0x8000); |
270 | |
271 | // Get rid of +Inf/-Inf, +0/-0. |
272 | if (u > 0x477fefff) { |
273 | ret.x = sign | 0x7c00U; |
274 | return ret; |
275 | } |
276 | if (u < 0x33000001) { |
277 | ret.x = (sign | 0x0000); |
278 | return ret; |
279 | } |
280 | |
281 | exponent = ((u >> 23) & 0xff); |
282 | mantissa = (u & 0x7fffff); |
283 | |
284 | if (exponent > 0x70) { |
285 | shift = 13; |
286 | exponent -= 0x70; |
287 | } else { |
288 | shift = 0x7e - exponent; |
289 | exponent = 0; |
290 | mantissa |= 0x800000; |
291 | } |
292 | lsb = (1 << shift); |
293 | lsb_s1 = (lsb >> 1); |
294 | lsb_m1 = (lsb - 1); |
295 | |
296 | // Round to nearest even. |
297 | remainder = (mantissa & lsb_m1); |
298 | mantissa >>= shift; |
299 | if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { |
300 | ++mantissa; |
301 | if (!(mantissa & 0x3ff)) { |
302 | ++exponent; |
303 | mantissa = 0; |
304 | } |
305 | } |
306 | |
307 | ret.x = (sign | (exponent << 10) | mantissa); |
308 | |
309 | return ret; |
310 | } |
311 | |
312 | inline float cpu_half2float(float16 h) { |
313 | unsigned sign = ((h.x >> 15) & 1); |
314 | unsigned exponent = ((h.x >> 10) & 0x1f); |
315 | unsigned mantissa = ((h.x & 0x3ff) << 13); |
316 | |
317 | if (exponent == 0x1f) { /* NaN or Inf */ |
318 | mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); |
319 | exponent = 0xff; |
320 | } else if (!exponent) { /* Denorm or Zero */ |
321 | if (mantissa) { |
322 | unsigned int msb; |
323 | exponent = 0x71; |
324 | do { |
325 | msb = (mantissa & 0x400000); |
326 | mantissa <<= 1; /* normalize */ |
327 | --exponent; |
328 | } while (!msb); |
329 | mantissa &= 0x7fffff; /* 1.mantissa is implicit */ |
330 | } |
331 | } else { |
332 | exponent += 0x70; |
333 | } |
334 | |
335 | unsigned temp = ((sign << 31) | (exponent << 23) | mantissa); |
336 | |
337 | void* rp = &temp; |
338 | return *(float*)rp; |
339 | } |
340 | |
341 | } // namespace gloo |
342 | |