1 | #include <lower_bank_conflict.h> |
2 | |
3 | #include <dynamic_type.h> |
4 | #include <expr_evaluator.h> |
5 | #include <kernel_ir.h> |
6 | #include <kernel_ir_dispatch.h> |
7 | #include <type.h> |
8 | |
9 | #include <unordered_set> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | namespace { |
17 | |
18 | bool isSmemTensorIndex(Val* value) { |
19 | return value->isA<kir::TensorIndex>() && |
20 | value->as<kir::TensorIndex>()->view()->getMemoryType() == |
21 | MemoryType::Shared; |
22 | } |
23 | |
24 | int64_t getVectorizeSize(kir::TensorIndex* ti) { |
25 | for (auto id : ti->view()->domain()->domain()) { |
26 | if (!isParallelTypeVectorize(id->getParallelType())) { |
27 | continue; |
28 | } |
29 | |
30 | ExpressionEvaluator expr_eval(id->fusion()); |
31 | auto vector_size_optional = expr_eval.evaluate(id->extent()); |
32 | |
33 | TORCH_INTERNAL_ASSERT( |
34 | vector_size_optional.has_value(), |
35 | "Could not evaluate constant value bound to vectorized dim." ); |
36 | |
37 | return vector_size_optional->as<int64_t>(); |
38 | } |
39 | return 1; |
40 | } |
41 | |
42 | inline int64_t getPhaseSize(int64_t word_size_bytes) { |
43 | if (word_size_bytes == 16) { |
44 | return 8; |
45 | } |
46 | if (word_size_bytes == 8) { |
47 | return 16; |
48 | } |
49 | return 32; |
50 | } |
51 | |
52 | bool isThreadIdx(const std::string& name) { |
53 | return name == "threadIdx.x" || name == "threadIdx.y" || |
54 | name == "threadIdx.z" ; |
55 | } |
56 | |
57 | bool isBlockIdx(const std::string& name) { |
58 | return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z" ; |
59 | } |
60 | |
61 | bool isBlockDim(const std::string& name) { |
62 | return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z" ; |
63 | } |
64 | |
65 | bool isGridDim(const std::string& name) { |
66 | return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z" ; |
67 | } |
68 | |
69 | ParallelType getParallelType(const std::string& name) { |
70 | if (name == "threadIdx.x" ) { |
71 | return ParallelType::TIDx; |
72 | } else if (name == "threadIdx.y" ) { |
73 | return ParallelType::TIDy; |
74 | } else if (name == "threadIdx.z" ) { |
75 | return ParallelType::TIDz; |
76 | } else if (name == "blockIdx.x" ) { |
77 | return ParallelType::BIDx; |
78 | } else if (name == "blockIdx.y" ) { |
79 | return ParallelType::BIDy; |
80 | } else if (name == "blockIdx.z" ) { |
81 | return ParallelType::BIDz; |
82 | } |
83 | TORCH_INTERNAL_ASSERT(false, "Not a parallel type" ); |
84 | } |
85 | |
86 | std::vector<int64_t> evaluateAddressesOnFirstPhase( |
87 | kir::TensorIndex* ti, |
88 | const std::vector<kir::ForLoop*>& for_loops, |
89 | c10::optional<LaunchParams> launch_params, |
90 | const ExpressionEvaluator& expr_eval_common) { |
91 | std::vector<int64_t> addresses; |
92 | const auto word_size_bytes = |
93 | dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); |
94 | int64_t phase_size = getPhaseSize(word_size_bytes); |
95 | |
96 | if (launch_params.has_value()) { |
97 | phase_size = std::min<int64_t>(phase_size, launch_params->nThreads()); |
98 | } |
99 | |
100 | for (int64_t linear_tidx : c10::irange(phase_size)) { |
101 | int64_t tidx = linear_tidx; |
102 | int64_t tidy = 0; |
103 | int64_t tidz = 0; |
104 | if (launch_params.has_value()) { |
105 | tidy = tidx / launch_params->bdimx(); |
106 | tidx = tidx % launch_params->bdimx(); |
107 | tidz = tidy / launch_params->bdimy(); |
108 | tidy = tidy % launch_params->bdimy(); |
109 | } |
110 | int64_t index = 0; |
111 | // make a copy of the expression evaluator |
112 | ExpressionEvaluator expr_eval = expr_eval_common; |
113 | expr_eval.bind("threadIdx.x" , tidx); |
114 | expr_eval.bind("threadIdx.y" , tidy); |
115 | expr_eval.bind("threadIdx.z" , tidz); |
116 | for (auto fl : for_loops) { |
117 | if (fl->index()->isA<NamedScalar>()) { |
118 | auto name = fl->index()->as<NamedScalar>()->name(); |
119 | TORCH_INTERNAL_ASSERT( |
120 | isThreadIdx(name) || isBlockIdx(name), "unknow loop index" ); |
121 | } else { |
122 | auto start = expr_eval.evaluate(fl->start())->as<int64_t>(); |
123 | expr_eval.bind(fl->index(), start); |
124 | } |
125 | } |
126 | for (auto ind : ti->indices()) { |
127 | index += expr_eval.evaluate(ind)->as<int64_t>(); |
128 | } |
129 | addresses.emplace_back(index * word_size_bytes); |
130 | } |
131 | return addresses; |
132 | } |
133 | |
134 | int getConflictWays(const std::vector<int64_t>& addresses) { |
135 | std::unordered_set<int64_t> words_by_bank[32]; |
136 | for (auto addr : addresses) { |
137 | int64_t word = addr / 4; |
138 | int64_t bank = word % 32; |
139 | words_by_bank[bank].insert(word); |
140 | } |
141 | int conflict = 1; |
142 | for (const auto& words : words_by_bank) { |
143 | conflict = std::max<int>(conflict, words.size()); |
144 | } |
145 | return conflict; |
146 | } |
147 | |
148 | class InferLaunchParams : public kir::IrVisitor { |
149 | public: |
150 | static c10::optional<LaunchParams> get( |
151 | const std::vector<Expr*>& exprs, |
152 | const std::unordered_map<std::string, IntOrDouble>& known_values) { |
153 | if (exprs.empty()) { |
154 | return c10::nullopt; |
155 | } |
156 | return InferLaunchParams(exprs, known_values).launch_params_; |
157 | } |
158 | |
159 | private: |
160 | InferLaunchParams( |
161 | const std::vector<Expr*>& exprs, |
162 | const std::unordered_map<std::string, IntOrDouble>& known_values) |
163 | : expr_eval_(exprs[0]->fusion()) { |
164 | for (auto pair : known_values) { |
165 | expr_eval_.bind(pair.first, pair.second); |
166 | } |
167 | handle(exprs); |
168 | } |
169 | |
170 | using kir::IrVisitor::handle; |
171 | |
172 | void handle(Expr* expr) final { |
173 | if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { |
174 | kir::IrVisitor::handle(expr); |
175 | return; |
176 | } |
177 | |
178 | for (auto fl : for_loops_) { |
179 | if (fl->index()->isA<NamedScalar>()) { |
180 | auto name = fl->index()->as<NamedScalar>()->name(); |
181 | if (isThreadIdx(name) || isBlockIdx(name)) { |
182 | auto ptype = getParallelType(name); |
183 | auto stop = expr_eval_.evaluate(fl->stop()); |
184 | if (stop.has_value()) { |
185 | if (!launch_params_.has_value()) { |
186 | launch_params_ = LaunchParams(); |
187 | } |
188 | if (launch_params_->getRawVal(ptype) == |
189 | LaunchParams::UNINITIALIZED_VAL) { |
190 | launch_params_->bind(stop->as<int64_t>(), ptype); |
191 | } else { |
192 | TORCH_INTERNAL_ASSERT( |
193 | launch_params_->getDim(ptype) == stop, |
194 | "Unable to infer launch parameters" ); |
195 | } |
196 | } |
197 | } |
198 | } |
199 | } |
200 | } |
201 | |
202 | ExpressionEvaluator expr_eval_; |
203 | c10::optional<LaunchParams> launch_params_; |
204 | }; |
205 | |
206 | class BankConflictInfo : public kir::IrVisitor { |
207 | public: |
208 | static std::unordered_map<const Expr*, std::pair<int, int>> get( |
209 | const std::vector<Expr*>& exprs, |
210 | c10::optional<LaunchParams> launch_params, |
211 | const std::unordered_map<std::string, IntOrDouble>& known_values) { |
212 | if (exprs.empty()) { |
213 | return {}; |
214 | } |
215 | return BankConflictInfo(exprs, launch_params, known_values) |
216 | .bank_conflict_info_; |
217 | } |
218 | |
219 | private: |
220 | BankConflictInfo( |
221 | const std::vector<Expr*>& exprs, |
222 | c10::optional<LaunchParams> launch_params, |
223 | const std::unordered_map<std::string, IntOrDouble>& known_values) |
224 | : launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) { |
225 | expr_eval_common_.bind("blockIdx.x" , 0); |
226 | expr_eval_common_.bind("blockIdx.y" , 0); |
227 | expr_eval_common_.bind("blockIdx.z" , 0); |
228 | if (launch_params.has_value()) { |
229 | expr_eval_common_.bind("blockDim.x" , launch_params->bdimx()); |
230 | expr_eval_common_.bind("blockDim.y" , launch_params->bdimy()); |
231 | expr_eval_common_.bind("blockDim.z" , launch_params->bdimz()); |
232 | expr_eval_common_.bind("gridDim.x" , launch_params->gdimx()); |
233 | expr_eval_common_.bind("gridDim.y" , launch_params->gdimy()); |
234 | expr_eval_common_.bind("gridDim.z" , launch_params->gdimz()); |
235 | } |
236 | for (auto pair : known_values) { |
237 | expr_eval_common_.bind(pair.first, pair.second); |
238 | } |
239 | handle(exprs); |
240 | } |
241 | |
242 | using kir::IrVisitor::handle; |
243 | |
244 | void handle(Expr* expr) final { |
245 | if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { |
246 | kir::IrVisitor::handle(expr); |
247 | return; |
248 | } |
249 | |
250 | if (expr->isA<UnaryOp>()) { |
251 | auto uop = expr->as<UnaryOp>(); |
252 | if (uop->getUnaryOpType() != UnaryOpType::Set) { |
253 | return; |
254 | } |
255 | std::pair<int, int> conflict_ways{0, 0}; |
256 | if (isSmemTensorIndex(uop->in())) { |
257 | conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( |
258 | uop->in()->as<kir::TensorIndex>(), |
259 | for_loops_, |
260 | launch_params_, |
261 | expr_eval_common_)); |
262 | } |
263 | if (isSmemTensorIndex(uop->out())) { |
264 | conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( |
265 | uop->out()->as<kir::TensorIndex>(), |
266 | for_loops_, |
267 | launch_params_, |
268 | expr_eval_common_)); |
269 | } |
270 | if (conflict_ways.first > 1 || conflict_ways.second > 1) { |
271 | bank_conflict_info_[expr] = conflict_ways; |
272 | } |
273 | } else if (expr->isA<LoadStoreOp>()) { |
274 | auto ldst = expr->as<LoadStoreOp>(); |
275 | std::pair<int, int> conflict_ways{0, 0}; |
276 | if (isSmemTensorIndex(ldst->in())) { |
277 | conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( |
278 | ldst->in()->as<kir::TensorIndex>(), |
279 | for_loops_, |
280 | launch_params_, |
281 | expr_eval_common_)); |
282 | } |
283 | if (isSmemTensorIndex(ldst->out())) { |
284 | conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( |
285 | ldst->out()->as<kir::TensorIndex>(), |
286 | for_loops_, |
287 | launch_params_, |
288 | expr_eval_common_)); |
289 | } |
290 | if (conflict_ways.first > 1 || conflict_ways.second > 1) { |
291 | bank_conflict_info_[expr] = conflict_ways; |
292 | } |
293 | } |
294 | } |
295 | |
296 | std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_; |
297 | c10::optional<LaunchParams> launch_params_; |
298 | ExpressionEvaluator expr_eval_common_; |
299 | }; |
300 | |
301 | } // namespace |
302 | |
303 | std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo( |
304 | kir::Kernel* kernel, |
305 | c10::optional<LaunchParams> launch_params, |
306 | const std::unordered_map<std::string, IntOrDouble>& known_values) { |
307 | for (auto pair : known_values) { |
308 | TORCH_CHECK( |
309 | !isThreadIdx(pair.first), |
310 | "threadIdx.{x,y,z} should be computed instead of provided" ); |
311 | TORCH_CHECK( |
312 | !isBlockIdx(pair.first), |
313 | "blockIdx.{x,y,z} should not be provided (they are always zero)" ); |
314 | TORCH_CHECK( |
315 | !isBlockDim(pair.first), |
316 | "blockDim.{x,y,z} should be provided by launch_params" ); |
317 | TORCH_CHECK( |
318 | !isGridDim(pair.first), |
319 | "gridDim.{x,y,z} should be provided by launch_params" ); |
320 | } |
321 | if (!launch_params.has_value()) { |
322 | launch_params = |
323 | InferLaunchParams::get(kernel->topLevelExprs(), known_values); |
324 | } |
325 | return BankConflictInfo::get( |
326 | kernel->topLevelExprs(), launch_params, known_values); |
327 | } |
328 | |
329 | } // namespace cuda |
330 | } // namespace fuser |
331 | } // namespace jit |
332 | } // namespace torch |
333 | |