1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_REORDER_REORDER_KERNEL_HPP |
18 | #define GPU_JIT_REORDER_REORDER_KERNEL_HPP |
19 | |
20 | #include "gpu/jit/codegen/codegen.hpp" |
21 | #include "gpu/jit/codegen/kernel.hpp" |
22 | #include "gpu/jit/codegen/ngen_helpers.hpp" |
23 | #include "gpu/jit/codegen/register_scope.hpp" |
24 | #include "gpu/jit/ir/ir_builder.hpp" |
25 | #include "gpu/jit/ir/message.hpp" |
26 | #include "gpu/jit/ir/reorder.hpp" |
27 | #include "gpu/jit/ir/tensor.hpp" |
28 | #include "gpu/jit/reorder/ir_builder.hpp" |
29 | #include "gpu/jit/utils/ngen_type_bridge.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace jit { |
35 | |
36 | template <ngen::HW hw = ngen::HW::Unknown> |
37 | class reorder_kernel_t : public ir_kernel_t<hw> { |
38 | public: |
39 | IR_KERNEL_FORWARD(hw) |
40 | |
41 | reorder_kernel_t(const exec_config_t &exec_cfg, |
42 | const std::string &kernel_name, const kernel_info_t &kernel_info, |
43 | const layout_t &src_layout, const layout_t &dst_layout, |
44 | bool require_dpas, grf_mode_t grf_mode) |
45 | : ir_kernel_t<hw>( |
46 | kernel_name, exec_cfg, kernel_info, require_dpas, grf_mode) { |
47 | |
48 | if (reorder_kernel_t<>::is_ir_based_reorder(src_layout, dst_layout)) { |
49 | reorder_ir_builder_t builder( |
50 | exec_cfg, kernel_info, src_layout, dst_layout); |
51 | stmt_t body = builder.stmt(); |
52 | setup_interface(body); |
53 | generate_prologue(); |
54 | expr_binding_t expr_binding(hw); |
55 | bind_external_vars(body, builder.kernel_grid(), builder.local_id(), |
56 | expr_binding); |
57 | |
58 | // Generate assembly from IR. |
59 | ir_to_ngen_t<hw> visitor(this, expr_binding); |
60 | visitor.visit(body); |
61 | |
62 | generate_epilogue(); |
63 | return; |
64 | } |
65 | |
66 | // Handle specific reorder versions. |
67 | setup_interface(); |
68 | generate_prologue(); |
69 | |
70 | std::vector<std::string> arg_names(kernel_info.nargs()); |
71 | for (int i = 0; i < kernel_info.nargs(); i++) { |
72 | arg_names[i] = kernel_info.arg_name(i); |
73 | } |
74 | src_ptr_ = getArgument(arg_names[0]); |
75 | dst_ptr_ = getArgument(arg_names[1]); |
76 | src_surf_ = Surface(getArgumentSurface(arg_names[0])); |
77 | dst_surf_ = Surface(getArgumentSurface(arg_names[1])); |
78 | elems_ = getArgument(arg_names[2]); |
79 | |
80 | global_id_ = ra_.template alloc_sub<uint32_t>(); |
81 | |
82 | mul(1, global_id_, r0.ud(1), getLocalSize(0).uw()); |
83 | add(1, global_id_, global_id_, getLocalID(0)); |
84 | |
85 | int elems_per_thr; |
86 | if (is_2d_reorder(src_layout, dst_layout, elems_per_thr)) { |
87 | emit_2d_reorder(src_layout, dst_layout, elems_per_thr); |
88 | } else { |
89 | ir_error_not_expected(); |
90 | } |
91 | |
92 | generate_epilogue(); |
93 | } |
94 | |
95 | void emit_2d_reorder( |
96 | const layout_t &_src, const layout_t &_dst, int elems_per_thr) { |
97 | auto tile = _src.split_into_max_tile(elems_per_thr, /*is_dense=*/true); |
98 | ir_assert(!tile.is_empty()) << "Can't split " << _src; |
99 | |
100 | int simd_size = getSIMD(); |
101 | auto src = _src.map(tile); |
102 | auto dst = _dst.map(tile); |
103 | |
104 | int src_size = src.type().size(); |
105 | int dst_size = dst.type().size(); |
106 | int src_tile_bytes = src_size * elems_per_thr; |
107 | int dst_tile_bytes = dst_size * elems_per_thr; |
108 | |
109 | int grf_size = ngen::GRF::bytes(hw); |
110 | |
111 | auto S = ra_.alloc_range(utils::div_up(src_tile_bytes, grf_size)); |
112 | auto D = ra_.alloc_range(utils::div_up(dst_tile_bytes, grf_size)); |
113 | |
114 | auto = ra_.alloc(); |
115 | auto = ra_.alloc(); |
116 | |
117 | // Prepare headers for loads and stores. |
118 | eshl(1, src_header.uq(0), global_id_, |
119 | math::ilog2q(src_tile_bytes / simd_size)); |
120 | eshl(1, dst_header.uq(0), global_id_, |
121 | math::ilog2q(dst_tile_bytes / simd_size)); |
122 | eadd(1, src_header.uq(0), src_header.uq(0), src_ptr_); |
123 | eadd(1, dst_header.uq(0), dst_header.uq(0), dst_ptr_); |
124 | |
125 | int oword_bytes = 16; |
126 | |
127 | // Load source tile. |
128 | int src_off = 0; |
129 | while (src_tile_bytes > 0) { |
130 | for (int i = 3; i >= 0; i--) { |
131 | int size = (1 << i) * oword_bytes; |
132 | if (src_tile_bytes >= size) { |
133 | load(16, S[src_off / grf_size], |
134 | ngen::block_oword(size / oword_bytes), A64, |
135 | src_header); |
136 | eadd(1, src_header.uq(0), src_header.uq(0), size); |
137 | src_tile_bytes -= size; |
138 | src_off += size; |
139 | break; |
140 | } |
141 | } |
142 | } |
143 | |
144 | // Reorder source tile to destination tile. |
145 | ngen_register_scope_t scope(ra_); |
146 | reorder_2d_impl_t r(hw, src, dst); |
147 | reg_buf_t S_buf(hw, S); |
148 | reg_buf_t D_buf(hw, D); |
149 | r.emit(this, scope, S_buf, D_buf); |
150 | |
151 | // Store destination tile. |
152 | int dst_off = 0; |
153 | while (dst_tile_bytes > 0) { |
154 | for (int i = 3; i >= 0; i--) { |
155 | int size = (1 << i) * oword_bytes; |
156 | if (dst_tile_bytes >= size) { |
157 | store(16, ngen::block_oword(size / oword_bytes), A64, |
158 | dst_header, D[dst_off / grf_size]); |
159 | eadd(1, dst_header.uq(0), dst_header.uq(0), size); |
160 | dst_tile_bytes -= size; |
161 | dst_off += size; |
162 | break; |
163 | } |
164 | } |
165 | } |
166 | } |
167 | |
168 | static bool is_ir_based_reorder(const layout_t &src, const layout_t &dst) { |
169 | int dummy; |
170 | if (is_2d_reorder(src, dst, dummy)) return false; |
171 | return true; |
172 | } |
173 | |
174 | static compute::nd_range_t nd_range(const exec_config_t &exec_cfg, |
175 | const layout_t &src, const layout_t &dst) { |
176 | const int simd = exec_cfg.simd(); |
177 | |
178 | int elems_per_thr; |
179 | if (is_2d_reorder(src, dst, elems_per_thr)) { |
180 | ir_assert(src.elems() == dst.elems()); |
181 | return compute::nd_range_t( |
182 | {(int)utils::div_up(src.elems(), elems_per_thr) * simd, 1, |
183 | 1}); |
184 | } |
185 | |
186 | // Handle IR-based reorder. |
187 | ir_assert(reorder_kernel_t<>::is_ir_based_reorder(src, dst)); |
188 | |
189 | return reorder_ir_builder_t::nd_range(exec_cfg, src, dst); |
190 | } |
191 | |
192 | private: |
193 | static bool is_2d_reorder( |
194 | const layout_t &src, const layout_t &dst, int &elems_per_thr) { |
195 | if (!src.type().is_bitwise_compatible(dst.type())) return false; |
196 | if (src.is_equal(dst)) return false; |
197 | |
198 | const int hword_bytes = 32; |
199 | const int min_bytes_per_thr = hword_bytes; |
200 | const int max_bytes_per_thr = 32 * hword_bytes; |
201 | |
202 | int type_size = src.type().size(); |
203 | int max_elems_per_thr = max_bytes_per_thr / type_size; |
204 | |
205 | auto tile = reorder_2d_impl_t::find_2d_tile( |
206 | src, dst, max_elems_per_thr, /*match_outer=*/true); |
207 | |
208 | if (tile.is_empty()) return false; |
209 | if (tile.ndims() < 2) return false; |
210 | |
211 | elems_per_thr = tile.elems(); |
212 | if (!math::is_pow2(elems_per_thr)) return false; |
213 | |
214 | int bytes_per_thr = elems_per_thr * type_size; |
215 | if (bytes_per_thr % hword_bytes != 0) return false; |
216 | if (bytes_per_thr < min_bytes_per_thr) return false; |
217 | if (bytes_per_thr > max_bytes_per_thr) return false; |
218 | |
219 | return true; |
220 | } |
221 | |
222 | ngen::Subregister src_ptr_; |
223 | ngen::Subregister dst_ptr_; |
224 | ngen::AddressBase src_surf_; |
225 | ngen::AddressBase dst_surf_; |
226 | ngen::Subregister elems_; |
227 | ngen::Subregister global_id_; |
228 | }; |
229 | |
230 | } // namespace jit |
231 | } // namespace gpu |
232 | } // namespace impl |
233 | } // namespace dnnl |
234 | |
235 | #endif |
236 | |