1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/cpu_primitive.hpp" |
23 | |
24 | #include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | using namespace data_type; |
35 | |
36 | using namespace nstl; |
37 | |
38 | #define wht_blk_off(d, g, ...) \ |
39 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
40 | : (d).blk_off(__VA_ARGS__)) |
41 | |
42 | template <cpu_isa_t isa> |
43 | const float *jit_uni_x8s8s32x_convolution_fwd_t<isa>::adjust_oscales( |
44 | const memory_tracking::grantor_t &scratchpad, const float *src_scales, |
45 | const float *wei_scales) const { |
46 | auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales); |
47 | int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
48 | float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) |
49 | ? 1.f / pd()->jcp_.wei_adj_scale |
50 | : 1.0f; |
51 | if (wei_mask == 0) { |
52 | utils::array_set(loc_scales, src_scales[0] * wei_scales[0] * factor, 8); |
53 | } else { |
54 | for (dim_t c = 0; c < pd()->OC(); c++) |
55 | loc_scales[c] = src_scales[0] * wei_scales[c] * factor; |
56 | } |
57 | return loc_scales; |
58 | } |
59 | |
60 | template <cpu_isa_t isa> |
61 | status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_2d( |
62 | const exec_ctx_t &ctx) const { |
63 | const auto &jcp = pd()->jcp_; |
64 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
65 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
66 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
67 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
68 | const auto post_ops_binary_rhs_arg_vec |
69 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
70 | |
71 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
72 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
73 | |
74 | const memory_desc_wrapper src_d(pd()->src_md()); |
75 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
76 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
77 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
78 | |
79 | const size_t bia_dt_size |
80 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
81 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
82 | |
83 | assert(jcp.ch_block == 1); |
84 | assert(jcp.nb_ch_blocking == 1); |
85 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
86 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
87 | |
88 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
89 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
90 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
91 | |
92 | const float *oscales = adjust_oscales( |
93 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
94 | |
95 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
96 | auto w = const_cast<char *>(weights); |
97 | const int32_t *compensation = (jcp.signed_input) |
98 | ? reinterpret_cast<int32_t *>(&w[offset]) |
99 | : nullptr; |
100 | const int32_t *zp_compensation = jcp.src_zero_point |
101 | ? reinterpret_cast<int32_t *>(&w[offset]) |
102 | + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) |
103 | : nullptr; |
104 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; |
105 | int nb_groups = jcp.nb_ch; |
106 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; |
107 | |
108 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
109 | int start {0}, end {0}; |
110 | balance211(work_amount, nthr, ithr, start, end); |
111 | |
112 | auto p = jit_conv_call_s(); |
113 | |
114 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
115 | size_t dst_h_stride = dst_d.blk_off(0, 0, 1); |
116 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
117 | |
118 | int n {0}, g {0}, occ {0}, oh_s {0}, owb {0}; |
119 | switch (jcp.loop_order) { |
120 | case loop_cwgn: |
121 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, |
122 | nb_groups, n, jcp.mb, oh_s, jcp.oh); |
123 | break; |
124 | case loop_ngcw: |
125 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
126 | owb, jcp.nb_ow, oh_s, jcp.oh); |
127 | break; |
128 | case loop_nhwcg: |
129 | nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
130 | occ, oc_chunks, g, nb_groups); |
131 | break; |
132 | default: assert(!"unsupported loop order" ); |
133 | } |
134 | while (start < end) { |
135 | for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; |
136 | occ1 += jcp.nb_oc_blocking) { |
137 | int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; |
138 | int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
139 | |
140 | int g_ic = g * jcp.nb_ic * jcp.ic_block; |
141 | |
142 | int work_rem = end - start; |
143 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
144 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
145 | if (jcp.loop_order == loop_nhwcg) |
146 | oh_e = oh_s + 1; // step instead |
147 | int ow_s = owb * jcp.ow_block; |
148 | int iw_s = ow_s * jcp.stride_w; |
149 | |
150 | auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) |
151 | : nullptr; |
152 | const int32_t *compensation_w |
153 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
154 | |
155 | auto dst_w = dst |
156 | + dst_dt_size * dst_d.blk_off(n, g_oc, oh_s, ow_s); |
157 | auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); |
158 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); |
159 | |
160 | auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
161 | |
162 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
163 | ++oj, ij += jcp.stride_h) { |
164 | int dilate_h = jcp.dilate_h + 1; |
165 | int i_t_overflow |
166 | = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h)); |
167 | int i_b_overflow = nstl::min(jcp.kh, |
168 | div_up(max(0, |
169 | ij - jcp.ih + (jcp.kh - 1) * dilate_h |
170 | + 1), |
171 | dilate_h)); |
172 | int kh_padding = nstl::max( |
173 | 0, jcp.kh - i_t_overflow - i_b_overflow); |
174 | |
175 | const size_t wei_stride |
176 | = (jcp.signed_input || jcp.src_zero_point) |
177 | ? 0 |
178 | : i_t_overflow * wht_h_stride; |
179 | p.src = src_w + i_t_overflow * dilate_h * src_h_stride; |
180 | p.dst = dst_w; |
181 | p.filt = wht_w + wei_stride; |
182 | p.bias = bias_w; |
183 | p.compensation = compensation_w; |
184 | p.zp_compensation = jcp.src_zero_point |
185 | ? zp_compensation + g_oc |
186 | : nullptr; |
187 | p.src_zero_point |
188 | = jcp.src_zero_point ? src_zero_point : nullptr; |
189 | p.dst_zero_point |
190 | = jcp.dst_zero_point ? dst_zero_point : nullptr; |
191 | p.oc_blocks = ocb; |
192 | p.kh_padding = kh_padding; |
193 | p.scales = scales; |
194 | p.dst_scale = dst_scales; |
195 | p.t_overflow = i_t_overflow; |
196 | p.b_overflow = i_b_overflow; |
197 | p.owb = owb; |
198 | |
199 | p.oc_l_off = g_oc; |
200 | p.post_ops_binary_rhs_arg_vec |
201 | = post_ops_binary_rhs_arg_vec.data(); |
202 | p.dst_orig = dst; |
203 | |
204 | (*kernel_)(&p); |
205 | src_w += src_h_stride * jcp.stride_h; |
206 | dst_w += dst_dt_size * dst_h_stride; |
207 | } |
208 | } |
209 | switch (jcp.loop_order) { |
210 | case loop_cwgn: |
211 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
212 | g, nb_groups, n, jcp.mb, oh_s, jcp.oh); |
213 | break; |
214 | case loop_ngcw: |
215 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
216 | oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); |
217 | break; |
218 | case loop_nhwcg: |
219 | ++start; |
220 | nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
221 | occ, oc_chunks, g, nb_groups); |
222 | break; |
223 | default: assert(!"unsupported loop order" ); |
224 | } |
225 | } |
226 | }); |
227 | return status::success; |
228 | } |
229 | |
230 | template <cpu_isa_t isa> |
231 | status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_1d( |
232 | const exec_ctx_t &ctx) const { |
233 | const auto &jcp = pd()->jcp_; |
234 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
235 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
236 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
237 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
238 | const auto post_ops_binary_rhs_arg_vec |
239 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
240 | |
241 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
242 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
243 | |
244 | const memory_desc_wrapper src_d(pd()->src_md()); |
245 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
246 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
247 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
248 | |
249 | const size_t bia_dt_size |
250 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
251 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
252 | |
253 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
254 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
255 | |
256 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
257 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
258 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
259 | |
260 | const float *oscales = adjust_oscales( |
261 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
262 | |
263 | size_t |
264 | = weights_d.size() - weights_d.additional_buffer_size(); |
265 | size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block |
266 | : jcp.ngroups * jcp.oc; |
267 | auto w = const_cast<char *>(weights); |
268 | const int32_t *compensation = (jcp.signed_input) |
269 | ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) |
270 | : nullptr; |
271 | const int32_t *zp_compensation = jcp.src_zero_point |
272 | ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) |
273 | + (jcp.signed_input ? ch_offset : 0) |
274 | : nullptr; |
275 | |
276 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
277 | int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; |
278 | int group_block = jcp.ch_block; |
279 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; |
280 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
281 | int start {0}, end {0}; |
282 | balance211(work_amount, nthr, ithr, start, end); |
283 | |
284 | auto p = jit_conv_call_s(); |
285 | |
286 | int n {0}, gg {0}, occ {0}, owb {0}; |
287 | switch (jcp.loop_order) { |
288 | case loop_cwgn: |
289 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
290 | nb_groups, n, jcp.mb); |
291 | break; |
292 | case loop_gncw: |
293 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, |
294 | oc_chunks, owb, jcp.nb_ow); |
295 | break; |
296 | case loop_ngcw: |
297 | nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, |
298 | oc_chunks, owb, jcp.nb_ow); |
299 | break; |
300 | case loop_nwcg: |
301 | nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, |
302 | oc_chunks, gg, nb_groups); |
303 | break; |
304 | default: assert(!"unsupported loop order" ); |
305 | } |
306 | while (start < end) { |
307 | int ocb = occ * jcp.nb_oc_blocking; |
308 | int gb = gg * jcp.nb_ch_blocking; |
309 | int g = gb * group_block; |
310 | int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
311 | int g_ic = g * jcp.nb_ic * jcp.ic_block; |
312 | int ow_s = owb * jcp.ow_block; |
313 | int iw_s = ow_s * jcp.stride_w; |
314 | |
315 | p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) |
316 | : nullptr; |
317 | p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr; |
318 | p.zp_compensation |
319 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
320 | p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; |
321 | p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; |
322 | p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc, ow_s); |
323 | p.src = src + src_d.blk_off(n, g_ic, iw_s); |
324 | p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); |
325 | p.scales = &oscales[jcp.is_oc_scale * g_oc]; |
326 | p.dst_scale = dst_scales; |
327 | p.oc_blocks = jcp.is_depthwise ? gb : ocb; |
328 | p.kh_padding = jcp.kh; |
329 | p.t_overflow = 0; |
330 | p.b_overflow = 0; |
331 | p.owb = owb; |
332 | |
333 | p.oc_l_off = g_oc; |
334 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
335 | p.dst_orig = dst; |
336 | |
337 | (*kernel_)(&p); |
338 | |
339 | ++start; |
340 | switch (jcp.loop_order) { |
341 | case loop_cwgn: |
342 | nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, |
343 | nb_groups, n, jcp.mb); |
344 | break; |
345 | case loop_gncw: |
346 | nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, |
347 | owb, jcp.nb_ow); |
348 | break; |
349 | case loop_ngcw: |
350 | nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, |
351 | owb, jcp.nb_ow); |
352 | break; |
353 | case loop_nwcg: |
354 | nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, |
355 | gg, nb_groups); |
356 | break; |
357 | default: assert(!"unsupported loop order" ); |
358 | } |
359 | } |
360 | }); |
361 | return status::success; |
362 | } |
363 | |
364 | template <cpu_isa_t isa> |
365 | status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_2d_dw( |
366 | const exec_ctx_t &ctx) const { |
367 | const auto &jcp = pd()->jcp_; |
368 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
369 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
370 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
371 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
372 | const auto post_ops_binary_rhs_arg_vec |
373 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
374 | |
375 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
376 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
377 | |
378 | const memory_desc_wrapper src_d(pd()->src_md()); |
379 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
380 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
381 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
382 | |
383 | const size_t bia_dt_size |
384 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
385 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
386 | |
387 | assert(jcp.ic_block == 1); |
388 | assert(jcp.oc_block == 1); |
389 | assert(jcp.nb_ic == 1); |
390 | assert(jcp.nb_oc == 1); |
391 | assert(jcp.nb_oc_blocking == 1); |
392 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
393 | |
394 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
395 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
396 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
397 | |
398 | const float *oscales = adjust_oscales( |
399 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
400 | |
401 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
402 | auto w = const_cast<char *>(weights); |
403 | const int32_t *compensation = (jcp.signed_input) |
404 | ? reinterpret_cast<int32_t *>(&w[offset]) |
405 | : nullptr; |
406 | const int32_t *zp_compensation = jcp.src_zero_point |
407 | ? reinterpret_cast<int32_t *>(&w[offset]) |
408 | + (jcp.signed_input ? jcp.nb_ch * jcp.ch_block : 0) |
409 | : nullptr; |
410 | int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; |
411 | int group_block = jcp.ch_block; |
412 | |
413 | parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, |
414 | [&](dim_t n, dim_t oh_s, dim_t owb, dim_t gg) { |
415 | auto p = jit_conv_call_s(); |
416 | |
417 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
418 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
419 | |
420 | int gb = gg * jcp.nb_ch_blocking; |
421 | int g = gb * group_block; |
422 | |
423 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
424 | int ow_s = owb * jcp.ow_block; |
425 | int iw_s = ow_s * jcp.stride_w; |
426 | |
427 | auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) |
428 | : nullptr; |
429 | const int32_t *compensation_w |
430 | = jcp.signed_input ? compensation + g : nullptr; |
431 | |
432 | auto dst_w |
433 | = dst + dst_dt_size * dst_d.blk_off(n, g, oh_s, ow_s); |
434 | auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); |
435 | auto wht_w = weights + wht_blk_off(weights_d, gb, 0); |
436 | |
437 | auto scales = &oscales[jcp.is_oc_scale * g]; |
438 | |
439 | int dilate_h = jcp.dilate_h + 1; |
440 | int i_t_overflow |
441 | = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); |
442 | int i_b_overflow = nstl::min(jcp.kh, |
443 | div_up(max(0, |
444 | ih_s - jcp.ih + (jcp.kh - 1) * dilate_h |
445 | + 1), |
446 | dilate_h)); |
447 | int kh_padding |
448 | = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); |
449 | |
450 | size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) |
451 | ? 0 |
452 | : i_t_overflow * wht_h_stride; |
453 | p.src = src_w + i_t_overflow * dilate_h * src_h_stride; |
454 | p.dst = dst_w; |
455 | p.filt = wht_w + wei_stride; |
456 | p.bias = bias_w; |
457 | p.compensation = compensation_w; |
458 | p.zp_compensation |
459 | = jcp.src_zero_point ? zp_compensation + g : nullptr; |
460 | p.src_zero_point |
461 | = jcp.src_zero_point ? src_zero_point : nullptr; |
462 | p.dst_zero_point |
463 | = jcp.dst_zero_point ? dst_zero_point : nullptr; |
464 | p.oc_blocks = gb; |
465 | p.kh_padding = kh_padding; |
466 | p.scales = scales; |
467 | p.dst_scale = dst_scales; |
468 | p.t_overflow = i_t_overflow; |
469 | p.b_overflow = i_b_overflow; |
470 | p.owb = owb; |
471 | |
472 | p.oc_l_off = g; |
473 | p.post_ops_binary_rhs_arg_vec |
474 | = post_ops_binary_rhs_arg_vec.data(); |
475 | p.dst_orig = dst; |
476 | |
477 | (*kernel_)(&p); |
478 | }); |
479 | return status::success; |
480 | } |
481 | |
482 | template <cpu_isa_t isa> |
483 | status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_3d( |
484 | const exec_ctx_t &ctx) const { |
485 | const auto &jcp = pd()->jcp_; |
486 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
487 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
488 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
489 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
490 | const auto post_ops_binary_rhs_arg_vec |
491 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
492 | |
493 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
494 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
495 | |
496 | const memory_desc_wrapper src_d(pd()->src_md()); |
497 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
498 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
499 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
500 | |
501 | const size_t bia_dt_size |
502 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
503 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
504 | |
505 | assert(jcp.ch_block == 1); |
506 | assert(jcp.nb_ch_blocking == 1); |
507 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
508 | assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); |
509 | |
510 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
511 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
512 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
513 | |
514 | const float *oscales = adjust_oscales( |
515 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
516 | |
517 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
518 | auto w = const_cast<char *>(weights); |
519 | const int32_t *compensation = (jcp.signed_input) |
520 | ? reinterpret_cast<int32_t *>(&w[offset]) |
521 | : nullptr; |
522 | const int32_t *zp_compensation = jcp.src_zero_point |
523 | ? reinterpret_cast<int32_t *>(&w[offset]) |
524 | + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) |
525 | : nullptr; |
526 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; |
527 | int nb_groups = jcp.nb_ch; |
528 | int work_amount |
529 | = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; |
530 | |
531 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
532 | int start {0}, end {0}; |
533 | balance211(work_amount, nthr, ithr, start, end); |
534 | |
535 | auto p = jit_conv_call_s(); |
536 | |
537 | size_t src_d_stride = src_d.blk_off(0, 0, 1); |
538 | size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); |
539 | size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); |
540 | size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
541 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
542 | |
543 | int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0}, owb {0}; |
544 | switch (jcp.loop_order) { |
545 | case loop_cwgn: |
546 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, |
547 | nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
548 | break; |
549 | case loop_ngcw: |
550 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
551 | owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); |
552 | break; |
553 | case loop_nhwcg: |
554 | nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, |
555 | owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups); |
556 | break; |
557 | default: assert(!"unsupported loop order" ); |
558 | } |
559 | while (start < end) { |
560 | for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; |
561 | occ1 += jcp.nb_oc_blocking) { |
562 | int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; |
563 | int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
564 | |
565 | int g_ic = g * jcp.nb_ic * jcp.ic_block; |
566 | |
567 | int work_rem = end - start; |
568 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
569 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
570 | if (jcp.loop_order == loop_nhwcg) |
571 | oh_e = oh_s + 1; // step instead |
572 | int ow_s = owb * jcp.ow_block; |
573 | int iw_s = ow_s * jcp.stride_w; |
574 | int id_s = -jcp.f_pad + od_s * jcp.stride_d; |
575 | int dilate_d = jcp.dilate_d + 1; |
576 | int d_f_overflow |
577 | = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d)); |
578 | int d_back_overflow = nstl::min(jcp.kd, |
579 | div_up(max(0, |
580 | id_s - jcp.id + (jcp.kd - 1) * dilate_d |
581 | + 1), |
582 | dilate_d)); |
583 | |
584 | int kd_padding |
585 | = nstl::max(0, jcp.kd - d_f_overflow - d_back_overflow); |
586 | |
587 | auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) |
588 | : nullptr; |
589 | const int32_t *compensation_w |
590 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
591 | p.zp_compensation |
592 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
593 | p.src_zero_point |
594 | = jcp.src_zero_point ? src_zero_point : nullptr; |
595 | p.dst_zero_point |
596 | = jcp.dst_zero_point ? dst_zero_point : nullptr; |
597 | |
598 | auto dst_w = dst |
599 | + dst_dt_size |
600 | * dst_d.blk_off(n, g_oc, od_s, oh_s, ow_s); |
601 | auto src_w = src + src_d.blk_off(n, g_ic, id_s, ih_s, iw_s) |
602 | + d_f_overflow * dilate_d * src_d_stride; |
603 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) |
604 | + ((jcp.signed_input || jcp.src_zero_point) |
605 | ? 0 |
606 | : d_f_overflow) |
607 | * wht_d_stride; |
608 | |
609 | auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
610 | |
611 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
612 | ++oj, ij += jcp.stride_h) { |
613 | int dilate_h = jcp.dilate_h + 1; |
614 | int i_t_overflow |
615 | = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h)); |
616 | int i_b_overflow = nstl::min(jcp.kh, |
617 | div_up(max(0, |
618 | ij - jcp.ih + (jcp.kh - 1) * dilate_h |
619 | + 1), |
620 | dilate_h)); |
621 | int kh_padding = nstl::max( |
622 | 0, jcp.kh - i_t_overflow - i_b_overflow); |
623 | |
624 | size_t wei_stride = (jcp.signed_input || jcp.src_zero_point) |
625 | ? 0 |
626 | : wht_h_stride * i_t_overflow; |
627 | p.src = src_w + i_t_overflow * dilate_h * src_h_stride; |
628 | p.dst = dst_w; |
629 | p.filt = wht_w + wei_stride; |
630 | p.bias = bias_w; |
631 | p.compensation = compensation_w; |
632 | p.oc_blocks = ocb; |
633 | p.kh_padding = kh_padding; |
634 | p.kd_padding = kd_padding; |
635 | p.scales = scales; |
636 | p.dst_scale = dst_scales; |
637 | p.t_overflow = i_t_overflow; |
638 | p.b_overflow = i_b_overflow; |
639 | p.f_overflow = d_f_overflow; |
640 | p.back_overflow = d_back_overflow; |
641 | p.owb = owb; |
642 | |
643 | p.oc_l_off = g_oc; |
644 | p.post_ops_binary_rhs_arg_vec |
645 | = post_ops_binary_rhs_arg_vec.data(); |
646 | p.dst_orig = dst; |
647 | |
648 | (*kernel_)(&p); |
649 | src_w += src_h_stride * jcp.stride_h; |
650 | dst_w += dst_dt_size * dst_h_stride; |
651 | } |
652 | } |
653 | switch (jcp.loop_order) { |
654 | case loop_cwgn: |
655 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
656 | g, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, |
657 | jcp.oh); |
658 | break; |
659 | case loop_ngcw: |
660 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
661 | oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, |
662 | jcp.oh); |
663 | break; |
664 | case loop_nhwcg: |
665 | ++start; |
666 | nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, |
667 | jcp.nb_ow, occ, oc_chunks, g, nb_groups); |
668 | break; |
669 | default: assert(!"unsupported loop order" ); |
670 | } |
671 | } |
672 | }); |
673 | return status::success; |
674 | } |
675 | |
676 | template struct jit_uni_x8s8s32x_convolution_fwd_t<sse41>; |
677 | template struct jit_uni_x8s8s32x_convolution_fwd_t<avx2>; |
678 | |
679 | } // namespace x64 |
680 | } // namespace cpu |
681 | } // namespace impl |
682 | } // namespace dnnl |
683 | |