1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_binary.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | // Gen9_binary requires that dst and both src tensors have the same |
25 | // format, with one exception: it also works if src0 and dst are blocked, |
26 | // src1 is plain, src's D1 is divisible by 16 and src1 has broadcast on all |
27 | // dimensions except D0 and D1. This function checks for such circumstance. |
28 | bool perf_workaround(const memory_desc_t *md) { |
29 | if (md->ndims < 2) { return false; } |
30 | if (md->format_desc.blocking.inner_nblks != 0) { return false; } |
31 | if (md->format_desc.blocking.strides[1] != 1) { return false; } |
32 | if (md->dims[1] % 16 != 0) { return false; } |
33 | for (int i = 2; i < md->ndims; i++) { |
34 | if (md->dims[i] != 1) { return false; } |
35 | } |
36 | return true; |
37 | } |
38 | |
39 | status_t gen9_binary_t::pd_t::init_conf(engine_t *engine) { |
40 | const memory_desc_wrapper src0_d(src_md(0)); |
41 | const memory_desc_wrapper src1_d(src_md(1)); |
42 | const memory_desc_wrapper dst_d(dst_md()); |
43 | |
44 | alg_kind_t alg = desc()->alg_kind; |
45 | |
46 | const int ndims = src0_d.ndims(); |
47 | conf.src0_md_info = memory_desc_info_t::create(src0_d); |
48 | conf.src1_md_info = memory_desc_info_t::create(src1_d); |
49 | conf.dst_md_info = memory_desc_info_t::create(dst_d); |
50 | conf.attr_info = attr_info_t::create(attr()); |
51 | conf.src0_data_type = src0_d.data_type(); |
52 | conf.src1_data_type = src1_d.data_type(); |
53 | conf.dst_data_type = dst_d.data_type(); |
54 | conf.ndims = ndims; |
55 | conf.is_add = (alg == alg_kind::binary_add); |
56 | conf.is_mul = (alg == alg_kind::binary_mul); |
57 | conf.is_max = (alg == alg_kind::binary_max); |
58 | conf.is_min = (alg == alg_kind::binary_min); |
59 | conf.is_div = (alg == alg_kind::binary_div); |
60 | conf.is_sub = (alg == alg_kind::binary_sub); |
61 | conf.is_ge = (alg == alg_kind::binary_ge); |
62 | conf.is_gt = (alg == alg_kind::binary_gt); |
63 | conf.is_le = (alg == alg_kind::binary_le); |
64 | conf.is_lt = (alg == alg_kind::binary_lt); |
65 | conf.is_eq = (alg == alg_kind::binary_eq); |
66 | conf.is_ne = (alg == alg_kind::binary_ne); |
67 | conf.is_tensor_op = is_tensor_op(); |
68 | conf.is_dense = dst_d.is_dense(); |
69 | conf.same_src_dt = (src0_d.data_type() == src1_d.data_type()); |
70 | conf.is_same_md = (src0_d == dst_d) && (src1_d == dst_d); |
71 | conf.plain_to_ABcd4a4b = false; |
72 | conf.isXa16b = false; |
73 | conf.mb_block = 0; |
74 | |
75 | for (int i = 0; i < MAX_NDIMS; ++i) { |
76 | // Kernel doesn't support src0 broadcast |
77 | if (i < ndims && src0_d.dims()[i] == 1 |
78 | && src0_d.dims()[i] != src1_d.dims()[i]) { |
79 | return status::unimplemented; |
80 | } |
81 | conf.src1_bcast_dims[i] = i < ndims ? broadcast_dims()[i] : 1; |
82 | } |
83 | |
84 | if (conf.src1_bcast_dims[1] && !conf.src1_bcast_dims[ndims - 1]) { |
85 | conf.nvect = 1; |
86 | } else { |
87 | conf.nvect = 8; |
88 | while (dst_d.dims()[ndims - 1] % conf.nvect != 0) { |
89 | conf.nvect /= 2; |
90 | } |
91 | } |
92 | |
93 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
94 | conf.dispatch = compute_engine->create_dispatch(dst_d.md_); |
95 | |
96 | using namespace dnnl::impl::format_tag; |
97 | |
98 | conf.isXa16b = src0_d.matches_one_of_tag( |
99 | ABcd32a16b, ABcde32a16b, ABcd16a16b, ABcde16a16b) |
100 | && dst_d.matches_one_of_tag( |
101 | ABcd32a16b, ABcde32a16b, ABcd16a16b, ABcde16a16b) |
102 | && src1_d.matches_one_of_tag( |
103 | ABcd32a16b, ABcde32a16b, ABcd16a16b, ABcde16a16b); |
104 | format_tag_t dst_tag = dst_d.matches_one_of_tag(nc, ncw, nchw, ncdhw); |
105 | conf.is_plain_layout = dst_tag; |
106 | if (!conf.is_plain_layout) { |
107 | format_tag_t src_tag = src0_d.matches_one_of_tag(abcd, acdb); |
108 | const auto &padded_dims = dst_d.padded_dims(); |
109 | if (src1_d.matches_tag(src_tag) && dst_d.matches_one_of_tag(ABcd4a4b) |
110 | && src0_d.is_dense() && dst_d.is_dense(true) |
111 | && padded_dims[3] % 16 == 0 && dst_d.data_type() != dnnl_f32) { |
112 | dim_t blocks[MAX_NDIMS] = {1, 1, 1, 1, 1, 1}; |
113 | auto &blk = dst_d.blocking_desc(); |
114 | int b_block = blk.inner_blks[blk.inner_nblks - 1]; |
115 | int sub_group_size = (b_block == 2 ? 8 : 16); |
116 | blocks[0] = 4; |
117 | blocks[1] = b_block; |
118 | int vect_dim = 3; |
119 | conf.nvect = 8; |
120 | for (int i = 0; i < MAX_NDIMS; ++i) { |
121 | auto dim_str = utils::format("D%d" , i); |
122 | if (i < dst_d.ndims()) { |
123 | conf.dispatch.define_dim( |
124 | dim_str, i, padded_dims[i], blocks[i]); |
125 | } else { |
126 | conf.dispatch.define_dim(dim_str, 1); |
127 | } |
128 | } |
129 | |
130 | auto dim_str = utils::format("D%d" , vect_dim); |
131 | CHECK(conf.dispatch.vectorize_dim(dim_str, sub_group_size)); |
132 | conf.plain_to_ABcd4a4b = true; |
133 | } else if (conf.isXa16b) { |
134 | conf.nvect = 8; |
135 | int channel_blk = 16; |
136 | const int vect_dim_size = 16; |
137 | const int padded_channels = padded_dims[1]; |
138 | conf.mb_block = dst_d.md_->format_desc.blocking.inner_blks[0]; |
139 | while (padded_channels % (vect_dim_size * channel_blk) != 0) { |
140 | channel_blk /= 2; |
141 | } |
142 | dim_t blocks[MAX_NDIMS] = {8, channel_blk, 1, 1, 1, 1}; |
143 | for (int i = 0; i < MAX_NDIMS; ++i) { |
144 | auto dim_str = utils::format("D%d" , i); |
145 | if (i < dst_d.ndims()) { |
146 | conf.dispatch.define_dim( |
147 | dim_str, i, padded_dims[i], blocks[i]); |
148 | if (i == 1) { |
149 | CHECK(conf.dispatch.vectorize_dim( |
150 | dim_str, vect_dim_size)); |
151 | } |
152 | } else { |
153 | conf.dispatch.define_dim(dim_str, 1); |
154 | } |
155 | } |
156 | } else { |
157 | auto format_fits = [](const memory_desc_t &md) { |
158 | if (md.format_kind != dnnl_blocked) { return false; } |
159 | auto blocking = md.format_desc.blocking; |
160 | return blocking.inner_nblks == 1 && blocking.inner_idxs[0] == 1 |
161 | && blocking.inner_blks[0] == 16 && md.dims[1] % 16 == 0; |
162 | }; |
163 | if (!(format_fits(*src_md(0)) && format_fits(*dst_md()) |
164 | && (format_fits(*src_md(1)) |
165 | || perf_workaround(src_md(1))))) { |
166 | return status::unimplemented; |
167 | } |
168 | format_tag_t src_tag |
169 | = src0_d.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b); |
170 | bool is16b |
171 | = src1_d.matches_tag(src_tag) && dst_d.matches_tag(src_tag); |
172 | int idx = 0; |
173 | if (!is16b) { |
174 | idx = 1; |
175 | // Setting the MB as the innermost dim for optimized performance |
176 | // Hence starting i = 1, ignoring MB |
177 | conf.dispatch.define_dim_with_nesting_level( |
178 | "D0" , ndims, dst_d.dims()[0], 1); |
179 | } |
180 | for (int i = idx; i < MAX_NDIMS; ++i) { |
181 | int dim = i < ndims ? dst_d.dims()[i] : 1; |
182 | if (i == 1) { |
183 | conf.dispatch.define_dim(utils::format("D%d" , i), |
184 | nstl::min(i, ndims - 1), dim, 1); |
185 | CHECK(conf.dispatch.vectorize_dim("D1" , 16)); |
186 | } else if (i == ndims - 1) { |
187 | conf.dispatch.define_dim(utils::format("D%d" , i), |
188 | nstl::min(i, ndims - 1), dim, conf.nvect); |
189 | } else { |
190 | conf.dispatch.define_dim(utils::format("D%d" , i), |
191 | nstl::min(i, ndims - 1), dim, 1); |
192 | } |
193 | } |
194 | } |
195 | } else { |
196 | if (!src0_d.matches_tag(dst_tag) || !src1_d.matches_tag(dst_tag)) { |
197 | return status::unimplemented; |
198 | } |
199 | |
200 | if (dst_md()->dims[dst_md()->ndims - 1] % 16 != 0) |
201 | return status::unimplemented; |
202 | conf.nvect = 16; |
203 | while ((dst_d.dims()[ndims - 1] / 16) % conf.nvect != 0) { |
204 | --conf.nvect; |
205 | } |
206 | |
207 | int mixed_dim = 1; |
208 | for (int i = 0; i < ndims; ++i) { |
209 | mixed_dim *= dst_d.dims()[i]; |
210 | } |
211 | conf.dispatch.define_dim("MIXED_DIM" , 0, mixed_dim, conf.nvect); |
212 | CHECK(conf.dispatch.vectorize_dim("MIXED_DIM" , 16)); |
213 | } |
214 | |
215 | conf.dispatch.generate(); |
216 | return status::success; |
217 | } |
218 | |
219 | status_t gen9_binary_t::pd_t::init_kernel_ctx( |
220 | compute::kernel_ctx_t &kernel_ctx) const { |
221 | kernel_ctx.set_data_type(conf.src0_data_type); |
222 | kernel_ctx.define_int("SUB_GROUP_SIZE" , 16); |
223 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
224 | kernel_ctx.define_int("IS_PLAIN_LAYOUT" , conf.is_plain_layout); |
225 | kernel_ctx.define_int("PLAIN_TO_ABCD4AXB" , conf.plain_to_ABcd4a4b); |
226 | kernel_ctx.define_int("IS_XA16B" , conf.isXa16b); |
227 | kernel_ctx.define_int("IS_MUL" , conf.is_mul); |
228 | kernel_ctx.define_int("IS_ADD" , conf.is_add); |
229 | kernel_ctx.define_int("IS_MAX" , conf.is_max); |
230 | kernel_ctx.define_int("IS_MIN" , conf.is_min); |
231 | kernel_ctx.define_int("IS_DIV" , conf.is_div); |
232 | kernel_ctx.define_int("IS_SUB" , conf.is_sub); |
233 | kernel_ctx.define_int("IS_GE" , conf.is_ge); |
234 | kernel_ctx.define_int("IS_GT" , conf.is_gt); |
235 | kernel_ctx.define_int("IS_LE" , conf.is_le); |
236 | kernel_ctx.define_int("IS_LT" , conf.is_lt); |
237 | kernel_ctx.define_int("IS_EQ" , conf.is_eq); |
238 | kernel_ctx.define_int("IS_NE" , conf.is_ne); |
239 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
240 | kernel_ctx.define_int("SAME_SRC_DT" , conf.same_src_dt); |
241 | kernel_ctx.define_int("BCAST_DIM0" , conf.src1_bcast_dims[0]); |
242 | kernel_ctx.define_int("BCAST_DIM1" , conf.src1_bcast_dims[1]); |
243 | kernel_ctx.define_int("BCAST_DIM2" , conf.src1_bcast_dims[2]); |
244 | kernel_ctx.define_int("BCAST_DIM3" , conf.src1_bcast_dims[3]); |
245 | kernel_ctx.define_int("BCAST_DIM4" , conf.src1_bcast_dims[4]); |
246 | kernel_ctx.define_int("BCAST_DIM5" , conf.src1_bcast_dims[5]); |
247 | kernel_ctx.define_int( |
248 | "BCAST_AT_INNERMOST_DIM" , conf.src1_bcast_dims[conf.ndims - 1]); |
249 | kernel_ctx.define_int("NVECT" , conf.nvect); |
250 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
251 | kernel_ctx.add_option("-Dcl_intel_subgroups_uchar" ); |
252 | |
253 | def_memory_desc_info(kernel_ctx, conf.src0_md_info, "SRC0" ); |
254 | def_memory_desc_info(kernel_ctx, conf.src1_md_info, "SRC1" ); |
255 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
256 | |
257 | def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_); |
258 | |
259 | def_dispatch(kernel_ctx, conf.dispatch); |
260 | |
261 | return status::success; |
262 | } |
263 | |
264 | } // namespace ocl |
265 | } // namespace gpu |
266 | } // namespace impl |
267 | } // namespace dnnl |
268 | |
269 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
270 | |