1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/memory.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | |
22 | #include "cpu/x64/injectors/injector_utils.hpp" |
23 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
24 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
25 | #include "cpu/x64/jit_sse41_conv_kernel_f32.hpp" |
26 | |
27 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | using namespace dnnl::impl::format_tag; |
35 | using namespace dnnl::impl::prop_kind; |
36 | using namespace dnnl::impl::utils; |
37 | |
38 | using namespace Xbyak; |
39 | |
40 | jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32( |
41 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
42 | const memory_desc_t &dst_md) |
43 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, sse41) |
44 | , jcp(ajcp) |
45 | , attr_(attr) { |
46 | if (jcp.with_eltwise || jcp.with_binary) { |
47 | static constexpr bool preserve_gpr = true; |
48 | static constexpr bool preserve_vmm = false; |
49 | static constexpr size_t helper_vmm_idx = 15; |
50 | const size_t tail_size = jcp.oc_without_padding % simd_w_; |
51 | static constexpr bool use_exact_tail_scalar_bcast = false; |
52 | |
53 | const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { |
54 | helper_vmm_idx, r14, r15, r12, preserve_gpr, preserve_vmm, |
55 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
56 | memory_desc_wrapper(dst_md), tail_size, |
57 | use_exact_tail_scalar_bcast}; |
58 | const binary_injector::static_params_t static_params { |
59 | this->param1, rhs_arg_static_params}; |
60 | |
61 | postops_injector_ = utils::make_unique< |
62 | injector::jit_uni_postops_injector_t<sse41>>( |
63 | this, jcp.post_ops, static_params); |
64 | } |
65 | } |
66 | |
67 | void jit_sse41_conv_fwd_kernel_f32::oh_step_unroll_kw( |
68 | int ur_w, int pad_l, int pad_r, int oc_blocks) { |
69 | int kw = jcp.kw; |
70 | int stride_w = jcp.stride_w; |
71 | int dilate_w = jcp.dilate_w + 1; |
72 | int ic_blk = jcp.ic_block; |
73 | |
74 | for (int ki = 0; ki < kw; ki++) { |
75 | int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); |
76 | int jj_end = ur_w |
77 | - nstl::max(0, |
78 | div_up(ki * dilate_w + pad_r - (kw - 1) * dilate_w, |
79 | stride_w)); |
80 | for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { |
81 | for (int jj = jj_start; jj < jj_end; jj++) { |
82 | size_t inp_off = get_input_offset( |
83 | ifm2, filter_w_to_input(ki, jj, pad_l)); |
84 | movss(Xmm(oc_blocks * ur_w + jj + 1), |
85 | ptr[aux_reg_input + inp_off]); |
86 | shufps(Xmm(oc_blocks * ur_w + jj + 1), |
87 | Xmm(oc_blocks * ur_w + jj + 1), 0x0); |
88 | } |
89 | |
90 | for (int ii = 0; ii < oc_blocks; ii++) { |
91 | for (int jj = jj_start; jj < jj_end; jj++) { |
92 | movups(xmm0, |
93 | ptr[aux_reg_kernel |
94 | + get_kernel_offset(ii, ki, ifm2)]); |
95 | mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); |
96 | addps(Xmm(ur_w * ii + jj + 1), xmm0); |
97 | } |
98 | } |
99 | } |
100 | } |
101 | } |
102 | |
103 | void jit_sse41_conv_fwd_kernel_f32::oh_step_nopad( |
104 | int ur_w, int pad_l, int pad_r, int oc_blocks) { |
105 | Label kw_loop; |
106 | |
107 | int kw = jcp.kw; |
108 | int ic_blk = jcp.ic_block; |
109 | |
110 | xor_(ki_iter, ki_iter); |
111 | L(kw_loop); |
112 | { |
113 | int jj_start = 0; |
114 | int jj_end = ur_w; |
115 | for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { |
116 | for (int jj = jj_start; jj < jj_end; jj++) { |
117 | size_t inp_off = get_input_offset( |
118 | ifm2, filter_w_to_input(0, jj, pad_l)); |
119 | movss(Xmm(oc_blocks * ur_w + jj + 1), |
120 | ptr[aux_reg_input + inp_off]); |
121 | shufps(Xmm(oc_blocks * ur_w + jj + 1), |
122 | Xmm(oc_blocks * ur_w + jj + 1), 0x0); |
123 | } |
124 | for (int ii = 0; ii < oc_blocks; ii++) { |
125 | for (int jj = jj_start; jj < jj_end; jj++) { |
126 | movups(xmm0, |
127 | ptr[aux_reg_kernel |
128 | + get_kernel_offset(ii, 0, ifm2)]); |
129 | mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); |
130 | addps(Xmm(ur_w * ii + jj + 1), xmm0); |
131 | } |
132 | } |
133 | } |
134 | add(aux_reg_kernel, get_kernel_offset(0, 1, 0)); |
135 | add(aux_reg_input, get_input_offset(0, filter_w_to_input(1))); |
136 | |
137 | inc(ki_iter); |
138 | cmp(ki_iter, kw); |
139 | jl(kw_loop, T_NEAR); |
140 | } |
141 | } |
142 | |
143 | int get_xmm_idx(const int ur_w, const int oc_block_idx, const int ur_w_idx) { |
144 | return ur_w * oc_block_idx + ur_w_idx + 1; |
145 | } |
146 | |
147 | Xmm get_xmm(const int ur_w, const int oc_block_idx, const int ur_w_idx) { |
148 | return Xmm(get_xmm_idx(ur_w, oc_block_idx, ur_w_idx)); |
149 | } |
150 | |
151 | template <typename F> |
152 | static void iterate(const int oc_blocks, const int ur_w, const F &f) { |
153 | for (int i = 0; i < oc_blocks; i++) { |
154 | const bool mask_flag = i == oc_blocks - 1; |
155 | for (int j = 0; j < ur_w; j++) |
156 | f(mask_flag, i, j); |
157 | } |
158 | } |
159 | void jit_sse41_conv_fwd_kernel_f32::apply_postops( |
160 | const int oc_blocks, const int ur_w) { |
161 | injector_utils::vmm_index_set_t vmm_idxs; |
162 | if (jcp.with_binary) { |
163 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
164 | iterate(oc_blocks, ur_w, |
165 | [&](const bool mask_flag, const int i, const int j) { |
166 | const size_t o_off = get_output_offset(i, j); |
167 | const auto vmm_idx = get_xmm_idx(ur_w, i, j); |
168 | vmm_idxs.emplace(vmm_idx); |
169 | |
170 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
171 | vmm_idx, reg_output); |
172 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
173 | vmm_idx, o_off); |
174 | if (mask_flag) |
175 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
176 | }); |
177 | |
178 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
179 | } else { |
180 | iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { |
181 | vmm_idxs.emplace(get_xmm_idx(ur_w, i, j)); |
182 | }); |
183 | postops_injector_->compute_vector_range(vmm_idxs); |
184 | } |
185 | } |
186 | |
187 | void jit_sse41_conv_fwd_kernel_f32::width_blk_step( |
188 | int ur_w, int pad_l, int pad_r, int oc_blocks) { |
189 | int kw = jcp.kw; |
190 | int oc_blk = jcp.oc_block; |
191 | |
192 | xor_(simd_iter, simd_iter); |
193 | |
194 | mov(aux_reg_input, reg_input); |
195 | mov(aux_reg_kernel, reg_kernel); |
196 | |
197 | Label init_simd_iter_loop; |
198 | Label init_done; |
199 | Label init_first; |
200 | |
201 | L(init_simd_iter_loop); |
202 | |
203 | if (!jcp.with_sum) { |
204 | test(reg_ci_flag, FLAG_IC_FIRST); |
205 | jne(init_first, T_NEAR); |
206 | } |
207 | |
208 | for (int ii = 0; ii < oc_blocks; ii++) |
209 | for (int jj = 0; jj < ur_w; jj++) |
210 | movups(get_xmm(ur_w, ii, jj), |
211 | xword[reg_output + get_output_offset(ii, jj)]); |
212 | |
213 | if (jcp.with_sum && jcp.with_bias) { |
214 | test(reg_ci_flag, FLAG_IC_FIRST); |
215 | je(init_done, T_NEAR); |
216 | |
217 | for (int ii = 0; ii < oc_blocks; ii++) |
218 | for (int jj = 0; jj < ur_w; jj++) |
219 | addps(get_xmm(ur_w, ii, jj), |
220 | xword[reg_bias + sizeof(float) * ii * oc_blk]); |
221 | } |
222 | |
223 | jmp(init_done); |
224 | |
225 | L(init_first); |
226 | if (this->jcp.with_bias) { |
227 | for (int ii = 0; ii < oc_blocks; ii++) |
228 | for (int jj = 0; jj < ur_w; jj++) |
229 | movups(get_xmm(ur_w, ii, jj), |
230 | xword[reg_bias + sizeof(float) * ii * oc_blk]); |
231 | } else { |
232 | for (int ii = 0; ii < oc_blocks; ii++) |
233 | for (int jj = 0; jj < ur_w; jj++) { |
234 | const auto xmm = get_xmm(ur_w, ii, jj); |
235 | pxor(xmm, xmm); |
236 | } |
237 | } |
238 | |
239 | L(init_done); |
240 | |
241 | Label skip_kh_loop; |
242 | mov(kj, reg_kh); |
243 | if ((jcp.dilate_h >= jcp.ih) |
244 | || (jcp.kh - 1) * (jcp.dilate_h + 1) |
245 | < nstl::max(jcp.t_pad, jcp.b_pad)) { |
246 | cmp(kj, 0); |
247 | je(skip_kh_loop, T_NEAR); |
248 | } |
249 | Label kh_loop; |
250 | L(kh_loop); |
251 | { |
252 | if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { |
253 | oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); |
254 | sub(aux_reg_input, get_input_offset(0, filter_w_to_input(kw))); |
255 | add(aux_reg_input, get_input_offset(0, filter_h_to_input(1))); |
256 | } else { |
257 | oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); |
258 | add(aux_reg_kernel, get_kernel_offset(0, kw, 0)); |
259 | add(aux_reg_input, get_input_offset(0, filter_h_to_input(1))); |
260 | } |
261 | |
262 | dec(kj); |
263 | cmp(kj, 0); |
264 | jg(kh_loop, T_NEAR); |
265 | } |
266 | |
267 | L(skip_kh_loop); |
268 | |
269 | if (jcp.with_eltwise || jcp.with_binary) { |
270 | Label regular_store; |
271 | test(reg_ci_flag, FLAG_IC_LAST); |
272 | je(regular_store, T_NEAR); |
273 | |
274 | apply_postops(oc_blocks, ur_w); |
275 | |
276 | L(regular_store); |
277 | } |
278 | |
279 | for (int ii = 0; ii < oc_blocks; ii++) { |
280 | for (int jj = 0; jj < ur_w; jj++) { |
281 | const Xmm reg_out = get_xmm(ur_w, ii, jj); |
282 | movups(xword[reg_output + get_output_offset(ii, jj)], reg_out); |
283 | } |
284 | } |
285 | |
286 | mov(aux_reg_kernel, reg_kernel); |
287 | mov(aux_reg_input, reg_input); |
288 | add(aux_reg_kernel, sizeof(float) * 4); |
289 | add(reg_output, sizeof(float) * 4); |
290 | add(reg_bias, sizeof(float) * 4); |
291 | inc(simd_iter); |
292 | cmp(simd_iter, 2); |
293 | jl(init_simd_iter_loop, T_NEAR); |
294 | |
295 | sub(reg_output, sizeof(float) * 8); |
296 | sub(reg_bias, sizeof(float) * 8); |
297 | } |
298 | |
299 | inline void jit_sse41_conv_fwd_kernel_f32::solve_common(int oc_blocks) { |
300 | int ur_w = jcp.ur_w; |
301 | int ur_w_tail = jcp.ur_w_tail; |
302 | int n_oi = jcp.ow / ur_w; |
303 | int iw = jcp.iw; |
304 | int kw = jcp.kw; |
305 | int str_w = jcp.stride_w; |
306 | |
307 | int l_pad = jcp.l_pad; |
308 | int r_pad = nstl::max(0, jcp.r_pad); |
309 | int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, str_w, |
310 | calculate_extended_filter_size(kw, jcp.dilate_w)); |
311 | if (r_pad1 > 0) n_oi--; |
312 | |
313 | if (l_pad > 0) { |
314 | n_oi--; |
315 | if (n_oi < 0 && r_pad1 > 0) |
316 | width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" |
317 | else |
318 | width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" |
319 | add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w, l_pad))); |
320 | add(reg_output, get_output_offset(0, ur_w)); |
321 | } |
322 | |
323 | Label ow_loop; |
324 | xor_(oi_iter, oi_iter); |
325 | |
326 | if (n_oi > 0) { |
327 | L(ow_loop); |
328 | |
329 | width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" |
330 | add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w))); |
331 | add(reg_output, get_output_offset(0, ur_w)); |
332 | |
333 | inc(oi_iter); |
334 | cmp(oi_iter, n_oi); |
335 | jl(ow_loop, T_NEAR); |
336 | } |
337 | |
338 | if (r_pad1 > 0 && n_oi >= 0) { |
339 | width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" |
340 | add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w))); |
341 | add(reg_output, get_output_offset(0, ur_w)); |
342 | } |
343 | |
344 | if (ur_w_tail != 0) |
345 | width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" |
346 | } |
347 | |
348 | void jit_sse41_conv_fwd_kernel_f32::generate() { |
349 | this->preamble(); |
350 | |
351 | mov(reg_input, ptr[this->param1 + GET_OFF(src)]); |
352 | mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); |
353 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
354 | if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); |
355 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
356 | mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); |
357 | mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); |
358 | |
359 | int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; |
360 | Label tail, exit; |
361 | |
362 | cmp(reg_oc_blocks, jcp.nb_oc_blocking); |
363 | jne(nb_oc_tail ? tail : exit, T_NEAR); |
364 | |
365 | solve_common(jcp.nb_oc_blocking); |
366 | jmp(exit, T_NEAR); |
367 | |
368 | if (nb_oc_tail) { |
369 | L(tail); |
370 | cmp(reg_oc_blocks, nb_oc_tail); |
371 | jne(exit, T_NEAR); |
372 | solve_common(nb_oc_tail); |
373 | } |
374 | |
375 | L(exit); |
376 | |
377 | this->postamble(); |
378 | |
379 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
380 | } |
381 | |
382 | status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, |
383 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
384 | const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, |
385 | const primitive_attr_t &attr, int nthreads) { |
386 | if (!mayiuse(sse41)) return status::unimplemented; |
387 | |
388 | jcp.nthr = nthreads; |
389 | |
390 | jcp.prop_kind = cd.prop_kind; |
391 | |
392 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
393 | const int ndims = src_d.ndims(); |
394 | jcp.ndims = ndims; |
395 | |
396 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
397 | jcp.mb = src_d.dims()[0]; |
398 | |
399 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
400 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
401 | |
402 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; |
403 | jcp.iw = src_d.dims()[ndims - 1]; |
404 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; |
405 | jcp.ow = dst_d.dims()[ndims - 1]; |
406 | |
407 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; |
408 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
409 | |
410 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; |
411 | jcp.l_pad = cd.padding[0][ndims - 3]; |
412 | |
413 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; |
414 | jcp.stride_w = cd.strides[ndims - 3]; |
415 | |
416 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0]; |
417 | jcp.dilate_w = cd.dilates[ndims - 3]; |
418 | |
419 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
420 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
421 | jcp.r_pad = calculate_end_padding( |
422 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
423 | jcp.b_pad = calculate_end_padding( |
424 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
425 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
426 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad |
427 | || ext_kh <= jcp.b_pad; |
428 | if (kernel_outside_src) return status::unimplemented; |
429 | |
430 | const auto dat_tag_nxc = (ndims == 3 ? nwc : nhwc); |
431 | const auto dat_tag_ncx = (ndims == 3 ? ncw : nchw); |
432 | const auto dat_tag_nCx8c = (ndims == 3 ? nCw8c : nChw8c); |
433 | const auto wei_tag_OIxio = with_groups |
434 | ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o) |
435 | : pick(ndims - 3, OIw8i8o, OIhw8i8o); |
436 | const auto wei_tag_Oxio = with_groups ? pick(ndims - 3, gOwi8o, gOhwi8o) |
437 | : pick(ndims - 3, Owi8o, Ohwi8o); |
438 | |
439 | jcp.src_tag |
440 | = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); |
441 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio); |
442 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
443 | |
444 | const bool is_data_layout_nxc |
445 | = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
446 | |
447 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
448 | |
449 | const auto &post_ops = attr.post_ops_; |
450 | jcp.with_sum = post_ops.find(primitive_kind::sum) != -1; |
451 | const int eltwise_ind = post_ops.find(primitive_kind::eltwise); |
452 | jcp.with_eltwise = eltwise_ind != -1; |
453 | |
454 | const int binary_ind = post_ops.find(primitive_kind::binary); |
455 | jcp.with_binary = binary_ind != -1; |
456 | |
457 | jcp.post_ops = post_ops; |
458 | |
459 | using namespace injector; |
460 | static constexpr bool sum_at_pos_0_only = true; |
461 | static constexpr bool sum_requires_scale_one = true; |
462 | static constexpr bool sum_requires_zp_zero = true; |
463 | const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum}, |
464 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
465 | sum_requires_zp_zero}); |
466 | if (!post_ops_ok_) return status::unimplemented; |
467 | |
468 | const bool flat = jcp.ic == 3; |
469 | const bool mimo = !flat; |
470 | |
471 | bool args_ok = true |
472 | && IMPLICATION(flat, |
473 | jcp.wei_tag == wei_tag_Oxio |
474 | && ((jcp.src_tag == dat_tag_ncx |
475 | && jcp.dst_tag == dat_tag_nCx8c) |
476 | || (jcp.src_tag == dat_tag_nxc |
477 | && jcp.dst_tag == dat_tag_nxc))) |
478 | && IMPLICATION(mimo, |
479 | jcp.wei_tag == wei_tag_OIxio |
480 | && ((jcp.src_tag == dat_tag_nCx8c |
481 | && jcp.dst_tag == dat_tag_nCx8c) |
482 | || (jcp.src_tag == dat_tag_nxc |
483 | && jcp.dst_tag == dat_tag_nxc))) |
484 | && jcp.ic <= src_d.padded_dims()[1] |
485 | && jcp.oc <= dst_d.padded_dims()[1]; |
486 | if (!args_ok) return status::unimplemented; |
487 | |
488 | const int simd_w = 8; // 2 SSE vectors processing at once |
489 | |
490 | jcp.ur_h = 1; /* no code-unrolling by h so far */ |
491 | jcp.ur_w = 3; |
492 | if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; |
493 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
494 | |
495 | jcp.nb_oc_blocking |
496 | = is_data_layout_nxc ? 1 : 4; /* the optimal value for the kernel */ |
497 | |
498 | args_ok = true && jcp.oc % simd_w == 0 && jcp.l_pad <= jcp.ur_w |
499 | && IMPLICATION(jcp.kw > 7, |
500 | (jcp.t_pad == 0 && jcp.l_pad == 0) |
501 | || (jcp.stride_w == 1 && jcp.stride_h == 1)) |
502 | && IMPLICATION(mimo, jcp.ic % simd_w == 0); |
503 | if (!args_ok) return status::unimplemented; |
504 | |
505 | int r_pad_no_tail = nstl::max(0, |
506 | calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, |
507 | jcp.stride_w, ext_kw)); |
508 | |
509 | // kernel needs 1 temporary YMM register |
510 | const int num_avail_regs = 15; |
511 | if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { |
512 | /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ |
513 | jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, |
514 | nstl::min(jcp.ow, num_avail_regs / 2)); |
515 | jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; |
516 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
517 | /* check again ... */ |
518 | r_pad_no_tail = nstl::max(0, |
519 | calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, |
520 | jcp.stride_w, ext_kw)); |
521 | |
522 | if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) |
523 | return status::unimplemented; |
524 | } |
525 | assert(jcp.nb_oc_blocking > 0); |
526 | assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); |
527 | |
528 | jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; |
529 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
530 | |
531 | jcp.oc_block = simd_w; |
532 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
533 | |
534 | if (one_of(jcp.prop_kind, forward_training, forward_inference)) { |
535 | jcp.nb_ic_blocking = 12; |
536 | jcp.nb_ic_blocking_max = 16; |
537 | } else { |
538 | jcp.nb_ic_blocking = 1; |
539 | jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; |
540 | } |
541 | |
542 | return status::success; |
543 | } |
544 | |
545 | } // namespace x64 |
546 | } // namespace cpu |
547 | } // namespace impl |
548 | } // namespace dnnl |
549 | |