1 | /******************************************************************************* |
2 | * Copyright 2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <cmath> |
17 | #include <vector> |
18 | #include "cpu/gemm_x8s8s32x_conv_zp_src_pad_comp.hpp" |
19 | #if DNNL_X64 |
20 | #include "cpu/x64/jit_primitive_conf.hpp" |
21 | #endif |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | |
27 | static dim_t zp_src_comp_pad_offset(const conv_gemm_conf_t &jcp, |
28 | const dim_t zp_pad_com_d, const dim_t zp_pad_com_h, |
29 | const dim_t zp_pad_com_w, dim_t oc, dim_t g) { |
30 | return ((zp_pad_com_d * jcp.zp.src_pad_comp.h + zp_pad_com_h) |
31 | * jcp.zp.src_pad_comp.w |
32 | + zp_pad_com_w) |
33 | * jcp.oc * jcp.ngroups |
34 | + (g * jcp.oc + oc); |
35 | } |
36 | |
37 | static dim_t get_weights_offset(const memory_desc_wrapper &weights_md, |
38 | const bool with_groups, const dim_t kd, const dim_t kh, |
39 | const dim_t kw) { |
40 | auto ndims = weights_md.ndims(); |
41 | if (with_groups) ndims -= 1; |
42 | |
43 | switch (ndims) { |
44 | case 5: |
45 | return with_groups ? weights_md.blk_off(0, 0, 0, kd, kh, kw) |
46 | : weights_md.blk_off(0, 0, kd, kh, kw); |
47 | case 4: |
48 | return with_groups ? weights_md.blk_off(0, 0, 0, kh, kw) |
49 | : weights_md.blk_off(0, 0, kh, kw); |
50 | case 3: |
51 | return with_groups ? weights_md.blk_off(0, 0, 0, kw) |
52 | : weights_md.blk_off(0, 0, kw); |
53 | default: assert(!"unsupported ndims" ); return dim_t(0); |
54 | } |
55 | } |
56 | |
57 | static dim_t calculate_blk_size(const conv_gemm_conf_t &jcp) { |
58 | const auto number_of_threads = dnnl_get_max_threads(); |
59 | const auto number_of_tasks = jcp.zp.src_pad_comp.d * jcp.zp.src_pad_comp.h |
60 | * jcp.zp.src_pad_comp.w; |
61 | auto scaling_factor = number_of_threads / number_of_tasks; |
62 | const auto output_channels = jcp.oc * jcp.ngroups; |
63 | static constexpr dim_t min_blk_size |
64 | = platform::get_cache_line_size() / sizeof(int32_t); |
65 | |
66 | if (output_channels <= min_blk_size || scaling_factor <= 1) |
67 | return output_channels; |
68 | |
69 | const auto scaling_factor_threashold |
70 | = nstl::max(output_channels / (2 * min_blk_size), dim_t(1)); |
71 | if (scaling_factor > scaling_factor_threashold) { |
72 | scaling_factor = scaling_factor_threashold; |
73 | } |
74 | |
75 | if (const auto blk_size |
76 | = utils::rnd_up(output_channels / scaling_factor, min_blk_size)) { |
77 | return blk_size; |
78 | } |
79 | |
80 | return output_channels; |
81 | } |
82 | |
83 | static void append_weights_to_comp_pad_buf(const conv_gemm_conf_t &jcp, |
84 | int32_t *const __restrict zp_src_pad_comp, |
85 | const int8_t *__restrict weights, dim_t weights_offset, |
86 | const dim_t start_oc_blk, const dim_t end_oc_blk) { |
87 | const auto output_channels = jcp.oc * jcp.ngroups; |
88 | |
89 | for (dim_t it_ic = 0; it_ic < jcp.ic; ++it_ic) { |
90 | for (dim_t oc_off = start_oc_blk; oc_off < end_oc_blk; ++oc_off) { |
91 | zp_src_pad_comp[oc_off] |
92 | += static_cast<int32_t>(weights[weights_offset + oc_off]); |
93 | } |
94 | |
95 | weights_offset += output_channels; |
96 | } |
97 | } |
98 | |
99 | static dim_t calc_filter_corner_dim(const dim_t it_zp_buf_dim, |
100 | const dim_t &dim_size, const dim_t &input_begin_pad, |
101 | const dim_t &stride_dim, const dim_t &begin_comp_pad, |
102 | const bool &mid_comp_pad, const dim_t &end_comp_pad) { |
103 | |
104 | if (it_zp_buf_dim < begin_comp_pad) |
105 | return it_zp_buf_dim * stride_dim - input_begin_pad; |
106 | else if (mid_comp_pad && it_zp_buf_dim == begin_comp_pad) |
107 | return 0; |
108 | else |
109 | return (dim_size - 1) * stride_dim - input_begin_pad |
110 | - (end_comp_pad - 1) * stride_dim |
111 | + (it_zp_buf_dim - (begin_comp_pad + mid_comp_pad)) |
112 | * stride_dim; |
113 | } |
114 | |
115 | void compute_zp_src_comp_pad(const conv_gemm_conf_t &jcp, |
116 | int32_t *const zp_src_pad_buf, const int32_t *const zp_src, |
117 | const int8_t *weights, const memory_desc_wrapper &weights_md, |
118 | const bool with_groups) { |
119 | |
120 | const dim_t blk_size = calculate_blk_size(jcp); |
121 | const dim_t output_channels = jcp.oc * jcp.ngroups; |
122 | const dim_t oc_blks = utils::div_up(output_channels, blk_size); |
123 | |
124 | const auto compute_zp_src_pad_buf = [&](const dim_t zp_pad_com_d, |
125 | const dim_t zp_pad_com_h, |
126 | const dim_t zp_pad_com_w, |
127 | const dim_t filter_corner_src_d, |
128 | const dim_t filter_corner_src_h, |
129 | const dim_t filter_corner_src_w, |
130 | const dim_t oc_blk) { |
131 | const auto start_blk = oc_blk * blk_size; |
132 | const auto end_blk = nstl::min(start_blk + blk_size, output_channels); |
133 | const auto size = end_blk - start_blk; |
134 | const auto zp_pad_offset = zp_src_comp_pad_offset( |
135 | jcp, zp_pad_com_d, zp_pad_com_h, zp_pad_com_w, 0, 0); |
136 | int32_t *const __restrict zp_src_pad_comp |
137 | = zp_src_pad_buf + zp_pad_offset; |
138 | |
139 | std::memset(zp_src_pad_comp + start_blk, 0, size * sizeof(int32_t)); |
140 | |
141 | const auto dilate_scale_d = jcp.dilate_d + 1; |
142 | const auto dilate_scale_h = jcp.dilate_h + 1; |
143 | const auto dilate_scale_w = jcp.dilate_w + 1; |
144 | |
145 | for (int it_kd = 0; it_kd < jcp.kd; it_kd++) { |
146 | const int filter_point_d = it_kd * dilate_scale_d; |
147 | const int filter_point_src_d = filter_corner_src_d + filter_point_d; |
148 | const bool filter_point_srd_d_pad |
149 | = filter_point_src_d < 0 || filter_point_src_d >= jcp.id; |
150 | |
151 | for (int it_kh = 0; it_kh < jcp.kh; it_kh++) { |
152 | const int filter_point_h = it_kh * dilate_scale_h; |
153 | const int filter_point_src_h |
154 | = filter_corner_src_h + filter_point_h; |
155 | const bool filter_point_srd_h_pad = filter_point_src_h < 0 |
156 | || filter_point_src_h >= jcp.ih; |
157 | |
158 | for (int it_kw = 0; it_kw < jcp.kw; it_kw++) { |
159 | const int filter_point_w = it_kw * dilate_scale_w; |
160 | const int filter_point_src_w |
161 | = filter_corner_src_w + filter_point_w; |
162 | |
163 | if (filter_point_srd_d_pad || filter_point_srd_h_pad |
164 | || filter_point_src_w < 0 |
165 | || filter_point_src_w >= jcp.iw) { |
166 | const auto weights_offset = get_weights_offset( |
167 | weights_md, with_groups, it_kd, it_kh, it_kw); |
168 | append_weights_to_comp_pad_buf(jcp, zp_src_pad_comp, |
169 | weights, weights_offset, start_blk, end_blk); |
170 | } |
171 | } |
172 | } |
173 | } |
174 | |
175 | if (jcp.zp.src_is_common) { |
176 | const int32_t zp_src_val = *zp_src; |
177 | for (auto oc_off = start_blk; oc_off < end_blk; ++oc_off) |
178 | zp_src_pad_comp[oc_off] *= zp_src_val; |
179 | } else { |
180 | for (auto oc_off = start_blk; oc_off < end_blk; ++oc_off) |
181 | zp_src_pad_comp[oc_off] *= zp_src[oc_off]; |
182 | } |
183 | }; |
184 | |
185 | const auto compute_zp_buf_w = [&](dim_t it_zp_buf_d, dim_t it_zp_buf_h, |
186 | dim_t it_zp_buf_w, |
187 | dim_t filter_corner_src_d, |
188 | dim_t filter_corner_src_h, |
189 | const dim_t oc_blk) { |
190 | const int filter_corner_src_w = calc_filter_corner_dim(it_zp_buf_w, |
191 | jcp.ow, jcp.l_pad, jcp.stride_w, jcp.zp.src_pad_comp.left_pad, |
192 | jcp.zp.src_pad_comp.mid_w, jcp.zp.src_pad_comp.right_pad); |
193 | compute_zp_src_pad_buf(it_zp_buf_d, it_zp_buf_h, it_zp_buf_w, |
194 | filter_corner_src_d, filter_corner_src_h, filter_corner_src_w, |
195 | oc_blk); |
196 | }; |
197 | |
198 | const auto compute_zp_buf_h = [&](dim_t it_zp_buf_d, dim_t it_zp_buf_h, |
199 | dim_t it_zp_buf_w, |
200 | dim_t filter_corner_src_d, |
201 | const dim_t oc_blk) { |
202 | const auto filter_corner_src_h = calc_filter_corner_dim(it_zp_buf_h, |
203 | jcp.oh, jcp.t_pad, jcp.stride_h, jcp.zp.src_pad_comp.top_pad, |
204 | jcp.zp.src_pad_comp.mid_h, jcp.zp.src_pad_comp.bottom_pad); |
205 | |
206 | compute_zp_buf_w(it_zp_buf_d, it_zp_buf_h, it_zp_buf_w, |
207 | filter_corner_src_d, filter_corner_src_h, oc_blk); |
208 | }; |
209 | |
210 | parallel_nd(jcp.zp.src_pad_comp.d, jcp.zp.src_pad_comp.h, |
211 | jcp.zp.src_pad_comp.w, oc_blks, |
212 | [&](const dim_t it_zp_buf_d, const dim_t it_zp_buf_h, |
213 | const dim_t it_zp_buf_w, const dim_t oc_blk) { |
214 | const int filter_corner_src_d |
215 | = calc_filter_corner_dim(it_zp_buf_d, jcp.od, jcp.f_pad, |
216 | jcp.stride_d, jcp.zp.src_pad_comp.front_pad, |
217 | jcp.zp.src_pad_comp.mid_d, |
218 | jcp.zp.src_pad_comp.back_pad); |
219 | |
220 | compute_zp_buf_h(it_zp_buf_d, it_zp_buf_h, it_zp_buf_w, |
221 | filter_corner_src_d, oc_blk); |
222 | }); |
223 | } |
224 | |
225 | static dim_t zp_src_comp_pad_offset(const conv_gemm_conf_t &jcp, |
226 | const dim_t zp_pad_com_d, const dim_t zp_pad_com_h, |
227 | const dim_t zp_pad_com_w, const dim_t g) { |
228 | return zp_src_comp_pad_offset( |
229 | jcp, zp_pad_com_d, zp_pad_com_h, zp_pad_com_w, 0, g); |
230 | } |
231 | |
232 | static dim_t gemm_conv_result_offset( |
233 | const conv_gemm_conf_t &jcp, const dim_t h, const dim_t w) { |
234 | return (h * jcp.ow + w) * jcp.oc; |
235 | } |
236 | |
237 | static void append_zp_src_comp_pad(const conv_gemm_conf_t &jcp, |
238 | const int32_t *__restrict zp_src_pad_comp, |
239 | const dim_t zp_src_comp_pad_offset, |
240 | int32_t *__restrict gemm_conv_result, |
241 | const dim_t gemm_conv_result_offset) { |
242 | |
243 | const int32_t *const __restrict zp_src_pad_comp_h_w |
244 | = zp_src_pad_comp + zp_src_comp_pad_offset; |
245 | int32_t *const __restrict gemm_conv_result_h_w |
246 | = gemm_conv_result + gemm_conv_result_offset; |
247 | const std::ptrdiff_t oc = jcp.oc; |
248 | |
249 | for (std::ptrdiff_t oc_off = 0; oc_off < oc; ++oc_off) |
250 | gemm_conv_result_h_w[oc_off] += zp_src_pad_comp_h_w[oc_off]; |
251 | } |
252 | |
253 | static dim_t get_zp_pad_com_dim(const bool dim_under_lower_bound, |
254 | const bool dim_over_eq_upper_bound, const dim_t begin_pad, bool mid_pad, |
255 | const dim_t end_pad, const dim_t out_dim_size, |
256 | const dim_t out_point_dim) { |
257 | |
258 | if (dim_under_lower_bound) { |
259 | return out_point_dim; |
260 | } else if (dim_over_eq_upper_bound) { |
261 | return begin_pad + mid_pad + (end_pad - (out_dim_size - out_point_dim)); |
262 | } |
263 | |
264 | return begin_pad; |
265 | } |
266 | |
267 | dim_t calculate_lower_bound_dim( |
268 | const dim_t dim_offset, const dim_t begin_comp_pad) { |
269 | return dim_offset < begin_comp_pad ? begin_comp_pad - dim_offset : 0u; |
270 | } |
271 | |
272 | dim_t calculate_upper_bound_dim(const dim_t output_dim_size, |
273 | const dim_t dim_size, const dim_t dim_offset, |
274 | const dim_t end_comp_pad) { |
275 | |
276 | const dim_t distance_to_ouput_end |
277 | = output_dim_size - (dim_offset + dim_size); |
278 | |
279 | const dim_t output_created_from_pad = distance_to_ouput_end < end_comp_pad |
280 | ? end_comp_pad - distance_to_ouput_end |
281 | : 0u; |
282 | |
283 | return dim_size - output_created_from_pad; |
284 | } |
285 | |
286 | void apply_zp_src_comp_pad(const conv_gemm_conf_t &jcp, const dim_t g, |
287 | const dim_t d_offset, const dim_t h_offset, const dim_t w_offset, |
288 | const dim_t h_size, const dim_t w_size, |
289 | int32_t *__restrict gemm_conv_result, |
290 | const int32_t *__restrict zp_src_pad_buf) { |
291 | |
292 | const auto &comp_pad = jcp.zp.src_pad_comp; |
293 | const dim_t lower_d_bound |
294 | = calculate_lower_bound_dim(0, comp_pad.front_pad); |
295 | const dim_t upper_d_bound |
296 | = calculate_upper_bound_dim(jcp.od, jcp.od, 0, comp_pad.back_pad); |
297 | |
298 | const bool d_under_lower_bound = d_offset < lower_d_bound; |
299 | const bool d_over_eq_upper_bound = d_offset >= upper_d_bound; |
300 | const bool should_apply_zp_src_pad_comp_d |
301 | = d_under_lower_bound || d_over_eq_upper_bound; |
302 | const dim_t zp_pad_com_d = get_zp_pad_com_dim(d_under_lower_bound, |
303 | d_over_eq_upper_bound, comp_pad.front_pad, comp_pad.mid_d, |
304 | comp_pad.back_pad, jcp.od, d_offset); |
305 | |
306 | const dim_t lower_h_bound |
307 | = calculate_lower_bound_dim(h_offset, comp_pad.top_pad); |
308 | const dim_t upper_h_bound = calculate_upper_bound_dim( |
309 | jcp.oh, h_size, h_offset, comp_pad.bottom_pad); |
310 | const dim_t lower_w_bound |
311 | = calculate_lower_bound_dim(w_offset, comp_pad.left_pad); |
312 | const dim_t upper_w_bound = calculate_upper_bound_dim( |
313 | jcp.ow, w_size, w_offset, comp_pad.right_pad); |
314 | |
315 | parallel_nd(h_size, w_size, [=](const dim_t h, const dim_t w) { |
316 | const bool h_under_lower_bound = h < lower_h_bound; |
317 | const bool h_over_eq_upper_bound = h >= upper_h_bound; |
318 | const bool w_under_lower_bound = w < lower_w_bound; |
319 | const bool w_over_eq_upper_bound = w >= upper_w_bound; |
320 | |
321 | const bool should_apply_zp_src_pad_comp = should_apply_zp_src_pad_comp_d |
322 | || w_under_lower_bound || w_over_eq_upper_bound |
323 | || h_under_lower_bound || h_over_eq_upper_bound; |
324 | |
325 | if (!should_apply_zp_src_pad_comp) return; |
326 | |
327 | const auto out_point_h = h_offset + h; |
328 | const auto out_point_w = w_offset + w; |
329 | |
330 | const dim_t zp_pad_com_h = get_zp_pad_com_dim(h_under_lower_bound, |
331 | h_over_eq_upper_bound, comp_pad.top_pad, comp_pad.mid_h, |
332 | comp_pad.bottom_pad, jcp.oh, out_point_h); |
333 | |
334 | const dim_t zp_pad_com_w = get_zp_pad_com_dim(w_under_lower_bound, |
335 | w_over_eq_upper_bound, comp_pad.left_pad, comp_pad.mid_w, |
336 | comp_pad.right_pad, jcp.ow, out_point_w); |
337 | |
338 | const auto zp_src_comp_pad_off = zp_src_comp_pad_offset( |
339 | jcp, zp_pad_com_d, zp_pad_com_h, zp_pad_com_w, g); |
340 | const auto gemm_result_off = gemm_conv_result_offset(jcp, h, w); |
341 | |
342 | append_zp_src_comp_pad(jcp, zp_src_pad_buf, zp_src_comp_pad_off, |
343 | gemm_conv_result, gemm_result_off); |
344 | }); |
345 | } |
346 | |
347 | } // namespace cpu |
348 | } // namespace impl |
349 | } // namespace dnnl |
350 | |