1/**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 */
4
5#pragma once
6
7#include <iostream>
8
9#ifdef __CUDA_ARCH__
10#include <cuda.h>
11// Disable strict aliasing errors for CUDA 9.
12#if CUDA_VERSION >= 9000
13#ifdef __GNUC__
14#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
15#pragma GCC diagnostic push
16#endif
17#pragma GCC diagnostic ignored "-Wstrict-aliasing"
18#endif // __GNUC__
19#endif // CUDA_VERSION >= 9000
20#include <cuda_fp16.h>
21#if CUDA_VERSION >= 9000
22#ifdef __GNUC__
23#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
24#pragma GCC diagnostic pop
25#endif
26#endif // __GNUC__
27#endif // CUDA_VERSION >= 9000
28#endif
29
30#include "gloo/common/common.h"
31
32#ifdef _WIN32
33#include <BaseTsd.h>
34typedef SSIZE_T ssize_t;
35#endif
36
37namespace gloo {
38
39// Unlike old style collectives that are class instances that hold
40// some state, the new style collectives do not need initialization
41// before they can run. Instead of asking the context for a series of
42// slots and storing them for later use and reuse, the new style
43// collectives take a slot (or tag) argument that allows for
44// concurrent execution of multiple collectives on the same context.
45//
46// This tag is what determines the slot numbers for the send and recv
47// operations that the collectives end up executing. A single
48// collective may have many send and recv operations running in
49// parallel, so instead of using the specified tag verbatim, we use it
50// as a prefix. Also, to avoid conflicts between collectives with the
51// same tag, we have another tag prefix per collective type. Out of
52// the 64 bits we can use for a slot, we use 8 of them to identify a
53// collective, 32 to identify the collective tag, another 8 for use by
54// the collective operation itself (allowing for 256 independent send
55// and recv operations against the same point to point pair), and
56// leave 16 bits unused.
57//
58// Below, you find constexprs for the prefix per collective type, as
59// well as a way to compute slots when executing a collective. The
60// slot class below captures both a prefix and a delta on that prefix
61// to support addition with bounds checking. It is usable as an
62// uint64_t, but one that cannot overflow beyond the bits allocated
63// for use within a collective.
64//
65
66constexpr uint8_t kGatherSlotPrefix = 0x01;
67constexpr uint8_t kAllgatherSlotPrefix = 0x02;
68constexpr uint8_t kReduceSlotPrefix = 0x03;
69constexpr uint8_t kAllreduceSlotPrefix = 0x04;
70constexpr uint8_t kScatterSlotPrefix = 0x05;
71constexpr uint8_t kBroadcastSlotPrefix = 0x06;
72constexpr uint8_t kBarrierSlotPrefix = 0x07;
73constexpr uint8_t kAlltoallSlotPrefix = 0x08;
74
75class Slot {
76 public:
77 static Slot build(uint8_t prefix, uint32_t tag);
78
79 operator uint64_t() const {
80 return base_ + delta_;
81 }
82
83 Slot operator+(uint8_t i) const;
84
85 protected:
86 explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}
87
88 const uint64_t base_;
89 const uint64_t delta_;
90};
91
92struct float16;
93float16 cpu_float2half_rn(float f);
94float cpu_half2float(float16 h);
95
96struct alignas(2) float16 {
97 uint16_t x;
98
99 float16() : x(0) {}
100
101 float16(const float16 &) = default;
102
103 explicit float16(int val) {
104 float16 res = cpu_float2half_rn(static_cast<float>(val));
105 x = res.x;
106 }
107
108 explicit float16(unsigned long val) {
109 float16 res = cpu_float2half_rn(static_cast<float>(val));
110 x = res.x;
111 }
112
113 explicit float16(unsigned long long val) {
114 float16 res = cpu_float2half_rn(static_cast<float>(val));
115 x = res.x;
116 }
117
118 explicit float16(double val) {
119 float16 res = cpu_float2half_rn(static_cast<float>(val));
120 x = res.x;
121 }
122
123 float16& operator=(const int& rhs) {
124 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
125 x = res.x;
126 return *this;
127 }
128
129 float16& operator=(const float16& rhs) {
130 if (rhs != *this) {
131 x = rhs.x;
132 }
133 return *this;
134 }
135
136 bool operator==(const float16& rhs) const {
137 return x == rhs.x;
138 }
139
140 bool operator!=(const float16& rhs) const {
141 return !(*this == rhs.x);
142 }
143
144 bool operator==(const int& rhs) const {
145 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
146 return x == res.x;
147 }
148
149 bool operator==(const unsigned long& rhs) const {
150 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
151 return x == res.x;
152 }
153
154 bool operator==(const double& rhs) const {
155 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
156 return x == res.x;
157 }
158#ifdef __CUDA_ARCH__
159 float16(half h) {
160#if CUDA_VERSION >= 9000
161 x = reinterpret_cast<__half_raw*>(&h)->x;
162#else
163 x = h.x;
164#endif // CUDA_VERSION
165 }
166
167 // half and float16 are supposed to have identical representation so implicit
168 // conversion should be fine
169 /* implicit */
170 operator half() const {
171#if CUDA_VERSION >= 9000
172 __half_raw hr;
173 hr.x = this->x;
174 return half(hr);
175#else
176 return (half) * this;
177#endif // CUDA_VERSION
178 }
179#endif // __CUDA_ARCH
180
181 float16& operator+=(const float16& rhs) {
182 float r = cpu_half2float(*this) + cpu_half2float(rhs);
183 *this = cpu_float2half_rn(r);
184 return *this;
185 }
186
187 float16& operator-=(const float16& rhs) {
188 float r = cpu_half2float(*this) - cpu_half2float(rhs);
189 *this = cpu_float2half_rn(r);
190 return *this;
191 }
192
193 float16& operator*=(const float16& rhs) {
194 float r = cpu_half2float(*this) * cpu_half2float(rhs);
195 *this = cpu_float2half_rn(r);
196 return *this;
197 }
198
199 float16& operator/=(const float16& rhs) {
200 float r = cpu_half2float(*this) / cpu_half2float(rhs);
201 *this = cpu_float2half_rn(r);
202 return *this;
203 }
204};
205
206inline std::ostream& operator<<(std::ostream& stream, const float16& val) {
207 stream << cpu_half2float(val);
208 return stream;
209}
210
211inline float16 operator+(const float16& lhs, const float16& rhs) {
212 float16 result = lhs;
213 result += rhs;
214 return result;
215}
216
217inline float16 operator-(const float16& lhs, const float16& rhs) {
218 float16 result = lhs;
219 result -= rhs;
220 return result;
221}
222
223inline float16 operator*(const float16& lhs, const float16& rhs) {
224 float16 result = lhs;
225 result *= rhs;
226 return result;
227}
228
229inline float16 operator/(const float16& lhs, const float16& rhs) {
230 float16 result = lhs;
231 result /= rhs;
232 return result;
233}
234
235inline bool operator<(const float16& lhs, const float16& rhs) {
236 return cpu_half2float(lhs) < cpu_half2float(rhs);
237}
238
239inline bool operator<=(const float16& lhs, const float16& rhs) {
240 return cpu_half2float(lhs) <= cpu_half2float(rhs);
241}
242
243inline bool operator>(const float16& lhs, const float16& rhs) {
244 return cpu_half2float(lhs) > cpu_half2float(rhs);
245}
246
247inline bool operator>=(const float16& lhs, const float16& rhs) {
248 return cpu_half2float(lhs) >= cpu_half2float(rhs);
249}
250
251inline float16 cpu_float2half_rn(float f) {
252 float16 ret;
253
254 static_assert(
255 sizeof(unsigned int) == sizeof(float),
256 "Programming error sizeof(unsigned int) != sizeof(float)");
257
258 unsigned* xp = reinterpret_cast<unsigned int*>(&f);
259 unsigned x = *xp;
260 unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
261 unsigned sign, exponent, mantissa;
262
263 // Get rid of +NaN/-NaN case first.
264 if (u > 0x7f800000) {
265 ret.x = 0x7fffU;
266 return ret;
267 }
268
269 sign = ((x >> 16) & 0x8000);
270
271 // Get rid of +Inf/-Inf, +0/-0.
272 if (u > 0x477fefff) {
273 ret.x = sign | 0x7c00U;
274 return ret;
275 }
276 if (u < 0x33000001) {
277 ret.x = (sign | 0x0000);
278 return ret;
279 }
280
281 exponent = ((u >> 23) & 0xff);
282 mantissa = (u & 0x7fffff);
283
284 if (exponent > 0x70) {
285 shift = 13;
286 exponent -= 0x70;
287 } else {
288 shift = 0x7e - exponent;
289 exponent = 0;
290 mantissa |= 0x800000;
291 }
292 lsb = (1 << shift);
293 lsb_s1 = (lsb >> 1);
294 lsb_m1 = (lsb - 1);
295
296 // Round to nearest even.
297 remainder = (mantissa & lsb_m1);
298 mantissa >>= shift;
299 if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
300 ++mantissa;
301 if (!(mantissa & 0x3ff)) {
302 ++exponent;
303 mantissa = 0;
304 }
305 }
306
307 ret.x = (sign | (exponent << 10) | mantissa);
308
309 return ret;
310}
311
312inline float cpu_half2float(float16 h) {
313 unsigned sign = ((h.x >> 15) & 1);
314 unsigned exponent = ((h.x >> 10) & 0x1f);
315 unsigned mantissa = ((h.x & 0x3ff) << 13);
316
317 if (exponent == 0x1f) { /* NaN or Inf */
318 mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
319 exponent = 0xff;
320 } else if (!exponent) { /* Denorm or Zero */
321 if (mantissa) {
322 unsigned int msb;
323 exponent = 0x71;
324 do {
325 msb = (mantissa & 0x400000);
326 mantissa <<= 1; /* normalize */
327 --exponent;
328 } while (!msb);
329 mantissa &= 0x7fffff; /* 1.mantissa is implicit */
330 }
331 } else {
332 exponent += 0x70;
333 }
334
335 unsigned temp = ((sign << 31) | (exponent << 23) | mantissa);
336
337 void* rp = &temp;
338 return *(float*)rp;
339}
340
341} // namespace gloo
342