1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <cmath>
17#include <vector>
18#include "cpu/gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
19#if DNNL_X64
20#include "cpu/x64/jit_primitive_conf.hpp"
21#endif
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26
27static dim_t zp_src_comp_pad_offset(const conv_gemm_conf_t &jcp,
28 const dim_t zp_pad_com_d, const dim_t zp_pad_com_h,
29 const dim_t zp_pad_com_w, dim_t oc, dim_t g) {
30 return ((zp_pad_com_d * jcp.zp.src_pad_comp.h + zp_pad_com_h)
31 * jcp.zp.src_pad_comp.w
32 + zp_pad_com_w)
33 * jcp.oc * jcp.ngroups
34 + (g * jcp.oc + oc);
35}
36
37static dim_t get_weights_offset(const memory_desc_wrapper &weights_md,
38 const bool with_groups, const dim_t kd, const dim_t kh,
39 const dim_t kw) {
40 auto ndims = weights_md.ndims();
41 if (with_groups) ndims -= 1;
42
43 switch (ndims) {
44 case 5:
45 return with_groups ? weights_md.blk_off(0, 0, 0, kd, kh, kw)
46 : weights_md.blk_off(0, 0, kd, kh, kw);
47 case 4:
48 return with_groups ? weights_md.blk_off(0, 0, 0, kh, kw)
49 : weights_md.blk_off(0, 0, kh, kw);
50 case 3:
51 return with_groups ? weights_md.blk_off(0, 0, 0, kw)
52 : weights_md.blk_off(0, 0, kw);
53 default: assert(!"unsupported ndims"); return dim_t(0);
54 }
55}
56
57static dim_t calculate_blk_size(const conv_gemm_conf_t &jcp) {
58 const auto number_of_threads = dnnl_get_max_threads();
59 const auto number_of_tasks = jcp.zp.src_pad_comp.d * jcp.zp.src_pad_comp.h
60 * jcp.zp.src_pad_comp.w;
61 auto scaling_factor = number_of_threads / number_of_tasks;
62 const auto output_channels = jcp.oc * jcp.ngroups;
63 static constexpr dim_t min_blk_size
64 = platform::get_cache_line_size() / sizeof(int32_t);
65
66 if (output_channels <= min_blk_size || scaling_factor <= 1)
67 return output_channels;
68
69 const auto scaling_factor_threashold
70 = nstl::max(output_channels / (2 * min_blk_size), dim_t(1));
71 if (scaling_factor > scaling_factor_threashold) {
72 scaling_factor = scaling_factor_threashold;
73 }
74
75 if (const auto blk_size
76 = utils::rnd_up(output_channels / scaling_factor, min_blk_size)) {
77 return blk_size;
78 }
79
80 return output_channels;
81}
82
83static void append_weights_to_comp_pad_buf(const conv_gemm_conf_t &jcp,
84 int32_t *const __restrict zp_src_pad_comp,
85 const int8_t *__restrict weights, dim_t weights_offset,
86 const dim_t start_oc_blk, const dim_t end_oc_blk) {
87 const auto output_channels = jcp.oc * jcp.ngroups;
88
89 for (dim_t it_ic = 0; it_ic < jcp.ic; ++it_ic) {
90 for (dim_t oc_off = start_oc_blk; oc_off < end_oc_blk; ++oc_off) {
91 zp_src_pad_comp[oc_off]
92 += static_cast<int32_t>(weights[weights_offset + oc_off]);
93 }
94
95 weights_offset += output_channels;
96 }
97}
98
99static dim_t calc_filter_corner_dim(const dim_t it_zp_buf_dim,
100 const dim_t &dim_size, const dim_t &input_begin_pad,
101 const dim_t &stride_dim, const dim_t &begin_comp_pad,
102 const bool &mid_comp_pad, const dim_t &end_comp_pad) {
103
104 if (it_zp_buf_dim < begin_comp_pad)
105 return it_zp_buf_dim * stride_dim - input_begin_pad;
106 else if (mid_comp_pad && it_zp_buf_dim == begin_comp_pad)
107 return 0;
108 else
109 return (dim_size - 1) * stride_dim - input_begin_pad
110 - (end_comp_pad - 1) * stride_dim
111 + (it_zp_buf_dim - (begin_comp_pad + mid_comp_pad))
112 * stride_dim;
113}
114
115void compute_zp_src_comp_pad(const conv_gemm_conf_t &jcp,
116 int32_t *const zp_src_pad_buf, const int32_t *const zp_src,
117 const int8_t *weights, const memory_desc_wrapper &weights_md,
118 const bool with_groups) {
119
120 const dim_t blk_size = calculate_blk_size(jcp);
121 const dim_t output_channels = jcp.oc * jcp.ngroups;
122 const dim_t oc_blks = utils::div_up(output_channels, blk_size);
123
124 const auto compute_zp_src_pad_buf = [&](const dim_t zp_pad_com_d,
125 const dim_t zp_pad_com_h,
126 const dim_t zp_pad_com_w,
127 const dim_t filter_corner_src_d,
128 const dim_t filter_corner_src_h,
129 const dim_t filter_corner_src_w,
130 const dim_t oc_blk) {
131 const auto start_blk = oc_blk * blk_size;
132 const auto end_blk = nstl::min(start_blk + blk_size, output_channels);
133 const auto size = end_blk - start_blk;
134 const auto zp_pad_offset = zp_src_comp_pad_offset(
135 jcp, zp_pad_com_d, zp_pad_com_h, zp_pad_com_w, 0, 0);
136 int32_t *const __restrict zp_src_pad_comp
137 = zp_src_pad_buf + zp_pad_offset;
138
139 std::memset(zp_src_pad_comp + start_blk, 0, size * sizeof(int32_t));
140
141 const auto dilate_scale_d = jcp.dilate_d + 1;
142 const auto dilate_scale_h = jcp.dilate_h + 1;
143 const auto dilate_scale_w = jcp.dilate_w + 1;
144
145 for (int it_kd = 0; it_kd < jcp.kd; it_kd++) {
146 const int filter_point_d = it_kd * dilate_scale_d;
147 const int filter_point_src_d = filter_corner_src_d + filter_point_d;
148 const bool filter_point_srd_d_pad
149 = filter_point_src_d < 0 || filter_point_src_d >= jcp.id;
150
151 for (int it_kh = 0; it_kh < jcp.kh; it_kh++) {
152 const int filter_point_h = it_kh * dilate_scale_h;
153 const int filter_point_src_h
154 = filter_corner_src_h + filter_point_h;
155 const bool filter_point_srd_h_pad = filter_point_src_h < 0
156 || filter_point_src_h >= jcp.ih;
157
158 for (int it_kw = 0; it_kw < jcp.kw; it_kw++) {
159 const int filter_point_w = it_kw * dilate_scale_w;
160 const int filter_point_src_w
161 = filter_corner_src_w + filter_point_w;
162
163 if (filter_point_srd_d_pad || filter_point_srd_h_pad
164 || filter_point_src_w < 0
165 || filter_point_src_w >= jcp.iw) {
166 const auto weights_offset = get_weights_offset(
167 weights_md, with_groups, it_kd, it_kh, it_kw);
168 append_weights_to_comp_pad_buf(jcp, zp_src_pad_comp,
169 weights, weights_offset, start_blk, end_blk);
170 }
171 }
172 }
173 }
174
175 if (jcp.zp.src_is_common) {
176 const int32_t zp_src_val = *zp_src;
177 for (auto oc_off = start_blk; oc_off < end_blk; ++oc_off)
178 zp_src_pad_comp[oc_off] *= zp_src_val;
179 } else {
180 for (auto oc_off = start_blk; oc_off < end_blk; ++oc_off)
181 zp_src_pad_comp[oc_off] *= zp_src[oc_off];
182 }
183 };
184
185 const auto compute_zp_buf_w = [&](dim_t it_zp_buf_d, dim_t it_zp_buf_h,
186 dim_t it_zp_buf_w,
187 dim_t filter_corner_src_d,
188 dim_t filter_corner_src_h,
189 const dim_t oc_blk) {
190 const int filter_corner_src_w = calc_filter_corner_dim(it_zp_buf_w,
191 jcp.ow, jcp.l_pad, jcp.stride_w, jcp.zp.src_pad_comp.left_pad,
192 jcp.zp.src_pad_comp.mid_w, jcp.zp.src_pad_comp.right_pad);
193 compute_zp_src_pad_buf(it_zp_buf_d, it_zp_buf_h, it_zp_buf_w,
194 filter_corner_src_d, filter_corner_src_h, filter_corner_src_w,
195 oc_blk);
196 };
197
198 const auto compute_zp_buf_h = [&](dim_t it_zp_buf_d, dim_t it_zp_buf_h,
199 dim_t it_zp_buf_w,
200 dim_t filter_corner_src_d,
201 const dim_t oc_blk) {
202 const auto filter_corner_src_h = calc_filter_corner_dim(it_zp_buf_h,
203 jcp.oh, jcp.t_pad, jcp.stride_h, jcp.zp.src_pad_comp.top_pad,
204 jcp.zp.src_pad_comp.mid_h, jcp.zp.src_pad_comp.bottom_pad);
205
206 compute_zp_buf_w(it_zp_buf_d, it_zp_buf_h, it_zp_buf_w,
207 filter_corner_src_d, filter_corner_src_h, oc_blk);
208 };
209
210 parallel_nd(jcp.zp.src_pad_comp.d, jcp.zp.src_pad_comp.h,
211 jcp.zp.src_pad_comp.w, oc_blks,
212 [&](const dim_t it_zp_buf_d, const dim_t it_zp_buf_h,
213 const dim_t it_zp_buf_w, const dim_t oc_blk) {
214 const int filter_corner_src_d
215 = calc_filter_corner_dim(it_zp_buf_d, jcp.od, jcp.f_pad,
216 jcp.stride_d, jcp.zp.src_pad_comp.front_pad,
217 jcp.zp.src_pad_comp.mid_d,
218 jcp.zp.src_pad_comp.back_pad);
219
220 compute_zp_buf_h(it_zp_buf_d, it_zp_buf_h, it_zp_buf_w,
221 filter_corner_src_d, oc_blk);
222 });
223}
224
225static dim_t zp_src_comp_pad_offset(const conv_gemm_conf_t &jcp,
226 const dim_t zp_pad_com_d, const dim_t zp_pad_com_h,
227 const dim_t zp_pad_com_w, const dim_t g) {
228 return zp_src_comp_pad_offset(
229 jcp, zp_pad_com_d, zp_pad_com_h, zp_pad_com_w, 0, g);
230}
231
232static dim_t gemm_conv_result_offset(
233 const conv_gemm_conf_t &jcp, const dim_t h, const dim_t w) {
234 return (h * jcp.ow + w) * jcp.oc;
235}
236
237static void append_zp_src_comp_pad(const conv_gemm_conf_t &jcp,
238 const int32_t *__restrict zp_src_pad_comp,
239 const dim_t zp_src_comp_pad_offset,
240 int32_t *__restrict gemm_conv_result,
241 const dim_t gemm_conv_result_offset) {
242
243 const int32_t *const __restrict zp_src_pad_comp_h_w
244 = zp_src_pad_comp + zp_src_comp_pad_offset;
245 int32_t *const __restrict gemm_conv_result_h_w
246 = gemm_conv_result + gemm_conv_result_offset;
247 const std::ptrdiff_t oc = jcp.oc;
248
249 for (std::ptrdiff_t oc_off = 0; oc_off < oc; ++oc_off)
250 gemm_conv_result_h_w[oc_off] += zp_src_pad_comp_h_w[oc_off];
251}
252
253static dim_t get_zp_pad_com_dim(const bool dim_under_lower_bound,
254 const bool dim_over_eq_upper_bound, const dim_t begin_pad, bool mid_pad,
255 const dim_t end_pad, const dim_t out_dim_size,
256 const dim_t out_point_dim) {
257
258 if (dim_under_lower_bound) {
259 return out_point_dim;
260 } else if (dim_over_eq_upper_bound) {
261 return begin_pad + mid_pad + (end_pad - (out_dim_size - out_point_dim));
262 }
263
264 return begin_pad;
265}
266
267dim_t calculate_lower_bound_dim(
268 const dim_t dim_offset, const dim_t begin_comp_pad) {
269 return dim_offset < begin_comp_pad ? begin_comp_pad - dim_offset : 0u;
270}
271
272dim_t calculate_upper_bound_dim(const dim_t output_dim_size,
273 const dim_t dim_size, const dim_t dim_offset,
274 const dim_t end_comp_pad) {
275
276 const dim_t distance_to_ouput_end
277 = output_dim_size - (dim_offset + dim_size);
278
279 const dim_t output_created_from_pad = distance_to_ouput_end < end_comp_pad
280 ? end_comp_pad - distance_to_ouput_end
281 : 0u;
282
283 return dim_size - output_created_from_pad;
284}
285
286void apply_zp_src_comp_pad(const conv_gemm_conf_t &jcp, const dim_t g,
287 const dim_t d_offset, const dim_t h_offset, const dim_t w_offset,
288 const dim_t h_size, const dim_t w_size,
289 int32_t *__restrict gemm_conv_result,
290 const int32_t *__restrict zp_src_pad_buf) {
291
292 const auto &comp_pad = jcp.zp.src_pad_comp;
293 const dim_t lower_d_bound
294 = calculate_lower_bound_dim(0, comp_pad.front_pad);
295 const dim_t upper_d_bound
296 = calculate_upper_bound_dim(jcp.od, jcp.od, 0, comp_pad.back_pad);
297
298 const bool d_under_lower_bound = d_offset < lower_d_bound;
299 const bool d_over_eq_upper_bound = d_offset >= upper_d_bound;
300 const bool should_apply_zp_src_pad_comp_d
301 = d_under_lower_bound || d_over_eq_upper_bound;
302 const dim_t zp_pad_com_d = get_zp_pad_com_dim(d_under_lower_bound,
303 d_over_eq_upper_bound, comp_pad.front_pad, comp_pad.mid_d,
304 comp_pad.back_pad, jcp.od, d_offset);
305
306 const dim_t lower_h_bound
307 = calculate_lower_bound_dim(h_offset, comp_pad.top_pad);
308 const dim_t upper_h_bound = calculate_upper_bound_dim(
309 jcp.oh, h_size, h_offset, comp_pad.bottom_pad);
310 const dim_t lower_w_bound
311 = calculate_lower_bound_dim(w_offset, comp_pad.left_pad);
312 const dim_t upper_w_bound = calculate_upper_bound_dim(
313 jcp.ow, w_size, w_offset, comp_pad.right_pad);
314
315 parallel_nd(h_size, w_size, [=](const dim_t h, const dim_t w) {
316 const bool h_under_lower_bound = h < lower_h_bound;
317 const bool h_over_eq_upper_bound = h >= upper_h_bound;
318 const bool w_under_lower_bound = w < lower_w_bound;
319 const bool w_over_eq_upper_bound = w >= upper_w_bound;
320
321 const bool should_apply_zp_src_pad_comp = should_apply_zp_src_pad_comp_d
322 || w_under_lower_bound || w_over_eq_upper_bound
323 || h_under_lower_bound || h_over_eq_upper_bound;
324
325 if (!should_apply_zp_src_pad_comp) return;
326
327 const auto out_point_h = h_offset + h;
328 const auto out_point_w = w_offset + w;
329
330 const dim_t zp_pad_com_h = get_zp_pad_com_dim(h_under_lower_bound,
331 h_over_eq_upper_bound, comp_pad.top_pad, comp_pad.mid_h,
332 comp_pad.bottom_pad, jcp.oh, out_point_h);
333
334 const dim_t zp_pad_com_w = get_zp_pad_com_dim(w_under_lower_bound,
335 w_over_eq_upper_bound, comp_pad.left_pad, comp_pad.mid_w,
336 comp_pad.right_pad, jcp.ow, out_point_w);
337
338 const auto zp_src_comp_pad_off = zp_src_comp_pad_offset(
339 jcp, zp_pad_com_d, zp_pad_com_h, zp_pad_com_w, g);
340 const auto gemm_result_off = gemm_conv_result_offset(jcp, h, w);
341
342 append_zp_src_comp_pad(jcp, zp_src_pad_buf, zp_src_comp_pad_off,
343 gemm_conv_result, gemm_result_off);
344 });
345}
346
347} // namespace cpu
348} // namespace impl
349} // namespace dnnl
350