1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/memory.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
24 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
25 | #include "cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp" |
26 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
27 | |
28 | #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace dnnl::impl::format_tag; |
36 | using namespace dnnl::impl::prop_kind; |
37 | using namespace dnnl::impl::utils; |
38 | |
39 | using namespace Xbyak; |
40 | |
41 | jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32( |
42 | const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, |
43 | const memory_desc_t &dst_md) |
44 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, sse41) |
45 | , jcp(ajcp) |
46 | , attr_(attr) { |
47 | if (jcp.with_eltwise || jcp.with_binary) { |
48 | static constexpr bool preserve_gpr = true; |
49 | static constexpr bool preserve_vmm = false; |
50 | static constexpr size_t helper_vmm_idx = 15; |
51 | const size_t tail_size = jcp.oc_without_padding % simd_w_; |
52 | static constexpr bool use_exact_tail_scalar_bcast = false; |
53 | |
54 | const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { |
55 | helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm, |
56 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
57 | memory_desc_wrapper(dst_md), tail_size, |
58 | use_exact_tail_scalar_bcast}; |
59 | const binary_injector::static_params_t static_params { |
60 | this->param1, rhs_arg_static_params}; |
61 | postops_injector_ = utils::make_unique< |
62 | injector::jit_uni_postops_injector_t<sse41>>( |
63 | this, jcp.post_ops, static_params); |
64 | } |
65 | } |
66 | |
67 | void jit_sse41_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) { |
68 | mov(aux1_reg_bcast_data, reg_bcast_data); |
69 | mov(aux_reg_output_data, reg_output_data); |
70 | mov(bcast_loop_iter, reg_bcast_loop_work); |
71 | |
72 | Label bcast_loop; |
73 | Label bcast_loop_tail; |
74 | |
75 | cmp(bcast_loop_iter, jcp.ur); |
76 | jl(bcast_loop_tail, T_NEAR); |
77 | |
78 | L(bcast_loop); |
79 | { |
80 | assert(jcp.bcast_block % jcp.ur == 0); |
81 | int num_substeps = jcp.bcast_block / jcp.ur; |
82 | assert(num_substeps > 0 && num_substeps < 10); |
83 | for (int i = 0; i < num_substeps; i++) { |
84 | generate_reduce_loop(load_loop_blk, jcp.ur); |
85 | if (i < num_substeps - 1) { |
86 | add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); |
87 | add(aux_reg_output_data, jcp.bcast_loop_output_substep); |
88 | } else { |
89 | add(aux1_reg_bcast_data, |
90 | jcp.bcast_loop_bcast_step |
91 | - (num_substeps - 1) |
92 | * jcp.bcast_loop_bcast_substep); |
93 | add(aux_reg_output_data, |
94 | jcp.bcast_loop_output_step |
95 | - (num_substeps - 1) |
96 | * jcp.bcast_loop_output_substep); |
97 | } |
98 | } |
99 | sub(bcast_loop_iter, jcp.bcast_block); |
100 | cmp(bcast_loop_iter, jcp.bcast_block); |
101 | jge(bcast_loop, T_NEAR); |
102 | } |
103 | |
104 | L(bcast_loop_tail); |
105 | if (jcp.ur_tail) { |
106 | Label bcast_loop_tail_out; |
107 | cmp(bcast_loop_iter, 0); |
108 | jz(bcast_loop_tail_out, T_NEAR); |
109 | generate_reduce_loop(load_loop_blk, jcp.ur_tail); |
110 | L(bcast_loop_tail_out); |
111 | } |
112 | } |
113 | |
114 | size_t jit_sse41_1x1_conv_kernel_f32::get_fwd_output_ptr_l_off( |
115 | int i, int j, int n) const { |
116 | return i * get_output_i_offset(jcp) + j * get_output_j_offset(jcp) + n * 4; |
117 | } |
118 | |
119 | static int reg_accum_idx( |
120 | const int load_loop_blk, const int i, const int j, const int n) { |
121 | return 2 * j * load_loop_blk + 2 * i + n + 1; |
122 | } |
123 | |
124 | template <typename F> |
125 | static void iterate(const int load_loop_blk, const int ur, const F &f) { |
126 | for (int j = 0; j < ur; ++j) |
127 | for (int i = 0; i < load_loop_blk; ++i) |
128 | for (int n = 0; n < 2; n++) |
129 | f(i, j, n); |
130 | } |
131 | void jit_sse41_1x1_conv_kernel_f32::apply_postops( |
132 | const int load_loop_blk, const int ur) { |
133 | injector_utils::vmm_index_set_t vmm_idxs; |
134 | if (jcp.with_binary) { |
135 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
136 | iterate(load_loop_blk, ur, [&](const int i, const int j, const int n) { |
137 | const bool mask_flag = (2 * i + n) == load_loop_blk - 1; |
138 | const size_t aux_output_offset |
139 | = get_fwd_output_ptr_l_off(i, j, n) * sizeof(float); |
140 | const auto vmm_idx = reg_accum_idx(load_loop_blk, i, j, n); |
141 | vmm_idxs.emplace(vmm_idx); |
142 | |
143 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
144 | vmm_idx, aux_reg_output_data); |
145 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
146 | vmm_idx, aux_output_offset); |
147 | if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
148 | }); |
149 | const injector_utils::register_preserve_guard_t register_guard( |
150 | this, {abi_param1}); |
151 | const size_t reg_guard_stack_occupied |
152 | = register_guard.stack_space_occupied(); |
153 | mov(abi_param1, |
154 | ptr[rsp + reg_abi_param1_backup + reg_guard_stack_occupied]); |
155 | |
156 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
157 | } else { |
158 | iterate(load_loop_blk, ur, [&](const int i, const int j, const int n) { |
159 | vmm_idxs.emplace(reg_accum_idx(load_loop_blk, i, j, n)); |
160 | }); |
161 | postops_injector_->compute_vector_range(vmm_idxs); |
162 | } |
163 | } |
164 | |
165 | void jit_sse41_1x1_conv_kernel_f32::generate_reduce_loop( |
166 | int load_loop_blk, int ur) { |
167 | auto reg_load = [=](int i, int n) { |
168 | return Xmm(2 * ur * load_loop_blk + 2 * i + n + 1); |
169 | }; |
170 | |
171 | auto reg_accum = [=](int i, int j, int n) { |
172 | return Xmm(reg_accum_idx(load_loop_blk, i, j, n)); |
173 | }; |
174 | |
175 | auto bias_ptr = [=](int i, int n) { |
176 | return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i |
177 | + n * 4 * sizeof(float)]; |
178 | }; |
179 | |
180 | auto bcast_ptr = [=](int u, int j) { |
181 | assert(j < jcp.ur); |
182 | assert(u <= jcp.reduce_loop_unroll); |
183 | size_t offt; |
184 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
185 | backward_data)) { |
186 | assert(jcp.reduce_loop_unroll == jcp.reduce_block); |
187 | offt = get_bcast_offset(jcp, u, j); |
188 | } else |
189 | offt = u * jcp.ic_block + j; |
190 | return ptr[aux_reg_bcast_data + offt]; |
191 | }; |
192 | |
193 | auto load_ptr = [=](int u, int i, int n) { |
194 | size_t offt; |
195 | size_t u0 = u % jcp.reduce_loop_unroll; |
196 | size_t u1 = u / jcp.reduce_loop_unroll; |
197 | switch (jcp.prop_kind) { |
198 | case backward_data: |
199 | offt = (i * jcp.oc_block + u0) * jcp.ic_block; |
200 | break; |
201 | case backward_weights: |
202 | offt = (i * jcp.os + u0) * jcp.oc_block; |
203 | break; |
204 | default: offt = (i * jcp.ic + u0) * jcp.oc_block; |
205 | } |
206 | return ptr[aux_reg_load_data + u1 * jcp.reduce_loop_load_step |
207 | + sizeof(float) * offt + n * 4 * sizeof(float)]; |
208 | }; |
209 | |
210 | auto output_ptr = [=](int i, int j, int n) { |
211 | switch (jcp.prop_kind) { |
212 | case backward_data: |
213 | return ptr[aux_reg_output_data |
214 | + (i * jcp.is + j) * jcp.ic_block * sizeof(float) |
215 | + n * 4 * sizeof(float)]; |
216 | case backward_weights: |
217 | return ptr[aux_reg_output_data |
218 | + (i ? reg_output_stride * i |
219 | : 0) // TODO: Xbyak should allow 0 scale |
220 | + sizeof(float) * jcp.oc_block * j |
221 | + n * 4 * sizeof(float)]; |
222 | default: |
223 | return ptr[aux_reg_output_data |
224 | + get_fwd_output_ptr_l_off(i, j, n) * sizeof(float)]; |
225 | } |
226 | }; |
227 | |
228 | auto init = [=]() { |
229 | Label init_done; |
230 | Label init_zero; |
231 | |
232 | if (jcp.with_bias |
233 | && one_of(jcp.prop_kind, forward_training, forward_inference)) { |
234 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
235 | jz(init_zero); |
236 | |
237 | for (int i = 0; i < load_loop_blk; i++) |
238 | for (int j = 0; j < ur; ++j) { |
239 | movups(reg_accum(i, j, 0), bias_ptr(i, 0)); |
240 | movups(reg_accum(i, j, 1), bias_ptr(i, 1)); |
241 | } |
242 | jmp(init_done); |
243 | } |
244 | |
245 | L(init_zero); |
246 | for (int i = 0; i < load_loop_blk; ++i) |
247 | for (int j = 0; j < ur; ++j) { |
248 | auto r0 = reg_accum(i, j, 0); |
249 | auto r1 = reg_accum(i, j, 1); |
250 | xorps(r0, r0); |
251 | xorps(r1, r1); |
252 | } |
253 | |
254 | L(init_done); |
255 | |
256 | // load weights |
257 | for (int i = 0; i < load_loop_blk; ++i) { |
258 | movups(reg_load(i, 0), load_ptr(0, i, 0)); |
259 | movups(reg_load(i, 1), load_ptr(0, i, 1)); |
260 | } |
261 | |
262 | movss(reg_bcast, bcast_ptr(0, 0)); |
263 | shufps(reg_bcast, reg_bcast, 0); |
264 | }; // init() |
265 | |
266 | auto store = [=]() { |
267 | Label store_noadd; |
268 | |
269 | if (!jcp.with_sum) { |
270 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
271 | jnz(store_noadd, T_NEAR); |
272 | } |
273 | |
274 | for (int j = 0; j < ur; ++j) |
275 | for (int i = 0; i < load_loop_blk; ++i) { |
276 | auto r0 = reg_accum(i, j, 0); |
277 | auto r1 = reg_accum(i, j, 1); |
278 | addps(r0, output_ptr(i, j, 0)); |
279 | addps(r1, output_ptr(i, j, 1)); |
280 | } |
281 | |
282 | L(store_noadd); |
283 | |
284 | if (jcp.with_eltwise || jcp.with_binary) { |
285 | assert(ur * load_loop_blk < 14); |
286 | |
287 | Label store_nopostops; |
288 | test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); |
289 | jz(store_nopostops, T_NEAR); |
290 | |
291 | apply_postops(load_loop_blk, ur); |
292 | |
293 | L(store_nopostops); |
294 | } |
295 | |
296 | for (int j = 0; j < ur; ++j) |
297 | for (int i = 0; i < load_loop_blk; ++i) { |
298 | movups(output_ptr(i, j, 0), reg_accum(i, j, 0)); |
299 | movups(output_ptr(i, j, 1), reg_accum(i, j, 1)); |
300 | } |
301 | }; |
302 | |
303 | auto fma_block = [=](bool last_block) { |
304 | for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { |
305 | for (int j = 0; j < ur; ++j) { |
306 | for (int i = 0; i < load_loop_blk; ++i) { |
307 | mulps(reg_load(i, 0), reg_bcast); |
308 | mulps(reg_load(i, 1), reg_bcast); |
309 | addps(reg_accum(i, j, 0), reg_load(i, 0)); |
310 | addps(reg_accum(i, j, 1), reg_load(i, 1)); |
311 | |
312 | if (j == ur - 1 |
313 | && !(last_block |
314 | && u == jcp.reduce_loop_unroll - 1)) { |
315 | movups(reg_load(i, 0), load_ptr(u + 1, i, 0)); |
316 | movups(reg_load(i, 1), load_ptr(u + 1, i, 1)); |
317 | } |
318 | } |
319 | if (j < ur - 1) { |
320 | movss(reg_bcast, bcast_ptr(u, j + 1)); |
321 | shufps(reg_bcast, reg_bcast, 0); |
322 | } |
323 | } // for ur |
324 | if (!last_block || u < jcp.reduce_loop_unroll - 1) { |
325 | movss(reg_bcast, bcast_ptr(u + 1, 0)); |
326 | shufps(reg_bcast, reg_bcast, 0); |
327 | } |
328 | } // for reduce_loop_unroll |
329 | }; |
330 | |
331 | Label reduce_loop; |
332 | Label reduce_loop_tail; |
333 | |
334 | mov(aux_reg_load_data, reg_load_data); |
335 | mov(aux_reg_bcast_data, aux1_reg_bcast_data); |
336 | |
337 | init(); |
338 | |
339 | mov(reduce_loop_iter, reg_reduce_loop_work); |
340 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
341 | jle(reduce_loop_tail, T_NEAR); |
342 | |
343 | L(reduce_loop); |
344 | { |
345 | fma_block(false); |
346 | add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); |
347 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
348 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
349 | jg(reduce_loop, T_NEAR); |
350 | } |
351 | |
352 | L(reduce_loop_tail); |
353 | fma_block(true); |
354 | |
355 | store(); |
356 | } // reduce_loop() |
357 | |
358 | void jit_sse41_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) { |
359 | if (!jcp.with_bias || jcp.prop_kind != backward_weights) return; |
360 | |
361 | Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; |
362 | Label diff_bias_load; |
363 | |
364 | auto diff_bias_ptr = [=](int i, int n) { |
365 | return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float) |
366 | + 4 * n * sizeof(float)]; |
367 | }; |
368 | |
369 | auto load_ptr = [=](int u, int i, int n) { |
370 | return ptr[aux_reg_load_data |
371 | + (i * jcp.os + u) * jcp.oc_block * sizeof(float) |
372 | + 4 * n * sizeof(float)]; |
373 | }; |
374 | |
375 | auto diff_bias_reg = [=](int i, int n) { return Xmm(2 * i + n + 1); }; |
376 | |
377 | mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); |
378 | cmp(reg_diff_bias_data, 0); |
379 | je(diff_bias_loop_out, T_NEAR); |
380 | |
381 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
382 | jz(diff_bias_load, T_NEAR); |
383 | |
384 | for (int i = 0; i < load_loop_blk; ++i) { |
385 | auto r0 = diff_bias_reg(i, 0); |
386 | auto r1 = diff_bias_reg(i, 1); |
387 | xorps(r0, r0); |
388 | xorps(r1, r1); |
389 | } |
390 | jmp(diff_bias_init_out, T_NEAR); |
391 | |
392 | L(diff_bias_load); |
393 | for (int i = 0; i < load_loop_blk; ++i) { |
394 | movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0)); |
395 | movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1)); |
396 | } |
397 | |
398 | L(diff_bias_init_out); |
399 | mov(aux_reg_load_data, reg_load_data); |
400 | mov(reduce_loop_iter, reg_reduce_loop_work); |
401 | L(diff_bias_loop); |
402 | { |
403 | for (int u = 0; u < jcp.reduce_loop_unroll; ++u) |
404 | for (int i = 0; i < load_loop_blk; ++i) { |
405 | addps(diff_bias_reg(i, 0), load_ptr(u, i, 0)); |
406 | addps(diff_bias_reg(i, 1), load_ptr(u, i, 1)); |
407 | } |
408 | assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); |
409 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
410 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
411 | jnz(diff_bias_loop, T_NEAR); |
412 | } |
413 | |
414 | for (int i = 0; i < load_loop_blk; i++) { |
415 | movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0)); |
416 | movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1)); |
417 | } |
418 | |
419 | add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); |
420 | mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); |
421 | |
422 | L(diff_bias_loop_out); |
423 | } |
424 | |
425 | void jit_sse41_1x1_conv_kernel_f32::generate() { |
426 | preamble(); |
427 | |
428 | sub(rsp, stack_space_needed); |
429 | if (jcp.with_binary) { |
430 | // backup abi_param1 for usage in post_ops processing |
431 | mov(ptr[rsp + reg_abi_param1_backup], abi_param1); |
432 | |
433 | // zero initialize binary post_ops offset accumulator (store on stack) |
434 | const auto zeroed_reg = r15; |
435 | xor_(zeroed_reg, zeroed_reg); |
436 | mov(ptr[rsp + reg_binary_post_op_acc_off], zeroed_reg); |
437 | } |
438 | |
439 | mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); |
440 | mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); |
441 | mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); |
442 | if (jcp.with_bias) { |
443 | if (jcp.prop_kind == backward_weights) { |
444 | mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
445 | mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); |
446 | } else |
447 | mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
448 | } |
449 | |
450 | mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); |
451 | mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); |
452 | mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); |
453 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
454 | if (jcp.prop_kind == backward_weights) |
455 | mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); |
456 | |
457 | auto generate_load_loop_body = [=](int load_loop_blk) { |
458 | generate_bcast_loop(load_loop_blk); |
459 | add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); |
460 | switch (jcp.prop_kind) { |
461 | case forward_training: |
462 | case forward_inference: |
463 | add(reg_bias_data, |
464 | load_loop_blk * jcp.oc_block * sizeof(float)); |
465 | add(reg_output_data, |
466 | get_load_loop_output_fwd_offset(jcp, load_loop_blk)); |
467 | if (jcp.with_binary) { |
468 | mov(aux_reg_load_data, |
469 | EVEX_compress_addr( |
470 | rsp, reg_binary_post_op_acc_off)); |
471 | add(aux_reg_load_data, jcp.load_block * load_loop_blk); |
472 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), |
473 | aux_reg_load_data); |
474 | } |
475 | break; |
476 | case backward_data: |
477 | add(reg_output_data, |
478 | load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); |
479 | break; |
480 | case backward_weights: |
481 | for (int i = 0; i < load_loop_blk; i++) |
482 | add(reg_output_data, reg_output_stride); |
483 | break; |
484 | default: assert(!"invalid prop_kind" ); |
485 | } |
486 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
487 | }; |
488 | |
489 | Label load_loop_blk_8; |
490 | Label load_loop_blk_16; |
491 | Label load_loop_blk_24; |
492 | Label load_loop_blk_end; |
493 | |
494 | cmp(reg_load_loop_work, 8); |
495 | jle(load_loop_blk_8, T_NEAR); |
496 | |
497 | cmp(reg_load_loop_work, 32); |
498 | je(load_loop_blk_16, T_NEAR); |
499 | |
500 | cmp(reg_load_loop_work, 16); |
501 | jle(load_loop_blk_16, T_NEAR); |
502 | |
503 | L(load_loop_blk_24); |
504 | { |
505 | generate_diff_bias_loop(3); |
506 | generate_load_loop_body(3); |
507 | cmp(reg_load_loop_work, 32); |
508 | je(load_loop_blk_16); |
509 | cmp(reg_load_loop_work, 24); |
510 | jge(load_loop_blk_24); |
511 | } |
512 | |
513 | cmp(reg_load_loop_work, 8); |
514 | jle(load_loop_blk_8, T_NEAR); |
515 | |
516 | L(load_loop_blk_16); |
517 | { |
518 | generate_diff_bias_loop(2); |
519 | generate_load_loop_body(2); |
520 | cmp(reg_load_loop_work, 16); |
521 | jge(load_loop_blk_16); |
522 | } |
523 | |
524 | L(load_loop_blk_8); |
525 | { |
526 | cmp(reg_load_loop_work, 0); |
527 | je(load_loop_blk_end, T_NEAR); |
528 | generate_diff_bias_loop(1); |
529 | generate_load_loop_body(1); |
530 | } |
531 | |
532 | L(load_loop_blk_end); |
533 | |
534 | add(rsp, stack_space_needed); |
535 | |
536 | postamble(); |
537 | |
538 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
539 | } |
540 | |
541 | status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, |
542 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
543 | const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, |
544 | const primitive_attr_t &attr, int nthreads) { |
545 | if (!mayiuse(sse41)) return status::unimplemented; |
546 | |
547 | // TODO (Roma): this code is duplicated from the generic kernel; maybe the |
548 | // configuration struct could do some stuff below |
549 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
550 | const int ndims = src_d.ndims(); |
551 | |
552 | jcp.nthr = nthreads; |
553 | |
554 | jcp.prop_kind = cd.prop_kind; |
555 | |
556 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
557 | jcp.mb = src_d.dims()[0]; |
558 | |
559 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
560 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
561 | |
562 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; |
563 | jcp.iw = src_d.dims()[ndims - 1]; |
564 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; |
565 | jcp.ow = dst_d.dims()[ndims - 1]; |
566 | |
567 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; |
568 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
569 | |
570 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; |
571 | jcp.l_pad = cd.padding[0][ndims - 3]; |
572 | |
573 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; |
574 | jcp.stride_w = cd.strides[ndims - 3]; |
575 | |
576 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
577 | |
578 | jcp.os = jcp.oh * jcp.ow; |
579 | jcp.is = jcp.ih * jcp.iw; |
580 | |
581 | jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type); |
582 | jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type); |
583 | |
584 | const auto &post_ops = attr.post_ops_; |
585 | |
586 | const int dw_conv_ind = post_ops.find(primitive_kind::convolution); |
587 | jcp.with_dw_conv = dw_conv_ind != -1; |
588 | // Using dw_conv_ind as upper-bound below, as post-ops after it will be |
589 | // handled in depthwise convolution. |
590 | jcp.with_sum = post_ops.find(primitive_kind::sum, 0, dw_conv_ind) != -1; |
591 | const int eltwise_ind |
592 | = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); |
593 | jcp.with_eltwise = eltwise_ind != -1; |
594 | const int binary_ind |
595 | = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); |
596 | jcp.with_binary = binary_ind != -1; |
597 | |
598 | if (dw_conv_ind >= 0) { |
599 | // dw_conv and post_ops after it are handled externally, so skip them |
600 | jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), |
601 | post_ops.entry_.cbegin() + dw_conv_ind); |
602 | } else { |
603 | jcp.post_ops = post_ops; |
604 | } |
605 | |
606 | using namespace injector; |
607 | static constexpr bool sum_at_pos_0_only = true; |
608 | static constexpr bool sum_requires_scale_one = true; |
609 | static constexpr bool sum_requires_zp_zero = true; |
610 | const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum}, |
611 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
612 | sum_requires_zp_zero}); |
613 | if (!post_ops_ok_) return status::unimplemented; |
614 | |
615 | const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc); |
616 | const auto dat_tag_blocked = utils::pick(ndims - 3, nCw8c, nChw8c); |
617 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); |
618 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); |
619 | const bool is_data_layout_nxc |
620 | = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
621 | const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked; |
622 | |
623 | const int is_bwd_d = jcp.prop_kind == backward_data; |
624 | format_tag_t wei_tag = with_groups |
625 | ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, |
626 | gOIhw8i8o, gOIhw8o8i) |
627 | : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, |
628 | OIhw8o8i); |
629 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
630 | |
631 | bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag |
632 | && jcp.wei_tag == wei_tag && jcp.dst_tag == dat_tag; |
633 | if (!args_ok) return status::unimplemented; |
634 | |
635 | const int simd_w = 4; |
636 | |
637 | jcp.ic_block = jcp.oc_block = simd_w * 2; |
638 | |
639 | args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 |
640 | && jcp.t_pad == 0 && jcp.l_pad == 0 && jcp.stride_w == 1 |
641 | && jcp.stride_h == 1 // TODO: support some strides |
642 | && jcp.ow == jcp.iw && jcp.oh == jcp.ih // enforce rpad=0 |
643 | && jcp.kh == 1 && jcp.kw == 1; |
644 | if (!args_ok) return status::unimplemented; |
645 | |
646 | jcp.ur = 1; |
647 | if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); |
648 | |
649 | int load_blocking {0}; |
650 | int load_blocking_max {0}; |
651 | int bcast_blocking {0}; |
652 | int bcast_blocking_max {0}; |
653 | int reduce_blocking {0}; |
654 | |
655 | if (one_of(jcp.prop_kind, forward_training, forward_inference)) { |
656 | jcp.reduce_dim = jcp.ic; |
657 | jcp.reduce_block = jcp.ic_block; |
658 | |
659 | jcp.load_dim = jcp.oc; |
660 | jcp.load_block = jcp.oc_block; |
661 | |
662 | jcp.bcast_dim = jcp.is; |
663 | jcp.bcast_block = jcp.ur; |
664 | |
665 | jcp.reduce_loop_unroll = jcp.reduce_block; |
666 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll |
667 | * (is_data_layout_nxc ? 1 : jcp.is) * sizeof(float); |
668 | jcp.reduce_loop_load_step |
669 | = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); |
670 | |
671 | jcp.bcast_loop_output_step = jcp.ur |
672 | * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float); |
673 | jcp.bcast_loop_output_substep = -1; // unused |
674 | jcp.bcast_loop_bcast_step = jcp.ur |
675 | * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float); |
676 | jcp.bcast_loop_bcast_substep = -1; // unused |
677 | |
678 | jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); |
679 | jcp.load_loop_iter_step = jcp.oc_block; |
680 | |
681 | load_blocking = 120; // assumes the kernel is jcp.ur x 3 |
682 | load_blocking_max = 144; |
683 | bcast_blocking = 128; // affects load balancing across threads |
684 | bcast_blocking_max = 192; |
685 | reduce_blocking = 128; // affects L1$ utilization |
686 | } else if (jcp.prop_kind == backward_data) { |
687 | jcp.reduce_dim = jcp.oc; |
688 | jcp.reduce_block = jcp.oc_block; |
689 | |
690 | jcp.load_dim = jcp.ic; |
691 | jcp.load_block = jcp.oc_block; |
692 | |
693 | jcp.bcast_dim = jcp.os; |
694 | jcp.bcast_block = jcp.ur; |
695 | |
696 | jcp.reduce_loop_unroll = jcp.reduce_block; |
697 | jcp.reduce_loop_bcast_step |
698 | = jcp.reduce_loop_unroll * jcp.os * sizeof(float); |
699 | jcp.reduce_loop_load_step |
700 | = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); |
701 | |
702 | jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); |
703 | jcp.bcast_loop_output_substep = -1; // unused |
704 | jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); |
705 | jcp.bcast_loop_bcast_substep = -1; // unused |
706 | |
707 | jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); |
708 | jcp.load_loop_iter_step = jcp.ic_block; |
709 | |
710 | load_blocking = 96; // assumes the kernel is jcp.ur x 3 |
711 | load_blocking_max = 144; |
712 | bcast_blocking = 128; // affects load balancing across threads |
713 | bcast_blocking_max = 196; |
714 | reduce_blocking = 64; // affects L1$ utilization |
715 | } else if (jcp.prop_kind == backward_weights) { |
716 | jcp.reduce_dim = jcp.os; |
717 | jcp.reduce_block = 1; |
718 | |
719 | jcp.load_dim = jcp.oc; |
720 | jcp.load_block = jcp.oc_block; |
721 | |
722 | jcp.bcast_dim = jcp.ic; |
723 | jcp.bcast_block = jcp.ic_block; |
724 | |
725 | jcp.reduce_loop_unroll = jcp.reduce_block; |
726 | jcp.reduce_loop_bcast_step |
727 | = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); |
728 | jcp.reduce_loop_load_step |
729 | = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); |
730 | |
731 | jcp.bcast_loop_output_step |
732 | = jcp.oc_block * jcp.ic_block * sizeof(float); |
733 | jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); |
734 | jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); |
735 | jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); |
736 | |
737 | jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); |
738 | jcp.load_loop_iter_step = jcp.oc_block; |
739 | |
740 | /* --- */ |
741 | |
742 | load_blocking = div_up(jcp.load_dim, jcp.load_block); |
743 | while (true) { |
744 | if (load_blocking <= 32) |
745 | break; |
746 | else if (load_blocking % 2 == 0) |
747 | load_blocking /= 2; |
748 | else if (load_blocking % 3 == 0) |
749 | load_blocking /= 3; |
750 | else |
751 | break; |
752 | } |
753 | load_blocking *= jcp.load_block; |
754 | load_blocking_max = load_blocking; |
755 | assert(jcp.load_dim % load_blocking == 0); |
756 | |
757 | bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); |
758 | while (true) { |
759 | if (bcast_blocking <= 9) |
760 | break; |
761 | else if (bcast_blocking % 2 == 0) |
762 | bcast_blocking /= 2; |
763 | else if (bcast_blocking % 3 == 0) |
764 | bcast_blocking /= 3; |
765 | else |
766 | break; |
767 | } |
768 | bcast_blocking *= jcp.bcast_block; |
769 | bcast_blocking_max = bcast_blocking; |
770 | assert(jcp.bcast_dim % bcast_blocking == 0); |
771 | |
772 | reduce_blocking = 128; // affects L1$ utilization |
773 | } else |
774 | return status::unimplemented; |
775 | |
776 | assert(load_blocking); |
777 | assert(load_blocking_max); |
778 | assert(bcast_blocking); |
779 | assert(bcast_blocking_max); |
780 | assert(reduce_blocking); |
781 | |
782 | assert(jcp.bcast_block % jcp.ur == 0); |
783 | jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; |
784 | |
785 | jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; |
786 | jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; |
787 | jcp.nb_load_blocking = load_blocking / jcp.load_block; |
788 | jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; |
789 | jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; |
790 | |
791 | jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
792 | jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); |
793 | jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
794 | |
795 | return status::success; |
796 | } |
797 | |
798 | } // namespace x64 |
799 | } // namespace cpu |
800 | } // namespace impl |
801 | } // namespace dnnl |
802 | |