1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/dpasw.hpp" |
18 | |
19 | #include "gpu/jit/ir/fma.hpp" |
20 | #include "gpu/jit/ir/grf_permutation.hpp" |
21 | #include "gpu/jit/ir/message.hpp" |
22 | #include "gpu/jit/utils/trace.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace jit { |
28 | |
29 | class dpasw_injector_t { |
30 | public: |
31 | dpasw_injector_t(ngen::HW hw, const stmt_t &load_mul_stmt, |
32 | const expr_t &c_buf, const stmt_t &c_store_stmt, |
33 | alloc_updater_t &alloc_updater, const expr_t &tg_idx0) |
34 | : hw_(hw) |
35 | , load_mul_stmt_(load_mul_stmt) |
36 | , c_buf_(c_buf) |
37 | , c_store_stmt_(c_store_stmt) |
38 | , alloc_updater_(alloc_updater) |
39 | , tg_idx0_(tg_idx0) {} |
40 | |
41 | const stmt_t &load_mul_stmt() const { return load_mul_stmt_; } |
42 | |
43 | const stmt_t &c_store_stmt() const { return c_store_stmt_; } |
44 | |
45 | void inject() { |
46 | expr_t src2_base; |
47 | if (!extract_dpas_calls(src2_base)) return; |
48 | |
49 | grf_permutation_t grf_perm; |
50 | |
51 | bool was_injected = false; |
52 | int dpas_count = int(dpas_infos_.size()); |
53 | for (int i = 0; i < dpas_count;) { |
54 | if (i + 1 < dpas_count) { |
55 | auto &a = dpas_infos_[i]; |
56 | auto &b = dpas_infos_[i + 1]; |
57 | if (try_convert_to_dpasw(a, b, grf_perm)) { |
58 | was_injected = true; |
59 | i += 2; |
60 | continue; |
61 | } |
62 | } |
63 | if (try_convert_to_dpasw(dpas_infos_[i], grf_perm)) { |
64 | was_injected = true; |
65 | } |
66 | ++i; |
67 | } |
68 | // Nothing to update, no dpas -> dpasw transformation. |
69 | if (!was_injected) return; |
70 | |
71 | int src2_size = 0; |
72 | object_map_t<stmt_t, int> send2off; |
73 | std::function<int(const stmt_t &)> get_src2_off; |
74 | get_src2_off = [&](const stmt_t &s) { |
75 | auto &si = find_send_info(s); |
76 | if (!si.base_call.is_empty()) return get_src2_off(si.base_call); |
77 | if (!si.prev_send.is_empty()) return get_src2_off(si.prev_send); |
78 | |
79 | auto it = send2off.find(s); |
80 | if (it != send2off.end()) return it->second; |
81 | |
82 | auto ret = send2off.insert({s, src2_size}); |
83 | if (!ret.second) return ret.first->second; |
84 | |
85 | int new_size = si.new_reg_buf_size(); |
86 | src2_size += new_size; |
87 | return ret.first->second; |
88 | }; |
89 | for (auto &si : send_infos_) { |
90 | if (!si.reg_buf_base().is_equal(src2_base)) continue; |
91 | |
92 | int src2_off = get_src2_off(si.call); |
93 | auto src2_sub = src2_base[src2_off]; |
94 | auto new_call = si.new_call; |
95 | if (!new_call.is_empty()) { |
96 | new_call = substitute( |
97 | new_call, send_t::arg_reg_buf(new_call), src2_sub, 1); |
98 | } |
99 | |
100 | load_mul_stmt_ = substitute(load_mul_stmt_, si.call, new_call, 1); |
101 | for (auto &d : si.dpas_consumers) { |
102 | auto &di = find_dpas_info(d); |
103 | ir_assert(si.promote_to_dpasw == di.promote_to_dpasw) |
104 | << "Both send and dpas must be updated." ; |
105 | if (di.update_applied) { |
106 | ir_error_not_expected() << "Can it happen?" ; |
107 | continue; |
108 | } |
109 | auto new_call = di.new_call; |
110 | new_call = substitute(new_call, dpas_t::arg_src2(new_call), |
111 | src2_sub[di.src2_relative_off], 1); |
112 | load_mul_stmt_ |
113 | = substitute(load_mul_stmt_, di.call, new_call, 1); |
114 | di.update_applied = true; |
115 | } |
116 | } |
117 | |
118 | // Update src2 size after applying send updates. |
119 | alloc_updater_.resize(src2_base, src2_size); |
120 | |
121 | // Apply permutation to C buffer. |
122 | alloc_updater_.add_attr(c_buf_, |
123 | grf_permute_attr_t::make( |
124 | std::make_shared<grf_permutation_t>(grf_perm))); |
125 | } |
126 | |
127 | private: |
128 | struct send_info_t { |
129 | send_info_t() = default; |
130 | |
131 | send_info_t(const stmt_t &call) : call(call), new_call(call) {} |
132 | |
133 | const send_t &send() const { |
134 | return call.as<func_call_t>().func.as<send_t>(); |
135 | } |
136 | |
137 | const send_t &new_send() const { |
138 | ir_assert(!new_call.is_same(call)); |
139 | return new_call.as<func_call_t>().func.as<send_t>(); |
140 | } |
141 | |
142 | const std::vector<expr_t> &args() const { |
143 | return call.as<func_call_t>().args; |
144 | } |
145 | |
146 | const expr_t ®_buf() const { return send_t::arg_reg_buf(call); } |
147 | |
148 | const expr_t ®_buf_base() const { |
149 | return reg_buf().as<ptr_t>().base; |
150 | } |
151 | |
152 | int reg_buf_size() const { return send().payload_size(); } |
153 | |
154 | int new_reg_buf_size() const { |
155 | if (new_call.is_same(call)) return reg_buf_size(); |
156 | return new_send().payload_size(); |
157 | } |
158 | |
159 | void set_new_call(const stmt_t &s, const stmt_t &base = stmt_t()) { |
160 | if (!promote_to_dpasw) { |
161 | promote_to_dpasw = true; |
162 | new_call = s; |
163 | base_call = base; |
164 | return; |
165 | } |
166 | ir_assert(new_call.is_equal(s)); |
167 | ir_assert(base_call.is_equal(base)); |
168 | } |
169 | |
170 | void set_prev_send(const stmt_t &s) { |
171 | int prev_size |
172 | = s.as<func_call_t>().func.as<send_t>().payload_size(); |
173 | if (reg_buf_size() != prev_size) return; |
174 | prev_send = s; |
175 | } |
176 | |
177 | stmt_t call; |
178 | std::vector<stmt_t> dpas_consumers; |
179 | |
180 | bool promote_to_dpasw = false; |
181 | stmt_t new_call; |
182 | stmt_t base_call; |
183 | stmt_t prev_send; |
184 | }; |
185 | |
186 | struct dpas_info_t { |
187 | dpas_info_t() = default; |
188 | |
189 | dpas_info_t(const stmt_t &call) : call(call), new_call(call) {} |
190 | |
191 | const dpas_t &dpas() const { |
192 | return call.as<func_call_t>().func.as<dpas_t>(); |
193 | } |
194 | |
195 | const std::vector<expr_t> &args() const { |
196 | return call.as<func_call_t>().args; |
197 | } |
198 | |
199 | const expr_t &src1_buf() const { return dpas_t::arg_src1(call); } |
200 | |
201 | const expr_t &src2_buf() const { return dpas_t::arg_src2(call); } |
202 | |
203 | int src2_size() const { return dpas().src2_size(); } |
204 | |
205 | void set_new_call(const stmt_t &s, int src2_relative_off) { |
206 | if (!promote_to_dpasw) { |
207 | promote_to_dpasw = true; |
208 | this->src2_relative_off = src2_relative_off; |
209 | new_call = s; |
210 | return; |
211 | } |
212 | ir_assert(this->src2_relative_off == src2_relative_off); |
213 | ir_assert(new_call.is_equal(s)); |
214 | } |
215 | |
216 | stmt_t call; |
217 | stmt_t send_producer; |
218 | |
219 | bool promote_to_dpasw = false; |
220 | bool update_applied = false; |
221 | int src2_relative_off = 0; |
222 | stmt_t new_call; |
223 | }; |
224 | |
225 | send_info_t &find_send_info(const stmt_t &s) { |
226 | for (auto &si : send_infos_) |
227 | if (si.call.is_same(s)) return si; |
228 | ir_error_not_expected(); |
229 | return send_infos_.front(); |
230 | } |
231 | |
232 | dpas_info_t &find_dpas_info(const stmt_t &s) { |
233 | for (auto &si : dpas_infos_) |
234 | if (si.call.is_same(s)) return si; |
235 | ir_error_not_expected(); |
236 | return dpas_infos_.front(); |
237 | } |
238 | static bool is_send(const stmt_t &s, send_info_t &info) { |
239 | if (!is_func_call<send_t>(s)) return false; |
240 | info = send_info_t(s); |
241 | return true; |
242 | } |
243 | |
244 | static bool is_dpas(const stmt_t &s, dpas_info_t &info) { |
245 | if (!is_func_call<dpas_t>(s)) return false; |
246 | if (dpas_t::is_dp4a_call(s)) return false; |
247 | info = dpas_info_t(s); |
248 | return true; |
249 | } |
250 | |
251 | bool (expr_t &src2_base) { |
252 | object_eq_map_t<expr_t, stmt_t> buf2send; |
253 | |
254 | auto set_src2_base = [&](const expr_t &ptr) { |
255 | auto &ptr_base = ptr.as<ptr_t>().base; |
256 | if (src2_base.is_empty()) { |
257 | src2_base = ptr_base; |
258 | return; |
259 | } |
260 | ir_assert(src2_base.is_same(ptr_base)); |
261 | }; |
262 | |
263 | // Iterate through dpas and send calls. |
264 | auto stmt_vec = flatten_statements(load_mul_stmt_); |
265 | for (auto &s : stmt_vec) { |
266 | send_info_t send_info; |
267 | if (is_send(s, send_info)) { |
268 | auto &buf = send_info.reg_buf(); |
269 | stmt_t prev_send; |
270 | auto it = buf2send.find(buf); |
271 | if (it != buf2send.end()) prev_send = it->second; |
272 | buf2send[buf] = s; |
273 | send_infos_.push_back(send_info); |
274 | if (!prev_send.is_empty()) { |
275 | send_infos_.back().set_prev_send(prev_send); |
276 | } |
277 | continue; |
278 | } |
279 | dpas_info_t dpas_info; |
280 | if (is_dpas(s, dpas_info)) { |
281 | set_src2_base(dpas_info.src2_buf()); |
282 | auto &buf = dpas_info.src2_buf(); |
283 | auto it = buf2send.find(buf); |
284 | if (it == buf2send.end()) continue; |
285 | auto &send_info = find_send_info(it->second); |
286 | // For simplicity require full size match between load and dpas |
287 | // instructions. That is dpas src2 buffer should be fully |
288 | // loaded by the corresponding send message. |
289 | if (send_info.reg_buf_size() != dpas_info.src2_size()) { |
290 | ir_warning() << "Can't inject dpasw: different register " |
291 | "sizes in send and dpas." |
292 | << std::endl; |
293 | return false; |
294 | } |
295 | dpas_info.send_producer = send_info.call; |
296 | send_info.dpas_consumers.push_back(s); |
297 | dpas_infos_.push_back(dpas_info); |
298 | } |
299 | } |
300 | return true; |
301 | } |
302 | |
303 | // Checks for the following pattern: |
304 | // dpas.sxr(a_dst, a_src0, src1, src2) |
305 | // dpas.sxr(b_dst, b_src0, src1, src2 + s * r * 4) |
306 | static bool can_convert_to_dpasw( |
307 | const dpas_info_t &a, const dpas_info_t &b) { |
308 | if (!a.dpas().is_equal(b.dpas())) return false; |
309 | if (!a.src1_buf().is_equal(b.src1_buf())) return false; |
310 | |
311 | auto src2_off0 = to_cpp<int>(a.src2_buf().as<ptr_t>().off); |
312 | auto src2_off1 = to_cpp<int>(b.src2_buf().as<ptr_t>().off); |
313 | |
314 | if (src2_off1 - src2_off0 != a.src2_size()) return false; |
315 | |
316 | return true; |
317 | } |
318 | |
319 | bool try_convert_to_dpasw( |
320 | dpas_info_t &a, dpas_info_t &b, grf_permutation_t &grf_perm) { |
321 | if (hw_ >= ngen::HW::XeHPC) return false; |
322 | |
323 | // Check if DPAS -> DPASW transformation is possible. |
324 | if (!can_convert_to_dpasw(a, b)) return false; |
325 | |
326 | // Perform the transformation: |
327 | // Before: |
328 | // send(slm, a_off, src2[0]) |
329 | // send(slm, b_off, src2[s * r * 4]) |
330 | // dpas.sxr(a_dst, a_src0, src1, src2[0]) |
331 | // dpas.sxr(b_dst, b_src0, src1, src2[s * r * 4]) |
332 | // After: |
333 | // send(slm, a_off + (tg_idx0 % 2) * (b_off - a_off), src2) |
334 | // dpasw.sxr(p_a_dst, p_a_src0, src1, src2[0]) |
335 | // dpasw.sxr(p_b_dst, p_b_src0, src1, src2[s * r * 4 / 2]) |
336 | // Where: |
337 | // p_a_dst[:] = a_dst[0:rcount / 2] + b_dst[0:rcount / 2] |
338 | // p_b_dst[:] = a_dst[rcount / 2:rcount] + b_dst[rcount / 2:rcount] |
339 | ir_assert(a.dpas().is_equal(b.dpas())); |
340 | auto _dpasw = dpas_t::make_dpasw(a.dpas()); |
341 | auto &dpasw = _dpasw.as<dpas_t>(); |
342 | |
343 | auto a_args = a.args(); |
344 | auto b_args = b.args(); |
345 | dpas_t::arg_src2(b_args) -= dpasw.src2_size(); |
346 | |
347 | a.set_new_call(dpasw.call(a.args()), 0); |
348 | b.set_new_call(dpasw.call(b_args), dpasw.src2_size()); |
349 | |
350 | // Record permutation for registers to apply it for the destination |
351 | // store later. |
352 | const auto grf_size = ngen::GRF::bytes(hw_); |
353 | const auto rcount = a.dpas().rcount; |
354 | for (int j = 0; j < rcount; j++) { |
355 | int k = j % (rcount / 2); |
356 | auto a_old = dpas_t::arg_dst(a_args) + grf_size * j; |
357 | auto b_old = dpas_t::arg_dst(b_args) + grf_size * j; |
358 | expr_t grf_new; |
359 | if (j < rcount / 2) { |
360 | grf_new = dpas_t::arg_dst(a_args)[grf_size * k]; |
361 | } else { |
362 | grf_new = dpas_t::arg_dst(b_args)[grf_size * k]; |
363 | } |
364 | set_grf_permute(grf_perm, a_old, grf_new); |
365 | set_grf_permute(grf_perm, b_old, grf_new + grf_size * rcount / 2); |
366 | } |
367 | |
368 | auto &a_send = find_send_info(a.send_producer); |
369 | auto &b_send = find_send_info(b.send_producer); |
370 | |
371 | auto &a_mem_off = send_t::arg_mem_off(a_send.call); |
372 | auto &b_mem_off = send_t::arg_mem_off(b_send.call); |
373 | auto ab_addr_diff = simplify(b_mem_off - a_mem_off); |
374 | ir_assert(is_const(ab_addr_diff)); |
375 | |
376 | auto new_send_args = a_send.args(); |
377 | send_t::arg_mem_off(new_send_args) |
378 | += (tg_idx0_ % 2) * to_cpp<int64_t>(ab_addr_diff); |
379 | |
380 | a_send.set_new_call(a_send.send().call(new_send_args)); |
381 | b_send.set_new_call(stmt_t(), a_send.call); |
382 | |
383 | return true; |
384 | } |
385 | |
386 | void set_grf_permute(grf_permutation_t &grf_perm, const expr_t &old_grf, |
387 | const expr_t &new_grf) { |
388 | int old_off = to_cpp<int>(old_grf.as<ptr_t>().off); |
389 | int new_off = to_cpp<int>(new_grf.as<ptr_t>().off); |
390 | |
391 | const int grf_size = ngen::GRF::bytes(hw_); |
392 | |
393 | ir_assert(old_off % grf_size == 0) |
394 | << "Must be aligned to GRF boundary." ; |
395 | ir_assert(new_off % grf_size == 0) |
396 | << "Must be aligned to GRF boundary." ; |
397 | |
398 | old_off /= grf_size; |
399 | new_off /= grf_size; |
400 | |
401 | grf_perm.set_permute(old_off, new_off); |
402 | } |
403 | |
404 | static bool can_convert_to_dpasw(const dpas_info_t &a_dpas, |
405 | const send_info_t &a_send, const expr_t &tg_idx0) { |
406 | if (contains_object(a_send.call, tg_idx0)) return false; |
407 | return a_dpas.dpas().rcount % 2 == 0; |
408 | } |
409 | |
410 | static func_t create_half_send(const send_t &send) { |
411 | ir_assert(send.type.elems() % 2 == 0) << "Can't create half-send." ; |
412 | auto _s = send_t::make(send.hw, send.op, send.address, |
413 | send.type.with_elems(send.type.elems() / 2), send.slots, |
414 | send.is_lsc, send.cache_hint); |
415 | auto &s = _s.as<send_t>(); |
416 | ir_assert(s.is_supported()) |
417 | << "Can't find send reading half of the original send." ; |
418 | MAYBE_UNUSED(s); |
419 | return _s; |
420 | } |
421 | |
422 | bool try_convert_to_dpasw(dpas_info_t &a, grf_permutation_t &grf_perm) { |
423 | if (hw_ >= ngen::HW::XeHPC) return false; |
424 | if (!can_convert_to_dpasw(a, find_send_info(a.send_producer), tg_idx0_)) |
425 | return false; |
426 | |
427 | // Perform the transformation: |
428 | // Before: |
429 | // send(slm, a_off, src2[0]) |
430 | // dpas.sxr(a_dst, a_src0, src1, src2[0]) |
431 | // After: |
432 | // send(slm, a_off + (tg_idx0 % 2) * (s * r * 4 / 2), src2) |
433 | // dpasw.sxr(a_dst, a_src0, src1, src2[0]) |
434 | |
435 | auto _dpasw = dpas_t::make_dpasw(a.dpas()); |
436 | auto &dpasw = _dpasw.as<dpas_t>(); |
437 | |
438 | a.set_new_call(dpasw.call(a.args()), 0); |
439 | |
440 | auto &a_send = find_send_info(a.send_producer); |
441 | auto new_send_args = a_send.args(); |
442 | send_t::arg_mem_off(new_send_args) |
443 | += (tg_idx0_ % 2) * (a.src2_size() / 2); |
444 | a_send.set_new_call( |
445 | create_half_send(a_send.send()).call(new_send_args)); |
446 | |
447 | return true; |
448 | } |
449 | |
450 | ngen::HW hw_; |
451 | stmt_t load_mul_stmt_; |
452 | expr_t c_buf_; |
453 | stmt_t c_store_stmt_; |
454 | alloc_updater_t &alloc_updater_; |
455 | expr_t tg_idx0_; |
456 | |
457 | std::vector<dpas_info_t> dpas_infos_; |
458 | std::vector<send_info_t> send_infos_; |
459 | }; |
460 | |
461 | void inject_dpasw(ngen::HW hw, stmt_t &load_mul_stmt, const expr_t &c_buf, |
462 | stmt_t &c_store_stmt, alloc_updater_t &alloc_updater, |
463 | const expr_t &tg_idx0) { |
464 | dpasw_injector_t injector( |
465 | hw, load_mul_stmt, c_buf, c_store_stmt, alloc_updater, tg_idx0); |
466 | |
467 | injector.inject(); |
468 | load_mul_stmt = injector.load_mul_stmt(); |
469 | c_store_stmt = injector.c_store_stmt(); |
470 | } |
471 | |
472 | } // namespace jit |
473 | } // namespace gpu |
474 | } // namespace impl |
475 | } // namespace dnnl |
476 | |