1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/memory.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21
22#include "cpu/x64/injectors/injector_utils.hpp"
23#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
24#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
25#include "cpu/x64/jit_sse41_conv_kernel_f32.hpp"
26
27#define GET_OFF(field) offsetof(jit_conv_call_s, field)
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34using namespace dnnl::impl::format_tag;
35using namespace dnnl::impl::prop_kind;
36using namespace dnnl::impl::utils;
37
38using namespace Xbyak;
39
40jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32(
41 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
42 const memory_desc_t &dst_md)
43 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, sse41)
44 , jcp(ajcp)
45 , attr_(attr) {
46 if (jcp.with_eltwise || jcp.with_binary) {
47 static constexpr bool preserve_gpr = true;
48 static constexpr bool preserve_vmm = false;
49 static constexpr size_t helper_vmm_idx = 15;
50 const size_t tail_size = jcp.oc_without_padding % simd_w_;
51 static constexpr bool use_exact_tail_scalar_bcast = false;
52
53 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
54 helper_vmm_idx, r14, r15, r12, preserve_gpr, preserve_vmm,
55 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
56 memory_desc_wrapper(dst_md), tail_size,
57 use_exact_tail_scalar_bcast};
58 const binary_injector::static_params_t static_params {
59 this->param1, rhs_arg_static_params};
60
61 postops_injector_ = utils::make_unique<
62 injector::jit_uni_postops_injector_t<sse41>>(
63 this, jcp.post_ops, static_params);
64 }
65}
66
67void jit_sse41_conv_fwd_kernel_f32::oh_step_unroll_kw(
68 int ur_w, int pad_l, int pad_r, int oc_blocks) {
69 int kw = jcp.kw;
70 int stride_w = jcp.stride_w;
71 int dilate_w = jcp.dilate_w + 1;
72 int ic_blk = jcp.ic_block;
73
74 for (int ki = 0; ki < kw; ki++) {
75 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
76 int jj_end = ur_w
77 - nstl::max(0,
78 div_up(ki * dilate_w + pad_r - (kw - 1) * dilate_w,
79 stride_w));
80 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
81 for (int jj = jj_start; jj < jj_end; jj++) {
82 size_t inp_off = get_input_offset(
83 ifm2, filter_w_to_input(ki, jj, pad_l));
84 movss(Xmm(oc_blocks * ur_w + jj + 1),
85 ptr[aux_reg_input + inp_off]);
86 shufps(Xmm(oc_blocks * ur_w + jj + 1),
87 Xmm(oc_blocks * ur_w + jj + 1), 0x0);
88 }
89
90 for (int ii = 0; ii < oc_blocks; ii++) {
91 for (int jj = jj_start; jj < jj_end; jj++) {
92 movups(xmm0,
93 ptr[aux_reg_kernel
94 + get_kernel_offset(ii, ki, ifm2)]);
95 mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
96 addps(Xmm(ur_w * ii + jj + 1), xmm0);
97 }
98 }
99 }
100 }
101}
102
103void jit_sse41_conv_fwd_kernel_f32::oh_step_nopad(
104 int ur_w, int pad_l, int pad_r, int oc_blocks) {
105 Label kw_loop;
106
107 int kw = jcp.kw;
108 int ic_blk = jcp.ic_block;
109
110 xor_(ki_iter, ki_iter);
111 L(kw_loop);
112 {
113 int jj_start = 0;
114 int jj_end = ur_w;
115 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
116 for (int jj = jj_start; jj < jj_end; jj++) {
117 size_t inp_off = get_input_offset(
118 ifm2, filter_w_to_input(0, jj, pad_l));
119 movss(Xmm(oc_blocks * ur_w + jj + 1),
120 ptr[aux_reg_input + inp_off]);
121 shufps(Xmm(oc_blocks * ur_w + jj + 1),
122 Xmm(oc_blocks * ur_w + jj + 1), 0x0);
123 }
124 for (int ii = 0; ii < oc_blocks; ii++) {
125 for (int jj = jj_start; jj < jj_end; jj++) {
126 movups(xmm0,
127 ptr[aux_reg_kernel
128 + get_kernel_offset(ii, 0, ifm2)]);
129 mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
130 addps(Xmm(ur_w * ii + jj + 1), xmm0);
131 }
132 }
133 }
134 add(aux_reg_kernel, get_kernel_offset(0, 1, 0));
135 add(aux_reg_input, get_input_offset(0, filter_w_to_input(1)));
136
137 inc(ki_iter);
138 cmp(ki_iter, kw);
139 jl(kw_loop, T_NEAR);
140 }
141}
142
143int get_xmm_idx(const int ur_w, const int oc_block_idx, const int ur_w_idx) {
144 return ur_w * oc_block_idx + ur_w_idx + 1;
145}
146
147Xmm get_xmm(const int ur_w, const int oc_block_idx, const int ur_w_idx) {
148 return Xmm(get_xmm_idx(ur_w, oc_block_idx, ur_w_idx));
149}
150
151template <typename F>
152static void iterate(const int oc_blocks, const int ur_w, const F &f) {
153 for (int i = 0; i < oc_blocks; i++) {
154 const bool mask_flag = i == oc_blocks - 1;
155 for (int j = 0; j < ur_w; j++)
156 f(mask_flag, i, j);
157 }
158}
159void jit_sse41_conv_fwd_kernel_f32::apply_postops(
160 const int oc_blocks, const int ur_w) {
161 injector_utils::vmm_index_set_t vmm_idxs;
162 if (jcp.with_binary) {
163 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
164 iterate(oc_blocks, ur_w,
165 [&](const bool mask_flag, const int i, const int j) {
166 const size_t o_off = get_output_offset(i, j);
167 const auto vmm_idx = get_xmm_idx(ur_w, i, j);
168 vmm_idxs.emplace(vmm_idx);
169
170 rhs_arg_params.vmm_idx_to_out_reg.emplace(
171 vmm_idx, reg_output);
172 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
173 vmm_idx, o_off);
174 if (mask_flag)
175 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
176 });
177
178 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
179 } else {
180 iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) {
181 vmm_idxs.emplace(get_xmm_idx(ur_w, i, j));
182 });
183 postops_injector_->compute_vector_range(vmm_idxs);
184 }
185}
186
187void jit_sse41_conv_fwd_kernel_f32::width_blk_step(
188 int ur_w, int pad_l, int pad_r, int oc_blocks) {
189 int kw = jcp.kw;
190 int oc_blk = jcp.oc_block;
191
192 xor_(simd_iter, simd_iter);
193
194 mov(aux_reg_input, reg_input);
195 mov(aux_reg_kernel, reg_kernel);
196
197 Label init_simd_iter_loop;
198 Label init_done;
199 Label init_first;
200
201 L(init_simd_iter_loop);
202
203 if (!jcp.with_sum) {
204 test(reg_ci_flag, FLAG_IC_FIRST);
205 jne(init_first, T_NEAR);
206 }
207
208 for (int ii = 0; ii < oc_blocks; ii++)
209 for (int jj = 0; jj < ur_w; jj++)
210 movups(get_xmm(ur_w, ii, jj),
211 xword[reg_output + get_output_offset(ii, jj)]);
212
213 if (jcp.with_sum && jcp.with_bias) {
214 test(reg_ci_flag, FLAG_IC_FIRST);
215 je(init_done, T_NEAR);
216
217 for (int ii = 0; ii < oc_blocks; ii++)
218 for (int jj = 0; jj < ur_w; jj++)
219 addps(get_xmm(ur_w, ii, jj),
220 xword[reg_bias + sizeof(float) * ii * oc_blk]);
221 }
222
223 jmp(init_done);
224
225 L(init_first);
226 if (this->jcp.with_bias) {
227 for (int ii = 0; ii < oc_blocks; ii++)
228 for (int jj = 0; jj < ur_w; jj++)
229 movups(get_xmm(ur_w, ii, jj),
230 xword[reg_bias + sizeof(float) * ii * oc_blk]);
231 } else {
232 for (int ii = 0; ii < oc_blocks; ii++)
233 for (int jj = 0; jj < ur_w; jj++) {
234 const auto xmm = get_xmm(ur_w, ii, jj);
235 pxor(xmm, xmm);
236 }
237 }
238
239 L(init_done);
240
241 Label skip_kh_loop;
242 mov(kj, reg_kh);
243 if ((jcp.dilate_h >= jcp.ih)
244 || (jcp.kh - 1) * (jcp.dilate_h + 1)
245 < nstl::max(jcp.t_pad, jcp.b_pad)) {
246 cmp(kj, 0);
247 je(skip_kh_loop, T_NEAR);
248 }
249 Label kh_loop;
250 L(kh_loop);
251 {
252 if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
253 oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks);
254 sub(aux_reg_input, get_input_offset(0, filter_w_to_input(kw)));
255 add(aux_reg_input, get_input_offset(0, filter_h_to_input(1)));
256 } else {
257 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
258 add(aux_reg_kernel, get_kernel_offset(0, kw, 0));
259 add(aux_reg_input, get_input_offset(0, filter_h_to_input(1)));
260 }
261
262 dec(kj);
263 cmp(kj, 0);
264 jg(kh_loop, T_NEAR);
265 }
266
267 L(skip_kh_loop);
268
269 if (jcp.with_eltwise || jcp.with_binary) {
270 Label regular_store;
271 test(reg_ci_flag, FLAG_IC_LAST);
272 je(regular_store, T_NEAR);
273
274 apply_postops(oc_blocks, ur_w);
275
276 L(regular_store);
277 }
278
279 for (int ii = 0; ii < oc_blocks; ii++) {
280 for (int jj = 0; jj < ur_w; jj++) {
281 const Xmm reg_out = get_xmm(ur_w, ii, jj);
282 movups(xword[reg_output + get_output_offset(ii, jj)], reg_out);
283 }
284 }
285
286 mov(aux_reg_kernel, reg_kernel);
287 mov(aux_reg_input, reg_input);
288 add(aux_reg_kernel, sizeof(float) * 4);
289 add(reg_output, sizeof(float) * 4);
290 add(reg_bias, sizeof(float) * 4);
291 inc(simd_iter);
292 cmp(simd_iter, 2);
293 jl(init_simd_iter_loop, T_NEAR);
294
295 sub(reg_output, sizeof(float) * 8);
296 sub(reg_bias, sizeof(float) * 8);
297}
298
299inline void jit_sse41_conv_fwd_kernel_f32::solve_common(int oc_blocks) {
300 int ur_w = jcp.ur_w;
301 int ur_w_tail = jcp.ur_w_tail;
302 int n_oi = jcp.ow / ur_w;
303 int iw = jcp.iw;
304 int kw = jcp.kw;
305 int str_w = jcp.stride_w;
306
307 int l_pad = jcp.l_pad;
308 int r_pad = nstl::max(0, jcp.r_pad);
309 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, str_w,
310 calculate_extended_filter_size(kw, jcp.dilate_w));
311 if (r_pad1 > 0) n_oi--;
312
313 if (l_pad > 0) {
314 n_oi--;
315 if (n_oi < 0 && r_pad1 > 0)
316 width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad"
317 else
318 width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad"
319 add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w, l_pad)));
320 add(reg_output, get_output_offset(0, ur_w));
321 }
322
323 Label ow_loop;
324 xor_(oi_iter, oi_iter);
325
326 if (n_oi > 0) {
327 L(ow_loop);
328
329 width_blk_step(ur_w, 0, 0, oc_blocks); // "middle"
330 add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w)));
331 add(reg_output, get_output_offset(0, ur_w));
332
333 inc(oi_iter);
334 cmp(oi_iter, n_oi);
335 jl(ow_loop, T_NEAR);
336 }
337
338 if (r_pad1 > 0 && n_oi >= 0) {
339 width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad"
340 add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w)));
341 add(reg_output, get_output_offset(0, ur_w));
342 }
343
344 if (ur_w_tail != 0)
345 width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail"
346}
347
348void jit_sse41_conv_fwd_kernel_f32::generate() {
349 this->preamble();
350
351 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
352 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
353 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
354 if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
355 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
356 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
357 mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
358
359 int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
360 Label tail, exit;
361
362 cmp(reg_oc_blocks, jcp.nb_oc_blocking);
363 jne(nb_oc_tail ? tail : exit, T_NEAR);
364
365 solve_common(jcp.nb_oc_blocking);
366 jmp(exit, T_NEAR);
367
368 if (nb_oc_tail) {
369 L(tail);
370 cmp(reg_oc_blocks, nb_oc_tail);
371 jne(exit, T_NEAR);
372 solve_common(nb_oc_tail);
373 }
374
375 L(exit);
376
377 this->postamble();
378
379 if (jcp.with_eltwise) postops_injector_->prepare_table();
380}
381
382status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
383 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
384 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
385 const primitive_attr_t &attr, int nthreads) {
386 if (!mayiuse(sse41)) return status::unimplemented;
387
388 jcp.nthr = nthreads;
389
390 jcp.prop_kind = cd.prop_kind;
391
392 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
393 const int ndims = src_d.ndims();
394 jcp.ndims = ndims;
395
396 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
397 jcp.mb = src_d.dims()[0];
398
399 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
400 jcp.ic = src_d.dims()[1] / jcp.ngroups;
401
402 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
403 jcp.iw = src_d.dims()[ndims - 1];
404 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
405 jcp.ow = dst_d.dims()[ndims - 1];
406
407 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
408 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
409
410 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
411 jcp.l_pad = cd.padding[0][ndims - 3];
412
413 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
414 jcp.stride_w = cd.strides[ndims - 3];
415
416 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0];
417 jcp.dilate_w = cd.dilates[ndims - 3];
418
419 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
420 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
421 jcp.r_pad = calculate_end_padding(
422 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
423 jcp.b_pad = calculate_end_padding(
424 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
425 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
426 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad
427 || ext_kh <= jcp.b_pad;
428 if (kernel_outside_src) return status::unimplemented;
429
430 const auto dat_tag_nxc = (ndims == 3 ? nwc : nhwc);
431 const auto dat_tag_ncx = (ndims == 3 ? ncw : nchw);
432 const auto dat_tag_nCx8c = (ndims == 3 ? nCw8c : nChw8c);
433 const auto wei_tag_OIxio = with_groups
434 ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o)
435 : pick(ndims - 3, OIw8i8o, OIhw8i8o);
436 const auto wei_tag_Oxio = with_groups ? pick(ndims - 3, gOwi8o, gOhwi8o)
437 : pick(ndims - 3, Owi8o, Ohwi8o);
438
439 jcp.src_tag
440 = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c);
441 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio);
442 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
443
444 const bool is_data_layout_nxc
445 = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
446
447 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
448
449 const auto &post_ops = attr.post_ops_;
450 jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
451 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
452 jcp.with_eltwise = eltwise_ind != -1;
453
454 const int binary_ind = post_ops.find(primitive_kind::binary);
455 jcp.with_binary = binary_ind != -1;
456
457 jcp.post_ops = post_ops;
458
459 using namespace injector;
460 static constexpr bool sum_at_pos_0_only = true;
461 static constexpr bool sum_requires_scale_one = true;
462 static constexpr bool sum_requires_zp_zero = true;
463 const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum},
464 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
465 sum_requires_zp_zero});
466 if (!post_ops_ok_) return status::unimplemented;
467
468 const bool flat = jcp.ic == 3;
469 const bool mimo = !flat;
470
471 bool args_ok = true
472 && IMPLICATION(flat,
473 jcp.wei_tag == wei_tag_Oxio
474 && ((jcp.src_tag == dat_tag_ncx
475 && jcp.dst_tag == dat_tag_nCx8c)
476 || (jcp.src_tag == dat_tag_nxc
477 && jcp.dst_tag == dat_tag_nxc)))
478 && IMPLICATION(mimo,
479 jcp.wei_tag == wei_tag_OIxio
480 && ((jcp.src_tag == dat_tag_nCx8c
481 && jcp.dst_tag == dat_tag_nCx8c)
482 || (jcp.src_tag == dat_tag_nxc
483 && jcp.dst_tag == dat_tag_nxc)))
484 && jcp.ic <= src_d.padded_dims()[1]
485 && jcp.oc <= dst_d.padded_dims()[1];
486 if (!args_ok) return status::unimplemented;
487
488 const int simd_w = 8; // 2 SSE vectors processing at once
489
490 jcp.ur_h = 1; /* no code-unrolling by h so far */
491 jcp.ur_w = 3;
492 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
493 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
494
495 jcp.nb_oc_blocking
496 = is_data_layout_nxc ? 1 : 4; /* the optimal value for the kernel */
497
498 args_ok = true && jcp.oc % simd_w == 0 && jcp.l_pad <= jcp.ur_w
499 && IMPLICATION(jcp.kw > 7,
500 (jcp.t_pad == 0 && jcp.l_pad == 0)
501 || (jcp.stride_w == 1 && jcp.stride_h == 1))
502 && IMPLICATION(mimo, jcp.ic % simd_w == 0);
503 if (!args_ok) return status::unimplemented;
504
505 int r_pad_no_tail = nstl::max(0,
506 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
507 jcp.stride_w, ext_kw));
508
509 // kernel needs 1 temporary YMM register
510 const int num_avail_regs = 15;
511 if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
512 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
513 jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
514 nstl::min(jcp.ow, num_avail_regs / 2));
515 jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
516 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
517 /* check again ... */
518 r_pad_no_tail = nstl::max(0,
519 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
520 jcp.stride_w, ext_kw));
521
522 if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
523 return status::unimplemented;
524 }
525 assert(jcp.nb_oc_blocking > 0);
526 assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
527
528 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
529 jcp.nb_ic = jcp.ic / jcp.ic_block;
530
531 jcp.oc_block = simd_w;
532 jcp.nb_oc = jcp.oc / jcp.oc_block;
533
534 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
535 jcp.nb_ic_blocking = 12;
536 jcp.nb_ic_blocking_max = 16;
537 } else {
538 jcp.nb_ic_blocking = 1;
539 jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
540 }
541
542 return status::success;
543}
544
545} // namespace x64
546} // namespace cpu
547} // namespace impl
548} // namespace dnnl
549