1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/cpu_primitive.hpp" |
23 | |
24 | #include "cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | |
35 | using namespace nstl; |
36 | |
37 | #define wht_blk_off(d, g, ...) \ |
38 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
39 | : (d).blk_off(__VA_ARGS__)) |
40 | |
41 | const float *jit_avx512_core_x8s8s32x_convolution_fwd_t::adjust_oscales( |
42 | const memory_tracking::grantor_t &scratchpad, const float *src_scales, |
43 | const float *wei_scales) const { |
44 | auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales); |
45 | const float src_scale = src_scales[0]; |
46 | const int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
47 | float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) |
48 | ? 1.f / pd()->jcp_.wei_adj_scale |
49 | : 1.f; |
50 | switch (wei_mask) { |
51 | case 0: |
52 | utils::array_set(loc_scales, src_scale * wei_scales[0] * factor, |
53 | pd()->jcp_.simd_w); |
54 | break; |
55 | default: |
56 | for (dim_t c = 0; c < pd()->OC(); c++) |
57 | loc_scales[c] = src_scale * wei_scales[c] * factor; |
58 | } |
59 | return loc_scales; |
60 | } |
61 | |
62 | status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( |
63 | const exec_ctx_t &ctx) const { |
64 | const auto &jcp = pd()->jcp_; |
65 | |
66 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
67 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
68 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
69 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
70 | const auto post_ops_binary_rhs_arg_vec |
71 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
72 | |
73 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
74 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
75 | |
76 | const memory_desc_wrapper src_d(pd()->src_md()); |
77 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
78 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
79 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
80 | |
81 | const size_t bia_dt_size |
82 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
83 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
84 | |
85 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
86 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
87 | |
88 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
89 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
90 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
91 | |
92 | const float *oscales = adjust_oscales( |
93 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
94 | |
95 | size_t |
96 | = weights_d.size() - weights_d.additional_buffer_size(); |
97 | size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block |
98 | : jcp.ngroups * jcp.oc; |
99 | auto w = const_cast<char *>(weights); |
100 | int32_t *compensation = (jcp.signed_input) |
101 | ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) |
102 | : nullptr; |
103 | int32_t *zp_compensation = jcp.src_zero_point |
104 | ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) |
105 | + (jcp.signed_input ? ch_offset : 0) |
106 | : nullptr; |
107 | |
108 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
109 | int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; |
110 | int group_block = jcp.ch_block; |
111 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; |
112 | |
113 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
114 | int start {0}, end {0}; |
115 | balance211(work_amount, nthr, ithr, start, end); |
116 | |
117 | auto p = jit_conv_call_s(); |
118 | |
119 | int n {0}, gg {0}, occ {0}, owb {0}; |
120 | switch (jcp.loop_order) { |
121 | case loop_cwgn: |
122 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
123 | nb_groups, n, jcp.mb); |
124 | break; |
125 | case loop_gncw: |
126 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, |
127 | oc_chunks, owb, jcp.nb_ow); |
128 | break; |
129 | case loop_ngcw: |
130 | nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, |
131 | oc_chunks, owb, jcp.nb_ow); |
132 | break; |
133 | case loop_nwcg: |
134 | nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, |
135 | oc_chunks, gg, nb_groups); |
136 | break; |
137 | default: assert(!"unsupported loop order" ); |
138 | } |
139 | while (start < end) { |
140 | int ocb = occ * jcp.nb_oc_blocking; |
141 | int gb = gg * jcp.nb_ch_blocking; |
142 | int g = gb * group_block; |
143 | int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
144 | int g_ic = g * jcp.nb_ic * jcp.ic_block; |
145 | int ow_s = owb * jcp.ow_block; |
146 | int iw_s = ow_s * jcp.stride_w; |
147 | |
148 | p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) |
149 | : nullptr; |
150 | p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr; |
151 | p.zp_compensation |
152 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
153 | p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; |
154 | p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; |
155 | p.dst_scale = dst_scales; |
156 | p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc, ow_s); |
157 | p.src = src + src_d.blk_off(n, g_ic, iw_s); |
158 | p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); |
159 | p.scales = &oscales[jcp.is_oc_scale * g_oc]; |
160 | p.oc_blocks = jcp.is_depthwise ? gb : ocb; |
161 | p.kh_padding = jcp.kh; |
162 | p.t_overflow = 0; |
163 | p.b_overflow = 0; |
164 | p.owb = owb; |
165 | |
166 | p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
167 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
168 | p.dst_orig = dst; |
169 | (*kernel_)(&p); |
170 | |
171 | ++start; |
172 | switch (jcp.loop_order) { |
173 | case loop_cwgn: |
174 | nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, |
175 | nb_groups, n, jcp.mb); |
176 | break; |
177 | case loop_gncw: |
178 | nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, |
179 | owb, jcp.nb_ow); |
180 | break; |
181 | case loop_ngcw: |
182 | nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, |
183 | owb, jcp.nb_ow); |
184 | break; |
185 | case loop_nwcg: |
186 | nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, |
187 | gg, nb_groups); |
188 | break; |
189 | default: assert(!"unsupported loop order" ); |
190 | } |
191 | } |
192 | }); |
193 | return status::success; |
194 | } |
195 | |
196 | status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( |
197 | const exec_ctx_t &ctx) const { |
198 | const auto &jcp = pd()->jcp_; |
199 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
200 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
201 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
202 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
203 | const auto post_ops_binary_rhs_arg_vec |
204 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
205 | |
206 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
207 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
208 | |
209 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
210 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
211 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
212 | |
213 | const memory_desc_wrapper src_d(pd()->src_md()); |
214 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
215 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
216 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
217 | |
218 | const size_t bia_dt_size |
219 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
220 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
221 | |
222 | assert(jcp.ch_block == 1); |
223 | assert(jcp.nb_ch_blocking == 1); |
224 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
225 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
226 | |
227 | const float *oscales = adjust_oscales( |
228 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
229 | |
230 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
231 | auto w = const_cast<char *>(weights); |
232 | int32_t *compensation = (jcp.signed_input) |
233 | ? reinterpret_cast<int32_t *>(&w[offset]) |
234 | : nullptr; |
235 | int32_t *zp_compensation = jcp.src_zero_point |
236 | ? reinterpret_cast<int32_t *>(&w[offset]) |
237 | + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) |
238 | : nullptr; |
239 | |
240 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; |
241 | int nb_groups = jcp.nb_ch; |
242 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; |
243 | |
244 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
245 | int start {0}, end {0}; |
246 | balance211(work_amount, nthr, ithr, start, end); |
247 | |
248 | auto p = jit_conv_call_s(); |
249 | |
250 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
251 | size_t dst_h_stride = dst_d.blk_off(0, 0, 1); |
252 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
253 | |
254 | int n {0}, g {0}, occ {0}, oh_s {0}, owb {0}; |
255 | switch (jcp.loop_order) { |
256 | case loop_cwgn: |
257 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, |
258 | nb_groups, n, jcp.mb, oh_s, jcp.oh); |
259 | break; |
260 | case loop_ngcw: |
261 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
262 | owb, jcp.nb_ow, oh_s, jcp.oh); |
263 | break; |
264 | case loop_nhwcg: |
265 | nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
266 | occ, oc_chunks, g, nb_groups); |
267 | break; |
268 | default: assert(!"unsupported loop order" ); |
269 | } |
270 | while (start < end) { |
271 | for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; |
272 | occ1 += jcp.nb_oc_blocking) { |
273 | int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; |
274 | int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
275 | |
276 | int g_ic = g * jcp.nb_ic * jcp.ic_block; |
277 | |
278 | int work_rem = end - start; |
279 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
280 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
281 | if (jcp.loop_order == loop_nhwcg) |
282 | oh_e = oh_s + 1; // step instead |
283 | int ow_s = owb * jcp.ow_block; |
284 | int iw_s = ow_s * jcp.stride_w; |
285 | |
286 | auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) |
287 | : nullptr; |
288 | int32_t *compensation_w |
289 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
290 | |
291 | auto dst_w = dst |
292 | + dst_dt_size * dst_d.blk_off(n, g_oc, oh_s, ow_s); |
293 | auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); |
294 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); |
295 | |
296 | auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
297 | |
298 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
299 | ++oj, ij += jcp.stride_h) { |
300 | int dilate_h = jcp.dilate_h + 1; |
301 | int i_t_overflow |
302 | = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h)); |
303 | int i_b_overflow = nstl::min(jcp.kh, |
304 | div_up(max(0, |
305 | ij - jcp.ih + (jcp.kh - 1) * dilate_h |
306 | + 1), |
307 | dilate_h)); |
308 | int kh_padding = nstl::max( |
309 | 0, jcp.kh - i_t_overflow - i_b_overflow); |
310 | |
311 | size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) |
312 | ? 0 |
313 | : i_t_overflow * wht_h_stride; |
314 | p.src = src_w + i_t_overflow * dilate_h * src_h_stride; |
315 | p.dst = dst_w; |
316 | p.filt = wht_w + wei_stride; |
317 | p.bias = bias_w; |
318 | p.compensation = compensation_w; |
319 | p.zp_compensation = jcp.src_zero_point |
320 | ? zp_compensation + g_oc |
321 | : nullptr; |
322 | p.src_zero_point |
323 | = jcp.src_zero_point ? src_zero_point : nullptr; |
324 | p.dst_zero_point |
325 | = jcp.dst_zero_point ? dst_zero_point : nullptr; |
326 | p.dst_scale = dst_scales; |
327 | p.oc_blocks = ocb; |
328 | p.kh_padding = kh_padding; |
329 | p.scales = scales; |
330 | p.t_overflow = i_t_overflow; |
331 | p.b_overflow = i_b_overflow; |
332 | p.owb = owb; |
333 | |
334 | p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
335 | p.post_ops_binary_rhs_arg_vec |
336 | = post_ops_binary_rhs_arg_vec.data(); |
337 | p.dst_orig = dst; |
338 | (*kernel_)(&p); |
339 | |
340 | src_w += src_h_stride * jcp.stride_h; |
341 | dst_w += dst_dt_size * dst_h_stride; |
342 | } |
343 | } |
344 | switch (jcp.loop_order) { |
345 | case loop_cwgn: |
346 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
347 | g, nb_groups, n, jcp.mb, oh_s, jcp.oh); |
348 | break; |
349 | case loop_ngcw: |
350 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
351 | oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); |
352 | break; |
353 | case loop_nhwcg: |
354 | ++start; |
355 | nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
356 | occ, oc_chunks, g, nb_groups); |
357 | break; |
358 | default: assert(!"unsupported loop order" ); |
359 | } |
360 | } |
361 | }); |
362 | return status::success; |
363 | } |
364 | |
365 | status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( |
366 | const exec_ctx_t &ctx) const { |
367 | const auto &jcp = pd()->jcp_; |
368 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
369 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
370 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
371 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
372 | const auto post_ops_binary_rhs_arg_vec |
373 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
374 | |
375 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
376 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
377 | |
378 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
379 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
380 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
381 | |
382 | const memory_desc_wrapper src_d(pd()->src_md()); |
383 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
384 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
385 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
386 | |
387 | const size_t bia_dt_size |
388 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
389 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
390 | |
391 | assert(jcp.ic_block == 1); |
392 | assert(jcp.oc_block == 1); |
393 | assert(jcp.nb_ic == 1); |
394 | assert(jcp.nb_oc == 1); |
395 | assert(jcp.nb_oc_blocking == 1); |
396 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
397 | |
398 | const float *oscales = adjust_oscales( |
399 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
400 | |
401 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
402 | auto w = const_cast<char *>(weights); |
403 | int32_t *compensation = (jcp.signed_input) |
404 | ? reinterpret_cast<int32_t *>(&w[offset]) |
405 | : nullptr; |
406 | int32_t *zp_compensation = jcp.src_zero_point |
407 | ? reinterpret_cast<int32_t *>(&w[offset]) |
408 | + (jcp.signed_input ? jcp.nb_ch * jcp.ch_block : 0) |
409 | : nullptr; |
410 | int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; |
411 | int group_block = jcp.ch_block; |
412 | |
413 | parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, |
414 | [&](dim_t n, dim_t oh_s, dim_t owb, dim_t gg) { |
415 | auto p = jit_conv_call_s(); |
416 | |
417 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
418 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
419 | |
420 | int gb = gg * jcp.nb_ch_blocking; |
421 | int g = gb * group_block; |
422 | |
423 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
424 | int ow_s = owb * jcp.ow_block; |
425 | int iw_s = ow_s * jcp.stride_w; |
426 | |
427 | auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) |
428 | : nullptr; |
429 | int32_t *compensation_w |
430 | = jcp.signed_input ? compensation + g : nullptr; |
431 | |
432 | auto dst_w |
433 | = dst + dst_dt_size * dst_d.blk_off(n, g, oh_s, ow_s); |
434 | auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); |
435 | auto wht_w = weights + wht_blk_off(weights_d, gb, 0); |
436 | |
437 | auto scales = &oscales[jcp.is_oc_scale * g]; |
438 | |
439 | int dilate_h = jcp.dilate_h + 1; |
440 | int i_t_overflow |
441 | = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); |
442 | int i_b_overflow = nstl::min(jcp.kh, |
443 | div_up(max(0, |
444 | ih_s - jcp.ih + (jcp.kh - 1) * dilate_h |
445 | + 1), |
446 | dilate_h)); |
447 | int kh_padding |
448 | = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); |
449 | |
450 | size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) |
451 | ? 0 |
452 | : i_t_overflow * wht_h_stride; |
453 | p.src = src_w + i_t_overflow * dilate_h * src_h_stride; |
454 | p.dst = dst_w; |
455 | p.filt = wht_w + wei_stride; |
456 | p.bias = bias_w; |
457 | p.compensation = compensation_w; |
458 | p.zp_compensation |
459 | = jcp.src_zero_point ? zp_compensation + g : nullptr; |
460 | p.src_zero_point |
461 | = jcp.src_zero_point ? src_zero_point : nullptr; |
462 | p.dst_zero_point |
463 | = jcp.dst_zero_point ? dst_zero_point : nullptr; |
464 | p.dst_scale = dst_scales; |
465 | p.oc_blocks = gb; |
466 | p.kh_padding = kh_padding; |
467 | p.scales = scales; |
468 | p.t_overflow = i_t_overflow; |
469 | p.b_overflow = i_b_overflow; |
470 | p.owb = owb; |
471 | |
472 | p.oc_l_off = g * jcp.oc; |
473 | p.post_ops_binary_rhs_arg_vec |
474 | = post_ops_binary_rhs_arg_vec.data(); |
475 | p.dst_orig = dst; |
476 | |
477 | (*kernel_)(&p); |
478 | }); |
479 | return status::success; |
480 | } |
481 | |
482 | status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( |
483 | const exec_ctx_t &ctx) const { |
484 | const auto &jcp = pd()->jcp_; |
485 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
486 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
487 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
488 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
489 | const auto post_ops_binary_rhs_arg_vec |
490 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
491 | |
492 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
493 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
494 | |
495 | const memory_desc_wrapper src_d(pd()->src_md()); |
496 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
497 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
498 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
499 | |
500 | const size_t bia_dt_size |
501 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
502 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
503 | |
504 | assert(jcp.ch_block == 1); |
505 | assert(jcp.nb_ch_blocking == 1); |
506 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
507 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
508 | |
509 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
510 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
511 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
512 | |
513 | const float *oscales = adjust_oscales( |
514 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
515 | |
516 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
517 | auto w = const_cast<char *>(weights); |
518 | int32_t *compensation = (jcp.signed_input) |
519 | ? reinterpret_cast<int32_t *>(&w[offset]) |
520 | : nullptr; |
521 | int32_t *zp_compensation = jcp.src_zero_point |
522 | ? reinterpret_cast<int32_t *>(&w[offset]) |
523 | + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) |
524 | : nullptr; |
525 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; |
526 | int nb_groups = jcp.nb_ch; |
527 | int work_amount |
528 | = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; |
529 | |
530 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
531 | int start {0}, end {0}; |
532 | balance211(work_amount, nthr, ithr, start, end); |
533 | |
534 | auto p = jit_conv_call_s(); |
535 | |
536 | size_t src_d_stride = src_d.blk_off(0, 0, 1); |
537 | size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); |
538 | size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); |
539 | size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
540 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
541 | |
542 | int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0}, owb {0}; |
543 | switch (jcp.loop_order) { |
544 | case loop_cwgn: |
545 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, |
546 | nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
547 | break; |
548 | case loop_ngcw: |
549 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
550 | owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); |
551 | break; |
552 | case loop_nhwcg: |
553 | nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, |
554 | owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); |
555 | break; |
556 | default: assert(!"unsupported loop order" ); |
557 | } |
558 | while (start < end) { |
559 | for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; |
560 | occ1 += jcp.nb_oc_blocking) { |
561 | int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; |
562 | int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
563 | |
564 | int g_ic = g * jcp.nb_ic * jcp.ic_block; |
565 | |
566 | int work_rem = end - start; |
567 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
568 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
569 | if (jcp.loop_order == loop_nhwcg) |
570 | oh_e = oh_s + 1; // step instead |
571 | int ow_s = owb * jcp.ow_block; |
572 | int iw_s = ow_s * jcp.stride_w; |
573 | int id_s = -jcp.f_pad + od_s * jcp.stride_d; |
574 | int dilate_d = jcp.dilate_d + 1; |
575 | int d_f_overflow |
576 | = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d)); |
577 | int d_back_overflow = nstl::min(jcp.kd, |
578 | div_up(max(0, |
579 | id_s - jcp.id + (jcp.kd - 1) * dilate_d |
580 | + 1), |
581 | dilate_d)); |
582 | |
583 | int kd_padding |
584 | = nstl::max(0, jcp.kd - d_f_overflow - d_back_overflow); |
585 | |
586 | auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) |
587 | : nullptr; |
588 | int32_t *compensation_w |
589 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
590 | |
591 | auto dst_w = dst |
592 | + dst_dt_size |
593 | * dst_d.blk_off(n, g_oc, od_s, oh_s, ow_s); |
594 | auto src_w = src + src_d.blk_off(n, g_ic, id_s, ih_s, iw_s) |
595 | + d_f_overflow * dilate_d * src_d_stride; |
596 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) |
597 | + ((jcp.signed_input || jcp.src_zero_point) |
598 | ? 0 |
599 | : d_f_overflow) |
600 | * wht_d_stride; |
601 | |
602 | auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
603 | |
604 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
605 | ++oj, ij += jcp.stride_h) { |
606 | int dilate_h = jcp.dilate_h + 1; |
607 | int i_t_overflow |
608 | = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h)); |
609 | int i_b_overflow = nstl::min(jcp.kh, |
610 | div_up(max(0, |
611 | ij - jcp.ih + (jcp.kh - 1) * dilate_h |
612 | + 1), |
613 | dilate_h)); |
614 | int kh_padding = nstl::max( |
615 | 0, jcp.kh - i_t_overflow - i_b_overflow); |
616 | |
617 | size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) |
618 | ? 0 |
619 | : wht_h_stride * i_t_overflow; |
620 | p.src = src_w + i_t_overflow * dilate_h * src_h_stride; |
621 | p.dst = dst_w; |
622 | p.filt = wht_w + wei_stride; |
623 | p.bias = bias_w; |
624 | p.compensation = compensation_w; |
625 | p.zp_compensation = jcp.src_zero_point |
626 | ? zp_compensation + g_oc |
627 | : nullptr; |
628 | p.src_zero_point |
629 | = jcp.src_zero_point ? src_zero_point : nullptr; |
630 | p.dst_zero_point |
631 | = jcp.dst_zero_point ? dst_zero_point : nullptr; |
632 | p.dst_scale = dst_scales; |
633 | p.oc_blocks = ocb; |
634 | p.kh_padding = kh_padding; |
635 | p.kd_padding = kd_padding; |
636 | p.scales = scales; |
637 | p.t_overflow = i_t_overflow; |
638 | p.b_overflow = i_b_overflow; |
639 | p.f_overflow = d_f_overflow; |
640 | p.back_overflow = d_back_overflow; |
641 | p.owb = owb; |
642 | |
643 | p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
644 | p.post_ops_binary_rhs_arg_vec |
645 | = post_ops_binary_rhs_arg_vec.data(); |
646 | p.dst_orig = dst; |
647 | (*kernel_)(&p); |
648 | |
649 | src_w += src_h_stride * jcp.stride_h; |
650 | dst_w += dst_dt_size * dst_h_stride; |
651 | } |
652 | } |
653 | switch (jcp.loop_order) { |
654 | case loop_cwgn: |
655 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
656 | g, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, |
657 | jcp.oh); |
658 | break; |
659 | case loop_ngcw: |
660 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
661 | oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, |
662 | jcp.oh); |
663 | break; |
664 | case loop_nhwcg: |
665 | ++start; |
666 | nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, |
667 | jcp.nb_ow, occ, oc_chunks, g, nb_groups); |
668 | break; |
669 | default: assert(!"unsupported loop order" ); |
670 | } |
671 | } |
672 | }); |
673 | return status::success; |
674 | } |
675 | |
676 | } // namespace x64 |
677 | } // namespace cpu |
678 | } // namespace impl |
679 | } // namespace dnnl |
680 | |
681 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
682 | |