1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/dpasw.hpp"
18
19#include "gpu/jit/ir/fma.hpp"
20#include "gpu/jit/ir/grf_permutation.hpp"
21#include "gpu/jit/ir/message.hpp"
22#include "gpu/jit/utils/trace.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace jit {
28
29class dpasw_injector_t {
30public:
31 dpasw_injector_t(ngen::HW hw, const stmt_t &load_mul_stmt,
32 const expr_t &c_buf, const stmt_t &c_store_stmt,
33 alloc_updater_t &alloc_updater, const expr_t &tg_idx0)
34 : hw_(hw)
35 , load_mul_stmt_(load_mul_stmt)
36 , c_buf_(c_buf)
37 , c_store_stmt_(c_store_stmt)
38 , alloc_updater_(alloc_updater)
39 , tg_idx0_(tg_idx0) {}
40
41 const stmt_t &load_mul_stmt() const { return load_mul_stmt_; }
42
43 const stmt_t &c_store_stmt() const { return c_store_stmt_; }
44
45 void inject() {
46 expr_t src2_base;
47 if (!extract_dpas_calls(src2_base)) return;
48
49 grf_permutation_t grf_perm;
50
51 bool was_injected = false;
52 int dpas_count = int(dpas_infos_.size());
53 for (int i = 0; i < dpas_count;) {
54 if (i + 1 < dpas_count) {
55 auto &a = dpas_infos_[i];
56 auto &b = dpas_infos_[i + 1];
57 if (try_convert_to_dpasw(a, b, grf_perm)) {
58 was_injected = true;
59 i += 2;
60 continue;
61 }
62 }
63 if (try_convert_to_dpasw(dpas_infos_[i], grf_perm)) {
64 was_injected = true;
65 }
66 ++i;
67 }
68 // Nothing to update, no dpas -> dpasw transformation.
69 if (!was_injected) return;
70
71 int src2_size = 0;
72 object_map_t<stmt_t, int> send2off;
73 std::function<int(const stmt_t &)> get_src2_off;
74 get_src2_off = [&](const stmt_t &s) {
75 auto &si = find_send_info(s);
76 if (!si.base_call.is_empty()) return get_src2_off(si.base_call);
77 if (!si.prev_send.is_empty()) return get_src2_off(si.prev_send);
78
79 auto it = send2off.find(s);
80 if (it != send2off.end()) return it->second;
81
82 auto ret = send2off.insert({s, src2_size});
83 if (!ret.second) return ret.first->second;
84
85 int new_size = si.new_reg_buf_size();
86 src2_size += new_size;
87 return ret.first->second;
88 };
89 for (auto &si : send_infos_) {
90 if (!si.reg_buf_base().is_equal(src2_base)) continue;
91
92 int src2_off = get_src2_off(si.call);
93 auto src2_sub = src2_base[src2_off];
94 auto new_call = si.new_call;
95 if (!new_call.is_empty()) {
96 new_call = substitute(
97 new_call, send_t::arg_reg_buf(new_call), src2_sub, 1);
98 }
99
100 load_mul_stmt_ = substitute(load_mul_stmt_, si.call, new_call, 1);
101 for (auto &d : si.dpas_consumers) {
102 auto &di = find_dpas_info(d);
103 ir_assert(si.promote_to_dpasw == di.promote_to_dpasw)
104 << "Both send and dpas must be updated.";
105 if (di.update_applied) {
106 ir_error_not_expected() << "Can it happen?";
107 continue;
108 }
109 auto new_call = di.new_call;
110 new_call = substitute(new_call, dpas_t::arg_src2(new_call),
111 src2_sub[di.src2_relative_off], 1);
112 load_mul_stmt_
113 = substitute(load_mul_stmt_, di.call, new_call, 1);
114 di.update_applied = true;
115 }
116 }
117
118 // Update src2 size after applying send updates.
119 alloc_updater_.resize(src2_base, src2_size);
120
121 // Apply permutation to C buffer.
122 alloc_updater_.add_attr(c_buf_,
123 grf_permute_attr_t::make(
124 std::make_shared<grf_permutation_t>(grf_perm)));
125 }
126
127private:
128 struct send_info_t {
129 send_info_t() = default;
130
131 send_info_t(const stmt_t &call) : call(call), new_call(call) {}
132
133 const send_t &send() const {
134 return call.as<func_call_t>().func.as<send_t>();
135 }
136
137 const send_t &new_send() const {
138 ir_assert(!new_call.is_same(call));
139 return new_call.as<func_call_t>().func.as<send_t>();
140 }
141
142 const std::vector<expr_t> &args() const {
143 return call.as<func_call_t>().args;
144 }
145
146 const expr_t &reg_buf() const { return send_t::arg_reg_buf(call); }
147
148 const expr_t &reg_buf_base() const {
149 return reg_buf().as<ptr_t>().base;
150 }
151
152 int reg_buf_size() const { return send().payload_size(); }
153
154 int new_reg_buf_size() const {
155 if (new_call.is_same(call)) return reg_buf_size();
156 return new_send().payload_size();
157 }
158
159 void set_new_call(const stmt_t &s, const stmt_t &base = stmt_t()) {
160 if (!promote_to_dpasw) {
161 promote_to_dpasw = true;
162 new_call = s;
163 base_call = base;
164 return;
165 }
166 ir_assert(new_call.is_equal(s));
167 ir_assert(base_call.is_equal(base));
168 }
169
170 void set_prev_send(const stmt_t &s) {
171 int prev_size
172 = s.as<func_call_t>().func.as<send_t>().payload_size();
173 if (reg_buf_size() != prev_size) return;
174 prev_send = s;
175 }
176
177 stmt_t call;
178 std::vector<stmt_t> dpas_consumers;
179
180 bool promote_to_dpasw = false;
181 stmt_t new_call;
182 stmt_t base_call;
183 stmt_t prev_send;
184 };
185
186 struct dpas_info_t {
187 dpas_info_t() = default;
188
189 dpas_info_t(const stmt_t &call) : call(call), new_call(call) {}
190
191 const dpas_t &dpas() const {
192 return call.as<func_call_t>().func.as<dpas_t>();
193 }
194
195 const std::vector<expr_t> &args() const {
196 return call.as<func_call_t>().args;
197 }
198
199 const expr_t &src1_buf() const { return dpas_t::arg_src1(call); }
200
201 const expr_t &src2_buf() const { return dpas_t::arg_src2(call); }
202
203 int src2_size() const { return dpas().src2_size(); }
204
205 void set_new_call(const stmt_t &s, int src2_relative_off) {
206 if (!promote_to_dpasw) {
207 promote_to_dpasw = true;
208 this->src2_relative_off = src2_relative_off;
209 new_call = s;
210 return;
211 }
212 ir_assert(this->src2_relative_off == src2_relative_off);
213 ir_assert(new_call.is_equal(s));
214 }
215
216 stmt_t call;
217 stmt_t send_producer;
218
219 bool promote_to_dpasw = false;
220 bool update_applied = false;
221 int src2_relative_off = 0;
222 stmt_t new_call;
223 };
224
225 send_info_t &find_send_info(const stmt_t &s) {
226 for (auto &si : send_infos_)
227 if (si.call.is_same(s)) return si;
228 ir_error_not_expected();
229 return send_infos_.front();
230 }
231
232 dpas_info_t &find_dpas_info(const stmt_t &s) {
233 for (auto &si : dpas_infos_)
234 if (si.call.is_same(s)) return si;
235 ir_error_not_expected();
236 return dpas_infos_.front();
237 }
238 static bool is_send(const stmt_t &s, send_info_t &info) {
239 if (!is_func_call<send_t>(s)) return false;
240 info = send_info_t(s);
241 return true;
242 }
243
244 static bool is_dpas(const stmt_t &s, dpas_info_t &info) {
245 if (!is_func_call<dpas_t>(s)) return false;
246 if (dpas_t::is_dp4a_call(s)) return false;
247 info = dpas_info_t(s);
248 return true;
249 }
250
251 bool extract_dpas_calls(expr_t &src2_base) {
252 object_eq_map_t<expr_t, stmt_t> buf2send;
253
254 auto set_src2_base = [&](const expr_t &ptr) {
255 auto &ptr_base = ptr.as<ptr_t>().base;
256 if (src2_base.is_empty()) {
257 src2_base = ptr_base;
258 return;
259 }
260 ir_assert(src2_base.is_same(ptr_base));
261 };
262
263 // Iterate through dpas and send calls.
264 auto stmt_vec = flatten_statements(load_mul_stmt_);
265 for (auto &s : stmt_vec) {
266 send_info_t send_info;
267 if (is_send(s, send_info)) {
268 auto &buf = send_info.reg_buf();
269 stmt_t prev_send;
270 auto it = buf2send.find(buf);
271 if (it != buf2send.end()) prev_send = it->second;
272 buf2send[buf] = s;
273 send_infos_.push_back(send_info);
274 if (!prev_send.is_empty()) {
275 send_infos_.back().set_prev_send(prev_send);
276 }
277 continue;
278 }
279 dpas_info_t dpas_info;
280 if (is_dpas(s, dpas_info)) {
281 set_src2_base(dpas_info.src2_buf());
282 auto &buf = dpas_info.src2_buf();
283 auto it = buf2send.find(buf);
284 if (it == buf2send.end()) continue;
285 auto &send_info = find_send_info(it->second);
286 // For simplicity require full size match between load and dpas
287 // instructions. That is dpas src2 buffer should be fully
288 // loaded by the corresponding send message.
289 if (send_info.reg_buf_size() != dpas_info.src2_size()) {
290 ir_warning() << "Can't inject dpasw: different register "
291 "sizes in send and dpas."
292 << std::endl;
293 return false;
294 }
295 dpas_info.send_producer = send_info.call;
296 send_info.dpas_consumers.push_back(s);
297 dpas_infos_.push_back(dpas_info);
298 }
299 }
300 return true;
301 }
302
303 // Checks for the following pattern:
304 // dpas.sxr(a_dst, a_src0, src1, src2)
305 // dpas.sxr(b_dst, b_src0, src1, src2 + s * r * 4)
306 static bool can_convert_to_dpasw(
307 const dpas_info_t &a, const dpas_info_t &b) {
308 if (!a.dpas().is_equal(b.dpas())) return false;
309 if (!a.src1_buf().is_equal(b.src1_buf())) return false;
310
311 auto src2_off0 = to_cpp<int>(a.src2_buf().as<ptr_t>().off);
312 auto src2_off1 = to_cpp<int>(b.src2_buf().as<ptr_t>().off);
313
314 if (src2_off1 - src2_off0 != a.src2_size()) return false;
315
316 return true;
317 }
318
319 bool try_convert_to_dpasw(
320 dpas_info_t &a, dpas_info_t &b, grf_permutation_t &grf_perm) {
321 if (hw_ >= ngen::HW::XeHPC) return false;
322
323 // Check if DPAS -> DPASW transformation is possible.
324 if (!can_convert_to_dpasw(a, b)) return false;
325
326 // Perform the transformation:
327 // Before:
328 // send(slm, a_off, src2[0])
329 // send(slm, b_off, src2[s * r * 4])
330 // dpas.sxr(a_dst, a_src0, src1, src2[0])
331 // dpas.sxr(b_dst, b_src0, src1, src2[s * r * 4])
332 // After:
333 // send(slm, a_off + (tg_idx0 % 2) * (b_off - a_off), src2)
334 // dpasw.sxr(p_a_dst, p_a_src0, src1, src2[0])
335 // dpasw.sxr(p_b_dst, p_b_src0, src1, src2[s * r * 4 / 2])
336 // Where:
337 // p_a_dst[:] = a_dst[0:rcount / 2] + b_dst[0:rcount / 2]
338 // p_b_dst[:] = a_dst[rcount / 2:rcount] + b_dst[rcount / 2:rcount]
339 ir_assert(a.dpas().is_equal(b.dpas()));
340 auto _dpasw = dpas_t::make_dpasw(a.dpas());
341 auto &dpasw = _dpasw.as<dpas_t>();
342
343 auto a_args = a.args();
344 auto b_args = b.args();
345 dpas_t::arg_src2(b_args) -= dpasw.src2_size();
346
347 a.set_new_call(dpasw.call(a.args()), 0);
348 b.set_new_call(dpasw.call(b_args), dpasw.src2_size());
349
350 // Record permutation for registers to apply it for the destination
351 // store later.
352 const auto grf_size = ngen::GRF::bytes(hw_);
353 const auto rcount = a.dpas().rcount;
354 for (int j = 0; j < rcount; j++) {
355 int k = j % (rcount / 2);
356 auto a_old = dpas_t::arg_dst(a_args) + grf_size * j;
357 auto b_old = dpas_t::arg_dst(b_args) + grf_size * j;
358 expr_t grf_new;
359 if (j < rcount / 2) {
360 grf_new = dpas_t::arg_dst(a_args)[grf_size * k];
361 } else {
362 grf_new = dpas_t::arg_dst(b_args)[grf_size * k];
363 }
364 set_grf_permute(grf_perm, a_old, grf_new);
365 set_grf_permute(grf_perm, b_old, grf_new + grf_size * rcount / 2);
366 }
367
368 auto &a_send = find_send_info(a.send_producer);
369 auto &b_send = find_send_info(b.send_producer);
370
371 auto &a_mem_off = send_t::arg_mem_off(a_send.call);
372 auto &b_mem_off = send_t::arg_mem_off(b_send.call);
373 auto ab_addr_diff = simplify(b_mem_off - a_mem_off);
374 ir_assert(is_const(ab_addr_diff));
375
376 auto new_send_args = a_send.args();
377 send_t::arg_mem_off(new_send_args)
378 += (tg_idx0_ % 2) * to_cpp<int64_t>(ab_addr_diff);
379
380 a_send.set_new_call(a_send.send().call(new_send_args));
381 b_send.set_new_call(stmt_t(), a_send.call);
382
383 return true;
384 }
385
386 void set_grf_permute(grf_permutation_t &grf_perm, const expr_t &old_grf,
387 const expr_t &new_grf) {
388 int old_off = to_cpp<int>(old_grf.as<ptr_t>().off);
389 int new_off = to_cpp<int>(new_grf.as<ptr_t>().off);
390
391 const int grf_size = ngen::GRF::bytes(hw_);
392
393 ir_assert(old_off % grf_size == 0)
394 << "Must be aligned to GRF boundary.";
395 ir_assert(new_off % grf_size == 0)
396 << "Must be aligned to GRF boundary.";
397
398 old_off /= grf_size;
399 new_off /= grf_size;
400
401 grf_perm.set_permute(old_off, new_off);
402 }
403
404 static bool can_convert_to_dpasw(const dpas_info_t &a_dpas,
405 const send_info_t &a_send, const expr_t &tg_idx0) {
406 if (contains_object(a_send.call, tg_idx0)) return false;
407 return a_dpas.dpas().rcount % 2 == 0;
408 }
409
410 static func_t create_half_send(const send_t &send) {
411 ir_assert(send.type.elems() % 2 == 0) << "Can't create half-send.";
412 auto _s = send_t::make(send.hw, send.op, send.address,
413 send.type.with_elems(send.type.elems() / 2), send.slots,
414 send.is_lsc, send.cache_hint);
415 auto &s = _s.as<send_t>();
416 ir_assert(s.is_supported())
417 << "Can't find send reading half of the original send.";
418 MAYBE_UNUSED(s);
419 return _s;
420 }
421
422 bool try_convert_to_dpasw(dpas_info_t &a, grf_permutation_t &grf_perm) {
423 if (hw_ >= ngen::HW::XeHPC) return false;
424 if (!can_convert_to_dpasw(a, find_send_info(a.send_producer), tg_idx0_))
425 return false;
426
427 // Perform the transformation:
428 // Before:
429 // send(slm, a_off, src2[0])
430 // dpas.sxr(a_dst, a_src0, src1, src2[0])
431 // After:
432 // send(slm, a_off + (tg_idx0 % 2) * (s * r * 4 / 2), src2)
433 // dpasw.sxr(a_dst, a_src0, src1, src2[0])
434
435 auto _dpasw = dpas_t::make_dpasw(a.dpas());
436 auto &dpasw = _dpasw.as<dpas_t>();
437
438 a.set_new_call(dpasw.call(a.args()), 0);
439
440 auto &a_send = find_send_info(a.send_producer);
441 auto new_send_args = a_send.args();
442 send_t::arg_mem_off(new_send_args)
443 += (tg_idx0_ % 2) * (a.src2_size() / 2);
444 a_send.set_new_call(
445 create_half_send(a_send.send()).call(new_send_args));
446
447 return true;
448 }
449
450 ngen::HW hw_;
451 stmt_t load_mul_stmt_;
452 expr_t c_buf_;
453 stmt_t c_store_stmt_;
454 alloc_updater_t &alloc_updater_;
455 expr_t tg_idx0_;
456
457 std::vector<dpas_info_t> dpas_infos_;
458 std::vector<send_info_t> send_infos_;
459};
460
461void inject_dpasw(ngen::HW hw, stmt_t &load_mul_stmt, const expr_t &c_buf,
462 stmt_t &c_store_stmt, alloc_updater_t &alloc_updater,
463 const expr_t &tg_idx0) {
464 dpasw_injector_t injector(
465 hw, load_mul_stmt, c_buf, c_store_stmt, alloc_updater, tg_idx0);
466
467 injector.inject();
468 load_mul_stmt = injector.load_mul_stmt();
469 c_store_stmt = injector.c_store_stmt();
470}
471
472} // namespace jit
473} // namespace gpu
474} // namespace impl
475} // namespace dnnl
476