1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_binary.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24status_t ref_binary_t::pd_t::init_conf(engine_t *engine) {
25 const memory_desc_wrapper src0_d(src_md(0));
26 const memory_desc_wrapper src1_d(src_md(1));
27 const memory_desc_wrapper dst_d(dst_md());
28
29 alg_kind_t alg = desc()->alg_kind;
30
31 const int ndims = src0_d.ndims();
32 conf.src0_md_info = memory_desc_info_t::create(src0_d);
33 conf.src1_md_info = memory_desc_info_t::create(src1_d);
34 conf.dst_md_info = memory_desc_info_t::create(dst_d);
35 conf.src0_data_type = src0_d.data_type();
36 conf.src1_data_type = src1_d.data_type();
37 conf.dst_data_type = dst_d.data_type();
38 conf.ndims = ndims;
39 bool is_src0_bcasted = false;
40 for (int i = 0; i < MAX_NDIMS; ++i) {
41 conf.src0_bcast_dims[i] = i < ndims
42 ? src0_d.dims()[i] == 1 && src0_d.dims()[i] != src1_d.dims()[i]
43 : 0;
44 is_src0_bcasted = is_src0_bcasted || conf.src0_bcast_dims[i];
45 conf.src1_bcast_dims[i] = i < ndims
46 ? src1_d.dims()[i] == 1 && src0_d.dims()[i] != src1_d.dims()[i]
47 : 0;
48 }
49 conf.is_add = (alg == alg_kind::binary_add);
50 conf.is_mul = (alg == alg_kind::binary_mul);
51 conf.is_max = (alg == alg_kind::binary_max);
52 conf.is_min = (alg == alg_kind::binary_min);
53 conf.is_div = (alg == alg_kind::binary_div);
54 conf.is_sub = (alg == alg_kind::binary_sub);
55 conf.is_ge = (alg == alg_kind::binary_ge);
56 conf.is_gt = (alg == alg_kind::binary_gt);
57 conf.is_le = (alg == alg_kind::binary_le);
58 conf.is_lt = (alg == alg_kind::binary_lt);
59 conf.is_eq = (alg == alg_kind::binary_eq);
60 conf.is_ne = (alg == alg_kind::binary_ne);
61 conf.is_tensor_op = is_tensor_op();
62 conf.is_dense = dst_d.is_dense();
63 conf.same_src_dt = (src0_d.data_type() == src1_d.data_type());
64 conf.is_same_md = (src0_d == dst_d) && (src1_d == dst_d);
65 conf.attr_info = attr_info_t::create(attr());
66 conf.with_binary_post_op
67 = attr()->post_ops_.find(primitive_kind::binary) != -1;
68 int ic_block_sz = 1;
69 conf.use_unroll_16b = false;
70 conf.src0_unroll_16b = false;
71
72 auto &blk0 = src0_d.blocking_desc();
73 auto &blk1 = src1_d.blocking_desc();
74 auto &blkd = dst_d.blocking_desc();
75 bool is_16b_blk0 = (blk0.inner_nblks >= 1)
76 && (blk0.inner_idxs[blk0.inner_nblks - 1] == 1)
77 && (blk0.inner_blks[blk0.inner_nblks - 1] == 16);
78 bool is_16b_blk1 = (blk1.inner_nblks >= 1)
79 && (blk1.inner_idxs[blk1.inner_nblks - 1] == 1)
80 && (blk1.inner_blks[blk1.inner_nblks - 1] == 16);
81 bool is_16b_blkd = (blkd.inner_nblks >= 1)
82 && (blkd.inner_idxs[blkd.inner_nblks - 1] == 1)
83 && (blkd.inner_blks[blkd.inner_nblks - 1] == 16);
84
85 if (is_16b_blkd && !conf.is_tensor_op && !is_src0_bcasted) {
86 // If: in case when both are blocked
87 // Else: only src0 is blocked
88 if (is_16b_blk0 && is_16b_blk1) {
89 ic_block_sz = 16;
90 conf.use_unroll_16b = true;
91 } else if (is_16b_blk0 && blk1.inner_nblks == 0) {
92 ic_block_sz = 16;
93 conf.src0_unroll_16b = true;
94 }
95 }
96
97 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
98 conf.dispatch = compute_engine->create_dispatch(dst_d.md_);
99 if (conf.is_tensor_op && conf.is_dense && conf.is_same_md
100 && !conf.with_binary_post_op) {
101 conf.dispatch.define_dim("IDX", 0, dst_d.nelems());
102 } else {
103 for (int i = 0; i < MAX_NDIMS; ++i) {
104 if (i == 1 && (conf.use_unroll_16b || conf.src0_unroll_16b)) {
105 // changing value for broadcasting offsets
106 // division by IC for enabling blocking within kernel
107 conf.dispatch.define_dim(utils::format("D%d", i),
108 nstl::min(i, ndims - 1),
109 i < ndims ? dst_d.padded_dims()[i] : 1, ic_block_sz);
110 } else {
111 conf.dispatch.define_dim(utils::format("D%d", i),
112 nstl::min(i, ndims - 1),
113 i < ndims ? dst_d.padded_dims()[i] : 1);
114 }
115 }
116 }
117 conf.dispatch.generate();
118 return status::success;
119}
120
121status_t ref_binary_t::pd_t::init_kernel_ctx(
122 compute::kernel_ctx_t &kernel_ctx) const {
123 kernel_ctx.set_data_type(conf.src0_data_type);
124 kernel_ctx.set_data_type(conf.src1_data_type);
125 kernel_ctx.set_data_type(conf.dst_data_type);
126 kernel_ctx.define_int("NDIMS", conf.ndims);
127 kernel_ctx.define_int("IS_MUL", conf.is_mul);
128 kernel_ctx.define_int("IS_ADD", conf.is_add);
129 kernel_ctx.define_int("IS_MAX", conf.is_max);
130 kernel_ctx.define_int("IS_MIN", conf.is_min);
131 kernel_ctx.define_int("IS_DIV", conf.is_div);
132 kernel_ctx.define_int("IS_SUB", conf.is_sub);
133 kernel_ctx.define_int("IS_GE", conf.is_ge);
134 kernel_ctx.define_int("IS_GT", conf.is_gt);
135 kernel_ctx.define_int("IS_LE", conf.is_le);
136 kernel_ctx.define_int("IS_LT", conf.is_lt);
137 kernel_ctx.define_int("IS_EQ", conf.is_eq);
138 kernel_ctx.define_int("IS_NE", conf.is_ne);
139 kernel_ctx.define_int("IS_TENSOR_OP", conf.is_tensor_op);
140 kernel_ctx.define_int("IS_DENSE", conf.is_dense);
141 kernel_ctx.define_int("IS_SAME_MD", conf.is_same_md);
142 kernel_ctx.define_int("WITH_BINARY_POST_OP", conf.with_binary_post_op);
143 kernel_ctx.define_int("SAME_SRC_DT", conf.same_src_dt);
144
145 kernel_ctx.define_int("SRC0_BCAST_DIM0", conf.src0_bcast_dims[0]);
146 kernel_ctx.define_int("SRC0_BCAST_DIM1", conf.src0_bcast_dims[1]);
147 kernel_ctx.define_int("SRC0_BCAST_DIM2", conf.src0_bcast_dims[2]);
148 kernel_ctx.define_int("SRC0_BCAST_DIM3", conf.src0_bcast_dims[3]);
149 kernel_ctx.define_int("SRC0_BCAST_DIM4", conf.src0_bcast_dims[4]);
150 kernel_ctx.define_int("SRC0_BCAST_DIM5", conf.src0_bcast_dims[5]);
151
152 kernel_ctx.define_int("SRC1_BCAST_DIM0", conf.src1_bcast_dims[0]);
153 kernel_ctx.define_int("SRC1_BCAST_DIM1", conf.src1_bcast_dims[1]);
154 kernel_ctx.define_int("SRC1_BCAST_DIM2", conf.src1_bcast_dims[2]);
155 kernel_ctx.define_int("SRC1_BCAST_DIM3", conf.src1_bcast_dims[3]);
156 kernel_ctx.define_int("SRC1_BCAST_DIM4", conf.src1_bcast_dims[4]);
157 kernel_ctx.define_int("SRC1_BCAST_DIM5", conf.src1_bcast_dims[5]);
158
159 kernel_ctx.define_int("USE_UNROLL_16B", conf.use_unroll_16b);
160 kernel_ctx.define_int("SRC0_UNROLL_16B", conf.src0_unroll_16b);
161 kernel_ctx.define_int("SUB_GROUP_SIZE", 1);
162
163 def_memory_desc_info(kernel_ctx, conf.src0_md_info, "SRC0");
164 def_memory_desc_info(kernel_ctx, conf.src1_md_info, "SRC1");
165 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
166
167 def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_);
168
169 def_dispatch(kernel_ctx, conf.dispatch);
170
171 return status::success;
172}
173
174status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const {
175
176 status_t status = status::success;
177
178 auto &src0 = CTX_IN_STORAGE(DNNL_ARG_SRC_0);
179 auto &src1 = CTX_IN_STORAGE(DNNL_ARG_SRC_1);
180 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
181 CHECK(status);
182
183 const auto &conf = pd()->conf;
184
185 auto &src0_scale = CTX_IN_STORAGE(DNNL_ARG_SRC_0 | DNNL_ARG_ATTR_SCALES);
186
187 auto &src1_scale = CTX_IN_STORAGE(DNNL_ARG_SRC_1 | DNNL_ARG_ATTR_SCALES);
188
189 compute::kernel_arg_list_t arg_list;
190 arg_list.set(0, src0);
191 arg_list.set(1, src1);
192 arg_list.set(2, dst);
193
194 unsigned arg_idx = append_post_ops_to_arg_list(
195 ctx, arg_list, 3, pd()->attr()->post_ops_);
196
197 arg_list.set(arg_idx++, src0_scale);
198 arg_list.set(arg_idx, src1_scale);
199
200 auto nd_range = conf.dispatch.nd_range();
201
202 status = parallel_for(ctx, nd_range, kernel_, arg_list);
203 return status;
204}
205
206} // namespace ocl
207} // namespace gpu
208} // namespace impl
209} // namespace dnnl
210
211// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
212