1#include <lower_bank_conflict.h>
2
3#include <dynamic_type.h>
4#include <expr_evaluator.h>
5#include <kernel_ir.h>
6#include <kernel_ir_dispatch.h>
7#include <type.h>
8
9#include <unordered_set>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16namespace {
17
18bool isSmemTensorIndex(Val* value) {
19 return value->isA<kir::TensorIndex>() &&
20 value->as<kir::TensorIndex>()->view()->getMemoryType() ==
21 MemoryType::Shared;
22}
23
24int64_t getVectorizeSize(kir::TensorIndex* ti) {
25 for (auto id : ti->view()->domain()->domain()) {
26 if (!isParallelTypeVectorize(id->getParallelType())) {
27 continue;
28 }
29
30 ExpressionEvaluator expr_eval(id->fusion());
31 auto vector_size_optional = expr_eval.evaluate(id->extent());
32
33 TORCH_INTERNAL_ASSERT(
34 vector_size_optional.has_value(),
35 "Could not evaluate constant value bound to vectorized dim.");
36
37 return vector_size_optional->as<int64_t>();
38 }
39 return 1;
40}
41
42inline int64_t getPhaseSize(int64_t word_size_bytes) {
43 if (word_size_bytes == 16) {
44 return 8;
45 }
46 if (word_size_bytes == 8) {
47 return 16;
48 }
49 return 32;
50}
51
52bool isThreadIdx(const std::string& name) {
53 return name == "threadIdx.x" || name == "threadIdx.y" ||
54 name == "threadIdx.z";
55}
56
57bool isBlockIdx(const std::string& name) {
58 return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z";
59}
60
61bool isBlockDim(const std::string& name) {
62 return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z";
63}
64
65bool isGridDim(const std::string& name) {
66 return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z";
67}
68
69ParallelType getParallelType(const std::string& name) {
70 if (name == "threadIdx.x") {
71 return ParallelType::TIDx;
72 } else if (name == "threadIdx.y") {
73 return ParallelType::TIDy;
74 } else if (name == "threadIdx.z") {
75 return ParallelType::TIDz;
76 } else if (name == "blockIdx.x") {
77 return ParallelType::BIDx;
78 } else if (name == "blockIdx.y") {
79 return ParallelType::BIDy;
80 } else if (name == "blockIdx.z") {
81 return ParallelType::BIDz;
82 }
83 TORCH_INTERNAL_ASSERT(false, "Not a parallel type");
84}
85
86std::vector<int64_t> evaluateAddressesOnFirstPhase(
87 kir::TensorIndex* ti,
88 const std::vector<kir::ForLoop*>& for_loops,
89 c10::optional<LaunchParams> launch_params,
90 const ExpressionEvaluator& expr_eval_common) {
91 std::vector<int64_t> addresses;
92 const auto word_size_bytes =
93 dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti);
94 int64_t phase_size = getPhaseSize(word_size_bytes);
95
96 if (launch_params.has_value()) {
97 phase_size = std::min<int64_t>(phase_size, launch_params->nThreads());
98 }
99
100 for (int64_t linear_tidx : c10::irange(phase_size)) {
101 int64_t tidx = linear_tidx;
102 int64_t tidy = 0;
103 int64_t tidz = 0;
104 if (launch_params.has_value()) {
105 tidy = tidx / launch_params->bdimx();
106 tidx = tidx % launch_params->bdimx();
107 tidz = tidy / launch_params->bdimy();
108 tidy = tidy % launch_params->bdimy();
109 }
110 int64_t index = 0;
111 // make a copy of the expression evaluator
112 ExpressionEvaluator expr_eval = expr_eval_common;
113 expr_eval.bind("threadIdx.x", tidx);
114 expr_eval.bind("threadIdx.y", tidy);
115 expr_eval.bind("threadIdx.z", tidz);
116 for (auto fl : for_loops) {
117 if (fl->index()->isA<NamedScalar>()) {
118 auto name = fl->index()->as<NamedScalar>()->name();
119 TORCH_INTERNAL_ASSERT(
120 isThreadIdx(name) || isBlockIdx(name), "unknow loop index");
121 } else {
122 auto start = expr_eval.evaluate(fl->start())->as<int64_t>();
123 expr_eval.bind(fl->index(), start);
124 }
125 }
126 for (auto ind : ti->indices()) {
127 index += expr_eval.evaluate(ind)->as<int64_t>();
128 }
129 addresses.emplace_back(index * word_size_bytes);
130 }
131 return addresses;
132}
133
134int getConflictWays(const std::vector<int64_t>& addresses) {
135 std::unordered_set<int64_t> words_by_bank[32];
136 for (auto addr : addresses) {
137 int64_t word = addr / 4;
138 int64_t bank = word % 32;
139 words_by_bank[bank].insert(word);
140 }
141 int conflict = 1;
142 for (const auto& words : words_by_bank) {
143 conflict = std::max<int>(conflict, words.size());
144 }
145 return conflict;
146}
147
148class InferLaunchParams : public kir::IrVisitor {
149 public:
150 static c10::optional<LaunchParams> get(
151 const std::vector<Expr*>& exprs,
152 const std::unordered_map<std::string, IntOrDouble>& known_values) {
153 if (exprs.empty()) {
154 return c10::nullopt;
155 }
156 return InferLaunchParams(exprs, known_values).launch_params_;
157 }
158
159 private:
160 InferLaunchParams(
161 const std::vector<Expr*>& exprs,
162 const std::unordered_map<std::string, IntOrDouble>& known_values)
163 : expr_eval_(exprs[0]->fusion()) {
164 for (auto pair : known_values) {
165 expr_eval_.bind(pair.first, pair.second);
166 }
167 handle(exprs);
168 }
169
170 using kir::IrVisitor::handle;
171
172 void handle(Expr* expr) final {
173 if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
174 kir::IrVisitor::handle(expr);
175 return;
176 }
177
178 for (auto fl : for_loops_) {
179 if (fl->index()->isA<NamedScalar>()) {
180 auto name = fl->index()->as<NamedScalar>()->name();
181 if (isThreadIdx(name) || isBlockIdx(name)) {
182 auto ptype = getParallelType(name);
183 auto stop = expr_eval_.evaluate(fl->stop());
184 if (stop.has_value()) {
185 if (!launch_params_.has_value()) {
186 launch_params_ = LaunchParams();
187 }
188 if (launch_params_->getRawVal(ptype) ==
189 LaunchParams::UNINITIALIZED_VAL) {
190 launch_params_->bind(stop->as<int64_t>(), ptype);
191 } else {
192 TORCH_INTERNAL_ASSERT(
193 launch_params_->getDim(ptype) == stop,
194 "Unable to infer launch parameters");
195 }
196 }
197 }
198 }
199 }
200 }
201
202 ExpressionEvaluator expr_eval_;
203 c10::optional<LaunchParams> launch_params_;
204};
205
206class BankConflictInfo : public kir::IrVisitor {
207 public:
208 static std::unordered_map<const Expr*, std::pair<int, int>> get(
209 const std::vector<Expr*>& exprs,
210 c10::optional<LaunchParams> launch_params,
211 const std::unordered_map<std::string, IntOrDouble>& known_values) {
212 if (exprs.empty()) {
213 return {};
214 }
215 return BankConflictInfo(exprs, launch_params, known_values)
216 .bank_conflict_info_;
217 }
218
219 private:
220 BankConflictInfo(
221 const std::vector<Expr*>& exprs,
222 c10::optional<LaunchParams> launch_params,
223 const std::unordered_map<std::string, IntOrDouble>& known_values)
224 : launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) {
225 expr_eval_common_.bind("blockIdx.x", 0);
226 expr_eval_common_.bind("blockIdx.y", 0);
227 expr_eval_common_.bind("blockIdx.z", 0);
228 if (launch_params.has_value()) {
229 expr_eval_common_.bind("blockDim.x", launch_params->bdimx());
230 expr_eval_common_.bind("blockDim.y", launch_params->bdimy());
231 expr_eval_common_.bind("blockDim.z", launch_params->bdimz());
232 expr_eval_common_.bind("gridDim.x", launch_params->gdimx());
233 expr_eval_common_.bind("gridDim.y", launch_params->gdimy());
234 expr_eval_common_.bind("gridDim.z", launch_params->gdimz());
235 }
236 for (auto pair : known_values) {
237 expr_eval_common_.bind(pair.first, pair.second);
238 }
239 handle(exprs);
240 }
241
242 using kir::IrVisitor::handle;
243
244 void handle(Expr* expr) final {
245 if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
246 kir::IrVisitor::handle(expr);
247 return;
248 }
249
250 if (expr->isA<UnaryOp>()) {
251 auto uop = expr->as<UnaryOp>();
252 if (uop->getUnaryOpType() != UnaryOpType::Set) {
253 return;
254 }
255 std::pair<int, int> conflict_ways{0, 0};
256 if (isSmemTensorIndex(uop->in())) {
257 conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
258 uop->in()->as<kir::TensorIndex>(),
259 for_loops_,
260 launch_params_,
261 expr_eval_common_));
262 }
263 if (isSmemTensorIndex(uop->out())) {
264 conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
265 uop->out()->as<kir::TensorIndex>(),
266 for_loops_,
267 launch_params_,
268 expr_eval_common_));
269 }
270 if (conflict_ways.first > 1 || conflict_ways.second > 1) {
271 bank_conflict_info_[expr] = conflict_ways;
272 }
273 } else if (expr->isA<LoadStoreOp>()) {
274 auto ldst = expr->as<LoadStoreOp>();
275 std::pair<int, int> conflict_ways{0, 0};
276 if (isSmemTensorIndex(ldst->in())) {
277 conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
278 ldst->in()->as<kir::TensorIndex>(),
279 for_loops_,
280 launch_params_,
281 expr_eval_common_));
282 }
283 if (isSmemTensorIndex(ldst->out())) {
284 conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
285 ldst->out()->as<kir::TensorIndex>(),
286 for_loops_,
287 launch_params_,
288 expr_eval_common_));
289 }
290 if (conflict_ways.first > 1 || conflict_ways.second > 1) {
291 bank_conflict_info_[expr] = conflict_ways;
292 }
293 }
294 }
295
296 std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_;
297 c10::optional<LaunchParams> launch_params_;
298 ExpressionEvaluator expr_eval_common_;
299};
300
301} // namespace
302
303std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
304 kir::Kernel* kernel,
305 c10::optional<LaunchParams> launch_params,
306 const std::unordered_map<std::string, IntOrDouble>& known_values) {
307 for (auto pair : known_values) {
308 TORCH_CHECK(
309 !isThreadIdx(pair.first),
310 "threadIdx.{x,y,z} should be computed instead of provided");
311 TORCH_CHECK(
312 !isBlockIdx(pair.first),
313 "blockIdx.{x,y,z} should not be provided (they are always zero)");
314 TORCH_CHECK(
315 !isBlockDim(pair.first),
316 "blockDim.{x,y,z} should be provided by launch_params");
317 TORCH_CHECK(
318 !isGridDim(pair.first),
319 "gridDim.{x,y,z} should be provided by launch_params");
320 }
321 if (!launch_params.has_value()) {
322 launch_params =
323 InferLaunchParams::get(kernel->topLevelExprs(), known_values);
324 }
325 return BankConflictInfo::get(
326 kernel->topLevelExprs(), launch_params, known_values);
327}
328
329} // namespace cuda
330} // namespace fuser
331} // namespace jit
332} // namespace torch
333