1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/epilogue.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/ir/mul_add.hpp"
21#include "gpu/jit/ir/reduce.hpp"
22#include "gpu/jit/ir/reorder.hpp"
23#include "gpu/jit/utils/trace.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30// Zero pads a register buffer of f32 type.
31class zero_pad_builder_t {
32public:
33 zero_pad_builder_t() = default;
34
35 zero_pad_builder_t(ir_context_t &ir_ctx, const view_t &full_mem_view,
36 const view_t &mem_view)
37 : ir_ctx_(&ir_ctx)
38 , full_mem_view_(full_mem_view)
39 , mem_view_(mem_view) {}
40
41 bool is_empty() const { return mem_view_.is_empty(); }
42
43 expr_t create_mask(const layout_t &reg_layout, const tensor_t &tile) const {
44 ir_assert(!is_empty());
45 auto layout = reg_layout.map(tile);
46 auto view = mem_view_.create_sub_view(tile);
47 mask_tensor_t mask_tensor(layout);
48 std::vector<dim_t> args(layout.ndims());
49 fill_mask_impl(mask_tensor, 0, args, view, layout);
50 mask_tensor.simplify(ir_ctx_->cset());
51 return mask_tensor.to_expr(tile.elems());
52 }
53
54 stmt_t build_stmt(const layout_t &reg_layout, const expr_t &reg_buf) const {
55 ir_assert(mem_view_.nvdims() == reg_layout.ndims())
56 << "Incompatible view/layout.";
57 int max_step = std::min(
58 16, 2 * ir_ctx_->grf_size() / reg_layout.type().size());
59 auto base_tile = reg_layout.split_into_max_tile(
60 max_step, /*is_dense_tile=*/true);
61 stmt_t stmt;
62 reg_layout.for_each_tile(
63 base_tile, [&](const std::vector<dim_t> &start) {
64 tensor_t tile(base_tile.dims(), start);
65 int off = reg_layout(start) * reg_layout.type().size();
66 auto mask = create_mask(reg_layout, tile);
67 auto zero = to_expr(0, reg_layout.type());
68 auto store = store_t::make(reg_buf, off,
69 shuffle_t::make_broadcast(zero, tile.elems()),
70 store_t::default_stride, -mask);
71 stmt = stmt.append(store);
72 });
73 return stmt;
74 }
75
76private:
77 void fill_mask_impl(mask_tensor_t &mask_tensor, int idx,
78 std::vector<dim_t> &args, const view_t &view,
79 const layout_t &layout) const {
80 if (idx == layout.ndims()) {
81 std::vector<expr_t> vargs;
82 for (int i = 0; i < layout.ndims(); i++)
83 vargs.push_back(view.vstart(i) + args[i]);
84 expr_t mask = full_mem_view_.vmask(vargs);
85 auto off = layout.offset(args, /*ignore_offset=*/true);
86 mask_tensor.set_mask(off, mask);
87 return;
88 }
89
90 for (int i = 0; i < int(layout.dims()[idx]); i++) {
91 args[idx] = i;
92 fill_mask_impl(mask_tensor, idx + 1, args, view, layout);
93 }
94 }
95
96 ir_context_t *ir_ctx_;
97
98 view_t full_mem_view_;
99 view_t mem_view_;
100
101 stmt_t stmt_;
102};
103
104// Represents the state of a post-op tensor.
105//
106// There are three kinds of tensors:
107// - C tensor converted to f32
108// - Never loaded or stored to global memory
109// - Input tensor
110// - No store, only load
111// - Output tensor
112// - No load, only store
113//
114// Post-op tensors that are both input/output are not expected/supported as
115// they doesn't occur in convolution. Post-op tensors with global reduction
116// (like lhs += rhs) are treated as output-only and handled via atomic stores.
117//
118// A post-op tensor optionally requires:
119// - Conversion to f32 (post-ops are done in f32)
120// - Reduction
121// - For output tensors with broadcast dimensions
122// - Masking during post-ops
123// - When a post-op is not zero preserving
124class post_op_tensor_t {
125public:
126 post_op_tensor_t(ir_context_t &ir_ctx, const post_op_tensor_info_t &info)
127 : ir_ctx_(&ir_ctx), info_(info) {
128 if (!mem_buf().is_empty()) {
129 auto &type = mem_buf().type();
130 if (!type.is_ptr()) {
131 ir_assert(type.is_f32()) << "Expected f32: " << mem_buf();
132 reg_buf_ = mem_buf();
133 reg_layout_ = layout_t(
134 type, 0, std::vector<dim_t>(mem_view().nvdims(), 1));
135 }
136 }
137 }
138
139 const view_t &mem_view() const { return info_.view(); }
140
141 const expr_t &mem_buf() const { return info_.buf(); }
142
143 // Bitmask with broadcast information for the tensor:
144 // - (mask() & (1 << idx)) == 0 -> idx is a brodcast dimension (equal to 1)
145 // - (mask() & (1 << idx)) != 0 -> idx dimension matches the C dimension
146 uint32_t mask() const { return info_.mask(); }
147
148 // Placeholder variable to represent the tensor in post-op expressions.
149 const expr_t &op_var() const { return info_.op_var(); }
150
151 const layout_t &reg_layout() const { return reg_layout_; }
152
153 const expr_t &reg_buf() const { return reg_buf_; }
154
155 post_op_tensor_t create_sub_tensor(const tensor_t &_tile) const {
156 auto ret = *this;
157 auto tile = apply_mask(_tile);
158 ret.info_ = ret.info_.create_sub_tensor(tile);
159 if (!reg_layout_.is_empty()) {
160 if (needs_reduction()) {
161 tensor_t reduce_tile(_tile.dims(), tile.start());
162 ret.reg_layout_ = ret.reg_layout_.map(reduce_tile);
163 } else {
164 ret.reg_layout_ = ret.reg_layout_.map(tile);
165 }
166 }
167 ret.allocs_.clear();
168 return ret;
169 }
170
171 bool needs_load() const {
172 if (!info_.is_input()) return false;
173 if (!mem_buf().type().is_ptr()) return false;
174 return true;
175 }
176
177 bool needs_store() const { return info_.is_output(); }
178
179 bool needs_masked_update() const { return info_.needs_masked_update(); }
180
181 bool needs_f32_convert() const {
182 return !mem_view().type().is_f32() && !mem_view().type().is_f64();
183 }
184
185 bool needs_reduction() const {
186 if (!info_.is_output()) return false;
187
188 for (int i = 0; i < mem_view().nvdims(); i++) {
189 if (is_broadcast_dim(i)) {
190 if (reg_layout_.dims()[i] != 1) return true;
191 }
192 }
193 return false;
194 }
195
196 bool is_broadcast_dim(int dim_idx) const {
197 ir_assert(dim_idx >= 0 && dim_idx < mem_view().nvdims());
198 return (mask() & (1 << dim_idx)) == 0;
199 }
200
201 int estimate_grf_consumption() const {
202 int elems = int(mem_view().create_dense_vlayout().elems());
203
204 int ret = 0;
205 ret += elems * mem_view().type().size();
206 if (needs_f32_convert()) ret += elems * type_t::f32().size();
207 return ret;
208 }
209
210 void set_reg_layout(const layout_t &layout) { reg_layout_ = layout; }
211
212 void set_reg_buf(const expr_t &buf) { reg_buf_ = buf; }
213
214 void set_preload(bool value = true) { do_preload_ = value; }
215
216 bool do_preload() const { return do_preload_; }
217
218 tensor_t apply_mask(const tensor_t &tile) const {
219 ir_assert(mem_view().nvdims() == tile.ndims());
220
221 auto start = tile.start();
222 auto dims = tile.dims();
223
224 for (int i = 0; i < tile.ndims(); i++) {
225 if (!is_broadcast_dim(i)) continue;
226 start[i] = expr_t(0);
227 dims[i] = 1;
228 }
229 return tensor_t(dims, start);
230 }
231
232 void init_output_buffer(const tensor_t &tile) {
233 ir_assert(needs_store());
234
235 ir_assert(reg_layout_.is_empty());
236 ir_assert(reg_buf_.is_empty());
237
238 reg_buf_ = make_tmp_reg_buffer();
239
240 reg_layout_ = mem_view().create_dense_vlayout();
241 reg_layout_ = reg_layout_.retype(type_t::f32());
242
243 // If this is output and there are masked dimensions then this buffer
244 // is computed via reduction. Extend layout to cover full masked_tile
245 // and apply the final reduction after all tiles.
246 auto masked_tile = apply_mask(tile);
247 for (int i = 0; i < masked_tile.ndims(); i++) {
248 if (masked_tile(i) >= tile(i)) continue;
249 ir_assert(masked_tile(i) == 1) << "Unexpected output tensor shape.";
250 reg_layout_ = reg_layout_.add_outer_block(i, tile(i));
251 }
252 register_buffer(reg_buf_, reg_layout_.size());
253 }
254
255 stmt_t build_load_stmt(const view_t &c_view) {
256 ir_assert(needs_load());
257 ir_assert(reg_buf_.is_empty());
258
259 reg_buf_ = make_tmp_reg_buffer();
260 auto read = make_access_builder(*ir_ctx_, mem_view(), mem_buf(),
261 reg_buf_, send_op_t::load, send_address_t::a64,
262 get_cache_hint(c_view));
263 reg_layout_ = read.reg_layout();
264 register_buffer(reg_buf_, read.reg_buf_size());
265 return read.stmt();
266 }
267
268 stmt_t build_prefetch_stmt(const view_t &c_view) const {
269 ir_assert(needs_load());
270
271 auto prefetch = make_access_builder(*ir_ctx_, mem_view(), mem_buf(),
272 expr_t(), send_op_t::prefetch, send_address_t::a64,
273 get_cache_hint(c_view));
274 return prefetch.stmt();
275 }
276
277 stmt_t build_convert_stmt() {
278 if (!needs_load() || !needs_f32_convert()) return stmt_t();
279
280 auto f32_buf = make_tmp_reg_buffer();
281 auto f32_layout = reg_layout_.retype(type_t::f32()).make_dense();
282
283 register_buffer(f32_buf, f32_layout.size());
284
285 // Reorder to f32.
286 auto ret = create_reorder_stmt(
287 reg_layout_, f32_layout, reg_buf_, f32_buf);
288
289 // Assign new f32 layout and buffer.
290 reg_layout_ = f32_layout;
291 reg_buf_ = f32_buf;
292
293 return ret;
294 }
295
296 stmt_t build_zero_out_stmt() const {
297 ir_assert(needs_store());
298 return create_zero_out_stmt(*ir_ctx_, reg_buf_, reg_layout_.size());
299 }
300
301 stmt_t build_reduce_stmt() {
302 ir_assert(needs_store());
303
304 stmt_t stmt;
305
306 if (needs_reduction()) {
307 auto reduced_layout = mem_view().create_dense_vlayout();
308 ir_assert(reduced_layout.size() <= reg_layout_.size());
309
310 stmt = stmt.append(
311 create_reduce_stmt(reg_layout_, reduced_layout, reg_buf_,
312 reg_buf_, tensor_t(), mask(), /*drop_dims=*/false));
313 reg_layout_ = reduced_layout;
314 }
315
316 // Apply optional scaling.
317 stmt = stmt.append(create_mul_add_stmt(*ir_ctx_, reg_buf_,
318 reg_layout_.size(), reg_layout_.type(), info_.scale(), 0));
319
320 return stmt;
321 }
322
323 stmt_t build_slm_store_stmt(const grid_info_t &tg_grid) {
324 ir_assert(needs_store());
325 tensor_t tile(mem_view().vdims());
326 slm_reduce_builder_ = slm_reduce_builder_t(
327 *ir_ctx_, tg_grid, reg_buf_, reg_layout_, tile, 1);
328 return slm_reduce_builder_.store_stmt();
329 }
330
331 stmt_t build_slm_load_stmt() {
332 ir_assert(needs_store());
333 ir_assert(!slm_reduce_builder_.is_empty());
334
335 reg_layout_ = slm_reduce_builder_.reg_layout();
336
337 auto new_tile = slm_reduce_builder_.thr_tile();
338 info_ = info_.create_sub_tensor(new_tile);
339
340 auto &slm_allocs = slm_reduce_builder_.allocs();
341 allocs_.insert(allocs_.end(), slm_allocs.begin(), slm_allocs.end());
342
343 return slm_reduce_builder_.load_stmt();
344 }
345
346 stmt_t build_store_stmt() const {
347 ir_assert(needs_store());
348
349 auto write = make_access_builder(*ir_ctx_, mem_view(), mem_buf(),
350 reg_buf(), send_op_t::atomic_fadd, send_address_t::a64);
351 ir_assert(write.reg_layout() == reg_layout());
352
353 return write.stmt();
354 }
355
356 expr_t load_expr(const tensor_t &tile, int dim_idx) const {
357 auto &type = reg_layout_.type();
358 int elems = is_broadcast_dim(dim_idx) ? 1 : tile.elems();
359 int off = reg_layout_.offset_in_bytes(expr_cast<dim_t>(tile.start()));
360 auto ret = (reg_buf_.type().is_ptr()
361 ? load_t::make(type.with_elems(elems), reg_buf_, off)
362 : reg_buf_);
363 if (elems != tile.elems())
364 ret = shuffle_t::make_broadcast(ret, tile.elems());
365 return ret;
366 }
367
368 stmt_t store_stmt(const tensor_t &tile, int dim_idx, const expr_t &_value,
369 const expr_t &mask = expr_t()) const {
370 auto value = _value;
371 ir_assert(!is_broadcast_dim(dim_idx));
372 ir_assert(value.type().elems() == tile.elems());
373 // Add cast for booleans for comparison ops.
374 if (value.type().is_bool()) {
375 value = cast(value, reg_layout_.type().with_elems(tile.elems()));
376 }
377 int off = reg_layout_.offset_in_bytes(expr_cast<dim_t>(tile.start()));
378 auto ret = store_t::make(
379 reg_buf_, off, value, store_t::default_stride, mask);
380 return ret;
381 }
382
383 const std::vector<stmt_t> &allocs() const { return allocs_; }
384
385private:
386 expr_t make_tmp_reg_buffer() {
387 auto *var = mem_buf().as_ptr<var_t>();
388 if (!var) {
389 auto *ptr = mem_buf().as_ptr<ptr_t>();
390 if (ptr) var = ptr->base.as_ptr<var_t>();
391 }
392 ir_assert(var) << "Can't extract variable from buffer: " << mem_buf();
393 auto &name = var->name;
394 return ir_ctx_->create_tmp_var(type_t::byte_ptr(), "tmp_" + name);
395 }
396
397 void register_buffer(const expr_t &buf, int size) {
398 for (auto &_a : allocs_) {
399 auto &a = _a.as<alloc_t>();
400 if (a.buf.is_same(buf)) {
401 if (size > a.size) {
402 _a = alloc_t::make(a.buf, a.size, a.kind, a.attrs);
403 }
404 return;
405 }
406 }
407 allocs_.push_back(alloc_t::make(buf, size, alloc_kind_t::grf));
408 }
409
410 send_cache_hint_t get_cache_hint(const view_t &c_view) const {
411 ir_assert(mem_view().nvdims() == c_view.nvdims());
412 bool per_tensor = true;
413 for (int i = 0; i < mem_view().nvdims(); i++) {
414 if ((mask() & (1 << i)) != 0) continue;
415 if (c_view.vdims()[i] == 1) continue;
416 per_tensor = false;
417 break;
418 }
419 if (per_tensor) return send_cache_hint_t::load_once;
420 return send_cache_hint_t::undef;
421 }
422
423 ir_context_t *ir_ctx_ = nullptr;
424
425 post_op_tensor_info_t info_;
426
427 layout_t reg_layout_;
428 expr_t reg_buf_;
429
430 bool do_preload_ = false;
431
432 std::vector<stmt_t> allocs_;
433
434 slm_reduce_builder_t slm_reduce_builder_;
435};
436
437// Applies substitutions and broadcasts to generate the final post-op
438// expression.
439class post_op_bcast_mutator_t : public ir_mutator_t {
440public:
441 post_op_bcast_mutator_t(
442 int elems, const object_map_t<object_t, object_t> &from2to)
443 : elems_(elems), from2to_(from2to) {}
444
445 object_t _mutate(const float_imm_t &obj) override {
446 return make_bcast(obj);
447 }
448
449 object_t _mutate(const int_imm_t &obj) override {
450 return make_bcast(float_imm_t::make(obj.value));
451 }
452
453 object_t _mutate(const var_t &obj) override {
454 auto it = from2to_.find(obj);
455 if (it != from2to_.end()) return make_bcast(it->second);
456
457 ir_error_not_expected() << "Unknown variable.";
458 return obj;
459 }
460
461private:
462 object_t make_bcast(const expr_t &e) const {
463 if (e.type().elems() == elems_) return e;
464 ir_assert(e.type().elems() == 1);
465 return shuffle_t::make_broadcast(e, elems_);
466 }
467
468 int elems_;
469 object_map_t<object_t, object_t> from2to_;
470};
471
472// Builds statements to apply a post-op for a given tile.
473class post_op_builder_t {
474public:
475 post_op_builder_t(ngen::HW hw, const post_op_t &post_op)
476 : hw_(hw), post_op_(post_op) {}
477
478 const post_op_t &post_op() const { return post_op_; }
479
480 // Applies post-op for a single tile.
481 stmt_t build_tile_stmt(const tensor_t &tile,
482 const object_map_t<expr_t, post_op_tensor_t *> &args,
483 const zero_pad_builder_t &zero_pad_builder) const {
484 auto &lhs_tensor = *args.at(post_op_.lhs());
485 if (!post_op_.eltwise().is_empty()) {
486 // Apply eltwise post-op.
487 ir_assert(post_op_.lhs().is_equal(post_op_.rhs()))
488 << "Only supported form is lhs = eltwise(lhs).";
489 int lhs_size = lhs_tensor.reg_layout().size();
490 int lhs_elems = lhs_size / int(sizeof(float));
491 return post_op_.eltwise().call(
492 {expr_t(lhs_elems), lhs_tensor.reg_buf()});
493 }
494
495 int inner_dim_idx = -1;
496 auto base_inner_tile = find_1d_tile(
497 lhs_tensor.reg_layout().type(), args, inner_dim_idx);
498 auto inner_layout = lhs_tensor.reg_layout().map(base_inner_tile);
499 ir_assert(inner_dim_idx != -1);
500
501 // All post-ops are performed in f32 except f64 bias.
502 for (auto &kv : args) {
503 ir_assert(kv.second->reg_layout().type().is_f32()
504 || kv.second->reg_layout().type().is_f64());
505 }
506
507 // Handle one inner tile at a time. Inner tile covers a single block
508 // within a single dimension.
509 stmt_t stmt;
510 lhs_tensor.reg_layout().for_each_tile(
511 base_inner_tile, [&](const std::vector<dim_t> &lhs_start) {
512 tensor_t inner_tile(base_inner_tile.dims(), lhs_start);
513 auto rhs_value = compute_post_op_expr(
514 post_op_.rhs(), inner_tile, inner_dim_idx, args);
515 auto &t = *args.at(post_op_.lhs());
516 expr_t store_mask;
517 if (lhs_tensor.needs_masked_update()) {
518 store_mask = zero_pad_builder.create_mask(
519 inner_layout, inner_tile);
520 }
521 auto inner_stmt = t.store_stmt(
522 inner_tile, inner_dim_idx, rhs_value, store_mask);
523 stmt = stmt.append(inner_stmt);
524 });
525
526 return stmt;
527 }
528
529private:
530 // Returns a 1D tile corresponding to an instruction to partially apply the
531 // post-op.
532 tensor_t find_1d_tile(const type_t &lhs_type,
533 const object_map_t<expr_t, post_op_tensor_t *> &args,
534 int &inner_dim_idx) const {
535 auto &lhs_tensor = *args.at(post_op_.lhs());
536
537 int ndims = lhs_tensor.mem_view().nvdims();
538
539 ir_assert(!lhs_tensor.reg_layout().is_empty());
540 ir_assert(!lhs_tensor.reg_layout().blocks().empty());
541 auto &b0 = lhs_tensor.reg_layout().blocks()[0];
542 ir_assert(dim_t(b0.stride) == 1);
543 inner_dim_idx = b0.dim_idx;
544
545 int inner_block = b0.block;
546 int max_step = 2 * ngen::GRF::bytes(hw_) / lhs_type.size();
547 inner_block = std::max(8, math::gcd(inner_block, max_step));
548
549 for (auto &kv : args) {
550 auto &t = *kv.second;
551 if (t.is_broadcast_dim(b0.dim_idx)) continue;
552
553 auto &l = t.reg_layout();
554 ir_assert(!l.is_empty());
555 ir_assert(!l.blocks().empty());
556 auto &lb0 = l.blocks()[0];
557 ir_assert(lb0.dim_idx == b0.dim_idx);
558 ir_assert(dim_t(lb0.stride) == 1);
559 inner_block = math::gcd(int(lb0.block), inner_block);
560 }
561
562 std::vector<dim_t> dims(ndims, 1);
563 dims[b0.dim_idx] = inner_block;
564
565 return tensor_t(dims);
566 }
567
568 expr_t compute_post_op_expr(const expr_t &expr, const tensor_t &tile,
569 int dim_idx,
570 const object_map_t<expr_t, post_op_tensor_t *> &args) const {
571 object_map_t<object_t, object_t> sub_map;
572 for (auto &kv : args) {
573 auto &t = *kv.second;
574 auto te = t.load_expr(tile, dim_idx);
575 sub_map.insert({t.op_var(), te});
576 }
577 post_op_bcast_mutator_t bcast_mutator(tile.elems(), sub_map);
578 return bcast_mutator.mutate(expr);
579 }
580
581 ngen::HW hw_;
582 post_op_t post_op_;
583};
584
585// Epilogue consists of the following steps after the main computation (C += A * B):
586// - C GRF reorder to match the memory layout for global memory store
587// - C conversion to f32 (if there are post-ops)
588// - Applying post-ops (if there are any)
589// - C conversion to the memory layout data type
590// - C store to global memory
591// - Reduction and storing output post-op tensors
592//
593// In general C tensor is updated/transformed following the C stages described
594// below. Each C stage is associated with GRF buffer and its layout.
595// Multiplication ->
596// M_x -> [R_f32] -> [P0_f32] -> ... -> [Pn_f32] -> [Z_f32] -> S_y ->
597// GMEM
598//
599// Where:
600// - x is data type after multiplication
601// - y is destination data type
602// - M_x is the stage after multiplication
603// - R_f32 is the stage after reordering from M_x to f32 (optional)
604// - Pi_f32 is the stage after applying Pi post-op (optional)
605// - Z_f32 is the stage after restoring zero padding (optional)
606// - S_y is the stage before storing C to global memory
607class epilogue_builder_t {
608public:
609 epilogue_builder_t(ir_context_t &ir_ctx, const conv_config_t &cfg,
610 const gemm_schedule_t &gemm_schedule,
611 const post_op_context_t &post_op_ctx, const tensor_t &thr_tile,
612 const view_t &c_mem_view, const layout_t &c_reg_layout,
613 const expr_t &c_mem_buf, const expr_t &c_reg_buf, int tile_size,
614 int preload_max_size, int post_op_blk)
615 : ir_ctx_(ir_ctx)
616 , cfg_(cfg)
617 , gemm_schedule_(gemm_schedule)
618 , post_op_ctx_(post_op_ctx)
619 , c_mem_view_(c_mem_view)
620 , c_mem_buf_(c_mem_buf)
621 , tg_grid_(gemm_schedule.tg_grid())
622 , tile_size_(tile_size)
623 , preload_max_size_(preload_max_size)
624 , post_op_blk_(post_op_blk) {
625
626 int tensor_idx = 0;
627 for (auto &po_tensor_info : post_op_ctx_.post_op_tensor_infos()) {
628 post_op_tensor_t po_tensor(ir_ctx_, po_tensor_info);
629 po_tensor = po_tensor.create_sub_tensor(thr_tile);
630 if (po_tensor_info.buf().is_empty()) {
631 // C tensor.
632 ir_assert(c_po_idx_ == -1);
633 c_po_idx_ = tensor_idx;
634 }
635 post_op_tensors_.push_back(po_tensor);
636 tensor_idx++;
637 }
638
639 restore_zero_padding_ = post_op_ctx_.need_to_restore_zero_padding();
640
641 for (auto &po : post_op_ctx_.post_ops()) {
642 post_op_builders_.emplace_back(ir_ctx_.hw(), po);
643 }
644
645 // Estimate buffer sizes required to load the full tensor, do not do
646 // preload if it requires too much GRF memory.
647 int available_size = preload_max_size_;
648 for (auto &t : post_op_tensors_) {
649 if (!t.needs_load()) continue;
650 int required_size = t.estimate_grf_consumption();
651 if (required_size > available_size) continue;
652 available_size -= required_size;
653 t.set_preload();
654 }
655
656 build(c_reg_layout, c_reg_buf);
657 }
658
659 const stmt_t &stmt() const { return stmt_; }
660
661private:
662 void register_buffer(const expr_t &buf, int size) {
663 buf_sizes_[buf] = std::max(buf_sizes_[buf], size);
664 }
665
666 expr_t make_c_tmp_buffer() const {
667 return ir_ctx_.create_tmp_var(type_t::byte_ptr(), "c_tmp");
668 }
669
670 // Represents a GRF buffer and layout to store C tensor.
671 struct c_stage_t {
672 c_stage_t(const layout_t &layout, const expr_t &buf,
673 const stmt_t &stmt = stmt_t())
674 : layout(layout), buf(buf), stmt(stmt) {}
675
676 void set_next(
677 ir_context_t &ir_ctx, c_stage_t *next, bool force_reorder) {
678 if (!next) return;
679 bool do_reorder
680 = !layout.is_equal(next->layout, /*compare_offset=*/false);
681 if (force_reorder) do_reorder = true;
682 if (do_reorder) {
683 ir_assert(stmt.is_empty());
684 // Generate reorder between stages.
685 stmt = create_reorder_stmt(
686 layout, next->layout, buf, next->buf);
687 } else {
688 // Reuse the same GRF buffer for the next stage.
689 int this_off = to_cpp<int>(layout.offset_in_bytes());
690 int next_off = to_cpp<int>(next->layout.offset_in_bytes());
691 ir_assert(next_off == 0);
692 MAYBE_UNUSED(next_off);
693 next->set_buf(buf[this_off]);
694 }
695 }
696
697 void set_buf(const expr_t &buf) {
698 // Replace old buffer if there is an assigned statement.
699 if (!stmt.is_empty()) { stmt = substitute(stmt, this->buf, buf); }
700 this->buf = buf;
701 }
702
703 const expr_t &buf_base() const {
704 if (buf.is<var_t>()) return buf;
705 return buf.as<ptr_t>().base;
706 }
707
708 int buf_size() const {
709 ir_assert(buf.is_same(buf_base()))
710 << "Size must be queried from another stage.";
711 return int(layout.size());
712 }
713
714 void prepend_stmt(const stmt_t &stmt) {
715 this->stmt = stmt.append(this->stmt);
716 }
717
718 layout_t layout;
719 expr_t buf;
720 stmt_t stmt; // Statement to emit after the stage.
721 };
722
723 void build(const layout_t &c_reg_layout, const expr_t &c_reg_buf) {
724 auto tmp_type = (post_op_builders_.empty() ? c_mem_view_.type()
725 : type_t::f32());
726 int tmp_buf_elems = tile_size_ / tmp_type.size();
727 auto base_tile = c_mem_view_.split_into_max_tile(
728 tmp_buf_elems, /*is_dense=*/false);
729
730 // Generate preload statements.
731 for (auto &t : post_op_tensors_) {
732 if (!t.do_preload()) continue;
733 stmt_ = stmt_.append(t.build_load_stmt(c_mem_view_));
734 }
735
736 // Generate prefetch statements.
737 if (ir_ctx_.hw() >= ngen::HW::XeHPC) {
738 for (auto &t : post_op_tensors_) {
739 if (!t.needs_load()) continue;
740 if (t.do_preload()) continue;
741 stmt_ = stmt_.append(t.build_prefetch_stmt(c_mem_view_));
742 }
743 }
744
745 // Generate f32 convert statements.
746 for (auto &t : post_op_tensors_) {
747 if (!t.do_preload()) continue;
748 if (!t.needs_f32_convert()) continue;
749 stmt_ = stmt_.append(t.build_convert_stmt());
750 }
751
752 // Initialize buffers for output post-op tensors.
753 for (auto &t : post_op_tensors_) {
754 if (!t.needs_store()) continue;
755 t.init_output_buffer(base_tile);
756 }
757
758 // Generate zero-out statements for output post-op tensors.
759 for (auto &t : post_op_tensors_) {
760 if (!t.needs_store()) continue;
761 stmt_ = stmt_.append(t.build_zero_out_stmt());
762 }
763
764 // Iterate by tiles and apply post-ops.
765 c_mem_view_.for_each_tile(
766 base_tile, [&](const std::vector<dim_t> &start) {
767 tensor_t tile(base_tile.dims(), start);
768 auto c_tile_layout = c_reg_layout.map(tile);
769 build_tile(tile, c_tile_layout, c_reg_buf);
770 });
771
772 // TODO: Generalize the condition. Iterate through output tensor masks
773 // and ensure C is distributed accordingly in thread group.
774 bool use_slm_reduction = (tg_grid_.dim(1) > 1);
775
776 // Generate reduce and store statements for output post-op tensors.
777 stmt_t thr_reduce_stmt;
778 stmt_t slm_store_stmt;
779 stmt_t slm_load_stmt;
780 stmt_t mem_store_stmt;
781 for (auto &t : post_op_tensors_) {
782 if (!t.needs_store()) continue;
783
784 thr_reduce_stmt = thr_reduce_stmt.append(t.build_reduce_stmt());
785 if (use_slm_reduction) {
786 auto store_stmt = t.build_slm_store_stmt(tg_grid_);
787 auto load_stmt = t.build_slm_load_stmt();
788 slm_store_stmt = slm_store_stmt.append(store_stmt);
789 slm_load_stmt = slm_load_stmt.append(load_stmt);
790 }
791 mem_store_stmt = mem_store_stmt.append(t.build_store_stmt());
792 }
793
794 stmt_ = stmt_.append(thr_reduce_stmt);
795 if (!slm_store_stmt.is_empty()) {
796 stmt_ = stmt_.append(funcs::barrier());
797 stmt_ = stmt_.append(slm_store_stmt);
798 stmt_ = stmt_.append(funcs::barrier());
799 stmt_ = stmt_.append(slm_load_stmt);
800 }
801
802 stmt_ = stmt_.append(mem_store_stmt);
803
804 // Generate alloc statements for post-op tensors.
805 std::vector<stmt_t> allocs;
806 for (auto &t : post_op_tensors_) {
807 auto t_allocs = t.allocs();
808 allocs.insert(allocs.end(), t_allocs.begin(), t_allocs.end());
809 }
810 stmt_ = jit::inject_alloc_stmts(stmt_, allocs, /*put_innermost=*/true);
811 }
812
813 // Builds statements for a tile iterating through all post-ops.
814 void build_tile(const tensor_t &tile, const layout_t &c_tile_layout,
815 const expr_t &c_reg_buf) {
816 auto c_mem_tile_view = c_mem_view_.create_sub_view(tile);
817 auto tmp_reg_buf = make_c_tmp_buffer();
818
819 type_t post_op_type
820 = c_tile_layout.type().is_f64() ? type_t::f64() : type_t::f32();
821 bool create_zero_pad_builder = restore_zero_padding_;
822 for (auto &t : post_op_tensors_) {
823 if (t.needs_masked_update()) {
824 create_zero_pad_builder = true;
825 break;
826 }
827 }
828 if (create_zero_pad_builder) {
829 zero_pad_builder_ = zero_pad_builder_t(
830 ir_ctx_, post_op_ctx_.cp_view(), c_mem_tile_view);
831 }
832
833 // S_y -> GMEM.
834 auto send_op = gemm_schedule_.with_kernel_grid_k_slicing()
835 ? send_op_t::atomic_fadd
836 : send_op_t::store;
837 auto send_hint = get_send_hint(ir_ctx_.exec_cfg(), send_op,
838 abc_kind_t::c, c_mem_tile_view, gemm_schedule_);
839 auto r2g = make_access_builder(ir_ctx_, c_mem_tile_view, c_mem_buf_,
840 tmp_reg_buf, send_op, send_address_t::a64, send_hint);
841
842 // Initialize C stages.
843 std::vector<c_stage_t> c_stages;
844
845 auto c_fx_layout = r2g.reg_layout().retype(post_op_type).make_dense();
846 bool with_post_ops = !post_op_builders_.empty();
847 int npost_ops = int(post_op_builders_.size());
848
849 int c_f32_stage_idx = -1;
850 int c_zero_pad_stage_idx = -1;
851
852 c_stages.emplace_back(c_tile_layout, c_reg_buf); // M_x
853 if (with_post_ops) {
854 c_f32_stage_idx = int(c_stages.size());
855 c_stages.emplace_back(c_fx_layout, make_c_tmp_buffer()); // R_f32
856 }
857 if (restore_zero_padding_) {
858 c_zero_pad_stage_idx = int(c_stages.size());
859 c_stages.emplace_back(c_fx_layout, make_c_tmp_buffer()); // Z_f32
860 }
861 c_stages.emplace_back(r2g.reg_layout(), tmp_reg_buf, r2g.stmt()); // S_y
862
863 int nstages = int(c_stages.size());
864 bool is_dpasw = (cfg_.fma_kind() == fma_kind_t::dpasw);
865
866 // Generate reorders between C stages if needed.
867 for (int i = 0; i < nstages; i++) {
868 auto *next_stage = (i + 1 < nstages ? &c_stages[i + 1] : nullptr);
869 // Always perform reorder when dpasw is used. This is to ensure
870 // that C is properly restored and permuted after dpasw.
871 c_stages[i].set_next(ir_ctx_, next_stage,
872 /*force_reorder=*/i == 0 && is_dpasw);
873 }
874
875 // Restore zero padding if needed.
876 if (c_zero_pad_stage_idx != -1) {
877 auto &s = c_stages[c_zero_pad_stage_idx];
878 s.prepend_stmt(zero_pad_builder_.build_stmt(s.layout, s.buf));
879 }
880
881 // Create sub-tensors for post-ops.
882 std::vector<post_op_tensor_t> sub_po_tensors;
883 for (auto &t : post_op_tensors_)
884 sub_po_tensors.push_back(t.create_sub_tensor(tile));
885
886 // Set C tensor layout and buffer to use for post-ops.
887 if (c_f32_stage_idx != -1) {
888 auto &s = c_stages[c_f32_stage_idx];
889 sub_po_tensors[c_po_idx_].set_reg_layout(s.layout);
890 sub_po_tensors[c_po_idx_].set_reg_buf(s.buf);
891 }
892
893 stmt_t tile_stmt;
894
895 // Add C stage statements and post-op statements.
896 for (int i = 0; i < nstages; i++) {
897 if (with_post_ops && i == c_f32_stage_idx) {
898 // Emit post-ops in blocks to reduce GRF consumption.
899 for (int j = 0; j < npost_ops; j += post_op_blk_) {
900 int k_beg = j;
901 int k_end = std::min(npost_ops, j + post_op_blk_);
902 auto blk_stmt = build_post_op_block_stmt(
903 tile, sub_po_tensors, k_beg, k_end);
904 tile_stmt = tile_stmt.append(blk_stmt);
905 }
906 }
907 tile_stmt = tile_stmt.append(c_stages[i].stmt);
908 }
909
910 // Generate alloc statements for C stage buffers.
911 object_set_t<expr_t> seen;
912 for (int i = 0; i < nstages; i++) {
913 auto &s = c_stages[i];
914 auto &buf = s.buf_base();
915 auto ret = seen.insert(buf);
916 if (i == 0 || !ret.second) continue;
917 int size = utils::rnd_up(s.buf_size(), ir_ctx_.grf_size());
918 tile_stmt = alloc_t::make(buf, size, alloc_kind_t::grf, tile_stmt);
919 }
920
921 stmt_ = stmt_.append(tile_stmt);
922 }
923
924 stmt_t build_post_op_block_stmt(const tensor_t &tile,
925 std::vector<post_op_tensor_t> &sub_po_tensors, int po_beg,
926 int po_end) const {
927 // Collect post-op inputs/outputs.
928 object_map_t<expr_t, post_op_tensor_t *> args;
929 for (int i = po_beg; i < po_end; i++) {
930 auto &po_builder = post_op_builders_[i];
931 for (auto &t : sub_po_tensors) {
932 if (po_builder.post_op().uses(t.op_var())) {
933 args.insert({t.op_var(), &t});
934 }
935 }
936 }
937
938 // Generate load and convert statements for the post-op.
939 stmt_t load_stmt;
940 stmt_t convert_stmt;
941 for (auto &kv : args) {
942 auto &t = *kv.second;
943 if (!t.needs_load()) continue;
944 if (t.do_preload()) continue;
945 load_stmt = load_stmt.append(t.build_load_stmt(c_mem_view_));
946 if (t.needs_f32_convert()) {
947 convert_stmt = convert_stmt.append(t.build_convert_stmt());
948 }
949 }
950
951 stmt_t stmt;
952 stmt = stmt.append(load_stmt);
953 stmt = stmt.append(convert_stmt);
954
955 for (int i = po_beg; i < po_end; i++) {
956 auto &po_builder = post_op_builders_[i];
957 auto po_stmt
958 = po_builder.build_tile_stmt(tile, args, zero_pad_builder_);
959 stmt = stmt.append(po_stmt);
960 }
961
962 // Generate alloc statements for post-op tensors.
963 std::vector<stmt_t> allocs;
964 for (auto &kv : args) {
965 auto &t = *kv.second;
966 if (!t.needs_load()) continue;
967 if (t.do_preload()) continue;
968 auto t_allocs = t.allocs();
969 allocs.insert(allocs.end(), t_allocs.begin(), t_allocs.end());
970 }
971 stmt = jit::inject_alloc_stmts(stmt, allocs);
972
973 return stmt;
974 }
975
976 ir_context_t &ir_ctx_;
977 const conv_config_t &cfg_;
978 const gemm_schedule_t &gemm_schedule_;
979 const post_op_context_t &post_op_ctx_;
980
981 // C view in global memory.
982 view_t c_mem_view_;
983 expr_t c_mem_buf_;
984
985 // C layout after the main loop.
986 layout_t c_reg_layout_;
987 expr_t c_reg_buf_;
988
989 const grid_info_t &tg_grid_;
990
991 bool restore_zero_padding_;
992
993 zero_pad_builder_t zero_pad_builder_;
994
995 // Tile size in bytes. The tile data type is:
996 // - the destination data type without post-ops
997 // - f32 with post-ops
998 int tile_size_;
999 int preload_max_size_;
1000 int post_op_blk_;
1001
1002 std::vector<post_op_builder_t> post_op_builders_;
1003 std::vector<post_op_tensor_t> post_op_tensors_;
1004 int c_po_idx_ = -1;
1005
1006 object_map_t<expr_t, int> buf_sizes_;
1007
1008 stmt_t stmt_;
1009};
1010
1011int get_post_op_mem_usage(const post_op_tensor_info_t &info, int c_elems,
1012 const view_t &c_mem_view, int max_elems_per_dim = 64) {
1013 int po_elems = 1;
1014 for (int i = 0; i < info.view().nvdims(); i++) {
1015 if ((info.mask() & (1 << i)) == 0) continue;
1016 po_elems *= std::min(max_elems_per_dim, (int)c_mem_view.vdims()[i]);
1017 }
1018 po_elems = std::min(po_elems, c_elems);
1019 int type_size = info.view().type().size();
1020 int load_size = po_elems * type_size;
1021 int cvt_size = info.view().type().is_f32() ? 0 : po_elems * sizeof(float);
1022 return load_size + cvt_size;
1023}
1024
1025int find_tile_size(const exec_config_t &exec_cfg,
1026 const post_op_context_t &post_op_ctx, const view_t &c_mem_view,
1027 const layout_t &c_reg_layout, int preload_max_size, int post_op_blk) {
1028 bool with_post_ops = !post_op_ctx.post_ops().empty();
1029 for (int tile_size = 1024; tile_size >= 1; tile_size /= 2) {
1030 int c_type_size = c_mem_view.type().size();
1031 int elems = tile_size / (with_post_ops ? sizeof(float) : c_type_size);
1032 int c_mul_size = elems * c_type_size;
1033 int c_f32_size = with_post_ops && !c_mem_view.type().is_f32()
1034 ? elems * sizeof(float)
1035 : 0;
1036 int c_size = c_mul_size + c_f32_size;
1037 int po_size = 0;
1038
1039 auto &infos = post_op_ctx.post_op_tensor_infos();
1040 int npost_ops = int(infos.size());
1041 for (int i = 0; i < npost_ops; i += post_op_blk) {
1042 int po_batch_size = 0;
1043 for (int j = i; j < std::min(npost_ops, i + post_op_blk); j++) {
1044 auto &t = infos[j];
1045 if (!t.is_input() || !t.buf().type().is_ptr()) continue;
1046 po_batch_size += get_post_op_mem_usage(t, elems, c_mem_view);
1047 }
1048 po_size = std::max(po_size, po_batch_size);
1049 }
1050
1051 int total_size = c_size + preload_max_size + po_size;
1052 int available_size = exec_cfg.regs() * exec_cfg.grf_size()
1053 - (int)c_reg_layout.size();
1054 if (total_size <= available_size * 0.7) return tile_size;
1055 }
1056 ir_error_not_expected();
1057 return -1;
1058}
1059
1060stmt_t create_epilogue_stmt(const conv_config_t &cfg, ir_context_t &ir_ctx,
1061 const gemm_schedule_t &gemm_schedule,
1062 const post_op_context_t &post_op_ctx, const tensor_t &thr_tile,
1063 const view_t &c_mem_view, const layout_t &c_reg_layout,
1064 const expr_t &c_mem_buf, const expr_t &c_reg_buf) {
1065 // Max size of post-op tensor buffers to preload and reuse for all tiles.
1066 int preload_max_size = 512;
1067 // Block size to apply post-ops within tile. A post-op may have associated
1068 // loads/conversions, larger block size helps to have more latency hiding
1069 // across multiple post-ops.
1070 int post_op_blk = 8;
1071 // Tile size in bytes. All post-ops are applied to a single tile, then to
1072 // the next tile, etc.
1073 int tile_size = find_tile_size(cfg.exec_cfg(), post_op_ctx, c_mem_view,
1074 c_reg_layout, preload_max_size, post_op_blk);
1075
1076 ir_trace() << "Creating epilogue with parameters"
1077 << ": tile_size = " << tile_size
1078 << ", preload_max_size = " << preload_max_size
1079 << ", post_op_blk = " << post_op_blk << std::endl;
1080 epilogue_builder_t builder(ir_ctx, cfg, gemm_schedule, post_op_ctx,
1081 thr_tile, c_mem_view, c_reg_layout, c_mem_buf, c_reg_buf, tile_size,
1082 preload_max_size, post_op_blk);
1083 return builder.stmt();
1084}
1085
1086} // namespace jit
1087} // namespace gpu
1088} // namespace impl
1089} // namespace dnnl
1090