1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_MESSAGE_HPP |
18 | #define GPU_JIT_IR_MESSAGE_HPP |
19 | |
20 | #include "gpu/jit/ir/fma.hpp" |
21 | #include "gpu/jit/ir/gemm_schedule.hpp" |
22 | #include "gpu/jit/ir/ir.hpp" |
23 | #include "gpu/jit/ir/tensor.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace jit { |
29 | |
30 | // Send operation kind. |
31 | enum class send_op_t { |
32 | atomic_fadd, |
33 | load, |
34 | load_2d, |
35 | prefetch, |
36 | prefetch_2d, |
37 | store, |
38 | store_2d, |
39 | }; |
40 | |
41 | std::ostream &operator<<(std::ostream &out, const send_op_t value); |
42 | |
43 | // Send address model. |
44 | enum class send_address_t { |
45 | a64, |
46 | bts, |
47 | slm, |
48 | }; |
49 | |
50 | enum class send_cache_hint_t { |
51 | undef, |
52 | load_once, |
53 | }; |
54 | |
55 | inline std::string to_string(send_cache_hint_t hint) { |
56 | switch (hint) { |
57 | case send_cache_hint_t::undef: return "cache:undef" ; |
58 | case send_cache_hint_t::load_once: return "cache:load_once" ; |
59 | default: return "cache:error" ; |
60 | } |
61 | } |
62 | |
63 | inline std::ostream &operator<<( |
64 | std::ostream &out, const send_cache_hint_t hint) { |
65 | out << to_string(hint); |
66 | return out; |
67 | } |
68 | |
69 | struct block_2d_info_t { |
70 | bool is_empty() const { return surface_width == 0; } |
71 | |
72 | bool operator==(const block_2d_info_t &other) const { |
73 | if (is_empty() != other.is_empty()) return false; |
74 | if (is_empty()) return true; |
75 | return (surface_width == other.surface_width) |
76 | && (surface_height == other.surface_height) |
77 | && (surface_pitch == other.surface_pitch) |
78 | && (width == other.width) && (height == other.height) |
79 | && (count == other.count) && (vnni == other.vnni) |
80 | && (transpose == other.transpose); |
81 | } |
82 | |
83 | size_t get_hash() const { |
84 | if (is_empty()) return 0; |
85 | return ir_utils::get_hash(surface_width, surface_height, surface_pitch, |
86 | width, height, count, vnni, transpose); |
87 | } |
88 | |
89 | std::string str() const { |
90 | std::ostringstream oss; |
91 | oss << count << "x" ; |
92 | oss << height << "x" ; |
93 | oss << width; |
94 | if (vnni || transpose) { |
95 | oss << "." ; |
96 | if (vnni) oss << "v" ; |
97 | if (transpose) oss << "t" ; |
98 | } |
99 | return oss.str(); |
100 | } |
101 | |
102 | // Encoded in header. |
103 | int surface_width = 0; |
104 | int surface_height = 0; |
105 | int surface_pitch = 0; |
106 | int width = 0; |
107 | int height = 0; |
108 | int count = 0; |
109 | // Part of descriptor. |
110 | bool vnni = false; |
111 | bool transpose = false; |
112 | }; |
113 | |
114 | // Function representing send messages. |
115 | class send_t : public func_impl_t { |
116 | public: |
117 | IR_DECL_DERIVED_TYPE_ID(send_t, func_impl_t) |
118 | |
119 | static func_t make(ngen::HW hw, send_op_t op, send_address_t address, |
120 | const type_t &type, int slots, |
121 | send_cache_hint_t cache_hint = send_cache_hint_t::undef) { |
122 | return make(hw, op, address, type, slots, hw >= ngen::HW::XeHPC, |
123 | cache_hint); |
124 | } |
125 | |
126 | static func_t make(ngen::HW hw, send_op_t op, send_address_t address, |
127 | const type_t &type, int slots, bool is_lsc, |
128 | send_cache_hint_t cache_hint = send_cache_hint_t::undef) { |
129 | return func_t( |
130 | new send_t(hw, op, address, type, slots, is_lsc, cache_hint)); |
131 | } |
132 | |
133 | static func_t make_2d(ngen::HW hw, send_op_t op, const type_t &type, |
134 | int surface_width, int surface_height, int surface_pitch, int width, |
135 | int height, int count, bool vnni, bool transpose, |
136 | send_cache_hint_t cache_hint = send_cache_hint_t::undef) { |
137 | block_2d_info_t info; |
138 | info.surface_width = surface_width; |
139 | info.surface_height = surface_height; |
140 | info.surface_pitch = surface_pitch; |
141 | info.width = width; |
142 | info.height = height; |
143 | info.count = count; |
144 | info.vnni = vnni; |
145 | info.transpose = transpose; |
146 | return func_t(new send_t(hw, op, type, info, cache_hint)); |
147 | } |
148 | |
149 | bool is_equal(const object_impl_t &obj) const override { |
150 | if (!obj.is<self_type>()) return false; |
151 | auto &other = obj.as<self_type>(); |
152 | |
153 | return (hw == other.hw) && (op == other.op) |
154 | && (address == other.address) && (type == other.type) |
155 | && (slots == other.slots) && (is_lsc == other.is_lsc) |
156 | && (block_2d_info == other.block_2d_info); |
157 | } |
158 | |
159 | size_t get_hash() const override { |
160 | return ir_utils::get_hash( |
161 | hw, op, address, type, slots, is_lsc, block_2d_info); |
162 | } |
163 | |
164 | std::string str() const override { |
165 | std::ostringstream oss; |
166 | oss << op; |
167 | oss << "." ; |
168 | if (is_scattered()) oss << slots << "x" ; |
169 | oss << type.str(); |
170 | if (is_2d()) oss << "." << block_2d_info.str(); |
171 | if (cache_hint != send_cache_hint_t::undef) oss << "." << cache_hint; |
172 | return oss.str(); |
173 | } |
174 | |
175 | IR_DEFINE_ARG_GET(mem_buf, 0) |
176 | IR_DEFINE_ARG_GET(mem_off, 1) |
177 | IR_DEFINE_ARG_GET(header_buf, 1) |
178 | IR_DEFINE_ARG_GET(reg_buf, 2) |
179 | IR_DEFINE_ARG_GET(mask, 3) |
180 | IR_DEFINE_ARG_GET(x, 4) |
181 | IR_DEFINE_ARG_GET(y, 5) |
182 | |
183 | // Header offsets in bytes for 2D block messages. |
184 | static int () { return 0; } |
185 | static int () { return 8; } |
186 | static int () { return 12; } |
187 | static int () { return 16; } |
188 | static int () { return 20; } |
189 | static int () { return 24; } |
190 | static int () { return 28; } |
191 | |
192 | stmt_t operator()(const expr_t &mem_buf, const expr_t &mem_off, |
193 | const expr_t ®_buf, const expr_t &mask, |
194 | const expr_t &x = expr_t(), const expr_t &y = expr_t()) const { |
195 | return call({mem_buf, mem_off, reg_buf, mask, x, y}); |
196 | } |
197 | |
198 | bool is_atomic() const { return op == send_op_t::atomic_fadd; } |
199 | bool is_load() const { return op == send_op_t::load; } |
200 | bool is_load_2d() const { return op == send_op_t::load_2d; } |
201 | bool is_prefetch() const { return op == send_op_t::prefetch; } |
202 | bool is_prefetch_2d() const { return op == send_op_t::prefetch_2d; } |
203 | bool is_store() const { return op == send_op_t::store; } |
204 | bool is_store_2d() const { return op == send_op_t::store_2d; } |
205 | bool is_2d() const { |
206 | return is_load_2d() || is_store_2d() || is_prefetch_2d(); |
207 | } |
208 | bool is_a64() const { return address == send_address_t::a64; } |
209 | bool is_bts() const { return address == send_address_t::bts; } |
210 | bool is_slm() const { return address == send_address_t::slm; } |
211 | |
212 | bool is_block() const { |
213 | return utils::one_of( |
214 | type.kind(), type_kind_t::oword, type_kind_t::hword); |
215 | } |
216 | |
217 | bool is_scattered() const { return !is_block() && !is_2d(); } |
218 | |
219 | // Size of memory (global memory or SLM) to access. |
220 | int access_size() const { |
221 | if (is_2d()) { |
222 | auto &info = block_2d_info; |
223 | return type.size() * info.width * info.height * info.count; |
224 | } |
225 | return type.size() * slots; |
226 | } |
227 | |
228 | int payload_type_stride() const { |
229 | ir_assert(!is_2d()); |
230 | if (type.kind() == type_kind_t::byte) return 4; |
231 | return type.size(); |
232 | } |
233 | |
234 | // Full size of payload GRF buffer for this message. Buffer may be strided |
235 | // and/or require GRF boundary round-up. |
236 | int payload_size() const { |
237 | if (is_2d()) return utils::rnd_up(access_size(), grf_size()); |
238 | int sz = payload_type_stride() * slots; |
239 | return utils::rnd_up(sz, grf_size()); |
240 | } |
241 | |
242 | int alignment() const { |
243 | if (is_2d()) return 128; |
244 | if (is_block()) return type.scalar().size(); |
245 | return 1; |
246 | } |
247 | |
248 | int mask_size() const { |
249 | if (is_2d()) return access_size(); |
250 | if (is_block()) { |
251 | // LSC messages use SIMT1 execution mask (one mask per message). |
252 | if (is_lsc) return type.size(); |
253 | return 4; |
254 | } |
255 | |
256 | if (is_scattered()) return type.size(); |
257 | |
258 | ir_error_not_expected(); |
259 | return 0; |
260 | } |
261 | |
262 | int nmasks() const { |
263 | if (is_2d()) return 1; |
264 | int masks = ir_utils::safe_divide(type.size() * slots, mask_size()); |
265 | if (masks > 16) { |
266 | ir_assert(is_block()) |
267 | << "Round-robin masking applies to block messages only." ; |
268 | ir_assert(masks % 16 == 0); |
269 | masks = 16; |
270 | } |
271 | return masks; |
272 | } |
273 | |
274 | int address_size() const { return is_a64() ? 8 : 4; } |
275 | |
276 | type_t address_type(bool is_signed = false, int elems = 1) const { |
277 | int bits = address_size() * 8; |
278 | return is_signed ? type_t::s(bits, elems) : type_t::u(bits, elems); |
279 | } |
280 | |
281 | // Size of header in bytes. |
282 | int () const { |
283 | if (is_2d()) return grf_size(); |
284 | return utils::rnd_up(address_size() * slots, grf_size()); |
285 | } |
286 | |
287 | // Generates a statement to store (and maybe convert) the offset to the |
288 | // message header according to the message description. |
289 | stmt_t create_offset_store(const expr_t &, const expr_t &mem_buf, |
290 | const expr_t &mem_off, bool is_signed_offset = false) const; |
291 | |
292 | bool is_supported() const; |
293 | |
294 | static std::vector<func_t> get_all(ngen::HW hw, send_op_t op, |
295 | send_address_t address, const type_t &mem_type, |
296 | send_cache_hint_t cache_hint); |
297 | |
298 | ngen::HW hw; |
299 | send_op_t op; |
300 | send_address_t address; |
301 | type_t type; |
302 | int slots; |
303 | bool is_lsc; |
304 | |
305 | block_2d_info_t block_2d_info; |
306 | send_cache_hint_t cache_hint; |
307 | |
308 | private: |
309 | int grf_size() const { return ngen::GRF::bytes(hw); } |
310 | |
311 | bool is_xe_hp_plus() const { return hw >= ngen::HW::XeHP; } |
312 | |
313 | bool is_xe_hpc_plus() const { return hw >= ngen::HW::XeHPC; } |
314 | |
315 | send_t(ngen::HW hw, send_op_t op, send_address_t address, |
316 | const type_t &type, int slots, bool is_lsc, |
317 | send_cache_hint_t cache_hint) |
318 | : func_impl_t(_type_info()) |
319 | , hw(hw) |
320 | , op(op) |
321 | , address(address) |
322 | , type(type) |
323 | , slots(slots) |
324 | , is_lsc(is_lsc) |
325 | , cache_hint(cache_hint) {} |
326 | |
327 | send_t(ngen::HW hw, send_op_t op, const type_t &type, |
328 | const block_2d_info_t &block_2d_info, send_cache_hint_t cache_hint) |
329 | : func_impl_t(_type_info()) |
330 | , hw(hw) |
331 | , op(op) |
332 | , address(send_address_t::a64) |
333 | , type(type) |
334 | , slots(1) |
335 | , is_lsc(true) |
336 | , block_2d_info(block_2d_info) |
337 | , cache_hint(cache_hint) { |
338 | ir_assert(utils::one_of(op, send_op_t::load_2d, send_op_t::store_2d, |
339 | send_op_t::prefetch_2d)); |
340 | if (is_store_2d()) { |
341 | ir_assert(!block_2d_info.vnni); |
342 | ir_assert(!block_2d_info.transpose); |
343 | } |
344 | } |
345 | }; |
346 | |
347 | ngen::CacheSettingsLSC get_cache_settings(const send_t &send); |
348 | |
349 | class memory_walker_t; |
350 | class layout_walker_t; |
351 | |
352 | struct send_2d_hint_t { |
353 | type_t type; |
354 | bool enable = false; |
355 | bool vnni = false; |
356 | bool transpose = false; |
357 | int vnni_permute_factor = 0; |
358 | int width = 0; |
359 | int height = 0; |
360 | }; |
361 | |
362 | struct send_hint_t { |
363 | send_op_t convert(const send_op_t &op) const { |
364 | if (hint_2d.enable) { |
365 | if (op == send_op_t::load) return send_op_t::load_2d; |
366 | if (op == send_op_t::store) return send_op_t::store_2d; |
367 | if (op == send_op_t::prefetch) return send_op_t::prefetch_2d; |
368 | } |
369 | return op; |
370 | } |
371 | |
372 | send_2d_hint_t hint_2d; |
373 | }; |
374 | |
375 | // Generates loads or stores to move data between memory (global or SLM) and |
376 | // GRF. Memory view is a parameter. GRF payload layout is deduced |
377 | // automatically, according to the decomposition into messages. |
378 | class access_builder_t { |
379 | public: |
380 | access_builder_t(ir_context_t &ir_ctx, const view_t &mem_view, |
381 | const expr_t &mem_buf, const expr_t ®_buf, send_op_t send_op, |
382 | send_address_t send_address, send_cache_hint_t send_cache_hint, |
383 | send_hint_t &send_hint); |
384 | access_builder_t(access_builder_t &&); |
385 | ~access_builder_t(); |
386 | |
387 | const layout_t ®_layout() const { return reg_layout_; } |
388 | int reg_buf_size() const { |
389 | return utils::rnd_up(reg_layout_.size(), grf_size()); |
390 | } |
391 | const stmt_t &stmt() const { return stmt_; } |
392 | |
393 | std::string str() const { |
394 | std::ostringstream oss; |
395 | oss << "Memory view: " << mem_view_ << std::endl; |
396 | oss << "Register layout: " << reg_layout_ << std::endl; |
397 | oss << "Register buffer: " << reg_buf_ << std::endl; |
398 | oss << "Register buffer size: " << reg_buf_size() << " (" |
399 | << reg_buf_size() / grf_size() << " regs)" << std::endl; |
400 | oss << "Statement: " << std::endl << stmt_; |
401 | return oss.str(); |
402 | } |
403 | |
404 | private: |
405 | void build(); |
406 | bool try_build(const layout_t &try_layout); |
407 | bool try_build_2d(send_hint_t &send_hint); |
408 | bool fixup_send_2d_params(const type_t &send_type, bool vnni, |
409 | bool transpose, bool use_xy, int &W, int &H, int &P, int &w, int &h, |
410 | int &c, int &vnni_permute_factor); |
411 | |
412 | bool check_2d_mask(const tensor_t &tile, bool use_virtual_surface, |
413 | int w_idx, int h_idx, expr_t &mask) const; |
414 | |
415 | std::vector<layout_t> candidate_payload_layouts() const; |
416 | stmt_t create_send_stmt(const send_t &send); |
417 | int grf_size() const { return ngen::GRF::bytes(ir_ctx_->hw_cfg().hw()); } |
418 | |
419 | ir_context_t *ir_ctx_ = nullptr; |
420 | view_t mem_view_; |
421 | expr_t mem_buf_; |
422 | expr_t reg_buf_; |
423 | send_op_t send_op_; |
424 | send_address_t send_address_; |
425 | send_cache_hint_t send_cache_hint_; |
426 | |
427 | type_t mem_type_; |
428 | |
429 | std::unique_ptr<memory_walker_t> mem_walker_; |
430 | std::unique_ptr<layout_walker_t> reg_layout_walker_; |
431 | |
432 | layout_t reg_layout_; |
433 | stmt_t stmt_; |
434 | }; |
435 | |
436 | inline access_builder_t make_access_builder(ir_context_t &ir_ctx, |
437 | const view_t &mem_view, const expr_t &mem_buf, const expr_t ®_buf, |
438 | send_op_t send_op, send_address_t send_address, |
439 | send_hint_t &send_hint) { |
440 | return access_builder_t(ir_ctx, mem_view, mem_buf, reg_buf, send_op, |
441 | send_address, send_cache_hint_t::undef, send_hint); |
442 | } |
443 | |
444 | inline access_builder_t make_access_builder(ir_context_t &ir_ctx, |
445 | const view_t &mem_view, const expr_t &mem_buf, const expr_t ®_buf, |
446 | send_op_t send_op, send_address_t send_address, |
447 | send_cache_hint_t send_cache_hint = send_cache_hint_t::undef) { |
448 | send_hint_t send_hint; |
449 | return access_builder_t(ir_ctx, mem_view, mem_buf, reg_buf, send_op, |
450 | send_address, send_cache_hint, send_hint); |
451 | } |
452 | |
453 | send_hint_t get_send_hint(const exec_config_t &exec_cfg, send_op_t send_op, |
454 | fma_kind_t fma_kind, abc_kind_t abc_kind, const view_t &view, |
455 | const gemm_schedule_t &gemm_schedule, bool allow_2d = true); |
456 | |
457 | inline send_hint_t get_send_hint(const exec_config_t &exec_cfg, |
458 | send_op_t send_op, abc_kind_t abc_kind, const view_t &view, |
459 | const gemm_schedule_t &gemm_schedule, bool allow_2d = true) { |
460 | return get_send_hint(exec_cfg, send_op, fma_kind_t::unknown, abc_kind, view, |
461 | gemm_schedule, allow_2d); |
462 | } |
463 | |
464 | } // namespace jit |
465 | } // namespace gpu |
466 | } // namespace impl |
467 | } // namespace dnnl |
468 | |
469 | #endif |
470 | |