1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See theb_ License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23#include "cpu/scale_utils.hpp"
24
25#include "cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32using namespace dnnl::impl::status;
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::utils;
35
36using namespace nstl;
37
38#define wht_blk_off(d, g, ...) \
39 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
40 : (d).blk_off(__VA_ARGS__))
41
42#define md_blk_off(md, n, c, d, h, w) \
43 (pd()->ndims() == 3 \
44 ? (md).blk_off((n), (c), (w)) \
45 : (pd()->ndims() == 4 \
46 ? (md).blk_off((n), (c), (h), (w)) \
47 : (md).blk_off((n), (c), (d), (h), (w))))
48
49void jit_avx512_core_amx_1x1_convolution_fwd_t::prepare_padded_bias(
50 const char *&bias, const memory_tracking::grantor_t &scratchpad) const {
51 if (!pd()->wants_padded_bias()) return;
52
53 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
54 auto padded_bias = scratchpad.template get<char>(
55 memory_tracking::names::key_conv_padded_bias);
56 utils::array_copy(
57 padded_bias, bias, bia_dt_size * pd()->jcp_.oc_without_padding);
58 utils::array_set(padded_bias + bia_dt_size * pd()->jcp_.oc_without_padding,
59 0.f, bia_dt_size * (pd()->jcp_.oc - pd()->jcp_.oc_without_padding));
60 bias = padded_bias;
61}
62
63status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward(
64 const exec_ctx_t &ctx) const {
65
66 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
67 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
68 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
69 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
70 const auto post_ops_binary_rhs_arg_vec
71 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
72
73 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
74 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
75 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
76
77 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
78 src_scales, wei_scales, pd()->OC(), pd()->attr());
79
80 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
81 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
82
83 const memory_desc_wrapper src_d(pd()->src_md());
84 const memory_desc_wrapper dst_d(pd()->dst_md());
85 const memory_desc_wrapper weights_d(pd()->weights_md(0));
86 const memory_desc_wrapper bias_d(pd()->weights_md(1));
87
88 const size_t bia_dt_size = pd()->with_bias()
89 ? types::data_type_size(pd()->desc()->bias_desc.data_type)
90 : 0;
91 const size_t dst_dt_size
92 = types::data_type_size(pd()->desc()->dst_desc.data_type);
93 const size_t src_dt_size
94 = types::data_type_size(pd()->desc()->src_desc.data_type);
95 const size_t wei_dt_size
96 = types::data_type_size(pd()->desc()->weights_desc.data_type);
97
98 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
99
100 const auto &jcp = pd()->jcp_;
101 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
102
103 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
104 const int32_t *zp_compensation = jcp.src_zero_point
105 ? reinterpret_cast<const int32_t *>(&weights[offset])
106 : nullptr;
107
108 const bool is_ic_tail = jcp.ic_without_padding % jcp.ic_block_int_np;
109 auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>(
110 key_conv_amx_wsp_buffer);
111 int32_t *wsp_tile = (is_ic_tail)
112 ? ctx.get_scratchpad_grantor().template get<int32_t>(
113 key_conv_amx_tile_buffer)
114 : nullptr;
115 auto tcfg = ctx.get_scratchpad_grantor().template get<char>(
116 key_conv_amx_tilecfg);
117
118 const size_t wei_oc_shift = static_cast<size_t>(
119 utils::rnd_up(jcp.ic_without_padding, jcp.ic_block_int)
120 * jcp.oc_block * jcp.nb_oc_blocking);
121
122 int nb_os = (jcp.tile_tail) ? jcp.nb_os + 1 : jcp.nb_os;
123 int os_step = jcp.nb_os2_blocking * jcp.nb_os_blocking;
124 int os_chunks = div_up(nb_os, os_step);
125
126 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
127
128 const size_t work_amount
129 = (size_t)jcp.mb * jcp.ngroups * os_chunks * oc_chunks;
130 kernel_->tile_configure(tcfg);
131
132 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
133 size_t start {0}, end {0};
134 balance211(work_amount, nthr, ithr, start, end);
135
136 auto p = jit_conv_call_s();
137 p.tile_cfg = tcfg;
138 p.tile_cfg_tail = tcfg + 64;
139
140 amx_tile_configure(tcfg);
141
142 int mb {0}, g {0}, _osb {0}, _ocb {0};
143 nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, _osb, os_chunks,
144 _ocb, oc_chunks);
145
146 while (start < end) {
147 int osb = _osb * os_step;
148 int ocb = _ocb * jcp.nb_oc_blocking;
149 auto bias_w = bias
150 ? bias + (bias_d.blk_off(ocb * jcp.oc_block) * bia_dt_size)
151 : nullptr;
152
153 int oc = g * jcp.oc_without_padding + ocb * jcp.oc_block;
154 int ic = g * jcp.ic_without_padding;
155
156 p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size;
157 p.src_prf = wsp_tile + ithr * (jcp.wsp_buffer_size / 2);
158 p.filt = weights + wei_dt_size * _ocb * wei_oc_shift;
159 p.bias = bias_w;
160 p.scales = &oscales[jcp.is_oc_scale * oc];
161 p.dst_scale = &dst_scales[0];
162 p.oc_blocks = ocb;
163
164 p.zp_compensation
165 = jcp.src_zero_point ? zp_compensation + oc : nullptr;
166 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
167 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
168
169 p.oc_l_off = oc;
170 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
171 p.dst_orig = dst;
172
173 const bool check_last_sp = is_ic_tail && !(nb_os % 2);
174 const bool is_overflow = (osb + os_step >= nb_os);
175 if (is_overflow
176 && (os_chunks > 1 || (os_chunks == 1 && is_ic_tail))) {
177 int step = (check_last_sp) ? 1 : jcp.nb_os_blocking;
178 for (int osi = 0; osi < nb_os - osb; osi += step) {
179 int osb_i = osi + osb;
180 int od {0}, oh {0}, ow {0};
181 nd_iterator_init(osb_i * jcp.tile_width, od, jcp.od, oh,
182 jcp.oh, ow, jcp.ow);
183 size_t dst_offset = md_blk_off(dst_d, mb, oc, od, oh, ow);
184 p.dst = dst + dst_dt_size * dst_offset;
185
186 int id = od * jcp.stride_d;
187 int ih = oh * jcp.stride_h;
188 int iw = ow * jcp.stride_w;
189 size_t inp_offset = md_blk_off(src_d, mb, ic, id, ih, iw);
190 p.src = src + src_dt_size * inp_offset;
191
192 bool l_overflow = osb_i + jcp.nb_os_blocking >= nb_os;
193 p.last_h = (check_last_sp || (nb_os % 2 && l_overflow)) ? 1
194 : 0;
195 p.is_osb = 0;
196 (*kernel_)(&p);
197 }
198 } else {
199 int od {0}, oh {0}, ow {0};
200 nd_iterator_init(osb * jcp.tile_width, od, jcp.od, oh, jcp.oh,
201 ow, jcp.ow);
202 size_t dst_offset = md_blk_off(dst_d, mb, oc, od, oh, ow);
203 p.dst = dst + dst_dt_size * dst_offset;
204
205 int id = od * jcp.stride_d;
206 int ih = oh * jcp.stride_h;
207 int iw = ow * jcp.stride_w;
208 size_t inp_offset = md_blk_off(src_d, mb, ic, id, ih, iw);
209 p.src = src + src_dt_size * inp_offset;
210
211 p.last_h = 0;
212 p.is_osb = 1;
213
214 (*kernel_)(&p);
215 }
216 ++start;
217 nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, _osb, os_chunks, _ocb,
218 oc_chunks);
219 }
220
221 amx_tile_release();
222 });
223 return status::success;
224}
225
226} // namespace x64
227} // namespace cpu
228} // namespace impl
229} // namespace dnnl
230
231// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
232