1// Generated from "/code/pytorch/third_party/nvfuser/runtime/swizzle.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* swizzle_cu = R"(
7// Utility macro for this file
8#define DEVICE_INLINE __device__ inline
9
10// Utility class for 2D swizzle:
11template <typename index_t>
12struct IndexGeneric {
13 const index_t x = 0, y = 0;
14 DEVICE_INLINE IndexGeneric(index_t x_, index_t y_) : x(x_), y(y_) {}
15};
16
17// Default type for integration
18using Index2D = IndexGeneric<nvfuser_index_t>;
19
20// Small type for unit computation
21using Index2DInt = IndexGeneric<int>;
22
23// ------------------------------------------------------------
24// Swizzle Definitions
25// for each swizzle name:
26// un(Swizzle Name) e.g. unZShape is the inverse of ZShape,
27// (unswizzle is needed for inlining and is currently not actively used.)
28// ------------------------------------------------------------
29
30// Unit Z swizzle:
31// Alternate directions of Y dimension:
32// 1 2 3 1 2 3
33// 4 5 6 => 6 5 4
34// 7 8 9 7 8 9
35DEVICE_INLINE Index2D ZShape(Index2D in, Index2D unit_dim) {
36 return Index2D(in.x, in.x % 2 == 0 ? in.y : (unit_dim.y - in.y - 1));
37}
38
39// ZShape is inverse of itself
40DEVICE_INLINE Index2D unZShape(Index2D in, Index2D unit_dim) {
41 return ZShape(in, unit_dim);
42}
43
44// Block cyclic Xor swizzle: (bank conflict removal)
45// Apply cyclic Xor within blocks:
46// Example: cyclic Xor
47// 1 2 3 4 1 2 3 4
48// 5 6 7 8 6 5 8 7
49// 9 10 11 12 => 11 12 9 10
50// 13 14 15 16 16 15 14 13
51// Note:
52DEVICE_INLINE Index2D Xor(Index2D in, Index2DInt unit_dim) {
53 // Need to validate in swizzle configuration:
54 // unit_dim.x == unit_dim.y
55 return Index2D(in.x, (in.y ^ in.x));
56}
57
58// Inverse of Xor is itself
59DEVICE_INLINE Index2D unXor(Index2D in, Index2DInt unit_dim) {
60 return Xor(in, unit_dim);
61}
62
63// Scatter swizzle:
64// Corresponds to the data layout out of ldmatrix intrinsic.
65// supported dimensions are : 8x4, 16x4, 32x4
66template <int row_size>
67DEVICE_INLINE Index2D Scatter(Index2D in) {
68 static_assert(row_size == 8 || row_size == 16 || row_size == 32);
69 return Index2D((in.y * row_size + in.x) / 4, in.x % 4);
70}
71
72template <int row_size>
73DEVICE_INLINE Index2D unScatter(Index2D in) {
74 static_assert(row_size == 8 || row_size == 16 || row_size == 32);
75 return Index2D(in.y + (in.x % (row_size / 4)) * 4, in.x / (row_size / 4));
76}
77
78#undef DEVICE_INLINE
79)";
80
81} // namespace nvfuser_resources
82