1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_REORDER_REORDER_KERNEL_HPP
18#define GPU_JIT_REORDER_REORDER_KERNEL_HPP
19
20#include "gpu/jit/codegen/codegen.hpp"
21#include "gpu/jit/codegen/kernel.hpp"
22#include "gpu/jit/codegen/ngen_helpers.hpp"
23#include "gpu/jit/codegen/register_scope.hpp"
24#include "gpu/jit/ir/ir_builder.hpp"
25#include "gpu/jit/ir/message.hpp"
26#include "gpu/jit/ir/reorder.hpp"
27#include "gpu/jit/ir/tensor.hpp"
28#include "gpu/jit/reorder/ir_builder.hpp"
29#include "gpu/jit/utils/ngen_type_bridge.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace jit {
35
36template <ngen::HW hw = ngen::HW::Unknown>
37class reorder_kernel_t : public ir_kernel_t<hw> {
38public:
39 IR_KERNEL_FORWARD(hw)
40
41 reorder_kernel_t(const exec_config_t &exec_cfg,
42 const std::string &kernel_name, const kernel_info_t &kernel_info,
43 const layout_t &src_layout, const layout_t &dst_layout,
44 bool require_dpas, grf_mode_t grf_mode)
45 : ir_kernel_t<hw>(
46 kernel_name, exec_cfg, kernel_info, require_dpas, grf_mode) {
47
48 if (reorder_kernel_t<>::is_ir_based_reorder(src_layout, dst_layout)) {
49 reorder_ir_builder_t builder(
50 exec_cfg, kernel_info, src_layout, dst_layout);
51 stmt_t body = builder.stmt();
52 setup_interface(body);
53 generate_prologue();
54 expr_binding_t expr_binding(hw);
55 bind_external_vars(body, builder.kernel_grid(), builder.local_id(),
56 expr_binding);
57
58 // Generate assembly from IR.
59 ir_to_ngen_t<hw> visitor(this, expr_binding);
60 visitor.visit(body);
61
62 generate_epilogue();
63 return;
64 }
65
66 // Handle specific reorder versions.
67 setup_interface();
68 generate_prologue();
69
70 std::vector<std::string> arg_names(kernel_info.nargs());
71 for (int i = 0; i < kernel_info.nargs(); i++) {
72 arg_names[i] = kernel_info.arg_name(i);
73 }
74 src_ptr_ = getArgument(arg_names[0]);
75 dst_ptr_ = getArgument(arg_names[1]);
76 src_surf_ = Surface(getArgumentSurface(arg_names[0]));
77 dst_surf_ = Surface(getArgumentSurface(arg_names[1]));
78 elems_ = getArgument(arg_names[2]);
79
80 global_id_ = ra_.template alloc_sub<uint32_t>();
81
82 mul(1, global_id_, r0.ud(1), getLocalSize(0).uw());
83 add(1, global_id_, global_id_, getLocalID(0));
84
85 int elems_per_thr;
86 if (is_2d_reorder(src_layout, dst_layout, elems_per_thr)) {
87 emit_2d_reorder(src_layout, dst_layout, elems_per_thr);
88 } else {
89 ir_error_not_expected();
90 }
91
92 generate_epilogue();
93 }
94
95 void emit_2d_reorder(
96 const layout_t &_src, const layout_t &_dst, int elems_per_thr) {
97 auto tile = _src.split_into_max_tile(elems_per_thr, /*is_dense=*/true);
98 ir_assert(!tile.is_empty()) << "Can't split " << _src;
99
100 int simd_size = getSIMD();
101 auto src = _src.map(tile);
102 auto dst = _dst.map(tile);
103
104 int src_size = src.type().size();
105 int dst_size = dst.type().size();
106 int src_tile_bytes = src_size * elems_per_thr;
107 int dst_tile_bytes = dst_size * elems_per_thr;
108
109 int grf_size = ngen::GRF::bytes(hw);
110
111 auto S = ra_.alloc_range(utils::div_up(src_tile_bytes, grf_size));
112 auto D = ra_.alloc_range(utils::div_up(dst_tile_bytes, grf_size));
113
114 auto src_header = ra_.alloc();
115 auto dst_header = ra_.alloc();
116
117 // Prepare headers for loads and stores.
118 eshl(1, src_header.uq(0), global_id_,
119 math::ilog2q(src_tile_bytes / simd_size));
120 eshl(1, dst_header.uq(0), global_id_,
121 math::ilog2q(dst_tile_bytes / simd_size));
122 eadd(1, src_header.uq(0), src_header.uq(0), src_ptr_);
123 eadd(1, dst_header.uq(0), dst_header.uq(0), dst_ptr_);
124
125 int oword_bytes = 16;
126
127 // Load source tile.
128 int src_off = 0;
129 while (src_tile_bytes > 0) {
130 for (int i = 3; i >= 0; i--) {
131 int size = (1 << i) * oword_bytes;
132 if (src_tile_bytes >= size) {
133 load(16, S[src_off / grf_size],
134 ngen::block_oword(size / oword_bytes), A64,
135 src_header);
136 eadd(1, src_header.uq(0), src_header.uq(0), size);
137 src_tile_bytes -= size;
138 src_off += size;
139 break;
140 }
141 }
142 }
143
144 // Reorder source tile to destination tile.
145 ngen_register_scope_t scope(ra_);
146 reorder_2d_impl_t r(hw, src, dst);
147 reg_buf_t S_buf(hw, S);
148 reg_buf_t D_buf(hw, D);
149 r.emit(this, scope, S_buf, D_buf);
150
151 // Store destination tile.
152 int dst_off = 0;
153 while (dst_tile_bytes > 0) {
154 for (int i = 3; i >= 0; i--) {
155 int size = (1 << i) * oword_bytes;
156 if (dst_tile_bytes >= size) {
157 store(16, ngen::block_oword(size / oword_bytes), A64,
158 dst_header, D[dst_off / grf_size]);
159 eadd(1, dst_header.uq(0), dst_header.uq(0), size);
160 dst_tile_bytes -= size;
161 dst_off += size;
162 break;
163 }
164 }
165 }
166 }
167
168 static bool is_ir_based_reorder(const layout_t &src, const layout_t &dst) {
169 int dummy;
170 if (is_2d_reorder(src, dst, dummy)) return false;
171 return true;
172 }
173
174 static compute::nd_range_t nd_range(const exec_config_t &exec_cfg,
175 const layout_t &src, const layout_t &dst) {
176 const int simd = exec_cfg.simd();
177
178 int elems_per_thr;
179 if (is_2d_reorder(src, dst, elems_per_thr)) {
180 ir_assert(src.elems() == dst.elems());
181 return compute::nd_range_t(
182 {(int)utils::div_up(src.elems(), elems_per_thr) * simd, 1,
183 1});
184 }
185
186 // Handle IR-based reorder.
187 ir_assert(reorder_kernel_t<>::is_ir_based_reorder(src, dst));
188
189 return reorder_ir_builder_t::nd_range(exec_cfg, src, dst);
190 }
191
192private:
193 static bool is_2d_reorder(
194 const layout_t &src, const layout_t &dst, int &elems_per_thr) {
195 if (!src.type().is_bitwise_compatible(dst.type())) return false;
196 if (src.is_equal(dst)) return false;
197
198 const int hword_bytes = 32;
199 const int min_bytes_per_thr = hword_bytes;
200 const int max_bytes_per_thr = 32 * hword_bytes;
201
202 int type_size = src.type().size();
203 int max_elems_per_thr = max_bytes_per_thr / type_size;
204
205 auto tile = reorder_2d_impl_t::find_2d_tile(
206 src, dst, max_elems_per_thr, /*match_outer=*/true);
207
208 if (tile.is_empty()) return false;
209 if (tile.ndims() < 2) return false;
210
211 elems_per_thr = tile.elems();
212 if (!math::is_pow2(elems_per_thr)) return false;
213
214 int bytes_per_thr = elems_per_thr * type_size;
215 if (bytes_per_thr % hword_bytes != 0) return false;
216 if (bytes_per_thr < min_bytes_per_thr) return false;
217 if (bytes_per_thr > max_bytes_per_thr) return false;
218
219 return true;
220 }
221
222 ngen::Subregister src_ptr_;
223 ngen::Subregister dst_ptr_;
224 ngen::AddressBase src_surf_;
225 ngen::AddressBase dst_surf_;
226 ngen::Subregister elems_;
227 ngen::Subregister global_id_;
228};
229
230} // namespace jit
231} // namespace gpu
232} // namespace impl
233} // namespace dnnl
234
235#endif
236