1// Generated from "/code/pytorch/third_party/nvfuser/runtime/helpers.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* helpers_cu = R"(
7#define NVFUSER_DEFINE_MAGIC_ZERO \
8 __shared__ int nvfuser_zero_s; \
9 if (threadIdx.x == 0) \
10 nvfuser_zero_s = 0; \
11 __syncthreads(); \
12 atomicMin(&nvfuser_zero_s, threadIdx.x); \
13 int nvfuser_zero = nvfuser_zero_s;
14
15#define NVFUSER_UPDATE_MAGIC_ZERO \
16 do { \
17 nvfuser_zero <<= 1; \
18 } while (0);
19
20__device__ constexpr int ceilDiv(int a, int b) {
21 return (a + b - 1) / b;
22}
23
24__device__ constexpr int64_t ceilDiv(int64_t a, int64_t b) {
25 return (a + b - 1) / b;
26}
27
28__device__ constexpr int64_t ceilDiv(int64_t a, int b) {
29 return ceilDiv(a, (int64_t)b);
30}
31
32__device__ constexpr int64_t ceilDiv(int a, int64_t b) {
33 return ceilDiv((int64_t)a, b);
34}
35
36__device__ constexpr double ceilDiv(double a, double b) {
37 return std::ceil(a / b);
38}
39
40__device__ constexpr double ceilDiv(double a, int64_t b) {
41 return std::ceil(a / b);
42}
43
44__device__ constexpr double ceilDiv(int64_t a, double b) {
45 return std::ceil(a / b);
46}
47
48// Monotonic and precise lerp is described here:
49// https://math.stackexchange.com/a/1798323
50__device__ double lerp(double start, double end, double weight) {
51 if (weight < 0.5) {
52 return start + weight * (end - start);
53 } else {
54 return end - (end - start) * (1.0 - weight);
55 }
56}
57
58__device__ float lerp(float start, float end, float weight) {
59 if (weight < 0.5f) {
60 return start + weight * (end - start);
61 } else {
62 return end - (end - start) * (1.0f - weight);
63 }
64}
65
66__device__ std::complex<double> lerp(
67 std::complex<double> start,
68 std::complex<double> end,
69 std::complex<double> weight) {
70 if (abs(weight) < 0.5) {
71 return start + weight * (end - start);
72 } else {
73 return end - (end - start) * (1.0 - weight);
74 }
75}
76
77__device__ std::complex<float> lerp(
78 std::complex<float> start,
79 std::complex<float> end,
80 std::complex<float> weight) {
81 if (abs(weight) < 0.5f) {
82 return start + weight * (end - start);
83 } else {
84 return end - (end - start) * (1.0f - weight);
85 }
86}
87
88__device__ float lerp(float start, float end, double weight) {
89 return lerp(start, end, static_cast<float>(weight));
90}
91
92__device__ constexpr int max(int a, int b) {
93 return a > b ? a : b;
94}
95
96__device__ constexpr int64_t max(int64_t a, int b) {
97 return a > (int64_t)b ? a : (int64_t)b;
98}
99
100__device__ constexpr int64_t max(int a, int64_t b) {
101 return (int64_t)a > b ? (int64_t)a : b;
102}
103
104__device__ constexpr int64_t max(int64_t a, int64_t b) {
105 return a > b ? a : b;
106}
107
108__device__ double fmax(double a, double b) {
109 // check and propagate NaN
110 if (a != a) {
111 return a;
112 } else if (b != b) {
113 return b;
114 } else {
115 return a > b ? a : b;
116 }
117}
118
119__device__ float fmax(float a, float b) {
120 // check and propagate NaN
121 if (a != a) {
122 return a;
123 } else if (b != b) {
124 return b;
125 } else {
126 return a > b ? a : b;
127 }
128}
129
130__device__ constexpr int min(int a, int b) {
131 return a > b ? b : a;
132}
133
134__device__ constexpr int64_t min(int64_t a, int b) {
135 return (int64_t)a > b ? b : (int64_t)a;
136}
137
138__device__ constexpr int64_t min(int a, int64_t b) {
139 return a > (int64_t)b ? (int64_t)b : a;
140}
141
142__device__ constexpr int64_t min(int64_t a, int64_t b) {
143 return a > b ? b : a;
144}
145
146__device__ double fmin(double a, double b) {
147 // check and propagate NaN
148 if (a != a) {
149 return a;
150 } else if (b != b) {
151 return b;
152 } else {
153 return a > b ? b : a;
154 }
155}
156
157__device__ float fmin(float a, float b) {
158 // check and propagate NaN
159 if (a != a) {
160 return a;
161 } else if (b != b) {
162 return b;
163 } else {
164 return a > b ? b : a;
165 }
166}
167
168__device__ constexpr int alignBufferSize(int buffer, int size) {
169 return (buffer + (size - 1)) & ~(size - 1);
170}
171
172__device__ double clamp(double x, double minv, double maxv) {
173 return fmin(fmax(x, minv), maxv);
174}
175
176__device__ float clamp(float x, double minv, double maxv) {
177 return fmin(fmax((double)x, minv), maxv);
178}
179
180__device__ int clamp(int x, int64_t minv, int64_t maxv) {
181 return min(max((int64_t)x, minv), maxv);
182}
183
184__device__ int64_t clamp(int64_t x, int64_t minv, int64_t maxv) {
185 return min(max(x, minv), maxv);
186}
187
188__device__ double frac(double x) {
189 return x - trunc(x);
190}
191
192__device__ float frac(float x) {
193 return x - trunc(x);
194}
195
196__device__ double reciprocal(double x) {
197 return 1 / x;
198}
199
200__device__ float reciprocal(float x) {
201 return 1 / x;
202}
203
204__device__ std::complex<double> reciprocal(std::complex<double> x) {
205 return 1.0 / x;
206}
207
208__device__ std::complex<float> reciprocal(std::complex<float> x) {
209 return 1.0f / x;
210}
211
212__device__ double relu(double x) {
213 return x <= 0 ? 0 : x;
214}
215
216__device__ float relu(float x) {
217 return x <= 0 ? 0 : x;
218}
219
220__device__ float relu(int64_t x) {
221 return x <= 0 ? 0 : x;
222}
223
224__device__ float relu(int x) {
225 return x <= 0 ? 0 : x;
226}
227
228__device__ double remainder(double a, double b) {
229 auto mod = ::fmod(a, b);
230 if ((mod != 0) && ((b < 0) != (mod < 0)))
231 mod += b;
232 return mod;
233}
234
235__device__ float remainder(float a, float b) {
236 auto mod = ::fmod(a, b);
237 if ((mod != 0) && ((b < 0) != (mod < 0)))
238 mod += b;
239 return mod;
240}
241
242__device__ double sigmoid(double x) {
243 return 1.0 / (1.0 + exp(-x));
244}
245
246__device__ float sigmoid(float x) {
247 return 1.0f / (1.0f + exp(-x));
248}
249
250__device__ std::complex<double> sigmoid(std::complex<double> x) {
251 return 1.0 / (1.0 + exp(-x));
252}
253
254__device__ std::complex<float> sigmoid(std::complex<float> x) {
255 return 1.0f / (1.0f + exp(-x));
256}
257
258__device__ double silu(double x) {
259 return x * sigmoid(x);
260}
261
262__device__ float silu(float x) {
263 return x * sigmoid(x);
264}
265
266__device__ double threshold(double x, double t, double v) {
267 return x <= t ? v : x;
268}
269
270__device__ float threshold(float x, double t, double v) {
271 return x <= t ? v : x;
272}
273
274__device__ std::complex<double> where(
275 bool c,
276 std::complex<double> a,
277 std::complex<double> b) {
278 return c ? a : b;
279}
280
281__device__ std::complex<float> where(
282 bool c,
283 std::complex<float> a,
284 std::complex<float> b) {
285 return c ? a : b;
286}
287
288__device__ int threshold(int x, int64_t t, int64_t v) {
289 return x <= t ? v : x;
290}
291
292__device__ int64_t threshold(int64_t x, int64_t t, int64_t v) {
293 return x <= t ? v : x;
294}
295
296__device__ double where(bool c, double a, double b) {
297 return c ? a : b;
298}
299
300__device__ float where(bool c, float a, float b) {
301 return c ? a : b;
302}
303
304__device__ int64_t where(bool c, int64_t a, int64_t b) {
305 return c ? a : b;
306}
307
308__device__ int where(bool c, int a, int b) {
309 return c ? a : b;
310}
311
312__device__ int64_t where(bool c, int64_t a, int b) {
313 return c ? a : b;
314}
315
316__device__ int64_t where(bool c, int a, int64_t b) {
317 return c ? a : b;
318}
319
320__device__ constexpr int64_t remainder(int64_t a, int64_t b) {
321 auto mod = a % b;
322 if ((mod != 0) && ((b < 0) != (mod < 0)))
323 mod += b;
324 return mod;
325}
326
327__device__ constexpr int remainder(int a, int b) {
328 auto mod = a % b;
329 if ((mod != 0) && ((b < 0) != (mod < 0)))
330 mod += b;
331 return mod;
332}
333
334__device__ constexpr int64_t fmod(int64_t a, int64_t b) {
335 return a % b;
336}
337
338__device__ constexpr int fmod(int a, int b) {
339 return a % b;
340}
341
342__device__ constexpr double fmod(double a, double b) {
343 return ::fmod(a, b);
344}
345
346__device__ constexpr float fmod(float a, float b) {
347 return ::fmod(a, b);
348}
349
350template <typename T>
351__device__ T pow(T a, T b) {
352 if (b < 0) {
353 if (a == 1) {
354 return 1;
355 } else if (a == -1) {
356 auto negative = (-b) % static_cast<T>(2);
357 return negative ? -1 : 1;
358 } else {
359 return 0;
360 }
361 } else {
362 T result = 1;
363 while (b) {
364 if (b & 1) {
365 result *= a;
366 }
367 b /= 2;
368 a *= a;
369 }
370 return result;
371 }
372}
373
374template __device__ int pow<int>(int a, int b);
375template __device__ int64_t pow<int64_t>(int64_t a, int64_t b);
376
377template <>
378__device__ float pow<float>(float a, float b) {
379 return ::pow(a, b);
380}
381
382template <>
383__device__ double pow<double>(double a, double b) {
384 return ::pow(a, b);
385}
386
387__device__ float pow(float a, int b) {
388 return pow(a, (float)b);
389}
390
391__device__ double pow(double a, int b) {
392 return pow(a, (double)b);
393}
394
395__device__ float pow(float a, int64_t b) {
396 return pow(a, (float)b);
397}
398
399__device__ double pow(double a, int64_t b) {
400 return pow(a, (double)b);
401}
402
403int64_t pow(int64_t a, int b) {
404 return pow(a, (int64_t)b);
405}
406
407int64_t pow(int a, int64_t b) {
408 return pow((int64_t)a, b);
409}
410
411template <int size, int align = size>
412struct alignas(align) TypelessData {
413 int8_t data[size];
414
415 template <typename T, std::enable_if_t<sizeof(T) == size, int> _ = 0>
416 TypelessData(T x) {
417 *reinterpret_cast<T*>(data) = x;
418 }
419
420 template <typename T, std::enable_if_t<sizeof(T) == size, int> _ = 0>
421 operator T() {
422 return *reinterpret_cast<T*>(data);
423 }
424};
425
426template <typename T>
427TypelessData<sizeof(T), alignof(T)> erase_type(T x) {
428 return x;
429}
430
431template <typename T>
432bool isfinite(T x) {
433 return ::isfinite(x);
434}
435
436template <typename T>
437bool isfinite(std::complex<T> x) {
438 return ::isfinite(std::real(x)) && ::isfinite(std::imag(x));
439}
440
441template <typename T>
442bool isinf(T x) {
443 return ::isinf(x);
444}
445
446template <typename T>
447bool isinf(std::complex<T> x) {
448 return ::isinf(std::real(x)) || ::isinf(std::imag(x));
449}
450
451////////////////////////////////////////////////////////////
452// TODO: the following overloads are only needed for CUDA //
453// 10.2 Please remove when CUDA 10.2 support is dropped //
454////////////////////////////////////////////////////////////
455
456bool isinf(int64_t x) {
457 return false;
458}
459
460bool isinf(int x) {
461 return false;
462}
463
464bool isinf(short x) {
465 return false;
466}
467
468bool isinf(char x) {
469 return false;
470}
471
472bool isinf(unsigned char x) {
473 return false;
474}
475
476bool isinf(bool x) {
477 return false;
478}
479
480bool isfinite(int64_t x) {
481 return true;
482}
483
484bool isfinite(int x) {
485 return true;
486}
487
488bool isfinite(short x) {
489 return true;
490}
491
492bool isfinite(char x) {
493 return true;
494}
495
496bool isfinite(unsigned char x) {
497 return true;
498}
499
500bool isfinite(bool x) {
501 return true;
502}
503
504////////////////////////////////////////////////////////////
505// End TODO //
506////////////////////////////////////////////////////////////
507
508template <typename T>
509bool isnan(T x) {
510 return x != x;
511}
512
513template <typename T>
514bool isneginf(T x) {
515 return x < 0 && isinf(x);
516}
517
518template <typename T>
519bool isposinf(T x) {
520 return x > 0 && isinf(x);
521}
522
523template <typename T>
524bool isreal(T x) {
525 return true;
526}
527
528template <typename T>
529bool isreal(std::complex<T> x) {
530 return std::imag(x) == 0;
531}
532
533// Return the current value of the cycle counter
534__device__ inline int64_t readCycleCounter() {
535 // Ensures preceding memory operations are completed. Doing this
536 // would make sense for measuring elapsed times enclosed with this
537 // function.
538 __threadfence();
539 return clock64();
540}
541
542__device__ float print_impl(const char* name, float value) {
543 printf(
544 "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
545 name,
546 value,
547 (int)threadIdx.x,
548 (int)threadIdx.y,
549 (int)threadIdx.z,
550 (int)blockIdx.x,
551 (int)blockIdx.y,
552 (int)blockIdx.z);
553 return value;
554}
555
556__device__ double print_impl(const char* name, double value) {
557 printf(
558 "%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
559 name,
560 value,
561 (int)threadIdx.x,
562 (int)threadIdx.y,
563 (int)threadIdx.z,
564 (int)blockIdx.x,
565 (int)blockIdx.y,
566 (int)blockIdx.z);
567 return value;
568}
569
570__device__ int print_impl(const char* name, int value) {
571 printf(
572 "%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
573 name,
574 value,
575 (int)threadIdx.x,
576 (int)threadIdx.y,
577 (int)threadIdx.z,
578 (int)blockIdx.x,
579 (int)blockIdx.y,
580 (int)blockIdx.z);
581 return value;
582}
583
584__device__ int64_t print_impl(const char* name, int64_t value) {
585 printf(
586 "%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
587 name,
588 value,
589 (int)threadIdx.x,
590 (int)threadIdx.y,
591 (int)threadIdx.z,
592 (int)blockIdx.x,
593 (int)blockIdx.y,
594 (int)blockIdx.z);
595 return value;
596}
597
598__device__ bool print_impl(const char* name, bool value) {
599 printf(
600 "%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
601 name,
602 value ? "true" : "false",
603 (int)threadIdx.x,
604 (int)threadIdx.y,
605 (int)threadIdx.z,
606 (int)blockIdx.x,
607 (int)blockIdx.y,
608 (int)blockIdx.z);
609 return value;
610}
611
612__device__ __half print_impl(const char* name, __half value) {
613 printf(
614 "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
615 name,
616 __half2float(value),
617 (int)threadIdx.x,
618 (int)threadIdx.y,
619 (int)threadIdx.z,
620 (int)blockIdx.x,
621 (int)blockIdx.y,
622 (int)blockIdx.z);
623 return value;
624}
625
626__device__ __bfloat print_impl(const char* name, __bfloat value) {
627 printf(
628 "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
629 name,
630 __bfloat2float(value),
631 (int)threadIdx.x,
632 (int)threadIdx.y,
633 (int)threadIdx.z,
634 (int)blockIdx.x,
635 (int)blockIdx.y,
636 (int)blockIdx.z);
637 return value;
638}
639
640#define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__))
641
642template <typename OutT, typename IndexT, typename InputT>
643__device__ OutT arange(IndexT index, InputT start, InputT step) {
644 return start + step * index;
645}
646)";
647
648} // namespace nvfuser_resources
649