1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * Copyright 2018 YANDEX LLC |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include <assert.h> |
19 | #include <limits> |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/memory.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/x64/injectors/injector_utils.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
30 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
31 | #include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp" |
32 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
33 | |
34 | #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | using namespace dnnl::impl::prop_kind; |
42 | using namespace dnnl::impl::format_tag; |
43 | using namespace dnnl::impl::utils; |
44 | |
45 | using namespace Xbyak; |
46 | |
47 | jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( |
48 | const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, |
49 | const memory_desc_t &dst_md) |
50 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx2) |
51 | , jcp(ajcp) |
52 | , attr_(attr) { |
53 | if (jcp.with_eltwise || jcp.with_binary) { |
54 | using namespace binary_injector; |
55 | static constexpr bool preserve_gpr = true; |
56 | static constexpr bool preserve_vmm = false; |
57 | static constexpr size_t helper_vmm_idx = 15; |
58 | static constexpr bool use_exact_tail_scalar_bcast = false; |
59 | const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; |
60 | |
61 | rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14, |
62 | r15, preserve_gpr, preserve_vmm, |
63 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
64 | memory_desc_wrapper(dst_md), tail_size, |
65 | use_exact_tail_scalar_bcast}; |
66 | static_params_t static_params {this->param1, rhs_arg_static_params}; |
67 | |
68 | postops_injector_ = utils::make_unique< |
69 | injector::jit_uni_postops_injector_t<avx2>>( |
70 | this, jcp.post_ops, static_params); |
71 | } |
72 | } |
73 | |
74 | void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) { |
75 | mov(aux1_reg_bcast_data, ptr[rsp + reg_bcast_data_off]); |
76 | mov(aux_reg_output_data, reg_output_data); |
77 | mov(bcast_loop_iter, reg_bcast_loop_work); |
78 | |
79 | Label bcast_loop, bcast_loop_tail, large_tail; |
80 | |
81 | cmp(bcast_loop_iter, jcp.bcast_block); |
82 | jl(bcast_loop_tail, T_NEAR); |
83 | |
84 | L(bcast_loop); |
85 | { |
86 | assert(jcp.bcast_block % jcp.ur == 0); |
87 | const int num_substeps = jcp.bcast_block / jcp.ur; |
88 | assert(num_substeps > 0 && num_substeps < 10); |
89 | for (int i = 0; i < num_substeps; i++) { |
90 | if (i == num_substeps - 1) L(large_tail); |
91 | generate_reduce_loop(load_loop_blk, jcp.ur); |
92 | if (i < num_substeps - 1) { |
93 | add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); |
94 | add(aux_reg_output_data, jcp.bcast_loop_output_substep); |
95 | } else { |
96 | add(aux1_reg_bcast_data, |
97 | jcp.bcast_loop_bcast_step |
98 | - (num_substeps - 1) |
99 | * jcp.bcast_loop_bcast_substep); |
100 | add(aux_reg_output_data, |
101 | jcp.bcast_loop_output_step |
102 | - (num_substeps - 1) |
103 | * jcp.bcast_loop_output_substep); |
104 | } |
105 | sub(bcast_loop_iter, jcp.ur); |
106 | } |
107 | cmp(bcast_loop_iter, jcp.bcast_block); |
108 | jge(bcast_loop, T_NEAR); |
109 | } |
110 | |
111 | L(bcast_loop_tail); |
112 | if (jcp.ur_tail) { |
113 | Label bcast_loop_tail_out; |
114 | if (jcp.ur_tail >= jcp.ur) { |
115 | cmp(bcast_loop_iter, jcp.ur); |
116 | jge(large_tail, T_NEAR); |
117 | } |
118 | if (jcp.ur_tail % jcp.ur > 0) { |
119 | cmp(bcast_loop_iter, 0); |
120 | jle(bcast_loop_tail_out, T_NEAR); |
121 | generate_reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur); |
122 | L(bcast_loop_tail_out); |
123 | } |
124 | } |
125 | } |
126 | |
127 | static int vreg_accum_idx(const int load_loop_blk, int i, int j) { |
128 | return (j * load_loop_blk + i); |
129 | } |
130 | |
131 | static Ymm vreg_accum(const int load_loop_blk, int i, int j) { |
132 | return Ymm(vreg_accum_idx(load_loop_blk, i, j)); |
133 | } |
134 | |
135 | template <typename F> |
136 | void iterate(const int load_loop_blk, const int ur, const int load_dim_tail, |
137 | const F &f) { |
138 | for (int i = 0; i < load_loop_blk; ++i) { |
139 | const bool mask_flag = (load_dim_tail > 0) && (i == load_loop_blk - 1); |
140 | for (int j = 0; j < ur; ++j) |
141 | f(mask_flag, i, j); |
142 | } |
143 | } |
144 | template <typename F> |
145 | void iterate(const int load_loop_blk, const int ur, const F &f) { |
146 | iterate(load_loop_blk, ur, 0, f); |
147 | } |
148 | |
149 | void jit_avx2_1x1_conv_kernel_f32::apply_postops( |
150 | const int load_loop_blk, const int ur, const int load_dim_tail) { |
151 | if (jcp.with_eltwise || jcp.with_binary) { |
152 | assert(ur * load_loop_blk < 14); |
153 | |
154 | Label store_nopost_ops; |
155 | test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); |
156 | jz(store_nopost_ops, T_NEAR); |
157 | |
158 | injector_utils::vmm_index_set_t vmm_idxs; |
159 | if (jcp.with_binary) { |
160 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
161 | rhs_arg_params_tail; |
162 | |
163 | iterate(load_loop_blk, ur, load_dim_tail, |
164 | [&](const bool mask_flag, const int i, const int j) { |
165 | const size_t aux_output_offset |
166 | = (i * get_output_i_offset(jcp) |
167 | + j * get_output_j_offset(jcp)) |
168 | * sizeof(float); |
169 | const auto vmm_idx |
170 | = vreg_accum_idx(load_loop_blk, i, j); |
171 | vmm_idxs.emplace(vmm_idx); |
172 | |
173 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
174 | vmm_idx, aux_reg_output_data); |
175 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
176 | vmm_idx, aux_output_offset); |
177 | if (mask_flag) |
178 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
179 | }); |
180 | rhs_arg_params = rhs_arg_params_tail; |
181 | rhs_arg_params.vmm_tail_idx_.clear(); |
182 | |
183 | const injector_utils::register_preserve_guard_t register_guard( |
184 | this, {abi_param1}); |
185 | const size_t reg_guard_stack_occupied |
186 | = register_guard.stack_space_occupied(); |
187 | mov(abi_param1, |
188 | ptr[rsp + reg_abi_param1_backup |
189 | + reg_guard_stack_occupied]); |
190 | |
191 | Label postops_done; |
192 | if (load_dim_tail) { |
193 | Label postops_no_tail; |
194 | cmp(reg_load_loop_work, |
195 | load_loop_blk * jcp.load_loop_iter_step); |
196 | jge(postops_no_tail, T_NEAR); |
197 | postops_injector_->compute_vector_range( |
198 | vmm_idxs, rhs_arg_params_tail); |
199 | jmp(postops_done, T_NEAR); |
200 | L(postops_no_tail); |
201 | } |
202 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
203 | L(postops_done); |
204 | } else { |
205 | iterate(load_loop_blk, ur, load_dim_tail, |
206 | [&](const bool, const int i, const int j) { |
207 | vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i, j)); |
208 | }); |
209 | postops_injector_->compute_vector_range(vmm_idxs); |
210 | } |
211 | L(store_nopost_ops); |
212 | } |
213 | }; |
214 | |
215 | void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( |
216 | int load_loop_blk, int ur) { |
217 | const int load_dim_tail |
218 | = ((jcp.with_binary |
219 | && one_of(jcp.prop_kind, forward_training, |
220 | forward_inference)) |
221 | ? jcp.oc_without_padding |
222 | : jcp.load_dim) |
223 | % jcp.load_block; |
224 | const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block; |
225 | |
226 | auto vreg_load = [=](int i) { return Ymm(ur * load_loop_blk + i); }; |
227 | |
228 | auto bias_ptr = [=](int i) { |
229 | return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; |
230 | }; |
231 | |
232 | auto bcast_ptr = [=](int u, int j) { |
233 | assert(j < jcp.ur); |
234 | assert(u <= jcp.reduce_loop_unroll); |
235 | const size_t offset = get_bcast_offset(jcp, u, j); |
236 | return make_safe_addr(aux_reg_bcast_data, offset, reg_long_offt); |
237 | }; |
238 | |
239 | auto get_load_offset_bwd_w = [=](int u, int i) { |
240 | size_t u0 = u % jcp.reduce_loop_unroll; |
241 | size_t u1 = u / jcp.reduce_loop_unroll; |
242 | return u1 * jcp.reduce_loop_load_step |
243 | + sizeof(float) * get_load_bwd_w_offset(jcp, i, u0); |
244 | }; |
245 | |
246 | auto load_ptr = [=](int u, int i) { |
247 | size_t offt; |
248 | size_t u0 = u % jcp.reduce_loop_unroll; |
249 | size_t u1 = u / jcp.reduce_loop_unroll; |
250 | switch (jcp.prop_kind) { |
251 | case backward_data: |
252 | offt = (i * jcp.oc_block + u0) * jcp.ic_block; |
253 | break; |
254 | case backward_weights: |
255 | offt = get_load_bwd_w_offset(jcp, i, u0); |
256 | break; |
257 | default: |
258 | offt = (i * rnd_up(jcp.ic, jcp.ic_block) + u0) * jcp.oc_block; |
259 | } |
260 | return ptr[aux_reg_load_data + u1 * jcp.reduce_loop_load_step |
261 | + sizeof(float) * offt]; |
262 | }; |
263 | |
264 | auto get_output_offset = [=](int i, int j) { |
265 | switch (jcp.prop_kind) { |
266 | case backward_weights: return sizeof(float) * jcp.oc_block * j; |
267 | default: |
268 | return (i * get_output_i_offset(jcp) |
269 | + j * get_output_j_offset(jcp)) |
270 | * sizeof(float); |
271 | } |
272 | }; |
273 | |
274 | auto output_ptr = [=](int i, int j) { |
275 | switch (jcp.prop_kind) { |
276 | case backward_weights: |
277 | return ptr[aux_reg_output_data |
278 | + (i ? reg_output_stride * i |
279 | : 0) // TODO: Xbyak should allow 0 scale |
280 | + sizeof(float) * jcp.oc_block * j]; |
281 | default: |
282 | const size_t off = get_output_offset(i, j); |
283 | return make_safe_addr(aux_reg_output_data, off, reg_long_offt); |
284 | } |
285 | }; |
286 | |
287 | auto init = [=]() { |
288 | Label init_done, init_zero; |
289 | |
290 | if (jcp.with_bias |
291 | && one_of(jcp.prop_kind, forward_training, forward_inference)) { |
292 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
293 | jz(init_zero, T_NEAR); |
294 | |
295 | for (int i = 0; i < load_loop_blk; i++) { |
296 | for (int j = 0; j < ur; ++j) { |
297 | if (load_dim_tail > 0 && i == load_loop_blk - 1) { |
298 | Label load_bias_tail, load_bias_done; |
299 | cmp(reg_load_loop_work, |
300 | load_loop_blk * jcp.load_loop_iter_step); |
301 | jl(load_bias_tail); |
302 | vmovups(vreg_accum(load_loop_blk, i, j), bias_ptr(i)); |
303 | jmp(load_bias_done); |
304 | |
305 | L(load_bias_tail); |
306 | load_bytes(vreg_accum(load_loop_blk, i, j), |
307 | reg_bias_data, i * jcp.oc_block * sizeof(float), |
308 | load_dim_tail * sizeof(float)); |
309 | L(load_bias_done); |
310 | } else { |
311 | vmovups(vreg_accum(load_loop_blk, i, j), bias_ptr(i)); |
312 | } |
313 | } |
314 | } |
315 | jmp(init_done); |
316 | } |
317 | |
318 | L(init_zero); |
319 | for (int i = 0; i < load_loop_blk; ++i) |
320 | for (int j = 0; j < ur; ++j) { |
321 | auto r = vreg_accum(load_loop_blk, i, j); |
322 | vxorps(r, r, r); |
323 | } |
324 | |
325 | L(init_done); |
326 | for (int i = 0; i < load_loop_blk; ++i) { |
327 | if (jcp.prop_kind == backward_weights && load_dim_tail > 0 |
328 | && i == load_loop_blk - 1) { |
329 | Label load_init_tail, load_init_done; |
330 | cmp(reg_load_loop_work, |
331 | load_loop_blk * jcp.load_loop_iter_step); |
332 | jl(load_init_tail); |
333 | vmovups(vreg_load(i), load_ptr(0, i)); |
334 | jmp(load_init_done); |
335 | |
336 | L(load_init_tail); |
337 | vxorps(vreg_load(i), vreg_load(i), vreg_load(i)); |
338 | load_bytes(vreg_load(i), aux_reg_load_data, |
339 | get_load_offset_bwd_w(0, i), |
340 | load_dim_tail * sizeof(float)); |
341 | L(load_init_done); |
342 | } else { |
343 | vmovups(vreg_load(i), load_ptr(0, i)); |
344 | } |
345 | } |
346 | vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); |
347 | }; |
348 | |
349 | auto store = [=]() { |
350 | Label store_noadd; |
351 | |
352 | if (!jcp.with_sum) { |
353 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
354 | jnz(store_noadd, T_NEAR); |
355 | } |
356 | |
357 | for (int j = 0; j < ur; ++j) |
358 | for (int i = 0; i < load_loop_blk; ++i) { |
359 | auto r = vreg_accum(load_loop_blk, i, j); |
360 | if (jcp.with_sum && load_dim_tail > 0 |
361 | && i == load_loop_blk - 1) { |
362 | Label sum_tail, sum_done; |
363 | cmp(reg_load_loop_work, |
364 | load_loop_blk * jcp.load_loop_iter_step); |
365 | jl(sum_tail); |
366 | vaddps(r, r, output_ptr(i, j)); |
367 | jmp(sum_done); |
368 | |
369 | L(sum_tail); |
370 | load_bytes(vtmp, aux_reg_output_data, |
371 | get_output_offset(i, j), |
372 | load_dim_tail * sizeof(float)); |
373 | vaddps(r, r, vtmp); |
374 | L(sum_done); |
375 | } else { |
376 | vaddps(r, r, output_ptr(i, j)); |
377 | } |
378 | } |
379 | |
380 | L(store_noadd); |
381 | |
382 | apply_postops(load_loop_blk, ur, load_dim_tail); |
383 | |
384 | if (jcp.prop_kind == backward_weights && load_dim_tail > 0) { |
385 | push(aux_reg_bcast_data); |
386 | } |
387 | |
388 | const auto is_padding = jcp.oc_without_padding != jcp.oc; |
389 | if (is_padding) uni_vxorps(vtmp, vtmp, vtmp); |
390 | for (int j = 0; j < ur; ++j) |
391 | for (int i = 0; i < load_loop_blk; ++i) { |
392 | if (load_dim_tail > 0 && i == load_loop_blk - 1) { |
393 | Label store_tail, store_done; |
394 | cmp(reg_load_loop_work, |
395 | load_loop_blk * jcp.load_loop_iter_step); |
396 | jl(store_tail); |
397 | vmovups(output_ptr(i, j), vreg_accum(load_loop_blk, i, j)); |
398 | jmp(store_done); |
399 | |
400 | L(store_tail); |
401 | if (jcp.prop_kind == backward_weights) { |
402 | if (i) { |
403 | xor_(reg_tmp, reg_tmp); // rdx |
404 | mov(reg_tmp_output_stride, |
405 | reg_output_stride); // rax |
406 | mov(reg_output_stride_scale, i); |
407 | imul(reg_output_stride_scale); |
408 | } else { |
409 | xor_(reg_tmp_output_stride, reg_tmp_output_stride); |
410 | } |
411 | lea(reg_tmp, |
412 | ptr[aux_reg_output_data |
413 | + reg_tmp_output_stride]); |
414 | vmovups(output_ptr(i, j), |
415 | vreg_accum(load_loop_blk, i, j)); |
416 | } else { |
417 | if (is_padding && jcp.with_binary) { |
418 | vmovups(ptr[aux_reg_output_data |
419 | + get_output_offset(i, j)], |
420 | vtmp); |
421 | } |
422 | store_bytes(vreg_accum(load_loop_blk, i, j), |
423 | aux_reg_output_data, get_output_offset(i, j), |
424 | load_dim_tail * sizeof(float)); |
425 | } |
426 | L(store_done); |
427 | } else { |
428 | vmovups(output_ptr(i, j), vreg_accum(load_loop_blk, i, j)); |
429 | } |
430 | } |
431 | |
432 | if (jcp.prop_kind == backward_weights && load_dim_tail > 0) { |
433 | pop(aux_reg_bcast_data); |
434 | } |
435 | }; |
436 | |
437 | auto fma_block = [=](bool last_block) { |
438 | const bool is_tail = reduce_dim_tail && last_block; |
439 | const int u_end = is_tail ? reduce_dim_tail : jcp.reduce_loop_unroll; |
440 | for (int u = 0; u < u_end; ++u) { |
441 | for (int j = 0; j < ur; ++j) { |
442 | for (int i = 0; i < load_loop_blk; ++i) { |
443 | if (jcp.isa == avx2) |
444 | vfmadd231ps(vreg_accum(load_loop_blk, i, j), |
445 | vreg_load(i), vreg_bcast); |
446 | else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support |
447 | vmulps(vtmp, vreg_bcast, vreg_load(i)); |
448 | vaddps(vreg_accum(load_loop_blk, i, j), |
449 | vreg_accum(load_loop_blk, i, j), vtmp); |
450 | } |
451 | if (j == ur - 1 && !(last_block && u == u_end - 1)) { |
452 | if (jcp.prop_kind == backward_weights |
453 | && load_dim_tail > 0 |
454 | && i == load_loop_blk - 1) { |
455 | Label fma_load_tail, fma_load_done; |
456 | cmp(reg_load_loop_work, |
457 | load_loop_blk * jcp.load_loop_iter_step); |
458 | jl(fma_load_tail); |
459 | vmovups(vreg_load(i), load_ptr(u + 1, i)); |
460 | jmp(fma_load_done); |
461 | |
462 | L(fma_load_tail); |
463 | vxorps(vreg_load(i), vreg_load(i), vreg_load(i)); |
464 | load_bytes(vreg_load(i), aux_reg_load_data, |
465 | get_load_offset_bwd_w(u + 1, i), |
466 | load_dim_tail * sizeof(float)); |
467 | L(fma_load_done); |
468 | } else { |
469 | vmovups(vreg_load(i), load_ptr(u + 1, i)); |
470 | } |
471 | } |
472 | } |
473 | if (j < ur - 1) vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); |
474 | } |
475 | if (!last_block || u < u_end - 1) |
476 | vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); |
477 | } |
478 | }; |
479 | |
480 | Label reduce_loop, reduce_loop_tail; |
481 | |
482 | mov(aux_reg_load_data, reg_load_data); |
483 | mov(aux_reg_bcast_data, aux1_reg_bcast_data); |
484 | |
485 | init(); |
486 | |
487 | mov(reduce_loop_iter, reg_reduce_loop_work); |
488 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
489 | jle(reduce_loop_tail, T_NEAR); |
490 | |
491 | L(reduce_loop); |
492 | { |
493 | fma_block(false); |
494 | safe_add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step, reg_long_offt); |
495 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
496 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
497 | jg(reduce_loop, T_NEAR); |
498 | } |
499 | |
500 | L(reduce_loop_tail); |
501 | fma_block(true); |
502 | |
503 | store(); |
504 | } |
505 | |
506 | void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) { |
507 | if (!jcp.with_bias || jcp.prop_kind != backward_weights) return; |
508 | |
509 | Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; |
510 | Label diff_bias_load; |
511 | |
512 | auto diff_bias_ptr = [=](int i) { |
513 | return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; |
514 | }; |
515 | |
516 | auto load_ptr = [=](int u, int i) { |
517 | return ptr[aux_reg_load_data |
518 | + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; |
519 | }; |
520 | |
521 | auto diff_bias_reg = [=](int i) { return Ymm(i); }; |
522 | |
523 | mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); |
524 | cmp(reg_diff_bias_data, 0); |
525 | je(diff_bias_loop_out, T_NEAR); |
526 | |
527 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
528 | jz(diff_bias_load, T_NEAR); |
529 | |
530 | for (int i = 0; i < load_loop_blk; ++i) { |
531 | auto r = diff_bias_reg(i); |
532 | vxorps(r, r, r); |
533 | } |
534 | jmp(diff_bias_init_out, T_NEAR); |
535 | |
536 | L(diff_bias_load); |
537 | for (int i = 0; i < load_loop_blk; ++i) |
538 | vmovups(diff_bias_reg(i), diff_bias_ptr(i)); |
539 | |
540 | L(diff_bias_init_out); |
541 | mov(aux_reg_load_data, reg_load_data); |
542 | mov(reduce_loop_iter, reg_reduce_loop_work); |
543 | L(diff_bias_loop); |
544 | { |
545 | for (int u = 0; u < jcp.reduce_loop_unroll; ++u) |
546 | for (int i = 0; i < load_loop_blk; ++i) |
547 | vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); |
548 | assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); |
549 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
550 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
551 | jnz(diff_bias_loop, T_NEAR); |
552 | } |
553 | |
554 | for (int i = 0; i < load_loop_blk; i++) |
555 | vmovups(diff_bias_ptr(i), diff_bias_reg(i)); |
556 | add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); |
557 | mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); |
558 | |
559 | L(diff_bias_loop_out); |
560 | } |
561 | |
562 | void jit_avx2_1x1_conv_kernel_f32::generate() { |
563 | preamble(); |
564 | |
565 | sub(rsp, stack_space_needed); |
566 | |
567 | if (jcp.with_binary) { |
568 | const auto zeroed_reg = r15; |
569 | xor_(zeroed_reg, zeroed_reg); |
570 | mov(ptr[rsp + reg_binary_post_op_acc_off], zeroed_reg); |
571 | mov(ptr[rsp + reg_abi_param1_backup], abi_param1); |
572 | } |
573 | |
574 | mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); |
575 | mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); |
576 | mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); |
577 | mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); |
578 | if (jcp.with_bias) { |
579 | if (jcp.prop_kind == backward_weights) { |
580 | mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
581 | mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); |
582 | } else |
583 | mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
584 | } |
585 | |
586 | mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); |
587 | mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); |
588 | mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); |
589 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
590 | if (jcp.prop_kind == backward_weights) |
591 | mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); |
592 | |
593 | auto generate_load_loop_body = [=](int load_loop_blk) { |
594 | generate_bcast_loop(load_loop_blk); |
595 | add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); |
596 | switch (jcp.prop_kind) { |
597 | case forward_training: |
598 | case forward_inference: |
599 | add(reg_bias_data, |
600 | load_loop_blk * jcp.oc_block * sizeof(float)); |
601 | safe_add(reg_output_data, |
602 | get_load_loop_output_fwd_offset(jcp, load_loop_blk), |
603 | reg_long_offt); |
604 | if (jcp.with_binary) { |
605 | mov(aux_reg_load_data, |
606 | ptr[rsp + reg_binary_post_op_acc_off]); |
607 | add(aux_reg_load_data, jcp.load_block * load_loop_blk); |
608 | mov(ptr[rsp + reg_binary_post_op_acc_off], |
609 | aux_reg_load_data); |
610 | } |
611 | break; |
612 | case backward_data: |
613 | safe_add(reg_output_data, |
614 | get_load_loop_output_bwd_d_offset(jcp, load_loop_blk), |
615 | reg_long_offt); |
616 | break; |
617 | case backward_weights: |
618 | for (int i = 0; i < load_loop_blk; i++) |
619 | add(reg_output_data, reg_output_stride); |
620 | break; |
621 | default: assert(!"invalid prop_kind" ); |
622 | } |
623 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
624 | }; |
625 | |
626 | Label load_loop_blk_8; |
627 | Label load_loop_blk_16; |
628 | Label load_loop_blk_24; |
629 | Label load_loop_blk_end; |
630 | |
631 | cmp(reg_load_loop_work, 8); |
632 | jle(load_loop_blk_8, T_NEAR); |
633 | |
634 | cmp(reg_load_loop_work, 32); |
635 | je(load_loop_blk_16, T_NEAR); |
636 | |
637 | cmp(reg_load_loop_work, 16); |
638 | jle(load_loop_blk_16, T_NEAR); |
639 | |
640 | L(load_loop_blk_24); |
641 | { |
642 | generate_diff_bias_loop(3); |
643 | generate_load_loop_body(3); |
644 | cmp(reg_load_loop_work, 32); |
645 | je(load_loop_blk_16); |
646 | cmp(reg_load_loop_work, 24); |
647 | jge(load_loop_blk_24); |
648 | } |
649 | |
650 | cmp(reg_load_loop_work, 8); |
651 | jle(load_loop_blk_8, T_NEAR); |
652 | |
653 | L(load_loop_blk_16); |
654 | { |
655 | generate_diff_bias_loop(2); |
656 | generate_load_loop_body(2); |
657 | cmp(reg_load_loop_work, 16); |
658 | jge(load_loop_blk_16); |
659 | } |
660 | |
661 | L(load_loop_blk_8); |
662 | { |
663 | cmp(reg_load_loop_work, 0); |
664 | jle(load_loop_blk_end, T_NEAR); |
665 | generate_diff_bias_loop(1); |
666 | generate_load_loop_body(1); |
667 | } |
668 | |
669 | L(load_loop_blk_end); |
670 | |
671 | add(rsp, stack_space_needed); |
672 | |
673 | postamble(); |
674 | |
675 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
676 | } |
677 | |
678 | status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, |
679 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
680 | const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, |
681 | const primitive_attr_t &attr) { |
682 | if (!mayiuse(avx)) return status::unimplemented; |
683 | jcp.isa = mayiuse(avx2) ? avx2 : avx; |
684 | |
685 | // TODO (Roma): this code is duplicated from the generic kernel; maybe the |
686 | // configuration struct could do some stuff below |
687 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
688 | const int ndims = src_d.ndims(); |
689 | |
690 | jcp.nthr = dnnl_get_max_threads(); |
691 | |
692 | jcp.prop_kind = cd.prop_kind; |
693 | |
694 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
695 | jcp.mb = src_d.dims()[0]; |
696 | |
697 | jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; |
698 | jcp.oc = jcp.oc_without_padding; |
699 | jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; |
700 | jcp.ic = jcp.ic_without_padding; |
701 | |
702 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
703 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
704 | jcp.iw = src_d.dims()[ndims - 1]; |
705 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
706 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
707 | jcp.ow = dst_d.dims()[ndims - 1]; |
708 | |
709 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
710 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
711 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
712 | |
713 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
714 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
715 | jcp.l_pad = cd.padding[0][ndims - 3]; |
716 | |
717 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
718 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
719 | jcp.stride_w = cd.strides[ndims - 3]; |
720 | |
721 | jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, |
722 | format_kind::undef, cd.diff_bias_desc.format_kind) |
723 | != format_kind::undef; |
724 | |
725 | jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow; |
726 | jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw; |
727 | |
728 | jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type); |
729 | jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type); |
730 | |
731 | const auto &post_ops = attr.post_ops_; |
732 | const int dw_conv_ind = post_ops.find(primitive_kind::convolution); |
733 | jcp.with_dw_conv = dw_conv_ind != -1; |
734 | |
735 | // Using dw_conv_ind as upper-bound below, as post-ops after it will be |
736 | // handled in depthwise convolution. |
737 | const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); |
738 | jcp.with_sum = sum_ind != -1; |
739 | const int eltwise_ind |
740 | = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); |
741 | jcp.with_eltwise = eltwise_ind != -1; |
742 | const int binary_ind |
743 | = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); |
744 | jcp.with_binary = binary_ind != -1; |
745 | |
746 | if (dw_conv_ind >= 0) { |
747 | // dw_conv and post_ops after it are handled externally, so skip them |
748 | jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), |
749 | post_ops.entry_.cbegin() + dw_conv_ind); |
750 | } else { |
751 | jcp.post_ops = post_ops; |
752 | } |
753 | |
754 | const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc, ndhwc); |
755 | const auto dat_tag_nCx8c = utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
756 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
757 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
758 | const bool is_data_layout_nxc |
759 | = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
760 | const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
761 | |
762 | const int is_bwd_d = jcp.prop_kind == backward_data; |
763 | format_tag_t wei_tag = with_groups |
764 | ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, |
765 | gOIhw8i8o, gOIdhw8o8i, gOIhw8i8o, gOIdhw8o8i) |
766 | : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, |
767 | OIhw8o8i, OIdhw8i8o, OIdhw8o8i); |
768 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
769 | |
770 | const int simd_w = 8; |
771 | |
772 | bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1; |
773 | if (ok_to_pad_channels) { |
774 | jcp.oc = rnd_up(jcp.oc, simd_w); |
775 | jcp.ic = rnd_up(jcp.ic, simd_w); |
776 | } |
777 | |
778 | if (jcp.with_eltwise || jcp.with_binary) |
779 | if (jcp.isa < avx2) return status::unimplemented; |
780 | |
781 | using namespace injector; |
782 | static constexpr bool sum_at_pos_0_only = true; |
783 | static constexpr bool sum_requires_scale_one = true; |
784 | static constexpr bool sum_requires_zp_zero = true; |
785 | const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum}, |
786 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
787 | sum_requires_zp_zero}); |
788 | if (!post_ops_ok_) return status::unimplemented; |
789 | |
790 | bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag |
791 | && jcp.wei_tag == wei_tag && jcp.dst_tag == dat_tag; |
792 | if (!args_ok) return status::unimplemented; |
793 | |
794 | args_ok = true && jcp.id == jcp.od && jcp.ih == jcp.oh && jcp.iw == jcp.ow |
795 | && IMPLICATION(!is_data_layout_nxc, |
796 | jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) |
797 | && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 |
798 | && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1 |
799 | && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1; |
800 | if (!args_ok) return status::unimplemented; |
801 | |
802 | // TODO: remove this restriction |
803 | // optimized 1x1 bwd_w does not support Intel AVX |
804 | if (jcp.prop_kind == backward_weights && jcp.isa != avx2) |
805 | return status::unimplemented; |
806 | |
807 | jcp.ic_block = jcp.oc_block = simd_w; |
808 | |
809 | jcp.ur = jcp.isa == avx2 ? 4 : 3; // Intel AVX support |
810 | if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); |
811 | |
812 | int load_blocking {0}; |
813 | int load_blocking_max {0}; |
814 | int bcast_blocking {0}; |
815 | int bcast_blocking_max {0}; |
816 | int reduce_blocking {0}; |
817 | int reduce_blocking_max {0}; |
818 | |
819 | if (one_of(jcp.prop_kind, forward_training, forward_inference)) { |
820 | jcp.reduce_dim = jcp.ic; |
821 | jcp.reduce_block = jcp.ic_block; |
822 | |
823 | jcp.load_dim = jcp.oc; |
824 | jcp.load_block = jcp.oc_block; |
825 | |
826 | jcp.bcast_dim = jcp.is; |
827 | jcp.bcast_block = jcp.ur; |
828 | |
829 | jcp.reduce_loop_unroll = jcp.reduce_block; |
830 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll |
831 | * (is_data_layout_nxc ? 1 : jcp.is) * sizeof(float); |
832 | jcp.reduce_loop_load_step |
833 | = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); |
834 | |
835 | jcp.bcast_loop_output_step = jcp.ur |
836 | * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float); |
837 | jcp.bcast_loop_output_substep = -1; // unused |
838 | jcp.bcast_loop_bcast_step = jcp.ur |
839 | * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float); |
840 | jcp.bcast_loop_bcast_substep = -1; // unused |
841 | |
842 | jcp.load_loop_load_step |
843 | = rnd_up(jcp.ic, jcp.ic_block) * jcp.oc_block * sizeof(float); |
844 | jcp.load_loop_iter_step = jcp.oc_block; |
845 | |
846 | load_blocking = is_data_layout_nxc |
847 | ? jcp.load_dim |
848 | : 120; // assumes the kernel is jcp.ur x 3 |
849 | load_blocking_max = is_data_layout_nxc ? jcp.load_dim : 144; |
850 | bcast_blocking = 128; // affects load balancing across threads |
851 | bcast_blocking_max = 192; |
852 | reduce_blocking = is_data_layout_nxc ? jcp.reduce_dim |
853 | : 128; // affects L1$ utilization |
854 | } else if (jcp.prop_kind == backward_data) { |
855 | jcp.reduce_dim = jcp.oc; |
856 | jcp.reduce_block = jcp.oc_block; |
857 | |
858 | jcp.load_dim = jcp.ic; |
859 | jcp.load_block = jcp.ic_block; |
860 | |
861 | jcp.bcast_dim = jcp.os; |
862 | jcp.bcast_block = jcp.ur; |
863 | |
864 | jcp.reduce_loop_unroll = jcp.reduce_block; |
865 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll |
866 | * (is_data_layout_nxc ? 1 : jcp.os) * sizeof(float); |
867 | jcp.reduce_loop_load_step = jcp.reduce_loop_unroll |
868 | * rnd_up(jcp.ic, jcp.ic_block) * sizeof(float); |
869 | |
870 | jcp.bcast_loop_output_step = jcp.ur |
871 | * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float); |
872 | jcp.bcast_loop_output_substep = -1; // unused |
873 | jcp.bcast_loop_bcast_step = jcp.ur |
874 | * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float); |
875 | jcp.bcast_loop_bcast_substep = -1; // unused |
876 | |
877 | jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); |
878 | jcp.load_loop_iter_step = jcp.ic_block; |
879 | |
880 | load_blocking = is_data_layout_nxc |
881 | ? jcp.load_dim |
882 | : 96; // assumes the kernel is jcp.ur x 3 |
883 | load_blocking_max = is_data_layout_nxc ? jcp.load_dim : 144; |
884 | |
885 | bcast_blocking = 128; // affects load balancing across threads |
886 | bcast_blocking_max = 196; |
887 | reduce_blocking = is_data_layout_nxc ? jcp.reduce_dim |
888 | : 64; // affects L1$ utilization |
889 | } else if (jcp.prop_kind == backward_weights) { |
890 | jcp.reduce_dim = jcp.os; |
891 | jcp.reduce_block = 1; |
892 | |
893 | jcp.load_dim = jcp.oc; |
894 | jcp.load_block = jcp.oc_block; |
895 | |
896 | jcp.bcast_dim = jcp.ic; |
897 | jcp.bcast_block = jcp.ic_block; |
898 | |
899 | jcp.reduce_loop_unroll = jcp.reduce_block; |
900 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll |
901 | * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float); |
902 | jcp.reduce_loop_load_step = jcp.reduce_loop_unroll |
903 | * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float); |
904 | |
905 | jcp.bcast_loop_output_step |
906 | = jcp.oc_block * jcp.ic_block * sizeof(float); |
907 | jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); |
908 | jcp.bcast_loop_bcast_step = jcp.ic_block |
909 | * (is_data_layout_nxc ? 1 : jcp.is) * sizeof(float); |
910 | jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); |
911 | |
912 | jcp.load_loop_load_step = jcp.oc_block |
913 | * (is_data_layout_nxc ? 1 : jcp.os) * sizeof(float); |
914 | jcp.load_loop_iter_step = jcp.oc_block; |
915 | |
916 | /* --- */ |
917 | |
918 | load_blocking = div_up(jcp.load_dim, jcp.load_block); |
919 | const bool no_load_tail = jcp.load_dim % jcp.load_block == 0; |
920 | const bool modify_load_blocking |
921 | = IMPLICATION(is_data_layout_nxc, no_load_tail); |
922 | while (modify_load_blocking) { |
923 | if (load_blocking <= 32) |
924 | break; |
925 | else if (load_blocking % 2 == 0) |
926 | load_blocking /= 2; |
927 | else if (load_blocking % 3 == 0) |
928 | load_blocking /= 3; |
929 | else |
930 | break; |
931 | } |
932 | load_blocking *= jcp.load_block; |
933 | load_blocking_max = load_blocking; |
934 | assert(IMPLICATION( |
935 | !is_data_layout_nxc, jcp.load_dim % load_blocking == 0)); |
936 | |
937 | bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); |
938 | const int bcast_blocking_lim = is_data_layout_nxc ? 17 : 9; |
939 | const bool no_bcast_tail = jcp.bcast_dim % jcp.bcast_block == 0; |
940 | const bool small_size_for_bcast |
941 | = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw <= 1024; |
942 | |
943 | // TODO Verify if the size limitation helps for blocked format as well |
944 | const bool modify_bcast_blocking = IMPLICATION( |
945 | is_data_layout_nxc, no_bcast_tail && small_size_for_bcast); |
946 | |
947 | while (modify_bcast_blocking) { |
948 | if (bcast_blocking <= bcast_blocking_lim) |
949 | break; |
950 | else if (bcast_blocking % 2 == 0) |
951 | bcast_blocking /= 2; |
952 | else if (bcast_blocking % 3 == 0) |
953 | bcast_blocking /= 3; |
954 | else |
955 | break; |
956 | } |
957 | bcast_blocking *= jcp.bcast_block; |
958 | bcast_blocking_max = bcast_blocking; |
959 | assert(IMPLICATION( |
960 | !is_data_layout_nxc, jcp.bcast_dim % bcast_blocking == 0)); |
961 | |
962 | reduce_blocking = is_data_layout_nxc |
963 | ? rnd_up(nstl::min(jcp.ow, 128), jcp.reduce_block) |
964 | : 128; // affects L1$ utilization |
965 | reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); |
966 | } else |
967 | return status::unimplemented; |
968 | |
969 | assert(load_blocking); |
970 | assert(load_blocking_max); |
971 | assert(bcast_blocking); |
972 | assert(bcast_blocking_max); |
973 | assert(reduce_blocking); |
974 | |
975 | assert(jcp.bcast_block % jcp.ur == 0); |
976 | jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.bcast_block; |
977 | |
978 | jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; |
979 | jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; |
980 | jcp.nb_load_blocking = div_up(load_blocking, jcp.load_block); |
981 | jcp.nb_load_blocking_max = div_up(load_blocking_max, jcp.load_block); |
982 | jcp.nb_reduce_blocking = div_up(reduce_blocking, jcp.reduce_block); |
983 | jcp.nb_reduce_blocking_max = div_up(reduce_blocking_max, jcp.reduce_block); |
984 | |
985 | jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
986 | jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); |
987 | jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
988 | |
989 | if (jcp.prop_kind == backward_weights) { |
990 | const auto mb_with_nb_reduce |
991 | = static_cast<dim_t>(jcp.mb) * jcp.nb_reduce; |
992 | // prevent too large argument to cpu reducer |
993 | if (mb_with_nb_reduce > std::numeric_limits<int>::max()) |
994 | return status::unimplemented; |
995 | } |
996 | |
997 | return status::success; |
998 | } |
999 | |
1000 | void jit_avx2_1x1_conv_kernel_f32::init_scratchpad( |
1001 | memory_tracking::registrar_t &scratchpad, |
1002 | const jit_1x1_conv_conf_t &jcp) { |
1003 | using namespace dnnl::impl::memory_tracking::names; |
1004 | |
1005 | if (jcp.with_bias && jcp.prop_kind != backward_data |
1006 | && (jcp.oc != jcp.oc_without_padding // blocked format |
1007 | || (jcp.prop_kind == backward_weights // nxc format |
1008 | && jcp.oc % jcp.oc_block != 0))) { |
1009 | const size_t nelems_padded_bias |
1010 | = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block); |
1011 | scratchpad.book<float>(key_conv_padded_bias, nelems_padded_bias); |
1012 | } |
1013 | } |
1014 | |
1015 | } // namespace x64 |
1016 | } // namespace cpu |
1017 | } // namespace impl |
1018 | } // namespace dnnl |
1019 | |