1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/memory.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | #include "common/nstl.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/platform.hpp" |
27 | |
28 | #include "cpu/x64/injectors/injector_utils.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
30 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
31 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
32 | #include "cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp" |
33 | |
34 | #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | using namespace dnnl::impl::utils; |
42 | using namespace dnnl::impl::data_type; |
43 | using namespace Xbyak; |
44 | using namespace injector_utils; |
45 | |
46 | template <cpu_isa_t isa, typename Vmm> |
47 | _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::_jit_uni_x8s8s32x_1x1_conv_kernel( |
48 | const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, |
49 | const memory_desc_t &dst_md) |
50 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
51 | , jcp(ajcp) |
52 | , attr_(attr) { |
53 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
54 | using namespace binary_injector; |
55 | static constexpr bool preserve_gpr = true; |
56 | static constexpr bool preserve_vmm = true; |
57 | rhs_arg_static_params_t rhs_arg_static_params {15, r13, r14, r15, |
58 | preserve_gpr, preserve_vmm, |
59 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
60 | memory_desc_wrapper(dst_md)}; |
61 | static_params_t static_params {this->param1, rhs_arg_static_params}; |
62 | |
63 | postops_injector_ |
64 | = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>( |
65 | this, jcp.post_ops, static_params); |
66 | } |
67 | } |
68 | |
69 | template <cpu_isa_t isa, typename Vmm> |
70 | void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::cvt2ps(data_type_t type_in, |
71 | const Vmm &vmm_in, const Reg64 ®, int offset, int load_size) { |
72 | load_data(type_in, vmm_in, reg, offset, load_size); |
73 | if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in); |
74 | } |
75 | |
76 | template <cpu_isa_t isa, typename Vmm> |
77 | void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::bcast_loop( |
78 | int load_loop_blk) { |
79 | mov(aux1_reg_bcast_data, reg_bcast_data); |
80 | mov(aux_reg_bcast_data, reg_bcast_data); |
81 | |
82 | mov(aux_reg_output_data, reg_output_data); |
83 | mov(reg_bcast_loop_iter, ptr[rsp + bcast_loop_work_off]); |
84 | |
85 | Label bcast_loop; |
86 | Label bcast_loop_tail; |
87 | |
88 | cmp(reg_bcast_loop_iter, jcp.ur); |
89 | jl(bcast_loop_tail, T_NEAR); |
90 | |
91 | L(bcast_loop); |
92 | { |
93 | assert(jcp.bcast_block == jcp.ur); |
94 | reduce_loop(load_loop_blk, jcp.ur, false); |
95 | add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step); |
96 | add(aux_reg_output_data, jcp.bcast_loop_output_step); |
97 | |
98 | sub(reg_bcast_loop_iter, jcp.bcast_block); |
99 | cmp(reg_bcast_loop_iter, jcp.bcast_block); |
100 | jge(bcast_loop, T_NEAR); |
101 | } |
102 | |
103 | L(bcast_loop_tail); |
104 | if (jcp.ur_tail) { |
105 | Label bcast_loop_tail_out; |
106 | cmp(reg_bcast_loop_iter, 0); |
107 | jz(bcast_loop_tail_out, T_NEAR); |
108 | reduce_loop(load_loop_blk, jcp.ur_tail, true); |
109 | L(bcast_loop_tail_out); |
110 | } |
111 | } |
112 | |
113 | template <cpu_isa_t isa, typename Vmm> |
114 | int _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::output_ptr( |
115 | const int i_load, const int i_ur) { |
116 | const size_t ur_stride = jcp.with_dw_conv |
117 | ? jcp.nb_load_blocking * jcp.oc_block * i_ur |
118 | : jcp.oc_without_padding * i_ur; |
119 | |
120 | return jcp.typesize_out * (ur_stride + i_load * jcp.load_block); |
121 | }; |
122 | |
123 | template <cpu_isa_t isa, typename Vmm> |
124 | int _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::vreg_accum_idx( |
125 | const int load_loop_blk, const int i_load, const int i_ur) { |
126 | const int vmm_idx = i_ur * load_loop_blk + i_load; |
127 | assert(vmm_idx < ker_max_reg_idx); |
128 | return (15 - vmm_idx); |
129 | }; |
130 | |
131 | template <cpu_isa_t isa, typename Vmm> |
132 | Vmm _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::vreg_accum( |
133 | const int load_loop_blk, const int i_load, const int i_ur) { |
134 | return Vmm(vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
135 | }; |
136 | |
137 | template <typename F> |
138 | void iterate(const int ur, const int load_loop_blk, const F &f) { |
139 | for (int i_ur = 0; i_ur < ur; ++i_ur) |
140 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
141 | f(i_ur, i_load); |
142 | } |
143 | |
144 | template <cpu_isa_t isa, typename Vmm> |
145 | void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::apply_sum(const int ur, |
146 | const int load_loop_blk, const bool mask_flag_in, |
147 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
148 | |
149 | if (jcp.with_sum) { |
150 | assert(!utils::any_null(p_sum_scale, p_sum_zp) |
151 | && "p_sum_scale or p_sum_zp = nullptr" ); |
152 | const float sum_scale = *p_sum_scale; |
153 | const int32_t sum_zp = *p_sum_zp; |
154 | const auto sum_injector_lam = [this, mask_flag_in, load_loop_blk, |
155 | sum_scale, sum_zp](const int i_ur, |
156 | const int i_load) { |
157 | const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; |
158 | const auto ymm_prev_dst = vmm_zero; |
159 | |
160 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
161 | cvt2ps(jcp.sum_dt, ymm_prev_dst, aux_reg_output_data, |
162 | output_ptr(i_load, i_ur), |
163 | mask_flag ? get_tail_size() : simd_w); |
164 | |
165 | if (sum_zp != 0) { |
166 | uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_zp]); |
167 | uni_vcvtdq2ps(vmm_tmp, vmm_tmp); |
168 | uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_tmp); |
169 | } |
170 | if (sum_scale == 1.f) |
171 | uni_vaddps(r, r, ymm_prev_dst); |
172 | else { |
173 | uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_scale]); |
174 | uni_vfmadd231ps(r, ymm_prev_dst, vmm_tmp); |
175 | } |
176 | }; |
177 | const auto sum_injector |
178 | = [=]() { iterate(ur, load_loop_blk, sum_injector_lam); }; |
179 | if (sum_zp != 0) |
180 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
181 | postops_injector_->set_lambda_injector( |
182 | primitive_kind::sum, sum_injector); |
183 | } |
184 | } |
185 | |
186 | template <cpu_isa_t isa, typename Vmm> |
187 | void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::apply_postops(const int ur, |
188 | const int load_loop_blk, const bool mask_flag_in, |
189 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
190 | |
191 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
192 | if (jcp.with_sum && *p_sum_zp != 0) |
193 | mov(ptr[rsp + reg_bcast_loop_iter_off], reg_ptr_sum_zp); |
194 | apply_sum(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp); |
195 | |
196 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
197 | vmm_index_set_t vmm_idxs; |
198 | if (jcp.with_binary) { |
199 | iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) { |
200 | const int ur_stride = jcp.with_dw_conv |
201 | ? jcp.nb_load_blocking * jcp.oc_block * i_ur |
202 | : jcp.oc_without_padding * jcp.ngroups * i_ur; |
203 | const size_t aux_output_offset = jcp.typesize_out |
204 | * (ur_stride + i_load * jcp.load_block); |
205 | const auto vmm_idx |
206 | = vreg_accum_idx(load_loop_blk, i_load, i_ur); |
207 | vmm_idxs.emplace(vmm_idx); |
208 | |
209 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
210 | vmm_idx, aux_reg_output_data); |
211 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
212 | vmm_idx, aux_output_offset); |
213 | }); |
214 | |
215 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
216 | } else { |
217 | iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) { |
218 | vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
219 | }); |
220 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
221 | } |
222 | if (jcp.with_sum && *p_sum_zp != 0) |
223 | mov(reg_ptr_sum_zp, ptr[rsp + reg_bcast_loop_iter_off]); |
224 | } |
225 | } |
226 | |
227 | template <cpu_isa_t isa, typename Vmm> |
228 | void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::reduce_loop( |
229 | int load_loop_blk, int ur, bool wraparound) { |
230 | |
231 | // use 0x10001 to represent 2 words of 0x1 |
232 | // and avoid using uni_vpbroadcastb that is missing in jit generator |
233 | const auto xmm_one = Xmm(vmm_one.getIdx()); |
234 | mov(reg_init_bcast, 0x10001); |
235 | uni_vmovq(xmm_one, reg_init_bcast); |
236 | uni_vpbroadcastd(vmm_one, xmm_one); |
237 | |
238 | auto vreg_load = [&](int i_load) { |
239 | const int vmm_idx = ur * load_loop_blk + i_load; |
240 | assert(vmm_idx < ker_max_reg_idx); |
241 | /* remap the register indices to |
242 | * avoid passing xmm0 to eltwise injector */ |
243 | return Vmm(15 - vmm_idx); |
244 | }; |
245 | |
246 | auto bcast_ptr = [&](int i_reduce, int i_ur) { |
247 | assert(i_ur < jcp.ur); |
248 | assert(i_reduce <= jcp.reduce_loop_unroll); |
249 | assert(jcp.reduce_loop_unroll == jcp.reduce_block); |
250 | |
251 | int offt = (jcp.ic_without_padding * i_ur + i_reduce); |
252 | |
253 | return ptr[aux_reg_bcast_data + jcp.typesize_in * offt]; |
254 | }; |
255 | |
256 | auto load_ptr = [&](int i_reduce, int i_load) { |
257 | int u0 = i_reduce % jcp.reduce_loop_unroll; |
258 | int u1 = i_reduce / jcp.reduce_loop_unroll; |
259 | |
260 | int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; |
261 | |
262 | return ptr[aux_reg_load_data + u1 * jcp.reduce_loop_load_step |
263 | + jcp.typesize_in * offt]; |
264 | }; |
265 | |
266 | auto init = [&]() { |
267 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
268 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
269 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
270 | uni_vpxor(r, r, r); |
271 | } |
272 | if (jcp.signed_input) { |
273 | // Used 0x80808080 to represents 2 words of 128 |
274 | // to avoid using uni_vpbroadcastb that is missing in jit generator |
275 | auto xmm_shift = Xbyak::Xmm(vmm_shift.getIdx()); |
276 | auto _t32 = reg_init_bcast.cvt32(); |
277 | mov(_t32, 0x80808080); |
278 | uni_vpinsrd(xmm_shift, xmm_shift, _t32, 0); |
279 | uni_vpbroadcastd(vmm_shift, xmm_shift); |
280 | } |
281 | }; |
282 | |
283 | auto store = [&](const bool mask_flag_in) { |
284 | const auto &p = attr_.post_ops_; |
285 | const int sum_idx = p.find(primitive_kind::sum); |
286 | const float *p_sum_scale |
287 | = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; |
288 | const int32_t *p_sum_zp |
289 | = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr; |
290 | mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); |
291 | mov(reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]); |
292 | if (p_sum_scale && *p_sum_scale != 1.f) { |
293 | mov(ptr[rsp + reg_load_data_off], reg_load_data); |
294 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
295 | } |
296 | if (jcp.src_zero_point) { |
297 | mov(reg_zp_compensation, ptr[rsp + reg_zp_compensation_off]); |
298 | mov(reg_src_zero_point, ptr[rsp + reg_src_zero_point_off]); |
299 | } |
300 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
301 | if (jcp.src_zero_point) { |
302 | uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]); |
303 | } |
304 | const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; |
305 | const int load_size = mask_flag ? get_tail_size() : simd_w; |
306 | const auto ptr_scales_offset |
307 | = jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load); |
308 | if (jcp.with_bias) { |
309 | if (jcp.signed_input || jcp.dst_scale) |
310 | mov(reg_bias_data, ptr[rsp + reg_bias_data_off]); |
311 | cvt2ps(jcp.bia_dt, vmm_bias, reg_bias_data, |
312 | jcp.typesize_bia * jcp.oc_block * i_load, load_size); |
313 | } |
314 | if (jcp.signed_input) { |
315 | mov(reg_comp_data, ptr[rsp + reg_comp_data_off]); |
316 | cvt2ps(data_type::s32, vmm_comp, reg_comp_data, |
317 | sizeof(int32_t) * jcp.oc_block * i_load, load_size); |
318 | } |
319 | if (jcp.src_zero_point) { |
320 | const int zp_offset = sizeof(int32_t) * i_load * jcp.oc_block; |
321 | load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation, |
322 | zp_offset, load_size); |
323 | uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_zp); |
324 | |
325 | // upscale to f32 |
326 | uni_vcvtdq2ps(vmm_zp_comp, vmm_zp_comp); |
327 | } |
328 | |
329 | if (mask_flag) { |
330 | uni_vpxor(vmm_scale, vmm_scale, vmm_scale); |
331 | cvt2ps(data_type::f32, vmm_scale, reg_ptr_scales, |
332 | ptr_scales_offset, get_tail_size()); |
333 | } else { |
334 | uni_vmovups(vmm_scale, ptr[reg_ptr_scales + ptr_scales_offset]); |
335 | } |
336 | |
337 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
338 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
339 | uni_vcvtdq2ps(r, r); |
340 | if (jcp.signed_input) uni_vaddps(r, r, vmm_comp); |
341 | if (jcp.src_zero_point) uni_vaddps(r, r, vmm_zp_comp); |
342 | |
343 | uni_vmulps(r, r, vmm_scale); |
344 | |
345 | if (jcp.with_bias) uni_vaddps(r, r, vmm_bias); |
346 | } |
347 | } |
348 | |
349 | apply_postops(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp); |
350 | |
351 | if (jcp.dst_scale) { |
352 | mov(reg_ptr_dst_scale, ptr[rsp + reg_dst_scale_off]); |
353 | uni_vmovups(vmm_dst_scale, ptr[reg_ptr_dst_scale]); |
354 | |
355 | /* Apply dst scale to accumulator */ |
356 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
357 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
358 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
359 | uni_vmulps(r, r, vmm_dst_scale); |
360 | } |
361 | } |
362 | } |
363 | |
364 | if (jcp.dst_zero_point) { |
365 | mov(reg_dst_zero_point, ptr[rsp + reg_dst_zero_point_off]); |
366 | uni_vpbroadcastd(vmm_zp, ptr[reg_dst_zero_point]); |
367 | uni_vcvtdq2ps(vmm_zp, vmm_zp); |
368 | |
369 | /* Add dst zero_point to accumulator */ |
370 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
371 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
372 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
373 | uni_vaddps(r, r, vmm_zp); |
374 | } |
375 | } |
376 | } |
377 | |
378 | // Properly saturate the accumulators for integer datatypes |
379 | if (utils::one_of(jcp.dst_dt, u8, s8, s32)) { |
380 | init_saturate_f32(vmm_zero, vmm_saturation, aux_reg_saturation, f32, |
381 | jcp.dst_dt); |
382 | |
383 | for (int i_ur = 0; i_ur < ur; ++i_ur) |
384 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
385 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
386 | saturate_f32(r, vmm_zero, vmm_saturation, jcp.dst_dt); |
387 | uni_vcvtps2dq(r, r); |
388 | } |
389 | } |
390 | |
391 | /* write out register to output_addr */ |
392 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
393 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
394 | const bool mask_flag |
395 | = mask_flag_in && i_load == load_loop_blk - 1; |
396 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
397 | store_data(jcp.dst_dt, r, aux_reg_output_data, |
398 | output_ptr(i_load, i_ur), |
399 | mask_flag ? get_tail_size() : simd_w); |
400 | } |
401 | } |
402 | mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]); |
403 | if (p_sum_scale && *p_sum_scale != 1.f) |
404 | mov(reg_load_data, ptr[rsp + reg_load_data_off]); |
405 | }; |
406 | |
407 | auto compute = [&](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { |
408 | if (jcp.has_vnni) { |
409 | vpdpbusd(vreg_acc, vreg_src, vreg_wei, VexEncoding); |
410 | } else { |
411 | uni_vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); |
412 | uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); |
413 | uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp); |
414 | } |
415 | }; |
416 | |
417 | auto fma_block = [&](bool last_block) { |
418 | int reduce_step = 4; |
419 | int ic_tail_size = jcp.ic_without_padding % reduce_step; |
420 | int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding |
421 | ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) |
422 | : jcp.reduce_loop_unroll; |
423 | for (int i_reduce = 0; i_reduce < loop_unroll; |
424 | i_reduce += reduce_step) { |
425 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
426 | uni_vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); |
427 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
428 | if (last_block && ic_tail_size != 0 |
429 | && i_reduce == loop_unroll - reduce_step) { |
430 | load_bytes(vmm_bcast, aux_reg_bcast_data, |
431 | jcp.ic_without_padding * i_ur + i_reduce, |
432 | ic_tail_size); |
433 | uni_vpbroadcastd(vmm_bcast, Xmm(vmm_bcast.getIdx())); |
434 | } else { |
435 | uni_vpbroadcastd(vmm_bcast, bcast_ptr(i_reduce, i_ur)); |
436 | } |
437 | if (jcp.signed_input) |
438 | uni_vpsubb(vmm_bcast, vmm_bcast, vmm_shift); |
439 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
440 | compute(vreg_accum(load_loop_blk, i_load, i_ur), |
441 | vreg_load(i_load), vmm_bcast); |
442 | } |
443 | } |
444 | } |
445 | }; |
446 | |
447 | Label reduce_loop; |
448 | Label reduce_loop_tail; |
449 | |
450 | mov(aux_reg_load_data, reg_load_data); |
451 | |
452 | mov(aux_reg_bcast_data, aux1_reg_bcast_data); |
453 | init(); |
454 | |
455 | mov(reg_reduce_loop_iter, reg_reduce_loop_work); |
456 | sub(reg_reduce_loop_iter, jcp.reduce_loop_unroll); |
457 | jle(reduce_loop_tail, T_NEAR); |
458 | |
459 | L(reduce_loop); |
460 | { |
461 | fma_block(false); |
462 | add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); |
463 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
464 | sub(reg_reduce_loop_iter, jcp.reduce_loop_unroll); |
465 | jg(reduce_loop, T_NEAR); |
466 | } |
467 | |
468 | L(reduce_loop_tail); |
469 | fma_block(jcp.ic != jcp.ic_without_padding); |
470 | |
471 | if (jcp.oc_without_padding != jcp.oc) { |
472 | Label end_store, common_store; |
473 | mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); |
474 | |
475 | /*Check if it is the last load_loop_blk*/ |
476 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
477 | cmp(reg_load_loop_work, 0); |
478 | jg(common_store, T_NEAR); |
479 | |
480 | /*Check if it is the last ocb*/ |
481 | test(reg_reduce_pos_flag, FLAG_OC_LAST); |
482 | jz(common_store, T_NEAR); |
483 | |
484 | store(true); |
485 | jmp(end_store, T_NEAR); |
486 | |
487 | L(common_store); |
488 | store(false); |
489 | |
490 | L(end_store); |
491 | |
492 | add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
493 | } else { |
494 | store(false); |
495 | } |
496 | } |
497 | |
498 | template <cpu_isa_t isa, typename Vmm> |
499 | void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::generate() { |
500 | preamble(); |
501 | |
502 | sub(rsp, stack_space_needed); |
503 | if (jcp.with_binary) { |
504 | // zero initialize binary post_ops offset accumulator (store on stack) |
505 | const auto binary_post_op_acc_off_reg = r15; |
506 | xor_(binary_post_op_acc_off_reg, binary_post_op_acc_off_reg); |
507 | mov(ptr[rsp + reg_binary_post_op_acc_off], binary_post_op_acc_off_reg); |
508 | } |
509 | |
510 | if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
511 | if (jcp.signed_input) { |
512 | mov(ptr[rsp + reg_bias_data_off], reg_bias_data); |
513 | mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); |
514 | mov(ptr[rsp + reg_comp_data_off], reg_comp_data); |
515 | } |
516 | if (jcp.src_zero_point) { |
517 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
518 | mov(ptr[rsp + reg_zp_compensation_off], reg_zp_compensation); |
519 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
520 | mov(ptr[rsp + reg_src_zero_point_off], reg_src_zero_point); |
521 | } |
522 | if (jcp.dst_scale) { |
523 | if (!jcp.signed_input) mov(ptr[rsp + reg_bias_data_off], reg_bias_data); |
524 | mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
525 | mov(ptr[rsp + reg_dst_scale_off], reg_ptr_dst_scale); |
526 | } |
527 | if (jcp.dst_zero_point) { |
528 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
529 | mov(ptr[rsp + reg_dst_zero_point_off], reg_dst_zero_point); |
530 | } |
531 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
532 | mov(ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales); |
533 | mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); |
534 | mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); |
535 | mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); |
536 | |
537 | mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); |
538 | mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); |
539 | mov(ptr[rsp + bcast_loop_work_off], reg_bcast_loop_work); |
540 | mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); |
541 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
542 | |
543 | auto load_loop_body = [&](int load_loop_blk) { |
544 | bcast_loop(load_loop_blk); |
545 | add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); |
546 | if (jcp.with_bias) { |
547 | if (jcp.signed_input || jcp.dst_scale) |
548 | mov(reg_bias_data, ptr[rsp + reg_bias_data_off]); |
549 | add(reg_bias_data, |
550 | load_loop_blk * jcp.load_block * jcp.typesize_bia); |
551 | if (jcp.signed_input || jcp.dst_scale) |
552 | mov(ptr[rsp + reg_bias_data_off], reg_bias_data); |
553 | } |
554 | if (jcp.with_binary) { |
555 | mov(aux_reg_load_data, |
556 | EVEX_compress_addr(rsp, reg_binary_post_op_acc_off)); |
557 | add(aux_reg_load_data, jcp.load_block * load_loop_blk); |
558 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), |
559 | aux_reg_load_data); |
560 | } |
561 | if (jcp.signed_input) { |
562 | mov(reg_comp_data, ptr[rsp + reg_comp_data_off]); |
563 | add(reg_comp_data, |
564 | load_loop_blk * jcp.load_block * sizeof(int32_t)); |
565 | mov(ptr[rsp + reg_comp_data_off], reg_comp_data); |
566 | } |
567 | if (jcp.src_zero_point) { |
568 | mov(reg_zp_compensation, ptr[rsp + reg_zp_compensation_off]); |
569 | add(reg_zp_compensation, |
570 | load_loop_blk * jcp.load_block * sizeof(int32_t)); |
571 | mov(ptr[rsp + reg_zp_compensation_off], reg_zp_compensation); |
572 | } |
573 | mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); |
574 | mov(reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]); |
575 | add(reg_ptr_scales, |
576 | jcp.is_oc_scale * load_loop_blk * jcp.load_block |
577 | * sizeof(float)); |
578 | mov(ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales); |
579 | mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]); |
580 | add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out); |
581 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
582 | }; |
583 | |
584 | static const int ur_cases[] = {2, 3, 5, 12}; |
585 | constexpr int num_ur_cases = sizeof(ur_cases) / sizeof(*ur_cases); |
586 | Label load_loop_blk[num_ur_cases + 1]; |
587 | |
588 | for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { |
589 | int label_idx = num_ur_cases - ur_idx - 1; |
590 | if (jcp.ur <= ur_cases[ur_idx]) { |
591 | cmp(reg_load_loop_work, simd_w * (label_idx + 1)); |
592 | jle(load_loop_blk[label_idx], T_NEAR); |
593 | } |
594 | } |
595 | |
596 | for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { |
597 | if (jcp.ur <= ur_cases[ur_idx]) { |
598 | int label_idx = num_ur_cases - ur_idx - 1; |
599 | L(load_loop_blk[label_idx]); |
600 | { |
601 | if (label_idx == 0) { |
602 | cmp(reg_load_loop_work, 0); |
603 | je(load_loop_blk[num_ur_cases], T_NEAR); |
604 | } |
605 | |
606 | load_loop_body(label_idx + 1); |
607 | if (label_idx - 1 > 0) { |
608 | cmp(reg_load_loop_work, 2 * label_idx * simd_w); |
609 | je(load_loop_blk[label_idx - 1], T_NEAR); |
610 | } |
611 | cmp(reg_load_loop_work, (label_idx + 1) * simd_w); |
612 | jge(load_loop_blk[label_idx]); |
613 | } |
614 | for (int idx = label_idx - 1; idx > 0; --idx) { |
615 | cmp(reg_load_loop_work, simd_w * (idx + 1)); |
616 | je(load_loop_blk[idx], T_NEAR); |
617 | } |
618 | if (ur_idx < num_ur_cases - 2) { |
619 | cmp(reg_load_loop_work, simd_w); |
620 | jle(load_loop_blk[0], T_NEAR); |
621 | } |
622 | } |
623 | } |
624 | L(load_loop_blk[num_ur_cases]); |
625 | add(rsp, stack_space_needed); |
626 | postamble(); |
627 | |
628 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
629 | } |
630 | |
631 | template <cpu_isa_t isa> |
632 | status_t jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_conf( |
633 | jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, |
634 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, |
635 | const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, |
636 | primitive_attr_t &attr, int nthreads, bool reduce_src) { |
637 | if (!mayiuse(isa)) return status::unimplemented; |
638 | |
639 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
640 | if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) |
641 | || weights_d.data_type() != data_type::s8 |
642 | || !one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
643 | data_type::s8, data_type::u8)) |
644 | return status::unimplemented; |
645 | |
646 | const int ndims = src_d.ndims(); |
647 | jcp.nthr = nthreads; |
648 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
649 | jcp.mb = src_d.dims()[0]; |
650 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
651 | jcp.oc_without_padding = jcp.oc; |
652 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
653 | jcp.ic_without_padding = jcp.ic; |
654 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
655 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
656 | jcp.iw = src_d.dims()[ndims - 1]; |
657 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
658 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
659 | jcp.ow = dst_d.dims()[ndims - 1]; |
660 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
661 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
662 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
663 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
664 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
665 | jcp.l_pad = cd.padding[0][ndims - 3]; |
666 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
667 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
668 | jcp.stride_w = cd.strides[ndims - 3]; |
669 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
670 | |
671 | jcp.signed_input = (src_d.data_type() == data_type::s8); |
672 | |
673 | jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow; |
674 | jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw; |
675 | |
676 | const auto &post_ops = attr.post_ops_; |
677 | const int dw_conv_ind = post_ops.find(primitive_kind::convolution); |
678 | jcp.with_dw_conv = dw_conv_ind != -1; |
679 | // Using dw_conv_ind as upper-bound below, as post-ops after it will be |
680 | // handled in depthwise convolution. |
681 | const int eltwise_ind |
682 | = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); |
683 | jcp.with_eltwise = eltwise_ind != -1; |
684 | |
685 | const int binary_ind |
686 | = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); |
687 | jcp.with_binary = binary_ind != -1; |
688 | |
689 | const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); |
690 | jcp.with_sum = sum_ind != -1; |
691 | |
692 | const auto zp = attr.zero_points_; |
693 | jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); |
694 | jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
695 | jcp.zp_src_is_common |
696 | = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel |
697 | assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)); |
698 | |
699 | if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.with_dw_conv) |
700 | return status::unimplemented; |
701 | |
702 | format_tag_t dat_tag = utils::pick( |
703 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
704 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
705 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
706 | |
707 | bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag |
708 | && jcp.dst_tag == dat_tag; |
709 | if (!args_ok) return status::unimplemented; |
710 | |
711 | jcp.has_vnni = mayiuse(avx2_vnni); |
712 | |
713 | jcp.oc = rnd_up(jcp.oc, simd_w); |
714 | jcp.ic = rnd_up(jcp.ic, simd_w); |
715 | |
716 | if (dw_conv_ind >= 0) { |
717 | // dw_conv and post_ops after it are handled externally, so skip them |
718 | jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), |
719 | post_ops.entry_.cbegin() + dw_conv_ind); |
720 | } else { |
721 | jcp.post_ops = post_ops; |
722 | } |
723 | |
724 | for (auto &post_op : jcp.post_ops.entry_) |
725 | if (post_op.is_binary() && post_op.binary.src1_desc.dims[1] != 1) { |
726 | post_op.binary.src1_desc.dims[1] = jcp.oc; |
727 | } |
728 | |
729 | using namespace injector; |
730 | const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum}, |
731 | jcp.post_ops, &dst_d, false, false, false}); |
732 | if (!post_ops_ok_) return status::unimplemented; |
733 | |
734 | args_ok = true && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 |
735 | && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 |
736 | && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1 |
737 | && jcp.ow == jcp.iw && jcp.oh == jcp.ih && jcp.od == jcp.id |
738 | && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1; |
739 | if (!args_ok) return status::unimplemented; |
740 | |
741 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
742 | jcp.dst_dt = cd.dst_desc.data_type; |
743 | jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); |
744 | |
745 | jcp.ic_block = jcp.oc_block = simd_w; |
746 | |
747 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
748 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
749 | jcp.typesize_bia |
750 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
751 | |
752 | const int SMALL_SPATIAL = 7 * 7; |
753 | const int BIG_REDUCE_DIM = 512; |
754 | |
755 | int load_blocking = 0; |
756 | int load_blocking_max = 0; |
757 | int bcast_blocking = 0; |
758 | int bcast_blocking_max = 0; |
759 | int reduce_blocking = 0; |
760 | int reduce_blocking_max = 0; |
761 | jcp.load_grp_count = 1; |
762 | |
763 | const int L2_size |
764 | = platform::get_per_core_cache_size(2) / sizeof(jcp.typesize_in); |
765 | const int L2_capacity = (L2_size * 3) / 4; |
766 | |
767 | int size_threshold = 28; |
768 | |
769 | int min_regs = 3; |
770 | int max_regs = 5; |
771 | |
772 | if (jcp.mb == 1 && jcp.ic > 128 |
773 | && (jcp.oh <= size_threshold && jcp.ow <= size_threshold)) { |
774 | if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) |
775 | max_regs = min_regs = 3; |
776 | jcp.ur = nstl::min<dim_t>(max_regs, jcp.os); |
777 | } else { |
778 | const int spatial = jcp.od * jcp.oh; |
779 | jcp.ur = 1; |
780 | for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { |
781 | if ((spatial >= size_threshold && spatial % ur_w == 0) |
782 | || (spatial < size_threshold && jcp.os % ur_w == 0)) { |
783 | jcp.ur = ur_w; |
784 | break; |
785 | } |
786 | } |
787 | if (jcp.ur == 1) { |
788 | jcp.ur = nstl::min<dim_t>(max_regs, jcp.os); |
789 | int os_tail = jcp.os % max_regs; |
790 | for (int i = max_regs; i >= min_regs; i--) { |
791 | int i_tail = jcp.os % i; |
792 | if (i_tail > os_tail || i_tail == 0) { |
793 | jcp.ur = i; |
794 | os_tail = i_tail; |
795 | if (i_tail == 0) break; |
796 | } |
797 | } |
798 | } |
799 | } |
800 | |
801 | if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); |
802 | jcp.reduce_dim = jcp.ic; |
803 | jcp.reduce_block = jcp.ic_block; |
804 | |
805 | jcp.load_dim = jcp.oc; |
806 | jcp.load_block = jcp.oc_block; |
807 | |
808 | jcp.bcast_dim = jcp.is; |
809 | |
810 | jcp.bcast_block = jcp.ur; |
811 | |
812 | jcp.reduce_loop_unroll = jcp.reduce_block; |
813 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll * jcp.typesize_in; |
814 | |
815 | jcp.reduce_loop_load_step |
816 | = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; |
817 | |
818 | jcp.bcast_loop_output_step |
819 | = jcp.ur * jcp.oc_without_padding * jcp.typesize_out; |
820 | jcp.bcast_loop_bcast_step |
821 | = jcp.ur * jcp.ic_without_padding * jcp.typesize_in; |
822 | |
823 | jcp.load_loop_load_step = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; |
824 | |
825 | jcp.load_loop_iter_step = jcp.load_block; |
826 | |
827 | jcp.loop_order = reduce_src ? loop_blr : loop_lbr; |
828 | |
829 | int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
830 | int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
831 | |
832 | reduce_blocking = nb_reduce; |
833 | if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) |
834 | reduce_blocking = 64; |
835 | else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) |
836 | reduce_blocking = 16; |
837 | |
838 | reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); |
839 | reduce_blocking *= jcp.reduce_block; |
840 | |
841 | bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; |
842 | if (cmp_reduce) jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; |
843 | load_blocking = jcp.load_dim; |
844 | |
845 | jcp.load_grp_count = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast); |
846 | jcp.load_grp_count = best_divider( |
847 | jcp.nthr, jcp.load_grp_count, 2 * jcp.load_grp_count, false); |
848 | |
849 | if (jcp.bcast_dim <= SMALL_SPATIAL |
850 | && jcp.load_dim * jcp.reduce_dim >= L2_size) { |
851 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); |
852 | } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= jcp.nthr |
853 | && jcp.load_dim > 256 && jcp.load_dim / jcp.reduce_dim >= 4) { |
854 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); |
855 | load_blocking = jcp.load_block; |
856 | } |
857 | |
858 | bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, |
859 | div_up(jcp.nthr, jcp.load_grp_count)) |
860 | * jcp.bcast_block; |
861 | bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking); |
862 | bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); |
863 | |
864 | int space_for_bcast = (L2_capacity - /* kernel_size - */ |
865 | 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking |
866 | - 3 * 1024); |
867 | if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2; |
868 | |
869 | int bcast_in_cache |
870 | = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); |
871 | bcast_blocking = nstl::min( |
872 | bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); |
873 | |
874 | load_blocking_max = load_blocking; |
875 | bcast_blocking_max = bcast_blocking * 3 / 2; |
876 | reduce_blocking_max = reduce_blocking; |
877 | |
878 | const bool params_ok = true && load_blocking > 0 && load_blocking_max > 0 |
879 | && bcast_blocking > 0 && bcast_blocking_max > 0 |
880 | && reduce_blocking > 0 && reduce_blocking_max > 0 |
881 | && load_blocking % jcp.load_block == 0 |
882 | && reduce_blocking % jcp.reduce_block == 0 |
883 | && load_blocking_max % jcp.load_block == 0 |
884 | && reduce_blocking_max % jcp.reduce_block == 0 |
885 | && jcp.reduce_loop_unroll % 4 == 0 |
886 | && jcp.reduce_dim % jcp.reduce_loop_unroll == 0 |
887 | && jcp.bcast_block % jcp.ur == 0 |
888 | && jcp.reduce_dim % jcp.reduce_block == 0; |
889 | |
890 | assert(params_ok && "parameter values are inconsistent" ); |
891 | if (!params_ok) return status::unimplemented; |
892 | |
893 | jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; |
894 | |
895 | jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; |
896 | jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; |
897 | jcp.nb_load_blocking = load_blocking / jcp.load_block; |
898 | jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; |
899 | jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; |
900 | jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; |
901 | |
902 | jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
903 | jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); |
904 | jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
905 | |
906 | // miniumum size of load dim chunk for work distribution within threads |
907 | jcp.nb_load_chunk = 1; |
908 | |
909 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
910 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
911 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
912 | const int wei_mask_per_oc = 1 << (int)with_groups; |
913 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
914 | jcp.dst_scale = !dst_scales.has_default_values(); |
915 | |
916 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
917 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
918 | if (!scales_ok) return status::unimplemented; |
919 | |
920 | jcp.wei_adj_scale |
921 | = (weights_d.extra().flags & memory_extra_flags::scale_adjust) |
922 | ? weights_d.extra().scale_adjust |
923 | : 1.f; |
924 | |
925 | return status::success; |
926 | } |
927 | |
928 | template <cpu_isa_t isa> |
929 | void jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_scratchpad( |
930 | memory_tracking::registrar_t &scratchpad, |
931 | const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { |
932 | using namespace dnnl::impl::memory_tracking::names; |
933 | |
934 | const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
935 | const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups; |
936 | const dim_t count = nstl::max<dim_t>(scales_count, 8); |
937 | scratchpad.book<float>(key_conv_adjusted_scales, count); |
938 | } |
939 | |
940 | template struct _jit_uni_x8s8s32x_1x1_conv_kernel<avx2, Ymm>; |
941 | template struct _jit_uni_x8s8s32x_1x1_conv_kernel<sse41, Xmm>; |
942 | template struct jit_uni_x8s8s32x_1x1_conv_kernel<avx2>; |
943 | template struct jit_uni_x8s8s32x_1x1_conv_kernel<sse41>; |
944 | } // namespace x64 |
945 | } // namespace cpu |
946 | } // namespace impl |
947 | } // namespace dnnl |
948 | |