1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_binary.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | status_t ref_binary_t::pd_t::init_conf(engine_t *engine) { |
25 | const memory_desc_wrapper src0_d(src_md(0)); |
26 | const memory_desc_wrapper src1_d(src_md(1)); |
27 | const memory_desc_wrapper dst_d(dst_md()); |
28 | |
29 | alg_kind_t alg = desc()->alg_kind; |
30 | |
31 | const int ndims = src0_d.ndims(); |
32 | conf.src0_md_info = memory_desc_info_t::create(src0_d); |
33 | conf.src1_md_info = memory_desc_info_t::create(src1_d); |
34 | conf.dst_md_info = memory_desc_info_t::create(dst_d); |
35 | conf.src0_data_type = src0_d.data_type(); |
36 | conf.src1_data_type = src1_d.data_type(); |
37 | conf.dst_data_type = dst_d.data_type(); |
38 | conf.ndims = ndims; |
39 | bool is_src0_bcasted = false; |
40 | for (int i = 0; i < MAX_NDIMS; ++i) { |
41 | conf.src0_bcast_dims[i] = i < ndims |
42 | ? src0_d.dims()[i] == 1 && src0_d.dims()[i] != src1_d.dims()[i] |
43 | : 0; |
44 | is_src0_bcasted = is_src0_bcasted || conf.src0_bcast_dims[i]; |
45 | conf.src1_bcast_dims[i] = i < ndims |
46 | ? src1_d.dims()[i] == 1 && src0_d.dims()[i] != src1_d.dims()[i] |
47 | : 0; |
48 | } |
49 | conf.is_add = (alg == alg_kind::binary_add); |
50 | conf.is_mul = (alg == alg_kind::binary_mul); |
51 | conf.is_max = (alg == alg_kind::binary_max); |
52 | conf.is_min = (alg == alg_kind::binary_min); |
53 | conf.is_div = (alg == alg_kind::binary_div); |
54 | conf.is_sub = (alg == alg_kind::binary_sub); |
55 | conf.is_ge = (alg == alg_kind::binary_ge); |
56 | conf.is_gt = (alg == alg_kind::binary_gt); |
57 | conf.is_le = (alg == alg_kind::binary_le); |
58 | conf.is_lt = (alg == alg_kind::binary_lt); |
59 | conf.is_eq = (alg == alg_kind::binary_eq); |
60 | conf.is_ne = (alg == alg_kind::binary_ne); |
61 | conf.is_tensor_op = is_tensor_op(); |
62 | conf.is_dense = dst_d.is_dense(); |
63 | conf.same_src_dt = (src0_d.data_type() == src1_d.data_type()); |
64 | conf.is_same_md = (src0_d == dst_d) && (src1_d == dst_d); |
65 | conf.attr_info = attr_info_t::create(attr()); |
66 | conf.with_binary_post_op |
67 | = attr()->post_ops_.find(primitive_kind::binary) != -1; |
68 | int ic_block_sz = 1; |
69 | conf.use_unroll_16b = false; |
70 | conf.src0_unroll_16b = false; |
71 | |
72 | auto &blk0 = src0_d.blocking_desc(); |
73 | auto &blk1 = src1_d.blocking_desc(); |
74 | auto &blkd = dst_d.blocking_desc(); |
75 | bool is_16b_blk0 = (blk0.inner_nblks >= 1) |
76 | && (blk0.inner_idxs[blk0.inner_nblks - 1] == 1) |
77 | && (blk0.inner_blks[blk0.inner_nblks - 1] == 16); |
78 | bool is_16b_blk1 = (blk1.inner_nblks >= 1) |
79 | && (blk1.inner_idxs[blk1.inner_nblks - 1] == 1) |
80 | && (blk1.inner_blks[blk1.inner_nblks - 1] == 16); |
81 | bool is_16b_blkd = (blkd.inner_nblks >= 1) |
82 | && (blkd.inner_idxs[blkd.inner_nblks - 1] == 1) |
83 | && (blkd.inner_blks[blkd.inner_nblks - 1] == 16); |
84 | |
85 | if (is_16b_blkd && !conf.is_tensor_op && !is_src0_bcasted) { |
86 | // If: in case when both are blocked |
87 | // Else: only src0 is blocked |
88 | if (is_16b_blk0 && is_16b_blk1) { |
89 | ic_block_sz = 16; |
90 | conf.use_unroll_16b = true; |
91 | } else if (is_16b_blk0 && blk1.inner_nblks == 0) { |
92 | ic_block_sz = 16; |
93 | conf.src0_unroll_16b = true; |
94 | } |
95 | } |
96 | |
97 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
98 | conf.dispatch = compute_engine->create_dispatch(dst_d.md_); |
99 | if (conf.is_tensor_op && conf.is_dense && conf.is_same_md |
100 | && !conf.with_binary_post_op) { |
101 | conf.dispatch.define_dim("IDX" , 0, dst_d.nelems()); |
102 | } else { |
103 | for (int i = 0; i < MAX_NDIMS; ++i) { |
104 | if (i == 1 && (conf.use_unroll_16b || conf.src0_unroll_16b)) { |
105 | // changing value for broadcasting offsets |
106 | // division by IC for enabling blocking within kernel |
107 | conf.dispatch.define_dim(utils::format("D%d" , i), |
108 | nstl::min(i, ndims - 1), |
109 | i < ndims ? dst_d.padded_dims()[i] : 1, ic_block_sz); |
110 | } else { |
111 | conf.dispatch.define_dim(utils::format("D%d" , i), |
112 | nstl::min(i, ndims - 1), |
113 | i < ndims ? dst_d.padded_dims()[i] : 1); |
114 | } |
115 | } |
116 | } |
117 | conf.dispatch.generate(); |
118 | return status::success; |
119 | } |
120 | |
121 | status_t ref_binary_t::pd_t::init_kernel_ctx( |
122 | compute::kernel_ctx_t &kernel_ctx) const { |
123 | kernel_ctx.set_data_type(conf.src0_data_type); |
124 | kernel_ctx.set_data_type(conf.src1_data_type); |
125 | kernel_ctx.set_data_type(conf.dst_data_type); |
126 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
127 | kernel_ctx.define_int("IS_MUL" , conf.is_mul); |
128 | kernel_ctx.define_int("IS_ADD" , conf.is_add); |
129 | kernel_ctx.define_int("IS_MAX" , conf.is_max); |
130 | kernel_ctx.define_int("IS_MIN" , conf.is_min); |
131 | kernel_ctx.define_int("IS_DIV" , conf.is_div); |
132 | kernel_ctx.define_int("IS_SUB" , conf.is_sub); |
133 | kernel_ctx.define_int("IS_GE" , conf.is_ge); |
134 | kernel_ctx.define_int("IS_GT" , conf.is_gt); |
135 | kernel_ctx.define_int("IS_LE" , conf.is_le); |
136 | kernel_ctx.define_int("IS_LT" , conf.is_lt); |
137 | kernel_ctx.define_int("IS_EQ" , conf.is_eq); |
138 | kernel_ctx.define_int("IS_NE" , conf.is_ne); |
139 | kernel_ctx.define_int("IS_TENSOR_OP" , conf.is_tensor_op); |
140 | kernel_ctx.define_int("IS_DENSE" , conf.is_dense); |
141 | kernel_ctx.define_int("IS_SAME_MD" , conf.is_same_md); |
142 | kernel_ctx.define_int("WITH_BINARY_POST_OP" , conf.with_binary_post_op); |
143 | kernel_ctx.define_int("SAME_SRC_DT" , conf.same_src_dt); |
144 | |
145 | kernel_ctx.define_int("SRC0_BCAST_DIM0" , conf.src0_bcast_dims[0]); |
146 | kernel_ctx.define_int("SRC0_BCAST_DIM1" , conf.src0_bcast_dims[1]); |
147 | kernel_ctx.define_int("SRC0_BCAST_DIM2" , conf.src0_bcast_dims[2]); |
148 | kernel_ctx.define_int("SRC0_BCAST_DIM3" , conf.src0_bcast_dims[3]); |
149 | kernel_ctx.define_int("SRC0_BCAST_DIM4" , conf.src0_bcast_dims[4]); |
150 | kernel_ctx.define_int("SRC0_BCAST_DIM5" , conf.src0_bcast_dims[5]); |
151 | |
152 | kernel_ctx.define_int("SRC1_BCAST_DIM0" , conf.src1_bcast_dims[0]); |
153 | kernel_ctx.define_int("SRC1_BCAST_DIM1" , conf.src1_bcast_dims[1]); |
154 | kernel_ctx.define_int("SRC1_BCAST_DIM2" , conf.src1_bcast_dims[2]); |
155 | kernel_ctx.define_int("SRC1_BCAST_DIM3" , conf.src1_bcast_dims[3]); |
156 | kernel_ctx.define_int("SRC1_BCAST_DIM4" , conf.src1_bcast_dims[4]); |
157 | kernel_ctx.define_int("SRC1_BCAST_DIM5" , conf.src1_bcast_dims[5]); |
158 | |
159 | kernel_ctx.define_int("USE_UNROLL_16B" , conf.use_unroll_16b); |
160 | kernel_ctx.define_int("SRC0_UNROLL_16B" , conf.src0_unroll_16b); |
161 | kernel_ctx.define_int("SUB_GROUP_SIZE" , 1); |
162 | |
163 | def_memory_desc_info(kernel_ctx, conf.src0_md_info, "SRC0" ); |
164 | def_memory_desc_info(kernel_ctx, conf.src1_md_info, "SRC1" ); |
165 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
166 | |
167 | def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_); |
168 | |
169 | def_dispatch(kernel_ctx, conf.dispatch); |
170 | |
171 | return status::success; |
172 | } |
173 | |
174 | status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { |
175 | |
176 | status_t status = status::success; |
177 | |
178 | auto &src0 = CTX_IN_STORAGE(DNNL_ARG_SRC_0); |
179 | auto &src1 = CTX_IN_STORAGE(DNNL_ARG_SRC_1); |
180 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
181 | CHECK(status); |
182 | |
183 | const auto &conf = pd()->conf; |
184 | |
185 | auto &src0_scale = CTX_IN_STORAGE(DNNL_ARG_SRC_0 | DNNL_ARG_ATTR_SCALES); |
186 | |
187 | auto &src1_scale = CTX_IN_STORAGE(DNNL_ARG_SRC_1 | DNNL_ARG_ATTR_SCALES); |
188 | |
189 | compute::kernel_arg_list_t arg_list; |
190 | arg_list.set(0, src0); |
191 | arg_list.set(1, src1); |
192 | arg_list.set(2, dst); |
193 | |
194 | unsigned arg_idx = append_post_ops_to_arg_list( |
195 | ctx, arg_list, 3, pd()->attr()->post_ops_); |
196 | |
197 | arg_list.set(arg_idx++, src0_scale); |
198 | arg_list.set(arg_idx, src1_scale); |
199 | |
200 | auto nd_range = conf.dispatch.nd_range(); |
201 | |
202 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
203 | return status; |
204 | } |
205 | |
206 | } // namespace ocl |
207 | } // namespace gpu |
208 | } // namespace impl |
209 | } // namespace dnnl |
210 | |
211 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
212 | |