1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/pipeline.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/ir/reorder.hpp"
21#include "gpu/jit/utils/trace.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28// Helper structure for for_t.
29struct loop_info_t {
30 loop_info_t() = default;
31
32 loop_info_t(const stmt_t &s) {
33 ir_assert(s.is<for_t>()) << s;
34 auto &loop = s.as<for_t>();
35 stmt = s;
36 var = loop.var;
37 init_ = loop.init;
38 bound_ = loop.bound;
39
40 auto e_size = simplify(bound_ - init_);
41 ir_assert(is_const(e_size));
42 size_ = to_cpp<int>(e_size);
43 }
44
45 int init() const {
46 ir_assert(is_const(init_));
47 return to_cpp<int>(init_);
48 }
49
50 int bound() const {
51 ir_assert(is_const(bound_));
52 return to_cpp<int>(bound_);
53 }
54
55 int size() const { return size_; }
56 const stmt_t &body() const { return stmt.as<for_t>().body; }
57 int unroll() const { return stmt.as<for_t>().unroll; }
58
59 stmt_t stmt;
60 expr_t var;
61
62private:
63 expr_t init_;
64 expr_t bound_;
65 int size_;
66};
67
68// Iterates through multiple nested loops with fixed bounds. Used to unroll
69// such nested loops.
70class multi_loop_iterator_t {
71public:
72 // Ordered from innermost to outermost.
73 multi_loop_iterator_t(const std::vector<loop_info_t> &loops)
74 : loops_(loops) {
75 for (auto &l : loops)
76 var_values_.push_back(l.init());
77 }
78
79 int var_value(const expr_t &var) const {
80 for (size_t i = 0; i < loops_.size(); i++) {
81 if (loops_[i].var.is_same(var)) return var_values_[i];
82 }
83 ir_error_not_expected();
84 return 0;
85 }
86
87 void advance(int n = 1) {
88 if (loops_.empty()) return;
89 for (int i_n = 0; i_n < n; i_n++) {
90 for (size_t i = 0; i < loops_.size(); i++) {
91 auto &l = loops_[i];
92 if (++var_values_[i] < l.bound()) break;
93 var_values_[i] = l.init();
94 }
95 ir_assert(var_values_.back() < loops_.back().bound());
96 }
97 }
98
99 bool is_outer_loop_end() const {
100 if (loops_.empty()) return true;
101 for (size_t i = 0; i < loops_.size() - 1; i++) {
102 auto &l = loops_[i];
103 if (var_values_[i] != l.bound() - 1) return false;
104 }
105 return true;
106 }
107
108 std::string str() const {
109 std::ostringstream oss;
110 oss << "multi_loop_iterator_t(";
111 for (size_t i = 0; i < loops_.size(); i++) {
112 oss << (i != 0 ? ", " : "");
113 oss << loops_[i].var << " = " << var_values_[i];
114 }
115 oss << ")";
116 return oss.str();
117 }
118
119 IR_DEFINE_DUMP()
120
121private:
122 std::vector<loop_info_t> loops_;
123 std::vector<int> var_values_;
124};
125
126// Extracts different parts of the compute iteration and verifies the loop nest
127// is properly formed and can be further injected with SLM buffering.
128class compute_step_visitor_t : public ir_visitor_t {
129public:
130 stmt_t find_stmt_group(const stmt_label_t &label) const {
131 auto groups = find_stmt_groups(label);
132 if (groups.empty()) return stmt_t();
133 ir_assert(groups.size() == 1);
134 return groups[0];
135 }
136
137 std::vector<stmt_t> find_stmt_groups(const stmt_label_t &label) const {
138 std::vector<stmt_t> ret;
139 for (auto &_g : stmt_groups_) {
140 auto &g = _g.as<stmt_group_t>();
141 if (g.label == label) ret.push_back(_g);
142 }
143 return ret;
144 }
145
146 const std::vector<stmt_t> &inner_let_stmts() const {
147 return inner_let_stmts_;
148 }
149
150#define HANDLE_IR_OBJECT(type) \
151 void _visit(const type &obj) override { visit_stmt(obj); }
152
153 HANDLE_STMT_IR_OBJECTS()
154
155#undef HANDLE_IR_OBJECT
156
157 template <typename T>
158 void visit_stmt(const T &obj) {
159 bool is_for = obj.template is<for_t>();
160 bool is_stmt_group = obj.template is<stmt_group_t>();
161 bool is_let = obj.template is<let_t>();
162 bool is_stmt_seq = obj.template is<stmt_seq_t>();
163
164 // Loop may contain:
165 // - Another loop
166 // - Container statement (stmt_seq_t or stmt_group_t)
167 // - Let statement (in the innermost loop only)
168 // - Barrier
169 if (loop_level_ > 0) {
170 bool ok = false;
171 if (is_for || is_let || is_stmt_group || is_stmt_seq) {
172 ok = true;
173 } else if (obj.template is<func_call_t>()) {
174 auto &call = obj.template as<func_call_t>();
175 ok = call.func.is_equal(funcs::barrier_func());
176 }
177
178 if (!ok) {
179 ir_error_not_expected()
180 << "Found unexpected statement inside loop.\n"
181 << stmt_t(obj);
182 }
183 }
184
185 bool is_compute_loop = false;
186 if (is_stmt_group) {
187 auto label = obj.template as<stmt_group_t>().label;
188 stmt_groups_.push_back(obj);
189 if (utils::one_of(label, stmt_label_t::g2s_load(),
190 stmt_label_t::g2s_store(), stmt_label_t::g2r_load(),
191 stmt_label_t::s2r_load(), stmt_label_t::prefetch(),
192 stmt_label_t::mul())) {
193 // Leaf labels, do not visit them.
194 return;
195 }
196 if (label == stmt_label_t::compute_loop()) {
197 is_compute_loop = true;
198 in_compute_loop_ = true;
199 }
200 }
201
202 if (is_for && in_compute_loop_) loop_level_++;
203 found_loop_ = false;
204 ir_visitor_t::_visit(obj);
205 if (in_compute_loop_ && is_let) {
206 if (found_loop_)
207 ir_error_not_expected()
208 << "Let is allowed in the innermost loop only.";
209
210 inner_let_stmts_.push_back(replace_stmt_body(obj, stmt_t()));
211 }
212 if (is_for && in_compute_loop_) {
213 loop_level_--;
214 found_loop_ = true;
215 }
216
217 if (is_compute_loop) in_compute_loop_ = false;
218 }
219
220private:
221 bool found_loop_ = false;
222 bool in_compute_loop_ = false;
223 int loop_level_ = 0;
224
225 std::vector<stmt_t> stmt_groups_;
226 std::vector<stmt_t> inner_let_stmts_;
227};
228
229// Provides access to different parts of the inner compute iteration.
230class compute_step_t {
231public:
232 compute_step_t(const stmt_t &parent) {
233 compute_step_visitor_t v;
234 v.visit(parent);
235
236 compute_loop_ = v.find_stmt_group(stmt_label_t::compute_loop());
237 g2s_load_ = v.find_stmt_group(stmt_label_t::g2s_load());
238 g2s_store_ = v.find_stmt_group(stmt_label_t::g2s_store());
239 prefetch_ = v.find_stmt_group(stmt_label_t::prefetch());
240 g2r_load_ = v.find_stmt_groups(stmt_label_t::g2r_load());
241 s2r_load_ = v.find_stmt_groups(stmt_label_t::s2r_load());
242 mul_ = v.find_stmt_groups(stmt_label_t::mul());
243 c_zero_out_ = v.find_stmt_group(stmt_label_t::c_zero_out());
244 inner_let_stmts_ = v.inner_let_stmts();
245
246 ir_assert(g2r_load_.size() == mul_.size());
247 ir_assert(s2r_load_.size() == mul_.size());
248
249 // Assign preload/mul tags to let statements.
250 for (auto &_let : inner_let_stmts_) {
251 auto &var = _let.as<let_t>().var;
252 bool is_preload = (count_object(g2s_load_, var) > 0)
253 || (count_object(prefetch_, var) > 0);
254 bool is_mul = count_object(g2r_load_, var) > 0
255 || count_object(mul_, var) > 0;
256 if (is_preload) preload_lets_.insert(_let);
257 if (is_mul) mul_lets_.insert(_let);
258 }
259
260 // Propagate preload/mul tags up based on dependencies between let
261 // statements.
262 std::vector<let_info_t> let_infos;
263 object_set_t<stmt_t> seen;
264 std::function<void(const stmt_t &)> propagate;
265 propagate = [&](const stmt_t &_let) {
266 if (seen.count(_let) > 0) return;
267 auto &let = _let.as<let_t>();
268 for (auto &_child : inner_let_stmts_) {
269 auto &child = _child.as<let_t>();
270 if (_child.is_same(_let)) continue;
271 if (contains_object(child.value, let.var)) {
272 // Visit child let statements first.
273 propagate(_child);
274 // Propagate child preload/mul values to this let statement.
275 if (is_preload_let(_child)) preload_lets_.insert(_let);
276 if (is_mul_let(_child)) mul_lets_.insert(_let);
277 }
278 }
279 auto let_info = create_let_info(
280 let, is_preload_let(_let), is_mul_let(_let));
281 let_infos.push_back(let_info);
282 seen.insert(_let);
283 };
284 for (auto &_let : inner_let_stmts_)
285 propagate(_let);
286
287 // Duplicate lets that are used in both preload and mul contexts.
288 duplicate_lets(let_infos);
289 }
290
291 // See ir_core.hpp for the description.
292 const stmt_t &compute_loop() const { return compute_loop_; }
293 const stmt_t &g2s_load() const { return g2s_load_; }
294 const stmt_t &g2s_store() const { return g2s_store_; }
295 const stmt_t &prefetch() const { return prefetch_; }
296 const std::vector<stmt_t> &g2r_load() const { return g2r_load_; }
297 const std::vector<stmt_t> &s2r_load() const { return s2r_load_; }
298 const std::vector<stmt_t> &mul() const { return mul_; }
299 const stmt_t &c_zero_out() const { return c_zero_out_; }
300 const std::vector<stmt_t> &inner_let_stmts() const {
301 return inner_let_stmts_;
302 }
303
304 bool is_preload_let(const stmt_t &s) const {
305 return preload_lets_.count(s) > 0;
306 }
307 bool is_mul_let(const stmt_t &s) const { return mul_lets_.count(s) > 0; }
308
309private:
310 struct let_info_t {
311 let_info_t(const expr_t &var) : var(var) {}
312
313 expr_t var;
314 expr_t preload_var;
315 expr_t mul_var;
316
317 bool is_preload() const { return !preload_var.is_empty(); }
318 bool is_mul() const { return !mul_var.is_empty(); }
319
320 bool needs_update() const { return is_preload() && is_mul(); }
321 };
322
323 let_info_t create_let_info(const let_t &let, bool is_preload, bool is_mul) {
324 let_info_t info(let.var);
325 if (is_preload && !is_mul) {
326 info.preload_var = let.var;
327 } else if (!is_preload && is_mul) {
328 info.mul_var = let.var;
329 } else if (is_preload && is_mul) {
330 info.preload_var = create_var_with_suffix(let.var, "p");
331 info.mul_var = create_var_with_suffix(let.var, "m");
332 }
333 return info;
334 }
335
336 void duplicate_lets(const std::vector<let_info_t> &let_infos) {
337 int nlets = int(inner_let_stmts_.size());
338 ir_assert(int(let_infos.size()) == nlets);
339
340 std::vector<stmt_t> new_lets;
341 for (int i = nlets - 1; i >= 0; i--) {
342 auto &info = let_infos[i];
343 auto &old_let = inner_let_stmts_[i].as<let_t>();
344 if (!info.needs_update()) {
345 auto new_value = update_var(old_let.value, let_infos,
346 info.is_preload(), info.is_mul());
347 auto new_let = inner_let_stmts_[i];
348 if (!new_value.is_same(old_let.value)) {
349 new_let = let_t::make(old_let.var, new_value, old_let.body);
350 }
351 new_lets.push_back(new_let);
352 continue;
353 }
354
355 preload_lets_.erase(&old_let);
356 mul_lets_.erase(&old_let);
357
358 auto preload_value
359 = update_var(old_let.value, let_infos, true, false);
360 auto preload_let = let_t::make(
361 info.preload_var, preload_value, old_let.body);
362
363 auto mul_value = update_var(old_let.value, let_infos, false, true);
364 auto mul_let = let_t::make(info.mul_var, mul_value, old_let.body);
365
366 preload_lets_.insert(preload_let);
367 new_lets.push_back(preload_let);
368
369 mul_lets_.insert(mul_let);
370 new_lets.push_back(mul_let);
371
372 // Update statements.
373 g2s_load_ = update_var(g2s_load_, let_infos, true, false);
374 g2s_store_ = update_var(g2s_store_, let_infos, true, false);
375 prefetch_ = update_var(prefetch_, let_infos, true, false);
376 g2r_load_ = update_var(g2r_load_, let_infos, false, true);
377 s2r_load_ = update_var(s2r_load_, let_infos, false, true);
378 mul_ = update_var(mul_, let_infos, false, true);
379 }
380
381 std::reverse(new_lets.begin(), new_lets.end());
382 inner_let_stmts_ = new_lets;
383 }
384
385 template <typename T>
386 static std::vector<T> update_var(const std::vector<T> &vec,
387 const std::vector<let_info_t> &let_infos, bool is_preload,
388 bool is_mul) {
389 std::vector<T> ret;
390 for (auto &v : vec)
391 ret.push_back(update_var(v, let_infos, is_preload, is_mul));
392 return ret;
393 }
394
395 static object_t update_var(const object_t &obj,
396 const std::vector<let_info_t> &let_infos, bool is_preload,
397 bool is_mul) {
398 auto ret = obj;
399 for (auto &info : let_infos) {
400 if (!info.needs_update()) continue;
401 if (!contains_object(ret, info.var)) continue;
402 if (is_preload) {
403 ir_assert(info.is_preload());
404 ret = substitute(ret, info.var, info.preload_var);
405 } else if (is_mul) {
406 ir_assert(info.is_mul());
407 ret = substitute(ret, info.var, info.mul_var);
408 }
409 }
410 return ret;
411 }
412
413 static expr_t create_var_with_suffix(
414 const expr_t &_var, const std::string &suffix) {
415 auto &var = _var.as<var_t>();
416 auto new_name = var.name + "_" + suffix;
417 return var_t::make(var.type, new_name);
418 }
419
420 stmt_t compute_loop_;
421 stmt_t g2s_load_;
422 stmt_t g2s_store_;
423 stmt_t prefetch_;
424 std::vector<stmt_t> g2r_load_;
425 std::vector<stmt_t> s2r_load_;
426 std::vector<stmt_t> mul_;
427 stmt_t c_zero_out_;
428
429 std::vector<stmt_t> inner_let_stmts_;
430
431 // Due to loop unrolling the inner let statements may depend on different
432 // indices of the outer loops. There are two contexts:
433 // - "preload" loop iteration, e.g. index I
434 // - "multiplication" loop iteration, e.g. index (I + nbuf)
435 // Preloads (either via SLM or via prefetches) for the corresponding
436 // multiplication are executed several iterations before the real
437 // multiplication. That's why we need to know exactly in which context the
438 // given let statement is used. It might be that the same variable is used
439 // from two different contexts. In this case it is duplicated and
440 // initialized with different values for each case.
441 object_set_t<stmt_t> preload_lets_;
442 object_set_t<stmt_t> mul_lets_;
443};
444
445// Helper class to access the outer loop index after pipelining. Pipelining
446// in general requires tracking two versions of a loop index:
447// - Multiplication version - corresponding to the iteration that is currently
448// used for multiplication
449// - Preload version - corresponding to the iteration that is currently used
450// for preload for one of the next multiplications
451// The multiplication version is a few steps behind the preload version.
452class outer_loop_info_t : public loop_info_t {
453public:
454 outer_loop_info_t() = default;
455
456 outer_loop_info_t(const stmt_t &s, ir_context_t &ir_ctx) : loop_info_t(s) {
457 // Outer loop may not be used for unrolling hence loop iterations must
458 // not use its index. If this doesn't hold, introduce a GRF buffer to
459 // represent that variable and apply post-increment updates after each
460 // outer loop iteration.
461 if (count_object(s.as<for_t>().body, var) != 0) {
462 has_var_refs_ = true;
463 mul_var_buf_ = ir_ctx.create_tmp_var(
464 type_t::byte_ptr(), var.as<var_t>().name + "_mul_buf");
465 preload_var_buf_ = ir_ctx.create_tmp_var(
466 type_t::byte_ptr(), var.as<var_t>().name + "_preload_buf");
467
468 auto mul_alloc = alloc_t::make(
469 mul_var_buf_, var.type().size(), alloc_kind_t::grf);
470 auto preload_alloc = alloc_t::make(
471 preload_var_buf_, var.type().size(), alloc_kind_t::grf);
472 allocs_.push_back(mul_alloc);
473 allocs_.push_back(preload_alloc);
474
475 auto mul_init = store_t::make(mul_var_buf_, 0, init());
476 auto preload_init = store_t::make(preload_var_buf_, 0, init());
477 init_stmt_ = mul_init.append(preload_init);
478
479 mul_post_inc_stmt_
480 = store_t::make(mul_var_buf_, 0, mul_var_load() + 1);
481 preload_post_inc_stmt_ = store_t::make(
482 preload_var_buf_, 0, preload_var_load() + 1);
483 }
484 }
485
486 bool has_var_refs() const { return has_var_refs_; }
487
488 expr_t mul_var_load() const {
489 return load_t::make(var.type(), mul_var_buf_, 0);
490 }
491 expr_t preload_var_load() const {
492 return load_t::make(var.type(), preload_var_buf_, 0);
493 }
494
495 stmt_t inject_alloc_stmts(const stmt_t &stmt) const {
496 return jit::inject_alloc_stmts(stmt, allocs_);
497 }
498
499 const stmt_t &init_stmt() const { return init_stmt_; }
500
501 const stmt_t &mul_post_inc_stmt() const { return mul_post_inc_stmt_; }
502 const stmt_t &preload_post_inc_stmt() const {
503 return preload_post_inc_stmt_;
504 }
505
506private:
507 bool has_var_refs_ = false;
508
509 // Helper expressions/statements to partially unroll the loop.
510 expr_t mul_var_buf_;
511 expr_t preload_var_buf_;
512 std::vector<stmt_t> allocs_;
513 stmt_t init_stmt_;
514 stmt_t mul_post_inc_stmt_;
515 stmt_t preload_post_inc_stmt_;
516};
517
518class compute_loop_nest_visitor_t : public ir_visitor_t {
519public:
520 int compute_loop_level() const { return compute_loop_level_; }
521
522 const std::vector<loop_info_t> &loops() const { return loops_; }
523
524 void _visit(const stmt_group_t &obj) override {
525 bool is_compute_loop = (obj.label == stmt_label_t::compute_loop());
526 if (is_compute_loop) {
527 in_compute_loop_ = true;
528 compute_loop_level_ = level_;
529 }
530 ir_visitor_t::_visit(obj);
531 if (is_compute_loop) in_compute_loop_ = false;
532 }
533
534 void _visit(const for_t &obj) override {
535 level_++;
536 ir_visitor_t::_visit(obj);
537 if (in_compute_loop_) loops_.emplace_back(obj);
538 level_--;
539 }
540
541private:
542 bool in_compute_loop_ = false;
543 int compute_loop_level_ = -1;
544 std::vector<loop_info_t> loops_;
545 int level_ = 0;
546};
547
548// Helper class to work with loop nest of the compute loop.
549class compute_loop_nest_t {
550public:
551 compute_loop_nest_t() = default;
552
553 compute_loop_nest_t(const stmt_t &root, ir_context_t &ir_ctx)
554 : root_(root) {
555 compute_loop_nest_visitor_t visitor;
556 visitor.visit(root);
557
558 compute_loop_level_ = visitor.compute_loop_level();
559 loops_ = visitor.loops();
560
561 if (loops_.empty()) {
562 outer_loop_size_ = 1;
563 return;
564 }
565
566 outer_loop_ = outer_loop_info_t(loops_.back().stmt, ir_ctx);
567 outer_loop_size_ = outer_loop_.size();
568 }
569
570 // Returns the loop level of the compute_loop statement group corresponding
571 // to the number of outer loops.
572 int compute_loop_level() const { return compute_loop_level_; }
573
574 // Returns loops inside compute_loop statement group.
575 const std::vector<loop_info_t> &loops() const { return loops_; }
576
577 // Number of iterations of all loops.
578 int size() const {
579 int ret = 1;
580 for (auto &l : loops_)
581 ret *= l.size();
582 return ret;
583 }
584
585 // Number of iterations in the outermost loop (see comments in ctor).
586 int outer_loop_size() const { return outer_loop_size_; }
587
588 const outer_loop_info_t &outer_loop_info() const { return outer_loop_; }
589
590 template <typename F>
591 void for_each_loop_var(const F &f) const {
592 for (auto &l : loops_)
593 f(l.var);
594 }
595
596 // Number of iterations of all loops except the outermost.
597 int inner_loops_size() const { return size() / outer_loop_size(); }
598
599private:
600 stmt_t root_;
601 int compute_loop_level_ = -1;
602 std::vector<loop_info_t> loops_;
603
604 int outer_loop_size_;
605 outer_loop_info_t outer_loop_;
606};
607
608struct compute_params_t {
609 compute_params_t() = default;
610
611 compute_params_t(int slm_bufs, int gmem_bufs, int slm_buf_size,
612 int prefetch_bufs, int inner_loops_iters)
613 : slm_bufs(slm_bufs)
614 , gmem_bufs(gmem_bufs)
615 , slm_buf_size(slm_buf_size)
616 , prefetch_bufs(prefetch_bufs) {
617 use_slm = (slm_buf_size > 0);
618 use_prefetch = (prefetch_bufs > 0);
619 ir_assert(!use_slm || !use_prefetch)
620 << "Can't have both SLM buffering and prefetch enabled.";
621 if (use_slm) {
622 ir_assert(utils::one_of(slm_bufs, 1, 2, 3));
623 ir_assert(utils::one_of(gmem_bufs, 1, 2));
624 preload_bufs = slm_bufs;
625 unroll = math::lcm(slm_bufs * gmem_bufs, inner_loops_iters);
626 } else if (use_prefetch) {
627 preload_bufs = prefetch_bufs;
628 ir_assert(slm_bufs == 0);
629 ir_assert(gmem_bufs == 0);
630 unroll = math::lcm(prefetch_bufs, inner_loops_iters);
631 } else {
632 preload_bufs = 0;
633 ir_assert(slm_bufs == 0);
634 ir_assert(gmem_bufs == 0);
635 unroll = inner_loops_iters;
636 }
637 }
638
639 int slm_bufs;
640 int gmem_bufs;
641 int slm_buf_size;
642 int prefetch_bufs;
643 int preload_bufs;
644 int unroll;
645
646 bool use_slm;
647 bool use_prefetch;
648};
649
650// Helper class to implement SLM buffering.
651class compute_iterator_t {
652public:
653 compute_iterator_t(const compute_params_t &params,
654 const compute_loop_nest_t &loop_nest)
655 : params(params)
656 , preload_loop_it(loop_nest.loops())
657 , mul_loop_it(loop_nest.loops()) {
658
659 int compute_iters = loop_nest.size();
660 iters = compute_iters;
661 ir_assert(iters >= 1) << "Empty loop is not expected.";
662
663 iters += std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
664 ramp_up_iters
665 = std::max(1, preload_bufs() + std::max(0, gmem_bufs() - 1));
666 ramp_down_iters = std::min(
667 std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1),
668 iters - ramp_up_iters);
669 body_iters = iters - ramp_up_iters - ramp_down_iters;
670 body_iters = utils::rnd_dn(body_iters, params.unroll);
671 ramp_down_iters = iters - ramp_up_iters - body_iters;
672
673 ir_assert(ramp_up_iters + body_iters + ramp_down_iters == iters);
674
675 iter = 0;
676 linear_id = 0;
677 riter = iters - 1;
678 }
679
680 int unroll() const { return params.unroll; }
681
682 int preload_bufs() const { return params.preload_bufs; }
683
684 int slm_bufs() const { return params.slm_bufs; }
685
686 int gmem_bufs() const { return params.gmem_bufs; }
687
688 compute_iterator_t &operator++() {
689 if (do_preload()) preload_loop_it.advance();
690 if (do_mul()) mul_loop_it.advance();
691 ++iter;
692 ++linear_id;
693 --riter;
694 return *this;
695 }
696
697 void advance(int n) {
698 if (n == 0) return;
699
700 ir_assert(n % params.unroll == 0);
701 ir_assert(iter + n <= iters);
702
703 if (preload_bufs() > 0) ir_assert(do_preload());
704 ir_assert(do_mul());
705
706 iter += n;
707 riter -= n;
708
709 if (preload_bufs() > 0) preload_loop_it.advance(n);
710 mul_loop_it.advance(n);
711 }
712
713 bool do_mul() const {
714 return iter >= std::max(0, preload_bufs() - 1)
715 + std::max(0, gmem_bufs() - 1);
716 }
717
718 bool is_first_mul() const {
719 return iter
720 == std::max(0, preload_bufs() - 1)
721 + std::max(0, gmem_bufs() - 1);
722 }
723 bool is_last_mul() const { return riter == 0; }
724
725 bool is_last_g2s_store() const {
726 if (!do_g2s_store()) return false;
727 return riter == slm_bufs() - 1;
728 }
729
730 bool is_last_preload() const {
731 if (!do_preload()) return false;
732 return riter == (preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
733 }
734
735 bool is_last_g2s_load() const {
736 if (!do_g2s_load()) return false;
737 return is_last_preload();
738 }
739
740 bool is_last_prefetch() const {
741 if (!do_prefetch()) return false;
742 return is_last_preload();
743 }
744
745 bool do_preload() const {
746 if (preload_bufs() == 0) return false;
747 return riter >= (preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
748 }
749
750 bool do_prefetch() const {
751 if (!params.use_prefetch) return false;
752 return do_preload();
753 }
754
755 bool do_g2s_load() const {
756 if (!params.use_slm) return false;
757 return do_preload();
758 }
759
760 bool do_g2s_store() const {
761 if (!params.use_slm) return false;
762 ir_assert(gmem_bufs() >= 1);
763 return iter >= (gmem_bufs() - 1) && riter >= (slm_bufs() - 1);
764 }
765
766 int gmem_write_buf_index() const {
767 ir_assert(do_g2s_load());
768 return iter % gmem_bufs();
769 }
770
771 int gmem_read_buf_index() const {
772 ir_assert(do_g2s_store());
773 return (iter - (gmem_bufs() - 1)) % gmem_bufs();
774 }
775
776 int slm_read_offset_update() const {
777 ir_assert(params.use_slm);
778 ir_assert(do_mul());
779
780 int slm_iter = iter - (gmem_bufs() - 1) - (slm_bufs() - 1);
781 int cur_slm_idx = slm_iter % slm_bufs();
782 int next_slm_idx = (slm_iter + 1) % slm_bufs();
783 int ret = next_slm_idx * params.slm_buf_size
784 - cur_slm_idx * params.slm_buf_size;
785 return ret;
786 }
787
788 int slm_write_offset_update() const {
789 ir_assert(params.use_slm);
790 ir_assert(do_g2s_store());
791
792 int slm_iter = iter - (gmem_bufs() - 1);
793 int cur_slm_idx = slm_iter % slm_bufs();
794 int next_slm_idx = (slm_iter + 1) % slm_bufs();
795 int ret = next_slm_idx * params.slm_buf_size
796 - cur_slm_idx * params.slm_buf_size;
797 return ret;
798 }
799
800 compute_params_t params;
801 multi_loop_iterator_t preload_loop_it;
802 multi_loop_iterator_t mul_loop_it;
803
804 // ramp_up_iters + body_iters + ramp_down_iters == iters
805 int iters;
806 int ramp_up_iters;
807 int body_iters;
808 int ramp_down_iters;
809
810 // Invariant: iter + riter = iters - 1
811 int iter;
812 int riter;
813
814 int linear_id;
815};
816
817// Basic LRU SBID allocator, tries to use the same SBIDs for the same GRF
818// buffers.
819class sbid_manager_t {
820public:
821 sbid_manager_t(ngen::HW hw = ngen::HW::Unknown)
822 : sbid_count_(hw >= ngen::HW::XeHPC ? 32 : 16)
823 , tuple_func_(builtin_t::make("tuple")) {}
824
825 ngen_proxy::SBID get_sbid(const expr_t &buf, int index = 0) {
826 auto key = tuple_func_.call({buf, expr_t(index)});
827
828 int free_idx = -1;
829 for (int i = 0; i < sbid_count_; i++) {
830 auto &e = entries_[i];
831 if (key.is_equal(e.key)) {
832 e.time = cur_time_++;
833 return ngen_proxy::SBID(i);
834 }
835 if (free_idx == -1 && e.key.is_empty()) free_idx = i;
836 }
837
838 // Not found but there is a free SBID.
839 if (free_idx != -1) {
840 entries_[free_idx] = {key, cur_time_++};
841 return ngen_proxy::SBID(free_idx);
842 }
843
844 // Find the oldest SBID and use it.
845 int old_idx = 0;
846 int old_time = entries_[0].time;
847 for (int i = 1; i < sbid_count_; i++) {
848 if (entries_[i].time < old_time) {
849 old_idx = i;
850 old_time = entries_[i].time;
851 }
852 }
853
854 entries_[old_idx] = entry_t({key, cur_time_++});
855 return ngen_proxy::SBID(old_idx);
856 }
857
858private:
859 struct entry_t {
860 stmt_t key;
861 int time;
862 };
863
864 static const int max_sbid_count = 32;
865 std::array<entry_t, max_sbid_count> entries_;
866
867 int sbid_count_ = 0;
868 func_t tuple_func_;
869 int cur_time_ = 0;
870};
871
872// Helper to assign SBIDs to IR function calls.
873class sbid_assigner_t {
874public:
875 sbid_assigner_t(ngen::HW hw) : local_sbid_mgr_(hw) {}
876
877 sbid_assigner_t(sbid_manager_t &external_sbid_mgr)
878 : external_sbid_mgr_(&external_sbid_mgr) {}
879
880 stmt_t assign(const stmt_t &stmt) {
881 auto stmt_vec = flatten_statements(stmt);
882 stmt_t ret = stmt;
883 int prefetch_idx = 0;
884 for (auto &_s : stmt_vec) {
885 if (!_s.is<func_call_t>()) continue;
886 auto s = _s;
887 if (is_func_call<send_t>(s)) {
888 auto &send = s.as<func_call_t>().func.as<send_t>();
889 int idx = (send.is_prefetch() || send.is_prefetch_2d()
890 ? prefetch_idx++
891 : 0);
892 auto sbid = get_sbid(send_t::arg_reg_buf(s), idx);
893 s = update_call_with_sbid(s, sbid);
894 } else if (is_func_call<dpas_t>(s)) {
895 auto &c = s.as<func_call_t>();
896 auto *mod_attr = c.attr.as_ptr<instruction_modifier_attr_t>();
897 if (!c.func.as<dpas_t>().is_dp4a() && // dp4a-s do not need SBID
898 (!mod_attr || !mod_attr->mod.is_atomic)) {
899 // Last dpas in Atomic chain.
900 auto sbid = get_sbid(dpas_t::arg_src1(s));
901 s = update_call_with_sbid(s, sbid);
902 }
903 } else if (s.is<func_call_t>()) {
904 auto &c = s.as<func_call_t>();
905 if (c.func.is_equal(funcs::signal_func())
906 || c.func.is_equal(funcs::slm_fence_func())
907 || c.func.is_equal(funcs::barrier_func())) {
908 // Use 0 as the key for signals and SLM fences.
909 auto sbid = get_sbid(expr_t(0));
910 s = update_call_with_sbid(s, sbid);
911 }
912 } else {
913 ir_error_not_expected() << s;
914 }
915 ret = substitute(ret, _s, s);
916 }
917 return ret;
918 }
919
920private:
921 ngen_proxy::SBID get_sbid(const expr_t &ptr, int index = 0) {
922 auto &sbid_mgr
923 = (external_sbid_mgr_ ? *external_sbid_mgr_ : local_sbid_mgr_);
924 return sbid_mgr.get_sbid(ptr, index);
925 }
926
927 static stmt_t update_call_with_sbid(
928 const stmt_t &s, const ngen_proxy::SBID &sbid) {
929 return instruction_modifier_attr_t::make(
930 ngen_proxy::InstructionModifier().with_sbid(sbid))
931 .apply_to(s);
932 }
933
934 sbid_manager_t local_sbid_mgr_;
935 sbid_manager_t *external_sbid_mgr_ = nullptr;
936};
937
938// Work around due to limited scoping functionality in current generator
939// Prepends all newly created var_t names with given prefix.
940class var_prepender_t : public ir_mutator_t {
941public:
942 var_prepender_t(const std::string &prefix) : prefix_(prefix) {}
943 object_t _mutate(const for_t &obj) override {
944 auto new_obj = ir_mutator_t::_mutate(obj);
945 auto new_var = var_t::make(
946 obj.var.type(), prefix_ + obj.var.as<var_t>().name);
947 new_obj = substitute(new_obj, obj.var, new_var);
948 return new_obj;
949 }
950 object_t _mutate(const let_t &obj) override {
951 auto new_obj = ir_mutator_t::_mutate(obj);
952 auto new_var = var_t::make(
953 obj.var.type(), prefix_ + obj.var.as<var_t>().name);
954 new_obj = substitute(new_obj, obj.var, new_var);
955 return new_obj;
956 }
957
958private:
959 std::string prefix_;
960};
961
962object_t prepend_new_vars(const object_t &root, const std::string &prefix) {
963 var_prepender_t mutator(prefix);
964 return mutator.mutate(root);
965}
966
967// Perform pipelining operation. The goal is to transform
968// the loop structure from:
969//
970// for i in range(init, bound):
971// A_block(i);
972// B_block(i);
973//
974// to the following
975//
976// for i in range(init, init + length):
977// A_block(i);
978// for i in range(init, bound):
979// if (i < bound - length):
980// A_block(i + length);
981// B_block(i);
982//
983// Since A_block and B_block have to be independent to maintain correctness,
984// this transform ignores the operations within the for_loop and relies on a
985// correct substitution for A_block and B_block.
986
987struct pipeline_ctx_t {
988 pipeline_ctx_t(const stmt_t &prologue, const stmt_t &body)
989 : prologue_(prologue), body_(body) {}
990 stmt_t stmt() const { return prologue_.append(body_); }
991 stmt_t prologue() { return prologue_; }
992 stmt_t body() { return body_; }
993
994private:
995 stmt_t prologue_;
996 stmt_t body_;
997};
998
999pipeline_ctx_t pipeline(
1000 int length, const loop_info_t &loop, stmt_t A_block, stmt_t B_block) {
1001
1002 expr_t idx = loop.var;
1003 int bound = loop.bound();
1004 int init = loop.init();
1005
1006 int pipe_len = std::min(init + length, bound);
1007
1008 stmt_t prologue = prepend_new_vars(
1009 for_t::make(idx, init, pipe_len, A_block,
1010 pipe_len <= loop.unroll() ? pipe_len : 1),
1011 "prefetch_");
1012
1013 expr_t A_idx = idx + pipe_len;
1014 stmt_t body = if_t::make(
1015 idx < (bound - pipe_len), substitute(A_block, idx, A_idx));
1016 body = body.append(B_block);
1017 body = for_t::make(idx, init, bound, body, loop.unroll());
1018
1019 return pipeline_ctx_t(prologue, body);
1020}
1021
1022class prefetch_pipeliner_t {
1023public:
1024 prefetch_pipeliner_t(
1025 const stmt_t &root, const conv_config_t &cfg, ir_context_t &ir_ctx)
1026 : root_(root), cfg_(cfg), ir_ctx_(ir_ctx) {}
1027 stmt_t inject() {
1028 auto compute_loop_stmt
1029 = find_stmt_group(root_, stmt_label_t::compute_loop());
1030 if (!compute_loop_stmt.has_value()) return root_;
1031 auto compute_loop = compute_loop_stmt.value();
1032 auto loop_nest = compute_loop_nest_t(compute_loop, ir_ctx_);
1033 auto &loops = loop_nest.loops();
1034
1035 // No loops to pipeline
1036 if (loops.size() == 0) return root_;
1037 auto &loop_body = loops[0].body();
1038
1039 auto A_block_stmt
1040 = find_stmt_group(loop_body, stmt_label_t::prefetch());
1041 if (!A_block_stmt.has_value()) return root_;
1042 auto A_block = A_block_stmt.value();
1043 auto B_block = remove_stmt_group(loop_body, stmt_label_t::prefetch());
1044 size_t prefetch_count = 0;
1045 size_t max_nested_prefetch = 2;
1046 for (size_t i = 0; i < loops.size(); i++) {
1047 if (prefetch_count < max_nested_prefetch) {
1048 if (!contains_object(A_block, loops[i].var)) {
1049 // No point in prefetching a constant in a loop
1050 B_block = for_t::make(loops[i].var, loops[i].init(),
1051 loops[i].bound(), B_block, loops[i].unroll());
1052 continue;
1053 }
1054
1055 auto next = pipeline(
1056 cfg_.prefetch().bufs(), loops[i], A_block, B_block);
1057 A_block = next.prologue();
1058 B_block = next.body();
1059 prefetch_count++;
1060
1061 } else {
1062 B_block = for_t::make(loops[i].var, loops[i].init(),
1063 loops[i].bound(), A_block.append(B_block),
1064 loops[i].unroll());
1065 A_block = stmt_t();
1066 }
1067 }
1068 return substitute(root_, compute_loop, A_block.append(B_block));
1069 }
1070
1071private:
1072 stmt_t root_;
1073 const conv_config_t &cfg_;
1074 ir_context_t &ir_ctx_;
1075};
1076
1077stmt_t inject_prefetch_pipeline(
1078 const stmt_t &s, ir_context_t &ir_ctx, const conv_config_t &cfg) {
1079 trace_start();
1080 auto ret = prefetch_pipeliner_t(s, cfg, ir_ctx).inject();
1081 trace_pass("inject_prefetch_pipeline", ret, ir_ctx);
1082 return ret;
1083}
1084
1085// Helper class to handle synchronization between threads for cooperative SLM
1086// load and stores for double and triple buffering. Name conventions:
1087// - Lx step - load from global memory to GRF (to be stored in SLM buffer x)
1088// - Mx step - load from SLM buffer x to GRF and multiplication
1089// - Sx step - store from GRF to SLM buffer x
1090// - Rx event - SLM buffer x is available for reading
1091// - Wx event - SLM buffer x is available for writing
1092// Scheme for single buffering:
1093// L0
1094// barrier
1095// S0
1096// barrier
1097// M0
1098// Schemes for double and triple buffering are below.
1099class slm_sync_manager_t {
1100public:
1101 slm_sync_manager_t(const conv_config_t &cfg, bool with_unroll)
1102 : slm_bufs_(cfg.slm().bufs())
1103 , gmem_bufs_(cfg.slm().gmem_bufs())
1104 , with_unroll_(with_unroll) {
1105 switch (slm_bufs_) {
1106 case 2: ver_ = version_t::x2; break;
1107 case 3: ver_ = version_t::x3_v3; break;
1108 default: ver_ = version_t::undef;
1109 }
1110 if (cfg.slm().sync_version() != -1) {
1111 ver_ = (version_t)cfg.slm().sync_version();
1112 }
1113 switch (slm_bufs_) {
1114 case 2: ir_assert(ver_ == version_t::x2); break;
1115 case 3:
1116 ir_assert(utils::one_of(ver_, version_t::x3_v1,
1117 version_t::x3_v2, version_t::x3_v3));
1118 break;
1119 default: ir_assert(ver_ == version_t::undef);
1120 }
1121 }
1122
1123 stmt_t before_loop_prepend(const stmt_t &_s) const {
1124 if (with_unroll_) return _s;
1125 auto s = _s;
1126 if (is_x3_v1() || is_x3_v2() || is_x3_v3()) {
1127 // Emit initial signal, to match wait-signal pairs in the loop.
1128 s = funcs::signal().append(s);
1129 }
1130 return s;
1131 }
1132
1133 stmt_t after_loop(const stmt_t &_s) const {
1134 auto s = _s;
1135 if (slm_bufs_ == 3) {
1136 s = s.append(funcs::barrier_wait());
1137 // Wait with V3 guarantees that all SLM writes are synced, other
1138 // versions need additional synchronization.
1139 if (!is_x3_v3()) s = s.append(funcs::barrier());
1140 }
1141 return s;
1142 }
1143
1144 stmt_t before_L(const stmt_t &_s, bool do_mul) const {
1145 bool emit = false;
1146 if (!with_unroll_) emit = true;
1147 if (with_unroll_ && do_mul) emit = true;
1148
1149 auto s = _s;
1150 if (is_x3_v2() && emit) { s = s.append(funcs::barrier_wait()); }
1151
1152 return s;
1153 }
1154
1155 stmt_t before_L_prepend(const stmt_t &_s, bool do_mul) const {
1156 return before_L(stmt_t(), do_mul).append(_s);
1157 }
1158
1159 stmt_t after_L(const stmt_t &_s, bool do_mul) const {
1160 bool emit = false;
1161 if (!with_unroll_) emit = true;
1162 if (with_unroll_ && do_mul) emit = true;
1163
1164 auto s = _s;
1165 if (is_x3_v1() && emit) s = s.append(funcs::barrier_wait());
1166 return s;
1167 }
1168
1169 stmt_t after_L_prepend(const stmt_t &_s, bool do_mul) const {
1170 return after_L(stmt_t(), do_mul).append(_s);
1171 }
1172
1173 stmt_t before_S(const stmt_t &_s, bool do_mul, bool is_last_mul = false,
1174 int iter = -1) const {
1175 bool emit = false;
1176 if (!with_unroll_) emit = true;
1177 if (with_unroll_ && iter != -1
1178 && iter >= (slm_bufs_ - 1) + (gmem_bufs_ - 1) - 1)
1179 emit = true;
1180
1181 auto s = _s;
1182 if (is_x3_v3() && emit) {
1183 s = s.append(funcs::barrier_wait());
1184 } else if ((is_x3_v1() || is_x3_v2()) && emit) {
1185 // In general we have to use SLM fence before signal to flush all
1186 // previous SLM stores. However any SLM load behaves as implicit
1187 // SLM fence for all previous SLM stores. This means we don't need
1188 // explicit SLM fence when we perform SLM load/multiplication
1189 // before signal.
1190 if (!do_mul) s = s.append(funcs::slm_fence());
1191 if (!is_last_mul) s = s.append(funcs::signal());
1192 }
1193 return s;
1194 }
1195
1196 stmt_t after_S(
1197 const stmt_t &_s, bool is_last_mul = false, int iter = -1) const {
1198 auto s = _s;
1199 if (is_x2()) {
1200 s = s.append(funcs::barrier());
1201 } else if (is_x3_v3()) {
1202 bool emit = false;
1203 if (!with_unroll_) emit = true;
1204 if (with_unroll_ && !is_last_mul && iter != -1
1205 && iter >= (slm_bufs_ - 1) + (gmem_bufs_ - 1) - 2)
1206 emit = true;
1207 if (emit) {
1208 s = s.append(funcs::slm_fence());
1209 s = s.append(funcs::signal());
1210 }
1211 }
1212 return s;
1213 }
1214
1215 bool is_x2() const { return ver_ == version_t::x2; }
1216 bool is_x3_v1() const { return ver_ == version_t::x3_v1; }
1217 bool is_x3_v2() const { return ver_ == version_t::x3_v2; }
1218 bool is_x3_v3() const { return ver_ == version_t::x3_v3; }
1219
1220private:
1221 enum class version_t {
1222 undef,
1223 // Double buffering scheme:
1224 // L0
1225 // M1
1226 // S0
1227 // barrier
1228 // L1
1229 // M0
1230 // S1
1231 // barrier
1232 x2,
1233 // Triple buffering scheme V1 (wait before M)
1234 // L0
1235 // wait R1/W0
1236 // M1
1237 // signal R2/W1
1238 // S0
1239 // L1
1240 // wait R2/W1
1241 // M2
1242 // signal R0/W2
1243 // S1
1244 // L2
1245 // wait R0/W2
1246 // M0
1247 // signal R1/W0
1248 // S2
1249 x3_v1,
1250 // Triple buffering scheme V2 (wait before L)
1251 // wait R1/W0
1252 // L0
1253 // M1
1254 // signal R2/W1
1255 // S0
1256 // wait R2/W1
1257 // L1
1258 // M2
1259 // signal R0/W2
1260 // S1
1261 // wait R0/W2
1262 // L2
1263 // M0
1264 // signal R1/W0
1265 // S2
1266 x3_v2,
1267 // Triple buffering scheme V3 (signal after store)
1268 // There are no SLM loads between S and signal so explicit fence is
1269 // required.
1270 // L0
1271 // M1
1272 // wait R2/W0
1273 // S0
1274 // fence and signal R0/W1
1275 // L1
1276 // M2
1277 // wait R0/W1
1278 // S1
1279 // fence and signal R1/W2
1280 // L2
1281 // M0
1282 // wait R1/W2
1283 // S2
1284 // fence and signal R2/W0
1285 x3_v3
1286 };
1287
1288 int slm_bufs_;
1289 int gmem_bufs_;
1290 bool with_unroll_;
1291 version_t ver_;
1292};
1293
1294class slm_zp_mask_extractor_t : public ir_visitor_t {
1295public:
1296 slm_zp_mask_extractor_t(
1297 std::vector<stmt_t> &retn, object_eq_set_t<expr_t> &bufs)
1298 : retn_(retn), bufs_(bufs), outer_(true) {}
1299
1300 void _visit(const store_t &obj) override {
1301 if (obj.buf.str().find("zp_mask") == 0) {
1302 if (outer_) retn_.emplace_back(obj);
1303 bufs_.insert(obj.buf);
1304 }
1305 }
1306
1307 void _visit(const let_t &obj) override {
1308 if ((obj.var.str().find("zp_mask") == 0)) {
1309 if (outer_) retn_.emplace_back(obj);
1310 auto outer_prev = outer_;
1311 outer_ = false;
1312 visit(obj.body);
1313 outer_ = outer_prev;
1314 }
1315 }
1316
1317private:
1318 std::vector<stmt_t> &retn_;
1319 object_eq_set_t<expr_t> &bufs_;
1320 bool outer_;
1321};
1322
1323class simple_slm_buffering_injector_t {
1324public:
1325 simple_slm_buffering_injector_t(const stmt_t &root, ir_context_t &ir_ctx,
1326 const conv_config_t &cfg, int ab_slm_size)
1327 : ir_ctx_(ir_ctx)
1328 , cfg_(cfg)
1329 , ab_slm_size_(ab_slm_size)
1330 , root_(root)
1331 , alloc_mgr_(root_)
1332 , step_(root)
1333 , loop_nest_(root, ir_ctx)
1334 , slm_sync_mgr_(cfg, /*with_unroll=*/false) {}
1335
1336 stmt_t inject() {
1337 ir_assert(cfg_.slm().gmem_bufs() == 1)
1338 << "GRF buffering is not supported.";
1339 if (utils::one_of(cfg_.slm().bufs(), 0, 1)) return root_;
1340
1341 ir_assert(cfg_.slm().a() == cfg_.slm().b())
1342 << "Mixed SLM/GMEM loads are not supported.";
1343
1344 auto loop = step_.compute_loop();
1345
1346 // SLM indices are allocated as follows:
1347 // slm_idx[0] -> slm_buf_store
1348 // slm_idx[1] -> slm_buf_compute
1349 // slm_idx[2] -> slm_counter
1350 auto slm_idx_buf
1351 = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "slm_idx");
1352 int slm_idx_size = type_t::s32().size();
1353
1354 auto slm_idx_load = [&](int off, int elems) {
1355 return load_t::make(
1356 type_t::s32(elems), slm_idx_buf, slm_idx_size * off);
1357 };
1358
1359 // Initialize slm_idx.
1360 int off = 0;
1361 auto store0 = store_t::make(slm_idx_buf, off, 0);
1362 off += slm_idx_size;
1363
1364 auto store1 = store_t::make(slm_idx_buf, off, 1);
1365 off += slm_idx_size;
1366
1367 auto store2 = store_t::make(
1368 slm_idx_buf, off, int_imm_t::make(0, type_t::s32()));
1369
1370 auto slm_idx_init = store0.append(store1).append(store2);
1371
1372 auto slm_idx_load2 = slm_idx_load(0, 2);
1373 auto slm_idx_load4 = slm_idx_load(0, 4);
1374 auto slm_idx_store = store_t::make(slm_idx_buf, 0,
1375 slm_idx_load4 + shuffle_t::make_broadcast(1, 4));
1376
1377 // Update slm_idx.
1378 auto mask = (slm_idx_load2
1379 == shuffle_t::make_broadcast(cfg_.slm().bufs(), 2));
1380 auto slm_idx_store_fix = store_t::make(slm_idx_buf, 0,
1381 shuffle_t::make_broadcast(int_imm_t::make(0, type_t::s32()), 2),
1382 store_t::default_stride, mask);
1383
1384 auto slm_idx_update = slm_idx_store.append(slm_idx_store_fix);
1385
1386 loop = slm_idx_init.append(loop);
1387
1388 auto &g2s_load_orig = step_.g2s_load();
1389 auto &g2s_store_orig = step_.g2s_store();
1390 auto &s2r_load = step_.s2r_load();
1391 auto &mul = step_.mul();
1392
1393 auto g2s_load = g2s_load_orig;
1394 auto g2s_store = g2s_store_orig;
1395
1396 ir_assert(s2r_load.size() == mul.size());
1397
1398 stmt_t s2r_mul;
1399 for (int i = 0; i < int(mul.size()); i++) {
1400 s2r_mul = s2r_mul.append(s2r_load[i]);
1401 loop = substitute(loop, s2r_load[i], stmt_t(), 1);
1402 s2r_mul = s2r_mul.append(mul[i]);
1403 loop = substitute(loop, mul[i], stmt_t(), 1);
1404 }
1405
1406 loop = remove_synchronization(loop);
1407
1408 object_eq_set_t<expr_t> mask_bufs;
1409 std::vector<stmt_t> masks;
1410
1411 slm_zp_mask_extractor_t(masks, mask_bufs).visit(s2r_mul);
1412 if (!mask_bufs.empty())
1413 for (auto &m : masks)
1414 s2r_mul = substitute(s2r_mul, m, stmt_t());
1415
1416 s2r_mul = sub_slm_bufs(s2r_mul, slm_idx_load(1, 1));
1417 g2s_store = sub_slm_bufs(g2s_store, slm_idx_load(0, 1));
1418 g2s_store = g2s_store.append(slm_idx_update);
1419
1420 auto s2r_mul_body = s2r_mul;
1421 auto s2r_mul_tail = s2r_mul;
1422 auto slm_counter = slm_idx_load(2, 1);
1423 auto cond = (slm_counter >= cfg_.slm().bufs() - 1);
1424
1425 if (cfg_.slm().bufs() == 2) {
1426 s2r_mul_body = if_t::make(cond, s2r_mul_body);
1427 } else {
1428 // In general we have to use SLM fence before signal to flush all
1429 // previous SLM stores. However any SLM load behaves as implicit
1430 // SLM fence for all previous SLM stores. This means we don't need
1431 // explicit SLM fence when we perform SLM load/multiplication
1432 // before signal.
1433 auto with_mul = slm_sync_mgr_.before_S(s2r_mul_body, true);
1434 auto without_mul = slm_sync_mgr_.before_S(stmt_t(), false);
1435 s2r_mul_body = if_t::make(cond, with_mul, without_mul);
1436 }
1437
1438 g2s_store = slm_sync_mgr_.after_S(g2s_store);
1439 g2s_load = slm_sync_mgr_.before_L_prepend(g2s_load, true);
1440 g2s_load = slm_sync_mgr_.after_L(g2s_load, true);
1441
1442 if (!g2s_load.is_same(g2s_load_orig)) {
1443 loop = substitute(loop, g2s_load_orig, g2s_load, 1);
1444 }
1445
1446 alloc_updater_t alloc_updater;
1447
1448 int slm_bufs = cfg_.slm().bufs();
1449 for (auto &mbuf : mask_bufs) {
1450 auto sz = alloc_mgr_.alloc_size(mbuf);
1451 alloc_updater.resize(mbuf, sz * slm_bufs);
1452 for (auto &m : masks)
1453 m = substitute(m, mbuf, mbuf[sz * (slm_bufs - 1)]);
1454 layout_t comp_layout(type_t::u8(), 0, std::vector<dim_t> {sz});
1455 for (int b = 1; b < slm_bufs; b++) {
1456 auto reorder = create_reorder_stmt(comp_layout, comp_layout,
1457 mbuf + b * sz, mbuf + (b - 1) * sz);
1458 s2r_mul_body = s2r_mul_body.append(reorder);
1459 if ((slm_bufs == 3) && (b == 1))
1460 s2r_mul_tail = s2r_mul_tail.append(reorder);
1461 }
1462 }
1463 if (!mask_bufs.empty()) {
1464 stmt_t all_masks;
1465 for (auto &m : masks)
1466 all_masks = all_masks.append(m);
1467 s2r_mul_body = all_masks.append(s2r_mul_body);
1468 }
1469 loop = substitute(
1470 loop, g2s_store_orig, s2r_mul_body.append(g2s_store), 1);
1471
1472 loop = slm_sync_mgr_.before_loop_prepend(loop);
1473
1474 // Complete the remaining iterations.
1475 int rem_iters = slm_bufs - 1;
1476 int mul_start = std::max(0, rem_iters - loop_nest_.size());
1477 multi_loop_iterator_t multi(loop_nest_.loops());
1478 multi.advance(loop_nest_.size() - rem_iters + mul_start);
1479
1480 loop = slm_sync_mgr_.after_loop(loop);
1481 for (int i = 0; i < rem_iters; i++) {
1482 if (i >= mul_start) {
1483 auto tmp_mul_tail = s2r_mul_tail;
1484 loop_nest_.for_each_loop_var([&](const expr_t &v) {
1485 expr_t iter(multi.var_value(v));
1486 tmp_mul_tail = substitute(tmp_mul_tail, v, iter);
1487 });
1488 // SLM load/multiplication works as implicit SLM fence.
1489 loop = loop.append(tmp_mul_tail);
1490 multi.advance();
1491 }
1492 loop = loop.append(slm_idx_update);
1493 }
1494
1495 if (cfg_.assign_sbids())
1496 loop = sbid_assigner_t(ir_ctx_.hw_cfg().hw()).assign(loop);
1497
1498 const auto grf_size = ir_ctx_.hw_cfg().grf_size();
1499 loop = alloc_t::make(slm_idx_buf, grf_size, alloc_kind_t::grf, loop);
1500
1501 auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm);
1502 ir_assert(slm_buffers.size() == 1);
1503 auto &slm_buf = slm_buffers[0];
1504 int non_ab_slm_size = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_;
1505 alloc_updater.resize(
1506 slm_buf, non_ab_slm_size + ab_slm_size_ * slm_bufs);
1507
1508 auto ret = substitute(root_, step_.compute_loop(), loop, 1);
1509 ret = alloc_updater.update(ret);
1510 return ret;
1511 }
1512
1513 static stmt_t remove_synchronization(const stmt_t &s) {
1514 auto ret = s;
1515 for (auto &_c : find_objects<func_call_t>(s)) {
1516 auto &c = _c.as<func_call_t>();
1517 if (c.func.is_equal(funcs::signal_func())
1518 || c.func.is_equal(funcs::slm_fence_func())
1519 || c.func.is_equal(funcs::barrier_func())) {
1520 ret = substitute(ret, _c, stmt_t(), 1);
1521 }
1522 }
1523 return ret;
1524 }
1525
1526 stmt_t sub_slm_bufs(const stmt_t &stmt, const expr_t &slm_idx) const {
1527 auto stmt_vec = flatten_statements(stmt);
1528
1529 stmt_t ret = stmt;
1530 for (auto &s : stmt_vec) {
1531 if (!is_func_call<send_t>(s)) continue;
1532
1533 auto &send = s.as<func_call_t>().func.as<send_t>();
1534
1535 // This is not send to SLM, skip.
1536 if (!send.is_slm()) continue;
1537
1538 auto new_args = s.as<func_call_t>().args;
1539 send_t::arg_mem_off(new_args) += ab_slm_size_ * slm_idx;
1540 auto new_send = send.call(new_args);
1541 ret = substitute(ret, s, new_send, 1);
1542 }
1543
1544 return ret;
1545 }
1546
1547 ir_context_t &ir_ctx_;
1548 const conv_config_t &cfg_;
1549 int ab_slm_size_;
1550
1551 stmt_t root_;
1552 alloc_manager_t alloc_mgr_;
1553 compute_step_t step_;
1554 compute_loop_nest_t loop_nest_;
1555 slm_sync_manager_t slm_sync_mgr_;
1556};
1557
1558stmt_t inject_simple_slm_buffering(const stmt_t &s, ir_context_t &ir_ctx,
1559 const conv_config_t &cfg, int ab_slm_size) {
1560 trace_start();
1561 auto ret = simple_slm_buffering_injector_t(s, ir_ctx, cfg, ab_slm_size)
1562 .inject();
1563 trace_pass("inject_simple_slm_buffering", ret, ir_ctx);
1564 return ret;
1565}
1566
1567class unrolling_injector_t {
1568public:
1569 unrolling_injector_t(const stmt_t &root, const conv_config_t &cfg,
1570 ir_context_t &ir_ctx, int ab_slm_size)
1571 : cfg_(cfg)
1572 , ir_ctx_(ir_ctx)
1573 , ab_slm_size_(ab_slm_size)
1574 , root_(root)
1575 , alloc_mgr_(root_)
1576 , step_(root)
1577 , loop_nest_(root, ir_ctx)
1578 , slm_sync_mgr_(cfg, /*with_unroll=*/true) {
1579 int inner_iters = loop_nest_.inner_loops_size();
1580 params_ = compute_params_t(cfg_.slm().bufs(), cfg_.slm().gmem_bufs(),
1581 ab_slm_size, cfg_.prefetch().bufs(), inner_iters);
1582 if (params_.use_slm) {
1583 for (auto &b :
1584 find_send_buffers(step_.g2s_load(), /*is_mem=*/false)) {
1585 g2s_reg_bufs_.emplace_back(b, alloc_mgr_.alloc_size(b));
1586 }
1587 }
1588
1589 // Can't fuse top-level zero-out statement unless the compute loop is
1590 // top-level as well.
1591 fuse_zero_out_with_fma_ = (loop_nest_.compute_loop_level() == 0);
1592 }
1593
1594 stmt_t inject() {
1595 compute_iterator_t it(params_, loop_nest_);
1596 stmt_t body;
1597
1598 sbid_manager_t sbid_mgr(cfg_.hw());
1599
1600 auto &outer_loop_info = loop_nest_.outer_loop_info();
1601
1602 auto append_outer_post_inc = [&](const stmt_t &_s) {
1603 auto &mul = outer_loop_info.mul_post_inc_stmt();
1604 auto &preload = outer_loop_info.preload_post_inc_stmt();
1605 auto s = _s;
1606 if (it.mul_loop_it.is_outer_loop_end() && it.do_mul()) {
1607 s = s.append(mul);
1608 }
1609 if (it.preload_loop_it.is_outer_loop_end() && it.do_preload()) {
1610 s = s.append(preload);
1611 }
1612 return s;
1613 };
1614
1615 bmnk_dim_helper_t h(cfg_);
1616 int k_iter_blk = h.iter_dim('k');
1617 int reduce_iter_bytes = k_iter_blk * cfg_.prb().a_data_type_size;
1618 // Add periodic signal-wait thread group synchronization in some cases.
1619 // This is to ensure threads access close reduction blocks and able to
1620 // reuse their common data from L1.
1621 bool do_sync
1622 = (cfg_.hw() >= ngen::HW::XeHPC) && (reduce_iter_bytes > 32);
1623 if (cfg_.slm()) do_sync = false;
1624 // Distance in iterations between signal and wait.
1625 int sync_dist = 3;
1626
1627 // Ramp-up.
1628 for (int i = 0; i < it.ramp_up_iters; i++) {
1629 body = stmt_seq_t::make(body, create_iteration(it, sbid_mgr));
1630 body = append_outer_post_inc(body);
1631 ++it;
1632 }
1633
1634 // Body.
1635 if (it.body_iters > 0) {
1636 int extent = it.body_iters / it.unroll();
1637 bool has_loop = (extent > 1);
1638
1639 stmt_t loop_body;
1640 bool do_sync_wait = false;
1641 for (int i = 0; i < it.unroll(); i++) {
1642 if (do_sync && i % sync_dist == 0) {
1643 loop_body = loop_body.append(do_sync_wait
1644 ? funcs::barrier_wait()
1645 : funcs::signal());
1646 do_sync_wait = !do_sync_wait;
1647 }
1648 loop_body = loop_body.append(create_iteration(
1649 it, sbid_mgr, /*in_loop_body=*/has_loop));
1650 ir_assert(it.do_mul());
1651 loop_body = append_outer_post_inc(loop_body);
1652 ++it;
1653 }
1654 if (do_sync && do_sync_wait)
1655 loop_body = loop_body.append(funcs::barrier_wait());
1656 if (!has_loop) {
1657 body = body.append(loop_body);
1658 } else {
1659 ir_assert(extent > 0);
1660 auto for_var = ir_ctx_.create_tmp_var(type_t::s32(), "i");
1661 body = body.append(for_t::make(for_var, 0, extent, loop_body));
1662 }
1663 it.advance(it.body_iters - it.unroll());
1664 }
1665
1666 // Ramp-down.
1667 for (int i = 0; i < it.ramp_down_iters; i++) {
1668 ir_assert(it.do_mul());
1669 body = body.append(create_iteration(it, sbid_mgr));
1670 body = append_outer_post_inc(body);
1671 ++it;
1672 }
1673
1674 if (outer_loop_info.has_var_refs()) {
1675 body = outer_loop_info.init_stmt().append(body);
1676 body = outer_loop_info.inject_alloc_stmts(body);
1677 }
1678
1679 // When compute loop is part of outer loop and SLM buffering is used
1680 // then synchronization is required between outer iterations.
1681 if (loop_nest_.compute_loop_level() != 0 && params_.use_slm) {
1682 body = funcs::barrier().append(body);
1683 }
1684
1685 body = stmt_group_t::make(stmt_label_t::compute_loop(), body);
1686 auto ret = substitute(root_, step_.compute_loop(), body, 1);
1687
1688 if (params_.use_slm) {
1689 alloc_updater_t alloc_updater;
1690
1691 // Update buffer sizes.
1692 for (auto &b : g2s_reg_bufs_) {
1693 alloc_updater.resize(b.buf,
1694 alloc_mgr_.alloc_size(b.buf) * cfg_.slm().gmem_bufs());
1695 }
1696
1697 auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm);
1698 if (!slm_buffers.empty()) {
1699 ir_assert(slm_buffers.size() == 1);
1700
1701 auto &slm_buf = slm_buffers[0];
1702 int non_ab_slm_size
1703 = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_;
1704 alloc_updater.resize(slm_buf,
1705 non_ab_slm_size + ab_slm_size_ * cfg_.slm().bufs());
1706 }
1707
1708 ret = alloc_updater.update(ret);
1709 }
1710
1711 // Remove zero-out statement for C (handled by sub_fma_acc_with_zero).
1712 if (fuse_zero_out_with_fma_)
1713 ret = substitute(ret, step_.c_zero_out(), stmt_t(), 1);
1714
1715 return ret;
1716 }
1717
1718private:
1719 struct buffer_info_t {
1720 buffer_info_t(const expr_t &buf, int size) : buf(buf), size(size) {}
1721
1722 expr_t buf;
1723 int size;
1724 };
1725
1726 stmt_t create_iteration(const compute_iterator_t &it,
1727 sbid_manager_t &sbid_mgr, bool in_loop_body = false) const {
1728 auto g2s_load = step_.g2s_load();
1729 auto g2s_store = step_.g2s_store();
1730 auto prefetch = step_.prefetch();
1731 auto g2r_load = step_.g2r_load();
1732 auto s2r_load = step_.s2r_load();
1733 auto mul = step_.mul();
1734 auto lets = step_.inner_let_stmts();
1735 auto &outer_loop_info = loop_nest_.outer_loop_info();
1736
1737 loop_nest_.for_each_loop_var([&](const expr_t &v) {
1738 expr_t mul_var_value;
1739 expr_t preload_var_value;
1740 if (v.is_same(outer_loop_info.var) && in_loop_body
1741 && outer_loop_info.has_var_refs()) {
1742 mul_var_value = outer_loop_info.mul_var_load();
1743 preload_var_value = outer_loop_info.preload_var_load();
1744 } else {
1745 mul_var_value = it.mul_loop_it.var_value(v);
1746 preload_var_value = it.preload_loop_it.var_value(v);
1747 }
1748 g2s_load = const_fold(substitute(g2s_load, v, preload_var_value));
1749 g2s_store = const_fold(substitute(g2s_store, v, preload_var_value));
1750 prefetch = const_fold(substitute(prefetch, v, preload_var_value));
1751 for (auto &m : mul) {
1752 m = const_fold(substitute(m, v, mul_var_value));
1753 }
1754 for (auto &s : g2r_load) {
1755 s = const_fold(substitute(s, v, mul_var_value));
1756 }
1757 for (auto &s : s2r_load) {
1758 if (count_object(s, v) > 0) ir_error_not_expected();
1759 s = const_fold(substitute(s, v, preload_var_value));
1760 }
1761 for (int i = 0; i < int(lets.size()); i++) {
1762 auto &let = lets[i];
1763 auto &orig_let = step_.inner_let_stmts()[i];
1764 expr_t var_value;
1765 bool is_preload_let = step_.is_preload_let(orig_let);
1766 bool is_mul_let = step_.is_mul_let(orig_let);
1767 if (is_preload_let && !is_mul_let) {
1768 var_value = preload_var_value;
1769 } else if (is_mul_let && !is_preload_let) {
1770 var_value = mul_var_value;
1771 } else {
1772 ir_assert(count_object(let.as<let_t>().value, v) == 0)
1773 << "Unexpected reference to variable " << v
1774 << " from " << let;
1775 continue;
1776 }
1777 let = const_fold(substitute(let, v, var_value));
1778 }
1779 });
1780
1781 if (params_.use_slm) {
1782 g2s_load = sub_gmem_bufs(g2s_load, it, /*is_read=*/false);
1783 g2s_store = sub_gmem_bufs(g2s_store, it, /*is_read=*/true);
1784
1785 g2s_store = sub_slm_bufs(g2s_store, it, /*is_read=*/false);
1786 for (auto &s : s2r_load) {
1787 s = sub_slm_bufs(s, it, /*is_read=*/true);
1788 }
1789 }
1790
1791 if (it.is_first_mul() && fuse_zero_out_with_fma_) {
1792 mul = sub_fma_acc_with_zero(
1793 mul, cfg_.fma_kind() == fma_kind_t::mad);
1794 }
1795
1796 if (it.is_last_g2s_store())
1797 g2s_store = remove_post_inc_stores(g2s_store);
1798 if (it.is_last_g2s_load()) g2s_load = remove_post_inc_stores(g2s_load);
1799 if (it.is_last_prefetch()) prefetch = remove_post_inc_stores(prefetch);
1800 if (it.is_last_mul()) {
1801 for (auto &s : s2r_load)
1802 s = remove_post_inc_stores(s);
1803 for (auto &s : g2r_load)
1804 s = remove_post_inc_stores(s);
1805 }
1806
1807 stmt_t iter_stmt;
1808
1809 iter_stmt = slm_sync_mgr_.before_L(iter_stmt, it.do_mul());
1810 if (it.do_g2s_load()) iter_stmt = iter_stmt.append(g2s_load);
1811 iter_stmt = slm_sync_mgr_.after_L(iter_stmt, it.do_mul());
1812
1813 if (it.do_g2s_store() && it.slm_bufs() == 1) {
1814 iter_stmt = iter_stmt.append(funcs::barrier());
1815 iter_stmt = iter_stmt.append(g2s_store);
1816 iter_stmt = iter_stmt.append(funcs::barrier());
1817 }
1818
1819 if (it.do_prefetch()) iter_stmt = iter_stmt.append(prefetch);
1820
1821 if (it.do_mul()) {
1822 for (size_t i = 0; i < mul.size(); i++) {
1823 iter_stmt = iter_stmt.append(g2r_load[i]);
1824 iter_stmt = iter_stmt.append(s2r_load[i]);
1825 iter_stmt = iter_stmt.append(mul[i]);
1826 }
1827 }
1828 iter_stmt = slm_sync_mgr_.before_S(
1829 iter_stmt, it.do_mul(), it.is_last_mul(), it.iter);
1830
1831 if (it.do_g2s_store() && it.slm_bufs() >= 2) {
1832 iter_stmt = iter_stmt.append(g2s_store);
1833 }
1834
1835 iter_stmt = slm_sync_mgr_.after_S(iter_stmt, it.is_last_mul(), it.iter);
1836
1837 if (cfg_.assign_sbids())
1838 iter_stmt = sbid_assigner_t(sbid_mgr).assign(iter_stmt);
1839
1840 iter_stmt = inject_local_let(iter_stmt, lets, it.linear_id);
1841
1842 return iter_stmt;
1843 }
1844
1845 stmt_t sub_gmem_bufs(const stmt_t &stmt, const compute_iterator_t &it,
1846 bool is_read) const {
1847 if (it.slm_bufs() == 0) return stmt;
1848 if (is_read && !it.do_g2s_store()) return stmt;
1849 if (!is_read && !it.do_g2s_load()) return stmt;
1850
1851 int buf_idx = (is_read ? it.gmem_read_buf_index()
1852 : it.gmem_write_buf_index());
1853 if (buf_idx == 0) return stmt;
1854
1855 auto ret = stmt;
1856 for (auto &b : g2s_reg_bufs_) {
1857 ret = substitute(ret, b.buf, b.buf[buf_idx * b.size]);
1858 }
1859 return ret;
1860 }
1861
1862 stmt_t sub_slm_bufs(const stmt_t &stmt, const compute_iterator_t &it,
1863 bool is_read) const {
1864 if (it.slm_bufs() <= 1) return stmt;
1865 if (is_read && !it.do_mul()) return stmt;
1866 if (!is_read && !it.do_g2s_store()) return stmt;
1867
1868 int upd = (is_read ? it.slm_read_offset_update()
1869 : it.slm_write_offset_update());
1870
1871 auto stmt_vec = flatten_statements(stmt);
1872
1873 stmt_t ret = stmt;
1874 for (auto &s : stmt_vec) {
1875 auto *call = s.as_ptr<func_call_t>();
1876 if (!call) continue;
1877 auto *func = call->func.as_ptr<send_t>();
1878 if (!func) continue;
1879
1880 auto &send = call->func.as<send_t>();
1881 auto &args = call->args;
1882 auto &mem_buf = send_t::arg_mem_buf(args);
1883 auto &header_buf = send_t::arg_mem_off(args);
1884
1885 // This is not send to SLM, skip.
1886 if (!send.is_slm()) continue;
1887
1888 // May have signed offset.
1889 auto store_obj = send.create_offset_store(
1890 header_buf, mem_buf, upd, /*is_signed_offset=*/true);
1891 auto &store = store_obj.as<store_t>();
1892 expr_t old_value
1893 = load_t::make(send.address_type(), store.buf, store.off);
1894 auto post_inc_store = store_t::make(
1895 store.buf, store.off, old_value + store.value);
1896 ret = substitute(ret, s, stmt_seq_t::make(s, post_inc_store), 1);
1897 }
1898
1899 return ret;
1900 }
1901
1902 static std::vector<stmt_t> sub_fma_acc_with_zero(
1903 const std::vector<stmt_t> &stmt, const bool is_mad) {
1904 auto is_from_block = [is_mad](const stmt_t &curr,
1905 const std::vector<stmt_t> &vec, int bgn) {
1906 if (is_mad) return false;
1907 if (is_func_call<mad_t>(curr)) {
1908 return (mad_t::arg_dst(curr).is_equal(mad_t::arg_src0(curr)))
1909 && (!bgn || is_func_call<mad_t>(vec[bgn - 1]));
1910 } else if (const auto *s = curr.as_ptr<store_t>()) {
1911 const auto *bs
1912 = (bgn) ? vec[bgn - 1].as_ptr<store_t>() : nullptr;
1913 if (const auto *t = s->value.as_ptr<ternary_op_t>()) {
1914 const auto *a = t->a.as_ptr<load_t>();
1915 return (t->op_kind == op_kind_t::_mad) && a
1916 && (a->buf[a->off].is_equal(s->buf[s->off]))
1917 && (!bgn || (bs && bs->value.is<ternary_op_t>()));
1918 } else if (const auto *b = s->value.as_ptr<binary_op_t>()) {
1919 const auto *a = b->a.as_ptr<load_t>();
1920 return (b->op_kind == op_kind_t::_sub) && a
1921 && (a->buf[a->off].is_equal(s->buf[s->off]))
1922 && (!bgn || (bs && bs->value.is<binary_op_t>()));
1923 }
1924 }
1925 return false;
1926 };
1927 auto process_block = [is_mad](stmt_t &root,
1928 object_eq_set_t<expr_t> &seen,
1929 std::vector<stmt_t> &vec, int bgn,
1930 int end) {
1931 bgn += is_mad;
1932 end += is_mad;
1933 bool never_seen = (bgn > 0);
1934 for (int i = bgn - 1; never_seen && (i < end); i++) {
1935 if (is_func_call<mad_t>(vec[i])) {
1936 never_seen &= seen.insert(mad_t::arg_dst(vec[i])).second;
1937 } else if (const auto *s = vec[i].as_ptr<store_t>()) {
1938 never_seen &= seen.insert(s->buf[s->off]).second;
1939 }
1940 }
1941 for (int i = bgn - 1; never_seen && (i < end); i++) {
1942 if (is_func_call<mad_t>(vec[i])) {
1943 auto &call = vec[i].as<func_call_t>();
1944 auto &mad = call.func.as<mad_t>();
1945 auto *m = call.attr.as_ptr<instruction_modifier_attr_t>();
1946 ir_assert(!m
1947 || (!m->mod.is_atomic && m->mod.sbid.is_empty()));
1948 ir_assert(mad.src1_stride == 0);
1949 auto a_load = load_t::make(
1950 mad.src1_type.kind(), mad_t::arg_src1(vec[i]), 0);
1951 auto a = shuffle_t::make_broadcast(a_load, mad.exec_size);
1952 auto b_stride = mad.src2_stride * mad.src2_type.size();
1953 auto b = load_t::make(
1954 type_t(mad.src2_type.kind(), mad.exec_size),
1955 mad_t::arg_src2(vec[i]), 0, b_stride);
1956 auto mul = binary_op_t::make(op_kind_t::_mul, a, b,
1957 type_t(mad.dst_type.kind(), mad.exec_size));
1958 auto store = store_t::make(mad_t::arg_dst(vec[i]), 0, mul);
1959 root = substitute(root, vec[i], store, 1);
1960 } else if (const auto *s = vec[i].as_ptr<store_t>()) {
1961 if (const auto *t = s->value.as_ptr<ternary_op_t>()) {
1962 ir_assert(s->mask.is_empty());
1963 auto mul = binary_op_t::make(
1964 op_kind_t::_mul, t->b, t->c, t->type);
1965 root = substitute(root, vec[i],
1966 store_t::make(s->buf, s->off, mul), 1);
1967 } else if (const auto *b = s->value.as_ptr<binary_op_t>()) {
1968 auto store = store_t::make(s->buf, s->off, -b->b,
1969 s->stride, s->mask, true);
1970 root = substitute(root, vec[i], store, 1);
1971 } else {
1972 ir_error_not_expected();
1973 }
1974 } else {
1975 ir_error_not_expected();
1976 }
1977 }
1978 return (is_mad) ? end : 0;
1979 };
1980 std::vector<stmt_t> retn;
1981
1982 for (const auto &s : stmt) {
1983 ir_assert(s.is<stmt_group_t>());
1984 const auto &group = s.as<stmt_group_t>();
1985 auto body = group.body;
1986 auto stmt_vec = flatten_statements(body);
1987 object_eq_set_t<expr_t> seen_dst;
1988
1989 int bgn = 0;
1990 for (int i = 0; i < int(stmt_vec.size()); i++) {
1991 stmt_t curr = stmt_vec[i];
1992 const bool ifb = is_from_block(curr, stmt_vec, bgn);
1993 bgn = (ifb) ? (!bgn) ? i + 1 : bgn
1994 : process_block(body, seen_dst, stmt_vec, bgn, i);
1995 if (is_func_call<dpas_t>(curr) && !dpas_t::is_dp4a_call(curr)) {
1996 auto &call = curr.as<func_call_t>();
1997
1998 auto &dst = dpas_t::arg_dst(curr);
1999 auto src0 = expr_t(0); // Will be translated to null reg
2000 auto &src1 = dpas_t::arg_src1(curr);
2001 auto &src2 = dpas_t::arg_src2(curr);
2002
2003 if (seen_dst.insert(dst).second)
2004 body = substitute(body, curr,
2005 func_call_t::make(call.func,
2006 {dst, src0, src1, src2}, call.attr),
2007 1);
2008 }
2009 }
2010 process_block(body, seen_dst, stmt_vec, bgn,
2011 (int)stmt_vec.size() - is_mad);
2012 retn.emplace_back(stmt_group_t::make(group.label, body));
2013 }
2014 return retn;
2015 }
2016
2017 // Returns memory buffers if is_mem is true and register buffers otherwise.
2018 static object_set_t<expr_t> find_send_buffers(
2019 const stmt_t &s, bool is_mem) {
2020 object_set_t<expr_t> ret;
2021 auto calls = find_objects<func_call_t>(s);
2022 for (auto &_c : calls) {
2023 auto &c = _c.as<func_call_t>();
2024 if (!c.func.is<send_t>()) continue;
2025 auto &buf = (is_mem ? send_t::arg_mem_buf(_c)
2026 : send_t::arg_reg_buf(_c));
2027 ret.insert(buf.as<ptr_t>().base);
2028 }
2029 return ret;
2030 }
2031
2032 static stmt_t inject_local_let(const stmt_t &_s,
2033 const std::vector<stmt_t> &enclosed_lets, int id) {
2034 auto s = _s;
2035
2036 // Inject let statements from the innermost loop.
2037 for (auto &_let : enclosed_lets) {
2038 auto &let = _let.as<let_t>();
2039 s = let_t::make(let.var, let.value, s);
2040 }
2041
2042 // Substitute variables to avoid clashing.
2043 auto lets = find_objects<let_t>(s);
2044 for (auto &_let : lets) {
2045 auto &let = _let.as<let_t>();
2046 auto &var = let.var.as<var_t>();
2047 auto local_var = var_t::make(
2048 var.type, var.name + "_" + std::to_string(id));
2049 s = substitute(s, let.var, local_var);
2050 }
2051 return s;
2052 }
2053
2054 static stmt_t remove_post_inc_stores(const stmt_t &_s) {
2055 auto stores = find_objects<store_t>(_s);
2056 auto s = _s;
2057 for (auto &_store : stores) {
2058 auto &store = _store.as<store_t>();
2059 if (!contains_object(store.value, store.buf)) continue;
2060 s = substitute(s, store, stmt_t());
2061 }
2062 return s;
2063 }
2064
2065 const conv_config_t &cfg_;
2066 ir_context_t &ir_ctx_;
2067 int ab_slm_size_;
2068
2069 stmt_t root_;
2070 alloc_manager_t alloc_mgr_;
2071 compute_step_t step_;
2072 compute_loop_nest_t loop_nest_;
2073 compute_params_t params_;
2074 slm_sync_manager_t slm_sync_mgr_;
2075
2076 std::vector<buffer_info_t> g2s_reg_bufs_; // For SLM buffering.
2077 bool fuse_zero_out_with_fma_ = false;
2078};
2079
2080stmt_t inject_unrolling(const stmt_t &s, ir_context_t &ir_ctx,
2081 const conv_config_t &cfg, int ab_slm_size) {
2082 trace_start();
2083 auto ret = unrolling_injector_t(s, cfg, ir_ctx, ab_slm_size).inject();
2084 trace_pass("inject_unrolling", ret, ir_ctx);
2085 return ret;
2086}
2087
2088} // namespace jit
2089} // namespace gpu
2090} // namespace impl
2091} // namespace dnnl
2092