1 | #include <executor_launch_params.h> |
2 | |
3 | #include <ATen/cuda/CUDAContext.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | namespace fuser { |
8 | namespace cuda { |
9 | |
10 | void LaunchParams::assertValid() { |
11 | TORCH_INTERNAL_ASSERT( |
12 | bdimx() * bdimy() * bdimz() > 0 && |
13 | bdimx() * bdimy() * bdimz() <= |
14 | (int64_t)at::cuda::getCurrentDeviceProperties() |
15 | ->maxThreadsPerMultiProcessor, |
16 | "Selected invalid number of threads for cuda: " , |
17 | bdimx() * bdimy() * bdimz()); |
18 | TORCH_INTERNAL_ASSERT( |
19 | gdimx() > 0 && gdimx() < (std::int64_t(1) << 32) - 1, |
20 | "Invalid number of blocks in x direction: " , |
21 | gdimx()); |
22 | TORCH_INTERNAL_ASSERT( |
23 | gdimy() > 0 && gdimy() <= 65535, |
24 | "Invalid number of blocks in y direction: " , |
25 | gdimy()); |
26 | TORCH_INTERNAL_ASSERT( |
27 | gdimz() > 0 && gdimz() <= 65535, |
28 | "Invalid number of blocks in z direction: " , |
29 | gdimz()); |
30 | } |
31 | |
32 | void LaunchParams::bind(int64_t val, ParallelType p_type) { |
33 | switch (p_type) { |
34 | case ParallelType::TIDx: |
35 | checkAndSet(val, bdimx_, "blockDim.x" ); |
36 | break; |
37 | case ParallelType::BIDx: |
38 | checkAndSet(val, gdimx_, "gridDim.x" ); |
39 | break; |
40 | case ParallelType::TIDy: |
41 | checkAndSet(val, bdimy_, "blockDim.y" ); |
42 | break; |
43 | case ParallelType::BIDy: |
44 | checkAndSet(val, gdimy_, "gridDim.y" ); |
45 | break; |
46 | case ParallelType::TIDz: |
47 | checkAndSet(val, bdimz_, "blockdim.z" ); |
48 | break; |
49 | case ParallelType::BIDz: |
50 | checkAndSet(val, gdimz_, "gridDim.z" ); |
51 | break; |
52 | default: |
53 | TORCH_INTERNAL_ASSERT( |
54 | false, |
55 | "Tried to bind invalid parallel type in launch config: " , |
56 | p_type); |
57 | } |
58 | assertValid(); |
59 | } |
60 | |
61 | int64_t LaunchParams::getDim(ParallelType p_type) const { |
62 | switch (p_type) { |
63 | case ParallelType::TIDx: |
64 | return bdimx(); |
65 | case ParallelType::BIDx: |
66 | return gdimx(); |
67 | case ParallelType::TIDy: |
68 | return bdimy(); |
69 | case ParallelType::BIDy: |
70 | return gdimy(); |
71 | case ParallelType::TIDz: |
72 | return bdimz(); |
73 | case ParallelType::BIDz: |
74 | return gdimz(); |
75 | default: |
76 | TORCH_INTERNAL_ASSERT( |
77 | false, |
78 | "Tried to get with invalid parallel type in launch config: " , |
79 | p_type); |
80 | } |
81 | } |
82 | |
83 | bool LaunchParams::hasDim(ParallelType p_type) const { |
84 | return getRawVal(p_type) != UNINITIALIZED_VAL; |
85 | } |
86 | |
87 | const int64_t& LaunchParams::getRawVal(ParallelType p_type) const { |
88 | switch (p_type) { |
89 | case ParallelType::TIDx: |
90 | return bdimx_; |
91 | case ParallelType::BIDx: |
92 | return gdimx_; |
93 | case ParallelType::TIDy: |
94 | return bdimy_; |
95 | case ParallelType::BIDy: |
96 | return gdimy_; |
97 | case ParallelType::TIDz: |
98 | return bdimz_; |
99 | case ParallelType::BIDz: |
100 | return gdimz_; |
101 | default: |
102 | TORCH_INTERNAL_ASSERT( |
103 | false, |
104 | "Tried to get with invalid parallel type in launch config: " , |
105 | p_type); |
106 | } |
107 | } |
108 | |
109 | bool LaunchParams::operator==(const LaunchParams& other) const { |
110 | return gdimx_ == other.gdimx_ && gdimy_ == other.gdimy_ && |
111 | bdimx_ == other.bdimx_ && bdimy_ == other.bdimy_ && smem_ == other.smem_; |
112 | } |
113 | |
114 | void LaunchParams::print() const { |
115 | std::cout << toString(); |
116 | } |
117 | |
118 | std::string LaunchParams::toString() const { |
119 | std::stringstream ss; |
120 | ss << "Launch Parameters: " |
121 | << "BlockDim.x = " << (bdimx_ == UNINITIALIZED_VAL ? -1 : bdimx_) << ", " |
122 | << "BlockDim.y = " << (bdimy_ == UNINITIALIZED_VAL ? -1 : bdimy_) << ", " |
123 | << "BlockDim.z = " << (bdimz_ == UNINITIALIZED_VAL ? -1 : bdimz_) << ", " |
124 | << "GridDim.x = " << (gdimx_ == UNINITIALIZED_VAL ? -1 : gdimx_) << ", " |
125 | << "GridDim.y = " << (gdimy_ == UNINITIALIZED_VAL ? -1 : gdimy_) << ", " |
126 | << "GridDim.z = " << (gdimz_ == UNINITIALIZED_VAL ? -1 : gdimz_) << ", " |
127 | << "Smem Size = " << smem() << "\n" ; |
128 | return ss.str(); |
129 | } |
130 | |
131 | } // namespace cuda |
132 | } // namespace fuser |
133 | } // namespace jit |
134 | } // namespace torch |
135 | |