1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/epilogue.hpp" |
18 | |
19 | #include "gpu/jit/ir/message.hpp" |
20 | #include "gpu/jit/ir/mul_add.hpp" |
21 | #include "gpu/jit/ir/reduce.hpp" |
22 | #include "gpu/jit/ir/reorder.hpp" |
23 | #include "gpu/jit/utils/trace.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace jit { |
29 | |
30 | // Zero pads a register buffer of f32 type. |
31 | class zero_pad_builder_t { |
32 | public: |
33 | zero_pad_builder_t() = default; |
34 | |
35 | zero_pad_builder_t(ir_context_t &ir_ctx, const view_t &full_mem_view, |
36 | const view_t &mem_view) |
37 | : ir_ctx_(&ir_ctx) |
38 | , full_mem_view_(full_mem_view) |
39 | , mem_view_(mem_view) {} |
40 | |
41 | bool is_empty() const { return mem_view_.is_empty(); } |
42 | |
43 | expr_t create_mask(const layout_t ®_layout, const tensor_t &tile) const { |
44 | ir_assert(!is_empty()); |
45 | auto layout = reg_layout.map(tile); |
46 | auto view = mem_view_.create_sub_view(tile); |
47 | mask_tensor_t mask_tensor(layout); |
48 | std::vector<dim_t> args(layout.ndims()); |
49 | fill_mask_impl(mask_tensor, 0, args, view, layout); |
50 | mask_tensor.simplify(ir_ctx_->cset()); |
51 | return mask_tensor.to_expr(tile.elems()); |
52 | } |
53 | |
54 | stmt_t build_stmt(const layout_t ®_layout, const expr_t ®_buf) const { |
55 | ir_assert(mem_view_.nvdims() == reg_layout.ndims()) |
56 | << "Incompatible view/layout." ; |
57 | int max_step = std::min( |
58 | 16, 2 * ir_ctx_->grf_size() / reg_layout.type().size()); |
59 | auto base_tile = reg_layout.split_into_max_tile( |
60 | max_step, /*is_dense_tile=*/true); |
61 | stmt_t stmt; |
62 | reg_layout.for_each_tile( |
63 | base_tile, [&](const std::vector<dim_t> &start) { |
64 | tensor_t tile(base_tile.dims(), start); |
65 | int off = reg_layout(start) * reg_layout.type().size(); |
66 | auto mask = create_mask(reg_layout, tile); |
67 | auto zero = to_expr(0, reg_layout.type()); |
68 | auto store = store_t::make(reg_buf, off, |
69 | shuffle_t::make_broadcast(zero, tile.elems()), |
70 | store_t::default_stride, -mask); |
71 | stmt = stmt.append(store); |
72 | }); |
73 | return stmt; |
74 | } |
75 | |
76 | private: |
77 | void fill_mask_impl(mask_tensor_t &mask_tensor, int idx, |
78 | std::vector<dim_t> &args, const view_t &view, |
79 | const layout_t &layout) const { |
80 | if (idx == layout.ndims()) { |
81 | std::vector<expr_t> vargs; |
82 | for (int i = 0; i < layout.ndims(); i++) |
83 | vargs.push_back(view.vstart(i) + args[i]); |
84 | expr_t mask = full_mem_view_.vmask(vargs); |
85 | auto off = layout.offset(args, /*ignore_offset=*/true); |
86 | mask_tensor.set_mask(off, mask); |
87 | return; |
88 | } |
89 | |
90 | for (int i = 0; i < int(layout.dims()[idx]); i++) { |
91 | args[idx] = i; |
92 | fill_mask_impl(mask_tensor, idx + 1, args, view, layout); |
93 | } |
94 | } |
95 | |
96 | ir_context_t *ir_ctx_; |
97 | |
98 | view_t full_mem_view_; |
99 | view_t mem_view_; |
100 | |
101 | stmt_t stmt_; |
102 | }; |
103 | |
104 | // Represents the state of a post-op tensor. |
105 | // |
106 | // There are three kinds of tensors: |
107 | // - C tensor converted to f32 |
108 | // - Never loaded or stored to global memory |
109 | // - Input tensor |
110 | // - No store, only load |
111 | // - Output tensor |
112 | // - No load, only store |
113 | // |
114 | // Post-op tensors that are both input/output are not expected/supported as |
115 | // they doesn't occur in convolution. Post-op tensors with global reduction |
116 | // (like lhs += rhs) are treated as output-only and handled via atomic stores. |
117 | // |
118 | // A post-op tensor optionally requires: |
119 | // - Conversion to f32 (post-ops are done in f32) |
120 | // - Reduction |
121 | // - For output tensors with broadcast dimensions |
122 | // - Masking during post-ops |
123 | // - When a post-op is not zero preserving |
124 | class post_op_tensor_t { |
125 | public: |
126 | post_op_tensor_t(ir_context_t &ir_ctx, const post_op_tensor_info_t &info) |
127 | : ir_ctx_(&ir_ctx), info_(info) { |
128 | if (!mem_buf().is_empty()) { |
129 | auto &type = mem_buf().type(); |
130 | if (!type.is_ptr()) { |
131 | ir_assert(type.is_f32()) << "Expected f32: " << mem_buf(); |
132 | reg_buf_ = mem_buf(); |
133 | reg_layout_ = layout_t( |
134 | type, 0, std::vector<dim_t>(mem_view().nvdims(), 1)); |
135 | } |
136 | } |
137 | } |
138 | |
139 | const view_t &mem_view() const { return info_.view(); } |
140 | |
141 | const expr_t &mem_buf() const { return info_.buf(); } |
142 | |
143 | // Bitmask with broadcast information for the tensor: |
144 | // - (mask() & (1 << idx)) == 0 -> idx is a brodcast dimension (equal to 1) |
145 | // - (mask() & (1 << idx)) != 0 -> idx dimension matches the C dimension |
146 | uint32_t mask() const { return info_.mask(); } |
147 | |
148 | // Placeholder variable to represent the tensor in post-op expressions. |
149 | const expr_t &op_var() const { return info_.op_var(); } |
150 | |
151 | const layout_t ®_layout() const { return reg_layout_; } |
152 | |
153 | const expr_t ®_buf() const { return reg_buf_; } |
154 | |
155 | post_op_tensor_t create_sub_tensor(const tensor_t &_tile) const { |
156 | auto ret = *this; |
157 | auto tile = apply_mask(_tile); |
158 | ret.info_ = ret.info_.create_sub_tensor(tile); |
159 | if (!reg_layout_.is_empty()) { |
160 | if (needs_reduction()) { |
161 | tensor_t reduce_tile(_tile.dims(), tile.start()); |
162 | ret.reg_layout_ = ret.reg_layout_.map(reduce_tile); |
163 | } else { |
164 | ret.reg_layout_ = ret.reg_layout_.map(tile); |
165 | } |
166 | } |
167 | ret.allocs_.clear(); |
168 | return ret; |
169 | } |
170 | |
171 | bool needs_load() const { |
172 | if (!info_.is_input()) return false; |
173 | if (!mem_buf().type().is_ptr()) return false; |
174 | return true; |
175 | } |
176 | |
177 | bool needs_store() const { return info_.is_output(); } |
178 | |
179 | bool needs_masked_update() const { return info_.needs_masked_update(); } |
180 | |
181 | bool needs_f32_convert() const { |
182 | return !mem_view().type().is_f32() && !mem_view().type().is_f64(); |
183 | } |
184 | |
185 | bool needs_reduction() const { |
186 | if (!info_.is_output()) return false; |
187 | |
188 | for (int i = 0; i < mem_view().nvdims(); i++) { |
189 | if (is_broadcast_dim(i)) { |
190 | if (reg_layout_.dims()[i] != 1) return true; |
191 | } |
192 | } |
193 | return false; |
194 | } |
195 | |
196 | bool is_broadcast_dim(int dim_idx) const { |
197 | ir_assert(dim_idx >= 0 && dim_idx < mem_view().nvdims()); |
198 | return (mask() & (1 << dim_idx)) == 0; |
199 | } |
200 | |
201 | int estimate_grf_consumption() const { |
202 | int elems = int(mem_view().create_dense_vlayout().elems()); |
203 | |
204 | int ret = 0; |
205 | ret += elems * mem_view().type().size(); |
206 | if (needs_f32_convert()) ret += elems * type_t::f32().size(); |
207 | return ret; |
208 | } |
209 | |
210 | void set_reg_layout(const layout_t &layout) { reg_layout_ = layout; } |
211 | |
212 | void set_reg_buf(const expr_t &buf) { reg_buf_ = buf; } |
213 | |
214 | void set_preload(bool value = true) { do_preload_ = value; } |
215 | |
216 | bool do_preload() const { return do_preload_; } |
217 | |
218 | tensor_t apply_mask(const tensor_t &tile) const { |
219 | ir_assert(mem_view().nvdims() == tile.ndims()); |
220 | |
221 | auto start = tile.start(); |
222 | auto dims = tile.dims(); |
223 | |
224 | for (int i = 0; i < tile.ndims(); i++) { |
225 | if (!is_broadcast_dim(i)) continue; |
226 | start[i] = expr_t(0); |
227 | dims[i] = 1; |
228 | } |
229 | return tensor_t(dims, start); |
230 | } |
231 | |
232 | void init_output_buffer(const tensor_t &tile) { |
233 | ir_assert(needs_store()); |
234 | |
235 | ir_assert(reg_layout_.is_empty()); |
236 | ir_assert(reg_buf_.is_empty()); |
237 | |
238 | reg_buf_ = make_tmp_reg_buffer(); |
239 | |
240 | reg_layout_ = mem_view().create_dense_vlayout(); |
241 | reg_layout_ = reg_layout_.retype(type_t::f32()); |
242 | |
243 | // If this is output and there are masked dimensions then this buffer |
244 | // is computed via reduction. Extend layout to cover full masked_tile |
245 | // and apply the final reduction after all tiles. |
246 | auto masked_tile = apply_mask(tile); |
247 | for (int i = 0; i < masked_tile.ndims(); i++) { |
248 | if (masked_tile(i) >= tile(i)) continue; |
249 | ir_assert(masked_tile(i) == 1) << "Unexpected output tensor shape." ; |
250 | reg_layout_ = reg_layout_.add_outer_block(i, tile(i)); |
251 | } |
252 | register_buffer(reg_buf_, reg_layout_.size()); |
253 | } |
254 | |
255 | stmt_t build_load_stmt(const view_t &c_view) { |
256 | ir_assert(needs_load()); |
257 | ir_assert(reg_buf_.is_empty()); |
258 | |
259 | reg_buf_ = make_tmp_reg_buffer(); |
260 | auto read = make_access_builder(*ir_ctx_, mem_view(), mem_buf(), |
261 | reg_buf_, send_op_t::load, send_address_t::a64, |
262 | get_cache_hint(c_view)); |
263 | reg_layout_ = read.reg_layout(); |
264 | register_buffer(reg_buf_, read.reg_buf_size()); |
265 | return read.stmt(); |
266 | } |
267 | |
268 | stmt_t build_prefetch_stmt(const view_t &c_view) const { |
269 | ir_assert(needs_load()); |
270 | |
271 | auto prefetch = make_access_builder(*ir_ctx_, mem_view(), mem_buf(), |
272 | expr_t(), send_op_t::prefetch, send_address_t::a64, |
273 | get_cache_hint(c_view)); |
274 | return prefetch.stmt(); |
275 | } |
276 | |
277 | stmt_t build_convert_stmt() { |
278 | if (!needs_load() || !needs_f32_convert()) return stmt_t(); |
279 | |
280 | auto f32_buf = make_tmp_reg_buffer(); |
281 | auto f32_layout = reg_layout_.retype(type_t::f32()).make_dense(); |
282 | |
283 | register_buffer(f32_buf, f32_layout.size()); |
284 | |
285 | // Reorder to f32. |
286 | auto ret = create_reorder_stmt( |
287 | reg_layout_, f32_layout, reg_buf_, f32_buf); |
288 | |
289 | // Assign new f32 layout and buffer. |
290 | reg_layout_ = f32_layout; |
291 | reg_buf_ = f32_buf; |
292 | |
293 | return ret; |
294 | } |
295 | |
296 | stmt_t build_zero_out_stmt() const { |
297 | ir_assert(needs_store()); |
298 | return create_zero_out_stmt(*ir_ctx_, reg_buf_, reg_layout_.size()); |
299 | } |
300 | |
301 | stmt_t build_reduce_stmt() { |
302 | ir_assert(needs_store()); |
303 | |
304 | stmt_t stmt; |
305 | |
306 | if (needs_reduction()) { |
307 | auto reduced_layout = mem_view().create_dense_vlayout(); |
308 | ir_assert(reduced_layout.size() <= reg_layout_.size()); |
309 | |
310 | stmt = stmt.append( |
311 | create_reduce_stmt(reg_layout_, reduced_layout, reg_buf_, |
312 | reg_buf_, tensor_t(), mask(), /*drop_dims=*/false)); |
313 | reg_layout_ = reduced_layout; |
314 | } |
315 | |
316 | // Apply optional scaling. |
317 | stmt = stmt.append(create_mul_add_stmt(*ir_ctx_, reg_buf_, |
318 | reg_layout_.size(), reg_layout_.type(), info_.scale(), 0)); |
319 | |
320 | return stmt; |
321 | } |
322 | |
323 | stmt_t build_slm_store_stmt(const grid_info_t &tg_grid) { |
324 | ir_assert(needs_store()); |
325 | tensor_t tile(mem_view().vdims()); |
326 | slm_reduce_builder_ = slm_reduce_builder_t( |
327 | *ir_ctx_, tg_grid, reg_buf_, reg_layout_, tile, 1); |
328 | return slm_reduce_builder_.store_stmt(); |
329 | } |
330 | |
331 | stmt_t build_slm_load_stmt() { |
332 | ir_assert(needs_store()); |
333 | ir_assert(!slm_reduce_builder_.is_empty()); |
334 | |
335 | reg_layout_ = slm_reduce_builder_.reg_layout(); |
336 | |
337 | auto new_tile = slm_reduce_builder_.thr_tile(); |
338 | info_ = info_.create_sub_tensor(new_tile); |
339 | |
340 | auto &slm_allocs = slm_reduce_builder_.allocs(); |
341 | allocs_.insert(allocs_.end(), slm_allocs.begin(), slm_allocs.end()); |
342 | |
343 | return slm_reduce_builder_.load_stmt(); |
344 | } |
345 | |
346 | stmt_t build_store_stmt() const { |
347 | ir_assert(needs_store()); |
348 | |
349 | auto write = make_access_builder(*ir_ctx_, mem_view(), mem_buf(), |
350 | reg_buf(), send_op_t::atomic_fadd, send_address_t::a64); |
351 | ir_assert(write.reg_layout() == reg_layout()); |
352 | |
353 | return write.stmt(); |
354 | } |
355 | |
356 | expr_t load_expr(const tensor_t &tile, int dim_idx) const { |
357 | auto &type = reg_layout_.type(); |
358 | int elems = is_broadcast_dim(dim_idx) ? 1 : tile.elems(); |
359 | int off = reg_layout_.offset_in_bytes(expr_cast<dim_t>(tile.start())); |
360 | auto ret = (reg_buf_.type().is_ptr() |
361 | ? load_t::make(type.with_elems(elems), reg_buf_, off) |
362 | : reg_buf_); |
363 | if (elems != tile.elems()) |
364 | ret = shuffle_t::make_broadcast(ret, tile.elems()); |
365 | return ret; |
366 | } |
367 | |
368 | stmt_t store_stmt(const tensor_t &tile, int dim_idx, const expr_t &_value, |
369 | const expr_t &mask = expr_t()) const { |
370 | auto value = _value; |
371 | ir_assert(!is_broadcast_dim(dim_idx)); |
372 | ir_assert(value.type().elems() == tile.elems()); |
373 | // Add cast for booleans for comparison ops. |
374 | if (value.type().is_bool()) { |
375 | value = cast(value, reg_layout_.type().with_elems(tile.elems())); |
376 | } |
377 | int off = reg_layout_.offset_in_bytes(expr_cast<dim_t>(tile.start())); |
378 | auto ret = store_t::make( |
379 | reg_buf_, off, value, store_t::default_stride, mask); |
380 | return ret; |
381 | } |
382 | |
383 | const std::vector<stmt_t> &allocs() const { return allocs_; } |
384 | |
385 | private: |
386 | expr_t make_tmp_reg_buffer() { |
387 | auto *var = mem_buf().as_ptr<var_t>(); |
388 | if (!var) { |
389 | auto *ptr = mem_buf().as_ptr<ptr_t>(); |
390 | if (ptr) var = ptr->base.as_ptr<var_t>(); |
391 | } |
392 | ir_assert(var) << "Can't extract variable from buffer: " << mem_buf(); |
393 | auto &name = var->name; |
394 | return ir_ctx_->create_tmp_var(type_t::byte_ptr(), "tmp_" + name); |
395 | } |
396 | |
397 | void register_buffer(const expr_t &buf, int size) { |
398 | for (auto &_a : allocs_) { |
399 | auto &a = _a.as<alloc_t>(); |
400 | if (a.buf.is_same(buf)) { |
401 | if (size > a.size) { |
402 | _a = alloc_t::make(a.buf, a.size, a.kind, a.attrs); |
403 | } |
404 | return; |
405 | } |
406 | } |
407 | allocs_.push_back(alloc_t::make(buf, size, alloc_kind_t::grf)); |
408 | } |
409 | |
410 | send_cache_hint_t get_cache_hint(const view_t &c_view) const { |
411 | ir_assert(mem_view().nvdims() == c_view.nvdims()); |
412 | bool per_tensor = true; |
413 | for (int i = 0; i < mem_view().nvdims(); i++) { |
414 | if ((mask() & (1 << i)) != 0) continue; |
415 | if (c_view.vdims()[i] == 1) continue; |
416 | per_tensor = false; |
417 | break; |
418 | } |
419 | if (per_tensor) return send_cache_hint_t::load_once; |
420 | return send_cache_hint_t::undef; |
421 | } |
422 | |
423 | ir_context_t *ir_ctx_ = nullptr; |
424 | |
425 | post_op_tensor_info_t info_; |
426 | |
427 | layout_t reg_layout_; |
428 | expr_t reg_buf_; |
429 | |
430 | bool do_preload_ = false; |
431 | |
432 | std::vector<stmt_t> allocs_; |
433 | |
434 | slm_reduce_builder_t slm_reduce_builder_; |
435 | }; |
436 | |
437 | // Applies substitutions and broadcasts to generate the final post-op |
438 | // expression. |
439 | class post_op_bcast_mutator_t : public ir_mutator_t { |
440 | public: |
441 | post_op_bcast_mutator_t( |
442 | int elems, const object_map_t<object_t, object_t> &from2to) |
443 | : elems_(elems), from2to_(from2to) {} |
444 | |
445 | object_t _mutate(const float_imm_t &obj) override { |
446 | return make_bcast(obj); |
447 | } |
448 | |
449 | object_t _mutate(const int_imm_t &obj) override { |
450 | return make_bcast(float_imm_t::make(obj.value)); |
451 | } |
452 | |
453 | object_t _mutate(const var_t &obj) override { |
454 | auto it = from2to_.find(obj); |
455 | if (it != from2to_.end()) return make_bcast(it->second); |
456 | |
457 | ir_error_not_expected() << "Unknown variable." ; |
458 | return obj; |
459 | } |
460 | |
461 | private: |
462 | object_t make_bcast(const expr_t &e) const { |
463 | if (e.type().elems() == elems_) return e; |
464 | ir_assert(e.type().elems() == 1); |
465 | return shuffle_t::make_broadcast(e, elems_); |
466 | } |
467 | |
468 | int elems_; |
469 | object_map_t<object_t, object_t> from2to_; |
470 | }; |
471 | |
472 | // Builds statements to apply a post-op for a given tile. |
473 | class post_op_builder_t { |
474 | public: |
475 | post_op_builder_t(ngen::HW hw, const post_op_t &post_op) |
476 | : hw_(hw), post_op_(post_op) {} |
477 | |
478 | const post_op_t &post_op() const { return post_op_; } |
479 | |
480 | // Applies post-op for a single tile. |
481 | stmt_t build_tile_stmt(const tensor_t &tile, |
482 | const object_map_t<expr_t, post_op_tensor_t *> &args, |
483 | const zero_pad_builder_t &zero_pad_builder) const { |
484 | auto &lhs_tensor = *args.at(post_op_.lhs()); |
485 | if (!post_op_.eltwise().is_empty()) { |
486 | // Apply eltwise post-op. |
487 | ir_assert(post_op_.lhs().is_equal(post_op_.rhs())) |
488 | << "Only supported form is lhs = eltwise(lhs)." ; |
489 | int lhs_size = lhs_tensor.reg_layout().size(); |
490 | int lhs_elems = lhs_size / int(sizeof(float)); |
491 | return post_op_.eltwise().call( |
492 | {expr_t(lhs_elems), lhs_tensor.reg_buf()}); |
493 | } |
494 | |
495 | int inner_dim_idx = -1; |
496 | auto base_inner_tile = find_1d_tile( |
497 | lhs_tensor.reg_layout().type(), args, inner_dim_idx); |
498 | auto inner_layout = lhs_tensor.reg_layout().map(base_inner_tile); |
499 | ir_assert(inner_dim_idx != -1); |
500 | |
501 | // All post-ops are performed in f32 except f64 bias. |
502 | for (auto &kv : args) { |
503 | ir_assert(kv.second->reg_layout().type().is_f32() |
504 | || kv.second->reg_layout().type().is_f64()); |
505 | } |
506 | |
507 | // Handle one inner tile at a time. Inner tile covers a single block |
508 | // within a single dimension. |
509 | stmt_t stmt; |
510 | lhs_tensor.reg_layout().for_each_tile( |
511 | base_inner_tile, [&](const std::vector<dim_t> &lhs_start) { |
512 | tensor_t inner_tile(base_inner_tile.dims(), lhs_start); |
513 | auto rhs_value = compute_post_op_expr( |
514 | post_op_.rhs(), inner_tile, inner_dim_idx, args); |
515 | auto &t = *args.at(post_op_.lhs()); |
516 | expr_t store_mask; |
517 | if (lhs_tensor.needs_masked_update()) { |
518 | store_mask = zero_pad_builder.create_mask( |
519 | inner_layout, inner_tile); |
520 | } |
521 | auto inner_stmt = t.store_stmt( |
522 | inner_tile, inner_dim_idx, rhs_value, store_mask); |
523 | stmt = stmt.append(inner_stmt); |
524 | }); |
525 | |
526 | return stmt; |
527 | } |
528 | |
529 | private: |
530 | // Returns a 1D tile corresponding to an instruction to partially apply the |
531 | // post-op. |
532 | tensor_t find_1d_tile(const type_t &lhs_type, |
533 | const object_map_t<expr_t, post_op_tensor_t *> &args, |
534 | int &inner_dim_idx) const { |
535 | auto &lhs_tensor = *args.at(post_op_.lhs()); |
536 | |
537 | int ndims = lhs_tensor.mem_view().nvdims(); |
538 | |
539 | ir_assert(!lhs_tensor.reg_layout().is_empty()); |
540 | ir_assert(!lhs_tensor.reg_layout().blocks().empty()); |
541 | auto &b0 = lhs_tensor.reg_layout().blocks()[0]; |
542 | ir_assert(dim_t(b0.stride) == 1); |
543 | inner_dim_idx = b0.dim_idx; |
544 | |
545 | int inner_block = b0.block; |
546 | int max_step = 2 * ngen::GRF::bytes(hw_) / lhs_type.size(); |
547 | inner_block = std::max(8, math::gcd(inner_block, max_step)); |
548 | |
549 | for (auto &kv : args) { |
550 | auto &t = *kv.second; |
551 | if (t.is_broadcast_dim(b0.dim_idx)) continue; |
552 | |
553 | auto &l = t.reg_layout(); |
554 | ir_assert(!l.is_empty()); |
555 | ir_assert(!l.blocks().empty()); |
556 | auto &lb0 = l.blocks()[0]; |
557 | ir_assert(lb0.dim_idx == b0.dim_idx); |
558 | ir_assert(dim_t(lb0.stride) == 1); |
559 | inner_block = math::gcd(int(lb0.block), inner_block); |
560 | } |
561 | |
562 | std::vector<dim_t> dims(ndims, 1); |
563 | dims[b0.dim_idx] = inner_block; |
564 | |
565 | return tensor_t(dims); |
566 | } |
567 | |
568 | expr_t compute_post_op_expr(const expr_t &expr, const tensor_t &tile, |
569 | int dim_idx, |
570 | const object_map_t<expr_t, post_op_tensor_t *> &args) const { |
571 | object_map_t<object_t, object_t> sub_map; |
572 | for (auto &kv : args) { |
573 | auto &t = *kv.second; |
574 | auto te = t.load_expr(tile, dim_idx); |
575 | sub_map.insert({t.op_var(), te}); |
576 | } |
577 | post_op_bcast_mutator_t bcast_mutator(tile.elems(), sub_map); |
578 | return bcast_mutator.mutate(expr); |
579 | } |
580 | |
581 | ngen::HW hw_; |
582 | post_op_t post_op_; |
583 | }; |
584 | |
585 | // Epilogue consists of the following steps after the main computation (C += A * B): |
586 | // - C GRF reorder to match the memory layout for global memory store |
587 | // - C conversion to f32 (if there are post-ops) |
588 | // - Applying post-ops (if there are any) |
589 | // - C conversion to the memory layout data type |
590 | // - C store to global memory |
591 | // - Reduction and storing output post-op tensors |
592 | // |
593 | // In general C tensor is updated/transformed following the C stages described |
594 | // below. Each C stage is associated with GRF buffer and its layout. |
595 | // Multiplication -> |
596 | // M_x -> [R_f32] -> [P0_f32] -> ... -> [Pn_f32] -> [Z_f32] -> S_y -> |
597 | // GMEM |
598 | // |
599 | // Where: |
600 | // - x is data type after multiplication |
601 | // - y is destination data type |
602 | // - M_x is the stage after multiplication |
603 | // - R_f32 is the stage after reordering from M_x to f32 (optional) |
604 | // - Pi_f32 is the stage after applying Pi post-op (optional) |
605 | // - Z_f32 is the stage after restoring zero padding (optional) |
606 | // - S_y is the stage before storing C to global memory |
607 | class epilogue_builder_t { |
608 | public: |
609 | epilogue_builder_t(ir_context_t &ir_ctx, const conv_config_t &cfg, |
610 | const gemm_schedule_t &gemm_schedule, |
611 | const post_op_context_t &post_op_ctx, const tensor_t &thr_tile, |
612 | const view_t &c_mem_view, const layout_t &c_reg_layout, |
613 | const expr_t &c_mem_buf, const expr_t &c_reg_buf, int tile_size, |
614 | int preload_max_size, int post_op_blk) |
615 | : ir_ctx_(ir_ctx) |
616 | , cfg_(cfg) |
617 | , gemm_schedule_(gemm_schedule) |
618 | , post_op_ctx_(post_op_ctx) |
619 | , c_mem_view_(c_mem_view) |
620 | , c_mem_buf_(c_mem_buf) |
621 | , tg_grid_(gemm_schedule.tg_grid()) |
622 | , tile_size_(tile_size) |
623 | , preload_max_size_(preload_max_size) |
624 | , post_op_blk_(post_op_blk) { |
625 | |
626 | int tensor_idx = 0; |
627 | for (auto &po_tensor_info : post_op_ctx_.post_op_tensor_infos()) { |
628 | post_op_tensor_t po_tensor(ir_ctx_, po_tensor_info); |
629 | po_tensor = po_tensor.create_sub_tensor(thr_tile); |
630 | if (po_tensor_info.buf().is_empty()) { |
631 | // C tensor. |
632 | ir_assert(c_po_idx_ == -1); |
633 | c_po_idx_ = tensor_idx; |
634 | } |
635 | post_op_tensors_.push_back(po_tensor); |
636 | tensor_idx++; |
637 | } |
638 | |
639 | restore_zero_padding_ = post_op_ctx_.need_to_restore_zero_padding(); |
640 | |
641 | for (auto &po : post_op_ctx_.post_ops()) { |
642 | post_op_builders_.emplace_back(ir_ctx_.hw(), po); |
643 | } |
644 | |
645 | // Estimate buffer sizes required to load the full tensor, do not do |
646 | // preload if it requires too much GRF memory. |
647 | int available_size = preload_max_size_; |
648 | for (auto &t : post_op_tensors_) { |
649 | if (!t.needs_load()) continue; |
650 | int required_size = t.estimate_grf_consumption(); |
651 | if (required_size > available_size) continue; |
652 | available_size -= required_size; |
653 | t.set_preload(); |
654 | } |
655 | |
656 | build(c_reg_layout, c_reg_buf); |
657 | } |
658 | |
659 | const stmt_t &stmt() const { return stmt_; } |
660 | |
661 | private: |
662 | void register_buffer(const expr_t &buf, int size) { |
663 | buf_sizes_[buf] = std::max(buf_sizes_[buf], size); |
664 | } |
665 | |
666 | expr_t make_c_tmp_buffer() const { |
667 | return ir_ctx_.create_tmp_var(type_t::byte_ptr(), "c_tmp" ); |
668 | } |
669 | |
670 | // Represents a GRF buffer and layout to store C tensor. |
671 | struct c_stage_t { |
672 | c_stage_t(const layout_t &layout, const expr_t &buf, |
673 | const stmt_t &stmt = stmt_t()) |
674 | : layout(layout), buf(buf), stmt(stmt) {} |
675 | |
676 | void set_next( |
677 | ir_context_t &ir_ctx, c_stage_t *next, bool force_reorder) { |
678 | if (!next) return; |
679 | bool do_reorder |
680 | = !layout.is_equal(next->layout, /*compare_offset=*/false); |
681 | if (force_reorder) do_reorder = true; |
682 | if (do_reorder) { |
683 | ir_assert(stmt.is_empty()); |
684 | // Generate reorder between stages. |
685 | stmt = create_reorder_stmt( |
686 | layout, next->layout, buf, next->buf); |
687 | } else { |
688 | // Reuse the same GRF buffer for the next stage. |
689 | int this_off = to_cpp<int>(layout.offset_in_bytes()); |
690 | int next_off = to_cpp<int>(next->layout.offset_in_bytes()); |
691 | ir_assert(next_off == 0); |
692 | MAYBE_UNUSED(next_off); |
693 | next->set_buf(buf[this_off]); |
694 | } |
695 | } |
696 | |
697 | void set_buf(const expr_t &buf) { |
698 | // Replace old buffer if there is an assigned statement. |
699 | if (!stmt.is_empty()) { stmt = substitute(stmt, this->buf, buf); } |
700 | this->buf = buf; |
701 | } |
702 | |
703 | const expr_t &buf_base() const { |
704 | if (buf.is<var_t>()) return buf; |
705 | return buf.as<ptr_t>().base; |
706 | } |
707 | |
708 | int buf_size() const { |
709 | ir_assert(buf.is_same(buf_base())) |
710 | << "Size must be queried from another stage." ; |
711 | return int(layout.size()); |
712 | } |
713 | |
714 | void prepend_stmt(const stmt_t &stmt) { |
715 | this->stmt = stmt.append(this->stmt); |
716 | } |
717 | |
718 | layout_t layout; |
719 | expr_t buf; |
720 | stmt_t stmt; // Statement to emit after the stage. |
721 | }; |
722 | |
723 | void build(const layout_t &c_reg_layout, const expr_t &c_reg_buf) { |
724 | auto tmp_type = (post_op_builders_.empty() ? c_mem_view_.type() |
725 | : type_t::f32()); |
726 | int tmp_buf_elems = tile_size_ / tmp_type.size(); |
727 | auto base_tile = c_mem_view_.split_into_max_tile( |
728 | tmp_buf_elems, /*is_dense=*/false); |
729 | |
730 | // Generate preload statements. |
731 | for (auto &t : post_op_tensors_) { |
732 | if (!t.do_preload()) continue; |
733 | stmt_ = stmt_.append(t.build_load_stmt(c_mem_view_)); |
734 | } |
735 | |
736 | // Generate prefetch statements. |
737 | if (ir_ctx_.hw() >= ngen::HW::XeHPC) { |
738 | for (auto &t : post_op_tensors_) { |
739 | if (!t.needs_load()) continue; |
740 | if (t.do_preload()) continue; |
741 | stmt_ = stmt_.append(t.build_prefetch_stmt(c_mem_view_)); |
742 | } |
743 | } |
744 | |
745 | // Generate f32 convert statements. |
746 | for (auto &t : post_op_tensors_) { |
747 | if (!t.do_preload()) continue; |
748 | if (!t.needs_f32_convert()) continue; |
749 | stmt_ = stmt_.append(t.build_convert_stmt()); |
750 | } |
751 | |
752 | // Initialize buffers for output post-op tensors. |
753 | for (auto &t : post_op_tensors_) { |
754 | if (!t.needs_store()) continue; |
755 | t.init_output_buffer(base_tile); |
756 | } |
757 | |
758 | // Generate zero-out statements for output post-op tensors. |
759 | for (auto &t : post_op_tensors_) { |
760 | if (!t.needs_store()) continue; |
761 | stmt_ = stmt_.append(t.build_zero_out_stmt()); |
762 | } |
763 | |
764 | // Iterate by tiles and apply post-ops. |
765 | c_mem_view_.for_each_tile( |
766 | base_tile, [&](const std::vector<dim_t> &start) { |
767 | tensor_t tile(base_tile.dims(), start); |
768 | auto c_tile_layout = c_reg_layout.map(tile); |
769 | build_tile(tile, c_tile_layout, c_reg_buf); |
770 | }); |
771 | |
772 | // TODO: Generalize the condition. Iterate through output tensor masks |
773 | // and ensure C is distributed accordingly in thread group. |
774 | bool use_slm_reduction = (tg_grid_.dim(1) > 1); |
775 | |
776 | // Generate reduce and store statements for output post-op tensors. |
777 | stmt_t thr_reduce_stmt; |
778 | stmt_t slm_store_stmt; |
779 | stmt_t slm_load_stmt; |
780 | stmt_t mem_store_stmt; |
781 | for (auto &t : post_op_tensors_) { |
782 | if (!t.needs_store()) continue; |
783 | |
784 | thr_reduce_stmt = thr_reduce_stmt.append(t.build_reduce_stmt()); |
785 | if (use_slm_reduction) { |
786 | auto store_stmt = t.build_slm_store_stmt(tg_grid_); |
787 | auto load_stmt = t.build_slm_load_stmt(); |
788 | slm_store_stmt = slm_store_stmt.append(store_stmt); |
789 | slm_load_stmt = slm_load_stmt.append(load_stmt); |
790 | } |
791 | mem_store_stmt = mem_store_stmt.append(t.build_store_stmt()); |
792 | } |
793 | |
794 | stmt_ = stmt_.append(thr_reduce_stmt); |
795 | if (!slm_store_stmt.is_empty()) { |
796 | stmt_ = stmt_.append(funcs::barrier()); |
797 | stmt_ = stmt_.append(slm_store_stmt); |
798 | stmt_ = stmt_.append(funcs::barrier()); |
799 | stmt_ = stmt_.append(slm_load_stmt); |
800 | } |
801 | |
802 | stmt_ = stmt_.append(mem_store_stmt); |
803 | |
804 | // Generate alloc statements for post-op tensors. |
805 | std::vector<stmt_t> allocs; |
806 | for (auto &t : post_op_tensors_) { |
807 | auto t_allocs = t.allocs(); |
808 | allocs.insert(allocs.end(), t_allocs.begin(), t_allocs.end()); |
809 | } |
810 | stmt_ = jit::inject_alloc_stmts(stmt_, allocs, /*put_innermost=*/true); |
811 | } |
812 | |
813 | // Builds statements for a tile iterating through all post-ops. |
814 | void build_tile(const tensor_t &tile, const layout_t &c_tile_layout, |
815 | const expr_t &c_reg_buf) { |
816 | auto c_mem_tile_view = c_mem_view_.create_sub_view(tile); |
817 | auto tmp_reg_buf = make_c_tmp_buffer(); |
818 | |
819 | type_t post_op_type |
820 | = c_tile_layout.type().is_f64() ? type_t::f64() : type_t::f32(); |
821 | bool create_zero_pad_builder = restore_zero_padding_; |
822 | for (auto &t : post_op_tensors_) { |
823 | if (t.needs_masked_update()) { |
824 | create_zero_pad_builder = true; |
825 | break; |
826 | } |
827 | } |
828 | if (create_zero_pad_builder) { |
829 | zero_pad_builder_ = zero_pad_builder_t( |
830 | ir_ctx_, post_op_ctx_.cp_view(), c_mem_tile_view); |
831 | } |
832 | |
833 | // S_y -> GMEM. |
834 | auto send_op = gemm_schedule_.with_kernel_grid_k_slicing() |
835 | ? send_op_t::atomic_fadd |
836 | : send_op_t::store; |
837 | auto send_hint = get_send_hint(ir_ctx_.exec_cfg(), send_op, |
838 | abc_kind_t::c, c_mem_tile_view, gemm_schedule_); |
839 | auto r2g = make_access_builder(ir_ctx_, c_mem_tile_view, c_mem_buf_, |
840 | tmp_reg_buf, send_op, send_address_t::a64, send_hint); |
841 | |
842 | // Initialize C stages. |
843 | std::vector<c_stage_t> c_stages; |
844 | |
845 | auto c_fx_layout = r2g.reg_layout().retype(post_op_type).make_dense(); |
846 | bool with_post_ops = !post_op_builders_.empty(); |
847 | int npost_ops = int(post_op_builders_.size()); |
848 | |
849 | int c_f32_stage_idx = -1; |
850 | int c_zero_pad_stage_idx = -1; |
851 | |
852 | c_stages.emplace_back(c_tile_layout, c_reg_buf); // M_x |
853 | if (with_post_ops) { |
854 | c_f32_stage_idx = int(c_stages.size()); |
855 | c_stages.emplace_back(c_fx_layout, make_c_tmp_buffer()); // R_f32 |
856 | } |
857 | if (restore_zero_padding_) { |
858 | c_zero_pad_stage_idx = int(c_stages.size()); |
859 | c_stages.emplace_back(c_fx_layout, make_c_tmp_buffer()); // Z_f32 |
860 | } |
861 | c_stages.emplace_back(r2g.reg_layout(), tmp_reg_buf, r2g.stmt()); // S_y |
862 | |
863 | int nstages = int(c_stages.size()); |
864 | bool is_dpasw = (cfg_.fma_kind() == fma_kind_t::dpasw); |
865 | |
866 | // Generate reorders between C stages if needed. |
867 | for (int i = 0; i < nstages; i++) { |
868 | auto *next_stage = (i + 1 < nstages ? &c_stages[i + 1] : nullptr); |
869 | // Always perform reorder when dpasw is used. This is to ensure |
870 | // that C is properly restored and permuted after dpasw. |
871 | c_stages[i].set_next(ir_ctx_, next_stage, |
872 | /*force_reorder=*/i == 0 && is_dpasw); |
873 | } |
874 | |
875 | // Restore zero padding if needed. |
876 | if (c_zero_pad_stage_idx != -1) { |
877 | auto &s = c_stages[c_zero_pad_stage_idx]; |
878 | s.prepend_stmt(zero_pad_builder_.build_stmt(s.layout, s.buf)); |
879 | } |
880 | |
881 | // Create sub-tensors for post-ops. |
882 | std::vector<post_op_tensor_t> sub_po_tensors; |
883 | for (auto &t : post_op_tensors_) |
884 | sub_po_tensors.push_back(t.create_sub_tensor(tile)); |
885 | |
886 | // Set C tensor layout and buffer to use for post-ops. |
887 | if (c_f32_stage_idx != -1) { |
888 | auto &s = c_stages[c_f32_stage_idx]; |
889 | sub_po_tensors[c_po_idx_].set_reg_layout(s.layout); |
890 | sub_po_tensors[c_po_idx_].set_reg_buf(s.buf); |
891 | } |
892 | |
893 | stmt_t tile_stmt; |
894 | |
895 | // Add C stage statements and post-op statements. |
896 | for (int i = 0; i < nstages; i++) { |
897 | if (with_post_ops && i == c_f32_stage_idx) { |
898 | // Emit post-ops in blocks to reduce GRF consumption. |
899 | for (int j = 0; j < npost_ops; j += post_op_blk_) { |
900 | int k_beg = j; |
901 | int k_end = std::min(npost_ops, j + post_op_blk_); |
902 | auto blk_stmt = build_post_op_block_stmt( |
903 | tile, sub_po_tensors, k_beg, k_end); |
904 | tile_stmt = tile_stmt.append(blk_stmt); |
905 | } |
906 | } |
907 | tile_stmt = tile_stmt.append(c_stages[i].stmt); |
908 | } |
909 | |
910 | // Generate alloc statements for C stage buffers. |
911 | object_set_t<expr_t> seen; |
912 | for (int i = 0; i < nstages; i++) { |
913 | auto &s = c_stages[i]; |
914 | auto &buf = s.buf_base(); |
915 | auto ret = seen.insert(buf); |
916 | if (i == 0 || !ret.second) continue; |
917 | int size = utils::rnd_up(s.buf_size(), ir_ctx_.grf_size()); |
918 | tile_stmt = alloc_t::make(buf, size, alloc_kind_t::grf, tile_stmt); |
919 | } |
920 | |
921 | stmt_ = stmt_.append(tile_stmt); |
922 | } |
923 | |
924 | stmt_t build_post_op_block_stmt(const tensor_t &tile, |
925 | std::vector<post_op_tensor_t> &sub_po_tensors, int po_beg, |
926 | int po_end) const { |
927 | // Collect post-op inputs/outputs. |
928 | object_map_t<expr_t, post_op_tensor_t *> args; |
929 | for (int i = po_beg; i < po_end; i++) { |
930 | auto &po_builder = post_op_builders_[i]; |
931 | for (auto &t : sub_po_tensors) { |
932 | if (po_builder.post_op().uses(t.op_var())) { |
933 | args.insert({t.op_var(), &t}); |
934 | } |
935 | } |
936 | } |
937 | |
938 | // Generate load and convert statements for the post-op. |
939 | stmt_t load_stmt; |
940 | stmt_t convert_stmt; |
941 | for (auto &kv : args) { |
942 | auto &t = *kv.second; |
943 | if (!t.needs_load()) continue; |
944 | if (t.do_preload()) continue; |
945 | load_stmt = load_stmt.append(t.build_load_stmt(c_mem_view_)); |
946 | if (t.needs_f32_convert()) { |
947 | convert_stmt = convert_stmt.append(t.build_convert_stmt()); |
948 | } |
949 | } |
950 | |
951 | stmt_t stmt; |
952 | stmt = stmt.append(load_stmt); |
953 | stmt = stmt.append(convert_stmt); |
954 | |
955 | for (int i = po_beg; i < po_end; i++) { |
956 | auto &po_builder = post_op_builders_[i]; |
957 | auto po_stmt |
958 | = po_builder.build_tile_stmt(tile, args, zero_pad_builder_); |
959 | stmt = stmt.append(po_stmt); |
960 | } |
961 | |
962 | // Generate alloc statements for post-op tensors. |
963 | std::vector<stmt_t> allocs; |
964 | for (auto &kv : args) { |
965 | auto &t = *kv.second; |
966 | if (!t.needs_load()) continue; |
967 | if (t.do_preload()) continue; |
968 | auto t_allocs = t.allocs(); |
969 | allocs.insert(allocs.end(), t_allocs.begin(), t_allocs.end()); |
970 | } |
971 | stmt = jit::inject_alloc_stmts(stmt, allocs); |
972 | |
973 | return stmt; |
974 | } |
975 | |
976 | ir_context_t &ir_ctx_; |
977 | const conv_config_t &cfg_; |
978 | const gemm_schedule_t &gemm_schedule_; |
979 | const post_op_context_t &post_op_ctx_; |
980 | |
981 | // C view in global memory. |
982 | view_t c_mem_view_; |
983 | expr_t c_mem_buf_; |
984 | |
985 | // C layout after the main loop. |
986 | layout_t c_reg_layout_; |
987 | expr_t c_reg_buf_; |
988 | |
989 | const grid_info_t &tg_grid_; |
990 | |
991 | bool restore_zero_padding_; |
992 | |
993 | zero_pad_builder_t zero_pad_builder_; |
994 | |
995 | // Tile size in bytes. The tile data type is: |
996 | // - the destination data type without post-ops |
997 | // - f32 with post-ops |
998 | int tile_size_; |
999 | int preload_max_size_; |
1000 | int post_op_blk_; |
1001 | |
1002 | std::vector<post_op_builder_t> post_op_builders_; |
1003 | std::vector<post_op_tensor_t> post_op_tensors_; |
1004 | int c_po_idx_ = -1; |
1005 | |
1006 | object_map_t<expr_t, int> buf_sizes_; |
1007 | |
1008 | stmt_t stmt_; |
1009 | }; |
1010 | |
1011 | int get_post_op_mem_usage(const post_op_tensor_info_t &info, int c_elems, |
1012 | const view_t &c_mem_view, int max_elems_per_dim = 64) { |
1013 | int po_elems = 1; |
1014 | for (int i = 0; i < info.view().nvdims(); i++) { |
1015 | if ((info.mask() & (1 << i)) == 0) continue; |
1016 | po_elems *= std::min(max_elems_per_dim, (int)c_mem_view.vdims()[i]); |
1017 | } |
1018 | po_elems = std::min(po_elems, c_elems); |
1019 | int type_size = info.view().type().size(); |
1020 | int load_size = po_elems * type_size; |
1021 | int cvt_size = info.view().type().is_f32() ? 0 : po_elems * sizeof(float); |
1022 | return load_size + cvt_size; |
1023 | } |
1024 | |
1025 | int find_tile_size(const exec_config_t &exec_cfg, |
1026 | const post_op_context_t &post_op_ctx, const view_t &c_mem_view, |
1027 | const layout_t &c_reg_layout, int preload_max_size, int post_op_blk) { |
1028 | bool with_post_ops = !post_op_ctx.post_ops().empty(); |
1029 | for (int tile_size = 1024; tile_size >= 1; tile_size /= 2) { |
1030 | int c_type_size = c_mem_view.type().size(); |
1031 | int elems = tile_size / (with_post_ops ? sizeof(float) : c_type_size); |
1032 | int c_mul_size = elems * c_type_size; |
1033 | int c_f32_size = with_post_ops && !c_mem_view.type().is_f32() |
1034 | ? elems * sizeof(float) |
1035 | : 0; |
1036 | int c_size = c_mul_size + c_f32_size; |
1037 | int po_size = 0; |
1038 | |
1039 | auto &infos = post_op_ctx.post_op_tensor_infos(); |
1040 | int npost_ops = int(infos.size()); |
1041 | for (int i = 0; i < npost_ops; i += post_op_blk) { |
1042 | int po_batch_size = 0; |
1043 | for (int j = i; j < std::min(npost_ops, i + post_op_blk); j++) { |
1044 | auto &t = infos[j]; |
1045 | if (!t.is_input() || !t.buf().type().is_ptr()) continue; |
1046 | po_batch_size += get_post_op_mem_usage(t, elems, c_mem_view); |
1047 | } |
1048 | po_size = std::max(po_size, po_batch_size); |
1049 | } |
1050 | |
1051 | int total_size = c_size + preload_max_size + po_size; |
1052 | int available_size = exec_cfg.regs() * exec_cfg.grf_size() |
1053 | - (int)c_reg_layout.size(); |
1054 | if (total_size <= available_size * 0.7) return tile_size; |
1055 | } |
1056 | ir_error_not_expected(); |
1057 | return -1; |
1058 | } |
1059 | |
1060 | stmt_t create_epilogue_stmt(const conv_config_t &cfg, ir_context_t &ir_ctx, |
1061 | const gemm_schedule_t &gemm_schedule, |
1062 | const post_op_context_t &post_op_ctx, const tensor_t &thr_tile, |
1063 | const view_t &c_mem_view, const layout_t &c_reg_layout, |
1064 | const expr_t &c_mem_buf, const expr_t &c_reg_buf) { |
1065 | // Max size of post-op tensor buffers to preload and reuse for all tiles. |
1066 | int preload_max_size = 512; |
1067 | // Block size to apply post-ops within tile. A post-op may have associated |
1068 | // loads/conversions, larger block size helps to have more latency hiding |
1069 | // across multiple post-ops. |
1070 | int post_op_blk = 8; |
1071 | // Tile size in bytes. All post-ops are applied to a single tile, then to |
1072 | // the next tile, etc. |
1073 | int tile_size = find_tile_size(cfg.exec_cfg(), post_op_ctx, c_mem_view, |
1074 | c_reg_layout, preload_max_size, post_op_blk); |
1075 | |
1076 | ir_trace() << "Creating epilogue with parameters" |
1077 | << ": tile_size = " << tile_size |
1078 | << ", preload_max_size = " << preload_max_size |
1079 | << ", post_op_blk = " << post_op_blk << std::endl; |
1080 | epilogue_builder_t builder(ir_ctx, cfg, gemm_schedule, post_op_ctx, |
1081 | thr_tile, c_mem_view, c_reg_layout, c_mem_buf, c_reg_buf, tile_size, |
1082 | preload_max_size, post_op_blk); |
1083 | return builder.stmt(); |
1084 | } |
1085 | |
1086 | } // namespace jit |
1087 | } // namespace gpu |
1088 | } // namespace impl |
1089 | } // namespace dnnl |
1090 | |