1#include <executor_launch_params.h>
2
3#include <ATen/cuda/CUDAContext.h>
4
5namespace torch {
6namespace jit {
7namespace fuser {
8namespace cuda {
9
10void LaunchParams::assertValid() {
11 TORCH_INTERNAL_ASSERT(
12 bdimx() * bdimy() * bdimz() > 0 &&
13 bdimx() * bdimy() * bdimz() <=
14 (int64_t)at::cuda::getCurrentDeviceProperties()
15 ->maxThreadsPerMultiProcessor,
16 "Selected invalid number of threads for cuda: ",
17 bdimx() * bdimy() * bdimz());
18 TORCH_INTERNAL_ASSERT(
19 gdimx() > 0 && gdimx() < (std::int64_t(1) << 32) - 1,
20 "Invalid number of blocks in x direction: ",
21 gdimx());
22 TORCH_INTERNAL_ASSERT(
23 gdimy() > 0 && gdimy() <= 65535,
24 "Invalid number of blocks in y direction: ",
25 gdimy());
26 TORCH_INTERNAL_ASSERT(
27 gdimz() > 0 && gdimz() <= 65535,
28 "Invalid number of blocks in z direction: ",
29 gdimz());
30}
31
32void LaunchParams::bind(int64_t val, ParallelType p_type) {
33 switch (p_type) {
34 case ParallelType::TIDx:
35 checkAndSet(val, bdimx_, "blockDim.x");
36 break;
37 case ParallelType::BIDx:
38 checkAndSet(val, gdimx_, "gridDim.x");
39 break;
40 case ParallelType::TIDy:
41 checkAndSet(val, bdimy_, "blockDim.y");
42 break;
43 case ParallelType::BIDy:
44 checkAndSet(val, gdimy_, "gridDim.y");
45 break;
46 case ParallelType::TIDz:
47 checkAndSet(val, bdimz_, "blockdim.z");
48 break;
49 case ParallelType::BIDz:
50 checkAndSet(val, gdimz_, "gridDim.z");
51 break;
52 default:
53 TORCH_INTERNAL_ASSERT(
54 false,
55 "Tried to bind invalid parallel type in launch config: ",
56 p_type);
57 }
58 assertValid();
59}
60
61int64_t LaunchParams::getDim(ParallelType p_type) const {
62 switch (p_type) {
63 case ParallelType::TIDx:
64 return bdimx();
65 case ParallelType::BIDx:
66 return gdimx();
67 case ParallelType::TIDy:
68 return bdimy();
69 case ParallelType::BIDy:
70 return gdimy();
71 case ParallelType::TIDz:
72 return bdimz();
73 case ParallelType::BIDz:
74 return gdimz();
75 default:
76 TORCH_INTERNAL_ASSERT(
77 false,
78 "Tried to get with invalid parallel type in launch config: ",
79 p_type);
80 }
81}
82
83bool LaunchParams::hasDim(ParallelType p_type) const {
84 return getRawVal(p_type) != UNINITIALIZED_VAL;
85}
86
87const int64_t& LaunchParams::getRawVal(ParallelType p_type) const {
88 switch (p_type) {
89 case ParallelType::TIDx:
90 return bdimx_;
91 case ParallelType::BIDx:
92 return gdimx_;
93 case ParallelType::TIDy:
94 return bdimy_;
95 case ParallelType::BIDy:
96 return gdimy_;
97 case ParallelType::TIDz:
98 return bdimz_;
99 case ParallelType::BIDz:
100 return gdimz_;
101 default:
102 TORCH_INTERNAL_ASSERT(
103 false,
104 "Tried to get with invalid parallel type in launch config: ",
105 p_type);
106 }
107}
108
109bool LaunchParams::operator==(const LaunchParams& other) const {
110 return gdimx_ == other.gdimx_ && gdimy_ == other.gdimy_ &&
111 bdimx_ == other.bdimx_ && bdimy_ == other.bdimy_ && smem_ == other.smem_;
112}
113
114void LaunchParams::print() const {
115 std::cout << toString();
116}
117
118std::string LaunchParams::toString() const {
119 std::stringstream ss;
120 ss << "Launch Parameters: "
121 << "BlockDim.x = " << (bdimx_ == UNINITIALIZED_VAL ? -1 : bdimx_) << ", "
122 << "BlockDim.y = " << (bdimy_ == UNINITIALIZED_VAL ? -1 : bdimy_) << ", "
123 << "BlockDim.z = " << (bdimz_ == UNINITIALIZED_VAL ? -1 : bdimz_) << ", "
124 << "GridDim.x = " << (gdimx_ == UNINITIALIZED_VAL ? -1 : gdimx_) << ", "
125 << "GridDim.y = " << (gdimy_ == UNINITIALIZED_VAL ? -1 : gdimy_) << ", "
126 << "GridDim.z = " << (gdimz_ == UNINITIALIZED_VAL ? -1 : gdimz_) << ", "
127 << "Smem Size = " << smem() << "\n";
128 return ss.str();
129}
130
131} // namespace cuda
132} // namespace fuser
133} // namespace jit
134} // namespace torch
135