1/*******************************************************************************
2* Copyright 2017-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl_types.h"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "cpu/x64/jit_sse41_convolution.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::utils;
30
31#define src_blk_off(f, n, c, h, w) \
32 (pd()->ndims() == 3) ? (f).blk_off(n, c, w) : (f).blk_off(n, c, h, w)
33
34#define wht_blk_off_(f, g, ...) \
35 pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
36#define wht_blk_off(f, g, oc, ic, kh, kw) \
37 pd()->ndims() == 3 ? wht_blk_off_(f, g, oc, ic, kw) \
38 : wht_blk_off_(f, g, oc, ic, kh, kw)
39
40void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
41 const auto &jcp = kernel_->jcp;
42
43 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
45 auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
46 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
47 const auto post_ops_binary_rhs_arg_vec
48 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
49
50 const memory_desc_wrapper src_d(pd()->src_md());
51 const memory_desc_wrapper dst_d(pd()->dst_md());
52 const memory_desc_wrapper weights_d(pd()->weights_md(0));
53 const memory_desc_wrapper bias_d(pd()->weights_md(1));
54
55 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
56 const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh;
57
58 const bool is_src_layout_nxc
59 = one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc);
60 const bool is_dst_layout_nxc
61 = one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc);
62
63 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
64 assert(nthr == jcp.nthr);
65
66 size_t start {0}, end {0};
67 balance211(work_amount, nthr, ithr, start, end);
68
69 int icbb = 0;
70 while (icbb < jcp.nb_ic) {
71 int icb_step = jcp.nb_ic_blocking;
72 int icb_step_rem = jcp.nb_ic - icbb;
73 if (icb_step_rem < jcp.nb_ic_blocking_max) icb_step = icb_step_rem;
74
75 size_t n {0}, g {0}, ocbb {0}, oh {0};
76 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
77 oh, jcp.oh);
78 for (size_t iwork = start; iwork < end; ++iwork) {
79 int ocb = ocbb * jcp.nb_oc_blocking;
80 int ocb_num = jcp.nb_oc_blocking;
81
82 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
83 auto par_conv = jit_conv_call_s();
84
85 const int ij = oh * jcp.stride_h;
86 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
87 const int i_b_overflow
88 = nstl::max(jcp.ih,
89 ij + (jcp.kh - 1) * (jcp.dilate_h + 1)
90 - jcp.t_pad + 1)
91 - jcp.ih;
92
93 const size_t _oc
94 = g * (is_dst_layout_nxc ? jcp.oc : jcp.nb_oc)
95 + ocb * (is_dst_layout_nxc ? jcp.oc_block : 1);
96 const size_t _ic
97 = g * (is_src_layout_nxc ? jcp.ic : jcp.nb_ic)
98 + icb * (is_src_layout_nxc ? jcp.ic_block : 1);
99
100 const int ih = nstl::max(ij - jcp.t_pad
101 + div_up(i_t_overflow, (jcp.dilate_h + 1))
102 * (jcp.dilate_h + 1),
103 0);
104 par_conv.src = &src[src_blk_off(src_d, n, _ic, ih, 0)];
105
106 par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)];
107
108 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
109 par_conv.filt = &weights[wht_blk_off(
110 weights_d, g, ocb, icb, wh, 0)];
111
112 if (icb == 0) {
113 if (bias)
114 par_conv.bias = &bias[bias_d.blk_off(_oc
115 * (is_dst_layout_nxc ? 1 : jcp.oc_block))];
116 par_conv.flags |= FLAG_IC_FIRST;
117 }
118
119 if ((jcp.with_eltwise || jcp.with_binary)
120 && icb + 1 == jcp.nb_ic) {
121 par_conv.flags |= FLAG_IC_LAST;
122 }
123
124 par_conv.oc_blocks
125 = nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
126
127 par_conv.kw_padding = 0;
128 const int kh_padding = jcp.kh
129 - div_up(i_t_overflow, (jcp.dilate_h + 1))
130 - div_up(i_b_overflow, (jcp.dilate_h + 1));
131 par_conv.kh_padding = nstl::max(0, kh_padding);
132
133 par_conv.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block;
134 par_conv.post_ops_binary_rhs_arg_vec
135 = post_ops_binary_rhs_arg_vec.data();
136 par_conv.dst_orig = dst;
137
138 (*kernel_)(&par_conv);
139 }
140 nd_iterator_step(
141 n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
142 }
143 icbb += icb_step;
144 }
145 });
146
147 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
148}
149
150} // namespace x64
151} // namespace cpu
152} // namespace impl
153} // namespace dnnl
154