1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_MESSAGE_HPP
18#define GPU_JIT_IR_MESSAGE_HPP
19
20#include "gpu/jit/ir/fma.hpp"
21#include "gpu/jit/ir/gemm_schedule.hpp"
22#include "gpu/jit/ir/ir.hpp"
23#include "gpu/jit/ir/tensor.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30// Send operation kind.
31enum class send_op_t {
32 atomic_fadd,
33 load,
34 load_2d,
35 prefetch,
36 prefetch_2d,
37 store,
38 store_2d,
39};
40
41std::ostream &operator<<(std::ostream &out, const send_op_t value);
42
43// Send address model.
44enum class send_address_t {
45 a64,
46 bts,
47 slm,
48};
49
50enum class send_cache_hint_t {
51 undef,
52 load_once,
53};
54
55inline std::string to_string(send_cache_hint_t hint) {
56 switch (hint) {
57 case send_cache_hint_t::undef: return "cache:undef";
58 case send_cache_hint_t::load_once: return "cache:load_once";
59 default: return "cache:error";
60 }
61}
62
63inline std::ostream &operator<<(
64 std::ostream &out, const send_cache_hint_t hint) {
65 out << to_string(hint);
66 return out;
67}
68
69struct block_2d_info_t {
70 bool is_empty() const { return surface_width == 0; }
71
72 bool operator==(const block_2d_info_t &other) const {
73 if (is_empty() != other.is_empty()) return false;
74 if (is_empty()) return true;
75 return (surface_width == other.surface_width)
76 && (surface_height == other.surface_height)
77 && (surface_pitch == other.surface_pitch)
78 && (width == other.width) && (height == other.height)
79 && (count == other.count) && (vnni == other.vnni)
80 && (transpose == other.transpose);
81 }
82
83 size_t get_hash() const {
84 if (is_empty()) return 0;
85 return ir_utils::get_hash(surface_width, surface_height, surface_pitch,
86 width, height, count, vnni, transpose);
87 }
88
89 std::string str() const {
90 std::ostringstream oss;
91 oss << count << "x";
92 oss << height << "x";
93 oss << width;
94 if (vnni || transpose) {
95 oss << ".";
96 if (vnni) oss << "v";
97 if (transpose) oss << "t";
98 }
99 return oss.str();
100 }
101
102 // Encoded in header.
103 int surface_width = 0;
104 int surface_height = 0;
105 int surface_pitch = 0;
106 int width = 0;
107 int height = 0;
108 int count = 0;
109 // Part of descriptor.
110 bool vnni = false;
111 bool transpose = false;
112};
113
114// Function representing send messages.
115class send_t : public func_impl_t {
116public:
117 IR_DECL_DERIVED_TYPE_ID(send_t, func_impl_t)
118
119 static func_t make(ngen::HW hw, send_op_t op, send_address_t address,
120 const type_t &type, int slots,
121 send_cache_hint_t cache_hint = send_cache_hint_t::undef) {
122 return make(hw, op, address, type, slots, hw >= ngen::HW::XeHPC,
123 cache_hint);
124 }
125
126 static func_t make(ngen::HW hw, send_op_t op, send_address_t address,
127 const type_t &type, int slots, bool is_lsc,
128 send_cache_hint_t cache_hint = send_cache_hint_t::undef) {
129 return func_t(
130 new send_t(hw, op, address, type, slots, is_lsc, cache_hint));
131 }
132
133 static func_t make_2d(ngen::HW hw, send_op_t op, const type_t &type,
134 int surface_width, int surface_height, int surface_pitch, int width,
135 int height, int count, bool vnni, bool transpose,
136 send_cache_hint_t cache_hint = send_cache_hint_t::undef) {
137 block_2d_info_t info;
138 info.surface_width = surface_width;
139 info.surface_height = surface_height;
140 info.surface_pitch = surface_pitch;
141 info.width = width;
142 info.height = height;
143 info.count = count;
144 info.vnni = vnni;
145 info.transpose = transpose;
146 return func_t(new send_t(hw, op, type, info, cache_hint));
147 }
148
149 bool is_equal(const object_impl_t &obj) const override {
150 if (!obj.is<self_type>()) return false;
151 auto &other = obj.as<self_type>();
152
153 return (hw == other.hw) && (op == other.op)
154 && (address == other.address) && (type == other.type)
155 && (slots == other.slots) && (is_lsc == other.is_lsc)
156 && (block_2d_info == other.block_2d_info);
157 }
158
159 size_t get_hash() const override {
160 return ir_utils::get_hash(
161 hw, op, address, type, slots, is_lsc, block_2d_info);
162 }
163
164 std::string str() const override {
165 std::ostringstream oss;
166 oss << op;
167 oss << ".";
168 if (is_scattered()) oss << slots << "x";
169 oss << type.str();
170 if (is_2d()) oss << "." << block_2d_info.str();
171 if (cache_hint != send_cache_hint_t::undef) oss << "." << cache_hint;
172 return oss.str();
173 }
174
175 IR_DEFINE_ARG_GET(mem_buf, 0)
176 IR_DEFINE_ARG_GET(mem_off, 1)
177 IR_DEFINE_ARG_GET(header_buf, 1)
178 IR_DEFINE_ARG_GET(reg_buf, 2)
179 IR_DEFINE_ARG_GET(mask, 3)
180 IR_DEFINE_ARG_GET(x, 4)
181 IR_DEFINE_ARG_GET(y, 5)
182
183 // Header offsets in bytes for 2D block messages.
184 static int header_2d_off_base() { return 0; }
185 static int header_2d_off_surface_width() { return 8; }
186 static int header_2d_off_surface_height() { return 12; }
187 static int header_2d_off_surface_pitch() { return 16; }
188 static int header_2d_off_x() { return 20; }
189 static int header_2d_off_y() { return 24; }
190 static int header_2d_off_whc() { return 28; }
191
192 stmt_t operator()(const expr_t &mem_buf, const expr_t &mem_off,
193 const expr_t &reg_buf, const expr_t &mask,
194 const expr_t &x = expr_t(), const expr_t &y = expr_t()) const {
195 return call({mem_buf, mem_off, reg_buf, mask, x, y});
196 }
197
198 bool is_atomic() const { return op == send_op_t::atomic_fadd; }
199 bool is_load() const { return op == send_op_t::load; }
200 bool is_load_2d() const { return op == send_op_t::load_2d; }
201 bool is_prefetch() const { return op == send_op_t::prefetch; }
202 bool is_prefetch_2d() const { return op == send_op_t::prefetch_2d; }
203 bool is_store() const { return op == send_op_t::store; }
204 bool is_store_2d() const { return op == send_op_t::store_2d; }
205 bool is_2d() const {
206 return is_load_2d() || is_store_2d() || is_prefetch_2d();
207 }
208 bool is_a64() const { return address == send_address_t::a64; }
209 bool is_bts() const { return address == send_address_t::bts; }
210 bool is_slm() const { return address == send_address_t::slm; }
211
212 bool is_block() const {
213 return utils::one_of(
214 type.kind(), type_kind_t::oword, type_kind_t::hword);
215 }
216
217 bool is_scattered() const { return !is_block() && !is_2d(); }
218
219 // Size of memory (global memory or SLM) to access.
220 int access_size() const {
221 if (is_2d()) {
222 auto &info = block_2d_info;
223 return type.size() * info.width * info.height * info.count;
224 }
225 return type.size() * slots;
226 }
227
228 int payload_type_stride() const {
229 ir_assert(!is_2d());
230 if (type.kind() == type_kind_t::byte) return 4;
231 return type.size();
232 }
233
234 // Full size of payload GRF buffer for this message. Buffer may be strided
235 // and/or require GRF boundary round-up.
236 int payload_size() const {
237 if (is_2d()) return utils::rnd_up(access_size(), grf_size());
238 int sz = payload_type_stride() * slots;
239 return utils::rnd_up(sz, grf_size());
240 }
241
242 int alignment() const {
243 if (is_2d()) return 128;
244 if (is_block()) return type.scalar().size();
245 return 1;
246 }
247
248 int mask_size() const {
249 if (is_2d()) return access_size();
250 if (is_block()) {
251 // LSC messages use SIMT1 execution mask (one mask per message).
252 if (is_lsc) return type.size();
253 return 4;
254 }
255
256 if (is_scattered()) return type.size();
257
258 ir_error_not_expected();
259 return 0;
260 }
261
262 int nmasks() const {
263 if (is_2d()) return 1;
264 int masks = ir_utils::safe_divide(type.size() * slots, mask_size());
265 if (masks > 16) {
266 ir_assert(is_block())
267 << "Round-robin masking applies to block messages only.";
268 ir_assert(masks % 16 == 0);
269 masks = 16;
270 }
271 return masks;
272 }
273
274 int address_size() const { return is_a64() ? 8 : 4; }
275
276 type_t address_type(bool is_signed = false, int elems = 1) const {
277 int bits = address_size() * 8;
278 return is_signed ? type_t::s(bits, elems) : type_t::u(bits, elems);
279 }
280
281 // Size of header in bytes.
282 int header_size() const {
283 if (is_2d()) return grf_size();
284 return utils::rnd_up(address_size() * slots, grf_size());
285 }
286
287 // Generates a statement to store (and maybe convert) the offset to the
288 // message header according to the message description.
289 stmt_t create_offset_store(const expr_t &header_buf, const expr_t &mem_buf,
290 const expr_t &mem_off, bool is_signed_offset = false) const;
291
292 bool is_supported() const;
293
294 static std::vector<func_t> get_all(ngen::HW hw, send_op_t op,
295 send_address_t address, const type_t &mem_type,
296 send_cache_hint_t cache_hint);
297
298 ngen::HW hw;
299 send_op_t op;
300 send_address_t address;
301 type_t type;
302 int slots;
303 bool is_lsc;
304
305 block_2d_info_t block_2d_info;
306 send_cache_hint_t cache_hint;
307
308private:
309 int grf_size() const { return ngen::GRF::bytes(hw); }
310
311 bool is_xe_hp_plus() const { return hw >= ngen::HW::XeHP; }
312
313 bool is_xe_hpc_plus() const { return hw >= ngen::HW::XeHPC; }
314
315 send_t(ngen::HW hw, send_op_t op, send_address_t address,
316 const type_t &type, int slots, bool is_lsc,
317 send_cache_hint_t cache_hint)
318 : func_impl_t(_type_info())
319 , hw(hw)
320 , op(op)
321 , address(address)
322 , type(type)
323 , slots(slots)
324 , is_lsc(is_lsc)
325 , cache_hint(cache_hint) {}
326
327 send_t(ngen::HW hw, send_op_t op, const type_t &type,
328 const block_2d_info_t &block_2d_info, send_cache_hint_t cache_hint)
329 : func_impl_t(_type_info())
330 , hw(hw)
331 , op(op)
332 , address(send_address_t::a64)
333 , type(type)
334 , slots(1)
335 , is_lsc(true)
336 , block_2d_info(block_2d_info)
337 , cache_hint(cache_hint) {
338 ir_assert(utils::one_of(op, send_op_t::load_2d, send_op_t::store_2d,
339 send_op_t::prefetch_2d));
340 if (is_store_2d()) {
341 ir_assert(!block_2d_info.vnni);
342 ir_assert(!block_2d_info.transpose);
343 }
344 }
345};
346
347ngen::CacheSettingsLSC get_cache_settings(const send_t &send);
348
349class memory_walker_t;
350class layout_walker_t;
351
352struct send_2d_hint_t {
353 type_t type;
354 bool enable = false;
355 bool vnni = false;
356 bool transpose = false;
357 int vnni_permute_factor = 0;
358 int width = 0;
359 int height = 0;
360};
361
362struct send_hint_t {
363 send_op_t convert(const send_op_t &op) const {
364 if (hint_2d.enable) {
365 if (op == send_op_t::load) return send_op_t::load_2d;
366 if (op == send_op_t::store) return send_op_t::store_2d;
367 if (op == send_op_t::prefetch) return send_op_t::prefetch_2d;
368 }
369 return op;
370 }
371
372 send_2d_hint_t hint_2d;
373};
374
375// Generates loads or stores to move data between memory (global or SLM) and
376// GRF. Memory view is a parameter. GRF payload layout is deduced
377// automatically, according to the decomposition into messages.
378class access_builder_t {
379public:
380 access_builder_t(ir_context_t &ir_ctx, const view_t &mem_view,
381 const expr_t &mem_buf, const expr_t &reg_buf, send_op_t send_op,
382 send_address_t send_address, send_cache_hint_t send_cache_hint,
383 send_hint_t &send_hint);
384 access_builder_t(access_builder_t &&);
385 ~access_builder_t();
386
387 const layout_t &reg_layout() const { return reg_layout_; }
388 int reg_buf_size() const {
389 return utils::rnd_up(reg_layout_.size(), grf_size());
390 }
391 const stmt_t &stmt() const { return stmt_; }
392
393 std::string str() const {
394 std::ostringstream oss;
395 oss << "Memory view: " << mem_view_ << std::endl;
396 oss << "Register layout: " << reg_layout_ << std::endl;
397 oss << "Register buffer: " << reg_buf_ << std::endl;
398 oss << "Register buffer size: " << reg_buf_size() << " ("
399 << reg_buf_size() / grf_size() << " regs)" << std::endl;
400 oss << "Statement: " << std::endl << stmt_;
401 return oss.str();
402 }
403
404private:
405 void build();
406 bool try_build(const layout_t &try_layout);
407 bool try_build_2d(send_hint_t &send_hint);
408 bool fixup_send_2d_params(const type_t &send_type, bool vnni,
409 bool transpose, bool use_xy, int &W, int &H, int &P, int &w, int &h,
410 int &c, int &vnni_permute_factor);
411
412 bool check_2d_mask(const tensor_t &tile, bool use_virtual_surface,
413 int w_idx, int h_idx, expr_t &mask) const;
414
415 std::vector<layout_t> candidate_payload_layouts() const;
416 stmt_t create_send_stmt(const send_t &send);
417 int grf_size() const { return ngen::GRF::bytes(ir_ctx_->hw_cfg().hw()); }
418
419 ir_context_t *ir_ctx_ = nullptr;
420 view_t mem_view_;
421 expr_t mem_buf_;
422 expr_t reg_buf_;
423 send_op_t send_op_;
424 send_address_t send_address_;
425 send_cache_hint_t send_cache_hint_;
426
427 type_t mem_type_;
428
429 std::unique_ptr<memory_walker_t> mem_walker_;
430 std::unique_ptr<layout_walker_t> reg_layout_walker_;
431
432 layout_t reg_layout_;
433 stmt_t stmt_;
434};
435
436inline access_builder_t make_access_builder(ir_context_t &ir_ctx,
437 const view_t &mem_view, const expr_t &mem_buf, const expr_t &reg_buf,
438 send_op_t send_op, send_address_t send_address,
439 send_hint_t &send_hint) {
440 return access_builder_t(ir_ctx, mem_view, mem_buf, reg_buf, send_op,
441 send_address, send_cache_hint_t::undef, send_hint);
442}
443
444inline access_builder_t make_access_builder(ir_context_t &ir_ctx,
445 const view_t &mem_view, const expr_t &mem_buf, const expr_t &reg_buf,
446 send_op_t send_op, send_address_t send_address,
447 send_cache_hint_t send_cache_hint = send_cache_hint_t::undef) {
448 send_hint_t send_hint;
449 return access_builder_t(ir_ctx, mem_view, mem_buf, reg_buf, send_op,
450 send_address, send_cache_hint, send_hint);
451}
452
453send_hint_t get_send_hint(const exec_config_t &exec_cfg, send_op_t send_op,
454 fma_kind_t fma_kind, abc_kind_t abc_kind, const view_t &view,
455 const gemm_schedule_t &gemm_schedule, bool allow_2d = true);
456
457inline send_hint_t get_send_hint(const exec_config_t &exec_cfg,
458 send_op_t send_op, abc_kind_t abc_kind, const view_t &view,
459 const gemm_schedule_t &gemm_schedule, bool allow_2d = true) {
460 return get_send_hint(exec_cfg, send_op, fma_kind_t::unknown, abc_kind, view,
461 gemm_schedule, allow_2d);
462}
463
464} // namespace jit
465} // namespace gpu
466} // namespace impl
467} // namespace dnnl
468
469#endif
470