1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl_types.h"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23#include "cpu/x64/jit_sse41_1x1_convolution.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30#define data_blk_off(f, n, c, h, w) \
31 ((ndims == 3) ? (f).blk_off(n, c, w) : (f).blk_off(n, c, h, w))
32
33using namespace dnnl::impl::status;
34using namespace dnnl::impl::utils;
35
36void jit_sse41_1x1_convolution_fwd_t::execute_forward(
37 const exec_ctx_t &ctx) const {
38 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
39 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
40 auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
41 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
42 auto weights_dw = CTX_IN_MEM(
43 const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
44 auto bias_dw = CTX_IN_MEM(
45 const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
46 const auto post_ops_binary_rhs_arg_vec
47 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
48 const auto post_ops_binary_rhs_arg_vec_dw = pd()->dw_conv_pd_ != nullptr
49 ? binary_injector::prepare_binary_args(
50 pd()->dw_conv_pd_->jcp_.post_ops, ctx,
51 pd()->jcp_.post_ops.entry_.size() + 1)
52 : std::vector<const void *> {};
53
54 auto scratchpad = ctx.get_scratchpad_grantor();
55 parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
56 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
57 dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
58 post_ops_binary_rhs_arg_vec_dw.data());
59 });
60
61 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
62}
63
64void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr,
65 const int nthr, const data_t *src, const data_t *weights,
66 const data_t *bias, const data_t *weights_dw, const data_t *bias_dw,
67 data_t *dst, const memory_tracking::grantor_t &scratchpad,
68 const void *post_ops_binary_rhs_arg_vec,
69 const void *post_ops_binary_rhs_arg_vec_dw) const {
70
71 const memory_desc_wrapper src_d(pd()->src_md());
72 const memory_desc_wrapper dst_d(pd()->dst_md());
73 const memory_desc_wrapper weights_d(pd()->weights_md(0));
74 const memory_desc_wrapper dw_weights_d(
75 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
76 const memory_desc_wrapper dw_bias_d(
77 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS));
78
79 const auto &jcp = kernel_->jcp;
80 const int ndims = src_d.ndims();
81
82 // TODO (Roma): remove this restriction
83 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
84
85 auto par_conv = jit_1x1_conv_call_s();
86
87 const int nb_oc = jcp.nb_load;
88 const int nb_ic = jcp.nb_reduce;
89 const int nb_ic_blocking = jcp.nb_reduce_blocking;
90
91 // override some constants for fused dw_conv
92 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
93 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
94 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
95 const int nb_bcast_blocking_max
96 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
97 const int nb_load_blocking = jcp.nb_load_blocking;
98 const int nb_load_blocking_max = jcp.with_dw_conv
99 ? jcp.nb_load_blocking
100 : jcp.nb_load_blocking_max;
101 const bool is_dst_layout_nxc = utils::one_of(
102 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
103 const bool is_src_layout_nxc = utils::one_of(
104 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
105
106 // Begin: declare Variables needed for dw conv.
107 data_t *pbuf {nullptr};
108 size_t row_offset {};
109 const int nb_buffer = jcp.nb_load_blocking;
110 std::vector<data_t *> addrs;
111
112 auto step = [](int default_step, int remaining, int tail_step) {
113 assert(default_step <= tail_step);
114 return remaining < tail_step ? remaining : default_step;
115 };
116
117 auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
118 int bcast_end, int &oh, int &ow, int &ih,
119 int &iw) {
120 int osb {0};
121 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
122
123 bcast_step = step(
124 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
125 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
126
127 const int os = osb * os_block;
128 ow = os % jcp.ow;
129 oh = os / jcp.ow;
130
131 ih = oh * jcp.stride_h;
132 iw = ow * jcp.stride_w;
133
134 par_conv.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
135 };
136
137 auto init_load = [&](int ocb, int ocb_end, int &load_step) {
138 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
139 par_conv.load_dim = this_block_size(
140 ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
141 };
142
143 auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, int ih,
144 int iw) {
145 const size_t _ocb = g * nb_oc + ocb;
146 const size_t _icb = g * nb_ic + icb;
147
148 const int oc_off_idx = (is_dst_layout_nxc ? jcp.oc_block : 1) * _ocb;
149
150 par_conv.output_data = jcp.with_dw_conv
151 ? pbuf + (oh % pd()->dw_conv_pd_->jcp_.kh) * row_offset
152 : &dst[data_blk_off(dst_d, n, oc_off_idx, oh, ow)];
153
154 par_conv.bias_data = &bias[_ocb * jcp.oc_block];
155
156 par_conv.first_last_flag = 0 | (icb == 0) * FLAG_REDUCE_FIRST
157 | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST;
158
159 par_conv.reduce_dim = this_block_size(
160 icb * jcp.ic_block, jcp.ic, nb_ic_blocking * jcp.ic_block);
161
162 const int ic_off_idx = (is_src_layout_nxc ? jcp.ic_block : 1) * _icb;
163
164 const size_t src_off = data_blk_off(src_d, n, ic_off_idx, ih, iw);
165 par_conv.bcast_data = &src[src_off];
166
167 par_conv.load_data
168 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
169 : weights_d.blk_off(ocb, icb)];
170
171 par_conv.oc_l_off = _ocb * jcp.oc_block;
172 par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
173 par_conv.dst_orig = jcp.with_dw_conv ? pbuf : dst;
174
175 (*kernel_)(&par_conv);
176 };
177
178 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
179 int ocb_end) {
180 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
181
182 int iwork = bcast_start;
183 while (iwork < bcast_end) {
184 int n {0}, g {0}, bcast_step, oh, ow, ih, iw;
185 init_bcast(iwork, n, g, bcast_step, bcast_end, oh, ow, ih, iw);
186 int ocb = 0;
187 while (ocb < ocb_end) {
188 int load_step;
189 init_load(ocb, ocb_end, load_step);
190 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
191 inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
192 }
193 ocb += load_step;
194 }
195 iwork += bcast_step;
196 }
197 };
198
199 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
200 auto &jcp_dw = pd()->dw_conv_pd_->jcp_;
201 int oh_1x1 = nstl::max(dw_oh * jcp_dw.stride_h - jcp_dw.t_pad, 0);
202
203 for (int i = 0; i < jcp_dw.kh; ++i)
204 addrs[i] = pbuf + ((oh_1x1++) % jcp_dw.kh) * row_offset;
205
206 const auto ocb_end = ocb_start + load_step;
207 const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw.iw)
208 * jcp_dw.nb_ch_blocking * jcp_dw.ch_block;
209 const int dil_h = jcp_dw.dilate_h + 1;
210 const int str_h = jcp_dw.stride_h;
211 const int ch_num = jcp_dw.nb_ch_blocking;
212
213 for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw.nb_ch_blocking) {
214
215 const int i_t_overflow
216 = nstl::max(0, (int)(jcp_dw.t_pad - dw_oh * str_h));
217 const int i_b_overflow
218 = nstl::max(jcp_dw.ih,
219 (int)(dw_oh * str_h + (jcp_dw.kh - 1) * dil_h
220 - jcp_dw.t_pad + 1))
221 - jcp_dw.ih;
222
223 const int kh = div_up(i_t_overflow, dil_h);
224 const int kh_padding = jcp_dw.kh - div_up(i_t_overflow, dil_h)
225 - div_up(i_b_overflow, dil_h);
226
227 const int ow = 0;
228 const int kw = 0;
229 jit_conv_call_s par_conv_dw;
230
231 par_conv_dw.src = addrs.data();
232
233 const size_t ch_step = is_dst_layout_nxc
234 ? jcp_dw.ch_block
235 : dst_d.blk_off(0, 1, 0, 0);
236 par_conv_dw.dst
237 = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step];
238 par_conv_dw.filt
239 = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)];
240 if (bias)
241 par_conv_dw.bias
242 = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw.ch_block)];
243
244 par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding);
245
246 par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw.nb_ch) - ch)
247 * jcp_dw.ch_block;
248
249 par_conv_dw.oc_l_off = ch;
250 par_conv_dw.post_ops_binary_rhs_arg_vec
251 = post_ops_binary_rhs_arg_vec_dw;
252 par_conv_dw.dst_orig = dst;
253
254 (*kernel_dw_)(&par_conv_dw);
255
256 for (int i = 0; i < jcp_dw.kh; ++i)
257 addrs[i] += wch_stride;
258 }
259 };
260
261 auto conv_dw = [&]() {
262 // Set variables
263 auto &jcp_dw = pd()->dw_conv_pd_->jcp_;
264 memory_tracking::grantor_t dw_scratchpad(
265 scratchpad, memory_tracking::names::prefix_fusion);
266 const auto dw_conv_buffer = dw_scratchpad.get<data_t>(
267 memory_tracking::names::key_fusion_inout_buffer);
268
269 const auto dw_conv_buffer_size_
270 = (size_t)jcp_dw.kh * jcp.ow * nb_buffer * jcp.oc_block;
271 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
272 row_offset = dw_conv_buffer_size_ / jcp_dw.kh;
273 addrs.resize(jcp_dw.kh);
274
275 int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end;
276 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start,
277 bcast_end, nb_oc, ocb_start, ocb_end, 1);
278
279 while (ocb_start < ocb_end) {
280 int load_step;
281 init_load(ocb_start, ocb_end, load_step);
282
283 int oh_1x1 = 0;
284 auto bcast_iter = bcast_start;
285 while (bcast_iter < bcast_end) {
286 int n, g, oh_dw;
287 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
288 jcp_dw.oh);
289 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
290 const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad;
291 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
292 const int oh_1x1_end
293 = nstl::min(oh_1x1_range + jcp_dw.kh, jcp.oh);
294 oh_1x1 = nstl::max(
295 oh_1x1_begin, oh_1x1); // Skip rows computed previously
296
297 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh
298 const int bcast_start_1x1
299 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
300 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
301
302 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
303 ocb_start + load_step);
304 oh_1x1 = oh_1x1_end;
305 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
306
307 bcast_iter += nb_bcast_blocking;
308 }
309 ocb_start += load_step;
310 }
311 };
312
313 if (jcp.with_dw_conv) {
314 conv_dw();
315 } else {
316 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
317 int start {0}, end {0};
318 balance211(work_amount, nthr, ithr, start, end);
319 conv_1x1(start, end, 0, jcp.nb_load);
320 }
321}
322
323} // namespace x64
324} // namespace cpu
325} // namespace impl
326} // namespace dnnl
327