1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/slm_reduce_builder.hpp" |
18 | |
19 | #include <algorithm> |
20 | |
21 | #include "gpu/jit/ir/message.hpp" |
22 | #include "gpu/jit/ir/mul_add.hpp" |
23 | #include "gpu/jit/ir/reduce.hpp" |
24 | #include "gpu/jit/utils/trace.hpp" |
25 | #include "gpu/jit/utils/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace jit { |
31 | |
32 | slm_reduce_builder_t::slm_reduce_builder_t(ir_context_t &ir_ctx, |
33 | const grid_info_t &tg_grid, const expr_t ®_buf, |
34 | const layout_t ®_layout, const tensor_t &thr_tile, int dim) |
35 | : ir_ctx_(&ir_ctx) |
36 | , tg_grid_(tg_grid) |
37 | , reg_buf_(reg_buf) |
38 | , reg_layout_(reg_layout) |
39 | , thr_tile_(thr_tile) |
40 | , dim_(dim) { |
41 | ir_assert((dim_ >= 0) && (dim_ <= 2)); |
42 | ir_assert(tg_grid_.dim(dim_) > 1); |
43 | |
44 | tmp_reg_buf_ = ir_ctx.create_tmp_var(type_t::byte_ptr()); |
45 | slm_buf_ = ir_ctx.create_tmp_var(type_t::byte_ptr(), "reduce_slm" ); |
46 | tg_ndims_ = (dim_ != 2) ? dim_ + 1 : tg_grid_.ndims(); |
47 | |
48 | build(); |
49 | } |
50 | |
51 | void slm_reduce_builder_t::build() { |
52 | int ndims = reg_layout_.ndims(); |
53 | |
54 | // Create SLM layout to store all intermediate buffers from the thread |
55 | // group. |
56 | layout_t slm_layout(reg_layout_.type(), ndims + tg_ndims_, |
57 | reg_layout_.offset(), reg_layout_.blocks()); |
58 | for (int i = tg_ndims_ - 1; i >= 0; i--) { |
59 | slm_layout = slm_layout.add_outer_block(ndims + i, tg_grid_.dim(i)); |
60 | } |
61 | |
62 | slm_buf_size_ = slm_layout.size(); |
63 | |
64 | // Write thread tile to SLM. |
65 | std::vector<dim_t> write_dims = reg_layout_.dims(); |
66 | std::vector<expr_t> write_start(ndims + tg_ndims_, 0); |
67 | write_dims.resize(ndims + tg_ndims_, 1); |
68 | for (int i = tg_ndims_ - 1; i >= 0; i--) { |
69 | write_start[ndims + i] = tg_grid_.idx(i); |
70 | } |
71 | auto write_tile = tensor_t(write_dims, write_start); |
72 | auto write |
73 | = make_access_builder(*ir_ctx_, view_t(slm_layout.map(write_tile)), |
74 | slm_buf_, reg_buf_, send_op_t::store, send_address_t::slm); |
75 | store_stmt_ = write.stmt(); |
76 | |
77 | auto &write_layout = write.reg_layout(); |
78 | ir_assert(write_layout == reg_layout_) << "Incompatible layouts." ; |
79 | |
80 | // Redistribute the layout to read/reduce all k-axis tiles from every |
81 | // thread. |
82 | auto local_thr_tile = reg_layout_.split(tg_grid_.sub_grid({dim_})); |
83 | reg_layout_ = reg_layout_.map(tensor_t(local_thr_tile.dims())); |
84 | |
85 | std::vector<dim_t> read_dims(ndims + tg_ndims_, 1); |
86 | std::vector<expr_t> read_start(ndims + tg_ndims_); |
87 | for (int i = 0; i < ndims; i++) { |
88 | read_dims[i] = local_thr_tile(i); |
89 | read_start[i] = local_thr_tile.start(i); |
90 | auto cond = read_start[i] < slm_layout.dims()[i]; |
91 | if (reduce_cond_.is_empty()) |
92 | reduce_cond_ = cond; |
93 | else |
94 | reduce_cond_ &= cond; |
95 | } |
96 | read_dims[ndims + dim_] = tg_grid_.dim(dim_); |
97 | for (int i = 0; i < tg_ndims_; i++) { |
98 | read_start[ndims + i] = (i == dim_) ? 0 : tg_grid_.idx(i); |
99 | } |
100 | tensor_t read_tile(read_dims, read_start); |
101 | auto read = make_access_builder(*ir_ctx_, view_t(slm_layout.map(read_tile)), |
102 | slm_buf_, tmp_reg_buf_, send_op_t::load, send_address_t::slm); |
103 | |
104 | load_stmt_ = load_stmt_.append( |
105 | create_zero_out_stmt(*ir_ctx_, reg_buf_, reg_layout_.size())); |
106 | load_stmt_ = load_stmt_.append(read.stmt()); |
107 | |
108 | tmp_reg_buf_size_ = std::max(tmp_reg_buf_size_, read.reg_buf_size()); |
109 | |
110 | auto read_layout = read.reg_layout(); |
111 | load_stmt_ = load_stmt_.append(create_reduce_stmt(read_layout, reg_layout_, |
112 | tmp_reg_buf_, reg_buf_, tensor_t(), reduction_mask())); |
113 | |
114 | allocs_.push_back( |
115 | alloc_t::make(slm_buf_, slm_buf_size_, alloc_kind_t::slm)); |
116 | allocs_.push_back( |
117 | alloc_t::make(tmp_reg_buf_, tmp_reg_buf_size_, alloc_kind_t::grf)); |
118 | |
119 | if (!reduce_cond_.is_empty()) |
120 | load_stmt_ = if_t::make(reduce_cond_, load_stmt_); |
121 | if (!thr_tile_.is_empty()) { |
122 | thr_tile_ = thr_tile_.create_sub_tensor(local_thr_tile); |
123 | } |
124 | } |
125 | |
126 | } // namespace jit |
127 | } // namespace gpu |
128 | } // namespace impl |
129 | } // namespace dnnl |
130 | |