1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * Copyright 2018 YANDEX LLC |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/memory.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | #include "cpu/x64/injectors/injector_utils.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
28 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
29 | #include "cpu/x64/jit_avx2_conv_kernel_f32.hpp" |
30 | |
31 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace x64 { |
37 | |
38 | using namespace dnnl::impl::prop_kind; |
39 | using namespace dnnl::impl::format_tag; |
40 | using namespace dnnl::impl::memory_tracking::names; |
41 | using namespace dnnl::impl::utils; |
42 | |
43 | using namespace Xbyak; |
44 | |
45 | jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32( |
46 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
47 | const memory_desc_t &dst_md) |
48 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx2) |
49 | , jcp(ajcp) |
50 | , attr_(attr) { |
51 | if (jcp.with_eltwise || jcp.with_binary) { |
52 | using namespace binary_injector; |
53 | static constexpr bool preserve_gpr = true; |
54 | static constexpr bool preserve_vmm = false; |
55 | static constexpr size_t helper_vmm_idx = 15; |
56 | static constexpr bool use_exact_tail_scalar_bcast = false; |
57 | const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; |
58 | |
59 | rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14, |
60 | r15, preserve_gpr, preserve_vmm, |
61 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
62 | memory_desc_wrapper(dst_md), tail_size, |
63 | use_exact_tail_scalar_bcast}; |
64 | static_params_t static_params {this->param1, rhs_arg_static_params}; |
65 | |
66 | postops_injector_ = utils::make_unique< |
67 | injector::jit_uni_postops_injector_t<avx2>>( |
68 | this, jcp.post_ops, static_params); |
69 | } |
70 | } |
71 | |
72 | void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw( |
73 | int ur_w, int pad_l, int pad_r, int oc_blocks) { |
74 | int kw = jcp.kw; |
75 | int stride_w = jcp.stride_w; |
76 | int dilate_w = jcp.dilate_w + 1; |
77 | int ic_block = jcp.ic_block; |
78 | int ic_tail = jcp.ic_tail; |
79 | |
80 | for (int ki = 0; ki < kw; ki++) { |
81 | int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); |
82 | int jj_end = ur_w |
83 | - nstl::max(0, |
84 | div_up(ki * dilate_w + pad_r - (kw - 1) * dilate_w, |
85 | stride_w)); |
86 | |
87 | auto compute = [=](int cur_ic_blk) { |
88 | for (int ifm2 = 0; ifm2 < cur_ic_blk; ifm2++) { |
89 | for (int jj = jj_start; jj < jj_end; jj++) { |
90 | size_t inp_off = get_input_offset( |
91 | ifm2, filter_w_to_input(ki, jj, pad_l)); |
92 | vbroadcastss(Ymm(oc_blocks * ur_w + jj), |
93 | make_safe_addr( |
94 | aux_reg_input, inp_off, reg_long_offt)); |
95 | } |
96 | |
97 | for (int ii = 0; ii < oc_blocks; ii++) { |
98 | vmovups(ymm15, |
99 | make_safe_addr(aux_reg_kernel, |
100 | get_kernel_offset(ii, ki, ifm2), |
101 | reg_long_offt)); |
102 | for (int jj = jj_start; jj < jj_end; jj++) |
103 | if (mayiuse(avx2)) |
104 | vfmadd231ps(Ymm(ur_w * ii + jj), |
105 | Ymm(oc_blocks * ur_w + jj), ymm15); |
106 | else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support |
107 | vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); |
108 | vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), |
109 | ytmp); |
110 | } |
111 | } |
112 | } |
113 | }; |
114 | |
115 | if (ic_tail) { |
116 | if (jcp.ic == ic_tail) |
117 | compute(ic_tail); |
118 | else { |
119 | Label ic_blk_tail, ic_blk_done; |
120 | cmp(reg_channel, ic_block); |
121 | jl(ic_blk_tail, T_NEAR); |
122 | |
123 | compute(ic_block); |
124 | jmp(ic_blk_done, T_NEAR); |
125 | |
126 | L(ic_blk_tail); |
127 | compute(ic_tail); |
128 | |
129 | L(ic_blk_done); |
130 | } |
131 | } else { |
132 | compute(ic_block); |
133 | } |
134 | } |
135 | } |
136 | |
137 | void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad( |
138 | int ur_w, int pad_l, int pad_r, int oc_blocks) { |
139 | Label kw_loop; |
140 | |
141 | int kw = jcp.kw; |
142 | int ic_blk = jcp.ic_block; |
143 | |
144 | xor_(ki_iter, ki_iter); |
145 | L(kw_loop); |
146 | { |
147 | int jj_start = 0; |
148 | int jj_end = ur_w; |
149 | for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { |
150 | for (int jj = jj_start; jj < jj_end; jj++) { |
151 | size_t inp_off = get_input_offset( |
152 | ifm2, filter_w_to_input(0, jj, pad_l)); |
153 | vbroadcastss(Ymm(oc_blocks * ur_w + jj), |
154 | make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); |
155 | } |
156 | for (int ii = 0; ii < oc_blocks; ii++) { |
157 | vmovups(ymm15, |
158 | make_safe_addr(aux_reg_kernel, |
159 | get_kernel_offset(ii, 0, ifm2), reg_long_offt)); |
160 | for (int jj = jj_start; jj < jj_end; jj++) |
161 | if (mayiuse(avx2)) |
162 | vfmadd231ps(Ymm(ur_w * ii + jj), |
163 | Ymm(oc_blocks * ur_w + jj), ymm15); |
164 | else { // Intel AVX support |
165 | vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); |
166 | vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); |
167 | } |
168 | } |
169 | } |
170 | safe_add(aux_reg_kernel, get_kernel_offset(0, 1, 0), reg_long_offt); |
171 | safe_add(aux_reg_input, get_input_offset(0, filter_w_to_input(1)), |
172 | reg_long_offt); |
173 | |
174 | inc(ki_iter); |
175 | cmp(ki_iter, kw); |
176 | jl(kw_loop, T_NEAR); |
177 | } |
178 | } |
179 | |
180 | static int get_ymm_idx( |
181 | const int ur_w, const int oc_block_idx, const int ur_w_idx) { |
182 | return (ur_w * oc_block_idx + ur_w_idx); |
183 | } |
184 | |
185 | static Ymm get_ymm(const int ur_w, const int oc_block_idx, const int ur_w_idx) { |
186 | return Ymm(get_ymm_idx(ur_w, oc_block_idx, ur_w_idx)); |
187 | } |
188 | |
189 | template <typename F> |
190 | void iterate(const int load_loop_blk, const int ur, const int load_dim_tail, |
191 | const F &f) { |
192 | for (int i = 0; i < load_loop_blk; ++i) { |
193 | const bool mask_flag = (load_dim_tail > 0) && (i == load_loop_blk - 1); |
194 | for (int j = 0; j < ur; ++j) |
195 | f(mask_flag, i, j); |
196 | } |
197 | } |
198 | template <typename F> |
199 | void iterate(const int load_loop_blk, const int ur, const F &f) { |
200 | iterate(load_loop_blk, ur, 0, f); |
201 | } |
202 | |
203 | void jit_avx2_conv_fwd_kernel_f32::apply_postops( |
204 | const int oc_blocks, const int ur_w, const int oc_tail) { |
205 | if (jcp.with_eltwise || jcp.with_binary) { |
206 | Label regular_store; |
207 | test(reg_ci_flag, FLAG_IC_LAST); |
208 | je(regular_store, T_NEAR); |
209 | |
210 | injector_utils::vmm_index_set_t vmm_idxs; |
211 | if (jcp.with_binary) { |
212 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
213 | rhs_arg_params_tail; |
214 | iterate(oc_blocks, ur_w, oc_tail, |
215 | [&](const bool mask_flag, const int i, const int j) { |
216 | const size_t aux_output_offset |
217 | = get_output_offset(i, j); |
218 | const auto vmm_idx = get_ymm_idx(ur_w, i, j); |
219 | vmm_idxs.emplace(vmm_idx); |
220 | |
221 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
222 | vmm_idx, reg_output); |
223 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
224 | vmm_idx, aux_output_offset); |
225 | if (mask_flag) |
226 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
227 | }); |
228 | rhs_arg_params = rhs_arg_params_tail; |
229 | rhs_arg_params.vmm_tail_idx_.clear(); |
230 | |
231 | Label postops_done; |
232 | if (oc_tail) { |
233 | Label postops_no_tail; |
234 | test(reg_oc_flag, FLAG_OC_LAST); |
235 | je(postops_no_tail, T_NEAR); |
236 | postops_injector_->compute_vector_range( |
237 | vmm_idxs, rhs_arg_params_tail); |
238 | jmp(postops_done, T_NEAR); |
239 | L(postops_no_tail); |
240 | } |
241 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
242 | L(postops_done); |
243 | |
244 | } else { |
245 | iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { |
246 | vmm_idxs.emplace(get_ymm_idx(ur_w, i, j)); |
247 | }); |
248 | postops_injector_->compute_vector_range(vmm_idxs); |
249 | } |
250 | L(regular_store); |
251 | } |
252 | } |
253 | |
254 | void jit_avx2_conv_fwd_kernel_f32::width_blk_step( |
255 | int ur_w, int pad_l, int pad_r, int oc_blocks) { |
256 | int kw = jcp.kw; |
257 | int oc_blk = jcp.oc_block; |
258 | int oc_tail = jcp.oc_tail; |
259 | |
260 | if (oc_tail) { |
261 | push(reg_oc_blocks); |
262 | mov(reg_oc_flag, ptr[param1 + GET_OFF(oc_flag)]); |
263 | } |
264 | |
265 | auto load_output_bias_and_add_bias = [=](bool is_tail) { |
266 | Label init_done, init_first; |
267 | |
268 | if (!jcp.with_sum) { |
269 | test(reg_ci_flag, FLAG_IC_FIRST); |
270 | jne(init_first, T_NEAR); |
271 | } |
272 | |
273 | for (int ii = 0; ii < oc_blocks; ii++) |
274 | for (int jj = 0; jj < ur_w; jj++) { |
275 | const auto ymm = get_ymm(ur_w, ii, jj); |
276 | if (is_tail && ii == oc_blocks - 1) |
277 | load_bytes(ymm, reg_output, get_output_offset(ii, jj), |
278 | oc_tail * sizeof(float)); |
279 | else |
280 | vmovups(ymm, |
281 | make_safe_addr(reg_output, |
282 | get_output_offset(ii, jj), reg_long_offt)); |
283 | } |
284 | |
285 | if (jcp.with_sum && jcp.with_bias) { |
286 | test(reg_ci_flag, FLAG_IC_FIRST); |
287 | je(init_done, T_NEAR); |
288 | |
289 | for (int ii = 0; ii < oc_blocks; ii++) |
290 | for (int jj = 0; jj < ur_w; jj++) { |
291 | const Ymm ymm = get_ymm(ur_w, ii, jj); |
292 | if (is_tail && ii == oc_blocks - 1) { |
293 | load_bytes(ytmp, reg_bias, sizeof(float) * ii * oc_blk, |
294 | oc_tail * sizeof(float)); |
295 | vaddps(ymm, ymm, ytmp); |
296 | } else { |
297 | vaddps(ymm, ymm, |
298 | yword[reg_bias + sizeof(float) * ii * oc_blk]); |
299 | } |
300 | } |
301 | } |
302 | jmp(init_done, T_NEAR); |
303 | |
304 | L(init_first); |
305 | |
306 | if (jcp.with_bias) { |
307 | for (int ii = 0; ii < oc_blocks; ii++) |
308 | for (int jj = 0; jj < ur_w; jj++) { |
309 | const Ymm ymm = get_ymm(ur_w, ii, jj); |
310 | if (is_tail && ii == oc_blocks - 1) |
311 | load_bytes(ymm, reg_bias, sizeof(float) * ii * oc_blk, |
312 | oc_tail * sizeof(float)); |
313 | else |
314 | vmovups(ymm, |
315 | yword[reg_bias + sizeof(float) * ii * oc_blk]); |
316 | } |
317 | } else { |
318 | for (int ii = 0; ii < oc_blocks; ii++) |
319 | for (int jj = 0; jj < ur_w; jj++) { |
320 | const Ymm ymm = get_ymm(ur_w, ii, jj); |
321 | uni_vpxor(ymm, ymm, ymm); |
322 | } |
323 | } |
324 | L(init_done); |
325 | }; |
326 | |
327 | if (oc_tail) { |
328 | if (jcp.nb_oc > jcp.nb_oc_blocking) { |
329 | Label load_tail, load_done; |
330 | test(reg_oc_flag, FLAG_OC_LAST); |
331 | jne(load_tail, T_NEAR); |
332 | |
333 | load_output_bias_and_add_bias(false); |
334 | jmp(load_done, T_NEAR); |
335 | |
336 | L(load_tail); |
337 | load_output_bias_and_add_bias(true); |
338 | |
339 | L(load_done); |
340 | } else { |
341 | load_output_bias_and_add_bias(true); |
342 | } |
343 | } else { |
344 | load_output_bias_and_add_bias(false); |
345 | } |
346 | |
347 | if (one_of(jcp.ndims, 3, 4)) { |
348 | mov(aux_reg_input, reg_input); |
349 | mov(aux_reg_kernel, reg_kernel); |
350 | } |
351 | |
352 | Label skip_kh_loop, skip_kd_loop, kd_loop; |
353 | if (jcp.ndims == 5) { |
354 | push(reg_output); |
355 | push(oi_iter); |
356 | |
357 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
358 | mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); |
359 | mov(aux_reg_inp_d, reg_input); |
360 | |
361 | if ((jcp.dilate_d >= jcp.id) |
362 | || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { |
363 | cmp(reg_ki, 0); |
364 | je(skip_kd_loop, T_NEAR); |
365 | } |
366 | L(kd_loop); |
367 | mov(kj, ptr[param1 + GET_OFF(kh_padding)]); |
368 | } else { |
369 | mov(kj, reg_kh); |
370 | } |
371 | |
372 | if (jcp.ndims == 5) { |
373 | mov(aux_reg_input, aux_reg_inp_d); |
374 | mov(aux_reg_kernel, aux_reg_ker_d); |
375 | } |
376 | |
377 | if ((jcp.dilate_h >= jcp.ih) |
378 | || (jcp.kh - 1) * (jcp.dilate_h + 1) |
379 | < nstl::max(jcp.t_pad, jcp.b_pad)) { |
380 | cmp(kj, 0); |
381 | je(skip_kh_loop, T_NEAR); |
382 | } |
383 | Label kh_loop; |
384 | L(kh_loop); |
385 | { |
386 | if ((jcp.ic % jcp.ic_block == 0) && jcp.kw >= 5 && pad_l == 0 |
387 | && pad_r == 0) { |
388 | oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); |
389 | add(aux_reg_input, |
390 | get_input_offset(0, filter_h_to_input(1)) |
391 | - get_input_offset(0, filter_w_to_input(kw))); |
392 | } else { |
393 | oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); |
394 | safe_add( |
395 | aux_reg_kernel, get_kernel_offset(0, kw, 0), reg_long_offt); |
396 | safe_add(aux_reg_input, get_input_offset(0, filter_h_to_input(1)), |
397 | reg_long_offt); |
398 | } |
399 | |
400 | dec(kj); |
401 | cmp(kj, 0); |
402 | jg(kh_loop, T_NEAR); |
403 | } |
404 | |
405 | L(skip_kh_loop); |
406 | |
407 | if (jcp.ndims == 5) { |
408 | safe_add(aux_reg_inp_d, get_input_offset(0, filter_d_to_input(1)), |
409 | reg_long_offt); |
410 | safe_add(aux_reg_ker_d, get_kernel_offset(0, jcp.kw * jcp.kh, 0), |
411 | reg_long_offt); |
412 | |
413 | dec(reg_ki); |
414 | cmp(reg_ki, 0); |
415 | jg(kd_loop, T_NEAR); |
416 | L(skip_kd_loop); |
417 | |
418 | pop(oi_iter); |
419 | pop(reg_output); |
420 | } |
421 | |
422 | apply_postops(oc_blocks, ur_w, oc_tail); |
423 | |
424 | auto store_output = [=](bool is_tail, int tail) { |
425 | const auto is_padding = jcp.oc_without_padding != jcp.oc; |
426 | if (is_padding) uni_vxorps(ytmp, ytmp, ytmp); |
427 | for (int ii = 0; ii < oc_blocks; ii++) |
428 | for (int jj = 0; jj < ur_w; jj++) { |
429 | Ymm reg_out = get_ymm(ur_w, ii, jj); |
430 | if (is_tail && ii == oc_blocks - 1) { |
431 | if (is_padding && jcp.with_binary) { |
432 | vmovups(make_safe_addr(reg_output, |
433 | get_output_offset(ii, jj), |
434 | reg_long_offt), |
435 | ytmp); |
436 | } |
437 | store_bytes(reg_out, reg_output, get_output_offset(ii, jj), |
438 | tail * sizeof(float)); |
439 | } else |
440 | vmovups(make_safe_addr(reg_output, |
441 | get_output_offset(ii, jj), reg_long_offt), |
442 | reg_out); |
443 | } |
444 | }; |
445 | |
446 | if (oc_tail) { |
447 | if (jcp.nb_oc > jcp.nb_oc_blocking) { |
448 | Label store_tail, store_done; |
449 | test(reg_oc_flag, FLAG_OC_LAST); |
450 | jne(store_tail, T_NEAR); |
451 | |
452 | store_output(false, oc_tail); |
453 | jmp(store_done, T_NEAR); |
454 | |
455 | L(store_tail); |
456 | store_output(true, oc_tail); |
457 | |
458 | L(store_done); |
459 | } else { |
460 | store_output(true, oc_tail); |
461 | } |
462 | } else { |
463 | Label regular_store; |
464 | Label store_done; |
465 | const int tail = jcp.oc_without_padding % jcp.oc_block; |
466 | if (jcp.with_binary && tail) { |
467 | test(reg_ci_flag, FLAG_IC_LAST); |
468 | je(regular_store, T_NEAR); |
469 | if (!oc_tail) mov(reg_oc_flag, ptr[param1 + GET_OFF(oc_flag)]); |
470 | test(reg_oc_flag, FLAG_OC_LAST); |
471 | je(regular_store, T_NEAR); |
472 | store_output(true, tail); |
473 | jmp(store_done, T_NEAR); |
474 | } |
475 | |
476 | L(regular_store); |
477 | store_output(false, oc_tail); |
478 | |
479 | L(store_done); |
480 | } |
481 | |
482 | if (oc_tail) pop(reg_oc_blocks); |
483 | } |
484 | |
485 | inline void jit_avx2_conv_fwd_kernel_f32::solve_common(int oc_blocks) { |
486 | int ur_w = jcp.ur_w; |
487 | int ur_w_tail = jcp.ur_w_tail; |
488 | int n_oi = jcp.ow / ur_w; |
489 | int iw = jcp.iw; |
490 | int kw = jcp.kw; |
491 | int str_w = jcp.stride_w; |
492 | |
493 | int l_pad = jcp.l_pad; |
494 | int r_pad = nstl::max(0, jcp.r_pad); |
495 | int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, str_w, |
496 | calculate_extended_filter_size(kw, jcp.dilate_w)); |
497 | if (r_pad1 > 0) n_oi--; |
498 | |
499 | if (l_pad > 0) { |
500 | n_oi--; |
501 | if (n_oi < 0 && r_pad1 > 0) |
502 | width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" |
503 | else |
504 | width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" |
505 | add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w, l_pad))); |
506 | add(reg_output, get_output_offset(0, ur_w)); |
507 | } |
508 | |
509 | Label ow_loop; |
510 | xor_(oi_iter, oi_iter); |
511 | |
512 | if (n_oi > 0) { |
513 | L(ow_loop); |
514 | |
515 | width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" |
516 | add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w))); |
517 | add(reg_output, get_output_offset(0, ur_w)); |
518 | |
519 | inc(oi_iter); |
520 | cmp(oi_iter, n_oi); |
521 | jl(ow_loop, T_NEAR); |
522 | } |
523 | |
524 | if (r_pad1 > 0 && n_oi >= 0) { |
525 | width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" |
526 | add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w))); |
527 | add(reg_output, get_output_offset(0, ur_w)); |
528 | } |
529 | |
530 | if (ur_w_tail != 0) |
531 | width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" |
532 | } |
533 | |
534 | void jit_avx2_conv_fwd_kernel_f32::generate() { |
535 | this->preamble(); |
536 | |
537 | mov(reg_input, ptr[this->param1 + GET_OFF(src)]); |
538 | mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); |
539 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
540 | if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); |
541 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
542 | mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); |
543 | mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); |
544 | |
545 | if (is_src_layout_nxc()) |
546 | mov(reg_channel, ptr[param1 + GET_OFF(reduce_work)]); |
547 | |
548 | int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; |
549 | |
550 | Label tail, exit; |
551 | |
552 | if (jcp.nb_oc > jcp.nb_oc_blocking) { |
553 | cmp(reg_oc_blocks, jcp.nb_oc_blocking); |
554 | jne(nb_oc_tail ? tail : exit, T_NEAR); |
555 | |
556 | solve_common(jcp.nb_oc_blocking); |
557 | jmp(exit, T_NEAR); |
558 | |
559 | if (nb_oc_tail) { |
560 | L(tail); |
561 | cmp(reg_oc_blocks, nb_oc_tail); |
562 | jne(exit, T_NEAR); |
563 | solve_common(nb_oc_tail); |
564 | } |
565 | |
566 | L(exit); |
567 | } else if (jcp.nb_oc == jcp.nb_oc_blocking) { |
568 | solve_common(jcp.nb_oc_blocking); |
569 | } else { |
570 | solve_common(nb_oc_tail); |
571 | } |
572 | |
573 | this->postamble(); |
574 | |
575 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
576 | } |
577 | |
578 | status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, |
579 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
580 | const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, |
581 | const primitive_attr_t &attr) { |
582 | if (!mayiuse(avx)) return status::unimplemented; |
583 | jcp.isa = mayiuse(avx2) ? avx2 : avx; |
584 | |
585 | jcp.nthr = dnnl_get_max_threads(); |
586 | |
587 | jcp.prop_kind = cd.prop_kind; |
588 | |
589 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
590 | int ndims = src_d.ndims(); |
591 | jcp.ndims = ndims; |
592 | |
593 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
594 | jcp.mb = src_d.dims()[0]; |
595 | |
596 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
597 | jcp.oc_without_padding = jcp.oc; |
598 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
599 | |
600 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
601 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
602 | jcp.iw = src_d.dims()[ndims - 1]; |
603 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
604 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
605 | jcp.ow = dst_d.dims()[ndims - 1]; |
606 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
607 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
608 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
609 | |
610 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
611 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
612 | jcp.l_pad = cd.padding[0][ndims - 3]; |
613 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
614 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
615 | jcp.stride_w = cd.strides[ndims - 3]; |
616 | |
617 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
618 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
619 | jcp.dilate_w = cd.dilates[ndims - 3]; |
620 | |
621 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
622 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
623 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
624 | jcp.r_pad = calculate_end_padding( |
625 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
626 | jcp.b_pad = calculate_end_padding( |
627 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
628 | jcp.back_pad = calculate_end_padding( |
629 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); |
630 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
631 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad |
632 | || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; |
633 | if (kernel_outside_src) return status::unimplemented; |
634 | |
635 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
636 | const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); |
637 | const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
638 | auto wei_tag_OIxio = with_groups |
639 | ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) |
640 | : pick(ndims - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); |
641 | auto wei_tag_Oxio = with_groups ? pick(ndims - 3, gOwi8o, gOhwi8o, gOdhwi8o) |
642 | : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o); |
643 | |
644 | jcp.src_tag |
645 | = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); |
646 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio); |
647 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
648 | |
649 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
650 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
651 | |
652 | bool is_data_layout_nxc |
653 | = everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
654 | |
655 | // Disable this kernel on high width 1d object as gemm performs better until |
656 | // optimizations can be made to fix it. |
657 | if (is_data_layout_nxc && ndims == 3 && jcp.ow > 11 * 1024) |
658 | return status::unimplemented; |
659 | |
660 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
661 | |
662 | const auto &post_ops = attr.post_ops_; |
663 | |
664 | jcp.with_sum = post_ops.find(primitive_kind::sum) != -1; |
665 | const int eltwise_ind = post_ops.find(primitive_kind::eltwise); |
666 | jcp.with_eltwise = eltwise_ind != -1; |
667 | const int binary_ind = post_ops.find(primitive_kind::binary); |
668 | jcp.with_binary = binary_ind != -1; |
669 | |
670 | jcp.post_ops = post_ops; |
671 | |
672 | const int simd_w = 8; |
673 | const bool flat = jcp.ic < simd_w; |
674 | const bool mimo = !flat; |
675 | |
676 | /* Grouped channel offset to support 'non-blocked data' format for |
677 | * convolution sizes with '(input_channel / ngroups) < simd' */ |
678 | jcp.nonblk_group_off |
679 | = one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic |
680 | : 1; |
681 | |
682 | bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1; |
683 | |
684 | if (ok_to_pad_channels) { |
685 | jcp.oc = rnd_up(jcp.oc, simd_w); |
686 | if (mimo) jcp.ic = rnd_up(jcp.ic, simd_w); |
687 | } |
688 | |
689 | if (jcp.with_eltwise || jcp.with_binary) |
690 | if (!mayiuse(avx2)) return status::unimplemented; |
691 | |
692 | using namespace injector; |
693 | static constexpr bool sum_at_pos_0_only = true; |
694 | static constexpr bool sum_requires_scale_one = true; |
695 | static constexpr bool sum_requires_zp_zero = true; |
696 | const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum}, |
697 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
698 | sum_requires_zp_zero}); |
699 | if (!post_ops_ok_) return status::unimplemented; |
700 | |
701 | bool args_ok = true |
702 | && IMPLICATION(flat, |
703 | jcp.wei_tag == wei_tag_Oxio |
704 | && ((jcp.src_tag == dat_tag_ncx |
705 | && jcp.dst_tag == dat_tag_nCx8c) |
706 | || (jcp.src_tag == dat_tag_nxc |
707 | && jcp.dst_tag == dat_tag_nxc))) |
708 | && IMPLICATION(mimo, |
709 | jcp.wei_tag == wei_tag_OIxio |
710 | && ((jcp.src_tag == dat_tag_nCx8c |
711 | && jcp.dst_tag == dat_tag_nCx8c) |
712 | || (jcp.src_tag == dat_tag_nxc |
713 | && jcp.dst_tag == dat_tag_nxc))) |
714 | && jcp.ic <= src_d.padded_dims()[1] |
715 | && jcp.oc <= dst_d.padded_dims()[1]; |
716 | if (!args_ok) return status::unimplemented; |
717 | |
718 | jcp.ur_h = 1; /* no code-unrolling by h so far */ |
719 | jcp.ur_w = 3; |
720 | |
721 | jcp.oc_block = simd_w; |
722 | jcp.nb_oc = div_up(jcp.oc, jcp.oc_block); |
723 | |
724 | jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ |
725 | |
726 | // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively |
727 | // Thus, we can only assign 14 or 15 YMMs for data storage |
728 | const int num_avail_regs = mayiuse(avx2) ? 15 : 14; |
729 | if (!mayiuse(avx2)) { |
730 | if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { |
731 | // current register assignment requires more YMMs than available |
732 | // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad |
733 | if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1) |
734 | jcp.ur_w -= 1; |
735 | else { |
736 | for (int b = 3; b > 1; b--) { |
737 | if (jcp.nb_oc % b == 0) { |
738 | jcp.nb_oc_blocking = b; |
739 | break; |
740 | } |
741 | } |
742 | if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { |
743 | // No optimal size for 'nb_oc_blocking' with regards to |
744 | // 'nb_oc', default to only unroll by 'ur_w'. |
745 | jcp.nb_oc_blocking = 1; |
746 | } |
747 | } |
748 | } |
749 | } |
750 | |
751 | if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; |
752 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
753 | |
754 | args_ok = true && IMPLICATION(!is_data_layout_nxc, jcp.oc % simd_w == 0) |
755 | && jcp.l_pad <= jcp.ur_w |
756 | && IMPLICATION(jcp.kw > 7, |
757 | (jcp.t_pad == 0 && jcp.l_pad == 0) |
758 | || (jcp.stride_w == 1 && jcp.stride_h == 1)) |
759 | && IMPLICATION(mimo && !is_data_layout_nxc, jcp.ic % simd_w == 0); |
760 | if (!args_ok) return status::unimplemented; |
761 | |
762 | jcp.ic_tail = is_data_layout_nxc ? jcp.ic % simd_w : 0; |
763 | jcp.oc_tail = is_data_layout_nxc |
764 | ? jcp.oc % simd_w |
765 | : (jcp.with_binary ? jcp.oc_without_padding % simd_w : 0); |
766 | |
767 | int r_pad_no_tail = nstl::max(0, |
768 | calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, |
769 | jcp.stride_w, ext_kw)); |
770 | |
771 | if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { |
772 | /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ |
773 | jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, |
774 | nstl::min(jcp.ow, num_avail_regs / 2)); |
775 | jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; |
776 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
777 | /* check again ... */ |
778 | r_pad_no_tail = nstl::max(0, |
779 | calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, |
780 | jcp.stride_w, ext_kw)); |
781 | if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) |
782 | return status::unimplemented; |
783 | } |
784 | assert(jcp.nb_oc_blocking > 0); |
785 | assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); |
786 | |
787 | jcp.ic_block = flat ? jcp.ic : simd_w; |
788 | jcp.nb_ic = div_up(jcp.ic, jcp.ic_block); |
789 | |
790 | jcp.nb_ic_blocking = 12; |
791 | jcp.nb_ic_blocking_max = 16; |
792 | |
793 | /* adjust the thread decomposition |
794 | * to improve the perf for small problem size |
795 | * the threshold L1_cache_size is empirical |
796 | * simply set the thread as 4 for now |
797 | * TODO: Add get_thr_eff func to get the optimal thread number*/ |
798 | size_t wei_size = (size_t)sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw |
799 | * jcp.kd; |
800 | size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih |
801 | * jcp.iw * jcp.id; |
802 | size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh |
803 | * jcp.ow * jcp.od; |
804 | size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size); |
805 | |
806 | const unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
807 | |
808 | if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size) { |
809 | jcp.nthr = nstl::min(jcp.nthr, 4); |
810 | } |
811 | |
812 | return status::success; |
813 | } |
814 | |
815 | void jit_avx2_conv_fwd_kernel_f32::init_scratchpad( |
816 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
817 | if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) |
818 | scratchpad.book<float>(key_conv_padded_bias, jcp.oc); |
819 | } |
820 | |
821 | void jit_avx2_conv_bwd_data_kernel_f32::compute_loop( |
822 | int ur_w, int l_overflow, int r_overflow) { |
823 | int kw = jcp.kw; |
824 | int ow = jcp.ow; |
825 | |
826 | int oc_block = jcp.oc_block; |
827 | int nb_ic_block = jcp.nb_ic_blocking; |
828 | int stride_w = jcp.stride_w; |
829 | int stride_h = jcp.stride_h; |
830 | int oc_tail = jcp.oc_tail; |
831 | int ic_tail = jcp.ic_tail; |
832 | |
833 | Label kd_loop, skip_kd_loop; |
834 | Label oc_loop, skip_oc_loop; |
835 | |
836 | for (int ii = 0; ii < nb_ic_block; ii++) |
837 | for (int jj = 0; jj < ur_w; jj++) { |
838 | uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), |
839 | Ymm(ur_w * ii + jj)); |
840 | } |
841 | |
842 | if (oc_tail) { |
843 | push(reg_long_offt); |
844 | mov(reg_reduce_work, ptr[param1 + GET_OFF(reduce_work)]); |
845 | } |
846 | |
847 | if (one_of(jcp.ndims, 3, 4)) { |
848 | cmp(reg_channel_work, 0); |
849 | jle(skip_oc_loop, T_NEAR); |
850 | xor_(reg_channel, reg_channel); |
851 | |
852 | mov(aux_reg_ddst_oc_loop, reg_ddst); |
853 | mov(aux_reg_kernel_oc_loop, reg_kernel); |
854 | |
855 | L(oc_loop); |
856 | mov(aux_reg_ddst, aux_reg_ddst_oc_loop); |
857 | mov(aux_reg_kernel, aux_reg_kernel_oc_loop); |
858 | } |
859 | |
860 | if (jcp.ndims == 5) { |
861 | assert(jcp.nb_oc_blocking == 1); |
862 | push(oi_iter); |
863 | |
864 | mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); |
865 | cmp(reg_ki, 0); |
866 | jle(skip_kd_loop, T_NEAR); |
867 | |
868 | mov(aux_reg_dst_d, reg_ddst); |
869 | mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]); |
870 | |
871 | L(kd_loop); |
872 | mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]); |
873 | } else { |
874 | mov(kj, reg_kh); |
875 | } |
876 | |
877 | if (jcp.ndims == 5) { |
878 | mov(aux_reg_ddst, aux_reg_dst_d); |
879 | mov(aux_reg_kernel, aux_reg_ker_d); |
880 | } |
881 | |
882 | Label kh_loop, skip_kh_loop; |
883 | cmp(kj, 0); |
884 | jle(skip_kh_loop, T_NEAR); |
885 | |
886 | L(kh_loop); |
887 | { |
888 | for (int ki = 0; ki < kw; ki++) { |
889 | int jj_start = get_iw_start(ki, l_overflow); // 0; |
890 | int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w; |
891 | |
892 | auto compute = [=](int cur_oc_blk) { |
893 | for (int ofm2 = 0; ofm2 < cur_oc_blk; ofm2++) { |
894 | for (int jj = jj_start; jj < jj_end; jj += stride_w) { |
895 | int aux_output_offset = get_ddst_offset( |
896 | 0, filter_w_to_ddst(ki, jj, jcp.l_pad), ofm2); |
897 | vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w), |
898 | ptr[aux_reg_ddst + aux_output_offset]); |
899 | } |
900 | |
901 | for (int ii = 0; ii < nb_ic_block; ii++) { |
902 | vmovups(ymm15, |
903 | ptr[aux_reg_kernel |
904 | + get_kernel_offset(0, ii, ki, ofm2)]); |
905 | for (int jj = jj_start; jj < jj_end; jj += stride_w) |
906 | vfmadd231ps(Ymm(ur_w * ii + jj), |
907 | Ymm(nb_ic_block * ur_w + jj / stride_w), |
908 | ymm15); |
909 | } |
910 | } |
911 | }; |
912 | |
913 | if (oc_tail) { |
914 | if (jcp.oc == oc_tail) |
915 | compute(oc_tail); |
916 | else { |
917 | Label oc_blk_tail, oc_blk_done; |
918 | cmp(reg_reduce_work, oc_block); |
919 | jl(oc_blk_tail, T_NEAR); |
920 | compute(oc_block); |
921 | jmp(oc_blk_done, T_NEAR); |
922 | |
923 | L(oc_blk_tail); |
924 | compute(oc_tail); |
925 | |
926 | L(oc_blk_done); |
927 | } |
928 | } else { |
929 | compute(oc_block); |
930 | } |
931 | } |
932 | |
933 | add(aux_reg_kernel, get_kernel_offset(0, 0, stride_h * kw, 0)); |
934 | sub(aux_reg_ddst, get_ddst_offset(0, (jcp.dilate_h + 1) * ow, 0)); |
935 | |
936 | dec(kj); |
937 | cmp(kj, 0); |
938 | jg(kh_loop, T_NEAR); |
939 | } |
940 | L(skip_kh_loop); |
941 | |
942 | if (jcp.ndims == 5) { |
943 | sub(aux_reg_dst_d, |
944 | get_ddst_offset(0, (jcp.dilate_d + 1) * jcp.oh * ow, 0)); |
945 | add(aux_reg_ker_d, get_kernel_offset(0, 0, jcp.kw * jcp.kh, 0)); |
946 | |
947 | dec(reg_ki); |
948 | cmp(reg_ki, 0); |
949 | jg(kd_loop, T_NEAR); |
950 | L(skip_kd_loop); |
951 | |
952 | pop(oi_iter); |
953 | } |
954 | |
955 | if (one_of(jcp.ndims, 3, 4)) { |
956 | int ddst_oc_shift = get_ddst_offset(1, 0, 0); |
957 | int kernel_oc_shift = get_kernel_offset(1, 0, 0, 0); |
958 | |
959 | add(aux_reg_ddst_oc_loop, ddst_oc_shift); |
960 | add(aux_reg_kernel_oc_loop, kernel_oc_shift); |
961 | |
962 | if (oc_tail) sub(reg_reduce_work, jcp.oc_block); |
963 | inc(reg_channel); |
964 | cmp(reg_channel, reg_channel_work); |
965 | jl(oc_loop, T_NEAR); |
966 | |
967 | L(skip_oc_loop); |
968 | mov(reg_channel, ptr[param1 + GET_OFF(channel)]); |
969 | } |
970 | |
971 | if (oc_tail) pop(reg_long_offt); |
972 | |
973 | auto load_store_dsrc = [=](bool is_tail) { |
974 | mov(reg_channel, ptr[param1 + GET_OFF(channel)]); |
975 | Label no_update_label; |
976 | cmp(reg_channel, 0); |
977 | je(no_update_label, T_NEAR); |
978 | |
979 | for (int ii = 0; ii < nb_ic_block; ii++) |
980 | for (int jj = 0; jj < ur_w; jj++) { |
981 | if (is_tail && ii == nb_ic_block - 1) |
982 | load_bytes(Ymm(15), reg_dsrc, get_dsrc_offset(ii, jj), |
983 | ic_tail * sizeof(float)); |
984 | else |
985 | vmovups(Ymm(15), |
986 | make_safe_addr(reg_dsrc, get_dsrc_offset(ii, jj), |
987 | reg_long_offt)); |
988 | vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(15)); |
989 | } |
990 | |
991 | L(no_update_label); |
992 | |
993 | for (int ii = 0; ii < nb_ic_block; ii++) |
994 | for (int jj = 0; jj < ur_w; jj++) { |
995 | if (is_tail && ii == nb_ic_block - 1) |
996 | store_bytes(Ymm(ur_w * ii + jj), reg_dsrc, |
997 | get_dsrc_offset(ii, jj), ic_tail * sizeof(float)); |
998 | else |
999 | vmovups(make_safe_addr(reg_dsrc, get_dsrc_offset(ii, jj), |
1000 | reg_long_offt), |
1001 | Ymm(ur_w * ii + jj)); |
1002 | } |
1003 | }; |
1004 | |
1005 | if (ic_tail) { |
1006 | Label load_store_tail, load_store_done; |
1007 | mov(reg_ci_flag, ptr[param1 + GET_OFF(flags)]); |
1008 | test(reg_ci_flag, FLAG_IC_LAST); |
1009 | jne(load_store_tail, T_NEAR); |
1010 | |
1011 | load_store_dsrc(false); |
1012 | jmp(load_store_done, T_NEAR); |
1013 | |
1014 | L(load_store_tail); |
1015 | load_store_dsrc(true); |
1016 | |
1017 | L(load_store_done); |
1018 | } else { |
1019 | load_store_dsrc(false); |
1020 | } |
1021 | } |
1022 | |
1023 | void jit_avx2_conv_bwd_data_kernel_f32::generate() { |
1024 | preamble(); |
1025 | |
1026 | mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); |
1027 | mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); |
1028 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
1029 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
1030 | mov(reg_channel, ptr[param1 + GET_OFF(channel)]); |
1031 | mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]); |
1032 | |
1033 | int ddst_shift = get_ddst_offset(0, filter_w_to_ddst(0, jcp.ur_w), 0); |
1034 | int dsrc_shift = get_dsrc_offset(0, jcp.ur_w); |
1035 | |
1036 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1037 | |
1038 | int l_overflow = nstl::max(0, (ext_kw - 1 - jcp.l_pad) / jcp.stride_w); |
1039 | int r_overflow = nstl::max( |
1040 | 0, (ext_kw - 1 - nstl::max(0, jcp.r_pad)) / jcp.stride_w); |
1041 | int r_overflow1 = nstl::max( |
1042 | 0, (ext_kw - 1 - jcp.r_pad - jcp.ur_w_tail) / jcp.stride_w); |
1043 | |
1044 | int n_oi = jcp.iw / jcp.ur_w; |
1045 | if (r_overflow1 > 0) n_oi--; |
1046 | |
1047 | if (jcp.ur_w == jcp.iw) { |
1048 | compute_loop(jcp.ur_w, l_overflow, r_overflow); |
1049 | } else if (n_oi == 0) { |
1050 | compute_loop(jcp.ur_w, l_overflow, r_overflow1); |
1051 | add(reg_dsrc, dsrc_shift); |
1052 | add(reg_ddst, ddst_shift); |
1053 | if (jcp.ur_w_tail != 0) compute_loop(jcp.ur_w_tail, 0, r_overflow); |
1054 | } else { |
1055 | xor_(oi_iter, oi_iter); |
1056 | if (l_overflow > 0) { |
1057 | compute_loop(jcp.ur_w, l_overflow, 0); |
1058 | add(reg_dsrc, dsrc_shift); |
1059 | add(reg_ddst, ddst_shift); |
1060 | inc(oi_iter); |
1061 | } |
1062 | |
1063 | if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) { |
1064 | Label ow_loop; |
1065 | L(ow_loop); |
1066 | { |
1067 | compute_loop(jcp.ur_w, 0, 0); |
1068 | add(reg_dsrc, dsrc_shift); |
1069 | add(reg_ddst, ddst_shift); |
1070 | inc(oi_iter); |
1071 | cmp(oi_iter, n_oi); |
1072 | jl(ow_loop, T_NEAR); |
1073 | } |
1074 | } |
1075 | |
1076 | if (r_overflow1 > 0) { |
1077 | compute_loop(jcp.ur_w, 0, r_overflow1); |
1078 | add(reg_dsrc, dsrc_shift); |
1079 | add(reg_ddst, ddst_shift); |
1080 | } |
1081 | |
1082 | if (jcp.ur_w_tail != 0) compute_loop(jcp.ur_w_tail, 0, r_overflow); |
1083 | } |
1084 | |
1085 | this->postamble(); |
1086 | } |
1087 | |
1088 | status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, |
1089 | const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, |
1090 | const memory_desc_wrapper &weights_d, |
1091 | const memory_desc_wrapper &diff_dst_d) { |
1092 | if (!mayiuse(avx2)) return status::unimplemented; |
1093 | |
1094 | jcp.nthr = dnnl_get_max_threads(); |
1095 | |
1096 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
1097 | |
1098 | int ndims = diff_src_d.ndims(); |
1099 | jcp.ndims = ndims; |
1100 | |
1101 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1102 | jcp.mb = diff_src_d.dims()[0]; |
1103 | |
1104 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
1105 | jcp.oc_without_padding = jcp.oc; |
1106 | jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; |
1107 | |
1108 | jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; |
1109 | jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2]; |
1110 | jcp.iw = diff_src_d.dims()[ndims - 1]; |
1111 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
1112 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2]; |
1113 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
1114 | |
1115 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
1116 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1117 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1118 | |
1119 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1120 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1121 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1122 | |
1123 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1124 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1125 | jcp.stride_w = cd.strides[ndims - 3]; |
1126 | |
1127 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1128 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
1129 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1130 | |
1131 | if ((jcp.dilate_w != 0 && jcp.stride_w != 1) |
1132 | || (jcp.dilate_d != 0 && jcp.stride_d != 1) |
1133 | || (jcp.dilate_h != 0 && jcp.stride_h != 1)) |
1134 | return status::unimplemented; |
1135 | |
1136 | const int simd_w = 8; |
1137 | |
1138 | /* derivatives */ |
1139 | jcp.idp = jcp.id + 2 * jcp.f_pad; |
1140 | jcp.ihp = jcp.ih + 2 * jcp.t_pad; |
1141 | jcp.iwp = jcp.iw + 2 * jcp.l_pad; |
1142 | jcp.ohp = jcp.oh; /* do we really need */ |
1143 | jcp.owp = jcp.ow; /* padded output ??? */ |
1144 | |
1145 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
1146 | const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
1147 | auto wei_tag = with_groups |
1148 | ? pick(ndims - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) |
1149 | : pick(ndims - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); |
1150 | |
1151 | jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
1152 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
1153 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
1154 | |
1155 | jcp.typesize_in = types::data_type_size(diff_src_d.data_type()); |
1156 | jcp.typesize_out = types::data_type_size(diff_dst_d.data_type()); |
1157 | |
1158 | bool is_data_layout_nxc |
1159 | = everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
1160 | bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1; |
1161 | |
1162 | /* gemm-based convolution performs better in these cases */ |
1163 | if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1) |
1164 | return status::unimplemented; |
1165 | |
1166 | if (ok_to_pad_channels) { |
1167 | jcp.oc = rnd_up(jcp.oc, simd_w); |
1168 | jcp.ic = rnd_up(jcp.ic, simd_w); |
1169 | } |
1170 | |
1171 | jcp.ic_block = (!is_data_layout_nxc && jcp.ic % simd_w) ? 1 : simd_w; |
1172 | jcp.nb_ic = div_up(jcp.ic, jcp.ic_block); |
1173 | |
1174 | jcp.ic_tail = is_data_layout_nxc ? jcp.ic % simd_w : 0; |
1175 | jcp.oc_tail = is_data_layout_nxc ? jcp.oc % simd_w : 0; |
1176 | |
1177 | jcp.oc_block = simd_w; |
1178 | jcp.nb_oc = div_up(jcp.oc, jcp.oc_block); |
1179 | |
1180 | jcp.ur_h = 1; /* no code-unrolling by h so far */ |
1181 | jcp.nb_ic_blocking = 1; |
1182 | jcp.nb_oc_blocking = 1; |
1183 | jcp.ur_w = 1; |
1184 | |
1185 | if (one_of(ndims, 3, 4) && jcp.ow < 40) |
1186 | jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2; |
1187 | |
1188 | auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
1189 | |
1190 | bool args_ok = true && jcp.stride_w == jcp.stride_h && jcp.stride_d == 1 |
1191 | && IMPLICATION(!is_data_layout_nxc, |
1192 | jcp.ic % simd_w == 0 && jcp.oc % simd_w == 0) |
1193 | && jcp.ic <= diff_src_d.padded_dims()[1] |
1194 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
1195 | && jcp.dst_tag == required_dat_tag |
1196 | && jcp.src_tag == required_dat_tag && jcp.wei_tag == wei_tag; |
1197 | if (!args_ok) return status::unimplemented; |
1198 | |
1199 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1200 | const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1201 | const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1202 | |
1203 | jcp.r_pad = calculate_end_padding( |
1204 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
1205 | jcp.b_pad = calculate_end_padding( |
1206 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
1207 | jcp.back_pad = calculate_end_padding( |
1208 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); |
1209 | |
1210 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
1211 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad |
1212 | || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; |
1213 | if (kernel_outside_src) return status::unimplemented; |
1214 | |
1215 | int l_overflow = nstl::max(0, (ext_kw - 1 - jcp.l_pad) / jcp.stride_w); |
1216 | |
1217 | const int max_regs = 15; /* Maximum number of registers available for |
1218 | result accumulation and delta dst data. |
1219 | One additional register is reserved for weights |
1220 | data. */ |
1221 | |
1222 | /* Find the best blocking with maximum number of fma instructions |
1223 | per ur_w * nb_ic_blocking compute loops. Number of required registers |
1224 | is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. |
1225 | ur_w must be divisible by stride_w */ |
1226 | if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers |
1227 | distribution exceeds max_regs */ |
1228 | return status::unimplemented; |
1229 | |
1230 | int best_nfmas = 0; |
1231 | for (int b = 1; b <= 4; b++) { |
1232 | if (jcp.nb_ic % b != 0) continue; |
1233 | |
1234 | for (int u = jcp.stride_w; u * b + u / jcp.stride_w <= max_regs |
1235 | && u < jcp.iw + jcp.stride_w; |
1236 | u += jcp.stride_w) { |
1237 | int ur_w = nstl::min(u, jcp.iw); |
1238 | /* maximum 1 step with l_overflow so far */ |
1239 | if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) continue; |
1240 | int nfmas = div_up(ur_w, jcp.stride_w) * b; |
1241 | if (nfmas > best_nfmas |
1242 | || (nfmas == best_nfmas && jcp.ur_w < ur_w)) { |
1243 | jcp.ur_w = ur_w; |
1244 | jcp.nb_ic_blocking = b; |
1245 | best_nfmas = nfmas; |
1246 | } |
1247 | } |
1248 | } |
1249 | if (best_nfmas == 0) /* can't find appropriate blocking */ |
1250 | return status::unimplemented; |
1251 | |
1252 | jcp.ur_w_tail = jcp.iw % jcp.ur_w; |
1253 | |
1254 | int r_overflow_no_tail = nstl::max( |
1255 | 0, (ext_kw - 1 - jcp.r_pad - jcp.ur_w_tail) / jcp.stride_w); |
1256 | |
1257 | bool tails_not_ok = false |
1258 | /* maximum 1 ur_w block with r_overflow so far */ |
1259 | || r_overflow_no_tail * jcp.stride_w > jcp.ur_w |
1260 | /* ur_w must be a multiple of stride */ |
1261 | || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) |
1262 | /* r_pad must not extend beyond ur_w_tail */ |
1263 | || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0)); |
1264 | if (tails_not_ok) return status::unimplemented; |
1265 | |
1266 | /* adjust the thread decomposition |
1267 | * to improve the perf for small problem size |
1268 | * the threshold L1_cache_size is empirical |
1269 | * simply set the thread to 4 for now |
1270 | * TODO: Add get_thr_eff func to get optimal thread number */ |
1271 | size_t wei_size = (size_t)sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw |
1272 | * jcp.kd; |
1273 | size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih |
1274 | * jcp.iw * jcp.id; |
1275 | size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh |
1276 | * jcp.ow * jcp.od; |
1277 | size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size); |
1278 | const unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
1279 | |
1280 | if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size) { |
1281 | jcp.nthr = nstl::min(jcp.nthr, 4); |
1282 | } |
1283 | |
1284 | return status::success; |
1285 | } |
1286 | |
1287 | void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad( |
1288 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
1289 | UNUSED(scratchpad); |
1290 | UNUSED(jcp); |
1291 | } |
1292 | |
1293 | void jit_avx2_conv_bwd_weights_kernel_f32::generate() { |
1294 | this->preamble(); |
1295 | |
1296 | mov(reg_input, ptr[this->param1 + GET_OFF(src)]); |
1297 | mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); |
1298 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
1299 | compute_oh_loop_common(); |
1300 | this->postamble(); |
1301 | } |
1302 | |
1303 | status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp, |
1304 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
1305 | const memory_desc_wrapper &diff_weights_d, |
1306 | const memory_desc_wrapper &diff_dst_d) { |
1307 | if (!mayiuse(avx2)) return status::unimplemented; |
1308 | |
1309 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
1310 | int ndims = src_d.ndims(); |
1311 | jcp.ndims = ndims; |
1312 | |
1313 | jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; |
1314 | jcp.mb = src_d.dims()[0]; |
1315 | |
1316 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
1317 | jcp.oc_without_padding = jcp.oc; |
1318 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1319 | jcp.ic_without_padding = jcp.ic; |
1320 | |
1321 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
1322 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
1323 | jcp.iw = src_d.dims()[ndims - 1]; |
1324 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
1325 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2]; |
1326 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
1327 | |
1328 | jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; |
1329 | jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2]; |
1330 | jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1]; |
1331 | |
1332 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1333 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1334 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1335 | |
1336 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1337 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1338 | jcp.stride_w = cd.strides[ndims - 3]; |
1339 | |
1340 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1341 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
1342 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1343 | |
1344 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
1345 | const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); |
1346 | const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
1347 | auto wei_tag_OIxio = with_groups |
1348 | ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) |
1349 | : pick(ndims - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); |
1350 | auto wei_tag_Oxio = with_groups ? pick(ndims - 3, gOwi8o, gOhwi8o, gOdhwi8o) |
1351 | : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o); |
1352 | |
1353 | jcp.src_tag |
1354 | = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c); |
1355 | jcp.wei_tag |
1356 | = diff_weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio); |
1357 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
1358 | |
1359 | bool is_data_layout_nxc |
1360 | = everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
1361 | |
1362 | jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; |
1363 | |
1364 | const bool flat = jcp.ic == 3; |
1365 | const bool mimo = !flat; |
1366 | |
1367 | const int simd_w = 8; |
1368 | |
1369 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1370 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1371 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1372 | jcp.r_pad = nstl::max(0, |
1373 | calculate_end_padding( |
1374 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw)); |
1375 | jcp.b_pad = nstl::max(0, |
1376 | calculate_end_padding( |
1377 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh)); |
1378 | jcp.back_pad = nstl::max(0, |
1379 | calculate_end_padding( |
1380 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd)); |
1381 | |
1382 | const int max_h_pad = ext_kh; |
1383 | const int max_w_pad = ext_kw; |
1384 | const bool boundaries_ok = true && jcp.t_pad < max_h_pad |
1385 | && jcp.b_pad < max_h_pad && jcp.l_pad < max_w_pad |
1386 | && jcp.r_pad < max_w_pad && jcp.f_pad == 0 && jcp.back_pad == 0; |
1387 | if (!boundaries_ok) return status::unimplemented; |
1388 | |
1389 | bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1; |
1390 | |
1391 | if (ok_to_pad_channels) { |
1392 | jcp.oc = rnd_up(jcp.oc, simd_w); |
1393 | if (mimo) jcp.ic = rnd_up(jcp.ic, simd_w); |
1394 | } |
1395 | |
1396 | jcp.ic_tail = is_data_layout_nxc ? jcp.ic % simd_w : 0; |
1397 | jcp.oc_tail = is_data_layout_nxc ? jcp.oc % simd_w : 0; |
1398 | |
1399 | bool args_ok = true |
1400 | && IMPLICATION(flat, |
1401 | jcp.wei_tag == wei_tag_Oxio |
1402 | && ((jcp.src_tag == dat_tag_ncx |
1403 | && jcp.dst_tag == dat_tag_nCx8c) |
1404 | || (jcp.src_tag == dat_tag_nxc |
1405 | && jcp.dst_tag == dat_tag_nxc))) |
1406 | && IMPLICATION(mimo, |
1407 | jcp.wei_tag == wei_tag_OIxio |
1408 | && ((jcp.src_tag == dat_tag_nCx8c |
1409 | && jcp.dst_tag == dat_tag_nCx8c) |
1410 | || (jcp.src_tag == dat_tag_nxc |
1411 | && jcp.dst_tag == dat_tag_nxc))) |
1412 | && IMPLICATION(mimo && !is_data_layout_nxc, jcp.ic % simd_w == 0) |
1413 | && IMPLICATION(!is_data_layout_nxc, jcp.oc % simd_w == 0) |
1414 | && jcp.kw < 14 && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */ |
1415 | && jcp.kh <= jcp.ih /* [bwd_w:r2] */ |
1416 | && jcp.kd <= jcp.f_pad + jcp.id && jcp.kd <= jcp.id |
1417 | && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */ |
1418 | && jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0 |
1419 | && jcp.ic <= src_d.padded_dims()[1] |
1420 | && jcp.oc <= diff_dst_d.padded_dims()[1]; |
1421 | if (!args_ok) return status::unimplemented; |
1422 | |
1423 | jcp.ic_block = flat ? jcp.ic : simd_w; |
1424 | jcp.nb_ic = div_up(jcp.ic, jcp.ic_block); |
1425 | |
1426 | jcp.oc_block = simd_w; |
1427 | jcp.nb_oc = div_up(jcp.oc, jcp.oc_block); |
1428 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
1429 | |
1430 | return status::success; |
1431 | } |
1432 | |
1433 | void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad( |
1434 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
1435 | if (jcp.with_bias && (jcp.oc_without_padding % jcp.oc_block != 0)) { |
1436 | const size_t nelems_padded_bias |
1437 | = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block); |
1438 | scratchpad.book<float>(key_conv_padded_bias, nelems_padded_bias); |
1439 | } |
1440 | } |
1441 | |
1442 | inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() { |
1443 | Label kd_comeback_loop; |
1444 | mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0 |
1445 | L(kd_comeback_loop); |
1446 | { |
1447 | sub(aux_reg_input, get_input_offset(0, jcp.iw * jcp.ih)); |
1448 | sub(aux_reg_kernel, get_kernel_offset(jcp.kw * jcp.kh, 0)); |
1449 | dec(kj); |
1450 | cmp(kj, 0); |
1451 | jg(kd_comeback_loop, T_NEAR); |
1452 | } |
1453 | } |
1454 | |
1455 | inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() { |
1456 | mov(kj, reg_kh); |
1457 | Label kh_comeback_loop; |
1458 | L(kh_comeback_loop); |
1459 | { |
1460 | sub(reg_input, get_input_offset(0, jcp.iw)); |
1461 | sub(reg_kernel, get_kernel_offset(jcp.kw, 0)); |
1462 | dec(kj); |
1463 | cmp(kj, 0); |
1464 | jg(kh_comeback_loop, T_NEAR); |
1465 | } |
1466 | } |
1467 | |
1468 | inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step( |
1469 | int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset, |
1470 | int kernel_offset, int output_offset) { |
1471 | |
1472 | if (ic_block_step <= 0) return; |
1473 | |
1474 | const int kw = jcp.kw; |
1475 | const int oc_tail = jcp.oc_tail; |
1476 | |
1477 | if (oc_tail) { |
1478 | push(reg_kh); |
1479 | mov(reg_ci_flag, ptr[param1 + GET_OFF(flags)]); |
1480 | } |
1481 | |
1482 | auto load_compute_store = [=](bool is_tail) { |
1483 | for (int i_kw = 0; i_kw < kw; i_kw++) |
1484 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
1485 | size_t off = get_kernel_offset(i_kw, i_ic) + kernel_offset; |
1486 | if (is_tail) |
1487 | load_bytes(Ymm(i_kw * ic_block_step + i_ic), reg_kernel, |
1488 | off, oc_tail * sizeof(float)); |
1489 | else |
1490 | vmovups(Ymm(i_kw * ic_block_step + i_ic), |
1491 | yword[reg_kernel + off]); |
1492 | } |
1493 | |
1494 | for (int i_ur = 0; i_ur < ur_w; i_ur++) { |
1495 | if (is_tail) |
1496 | load_bytes(Ymm(kw * ic_block_step + 0), reg_output, |
1497 | get_output_offset(0, i_ur) + output_offset, |
1498 | oc_tail * sizeof(float)); |
1499 | else |
1500 | vmovups(Ymm(kw * ic_block_step + 0), |
1501 | yword[reg_output + get_output_offset(0, i_ur) |
1502 | + output_offset]); |
1503 | |
1504 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
1505 | int i_iw = i_ur * jcp.stride_w + i_kw; |
1506 | if (i_iw - pad_l < 0 |
1507 | || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r) |
1508 | continue; |
1509 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
1510 | size_t i_off = get_input_offset(i_ic, i_iw - pad_l); |
1511 | vbroadcastss(Ymm(kw * ic_block_step + 1), |
1512 | make_safe_addr(reg_input, i_off, reg_long_offt)); |
1513 | vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic), |
1514 | Ymm(kw * ic_block_step + 0), |
1515 | Ymm(kw * ic_block_step + 1)); |
1516 | } |
1517 | } |
1518 | } |
1519 | |
1520 | for (int i_kw = 0; i_kw < kw; i_kw++) |
1521 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
1522 | size_t off = get_kernel_offset(i_kw, i_ic) + kernel_offset; |
1523 | if (is_tail) |
1524 | store_bytes(Ymm(i_kw * ic_block_step + i_ic), reg_kernel, |
1525 | off, oc_tail * sizeof(float)); |
1526 | |
1527 | else |
1528 | vmovups(yword[reg_kernel + off], |
1529 | Ymm(i_kw * ic_block_step + i_ic)); |
1530 | } |
1531 | }; |
1532 | |
1533 | if (oc_tail) { |
1534 | Label load_tail, load_done; |
1535 | test(reg_ci_flag, FLAG_OC_LAST); |
1536 | jne(load_tail, T_NEAR); |
1537 | |
1538 | load_compute_store(false); |
1539 | jmp(load_done, T_NEAR); |
1540 | |
1541 | L(load_tail); |
1542 | load_compute_store(true); |
1543 | |
1544 | L(load_done); |
1545 | } else { |
1546 | load_compute_store(false); |
1547 | } |
1548 | |
1549 | if (oc_tail) pop(reg_kh); |
1550 | } |
1551 | |
1552 | inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() { |
1553 | int ic_block_step; |
1554 | if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { |
1555 | ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block; |
1556 | } else if (one_of(jcp.src_tag, nwc, nhwc, ndhwc)) { |
1557 | ic_block_step = jcp.kw > 7 ? 1 : jcp.kw > 3 ? 2 : jcp.kw > 1 ? 4 : 8; |
1558 | if (jcp.ic_block % ic_block_step != 0) { |
1559 | ic_block_step = jcp.ic_block < ic_block_step ? jcp.ic_block : 1; |
1560 | } |
1561 | if (jcp.ic < ic_block_step) ic_block_step = jcp.ic; |
1562 | } else { |
1563 | ic_block_step = jcp.kw > 7 ? 1 : jcp.kw > 3 ? 2 : jcp.kw > 1 ? 4 : 8; |
1564 | } |
1565 | |
1566 | const int max_ur_w = jcp.ow > 56 ? 14 : 28; |
1567 | |
1568 | if (jcp.ow <= max_ur_w || one_of(jcp.src_tag, nwc, nhwc, ndhwc)) |
1569 | compute_oh_step_unroll_ow(ic_block_step, max_ur_w); |
1570 | else |
1571 | compute_oh_step_common(ic_block_step, max_ur_w); |
1572 | |
1573 | if (jcp.ndims == 5) { |
1574 | od_step_comeback_pointers(); |
1575 | mov(reg_input, aux_reg_input); |
1576 | mov(reg_kernel, aux_reg_kernel); |
1577 | } else { |
1578 | oh_step_comeback_pointers(); |
1579 | } |
1580 | } |
1581 | |
1582 | inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow( |
1583 | int ic_block_step, int max_ur_w) { |
1584 | UNUSED(max_ur_w); |
1585 | |
1586 | const int r_pad = jcp.r_pad; |
1587 | const int ic_tail = jcp.ic_tail; |
1588 | const int ic_block = jcp.ic_block; |
1589 | const int ic_block_step_tail = jcp.ic % ic_block_step; |
1590 | const size_t inp_icblk_stride = get_input_offset(ic_block_step, 0); |
1591 | |
1592 | if (ic_tail) { |
1593 | push(reg_ih_count); |
1594 | mov(reg_channel, ptr[param1 + GET_OFF(channel)]); |
1595 | } |
1596 | |
1597 | Label kd_loop; |
1598 | if (jcp.ndims == 5) { |
1599 | mov(aux_reg_input, reg_input); |
1600 | mov(aux_reg_kernel, reg_kernel); |
1601 | mov(ki, jcp.kd); |
1602 | L(kd_loop); |
1603 | mov(reg_input, aux_reg_input); |
1604 | mov(reg_kernel, aux_reg_kernel); |
1605 | } |
1606 | |
1607 | mov(kj, reg_kh); |
1608 | Label kh_loop, kh_loop_ic_tail, kh_loop_done; |
1609 | if (ic_tail) { |
1610 | cmp(reg_channel, ic_block); |
1611 | jl(kh_loop_ic_tail, T_NEAR); |
1612 | } |
1613 | |
1614 | L(kh_loop); |
1615 | { |
1616 | xor_(b_ic, b_ic); |
1617 | Label ic_block_loop; |
1618 | L(ic_block_loop); |
1619 | { |
1620 | compute_ic_block_step( |
1621 | jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, 0, 0); |
1622 | safe_add(reg_input, inp_icblk_stride, reg_long_offt); |
1623 | add(reg_kernel, get_kernel_offset(0, ic_block_step)); |
1624 | add(b_ic, ic_block_step); |
1625 | cmp(b_ic, ic_block); |
1626 | jl(ic_block_loop, T_NEAR); |
1627 | } |
1628 | add(reg_input, |
1629 | get_input_offset(0, jcp.iw) - get_input_offset(ic_block, 0)); |
1630 | add(reg_kernel, get_kernel_offset((jcp.kw - 1), 0)); |
1631 | dec(kj); |
1632 | cmp(kj, 0); |
1633 | jg(kh_loop, T_NEAR); |
1634 | } |
1635 | jmp(kh_loop_done, T_NEAR); |
1636 | |
1637 | L(kh_loop_ic_tail); |
1638 | { |
1639 | Label ic_block_loop, ic_block_loop_done; |
1640 | |
1641 | cmp(reg_channel, ic_block_step); |
1642 | jl(ic_block_loop_done, T_NEAR); |
1643 | |
1644 | mov(b_ic, ic_tail); |
1645 | L(ic_block_loop); |
1646 | { |
1647 | compute_ic_block_step( |
1648 | jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, 0, 0); |
1649 | safe_add(reg_input, inp_icblk_stride, reg_long_offt); |
1650 | add(reg_kernel, get_kernel_offset(0, ic_block_step)); |
1651 | sub(b_ic, ic_block_step); |
1652 | cmp(b_ic, ic_block_step); |
1653 | jge(ic_block_loop, T_NEAR); |
1654 | } |
1655 | |
1656 | L(ic_block_loop_done); |
1657 | |
1658 | if (ic_block_step_tail) { |
1659 | compute_ic_block_step( |
1660 | jcp.ow, jcp.l_pad, r_pad, ic_block_step_tail, 0, 0, 0); |
1661 | add(reg_input, get_input_offset(ic_block_step_tail, 0)); |
1662 | add(reg_kernel, get_kernel_offset(0, ic_block_step_tail)); |
1663 | } |
1664 | |
1665 | add(reg_input, |
1666 | get_input_offset(0, jcp.iw) - get_input_offset(ic_tail, 0)); |
1667 | add(reg_kernel, |
1668 | get_kernel_offset(0, ic_block - ic_tail) |
1669 | + get_kernel_offset((jcp.kw - 1), 0)); |
1670 | dec(kj); |
1671 | cmp(kj, 0); |
1672 | jg(kh_loop_ic_tail, T_NEAR); |
1673 | } |
1674 | |
1675 | L(kh_loop_done); |
1676 | |
1677 | if (jcp.ndims == 5) { |
1678 | add(aux_reg_input, get_input_offset(0, jcp.ih * jcp.iw)); |
1679 | add(aux_reg_kernel, get_kernel_offset(jcp.kh * jcp.kw, 0)); |
1680 | dec(ki); |
1681 | cmp(ki, 0); |
1682 | jg(kd_loop, T_NEAR); |
1683 | } |
1684 | if (ic_tail) pop(reg_ih_count); |
1685 | } |
1686 | |
1687 | inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common( |
1688 | int ic_block_step, int max_ur_w) { |
1689 | // TODO: suppport channel tails for nxc format |
1690 | |
1691 | const int ic_block = jcp.ic_block; |
1692 | const int stride_w = jcp.stride_w; |
1693 | Label kd_loop; |
1694 | |
1695 | const int r_pad = jcp.r_pad; |
1696 | |
1697 | int ur_w = nstl::min(jcp.ow, max_ur_w); |
1698 | int ur_w_trips = jcp.ow / ur_w; |
1699 | int ur_w_tail = jcp.ow % ur_w; |
1700 | if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) { |
1701 | if (ur_w_trips > 1) { |
1702 | ur_w_tail += ur_w; |
1703 | ur_w_trips--; |
1704 | } else { |
1705 | ur_w_tail += (ur_w - ur_w / 2); |
1706 | ur_w = ur_w / 2; |
1707 | } |
1708 | } |
1709 | |
1710 | int input_comeback |
1711 | = get_input_offset(0, ur_w_trips * ur_w * stride_w - jcp.l_pad); |
1712 | int output_comeback = get_output_offset(0, ur_w_trips * ur_w); |
1713 | |
1714 | if (jcp.ndims == 5) { |
1715 | mov(aux_reg_input, reg_input); |
1716 | mov(aux_reg_kernel, reg_kernel); |
1717 | mov(ki, jcp.kd); |
1718 | L(kd_loop); |
1719 | mov(reg_input, aux_reg_input); |
1720 | mov(reg_kernel, aux_reg_kernel); |
1721 | } |
1722 | |
1723 | mov(kj, reg_kh); |
1724 | Label kh_loop; |
1725 | L(kh_loop); |
1726 | { |
1727 | xor_(b_ic, b_ic); |
1728 | Label ic_block_loop; |
1729 | L(ic_block_loop); |
1730 | { |
1731 | if (jcp.l_pad != 0) { |
1732 | ur_w_trips--; |
1733 | compute_ic_block_step( |
1734 | ur_w, jcp.l_pad, 0, ic_block_step, 0, 0, 0); |
1735 | add(reg_input, |
1736 | get_input_offset(0, ur_w * stride_w - jcp.l_pad)); |
1737 | add(reg_output, get_output_offset(0, ur_w)); |
1738 | } |
1739 | |
1740 | if (ur_w_trips > 0) { |
1741 | xor_(reg_ur_w_trips, reg_ur_w_trips); |
1742 | Label ow_block_loop; |
1743 | L(ow_block_loop); |
1744 | { |
1745 | compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); |
1746 | add(reg_output, get_output_offset(0, ur_w)); |
1747 | add(reg_input, get_input_offset(0, ur_w * stride_w)); |
1748 | |
1749 | inc(reg_ur_w_trips); |
1750 | cmp(reg_ur_w_trips, ur_w_trips); |
1751 | jl(ow_block_loop, T_NEAR); |
1752 | } |
1753 | } |
1754 | |
1755 | if (ur_w_tail > 0) |
1756 | compute_ic_block_step( |
1757 | ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0); |
1758 | |
1759 | sub(reg_input, input_comeback); |
1760 | sub(reg_output, output_comeback); |
1761 | |
1762 | size_t inp_icblk_stride = get_input_offset(ic_block_step, 0); |
1763 | safe_add(reg_input, inp_icblk_stride, reg_long_offt); |
1764 | add(reg_kernel, get_kernel_offset(0, ic_block_step)); |
1765 | |
1766 | add(b_ic, ic_block_step); |
1767 | cmp(b_ic, jcp.ic_block); |
1768 | jl(ic_block_loop, T_NEAR); |
1769 | } |
1770 | add(reg_input, |
1771 | get_input_offset(0, jcp.iw) - get_input_offset(ic_block, 0)); |
1772 | add(reg_kernel, get_kernel_offset((jcp.kw - 1), 0)); |
1773 | dec(kj); |
1774 | cmp(kj, 0); |
1775 | jg(kh_loop, T_NEAR); |
1776 | } |
1777 | |
1778 | if (jcp.ndims == 5) { |
1779 | add(aux_reg_input, get_input_offset(0, jcp.ih * jcp.iw)); |
1780 | add(aux_reg_kernel, get_kernel_offset(jcp.kh * jcp.kw, 0)); |
1781 | dec(ki); |
1782 | cmp(ki, 0); |
1783 | jg(kd_loop, T_NEAR); |
1784 | } |
1785 | } |
1786 | |
1787 | inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() { |
1788 | const int t_pad = jcp.t_pad; |
1789 | const int stride_h = jcp.stride_h; |
1790 | int b_pad = jcp.b_pad; |
1791 | |
1792 | Label oh_tpad_loop, oh_loop, oh_loop_end; |
1793 | |
1794 | mov(reg_kh, jcp.kh); |
1795 | xor_(reg_ih_count, reg_ih_count); |
1796 | xor_(reg_oj, reg_oj); |
1797 | if (t_pad > 0) { |
1798 | assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */ |
1799 | mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); |
1800 | add(reg_kernel, get_kernel_offset(t_pad * jcp.kw, 0)); |
1801 | |
1802 | L(oh_tpad_loop); |
1803 | { |
1804 | compute_oh_step_disp(); |
1805 | add(reg_output, get_output_offset(0, jcp.ow)); |
1806 | sub(reg_kernel, get_kernel_offset(stride_h * jcp.kw, 0)); |
1807 | |
1808 | inc(reg_oj); |
1809 | add(reg_ih_count, stride_h); |
1810 | add(reg_kh, stride_h); |
1811 | |
1812 | /* the overlap between input and kernel may not reach kernel size. |
1813 | * so far we do not support that (until we put constant here) */ |
1814 | const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */ |
1815 | cmp(reg_kh, final_inp_ker_overlap); |
1816 | jl(oh_tpad_loop, T_NEAR); |
1817 | } |
1818 | |
1819 | if (t_pad % stride_h != 0) { |
1820 | int inp_corr = stride_h - t_pad % stride_h; |
1821 | add(reg_kernel, get_kernel_offset(inp_corr * jcp.kw, 0)); |
1822 | add(reg_input, get_input_offset(0, inp_corr * jcp.iw)); |
1823 | } |
1824 | } |
1825 | cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); |
1826 | jge(oh_loop_end, T_NEAR); |
1827 | cmp(reg_oj, jcp.oh); |
1828 | jge(oh_loop, T_NEAR); |
1829 | |
1830 | mov(reg_kh, jcp.kh); |
1831 | L(oh_loop); |
1832 | { |
1833 | compute_oh_step_disp(); |
1834 | add(reg_input, get_input_offset(0, stride_h * jcp.iw)); |
1835 | add(reg_output, get_output_offset(0, jcp.ow)); |
1836 | |
1837 | inc(reg_oj); |
1838 | add(reg_ih_count, stride_h); |
1839 | |
1840 | cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); |
1841 | jge(oh_loop_end, T_NEAR); |
1842 | |
1843 | cmp(reg_oj, jcp.oh); |
1844 | jl(oh_loop, T_NEAR); |
1845 | } |
1846 | L(oh_loop_end); |
1847 | if (b_pad > 0) { |
1848 | Label oh_bpad_loop, oh_bpad_loop_end; |
1849 | cmp(reg_oj, jcp.oh); |
1850 | jge(oh_bpad_loop_end, T_NEAR); |
1851 | |
1852 | mov(reg_kh, jcp.ih + t_pad); |
1853 | sub(reg_kh, reg_ih_count); |
1854 | L(oh_bpad_loop); |
1855 | { |
1856 | compute_oh_step_disp(); |
1857 | add(reg_input, get_input_offset(0, stride_h * jcp.iw)); |
1858 | add(reg_output, get_output_offset(0, jcp.ow)); |
1859 | |
1860 | sub(reg_kh, stride_h); |
1861 | cmp(reg_kh, 0); |
1862 | jle(oh_bpad_loop_end, T_NEAR); |
1863 | |
1864 | inc(reg_oj); |
1865 | cmp(reg_oj, jcp.oh); |
1866 | jl(oh_bpad_loop, T_NEAR); |
1867 | } |
1868 | L(oh_bpad_loop_end); |
1869 | } |
1870 | } |
1871 | |
1872 | } // namespace x64 |
1873 | } // namespace cpu |
1874 | } // namespace impl |
1875 | } // namespace dnnl |
1876 | |
1877 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1878 | |