1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/pipeline.hpp" |
18 | |
19 | #include "gpu/jit/ir/message.hpp" |
20 | #include "gpu/jit/ir/reorder.hpp" |
21 | #include "gpu/jit/utils/trace.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | // Helper structure for for_t. |
29 | struct loop_info_t { |
30 | loop_info_t() = default; |
31 | |
32 | loop_info_t(const stmt_t &s) { |
33 | ir_assert(s.is<for_t>()) << s; |
34 | auto &loop = s.as<for_t>(); |
35 | stmt = s; |
36 | var = loop.var; |
37 | init_ = loop.init; |
38 | bound_ = loop.bound; |
39 | |
40 | auto e_size = simplify(bound_ - init_); |
41 | ir_assert(is_const(e_size)); |
42 | size_ = to_cpp<int>(e_size); |
43 | } |
44 | |
45 | int init() const { |
46 | ir_assert(is_const(init_)); |
47 | return to_cpp<int>(init_); |
48 | } |
49 | |
50 | int bound() const { |
51 | ir_assert(is_const(bound_)); |
52 | return to_cpp<int>(bound_); |
53 | } |
54 | |
55 | int size() const { return size_; } |
56 | const stmt_t &body() const { return stmt.as<for_t>().body; } |
57 | int unroll() const { return stmt.as<for_t>().unroll; } |
58 | |
59 | stmt_t stmt; |
60 | expr_t var; |
61 | |
62 | private: |
63 | expr_t init_; |
64 | expr_t bound_; |
65 | int size_; |
66 | }; |
67 | |
68 | // Iterates through multiple nested loops with fixed bounds. Used to unroll |
69 | // such nested loops. |
70 | class multi_loop_iterator_t { |
71 | public: |
72 | // Ordered from innermost to outermost. |
73 | multi_loop_iterator_t(const std::vector<loop_info_t> &loops) |
74 | : loops_(loops) { |
75 | for (auto &l : loops) |
76 | var_values_.push_back(l.init()); |
77 | } |
78 | |
79 | int var_value(const expr_t &var) const { |
80 | for (size_t i = 0; i < loops_.size(); i++) { |
81 | if (loops_[i].var.is_same(var)) return var_values_[i]; |
82 | } |
83 | ir_error_not_expected(); |
84 | return 0; |
85 | } |
86 | |
87 | void advance(int n = 1) { |
88 | if (loops_.empty()) return; |
89 | for (int i_n = 0; i_n < n; i_n++) { |
90 | for (size_t i = 0; i < loops_.size(); i++) { |
91 | auto &l = loops_[i]; |
92 | if (++var_values_[i] < l.bound()) break; |
93 | var_values_[i] = l.init(); |
94 | } |
95 | ir_assert(var_values_.back() < loops_.back().bound()); |
96 | } |
97 | } |
98 | |
99 | bool is_outer_loop_end() const { |
100 | if (loops_.empty()) return true; |
101 | for (size_t i = 0; i < loops_.size() - 1; i++) { |
102 | auto &l = loops_[i]; |
103 | if (var_values_[i] != l.bound() - 1) return false; |
104 | } |
105 | return true; |
106 | } |
107 | |
108 | std::string str() const { |
109 | std::ostringstream oss; |
110 | oss << "multi_loop_iterator_t(" ; |
111 | for (size_t i = 0; i < loops_.size(); i++) { |
112 | oss << (i != 0 ? ", " : "" ); |
113 | oss << loops_[i].var << " = " << var_values_[i]; |
114 | } |
115 | oss << ")" ; |
116 | return oss.str(); |
117 | } |
118 | |
119 | IR_DEFINE_DUMP() |
120 | |
121 | private: |
122 | std::vector<loop_info_t> loops_; |
123 | std::vector<int> var_values_; |
124 | }; |
125 | |
126 | // Extracts different parts of the compute iteration and verifies the loop nest |
127 | // is properly formed and can be further injected with SLM buffering. |
128 | class compute_step_visitor_t : public ir_visitor_t { |
129 | public: |
130 | stmt_t find_stmt_group(const stmt_label_t &label) const { |
131 | auto groups = find_stmt_groups(label); |
132 | if (groups.empty()) return stmt_t(); |
133 | ir_assert(groups.size() == 1); |
134 | return groups[0]; |
135 | } |
136 | |
137 | std::vector<stmt_t> find_stmt_groups(const stmt_label_t &label) const { |
138 | std::vector<stmt_t> ret; |
139 | for (auto &_g : stmt_groups_) { |
140 | auto &g = _g.as<stmt_group_t>(); |
141 | if (g.label == label) ret.push_back(_g); |
142 | } |
143 | return ret; |
144 | } |
145 | |
146 | const std::vector<stmt_t> &inner_let_stmts() const { |
147 | return inner_let_stmts_; |
148 | } |
149 | |
150 | #define HANDLE_IR_OBJECT(type) \ |
151 | void _visit(const type &obj) override { visit_stmt(obj); } |
152 | |
153 | HANDLE_STMT_IR_OBJECTS() |
154 | |
155 | #undef HANDLE_IR_OBJECT |
156 | |
157 | template <typename T> |
158 | void visit_stmt(const T &obj) { |
159 | bool is_for = obj.template is<for_t>(); |
160 | bool is_stmt_group = obj.template is<stmt_group_t>(); |
161 | bool is_let = obj.template is<let_t>(); |
162 | bool is_stmt_seq = obj.template is<stmt_seq_t>(); |
163 | |
164 | // Loop may contain: |
165 | // - Another loop |
166 | // - Container statement (stmt_seq_t or stmt_group_t) |
167 | // - Let statement (in the innermost loop only) |
168 | // - Barrier |
169 | if (loop_level_ > 0) { |
170 | bool ok = false; |
171 | if (is_for || is_let || is_stmt_group || is_stmt_seq) { |
172 | ok = true; |
173 | } else if (obj.template is<func_call_t>()) { |
174 | auto &call = obj.template as<func_call_t>(); |
175 | ok = call.func.is_equal(funcs::barrier_func()); |
176 | } |
177 | |
178 | if (!ok) { |
179 | ir_error_not_expected() |
180 | << "Found unexpected statement inside loop.\n" |
181 | << stmt_t(obj); |
182 | } |
183 | } |
184 | |
185 | bool is_compute_loop = false; |
186 | if (is_stmt_group) { |
187 | auto label = obj.template as<stmt_group_t>().label; |
188 | stmt_groups_.push_back(obj); |
189 | if (utils::one_of(label, stmt_label_t::g2s_load(), |
190 | stmt_label_t::g2s_store(), stmt_label_t::g2r_load(), |
191 | stmt_label_t::s2r_load(), stmt_label_t::prefetch(), |
192 | stmt_label_t::mul())) { |
193 | // Leaf labels, do not visit them. |
194 | return; |
195 | } |
196 | if (label == stmt_label_t::compute_loop()) { |
197 | is_compute_loop = true; |
198 | in_compute_loop_ = true; |
199 | } |
200 | } |
201 | |
202 | if (is_for && in_compute_loop_) loop_level_++; |
203 | found_loop_ = false; |
204 | ir_visitor_t::_visit(obj); |
205 | if (in_compute_loop_ && is_let) { |
206 | if (found_loop_) |
207 | ir_error_not_expected() |
208 | << "Let is allowed in the innermost loop only." ; |
209 | |
210 | inner_let_stmts_.push_back(replace_stmt_body(obj, stmt_t())); |
211 | } |
212 | if (is_for && in_compute_loop_) { |
213 | loop_level_--; |
214 | found_loop_ = true; |
215 | } |
216 | |
217 | if (is_compute_loop) in_compute_loop_ = false; |
218 | } |
219 | |
220 | private: |
221 | bool found_loop_ = false; |
222 | bool in_compute_loop_ = false; |
223 | int loop_level_ = 0; |
224 | |
225 | std::vector<stmt_t> stmt_groups_; |
226 | std::vector<stmt_t> inner_let_stmts_; |
227 | }; |
228 | |
229 | // Provides access to different parts of the inner compute iteration. |
230 | class compute_step_t { |
231 | public: |
232 | compute_step_t(const stmt_t &parent) { |
233 | compute_step_visitor_t v; |
234 | v.visit(parent); |
235 | |
236 | compute_loop_ = v.find_stmt_group(stmt_label_t::compute_loop()); |
237 | g2s_load_ = v.find_stmt_group(stmt_label_t::g2s_load()); |
238 | g2s_store_ = v.find_stmt_group(stmt_label_t::g2s_store()); |
239 | prefetch_ = v.find_stmt_group(stmt_label_t::prefetch()); |
240 | g2r_load_ = v.find_stmt_groups(stmt_label_t::g2r_load()); |
241 | s2r_load_ = v.find_stmt_groups(stmt_label_t::s2r_load()); |
242 | mul_ = v.find_stmt_groups(stmt_label_t::mul()); |
243 | c_zero_out_ = v.find_stmt_group(stmt_label_t::c_zero_out()); |
244 | inner_let_stmts_ = v.inner_let_stmts(); |
245 | |
246 | ir_assert(g2r_load_.size() == mul_.size()); |
247 | ir_assert(s2r_load_.size() == mul_.size()); |
248 | |
249 | // Assign preload/mul tags to let statements. |
250 | for (auto &_let : inner_let_stmts_) { |
251 | auto &var = _let.as<let_t>().var; |
252 | bool is_preload = (count_object(g2s_load_, var) > 0) |
253 | || (count_object(prefetch_, var) > 0); |
254 | bool is_mul = count_object(g2r_load_, var) > 0 |
255 | || count_object(mul_, var) > 0; |
256 | if (is_preload) preload_lets_.insert(_let); |
257 | if (is_mul) mul_lets_.insert(_let); |
258 | } |
259 | |
260 | // Propagate preload/mul tags up based on dependencies between let |
261 | // statements. |
262 | std::vector<let_info_t> let_infos; |
263 | object_set_t<stmt_t> seen; |
264 | std::function<void(const stmt_t &)> propagate; |
265 | propagate = [&](const stmt_t &_let) { |
266 | if (seen.count(_let) > 0) return; |
267 | auto &let = _let.as<let_t>(); |
268 | for (auto &_child : inner_let_stmts_) { |
269 | auto &child = _child.as<let_t>(); |
270 | if (_child.is_same(_let)) continue; |
271 | if (contains_object(child.value, let.var)) { |
272 | // Visit child let statements first. |
273 | propagate(_child); |
274 | // Propagate child preload/mul values to this let statement. |
275 | if (is_preload_let(_child)) preload_lets_.insert(_let); |
276 | if (is_mul_let(_child)) mul_lets_.insert(_let); |
277 | } |
278 | } |
279 | auto let_info = create_let_info( |
280 | let, is_preload_let(_let), is_mul_let(_let)); |
281 | let_infos.push_back(let_info); |
282 | seen.insert(_let); |
283 | }; |
284 | for (auto &_let : inner_let_stmts_) |
285 | propagate(_let); |
286 | |
287 | // Duplicate lets that are used in both preload and mul contexts. |
288 | duplicate_lets(let_infos); |
289 | } |
290 | |
291 | // See ir_core.hpp for the description. |
292 | const stmt_t &compute_loop() const { return compute_loop_; } |
293 | const stmt_t &g2s_load() const { return g2s_load_; } |
294 | const stmt_t &g2s_store() const { return g2s_store_; } |
295 | const stmt_t &prefetch() const { return prefetch_; } |
296 | const std::vector<stmt_t> &g2r_load() const { return g2r_load_; } |
297 | const std::vector<stmt_t> &s2r_load() const { return s2r_load_; } |
298 | const std::vector<stmt_t> &mul() const { return mul_; } |
299 | const stmt_t &c_zero_out() const { return c_zero_out_; } |
300 | const std::vector<stmt_t> &inner_let_stmts() const { |
301 | return inner_let_stmts_; |
302 | } |
303 | |
304 | bool is_preload_let(const stmt_t &s) const { |
305 | return preload_lets_.count(s) > 0; |
306 | } |
307 | bool is_mul_let(const stmt_t &s) const { return mul_lets_.count(s) > 0; } |
308 | |
309 | private: |
310 | struct let_info_t { |
311 | let_info_t(const expr_t &var) : var(var) {} |
312 | |
313 | expr_t var; |
314 | expr_t preload_var; |
315 | expr_t mul_var; |
316 | |
317 | bool is_preload() const { return !preload_var.is_empty(); } |
318 | bool is_mul() const { return !mul_var.is_empty(); } |
319 | |
320 | bool needs_update() const { return is_preload() && is_mul(); } |
321 | }; |
322 | |
323 | let_info_t create_let_info(const let_t &let, bool is_preload, bool is_mul) { |
324 | let_info_t info(let.var); |
325 | if (is_preload && !is_mul) { |
326 | info.preload_var = let.var; |
327 | } else if (!is_preload && is_mul) { |
328 | info.mul_var = let.var; |
329 | } else if (is_preload && is_mul) { |
330 | info.preload_var = create_var_with_suffix(let.var, "p" ); |
331 | info.mul_var = create_var_with_suffix(let.var, "m" ); |
332 | } |
333 | return info; |
334 | } |
335 | |
336 | void duplicate_lets(const std::vector<let_info_t> &let_infos) { |
337 | int nlets = int(inner_let_stmts_.size()); |
338 | ir_assert(int(let_infos.size()) == nlets); |
339 | |
340 | std::vector<stmt_t> new_lets; |
341 | for (int i = nlets - 1; i >= 0; i--) { |
342 | auto &info = let_infos[i]; |
343 | auto &old_let = inner_let_stmts_[i].as<let_t>(); |
344 | if (!info.needs_update()) { |
345 | auto new_value = update_var(old_let.value, let_infos, |
346 | info.is_preload(), info.is_mul()); |
347 | auto new_let = inner_let_stmts_[i]; |
348 | if (!new_value.is_same(old_let.value)) { |
349 | new_let = let_t::make(old_let.var, new_value, old_let.body); |
350 | } |
351 | new_lets.push_back(new_let); |
352 | continue; |
353 | } |
354 | |
355 | preload_lets_.erase(&old_let); |
356 | mul_lets_.erase(&old_let); |
357 | |
358 | auto preload_value |
359 | = update_var(old_let.value, let_infos, true, false); |
360 | auto preload_let = let_t::make( |
361 | info.preload_var, preload_value, old_let.body); |
362 | |
363 | auto mul_value = update_var(old_let.value, let_infos, false, true); |
364 | auto mul_let = let_t::make(info.mul_var, mul_value, old_let.body); |
365 | |
366 | preload_lets_.insert(preload_let); |
367 | new_lets.push_back(preload_let); |
368 | |
369 | mul_lets_.insert(mul_let); |
370 | new_lets.push_back(mul_let); |
371 | |
372 | // Update statements. |
373 | g2s_load_ = update_var(g2s_load_, let_infos, true, false); |
374 | g2s_store_ = update_var(g2s_store_, let_infos, true, false); |
375 | prefetch_ = update_var(prefetch_, let_infos, true, false); |
376 | g2r_load_ = update_var(g2r_load_, let_infos, false, true); |
377 | s2r_load_ = update_var(s2r_load_, let_infos, false, true); |
378 | mul_ = update_var(mul_, let_infos, false, true); |
379 | } |
380 | |
381 | std::reverse(new_lets.begin(), new_lets.end()); |
382 | inner_let_stmts_ = new_lets; |
383 | } |
384 | |
385 | template <typename T> |
386 | static std::vector<T> update_var(const std::vector<T> &vec, |
387 | const std::vector<let_info_t> &let_infos, bool is_preload, |
388 | bool is_mul) { |
389 | std::vector<T> ret; |
390 | for (auto &v : vec) |
391 | ret.push_back(update_var(v, let_infos, is_preload, is_mul)); |
392 | return ret; |
393 | } |
394 | |
395 | static object_t update_var(const object_t &obj, |
396 | const std::vector<let_info_t> &let_infos, bool is_preload, |
397 | bool is_mul) { |
398 | auto ret = obj; |
399 | for (auto &info : let_infos) { |
400 | if (!info.needs_update()) continue; |
401 | if (!contains_object(ret, info.var)) continue; |
402 | if (is_preload) { |
403 | ir_assert(info.is_preload()); |
404 | ret = substitute(ret, info.var, info.preload_var); |
405 | } else if (is_mul) { |
406 | ir_assert(info.is_mul()); |
407 | ret = substitute(ret, info.var, info.mul_var); |
408 | } |
409 | } |
410 | return ret; |
411 | } |
412 | |
413 | static expr_t create_var_with_suffix( |
414 | const expr_t &_var, const std::string &suffix) { |
415 | auto &var = _var.as<var_t>(); |
416 | auto new_name = var.name + "_" + suffix; |
417 | return var_t::make(var.type, new_name); |
418 | } |
419 | |
420 | stmt_t compute_loop_; |
421 | stmt_t g2s_load_; |
422 | stmt_t g2s_store_; |
423 | stmt_t prefetch_; |
424 | std::vector<stmt_t> g2r_load_; |
425 | std::vector<stmt_t> s2r_load_; |
426 | std::vector<stmt_t> mul_; |
427 | stmt_t c_zero_out_; |
428 | |
429 | std::vector<stmt_t> inner_let_stmts_; |
430 | |
431 | // Due to loop unrolling the inner let statements may depend on different |
432 | // indices of the outer loops. There are two contexts: |
433 | // - "preload" loop iteration, e.g. index I |
434 | // - "multiplication" loop iteration, e.g. index (I + nbuf) |
435 | // Preloads (either via SLM or via prefetches) for the corresponding |
436 | // multiplication are executed several iterations before the real |
437 | // multiplication. That's why we need to know exactly in which context the |
438 | // given let statement is used. It might be that the same variable is used |
439 | // from two different contexts. In this case it is duplicated and |
440 | // initialized with different values for each case. |
441 | object_set_t<stmt_t> preload_lets_; |
442 | object_set_t<stmt_t> mul_lets_; |
443 | }; |
444 | |
445 | // Helper class to access the outer loop index after pipelining. Pipelining |
446 | // in general requires tracking two versions of a loop index: |
447 | // - Multiplication version - corresponding to the iteration that is currently |
448 | // used for multiplication |
449 | // - Preload version - corresponding to the iteration that is currently used |
450 | // for preload for one of the next multiplications |
451 | // The multiplication version is a few steps behind the preload version. |
452 | class outer_loop_info_t : public loop_info_t { |
453 | public: |
454 | outer_loop_info_t() = default; |
455 | |
456 | outer_loop_info_t(const stmt_t &s, ir_context_t &ir_ctx) : loop_info_t(s) { |
457 | // Outer loop may not be used for unrolling hence loop iterations must |
458 | // not use its index. If this doesn't hold, introduce a GRF buffer to |
459 | // represent that variable and apply post-increment updates after each |
460 | // outer loop iteration. |
461 | if (count_object(s.as<for_t>().body, var) != 0) { |
462 | has_var_refs_ = true; |
463 | mul_var_buf_ = ir_ctx.create_tmp_var( |
464 | type_t::byte_ptr(), var.as<var_t>().name + "_mul_buf" ); |
465 | preload_var_buf_ = ir_ctx.create_tmp_var( |
466 | type_t::byte_ptr(), var.as<var_t>().name + "_preload_buf" ); |
467 | |
468 | auto mul_alloc = alloc_t::make( |
469 | mul_var_buf_, var.type().size(), alloc_kind_t::grf); |
470 | auto preload_alloc = alloc_t::make( |
471 | preload_var_buf_, var.type().size(), alloc_kind_t::grf); |
472 | allocs_.push_back(mul_alloc); |
473 | allocs_.push_back(preload_alloc); |
474 | |
475 | auto mul_init = store_t::make(mul_var_buf_, 0, init()); |
476 | auto preload_init = store_t::make(preload_var_buf_, 0, init()); |
477 | init_stmt_ = mul_init.append(preload_init); |
478 | |
479 | mul_post_inc_stmt_ |
480 | = store_t::make(mul_var_buf_, 0, mul_var_load() + 1); |
481 | preload_post_inc_stmt_ = store_t::make( |
482 | preload_var_buf_, 0, preload_var_load() + 1); |
483 | } |
484 | } |
485 | |
486 | bool has_var_refs() const { return has_var_refs_; } |
487 | |
488 | expr_t mul_var_load() const { |
489 | return load_t::make(var.type(), mul_var_buf_, 0); |
490 | } |
491 | expr_t preload_var_load() const { |
492 | return load_t::make(var.type(), preload_var_buf_, 0); |
493 | } |
494 | |
495 | stmt_t inject_alloc_stmts(const stmt_t &stmt) const { |
496 | return jit::inject_alloc_stmts(stmt, allocs_); |
497 | } |
498 | |
499 | const stmt_t &init_stmt() const { return init_stmt_; } |
500 | |
501 | const stmt_t &mul_post_inc_stmt() const { return mul_post_inc_stmt_; } |
502 | const stmt_t &preload_post_inc_stmt() const { |
503 | return preload_post_inc_stmt_; |
504 | } |
505 | |
506 | private: |
507 | bool has_var_refs_ = false; |
508 | |
509 | // Helper expressions/statements to partially unroll the loop. |
510 | expr_t mul_var_buf_; |
511 | expr_t preload_var_buf_; |
512 | std::vector<stmt_t> allocs_; |
513 | stmt_t init_stmt_; |
514 | stmt_t mul_post_inc_stmt_; |
515 | stmt_t preload_post_inc_stmt_; |
516 | }; |
517 | |
518 | class compute_loop_nest_visitor_t : public ir_visitor_t { |
519 | public: |
520 | int compute_loop_level() const { return compute_loop_level_; } |
521 | |
522 | const std::vector<loop_info_t> &loops() const { return loops_; } |
523 | |
524 | void _visit(const stmt_group_t &obj) override { |
525 | bool is_compute_loop = (obj.label == stmt_label_t::compute_loop()); |
526 | if (is_compute_loop) { |
527 | in_compute_loop_ = true; |
528 | compute_loop_level_ = level_; |
529 | } |
530 | ir_visitor_t::_visit(obj); |
531 | if (is_compute_loop) in_compute_loop_ = false; |
532 | } |
533 | |
534 | void _visit(const for_t &obj) override { |
535 | level_++; |
536 | ir_visitor_t::_visit(obj); |
537 | if (in_compute_loop_) loops_.emplace_back(obj); |
538 | level_--; |
539 | } |
540 | |
541 | private: |
542 | bool in_compute_loop_ = false; |
543 | int compute_loop_level_ = -1; |
544 | std::vector<loop_info_t> loops_; |
545 | int level_ = 0; |
546 | }; |
547 | |
548 | // Helper class to work with loop nest of the compute loop. |
549 | class compute_loop_nest_t { |
550 | public: |
551 | compute_loop_nest_t() = default; |
552 | |
553 | compute_loop_nest_t(const stmt_t &root, ir_context_t &ir_ctx) |
554 | : root_(root) { |
555 | compute_loop_nest_visitor_t visitor; |
556 | visitor.visit(root); |
557 | |
558 | compute_loop_level_ = visitor.compute_loop_level(); |
559 | loops_ = visitor.loops(); |
560 | |
561 | if (loops_.empty()) { |
562 | outer_loop_size_ = 1; |
563 | return; |
564 | } |
565 | |
566 | outer_loop_ = outer_loop_info_t(loops_.back().stmt, ir_ctx); |
567 | outer_loop_size_ = outer_loop_.size(); |
568 | } |
569 | |
570 | // Returns the loop level of the compute_loop statement group corresponding |
571 | // to the number of outer loops. |
572 | int compute_loop_level() const { return compute_loop_level_; } |
573 | |
574 | // Returns loops inside compute_loop statement group. |
575 | const std::vector<loop_info_t> &loops() const { return loops_; } |
576 | |
577 | // Number of iterations of all loops. |
578 | int size() const { |
579 | int ret = 1; |
580 | for (auto &l : loops_) |
581 | ret *= l.size(); |
582 | return ret; |
583 | } |
584 | |
585 | // Number of iterations in the outermost loop (see comments in ctor). |
586 | int outer_loop_size() const { return outer_loop_size_; } |
587 | |
588 | const outer_loop_info_t &outer_loop_info() const { return outer_loop_; } |
589 | |
590 | template <typename F> |
591 | void for_each_loop_var(const F &f) const { |
592 | for (auto &l : loops_) |
593 | f(l.var); |
594 | } |
595 | |
596 | // Number of iterations of all loops except the outermost. |
597 | int inner_loops_size() const { return size() / outer_loop_size(); } |
598 | |
599 | private: |
600 | stmt_t root_; |
601 | int compute_loop_level_ = -1; |
602 | std::vector<loop_info_t> loops_; |
603 | |
604 | int outer_loop_size_; |
605 | outer_loop_info_t outer_loop_; |
606 | }; |
607 | |
608 | struct compute_params_t { |
609 | compute_params_t() = default; |
610 | |
611 | compute_params_t(int slm_bufs, int gmem_bufs, int slm_buf_size, |
612 | int prefetch_bufs, int inner_loops_iters) |
613 | : slm_bufs(slm_bufs) |
614 | , gmem_bufs(gmem_bufs) |
615 | , slm_buf_size(slm_buf_size) |
616 | , prefetch_bufs(prefetch_bufs) { |
617 | use_slm = (slm_buf_size > 0); |
618 | use_prefetch = (prefetch_bufs > 0); |
619 | ir_assert(!use_slm || !use_prefetch) |
620 | << "Can't have both SLM buffering and prefetch enabled." ; |
621 | if (use_slm) { |
622 | ir_assert(utils::one_of(slm_bufs, 1, 2, 3)); |
623 | ir_assert(utils::one_of(gmem_bufs, 1, 2)); |
624 | preload_bufs = slm_bufs; |
625 | unroll = math::lcm(slm_bufs * gmem_bufs, inner_loops_iters); |
626 | } else if (use_prefetch) { |
627 | preload_bufs = prefetch_bufs; |
628 | ir_assert(slm_bufs == 0); |
629 | ir_assert(gmem_bufs == 0); |
630 | unroll = math::lcm(prefetch_bufs, inner_loops_iters); |
631 | } else { |
632 | preload_bufs = 0; |
633 | ir_assert(slm_bufs == 0); |
634 | ir_assert(gmem_bufs == 0); |
635 | unroll = inner_loops_iters; |
636 | } |
637 | } |
638 | |
639 | int slm_bufs; |
640 | int gmem_bufs; |
641 | int slm_buf_size; |
642 | int prefetch_bufs; |
643 | int preload_bufs; |
644 | int unroll; |
645 | |
646 | bool use_slm; |
647 | bool use_prefetch; |
648 | }; |
649 | |
650 | // Helper class to implement SLM buffering. |
651 | class compute_iterator_t { |
652 | public: |
653 | compute_iterator_t(const compute_params_t ¶ms, |
654 | const compute_loop_nest_t &loop_nest) |
655 | : params(params) |
656 | , preload_loop_it(loop_nest.loops()) |
657 | , mul_loop_it(loop_nest.loops()) { |
658 | |
659 | int compute_iters = loop_nest.size(); |
660 | iters = compute_iters; |
661 | ir_assert(iters >= 1) << "Empty loop is not expected." ; |
662 | |
663 | iters += std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1); |
664 | ramp_up_iters |
665 | = std::max(1, preload_bufs() + std::max(0, gmem_bufs() - 1)); |
666 | ramp_down_iters = std::min( |
667 | std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1), |
668 | iters - ramp_up_iters); |
669 | body_iters = iters - ramp_up_iters - ramp_down_iters; |
670 | body_iters = utils::rnd_dn(body_iters, params.unroll); |
671 | ramp_down_iters = iters - ramp_up_iters - body_iters; |
672 | |
673 | ir_assert(ramp_up_iters + body_iters + ramp_down_iters == iters); |
674 | |
675 | iter = 0; |
676 | linear_id = 0; |
677 | riter = iters - 1; |
678 | } |
679 | |
680 | int unroll() const { return params.unroll; } |
681 | |
682 | int preload_bufs() const { return params.preload_bufs; } |
683 | |
684 | int slm_bufs() const { return params.slm_bufs; } |
685 | |
686 | int gmem_bufs() const { return params.gmem_bufs; } |
687 | |
688 | compute_iterator_t &operator++() { |
689 | if (do_preload()) preload_loop_it.advance(); |
690 | if (do_mul()) mul_loop_it.advance(); |
691 | ++iter; |
692 | ++linear_id; |
693 | --riter; |
694 | return *this; |
695 | } |
696 | |
697 | void advance(int n) { |
698 | if (n == 0) return; |
699 | |
700 | ir_assert(n % params.unroll == 0); |
701 | ir_assert(iter + n <= iters); |
702 | |
703 | if (preload_bufs() > 0) ir_assert(do_preload()); |
704 | ir_assert(do_mul()); |
705 | |
706 | iter += n; |
707 | riter -= n; |
708 | |
709 | if (preload_bufs() > 0) preload_loop_it.advance(n); |
710 | mul_loop_it.advance(n); |
711 | } |
712 | |
713 | bool do_mul() const { |
714 | return iter >= std::max(0, preload_bufs() - 1) |
715 | + std::max(0, gmem_bufs() - 1); |
716 | } |
717 | |
718 | bool is_first_mul() const { |
719 | return iter |
720 | == std::max(0, preload_bufs() - 1) |
721 | + std::max(0, gmem_bufs() - 1); |
722 | } |
723 | bool is_last_mul() const { return riter == 0; } |
724 | |
725 | bool is_last_g2s_store() const { |
726 | if (!do_g2s_store()) return false; |
727 | return riter == slm_bufs() - 1; |
728 | } |
729 | |
730 | bool is_last_preload() const { |
731 | if (!do_preload()) return false; |
732 | return riter == (preload_bufs() - 1) + std::max(0, gmem_bufs() - 1); |
733 | } |
734 | |
735 | bool is_last_g2s_load() const { |
736 | if (!do_g2s_load()) return false; |
737 | return is_last_preload(); |
738 | } |
739 | |
740 | bool is_last_prefetch() const { |
741 | if (!do_prefetch()) return false; |
742 | return is_last_preload(); |
743 | } |
744 | |
745 | bool do_preload() const { |
746 | if (preload_bufs() == 0) return false; |
747 | return riter >= (preload_bufs() - 1) + std::max(0, gmem_bufs() - 1); |
748 | } |
749 | |
750 | bool do_prefetch() const { |
751 | if (!params.use_prefetch) return false; |
752 | return do_preload(); |
753 | } |
754 | |
755 | bool do_g2s_load() const { |
756 | if (!params.use_slm) return false; |
757 | return do_preload(); |
758 | } |
759 | |
760 | bool do_g2s_store() const { |
761 | if (!params.use_slm) return false; |
762 | ir_assert(gmem_bufs() >= 1); |
763 | return iter >= (gmem_bufs() - 1) && riter >= (slm_bufs() - 1); |
764 | } |
765 | |
766 | int gmem_write_buf_index() const { |
767 | ir_assert(do_g2s_load()); |
768 | return iter % gmem_bufs(); |
769 | } |
770 | |
771 | int gmem_read_buf_index() const { |
772 | ir_assert(do_g2s_store()); |
773 | return (iter - (gmem_bufs() - 1)) % gmem_bufs(); |
774 | } |
775 | |
776 | int slm_read_offset_update() const { |
777 | ir_assert(params.use_slm); |
778 | ir_assert(do_mul()); |
779 | |
780 | int slm_iter = iter - (gmem_bufs() - 1) - (slm_bufs() - 1); |
781 | int cur_slm_idx = slm_iter % slm_bufs(); |
782 | int next_slm_idx = (slm_iter + 1) % slm_bufs(); |
783 | int ret = next_slm_idx * params.slm_buf_size |
784 | - cur_slm_idx * params.slm_buf_size; |
785 | return ret; |
786 | } |
787 | |
788 | int slm_write_offset_update() const { |
789 | ir_assert(params.use_slm); |
790 | ir_assert(do_g2s_store()); |
791 | |
792 | int slm_iter = iter - (gmem_bufs() - 1); |
793 | int cur_slm_idx = slm_iter % slm_bufs(); |
794 | int next_slm_idx = (slm_iter + 1) % slm_bufs(); |
795 | int ret = next_slm_idx * params.slm_buf_size |
796 | - cur_slm_idx * params.slm_buf_size; |
797 | return ret; |
798 | } |
799 | |
800 | compute_params_t params; |
801 | multi_loop_iterator_t preload_loop_it; |
802 | multi_loop_iterator_t mul_loop_it; |
803 | |
804 | // ramp_up_iters + body_iters + ramp_down_iters == iters |
805 | int iters; |
806 | int ramp_up_iters; |
807 | int body_iters; |
808 | int ramp_down_iters; |
809 | |
810 | // Invariant: iter + riter = iters - 1 |
811 | int iter; |
812 | int riter; |
813 | |
814 | int linear_id; |
815 | }; |
816 | |
817 | // Basic LRU SBID allocator, tries to use the same SBIDs for the same GRF |
818 | // buffers. |
819 | class sbid_manager_t { |
820 | public: |
821 | sbid_manager_t(ngen::HW hw = ngen::HW::Unknown) |
822 | : sbid_count_(hw >= ngen::HW::XeHPC ? 32 : 16) |
823 | , tuple_func_(builtin_t::make("tuple" )) {} |
824 | |
825 | ngen_proxy::SBID get_sbid(const expr_t &buf, int index = 0) { |
826 | auto key = tuple_func_.call({buf, expr_t(index)}); |
827 | |
828 | int free_idx = -1; |
829 | for (int i = 0; i < sbid_count_; i++) { |
830 | auto &e = entries_[i]; |
831 | if (key.is_equal(e.key)) { |
832 | e.time = cur_time_++; |
833 | return ngen_proxy::SBID(i); |
834 | } |
835 | if (free_idx == -1 && e.key.is_empty()) free_idx = i; |
836 | } |
837 | |
838 | // Not found but there is a free SBID. |
839 | if (free_idx != -1) { |
840 | entries_[free_idx] = {key, cur_time_++}; |
841 | return ngen_proxy::SBID(free_idx); |
842 | } |
843 | |
844 | // Find the oldest SBID and use it. |
845 | int old_idx = 0; |
846 | int old_time = entries_[0].time; |
847 | for (int i = 1; i < sbid_count_; i++) { |
848 | if (entries_[i].time < old_time) { |
849 | old_idx = i; |
850 | old_time = entries_[i].time; |
851 | } |
852 | } |
853 | |
854 | entries_[old_idx] = entry_t({key, cur_time_++}); |
855 | return ngen_proxy::SBID(old_idx); |
856 | } |
857 | |
858 | private: |
859 | struct entry_t { |
860 | stmt_t key; |
861 | int time; |
862 | }; |
863 | |
864 | static const int max_sbid_count = 32; |
865 | std::array<entry_t, max_sbid_count> entries_; |
866 | |
867 | int sbid_count_ = 0; |
868 | func_t tuple_func_; |
869 | int cur_time_ = 0; |
870 | }; |
871 | |
872 | // Helper to assign SBIDs to IR function calls. |
873 | class sbid_assigner_t { |
874 | public: |
875 | sbid_assigner_t(ngen::HW hw) : local_sbid_mgr_(hw) {} |
876 | |
877 | sbid_assigner_t(sbid_manager_t &external_sbid_mgr) |
878 | : external_sbid_mgr_(&external_sbid_mgr) {} |
879 | |
880 | stmt_t assign(const stmt_t &stmt) { |
881 | auto stmt_vec = flatten_statements(stmt); |
882 | stmt_t ret = stmt; |
883 | int prefetch_idx = 0; |
884 | for (auto &_s : stmt_vec) { |
885 | if (!_s.is<func_call_t>()) continue; |
886 | auto s = _s; |
887 | if (is_func_call<send_t>(s)) { |
888 | auto &send = s.as<func_call_t>().func.as<send_t>(); |
889 | int idx = (send.is_prefetch() || send.is_prefetch_2d() |
890 | ? prefetch_idx++ |
891 | : 0); |
892 | auto sbid = get_sbid(send_t::arg_reg_buf(s), idx); |
893 | s = update_call_with_sbid(s, sbid); |
894 | } else if (is_func_call<dpas_t>(s)) { |
895 | auto &c = s.as<func_call_t>(); |
896 | auto *mod_attr = c.attr.as_ptr<instruction_modifier_attr_t>(); |
897 | if (!c.func.as<dpas_t>().is_dp4a() && // dp4a-s do not need SBID |
898 | (!mod_attr || !mod_attr->mod.is_atomic)) { |
899 | // Last dpas in Atomic chain. |
900 | auto sbid = get_sbid(dpas_t::arg_src1(s)); |
901 | s = update_call_with_sbid(s, sbid); |
902 | } |
903 | } else if (s.is<func_call_t>()) { |
904 | auto &c = s.as<func_call_t>(); |
905 | if (c.func.is_equal(funcs::signal_func()) |
906 | || c.func.is_equal(funcs::slm_fence_func()) |
907 | || c.func.is_equal(funcs::barrier_func())) { |
908 | // Use 0 as the key for signals and SLM fences. |
909 | auto sbid = get_sbid(expr_t(0)); |
910 | s = update_call_with_sbid(s, sbid); |
911 | } |
912 | } else { |
913 | ir_error_not_expected() << s; |
914 | } |
915 | ret = substitute(ret, _s, s); |
916 | } |
917 | return ret; |
918 | } |
919 | |
920 | private: |
921 | ngen_proxy::SBID get_sbid(const expr_t &ptr, int index = 0) { |
922 | auto &sbid_mgr |
923 | = (external_sbid_mgr_ ? *external_sbid_mgr_ : local_sbid_mgr_); |
924 | return sbid_mgr.get_sbid(ptr, index); |
925 | } |
926 | |
927 | static stmt_t update_call_with_sbid( |
928 | const stmt_t &s, const ngen_proxy::SBID &sbid) { |
929 | return instruction_modifier_attr_t::make( |
930 | ngen_proxy::InstructionModifier().with_sbid(sbid)) |
931 | .apply_to(s); |
932 | } |
933 | |
934 | sbid_manager_t local_sbid_mgr_; |
935 | sbid_manager_t *external_sbid_mgr_ = nullptr; |
936 | }; |
937 | |
938 | // Work around due to limited scoping functionality in current generator |
939 | // Prepends all newly created var_t names with given prefix. |
940 | class var_prepender_t : public ir_mutator_t { |
941 | public: |
942 | var_prepender_t(const std::string &prefix) : prefix_(prefix) {} |
943 | object_t _mutate(const for_t &obj) override { |
944 | auto new_obj = ir_mutator_t::_mutate(obj); |
945 | auto new_var = var_t::make( |
946 | obj.var.type(), prefix_ + obj.var.as<var_t>().name); |
947 | new_obj = substitute(new_obj, obj.var, new_var); |
948 | return new_obj; |
949 | } |
950 | object_t _mutate(const let_t &obj) override { |
951 | auto new_obj = ir_mutator_t::_mutate(obj); |
952 | auto new_var = var_t::make( |
953 | obj.var.type(), prefix_ + obj.var.as<var_t>().name); |
954 | new_obj = substitute(new_obj, obj.var, new_var); |
955 | return new_obj; |
956 | } |
957 | |
958 | private: |
959 | std::string prefix_; |
960 | }; |
961 | |
962 | object_t prepend_new_vars(const object_t &root, const std::string &prefix) { |
963 | var_prepender_t mutator(prefix); |
964 | return mutator.mutate(root); |
965 | } |
966 | |
967 | // Perform pipelining operation. The goal is to transform |
968 | // the loop structure from: |
969 | // |
970 | // for i in range(init, bound): |
971 | // A_block(i); |
972 | // B_block(i); |
973 | // |
974 | // to the following |
975 | // |
976 | // for i in range(init, init + length): |
977 | // A_block(i); |
978 | // for i in range(init, bound): |
979 | // if (i < bound - length): |
980 | // A_block(i + length); |
981 | // B_block(i); |
982 | // |
983 | // Since A_block and B_block have to be independent to maintain correctness, |
984 | // this transform ignores the operations within the for_loop and relies on a |
985 | // correct substitution for A_block and B_block. |
986 | |
987 | struct pipeline_ctx_t { |
988 | pipeline_ctx_t(const stmt_t &prologue, const stmt_t &body) |
989 | : prologue_(prologue), body_(body) {} |
990 | stmt_t stmt() const { return prologue_.append(body_); } |
991 | stmt_t prologue() { return prologue_; } |
992 | stmt_t body() { return body_; } |
993 | |
994 | private: |
995 | stmt_t prologue_; |
996 | stmt_t body_; |
997 | }; |
998 | |
999 | pipeline_ctx_t pipeline( |
1000 | int length, const loop_info_t &loop, stmt_t A_block, stmt_t B_block) { |
1001 | |
1002 | expr_t idx = loop.var; |
1003 | int bound = loop.bound(); |
1004 | int init = loop.init(); |
1005 | |
1006 | int pipe_len = std::min(init + length, bound); |
1007 | |
1008 | stmt_t prologue = prepend_new_vars( |
1009 | for_t::make(idx, init, pipe_len, A_block, |
1010 | pipe_len <= loop.unroll() ? pipe_len : 1), |
1011 | "prefetch_" ); |
1012 | |
1013 | expr_t A_idx = idx + pipe_len; |
1014 | stmt_t body = if_t::make( |
1015 | idx < (bound - pipe_len), substitute(A_block, idx, A_idx)); |
1016 | body = body.append(B_block); |
1017 | body = for_t::make(idx, init, bound, body, loop.unroll()); |
1018 | |
1019 | return pipeline_ctx_t(prologue, body); |
1020 | } |
1021 | |
1022 | class prefetch_pipeliner_t { |
1023 | public: |
1024 | prefetch_pipeliner_t( |
1025 | const stmt_t &root, const conv_config_t &cfg, ir_context_t &ir_ctx) |
1026 | : root_(root), cfg_(cfg), ir_ctx_(ir_ctx) {} |
1027 | stmt_t inject() { |
1028 | auto compute_loop_stmt |
1029 | = find_stmt_group(root_, stmt_label_t::compute_loop()); |
1030 | if (!compute_loop_stmt.has_value()) return root_; |
1031 | auto compute_loop = compute_loop_stmt.value(); |
1032 | auto loop_nest = compute_loop_nest_t(compute_loop, ir_ctx_); |
1033 | auto &loops = loop_nest.loops(); |
1034 | |
1035 | // No loops to pipeline |
1036 | if (loops.size() == 0) return root_; |
1037 | auto &loop_body = loops[0].body(); |
1038 | |
1039 | auto A_block_stmt |
1040 | = find_stmt_group(loop_body, stmt_label_t::prefetch()); |
1041 | if (!A_block_stmt.has_value()) return root_; |
1042 | auto A_block = A_block_stmt.value(); |
1043 | auto B_block = remove_stmt_group(loop_body, stmt_label_t::prefetch()); |
1044 | size_t prefetch_count = 0; |
1045 | size_t max_nested_prefetch = 2; |
1046 | for (size_t i = 0; i < loops.size(); i++) { |
1047 | if (prefetch_count < max_nested_prefetch) { |
1048 | if (!contains_object(A_block, loops[i].var)) { |
1049 | // No point in prefetching a constant in a loop |
1050 | B_block = for_t::make(loops[i].var, loops[i].init(), |
1051 | loops[i].bound(), B_block, loops[i].unroll()); |
1052 | continue; |
1053 | } |
1054 | |
1055 | auto next = pipeline( |
1056 | cfg_.prefetch().bufs(), loops[i], A_block, B_block); |
1057 | A_block = next.prologue(); |
1058 | B_block = next.body(); |
1059 | prefetch_count++; |
1060 | |
1061 | } else { |
1062 | B_block = for_t::make(loops[i].var, loops[i].init(), |
1063 | loops[i].bound(), A_block.append(B_block), |
1064 | loops[i].unroll()); |
1065 | A_block = stmt_t(); |
1066 | } |
1067 | } |
1068 | return substitute(root_, compute_loop, A_block.append(B_block)); |
1069 | } |
1070 | |
1071 | private: |
1072 | stmt_t root_; |
1073 | const conv_config_t &cfg_; |
1074 | ir_context_t &ir_ctx_; |
1075 | }; |
1076 | |
1077 | stmt_t inject_prefetch_pipeline( |
1078 | const stmt_t &s, ir_context_t &ir_ctx, const conv_config_t &cfg) { |
1079 | trace_start(); |
1080 | auto ret = prefetch_pipeliner_t(s, cfg, ir_ctx).inject(); |
1081 | trace_pass("inject_prefetch_pipeline" , ret, ir_ctx); |
1082 | return ret; |
1083 | } |
1084 | |
1085 | // Helper class to handle synchronization between threads for cooperative SLM |
1086 | // load and stores for double and triple buffering. Name conventions: |
1087 | // - Lx step - load from global memory to GRF (to be stored in SLM buffer x) |
1088 | // - Mx step - load from SLM buffer x to GRF and multiplication |
1089 | // - Sx step - store from GRF to SLM buffer x |
1090 | // - Rx event - SLM buffer x is available for reading |
1091 | // - Wx event - SLM buffer x is available for writing |
1092 | // Scheme for single buffering: |
1093 | // L0 |
1094 | // barrier |
1095 | // S0 |
1096 | // barrier |
1097 | // M0 |
1098 | // Schemes for double and triple buffering are below. |
1099 | class slm_sync_manager_t { |
1100 | public: |
1101 | slm_sync_manager_t(const conv_config_t &cfg, bool with_unroll) |
1102 | : slm_bufs_(cfg.slm().bufs()) |
1103 | , gmem_bufs_(cfg.slm().gmem_bufs()) |
1104 | , with_unroll_(with_unroll) { |
1105 | switch (slm_bufs_) { |
1106 | case 2: ver_ = version_t::x2; break; |
1107 | case 3: ver_ = version_t::x3_v3; break; |
1108 | default: ver_ = version_t::undef; |
1109 | } |
1110 | if (cfg.slm().sync_version() != -1) { |
1111 | ver_ = (version_t)cfg.slm().sync_version(); |
1112 | } |
1113 | switch (slm_bufs_) { |
1114 | case 2: ir_assert(ver_ == version_t::x2); break; |
1115 | case 3: |
1116 | ir_assert(utils::one_of(ver_, version_t::x3_v1, |
1117 | version_t::x3_v2, version_t::x3_v3)); |
1118 | break; |
1119 | default: ir_assert(ver_ == version_t::undef); |
1120 | } |
1121 | } |
1122 | |
1123 | stmt_t before_loop_prepend(const stmt_t &_s) const { |
1124 | if (with_unroll_) return _s; |
1125 | auto s = _s; |
1126 | if (is_x3_v1() || is_x3_v2() || is_x3_v3()) { |
1127 | // Emit initial signal, to match wait-signal pairs in the loop. |
1128 | s = funcs::signal().append(s); |
1129 | } |
1130 | return s; |
1131 | } |
1132 | |
1133 | stmt_t after_loop(const stmt_t &_s) const { |
1134 | auto s = _s; |
1135 | if (slm_bufs_ == 3) { |
1136 | s = s.append(funcs::barrier_wait()); |
1137 | // Wait with V3 guarantees that all SLM writes are synced, other |
1138 | // versions need additional synchronization. |
1139 | if (!is_x3_v3()) s = s.append(funcs::barrier()); |
1140 | } |
1141 | return s; |
1142 | } |
1143 | |
1144 | stmt_t before_L(const stmt_t &_s, bool do_mul) const { |
1145 | bool emit = false; |
1146 | if (!with_unroll_) emit = true; |
1147 | if (with_unroll_ && do_mul) emit = true; |
1148 | |
1149 | auto s = _s; |
1150 | if (is_x3_v2() && emit) { s = s.append(funcs::barrier_wait()); } |
1151 | |
1152 | return s; |
1153 | } |
1154 | |
1155 | stmt_t before_L_prepend(const stmt_t &_s, bool do_mul) const { |
1156 | return before_L(stmt_t(), do_mul).append(_s); |
1157 | } |
1158 | |
1159 | stmt_t after_L(const stmt_t &_s, bool do_mul) const { |
1160 | bool emit = false; |
1161 | if (!with_unroll_) emit = true; |
1162 | if (with_unroll_ && do_mul) emit = true; |
1163 | |
1164 | auto s = _s; |
1165 | if (is_x3_v1() && emit) s = s.append(funcs::barrier_wait()); |
1166 | return s; |
1167 | } |
1168 | |
1169 | stmt_t after_L_prepend(const stmt_t &_s, bool do_mul) const { |
1170 | return after_L(stmt_t(), do_mul).append(_s); |
1171 | } |
1172 | |
1173 | stmt_t before_S(const stmt_t &_s, bool do_mul, bool is_last_mul = false, |
1174 | int iter = -1) const { |
1175 | bool emit = false; |
1176 | if (!with_unroll_) emit = true; |
1177 | if (with_unroll_ && iter != -1 |
1178 | && iter >= (slm_bufs_ - 1) + (gmem_bufs_ - 1) - 1) |
1179 | emit = true; |
1180 | |
1181 | auto s = _s; |
1182 | if (is_x3_v3() && emit) { |
1183 | s = s.append(funcs::barrier_wait()); |
1184 | } else if ((is_x3_v1() || is_x3_v2()) && emit) { |
1185 | // In general we have to use SLM fence before signal to flush all |
1186 | // previous SLM stores. However any SLM load behaves as implicit |
1187 | // SLM fence for all previous SLM stores. This means we don't need |
1188 | // explicit SLM fence when we perform SLM load/multiplication |
1189 | // before signal. |
1190 | if (!do_mul) s = s.append(funcs::slm_fence()); |
1191 | if (!is_last_mul) s = s.append(funcs::signal()); |
1192 | } |
1193 | return s; |
1194 | } |
1195 | |
1196 | stmt_t after_S( |
1197 | const stmt_t &_s, bool is_last_mul = false, int iter = -1) const { |
1198 | auto s = _s; |
1199 | if (is_x2()) { |
1200 | s = s.append(funcs::barrier()); |
1201 | } else if (is_x3_v3()) { |
1202 | bool emit = false; |
1203 | if (!with_unroll_) emit = true; |
1204 | if (with_unroll_ && !is_last_mul && iter != -1 |
1205 | && iter >= (slm_bufs_ - 1) + (gmem_bufs_ - 1) - 2) |
1206 | emit = true; |
1207 | if (emit) { |
1208 | s = s.append(funcs::slm_fence()); |
1209 | s = s.append(funcs::signal()); |
1210 | } |
1211 | } |
1212 | return s; |
1213 | } |
1214 | |
1215 | bool is_x2() const { return ver_ == version_t::x2; } |
1216 | bool is_x3_v1() const { return ver_ == version_t::x3_v1; } |
1217 | bool is_x3_v2() const { return ver_ == version_t::x3_v2; } |
1218 | bool is_x3_v3() const { return ver_ == version_t::x3_v3; } |
1219 | |
1220 | private: |
1221 | enum class version_t { |
1222 | undef, |
1223 | // Double buffering scheme: |
1224 | // L0 |
1225 | // M1 |
1226 | // S0 |
1227 | // barrier |
1228 | // L1 |
1229 | // M0 |
1230 | // S1 |
1231 | // barrier |
1232 | x2, |
1233 | // Triple buffering scheme V1 (wait before M) |
1234 | // L0 |
1235 | // wait R1/W0 |
1236 | // M1 |
1237 | // signal R2/W1 |
1238 | // S0 |
1239 | // L1 |
1240 | // wait R2/W1 |
1241 | // M2 |
1242 | // signal R0/W2 |
1243 | // S1 |
1244 | // L2 |
1245 | // wait R0/W2 |
1246 | // M0 |
1247 | // signal R1/W0 |
1248 | // S2 |
1249 | x3_v1, |
1250 | // Triple buffering scheme V2 (wait before L) |
1251 | // wait R1/W0 |
1252 | // L0 |
1253 | // M1 |
1254 | // signal R2/W1 |
1255 | // S0 |
1256 | // wait R2/W1 |
1257 | // L1 |
1258 | // M2 |
1259 | // signal R0/W2 |
1260 | // S1 |
1261 | // wait R0/W2 |
1262 | // L2 |
1263 | // M0 |
1264 | // signal R1/W0 |
1265 | // S2 |
1266 | x3_v2, |
1267 | // Triple buffering scheme V3 (signal after store) |
1268 | // There are no SLM loads between S and signal so explicit fence is |
1269 | // required. |
1270 | // L0 |
1271 | // M1 |
1272 | // wait R2/W0 |
1273 | // S0 |
1274 | // fence and signal R0/W1 |
1275 | // L1 |
1276 | // M2 |
1277 | // wait R0/W1 |
1278 | // S1 |
1279 | // fence and signal R1/W2 |
1280 | // L2 |
1281 | // M0 |
1282 | // wait R1/W2 |
1283 | // S2 |
1284 | // fence and signal R2/W0 |
1285 | x3_v3 |
1286 | }; |
1287 | |
1288 | int slm_bufs_; |
1289 | int gmem_bufs_; |
1290 | bool with_unroll_; |
1291 | version_t ver_; |
1292 | }; |
1293 | |
1294 | class : public ir_visitor_t { |
1295 | public: |
1296 | ( |
1297 | std::vector<stmt_t> &retn, object_eq_set_t<expr_t> &bufs) |
1298 | : retn_(retn), bufs_(bufs), outer_(true) {} |
1299 | |
1300 | void (const store_t &obj) override { |
1301 | if (obj.buf.str().find("zp_mask" ) == 0) { |
1302 | if (outer_) retn_.emplace_back(obj); |
1303 | bufs_.insert(obj.buf); |
1304 | } |
1305 | } |
1306 | |
1307 | void (const let_t &obj) override { |
1308 | if ((obj.var.str().find("zp_mask" ) == 0)) { |
1309 | if (outer_) retn_.emplace_back(obj); |
1310 | auto outer_prev = outer_; |
1311 | outer_ = false; |
1312 | visit(obj.body); |
1313 | outer_ = outer_prev; |
1314 | } |
1315 | } |
1316 | |
1317 | private: |
1318 | std::vector<stmt_t> &; |
1319 | object_eq_set_t<expr_t> &; |
1320 | bool ; |
1321 | }; |
1322 | |
1323 | class simple_slm_buffering_injector_t { |
1324 | public: |
1325 | simple_slm_buffering_injector_t(const stmt_t &root, ir_context_t &ir_ctx, |
1326 | const conv_config_t &cfg, int ab_slm_size) |
1327 | : ir_ctx_(ir_ctx) |
1328 | , cfg_(cfg) |
1329 | , ab_slm_size_(ab_slm_size) |
1330 | , root_(root) |
1331 | , alloc_mgr_(root_) |
1332 | , step_(root) |
1333 | , loop_nest_(root, ir_ctx) |
1334 | , slm_sync_mgr_(cfg, /*with_unroll=*/false) {} |
1335 | |
1336 | stmt_t inject() { |
1337 | ir_assert(cfg_.slm().gmem_bufs() == 1) |
1338 | << "GRF buffering is not supported." ; |
1339 | if (utils::one_of(cfg_.slm().bufs(), 0, 1)) return root_; |
1340 | |
1341 | ir_assert(cfg_.slm().a() == cfg_.slm().b()) |
1342 | << "Mixed SLM/GMEM loads are not supported." ; |
1343 | |
1344 | auto loop = step_.compute_loop(); |
1345 | |
1346 | // SLM indices are allocated as follows: |
1347 | // slm_idx[0] -> slm_buf_store |
1348 | // slm_idx[1] -> slm_buf_compute |
1349 | // slm_idx[2] -> slm_counter |
1350 | auto slm_idx_buf |
1351 | = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "slm_idx" ); |
1352 | int slm_idx_size = type_t::s32().size(); |
1353 | |
1354 | auto slm_idx_load = [&](int off, int elems) { |
1355 | return load_t::make( |
1356 | type_t::s32(elems), slm_idx_buf, slm_idx_size * off); |
1357 | }; |
1358 | |
1359 | // Initialize slm_idx. |
1360 | int off = 0; |
1361 | auto store0 = store_t::make(slm_idx_buf, off, 0); |
1362 | off += slm_idx_size; |
1363 | |
1364 | auto store1 = store_t::make(slm_idx_buf, off, 1); |
1365 | off += slm_idx_size; |
1366 | |
1367 | auto store2 = store_t::make( |
1368 | slm_idx_buf, off, int_imm_t::make(0, type_t::s32())); |
1369 | |
1370 | auto slm_idx_init = store0.append(store1).append(store2); |
1371 | |
1372 | auto slm_idx_load2 = slm_idx_load(0, 2); |
1373 | auto slm_idx_load4 = slm_idx_load(0, 4); |
1374 | auto slm_idx_store = store_t::make(slm_idx_buf, 0, |
1375 | slm_idx_load4 + shuffle_t::make_broadcast(1, 4)); |
1376 | |
1377 | // Update slm_idx. |
1378 | auto mask = (slm_idx_load2 |
1379 | == shuffle_t::make_broadcast(cfg_.slm().bufs(), 2)); |
1380 | auto slm_idx_store_fix = store_t::make(slm_idx_buf, 0, |
1381 | shuffle_t::make_broadcast(int_imm_t::make(0, type_t::s32()), 2), |
1382 | store_t::default_stride, mask); |
1383 | |
1384 | auto slm_idx_update = slm_idx_store.append(slm_idx_store_fix); |
1385 | |
1386 | loop = slm_idx_init.append(loop); |
1387 | |
1388 | auto &g2s_load_orig = step_.g2s_load(); |
1389 | auto &g2s_store_orig = step_.g2s_store(); |
1390 | auto &s2r_load = step_.s2r_load(); |
1391 | auto &mul = step_.mul(); |
1392 | |
1393 | auto g2s_load = g2s_load_orig; |
1394 | auto g2s_store = g2s_store_orig; |
1395 | |
1396 | ir_assert(s2r_load.size() == mul.size()); |
1397 | |
1398 | stmt_t s2r_mul; |
1399 | for (int i = 0; i < int(mul.size()); i++) { |
1400 | s2r_mul = s2r_mul.append(s2r_load[i]); |
1401 | loop = substitute(loop, s2r_load[i], stmt_t(), 1); |
1402 | s2r_mul = s2r_mul.append(mul[i]); |
1403 | loop = substitute(loop, mul[i], stmt_t(), 1); |
1404 | } |
1405 | |
1406 | loop = remove_synchronization(loop); |
1407 | |
1408 | object_eq_set_t<expr_t> mask_bufs; |
1409 | std::vector<stmt_t> masks; |
1410 | |
1411 | slm_zp_mask_extractor_t(masks, mask_bufs).visit(s2r_mul); |
1412 | if (!mask_bufs.empty()) |
1413 | for (auto &m : masks) |
1414 | s2r_mul = substitute(s2r_mul, m, stmt_t()); |
1415 | |
1416 | s2r_mul = sub_slm_bufs(s2r_mul, slm_idx_load(1, 1)); |
1417 | g2s_store = sub_slm_bufs(g2s_store, slm_idx_load(0, 1)); |
1418 | g2s_store = g2s_store.append(slm_idx_update); |
1419 | |
1420 | auto s2r_mul_body = s2r_mul; |
1421 | auto s2r_mul_tail = s2r_mul; |
1422 | auto slm_counter = slm_idx_load(2, 1); |
1423 | auto cond = (slm_counter >= cfg_.slm().bufs() - 1); |
1424 | |
1425 | if (cfg_.slm().bufs() == 2) { |
1426 | s2r_mul_body = if_t::make(cond, s2r_mul_body); |
1427 | } else { |
1428 | // In general we have to use SLM fence before signal to flush all |
1429 | // previous SLM stores. However any SLM load behaves as implicit |
1430 | // SLM fence for all previous SLM stores. This means we don't need |
1431 | // explicit SLM fence when we perform SLM load/multiplication |
1432 | // before signal. |
1433 | auto with_mul = slm_sync_mgr_.before_S(s2r_mul_body, true); |
1434 | auto without_mul = slm_sync_mgr_.before_S(stmt_t(), false); |
1435 | s2r_mul_body = if_t::make(cond, with_mul, without_mul); |
1436 | } |
1437 | |
1438 | g2s_store = slm_sync_mgr_.after_S(g2s_store); |
1439 | g2s_load = slm_sync_mgr_.before_L_prepend(g2s_load, true); |
1440 | g2s_load = slm_sync_mgr_.after_L(g2s_load, true); |
1441 | |
1442 | if (!g2s_load.is_same(g2s_load_orig)) { |
1443 | loop = substitute(loop, g2s_load_orig, g2s_load, 1); |
1444 | } |
1445 | |
1446 | alloc_updater_t alloc_updater; |
1447 | |
1448 | int slm_bufs = cfg_.slm().bufs(); |
1449 | for (auto &mbuf : mask_bufs) { |
1450 | auto sz = alloc_mgr_.alloc_size(mbuf); |
1451 | alloc_updater.resize(mbuf, sz * slm_bufs); |
1452 | for (auto &m : masks) |
1453 | m = substitute(m, mbuf, mbuf[sz * (slm_bufs - 1)]); |
1454 | layout_t comp_layout(type_t::u8(), 0, std::vector<dim_t> {sz}); |
1455 | for (int b = 1; b < slm_bufs; b++) { |
1456 | auto reorder = create_reorder_stmt(comp_layout, comp_layout, |
1457 | mbuf + b * sz, mbuf + (b - 1) * sz); |
1458 | s2r_mul_body = s2r_mul_body.append(reorder); |
1459 | if ((slm_bufs == 3) && (b == 1)) |
1460 | s2r_mul_tail = s2r_mul_tail.append(reorder); |
1461 | } |
1462 | } |
1463 | if (!mask_bufs.empty()) { |
1464 | stmt_t all_masks; |
1465 | for (auto &m : masks) |
1466 | all_masks = all_masks.append(m); |
1467 | s2r_mul_body = all_masks.append(s2r_mul_body); |
1468 | } |
1469 | loop = substitute( |
1470 | loop, g2s_store_orig, s2r_mul_body.append(g2s_store), 1); |
1471 | |
1472 | loop = slm_sync_mgr_.before_loop_prepend(loop); |
1473 | |
1474 | // Complete the remaining iterations. |
1475 | int rem_iters = slm_bufs - 1; |
1476 | int mul_start = std::max(0, rem_iters - loop_nest_.size()); |
1477 | multi_loop_iterator_t multi(loop_nest_.loops()); |
1478 | multi.advance(loop_nest_.size() - rem_iters + mul_start); |
1479 | |
1480 | loop = slm_sync_mgr_.after_loop(loop); |
1481 | for (int i = 0; i < rem_iters; i++) { |
1482 | if (i >= mul_start) { |
1483 | auto tmp_mul_tail = s2r_mul_tail; |
1484 | loop_nest_.for_each_loop_var([&](const expr_t &v) { |
1485 | expr_t iter(multi.var_value(v)); |
1486 | tmp_mul_tail = substitute(tmp_mul_tail, v, iter); |
1487 | }); |
1488 | // SLM load/multiplication works as implicit SLM fence. |
1489 | loop = loop.append(tmp_mul_tail); |
1490 | multi.advance(); |
1491 | } |
1492 | loop = loop.append(slm_idx_update); |
1493 | } |
1494 | |
1495 | if (cfg_.assign_sbids()) |
1496 | loop = sbid_assigner_t(ir_ctx_.hw_cfg().hw()).assign(loop); |
1497 | |
1498 | const auto grf_size = ir_ctx_.hw_cfg().grf_size(); |
1499 | loop = alloc_t::make(slm_idx_buf, grf_size, alloc_kind_t::grf, loop); |
1500 | |
1501 | auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm); |
1502 | ir_assert(slm_buffers.size() == 1); |
1503 | auto &slm_buf = slm_buffers[0]; |
1504 | int non_ab_slm_size = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_; |
1505 | alloc_updater.resize( |
1506 | slm_buf, non_ab_slm_size + ab_slm_size_ * slm_bufs); |
1507 | |
1508 | auto ret = substitute(root_, step_.compute_loop(), loop, 1); |
1509 | ret = alloc_updater.update(ret); |
1510 | return ret; |
1511 | } |
1512 | |
1513 | static stmt_t remove_synchronization(const stmt_t &s) { |
1514 | auto ret = s; |
1515 | for (auto &_c : find_objects<func_call_t>(s)) { |
1516 | auto &c = _c.as<func_call_t>(); |
1517 | if (c.func.is_equal(funcs::signal_func()) |
1518 | || c.func.is_equal(funcs::slm_fence_func()) |
1519 | || c.func.is_equal(funcs::barrier_func())) { |
1520 | ret = substitute(ret, _c, stmt_t(), 1); |
1521 | } |
1522 | } |
1523 | return ret; |
1524 | } |
1525 | |
1526 | stmt_t sub_slm_bufs(const stmt_t &stmt, const expr_t &slm_idx) const { |
1527 | auto stmt_vec = flatten_statements(stmt); |
1528 | |
1529 | stmt_t ret = stmt; |
1530 | for (auto &s : stmt_vec) { |
1531 | if (!is_func_call<send_t>(s)) continue; |
1532 | |
1533 | auto &send = s.as<func_call_t>().func.as<send_t>(); |
1534 | |
1535 | // This is not send to SLM, skip. |
1536 | if (!send.is_slm()) continue; |
1537 | |
1538 | auto new_args = s.as<func_call_t>().args; |
1539 | send_t::arg_mem_off(new_args) += ab_slm_size_ * slm_idx; |
1540 | auto new_send = send.call(new_args); |
1541 | ret = substitute(ret, s, new_send, 1); |
1542 | } |
1543 | |
1544 | return ret; |
1545 | } |
1546 | |
1547 | ir_context_t &ir_ctx_; |
1548 | const conv_config_t &cfg_; |
1549 | int ab_slm_size_; |
1550 | |
1551 | stmt_t root_; |
1552 | alloc_manager_t alloc_mgr_; |
1553 | compute_step_t step_; |
1554 | compute_loop_nest_t loop_nest_; |
1555 | slm_sync_manager_t slm_sync_mgr_; |
1556 | }; |
1557 | |
1558 | stmt_t inject_simple_slm_buffering(const stmt_t &s, ir_context_t &ir_ctx, |
1559 | const conv_config_t &cfg, int ab_slm_size) { |
1560 | trace_start(); |
1561 | auto ret = simple_slm_buffering_injector_t(s, ir_ctx, cfg, ab_slm_size) |
1562 | .inject(); |
1563 | trace_pass("inject_simple_slm_buffering" , ret, ir_ctx); |
1564 | return ret; |
1565 | } |
1566 | |
1567 | class unrolling_injector_t { |
1568 | public: |
1569 | unrolling_injector_t(const stmt_t &root, const conv_config_t &cfg, |
1570 | ir_context_t &ir_ctx, int ab_slm_size) |
1571 | : cfg_(cfg) |
1572 | , ir_ctx_(ir_ctx) |
1573 | , ab_slm_size_(ab_slm_size) |
1574 | , root_(root) |
1575 | , alloc_mgr_(root_) |
1576 | , step_(root) |
1577 | , loop_nest_(root, ir_ctx) |
1578 | , slm_sync_mgr_(cfg, /*with_unroll=*/true) { |
1579 | int inner_iters = loop_nest_.inner_loops_size(); |
1580 | params_ = compute_params_t(cfg_.slm().bufs(), cfg_.slm().gmem_bufs(), |
1581 | ab_slm_size, cfg_.prefetch().bufs(), inner_iters); |
1582 | if (params_.use_slm) { |
1583 | for (auto &b : |
1584 | find_send_buffers(step_.g2s_load(), /*is_mem=*/false)) { |
1585 | g2s_reg_bufs_.emplace_back(b, alloc_mgr_.alloc_size(b)); |
1586 | } |
1587 | } |
1588 | |
1589 | // Can't fuse top-level zero-out statement unless the compute loop is |
1590 | // top-level as well. |
1591 | fuse_zero_out_with_fma_ = (loop_nest_.compute_loop_level() == 0); |
1592 | } |
1593 | |
1594 | stmt_t inject() { |
1595 | compute_iterator_t it(params_, loop_nest_); |
1596 | stmt_t body; |
1597 | |
1598 | sbid_manager_t sbid_mgr(cfg_.hw()); |
1599 | |
1600 | auto &outer_loop_info = loop_nest_.outer_loop_info(); |
1601 | |
1602 | auto append_outer_post_inc = [&](const stmt_t &_s) { |
1603 | auto &mul = outer_loop_info.mul_post_inc_stmt(); |
1604 | auto &preload = outer_loop_info.preload_post_inc_stmt(); |
1605 | auto s = _s; |
1606 | if (it.mul_loop_it.is_outer_loop_end() && it.do_mul()) { |
1607 | s = s.append(mul); |
1608 | } |
1609 | if (it.preload_loop_it.is_outer_loop_end() && it.do_preload()) { |
1610 | s = s.append(preload); |
1611 | } |
1612 | return s; |
1613 | }; |
1614 | |
1615 | bmnk_dim_helper_t h(cfg_); |
1616 | int k_iter_blk = h.iter_dim('k'); |
1617 | int reduce_iter_bytes = k_iter_blk * cfg_.prb().a_data_type_size; |
1618 | // Add periodic signal-wait thread group synchronization in some cases. |
1619 | // This is to ensure threads access close reduction blocks and able to |
1620 | // reuse their common data from L1. |
1621 | bool do_sync |
1622 | = (cfg_.hw() >= ngen::HW::XeHPC) && (reduce_iter_bytes > 32); |
1623 | if (cfg_.slm()) do_sync = false; |
1624 | // Distance in iterations between signal and wait. |
1625 | int sync_dist = 3; |
1626 | |
1627 | // Ramp-up. |
1628 | for (int i = 0; i < it.ramp_up_iters; i++) { |
1629 | body = stmt_seq_t::make(body, create_iteration(it, sbid_mgr)); |
1630 | body = append_outer_post_inc(body); |
1631 | ++it; |
1632 | } |
1633 | |
1634 | // Body. |
1635 | if (it.body_iters > 0) { |
1636 | int extent = it.body_iters / it.unroll(); |
1637 | bool has_loop = (extent > 1); |
1638 | |
1639 | stmt_t loop_body; |
1640 | bool do_sync_wait = false; |
1641 | for (int i = 0; i < it.unroll(); i++) { |
1642 | if (do_sync && i % sync_dist == 0) { |
1643 | loop_body = loop_body.append(do_sync_wait |
1644 | ? funcs::barrier_wait() |
1645 | : funcs::signal()); |
1646 | do_sync_wait = !do_sync_wait; |
1647 | } |
1648 | loop_body = loop_body.append(create_iteration( |
1649 | it, sbid_mgr, /*in_loop_body=*/has_loop)); |
1650 | ir_assert(it.do_mul()); |
1651 | loop_body = append_outer_post_inc(loop_body); |
1652 | ++it; |
1653 | } |
1654 | if (do_sync && do_sync_wait) |
1655 | loop_body = loop_body.append(funcs::barrier_wait()); |
1656 | if (!has_loop) { |
1657 | body = body.append(loop_body); |
1658 | } else { |
1659 | ir_assert(extent > 0); |
1660 | auto for_var = ir_ctx_.create_tmp_var(type_t::s32(), "i" ); |
1661 | body = body.append(for_t::make(for_var, 0, extent, loop_body)); |
1662 | } |
1663 | it.advance(it.body_iters - it.unroll()); |
1664 | } |
1665 | |
1666 | // Ramp-down. |
1667 | for (int i = 0; i < it.ramp_down_iters; i++) { |
1668 | ir_assert(it.do_mul()); |
1669 | body = body.append(create_iteration(it, sbid_mgr)); |
1670 | body = append_outer_post_inc(body); |
1671 | ++it; |
1672 | } |
1673 | |
1674 | if (outer_loop_info.has_var_refs()) { |
1675 | body = outer_loop_info.init_stmt().append(body); |
1676 | body = outer_loop_info.inject_alloc_stmts(body); |
1677 | } |
1678 | |
1679 | // When compute loop is part of outer loop and SLM buffering is used |
1680 | // then synchronization is required between outer iterations. |
1681 | if (loop_nest_.compute_loop_level() != 0 && params_.use_slm) { |
1682 | body = funcs::barrier().append(body); |
1683 | } |
1684 | |
1685 | body = stmt_group_t::make(stmt_label_t::compute_loop(), body); |
1686 | auto ret = substitute(root_, step_.compute_loop(), body, 1); |
1687 | |
1688 | if (params_.use_slm) { |
1689 | alloc_updater_t alloc_updater; |
1690 | |
1691 | // Update buffer sizes. |
1692 | for (auto &b : g2s_reg_bufs_) { |
1693 | alloc_updater.resize(b.buf, |
1694 | alloc_mgr_.alloc_size(b.buf) * cfg_.slm().gmem_bufs()); |
1695 | } |
1696 | |
1697 | auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm); |
1698 | if (!slm_buffers.empty()) { |
1699 | ir_assert(slm_buffers.size() == 1); |
1700 | |
1701 | auto &slm_buf = slm_buffers[0]; |
1702 | int non_ab_slm_size |
1703 | = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_; |
1704 | alloc_updater.resize(slm_buf, |
1705 | non_ab_slm_size + ab_slm_size_ * cfg_.slm().bufs()); |
1706 | } |
1707 | |
1708 | ret = alloc_updater.update(ret); |
1709 | } |
1710 | |
1711 | // Remove zero-out statement for C (handled by sub_fma_acc_with_zero). |
1712 | if (fuse_zero_out_with_fma_) |
1713 | ret = substitute(ret, step_.c_zero_out(), stmt_t(), 1); |
1714 | |
1715 | return ret; |
1716 | } |
1717 | |
1718 | private: |
1719 | struct buffer_info_t { |
1720 | buffer_info_t(const expr_t &buf, int size) : buf(buf), size(size) {} |
1721 | |
1722 | expr_t buf; |
1723 | int size; |
1724 | }; |
1725 | |
1726 | stmt_t create_iteration(const compute_iterator_t &it, |
1727 | sbid_manager_t &sbid_mgr, bool in_loop_body = false) const { |
1728 | auto g2s_load = step_.g2s_load(); |
1729 | auto g2s_store = step_.g2s_store(); |
1730 | auto prefetch = step_.prefetch(); |
1731 | auto g2r_load = step_.g2r_load(); |
1732 | auto s2r_load = step_.s2r_load(); |
1733 | auto mul = step_.mul(); |
1734 | auto lets = step_.inner_let_stmts(); |
1735 | auto &outer_loop_info = loop_nest_.outer_loop_info(); |
1736 | |
1737 | loop_nest_.for_each_loop_var([&](const expr_t &v) { |
1738 | expr_t mul_var_value; |
1739 | expr_t preload_var_value; |
1740 | if (v.is_same(outer_loop_info.var) && in_loop_body |
1741 | && outer_loop_info.has_var_refs()) { |
1742 | mul_var_value = outer_loop_info.mul_var_load(); |
1743 | preload_var_value = outer_loop_info.preload_var_load(); |
1744 | } else { |
1745 | mul_var_value = it.mul_loop_it.var_value(v); |
1746 | preload_var_value = it.preload_loop_it.var_value(v); |
1747 | } |
1748 | g2s_load = const_fold(substitute(g2s_load, v, preload_var_value)); |
1749 | g2s_store = const_fold(substitute(g2s_store, v, preload_var_value)); |
1750 | prefetch = const_fold(substitute(prefetch, v, preload_var_value)); |
1751 | for (auto &m : mul) { |
1752 | m = const_fold(substitute(m, v, mul_var_value)); |
1753 | } |
1754 | for (auto &s : g2r_load) { |
1755 | s = const_fold(substitute(s, v, mul_var_value)); |
1756 | } |
1757 | for (auto &s : s2r_load) { |
1758 | if (count_object(s, v) > 0) ir_error_not_expected(); |
1759 | s = const_fold(substitute(s, v, preload_var_value)); |
1760 | } |
1761 | for (int i = 0; i < int(lets.size()); i++) { |
1762 | auto &let = lets[i]; |
1763 | auto &orig_let = step_.inner_let_stmts()[i]; |
1764 | expr_t var_value; |
1765 | bool is_preload_let = step_.is_preload_let(orig_let); |
1766 | bool is_mul_let = step_.is_mul_let(orig_let); |
1767 | if (is_preload_let && !is_mul_let) { |
1768 | var_value = preload_var_value; |
1769 | } else if (is_mul_let && !is_preload_let) { |
1770 | var_value = mul_var_value; |
1771 | } else { |
1772 | ir_assert(count_object(let.as<let_t>().value, v) == 0) |
1773 | << "Unexpected reference to variable " << v |
1774 | << " from " << let; |
1775 | continue; |
1776 | } |
1777 | let = const_fold(substitute(let, v, var_value)); |
1778 | } |
1779 | }); |
1780 | |
1781 | if (params_.use_slm) { |
1782 | g2s_load = sub_gmem_bufs(g2s_load, it, /*is_read=*/false); |
1783 | g2s_store = sub_gmem_bufs(g2s_store, it, /*is_read=*/true); |
1784 | |
1785 | g2s_store = sub_slm_bufs(g2s_store, it, /*is_read=*/false); |
1786 | for (auto &s : s2r_load) { |
1787 | s = sub_slm_bufs(s, it, /*is_read=*/true); |
1788 | } |
1789 | } |
1790 | |
1791 | if (it.is_first_mul() && fuse_zero_out_with_fma_) { |
1792 | mul = sub_fma_acc_with_zero( |
1793 | mul, cfg_.fma_kind() == fma_kind_t::mad); |
1794 | } |
1795 | |
1796 | if (it.is_last_g2s_store()) |
1797 | g2s_store = remove_post_inc_stores(g2s_store); |
1798 | if (it.is_last_g2s_load()) g2s_load = remove_post_inc_stores(g2s_load); |
1799 | if (it.is_last_prefetch()) prefetch = remove_post_inc_stores(prefetch); |
1800 | if (it.is_last_mul()) { |
1801 | for (auto &s : s2r_load) |
1802 | s = remove_post_inc_stores(s); |
1803 | for (auto &s : g2r_load) |
1804 | s = remove_post_inc_stores(s); |
1805 | } |
1806 | |
1807 | stmt_t iter_stmt; |
1808 | |
1809 | iter_stmt = slm_sync_mgr_.before_L(iter_stmt, it.do_mul()); |
1810 | if (it.do_g2s_load()) iter_stmt = iter_stmt.append(g2s_load); |
1811 | iter_stmt = slm_sync_mgr_.after_L(iter_stmt, it.do_mul()); |
1812 | |
1813 | if (it.do_g2s_store() && it.slm_bufs() == 1) { |
1814 | iter_stmt = iter_stmt.append(funcs::barrier()); |
1815 | iter_stmt = iter_stmt.append(g2s_store); |
1816 | iter_stmt = iter_stmt.append(funcs::barrier()); |
1817 | } |
1818 | |
1819 | if (it.do_prefetch()) iter_stmt = iter_stmt.append(prefetch); |
1820 | |
1821 | if (it.do_mul()) { |
1822 | for (size_t i = 0; i < mul.size(); i++) { |
1823 | iter_stmt = iter_stmt.append(g2r_load[i]); |
1824 | iter_stmt = iter_stmt.append(s2r_load[i]); |
1825 | iter_stmt = iter_stmt.append(mul[i]); |
1826 | } |
1827 | } |
1828 | iter_stmt = slm_sync_mgr_.before_S( |
1829 | iter_stmt, it.do_mul(), it.is_last_mul(), it.iter); |
1830 | |
1831 | if (it.do_g2s_store() && it.slm_bufs() >= 2) { |
1832 | iter_stmt = iter_stmt.append(g2s_store); |
1833 | } |
1834 | |
1835 | iter_stmt = slm_sync_mgr_.after_S(iter_stmt, it.is_last_mul(), it.iter); |
1836 | |
1837 | if (cfg_.assign_sbids()) |
1838 | iter_stmt = sbid_assigner_t(sbid_mgr).assign(iter_stmt); |
1839 | |
1840 | iter_stmt = inject_local_let(iter_stmt, lets, it.linear_id); |
1841 | |
1842 | return iter_stmt; |
1843 | } |
1844 | |
1845 | stmt_t sub_gmem_bufs(const stmt_t &stmt, const compute_iterator_t &it, |
1846 | bool is_read) const { |
1847 | if (it.slm_bufs() == 0) return stmt; |
1848 | if (is_read && !it.do_g2s_store()) return stmt; |
1849 | if (!is_read && !it.do_g2s_load()) return stmt; |
1850 | |
1851 | int buf_idx = (is_read ? it.gmem_read_buf_index() |
1852 | : it.gmem_write_buf_index()); |
1853 | if (buf_idx == 0) return stmt; |
1854 | |
1855 | auto ret = stmt; |
1856 | for (auto &b : g2s_reg_bufs_) { |
1857 | ret = substitute(ret, b.buf, b.buf[buf_idx * b.size]); |
1858 | } |
1859 | return ret; |
1860 | } |
1861 | |
1862 | stmt_t sub_slm_bufs(const stmt_t &stmt, const compute_iterator_t &it, |
1863 | bool is_read) const { |
1864 | if (it.slm_bufs() <= 1) return stmt; |
1865 | if (is_read && !it.do_mul()) return stmt; |
1866 | if (!is_read && !it.do_g2s_store()) return stmt; |
1867 | |
1868 | int upd = (is_read ? it.slm_read_offset_update() |
1869 | : it.slm_write_offset_update()); |
1870 | |
1871 | auto stmt_vec = flatten_statements(stmt); |
1872 | |
1873 | stmt_t ret = stmt; |
1874 | for (auto &s : stmt_vec) { |
1875 | auto *call = s.as_ptr<func_call_t>(); |
1876 | if (!call) continue; |
1877 | auto *func = call->func.as_ptr<send_t>(); |
1878 | if (!func) continue; |
1879 | |
1880 | auto &send = call->func.as<send_t>(); |
1881 | auto &args = call->args; |
1882 | auto &mem_buf = send_t::arg_mem_buf(args); |
1883 | auto &header_buf = send_t::arg_mem_off(args); |
1884 | |
1885 | // This is not send to SLM, skip. |
1886 | if (!send.is_slm()) continue; |
1887 | |
1888 | // May have signed offset. |
1889 | auto store_obj = send.create_offset_store( |
1890 | header_buf, mem_buf, upd, /*is_signed_offset=*/true); |
1891 | auto &store = store_obj.as<store_t>(); |
1892 | expr_t old_value |
1893 | = load_t::make(send.address_type(), store.buf, store.off); |
1894 | auto post_inc_store = store_t::make( |
1895 | store.buf, store.off, old_value + store.value); |
1896 | ret = substitute(ret, s, stmt_seq_t::make(s, post_inc_store), 1); |
1897 | } |
1898 | |
1899 | return ret; |
1900 | } |
1901 | |
1902 | static std::vector<stmt_t> sub_fma_acc_with_zero( |
1903 | const std::vector<stmt_t> &stmt, const bool is_mad) { |
1904 | auto is_from_block = [is_mad](const stmt_t &curr, |
1905 | const std::vector<stmt_t> &vec, int bgn) { |
1906 | if (is_mad) return false; |
1907 | if (is_func_call<mad_t>(curr)) { |
1908 | return (mad_t::arg_dst(curr).is_equal(mad_t::arg_src0(curr))) |
1909 | && (!bgn || is_func_call<mad_t>(vec[bgn - 1])); |
1910 | } else if (const auto *s = curr.as_ptr<store_t>()) { |
1911 | const auto *bs |
1912 | = (bgn) ? vec[bgn - 1].as_ptr<store_t>() : nullptr; |
1913 | if (const auto *t = s->value.as_ptr<ternary_op_t>()) { |
1914 | const auto *a = t->a.as_ptr<load_t>(); |
1915 | return (t->op_kind == op_kind_t::_mad) && a |
1916 | && (a->buf[a->off].is_equal(s->buf[s->off])) |
1917 | && (!bgn || (bs && bs->value.is<ternary_op_t>())); |
1918 | } else if (const auto *b = s->value.as_ptr<binary_op_t>()) { |
1919 | const auto *a = b->a.as_ptr<load_t>(); |
1920 | return (b->op_kind == op_kind_t::_sub) && a |
1921 | && (a->buf[a->off].is_equal(s->buf[s->off])) |
1922 | && (!bgn || (bs && bs->value.is<binary_op_t>())); |
1923 | } |
1924 | } |
1925 | return false; |
1926 | }; |
1927 | auto process_block = [is_mad](stmt_t &root, |
1928 | object_eq_set_t<expr_t> &seen, |
1929 | std::vector<stmt_t> &vec, int bgn, |
1930 | int end) { |
1931 | bgn += is_mad; |
1932 | end += is_mad; |
1933 | bool never_seen = (bgn > 0); |
1934 | for (int i = bgn - 1; never_seen && (i < end); i++) { |
1935 | if (is_func_call<mad_t>(vec[i])) { |
1936 | never_seen &= seen.insert(mad_t::arg_dst(vec[i])).second; |
1937 | } else if (const auto *s = vec[i].as_ptr<store_t>()) { |
1938 | never_seen &= seen.insert(s->buf[s->off]).second; |
1939 | } |
1940 | } |
1941 | for (int i = bgn - 1; never_seen && (i < end); i++) { |
1942 | if (is_func_call<mad_t>(vec[i])) { |
1943 | auto &call = vec[i].as<func_call_t>(); |
1944 | auto &mad = call.func.as<mad_t>(); |
1945 | auto *m = call.attr.as_ptr<instruction_modifier_attr_t>(); |
1946 | ir_assert(!m |
1947 | || (!m->mod.is_atomic && m->mod.sbid.is_empty())); |
1948 | ir_assert(mad.src1_stride == 0); |
1949 | auto a_load = load_t::make( |
1950 | mad.src1_type.kind(), mad_t::arg_src1(vec[i]), 0); |
1951 | auto a = shuffle_t::make_broadcast(a_load, mad.exec_size); |
1952 | auto b_stride = mad.src2_stride * mad.src2_type.size(); |
1953 | auto b = load_t::make( |
1954 | type_t(mad.src2_type.kind(), mad.exec_size), |
1955 | mad_t::arg_src2(vec[i]), 0, b_stride); |
1956 | auto mul = binary_op_t::make(op_kind_t::_mul, a, b, |
1957 | type_t(mad.dst_type.kind(), mad.exec_size)); |
1958 | auto store = store_t::make(mad_t::arg_dst(vec[i]), 0, mul); |
1959 | root = substitute(root, vec[i], store, 1); |
1960 | } else if (const auto *s = vec[i].as_ptr<store_t>()) { |
1961 | if (const auto *t = s->value.as_ptr<ternary_op_t>()) { |
1962 | ir_assert(s->mask.is_empty()); |
1963 | auto mul = binary_op_t::make( |
1964 | op_kind_t::_mul, t->b, t->c, t->type); |
1965 | root = substitute(root, vec[i], |
1966 | store_t::make(s->buf, s->off, mul), 1); |
1967 | } else if (const auto *b = s->value.as_ptr<binary_op_t>()) { |
1968 | auto store = store_t::make(s->buf, s->off, -b->b, |
1969 | s->stride, s->mask, true); |
1970 | root = substitute(root, vec[i], store, 1); |
1971 | } else { |
1972 | ir_error_not_expected(); |
1973 | } |
1974 | } else { |
1975 | ir_error_not_expected(); |
1976 | } |
1977 | } |
1978 | return (is_mad) ? end : 0; |
1979 | }; |
1980 | std::vector<stmt_t> retn; |
1981 | |
1982 | for (const auto &s : stmt) { |
1983 | ir_assert(s.is<stmt_group_t>()); |
1984 | const auto &group = s.as<stmt_group_t>(); |
1985 | auto body = group.body; |
1986 | auto stmt_vec = flatten_statements(body); |
1987 | object_eq_set_t<expr_t> seen_dst; |
1988 | |
1989 | int bgn = 0; |
1990 | for (int i = 0; i < int(stmt_vec.size()); i++) { |
1991 | stmt_t curr = stmt_vec[i]; |
1992 | const bool ifb = is_from_block(curr, stmt_vec, bgn); |
1993 | bgn = (ifb) ? (!bgn) ? i + 1 : bgn |
1994 | : process_block(body, seen_dst, stmt_vec, bgn, i); |
1995 | if (is_func_call<dpas_t>(curr) && !dpas_t::is_dp4a_call(curr)) { |
1996 | auto &call = curr.as<func_call_t>(); |
1997 | |
1998 | auto &dst = dpas_t::arg_dst(curr); |
1999 | auto src0 = expr_t(0); // Will be translated to null reg |
2000 | auto &src1 = dpas_t::arg_src1(curr); |
2001 | auto &src2 = dpas_t::arg_src2(curr); |
2002 | |
2003 | if (seen_dst.insert(dst).second) |
2004 | body = substitute(body, curr, |
2005 | func_call_t::make(call.func, |
2006 | {dst, src0, src1, src2}, call.attr), |
2007 | 1); |
2008 | } |
2009 | } |
2010 | process_block(body, seen_dst, stmt_vec, bgn, |
2011 | (int)stmt_vec.size() - is_mad); |
2012 | retn.emplace_back(stmt_group_t::make(group.label, body)); |
2013 | } |
2014 | return retn; |
2015 | } |
2016 | |
2017 | // Returns memory buffers if is_mem is true and register buffers otherwise. |
2018 | static object_set_t<expr_t> find_send_buffers( |
2019 | const stmt_t &s, bool is_mem) { |
2020 | object_set_t<expr_t> ret; |
2021 | auto calls = find_objects<func_call_t>(s); |
2022 | for (auto &_c : calls) { |
2023 | auto &c = _c.as<func_call_t>(); |
2024 | if (!c.func.is<send_t>()) continue; |
2025 | auto &buf = (is_mem ? send_t::arg_mem_buf(_c) |
2026 | : send_t::arg_reg_buf(_c)); |
2027 | ret.insert(buf.as<ptr_t>().base); |
2028 | } |
2029 | return ret; |
2030 | } |
2031 | |
2032 | static stmt_t inject_local_let(const stmt_t &_s, |
2033 | const std::vector<stmt_t> &enclosed_lets, int id) { |
2034 | auto s = _s; |
2035 | |
2036 | // Inject let statements from the innermost loop. |
2037 | for (auto &_let : enclosed_lets) { |
2038 | auto &let = _let.as<let_t>(); |
2039 | s = let_t::make(let.var, let.value, s); |
2040 | } |
2041 | |
2042 | // Substitute variables to avoid clashing. |
2043 | auto lets = find_objects<let_t>(s); |
2044 | for (auto &_let : lets) { |
2045 | auto &let = _let.as<let_t>(); |
2046 | auto &var = let.var.as<var_t>(); |
2047 | auto local_var = var_t::make( |
2048 | var.type, var.name + "_" + std::to_string(id)); |
2049 | s = substitute(s, let.var, local_var); |
2050 | } |
2051 | return s; |
2052 | } |
2053 | |
2054 | static stmt_t remove_post_inc_stores(const stmt_t &_s) { |
2055 | auto stores = find_objects<store_t>(_s); |
2056 | auto s = _s; |
2057 | for (auto &_store : stores) { |
2058 | auto &store = _store.as<store_t>(); |
2059 | if (!contains_object(store.value, store.buf)) continue; |
2060 | s = substitute(s, store, stmt_t()); |
2061 | } |
2062 | return s; |
2063 | } |
2064 | |
2065 | const conv_config_t &cfg_; |
2066 | ir_context_t &ir_ctx_; |
2067 | int ab_slm_size_; |
2068 | |
2069 | stmt_t root_; |
2070 | alloc_manager_t alloc_mgr_; |
2071 | compute_step_t step_; |
2072 | compute_loop_nest_t loop_nest_; |
2073 | compute_params_t params_; |
2074 | slm_sync_manager_t slm_sync_mgr_; |
2075 | |
2076 | std::vector<buffer_info_t> g2s_reg_bufs_; // For SLM buffering. |
2077 | bool fuse_zero_out_with_fma_ = false; |
2078 | }; |
2079 | |
2080 | stmt_t inject_unrolling(const stmt_t &s, ir_context_t &ir_ctx, |
2081 | const conv_config_t &cfg, int ab_slm_size) { |
2082 | trace_start(); |
2083 | auto ret = unrolling_injector_t(s, cfg, ir_ctx, ab_slm_size).inject(); |
2084 | trace_pass("inject_unrolling" , ret, ir_ctx); |
2085 | return ret; |
2086 | } |
2087 | |
2088 | } // namespace jit |
2089 | } // namespace gpu |
2090 | } // namespace impl |
2091 | } // namespace dnnl |
2092 | |