1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl_types.h" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | #include "cpu/x64/jit_sse41_1x1_convolution.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | #define data_blk_off(f, n, c, h, w) \ |
31 | ((ndims == 3) ? (f).blk_off(n, c, w) : (f).blk_off(n, c, h, w)) |
32 | |
33 | using namespace dnnl::impl::status; |
34 | using namespace dnnl::impl::utils; |
35 | |
36 | void jit_sse41_1x1_convolution_fwd_t::execute_forward( |
37 | const exec_ctx_t &ctx) const { |
38 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
39 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
40 | auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
41 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
42 | auto weights_dw = CTX_IN_MEM( |
43 | const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
44 | auto bias_dw = CTX_IN_MEM( |
45 | const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
46 | const auto post_ops_binary_rhs_arg_vec |
47 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
48 | const auto post_ops_binary_rhs_arg_vec_dw = pd()->dw_conv_pd_ != nullptr |
49 | ? binary_injector::prepare_binary_args( |
50 | pd()->dw_conv_pd_->jcp_.post_ops, ctx, |
51 | pd()->jcp_.post_ops.entry_.size() + 1) |
52 | : std::vector<const void *> {}; |
53 | |
54 | auto scratchpad = ctx.get_scratchpad_grantor(); |
55 | parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { |
56 | execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, |
57 | dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), |
58 | post_ops_binary_rhs_arg_vec_dw.data()); |
59 | }); |
60 | |
61 | if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); |
62 | } |
63 | |
64 | void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, |
65 | const int nthr, const data_t *src, const data_t *weights, |
66 | const data_t *bias, const data_t *weights_dw, const data_t *bias_dw, |
67 | data_t *dst, const memory_tracking::grantor_t &scratchpad, |
68 | const void *post_ops_binary_rhs_arg_vec, |
69 | const void *post_ops_binary_rhs_arg_vec_dw) const { |
70 | |
71 | const memory_desc_wrapper src_d(pd()->src_md()); |
72 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
73 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
74 | const memory_desc_wrapper dw_weights_d( |
75 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); |
76 | const memory_desc_wrapper dw_bias_d( |
77 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); |
78 | |
79 | const auto &jcp = kernel_->jcp; |
80 | const int ndims = src_d.ndims(); |
81 | |
82 | // TODO (Roma): remove this restriction |
83 | assert(jcp.stride_w == 1 && jcp.stride_h == 1); |
84 | |
85 | auto par_conv = jit_1x1_conv_call_s(); |
86 | |
87 | const int nb_oc = jcp.nb_load; |
88 | const int nb_ic = jcp.nb_reduce; |
89 | const int nb_ic_blocking = jcp.nb_reduce_blocking; |
90 | |
91 | // override some constants for fused dw_conv |
92 | const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; |
93 | const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; |
94 | const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; |
95 | const int nb_bcast_blocking_max |
96 | = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; |
97 | const int nb_load_blocking = jcp.nb_load_blocking; |
98 | const int nb_load_blocking_max = jcp.with_dw_conv |
99 | ? jcp.nb_load_blocking |
100 | : jcp.nb_load_blocking_max; |
101 | const bool is_dst_layout_nxc = utils::one_of( |
102 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
103 | const bool is_src_layout_nxc = utils::one_of( |
104 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
105 | |
106 | // Begin: declare Variables needed for dw conv. |
107 | data_t *pbuf {nullptr}; |
108 | size_t row_offset {}; |
109 | const int nb_buffer = jcp.nb_load_blocking; |
110 | std::vector<data_t *> addrs; |
111 | |
112 | auto step = [](int default_step, int remaining, int tail_step) { |
113 | assert(default_step <= tail_step); |
114 | return remaining < tail_step ? remaining : default_step; |
115 | }; |
116 | |
117 | auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, |
118 | int bcast_end, int &oh, int &ow, int &ih, |
119 | int &iw) { |
120 | int osb {0}; |
121 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); |
122 | |
123 | bcast_step = step( |
124 | nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); |
125 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
126 | |
127 | const int os = osb * os_block; |
128 | ow = os % jcp.ow; |
129 | oh = os / jcp.ow; |
130 | |
131 | ih = oh * jcp.stride_h; |
132 | iw = ow * jcp.stride_w; |
133 | |
134 | par_conv.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
135 | }; |
136 | |
137 | auto init_load = [&](int ocb, int ocb_end, int &load_step) { |
138 | load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); |
139 | par_conv.load_dim = this_block_size( |
140 | ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block); |
141 | }; |
142 | |
143 | auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, int ih, |
144 | int iw) { |
145 | const size_t _ocb = g * nb_oc + ocb; |
146 | const size_t _icb = g * nb_ic + icb; |
147 | |
148 | const int oc_off_idx = (is_dst_layout_nxc ? jcp.oc_block : 1) * _ocb; |
149 | |
150 | par_conv.output_data = jcp.with_dw_conv |
151 | ? pbuf + (oh % pd()->dw_conv_pd_->jcp_.kh) * row_offset |
152 | : &dst[data_blk_off(dst_d, n, oc_off_idx, oh, ow)]; |
153 | |
154 | par_conv.bias_data = &bias[_ocb * jcp.oc_block]; |
155 | |
156 | par_conv.first_last_flag = 0 | (icb == 0) * FLAG_REDUCE_FIRST |
157 | | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST; |
158 | |
159 | par_conv.reduce_dim = this_block_size( |
160 | icb * jcp.ic_block, jcp.ic, nb_ic_blocking * jcp.ic_block); |
161 | |
162 | const int ic_off_idx = (is_src_layout_nxc ? jcp.ic_block : 1) * _icb; |
163 | |
164 | const size_t src_off = data_blk_off(src_d, n, ic_off_idx, ih, iw); |
165 | par_conv.bcast_data = &src[src_off]; |
166 | |
167 | par_conv.load_data |
168 | = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) |
169 | : weights_d.blk_off(ocb, icb)]; |
170 | |
171 | par_conv.oc_l_off = _ocb * jcp.oc_block; |
172 | par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
173 | par_conv.dst_orig = jcp.with_dw_conv ? pbuf : dst; |
174 | |
175 | (*kernel_)(&par_conv); |
176 | }; |
177 | |
178 | auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, |
179 | int ocb_end) { |
180 | if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; |
181 | |
182 | int iwork = bcast_start; |
183 | while (iwork < bcast_end) { |
184 | int n {0}, g {0}, bcast_step, oh, ow, ih, iw; |
185 | init_bcast(iwork, n, g, bcast_step, bcast_end, oh, ow, ih, iw); |
186 | int ocb = 0; |
187 | while (ocb < ocb_end) { |
188 | int load_step; |
189 | init_load(ocb, ocb_end, load_step); |
190 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
191 | inner_ker(ocb, icb, n, g, oh, ow, ih, iw); |
192 | } |
193 | ocb += load_step; |
194 | } |
195 | iwork += bcast_step; |
196 | } |
197 | }; |
198 | |
199 | auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { |
200 | auto &jcp_dw = pd()->dw_conv_pd_->jcp_; |
201 | int oh_1x1 = nstl::max(dw_oh * jcp_dw.stride_h - jcp_dw.t_pad, 0); |
202 | |
203 | for (int i = 0; i < jcp_dw.kh; ++i) |
204 | addrs[i] = pbuf + ((oh_1x1++) % jcp_dw.kh) * row_offset; |
205 | |
206 | const auto ocb_end = ocb_start + load_step; |
207 | const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw.iw) |
208 | * jcp_dw.nb_ch_blocking * jcp_dw.ch_block; |
209 | const int dil_h = jcp_dw.dilate_h + 1; |
210 | const int str_h = jcp_dw.stride_h; |
211 | const int ch_num = jcp_dw.nb_ch_blocking; |
212 | |
213 | for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw.nb_ch_blocking) { |
214 | |
215 | const int i_t_overflow |
216 | = nstl::max(0, (int)(jcp_dw.t_pad - dw_oh * str_h)); |
217 | const int i_b_overflow |
218 | = nstl::max(jcp_dw.ih, |
219 | (int)(dw_oh * str_h + (jcp_dw.kh - 1) * dil_h |
220 | - jcp_dw.t_pad + 1)) |
221 | - jcp_dw.ih; |
222 | |
223 | const int kh = div_up(i_t_overflow, dil_h); |
224 | const int kh_padding = jcp_dw.kh - div_up(i_t_overflow, dil_h) |
225 | - div_up(i_b_overflow, dil_h); |
226 | |
227 | const int ow = 0; |
228 | const int kw = 0; |
229 | jit_conv_call_s par_conv_dw; |
230 | |
231 | par_conv_dw.src = addrs.data(); |
232 | |
233 | const size_t ch_step = is_dst_layout_nxc |
234 | ? jcp_dw.ch_block |
235 | : dst_d.blk_off(0, 1, 0, 0); |
236 | par_conv_dw.dst |
237 | = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step]; |
238 | par_conv_dw.filt |
239 | = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)]; |
240 | if (bias) |
241 | par_conv_dw.bias |
242 | = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw.ch_block)]; |
243 | |
244 | par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding); |
245 | |
246 | par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw.nb_ch) - ch) |
247 | * jcp_dw.ch_block; |
248 | |
249 | par_conv_dw.oc_l_off = ch; |
250 | par_conv_dw.post_ops_binary_rhs_arg_vec |
251 | = post_ops_binary_rhs_arg_vec_dw; |
252 | par_conv_dw.dst_orig = dst; |
253 | |
254 | (*kernel_dw_)(&par_conv_dw); |
255 | |
256 | for (int i = 0; i < jcp_dw.kh; ++i) |
257 | addrs[i] += wch_stride; |
258 | } |
259 | }; |
260 | |
261 | auto conv_dw = [&]() { |
262 | // Set variables |
263 | auto &jcp_dw = pd()->dw_conv_pd_->jcp_; |
264 | memory_tracking::grantor_t dw_scratchpad( |
265 | scratchpad, memory_tracking::names::prefix_fusion); |
266 | const auto dw_conv_buffer = dw_scratchpad.get<data_t>( |
267 | memory_tracking::names::key_fusion_inout_buffer); |
268 | |
269 | const auto dw_conv_buffer_size_ |
270 | = (size_t)jcp_dw.kh * jcp.ow * nb_buffer * jcp.oc_block; |
271 | pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; |
272 | row_offset = dw_conv_buffer_size_ / jcp_dw.kh; |
273 | addrs.resize(jcp_dw.kh); |
274 | |
275 | int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; |
276 | balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start, |
277 | bcast_end, nb_oc, ocb_start, ocb_end, 1); |
278 | |
279 | while (ocb_start < ocb_end) { |
280 | int load_step; |
281 | init_load(ocb_start, ocb_end, load_step); |
282 | |
283 | int oh_1x1 = 0; |
284 | auto bcast_iter = bcast_start; |
285 | while (bcast_iter < bcast_end) { |
286 | int n, g, oh_dw; |
287 | nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, |
288 | jcp_dw.oh); |
289 | if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary |
290 | const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad; |
291 | const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); |
292 | const int oh_1x1_end |
293 | = nstl::min(oh_1x1_range + jcp_dw.kh, jcp.oh); |
294 | oh_1x1 = nstl::max( |
295 | oh_1x1_begin, oh_1x1); // Skip rows computed previously |
296 | |
297 | // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh |
298 | const int bcast_start_1x1 |
299 | = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; |
300 | const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; |
301 | |
302 | conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, |
303 | ocb_start + load_step); |
304 | oh_1x1 = oh_1x1_end; |
305 | ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); |
306 | |
307 | bcast_iter += nb_bcast_blocking; |
308 | } |
309 | ocb_start += load_step; |
310 | } |
311 | }; |
312 | |
313 | if (jcp.with_dw_conv) { |
314 | conv_dw(); |
315 | } else { |
316 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
317 | int start {0}, end {0}; |
318 | balance211(work_amount, nthr, ithr, start, end); |
319 | conv_1x1(start, end, 0, jcp.nb_load); |
320 | } |
321 | } |
322 | |
323 | } // namespace x64 |
324 | } // namespace cpu |
325 | } // namespace impl |
326 | } // namespace dnnl |
327 | |