1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gen9_binary.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24// Gen9_binary requires that dst and both src tensors have the same
25// format, with one exception: it also works if src0 and dst are blocked,
26// src1 is plain, src's D1 is divisible by 16 and src1 has broadcast on all
27// dimensions except D0 and D1. This function checks for such circumstance.
28bool perf_workaround(const memory_desc_t *md) {
29 if (md->ndims < 2) { return false; }
30 if (md->format_desc.blocking.inner_nblks != 0) { return false; }
31 if (md->format_desc.blocking.strides[1] != 1) { return false; }
32 if (md->dims[1] % 16 != 0) { return false; }
33 for (int i = 2; i < md->ndims; i++) {
34 if (md->dims[i] != 1) { return false; }
35 }
36 return true;
37}
38
39status_t gen9_binary_t::pd_t::init_conf(engine_t *engine) {
40 const memory_desc_wrapper src0_d(src_md(0));
41 const memory_desc_wrapper src1_d(src_md(1));
42 const memory_desc_wrapper dst_d(dst_md());
43
44 alg_kind_t alg = desc()->alg_kind;
45
46 const int ndims = src0_d.ndims();
47 conf.src0_md_info = memory_desc_info_t::create(src0_d);
48 conf.src1_md_info = memory_desc_info_t::create(src1_d);
49 conf.dst_md_info = memory_desc_info_t::create(dst_d);
50 conf.attr_info = attr_info_t::create(attr());
51 conf.src0_data_type = src0_d.data_type();
52 conf.src1_data_type = src1_d.data_type();
53 conf.dst_data_type = dst_d.data_type();
54 conf.ndims = ndims;
55 conf.is_add = (alg == alg_kind::binary_add);
56 conf.is_mul = (alg == alg_kind::binary_mul);
57 conf.is_max = (alg == alg_kind::binary_max);
58 conf.is_min = (alg == alg_kind::binary_min);
59 conf.is_div = (alg == alg_kind::binary_div);
60 conf.is_sub = (alg == alg_kind::binary_sub);
61 conf.is_ge = (alg == alg_kind::binary_ge);
62 conf.is_gt = (alg == alg_kind::binary_gt);
63 conf.is_le = (alg == alg_kind::binary_le);
64 conf.is_lt = (alg == alg_kind::binary_lt);
65 conf.is_eq = (alg == alg_kind::binary_eq);
66 conf.is_ne = (alg == alg_kind::binary_ne);
67 conf.is_tensor_op = is_tensor_op();
68 conf.is_dense = dst_d.is_dense();
69 conf.same_src_dt = (src0_d.data_type() == src1_d.data_type());
70 conf.is_same_md = (src0_d == dst_d) && (src1_d == dst_d);
71 conf.plain_to_ABcd4a4b = false;
72 conf.isXa16b = false;
73 conf.mb_block = 0;
74
75 for (int i = 0; i < MAX_NDIMS; ++i) {
76 // Kernel doesn't support src0 broadcast
77 if (i < ndims && src0_d.dims()[i] == 1
78 && src0_d.dims()[i] != src1_d.dims()[i]) {
79 return status::unimplemented;
80 }
81 conf.src1_bcast_dims[i] = i < ndims ? broadcast_dims()[i] : 1;
82 }
83
84 if (conf.src1_bcast_dims[1] && !conf.src1_bcast_dims[ndims - 1]) {
85 conf.nvect = 1;
86 } else {
87 conf.nvect = 8;
88 while (dst_d.dims()[ndims - 1] % conf.nvect != 0) {
89 conf.nvect /= 2;
90 }
91 }
92
93 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
94 conf.dispatch = compute_engine->create_dispatch(dst_d.md_);
95
96 using namespace dnnl::impl::format_tag;
97
98 conf.isXa16b = src0_d.matches_one_of_tag(
99 ABcd32a16b, ABcde32a16b, ABcd16a16b, ABcde16a16b)
100 && dst_d.matches_one_of_tag(
101 ABcd32a16b, ABcde32a16b, ABcd16a16b, ABcde16a16b)
102 && src1_d.matches_one_of_tag(
103 ABcd32a16b, ABcde32a16b, ABcd16a16b, ABcde16a16b);
104 format_tag_t dst_tag = dst_d.matches_one_of_tag(nc, ncw, nchw, ncdhw);
105 conf.is_plain_layout = dst_tag;
106 if (!conf.is_plain_layout) {
107 format_tag_t src_tag = src0_d.matches_one_of_tag(abcd, acdb);
108 const auto &padded_dims = dst_d.padded_dims();
109 if (src1_d.matches_tag(src_tag) && dst_d.matches_one_of_tag(ABcd4a4b)
110 && src0_d.is_dense() && dst_d.is_dense(true)
111 && padded_dims[3] % 16 == 0 && dst_d.data_type() != dnnl_f32) {
112 dim_t blocks[MAX_NDIMS] = {1, 1, 1, 1, 1, 1};
113 auto &blk = dst_d.blocking_desc();
114 int b_block = blk.inner_blks[blk.inner_nblks - 1];
115 int sub_group_size = (b_block == 2 ? 8 : 16);
116 blocks[0] = 4;
117 blocks[1] = b_block;
118 int vect_dim = 3;
119 conf.nvect = 8;
120 for (int i = 0; i < MAX_NDIMS; ++i) {
121 auto dim_str = utils::format("D%d", i);
122 if (i < dst_d.ndims()) {
123 conf.dispatch.define_dim(
124 dim_str, i, padded_dims[i], blocks[i]);
125 } else {
126 conf.dispatch.define_dim(dim_str, 1);
127 }
128 }
129
130 auto dim_str = utils::format("D%d", vect_dim);
131 CHECK(conf.dispatch.vectorize_dim(dim_str, sub_group_size));
132 conf.plain_to_ABcd4a4b = true;
133 } else if (conf.isXa16b) {
134 conf.nvect = 8;
135 int channel_blk = 16;
136 const int vect_dim_size = 16;
137 const int padded_channels = padded_dims[1];
138 conf.mb_block = dst_d.md_->format_desc.blocking.inner_blks[0];
139 while (padded_channels % (vect_dim_size * channel_blk) != 0) {
140 channel_blk /= 2;
141 }
142 dim_t blocks[MAX_NDIMS] = {8, channel_blk, 1, 1, 1, 1};
143 for (int i = 0; i < MAX_NDIMS; ++i) {
144 auto dim_str = utils::format("D%d", i);
145 if (i < dst_d.ndims()) {
146 conf.dispatch.define_dim(
147 dim_str, i, padded_dims[i], blocks[i]);
148 if (i == 1) {
149 CHECK(conf.dispatch.vectorize_dim(
150 dim_str, vect_dim_size));
151 }
152 } else {
153 conf.dispatch.define_dim(dim_str, 1);
154 }
155 }
156 } else {
157 auto format_fits = [](const memory_desc_t &md) {
158 if (md.format_kind != dnnl_blocked) { return false; }
159 auto blocking = md.format_desc.blocking;
160 return blocking.inner_nblks == 1 && blocking.inner_idxs[0] == 1
161 && blocking.inner_blks[0] == 16 && md.dims[1] % 16 == 0;
162 };
163 if (!(format_fits(*src_md(0)) && format_fits(*dst_md())
164 && (format_fits(*src_md(1))
165 || perf_workaround(src_md(1))))) {
166 return status::unimplemented;
167 }
168 format_tag_t src_tag
169 = src0_d.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b);
170 bool is16b
171 = src1_d.matches_tag(src_tag) && dst_d.matches_tag(src_tag);
172 int idx = 0;
173 if (!is16b) {
174 idx = 1;
175 // Setting the MB as the innermost dim for optimized performance
176 // Hence starting i = 1, ignoring MB
177 conf.dispatch.define_dim_with_nesting_level(
178 "D0", ndims, dst_d.dims()[0], 1);
179 }
180 for (int i = idx; i < MAX_NDIMS; ++i) {
181 int dim = i < ndims ? dst_d.dims()[i] : 1;
182 if (i == 1) {
183 conf.dispatch.define_dim(utils::format("D%d", i),
184 nstl::min(i, ndims - 1), dim, 1);
185 CHECK(conf.dispatch.vectorize_dim("D1", 16));
186 } else if (i == ndims - 1) {
187 conf.dispatch.define_dim(utils::format("D%d", i),
188 nstl::min(i, ndims - 1), dim, conf.nvect);
189 } else {
190 conf.dispatch.define_dim(utils::format("D%d", i),
191 nstl::min(i, ndims - 1), dim, 1);
192 }
193 }
194 }
195 } else {
196 if (!src0_d.matches_tag(dst_tag) || !src1_d.matches_tag(dst_tag)) {
197 return status::unimplemented;
198 }
199
200 if (dst_md()->dims[dst_md()->ndims - 1] % 16 != 0)
201 return status::unimplemented;
202 conf.nvect = 16;
203 while ((dst_d.dims()[ndims - 1] / 16) % conf.nvect != 0) {
204 --conf.nvect;
205 }
206
207 int mixed_dim = 1;
208 for (int i = 0; i < ndims; ++i) {
209 mixed_dim *= dst_d.dims()[i];
210 }
211 conf.dispatch.define_dim("MIXED_DIM", 0, mixed_dim, conf.nvect);
212 CHECK(conf.dispatch.vectorize_dim("MIXED_DIM", 16));
213 }
214
215 conf.dispatch.generate();
216 return status::success;
217}
218
219status_t gen9_binary_t::pd_t::init_kernel_ctx(
220 compute::kernel_ctx_t &kernel_ctx) const {
221 kernel_ctx.set_data_type(conf.src0_data_type);
222 kernel_ctx.define_int("SUB_GROUP_SIZE", 16);
223 kernel_ctx.define_int("NDIMS", conf.ndims);
224 kernel_ctx.define_int("IS_PLAIN_LAYOUT", conf.is_plain_layout);
225 kernel_ctx.define_int("PLAIN_TO_ABCD4AXB", conf.plain_to_ABcd4a4b);
226 kernel_ctx.define_int("IS_XA16B", conf.isXa16b);
227 kernel_ctx.define_int("IS_MUL", conf.is_mul);
228 kernel_ctx.define_int("IS_ADD", conf.is_add);
229 kernel_ctx.define_int("IS_MAX", conf.is_max);
230 kernel_ctx.define_int("IS_MIN", conf.is_min);
231 kernel_ctx.define_int("IS_DIV", conf.is_div);
232 kernel_ctx.define_int("IS_SUB", conf.is_sub);
233 kernel_ctx.define_int("IS_GE", conf.is_ge);
234 kernel_ctx.define_int("IS_GT", conf.is_gt);
235 kernel_ctx.define_int("IS_LE", conf.is_le);
236 kernel_ctx.define_int("IS_LT", conf.is_lt);
237 kernel_ctx.define_int("IS_EQ", conf.is_eq);
238 kernel_ctx.define_int("IS_NE", conf.is_ne);
239 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
240 kernel_ctx.define_int("SAME_SRC_DT", conf.same_src_dt);
241 kernel_ctx.define_int("BCAST_DIM0", conf.src1_bcast_dims[0]);
242 kernel_ctx.define_int("BCAST_DIM1", conf.src1_bcast_dims[1]);
243 kernel_ctx.define_int("BCAST_DIM2", conf.src1_bcast_dims[2]);
244 kernel_ctx.define_int("BCAST_DIM3", conf.src1_bcast_dims[3]);
245 kernel_ctx.define_int("BCAST_DIM4", conf.src1_bcast_dims[4]);
246 kernel_ctx.define_int("BCAST_DIM5", conf.src1_bcast_dims[5]);
247 kernel_ctx.define_int(
248 "BCAST_AT_INNERMOST_DIM", conf.src1_bcast_dims[conf.ndims - 1]);
249 kernel_ctx.define_int("NVECT", conf.nvect);
250 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
251 kernel_ctx.add_option("-Dcl_intel_subgroups_uchar");
252
253 def_memory_desc_info(kernel_ctx, conf.src0_md_info, "SRC0");
254 def_memory_desc_info(kernel_ctx, conf.src1_md_info, "SRC1");
255 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
256
257 def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_);
258
259 def_dispatch(kernel_ctx, conf.dispatch);
260
261 return status::success;
262}
263
264} // namespace ocl
265} // namespace gpu
266} // namespace impl
267} // namespace dnnl
268
269// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
270