1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <bitset>
18#include <cassert>
19
20#include "common/bfloat16.hpp"
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/resampling_utils.hpp"
27
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_uni_resampling.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36using namespace resampling_utils;
37
38static cpu_isa_t get_supported_isa(bool is_plain) {
39 if (is_plain && mayiuse(avx512_core_fp16)) return avx512_core_fp16;
40 if (mayiuse(avx512_core_bf16)) return avx512_core_bf16;
41 if (mayiuse(avx512_core)) return avx512_core;
42 if (mayiuse(avx2)) return avx2;
43 if (mayiuse(avx)) return avx;
44 if (mayiuse(sse41)) return sse41;
45
46 return isa_undef;
47}
48
49static bool impl_supports_datatype(data_type_t data_type) {
50 switch (data_type) {
51 case data_type::bf16: return x64::mayiuse(x64::avx512_core);
52 case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16);
53 case data_type::f32:
54 case data_type::s32:
55 case data_type::s8:
56 case data_type::u8: return true;
57 default: return false;
58 }
59}
60
61status_t jit_uni_resampling_fwd_t::pd_t::init(engine_t *engine) {
62 using namespace data_type;
63 using sm = primitive_attr_t::skip_mask_t;
64
65 const memory_desc_wrapper src_d(src_md());
66 const memory_desc_wrapper dst_d(dst_md());
67
68 conf_.src_data_type = src_md()->data_type;
69 conf_.dst_data_type = dst_md()->data_type;
70
71 fill_format_tag_info();
72 conf_.isa = get_supported_isa(src_d.is_plain());
73
74 const bool ok = is_fwd() && !has_zero_dim_memory()
75 && conf_.src_tag != format_tag::undef
76 && set_default_params(conf_.src_tag) == status::success
77 && impl_supports_datatype(conf_.src_data_type)
78 && impl_supports_datatype(conf_.dst_data_type)
79 && IMPLICATION(conf_.src_data_type == f16, src_d.is_plain())
80 && attr()->has_default_values(sm::post_ops, conf_.dst_data_type)
81 && attr_.set_default_formats(dst_md(0)) == status::success;
82 if (!ok) return status::unimplemented;
83
84 if (!memory_desc_matches_tag(*dst_md(), conf_.src_tag))
85 return status::unimplemented;
86
87 conf_.alg = desc()->alg_kind;
88 conf_.c = C();
89 conf_.od = OD();
90 conf_.oh = OH();
91 conf_.ow = OW();
92 conf_.id = ID();
93 conf_.ih = IH();
94 conf_.iw = IW();
95 conf_.ndims = ndims();
96
97 if (conf_.alg == alg_kind::resampling_linear)
98 conf_.number_of_corners = pow(2, conf_.ndims - 2);
99
100 conf_.src_dt_size = types::data_type_size(conf_.src_data_type);
101 conf_.dst_dt_size = types::data_type_size(conf_.dst_data_type);
102
103 conf_.is_saturation_needed
104 = utils::one_of(conf_.dst_data_type, s32, s8, u8);
105
106 const size_t L3_size = static_cast<size_t>(dnnl_get_current_num_threads())
107 * platform::get_per_core_cache_size(3);
108 const size_t input_data_size = src_d.nelems(true) * conf_.src_dt_size;
109 const size_t output_data_size = dst_d.nelems(true) * conf_.dst_dt_size;
110 const size_t whole_data_size = input_data_size + output_data_size;
111 conf_.output_data_size = output_data_size;
112 conf_.is_data_size_bigger_than_L3
113 = L3_size > 0 ? whole_data_size > L3_size : false;
114
115 conf_.el_size_of_indices = sizeof(unsigned);
116
117 conf_.inner_stride = src_d.blocking_desc().strides[ndims() - 1];
118 conf_.stride_d = IH() * IW() * conf_.inner_stride * conf_.src_dt_size;
119 conf_.stride_h = IW() * conf_.inner_stride * conf_.src_dt_size;
120 conf_.stride_w = conf_.inner_stride * conf_.src_dt_size;
121
122 const std::vector<injector::post_op_type> accepted_post_ops
123 = {injector::sum, injector::eltwise, injector::binary};
124 static constexpr bool sum_at_0_pos_only = false;
125 static constexpr bool sum_requires_scale_one = false;
126 static constexpr bool sum_requires_zp_zero = true;
127 const bcast_set_t accepted_broadcasts
128 = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
129 broadcasting_strategy_t::per_oc_spatial};
130 injector::post_ops_ok_args_t post_ops_args(conf_.isa, accepted_post_ops,
131 attr()->post_ops_, &dst_d, sum_at_0_pos_only,
132 sum_requires_scale_one, sum_requires_zp_zero, accepted_broadcasts);
133 if (!post_ops_ok(post_ops_args)) return status::unimplemented;
134
135 conf_.post_ops = attr()->post_ops_;
136
137 static constexpr bool require_scale_one = false;
138 conf_.with_eltwise = conf_.with_binary = conf_.with_sum = false;
139 for (const auto &entry : conf_.post_ops.entry_) {
140 if (entry.is_eltwise()) {
141 conf_.with_eltwise = true;
142 } else if (entry.is_binary()) {
143 conf_.with_binary = true;
144 } else if (entry.is_sum(require_scale_one) && entry.sum.scale != 0.f) {
145 conf_.with_sum = true;
146 conf_.sum_scales.push(entry.sum.scale);
147 }
148 }
149 conf_.with_postops
150 = conf_.with_eltwise || conf_.with_binary || conf_.with_sum;
151
152 return status::success;
153}
154
155void jit_uni_resampling_fwd_t::pd_t::fill_format_tag_info() {
156 using namespace format_tag;
157
158 const format_tag_t blocked_16_format = memory_desc_matches_one_of_tag(
159 *src_md(), nCw16c, nChw16c, nCdhw16c);
160 const format_tag_t blocked_8_format
161 = memory_desc_matches_one_of_tag(*src_md(), nCw8c, nChw8c, nCdhw8c);
162 const format_tag_t nspc_format
163 = memory_desc_matches_one_of_tag(*src_md(), nwc, nhwc, ndhwc);
164 const format_tag_t ncsp_format
165 = memory_desc_matches_one_of_tag(*src_md(), ncw, nchw, ncdhw);
166
167 if (blocked_16_format != undef) {
168 conf_.tag_kind = jit_memory_tag_kind_t::blocked;
169 conf_.src_tag = blocked_16_format;
170 } else if (blocked_8_format != undef) {
171 conf_.is_blocked_8_format = true;
172 conf_.tag_kind = jit_memory_tag_kind_t::blocked;
173 conf_.src_tag = blocked_8_format;
174 } else if (nspc_format != undef) {
175 conf_.tag_kind = jit_memory_tag_kind_t::nspc;
176 conf_.src_tag = nspc_format;
177 } else if (ncsp_format != undef) {
178 conf_.tag_kind = jit_memory_tag_kind_t::ncsp;
179 conf_.src_tag = ncsp_format;
180 } else {
181 conf_.tag_kind = jit_memory_tag_kind_t::undef;
182 conf_.src_tag = undef;
183 }
184}
185
186status_t jit_uni_resampling_fwd_t::get_proper_kernel_for_avx512(
187 const memory_desc_t *dst_md, const jit_resampling_conf_t &conf) {
188 const format_tag_t blocked_8_tag = utils::pick(conf.ndims - 3,
189 format_tag::nCw8c, format_tag::nChw8c, format_tag::nCdhw8c);
190 if (is_superset(conf.isa, avx512_core_fp16))
191 return safe_ptr_assign(kernel_,
192 new jit_uni_resampling_kernel_t<avx512_core_fp16, Xbyak::Zmm>(
193 conf, dst_md));
194
195 if (memory_desc_matches_tag(*pd()->src_md(), blocked_8_tag)) {
196 return safe_ptr_assign(kernel_,
197 new jit_uni_resampling_kernel_t<avx512_core, Xbyak::Ymm>(
198 conf, dst_md));
199 }
200
201 return safe_ptr_assign(kernel_,
202 new jit_uni_resampling_kernel_t<avx512_core, Xbyak::Zmm>(
203 conf, dst_md));
204}
205
206status_t jit_uni_resampling_fwd_t::get_proper_kernel_for_avx(
207 const memory_desc_t *dst_md, const jit_resampling_conf_t &conf) {
208 using namespace data_type;
209
210 const bool is_src_i8 = utils::one_of(conf.src_data_type, s8, u8);
211 const bool is_dst_i8 = utils::one_of(conf.dst_data_type, s8, u8);
212 if (is_src_i8 || is_dst_i8)
213 return safe_ptr_assign(kernel_,
214 new jit_uni_resampling_kernel_t<avx, Xbyak::Xmm>(conf, dst_md));
215
216 return safe_ptr_assign(kernel_,
217 new jit_uni_resampling_kernel_t<avx, Xbyak::Ymm>(conf, dst_md));
218}
219
220status_t jit_uni_resampling_fwd_t::get_proper_kernel_for_sse(
221 const memory_desc_t *dst_md, const jit_resampling_conf_t &conf) {
222 return safe_ptr_assign(kernel_,
223 new jit_uni_resampling_kernel_t<sse41, Xbyak::Xmm>(conf, dst_md));
224}
225
226status_t jit_uni_resampling_fwd_t::init(engine_t *engine) {
227 using namespace format_tag;
228
229 const memory_desc_t *dst_md = pd()->dst_md();
230 const jit_resampling_conf_t &conf = pd()->get_conf();
231
232 if (is_superset(conf.isa, avx512_core))
233 CHECK(get_proper_kernel_for_avx512(dst_md, conf));
234 else if (is_superset(conf.isa, avx))
235 CHECK(get_proper_kernel_for_avx(dst_md, conf));
236 else if (conf.isa == sse41) {
237 CHECK(get_proper_kernel_for_sse(dst_md, conf));
238 } else {
239 assert(!"Unsupported isa.");
240 return status::runtime_error;
241 }
242
243 CHECK(kernel_->create_kernel());
244
245 return fill_data_for_interpolation();
246}
247
248status_t jit_uni_resampling_fwd_t::fill_data_for_interpolation() {
249 switch (pd()->desc()->alg_kind) {
250 case alg_kind::resampling_nearest: return fill_data_for_nearest();
251 case alg_kind::resampling_linear: return fill_data_for_linear();
252 default:
253 assert(!"Invalid resampling algorithm.");
254 return status::invalid_arguments;
255 }
256}
257
258status_t jit_uni_resampling_fwd_t::fill_data_for_nearest() {
259 // In kernel is used vmovdqu to get indices. This instruction don't have
260 // tail processing possibilities on sse41 and avx. To avoid problems
261 // with that, OW is aligned to simd width, because indices for ow
262 // are read in the kernel.
263 indices_.reserve(pd()->OD() + pd()->OH()
264 + utils::rnd_up(pd()->OW(), kernel_->get_simd_w()));
265
266 for (dim_t od = 0; od < pd()->OD(); od++) {
267 const int offset_id = nearest_idx(od, pd()->OD(), pd()->ID())
268 * pd()->get_conf().stride_d;
269 indices_.emplace_back(offset_id);
270 }
271 for (dim_t oh = 0; oh < pd()->OH(); oh++) {
272 const int offset_ih = nearest_idx(oh, pd()->OH(), pd()->IH())
273 * pd()->get_conf().stride_h;
274 indices_.emplace_back(offset_ih);
275 }
276 for (dim_t ow = 0; ow < pd()->OW(); ow++) {
277 const int offset_iw = nearest_idx(ow, pd()->OW(), pd()->IW())
278 * pd()->get_conf().stride_w;
279 indices_.emplace_back(offset_iw);
280 }
281
282 return status::success;
283}
284
285status_t jit_uni_resampling_fwd_t::fill_data_for_linear() {
286 using namespace resampling_utils;
287
288 const unsigned number_of_corners = pd()->get_conf().number_of_corners;
289 const unsigned stride_w = pd()->get_conf().stride_w;
290 const unsigned stride_h = pd()->get_conf().stride_h;
291 const unsigned stride_d = pd()->get_conf().stride_d;
292
293 unsigned num_of_elements = 0;
294 if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::ncsp) {
295 // In kernel is used vmovdqu to get indices. This instruction don't have
296 // tail processing possibilities on sse41 and avx. To avoid problems
297 // with that, number of spatial points is aligned to simd width, because
298 // all of them are read in the kernel.
299 num_of_elements = number_of_corners
300 * utils::rnd_up(pd()->OD() * pd()->OH() * pd()->OW(),
301 kernel_->get_simd_w());
302
303 indices_.resize(num_of_elements);
304 weights_.resize(num_of_elements);
305
306 const size_t indices_stride = pd()->OW() * pd()->OH() * pd()->OD();
307 const size_t weights_stride = pd()->OW() * pd()->OH() * pd()->OD();
308
309 parallel_nd(pd()->OD(), pd()->OH(), [&](dim_t od, dim_t oh) {
310 const linear_coeffs_t coeffs_id(od, pd()->OD(), pd()->ID());
311 const linear_coeffs_t coeffs_ih(oh, pd()->OH(), pd()->IH());
312
313 for (dim_t ow = 0; ow < pd()->OW(); ow++) {
314 const size_t offset
315 = od * pd()->OH() * pd()->OW() + oh * pd()->OW() + ow;
316
317 const linear_coeffs_t coeffs_iw(ow, pd()->OW(), pd()->IW());
318
319 for (unsigned i = 0; i < number_of_corners; i++) {
320 std::bitset<3> corners(i);
321 indices_[i * indices_stride + offset]
322 = coeffs_id.idx[corners.test(2)] * stride_d
323 + coeffs_ih.idx[corners.test(1)] * stride_h
324 + coeffs_iw.idx[corners.test(0)] * stride_w;
325 weights_[i * weights_stride + offset]
326 = coeffs_id.wei[corners.test(2)]
327 * coeffs_ih.wei[corners.test(1)]
328 * coeffs_iw.wei[corners.test(0)];
329 }
330 }
331 });
332 } else if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::nspc
333 || pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) {
334 num_of_elements = 2 * (pd()->OD() + pd()->OH() + pd()->OW());
335
336 indices_.resize(num_of_elements);
337 weights_.resize(num_of_elements);
338
339 unsigned *indices_w = &indices_[0];
340 unsigned *indices_h = &indices_[2 * pd()->OW()];
341 unsigned *indices_d = &indices_[2 * (pd()->OW() + pd()->OH())];
342 float *weights_w = &weights_[0];
343 float *weights_h = &weights_[2 * pd()->OW()];
344 float *weights_d = &weights_[2 * (pd()->OW() + pd()->OH())];
345
346 for (dim_t ow = 0; ow < pd()->OW(); ow++) {
347 const linear_coeffs_t coeffs_iw(ow, pd()->OW(), pd()->IW());
348
349 // The right and left corners are set one after
350 // the other because in the kernel these values
351 // are read one by one, which makes it easier
352 // to read and makes the operation faster.
353 weights_w[2 * ow] = coeffs_iw.wei[0];
354 weights_w[2 * ow + 1] = coeffs_iw.wei[1];
355 indices_w[2 * ow] = coeffs_iw.idx[0] * stride_w;
356 indices_w[2 * ow + 1] = coeffs_iw.idx[1] * stride_w;
357 }
358
359 for (dim_t oh = 0; oh < pd()->OH(); oh++) {
360 const linear_coeffs_t coeffs_ih(oh, pd()->OH(), pd()->IH());
361
362 weights_h[oh] = coeffs_ih.wei[0];
363 weights_h[pd()->OH() + oh] = coeffs_ih.wei[1];
364 indices_h[oh] = coeffs_ih.idx[0] * stride_h;
365 indices_h[pd()->OH() + oh] = coeffs_ih.idx[1] * stride_h;
366 }
367
368 for (dim_t od = 0; od < pd()->OD(); od++) {
369 const linear_coeffs_t coeffs_id(od, pd()->OD(), pd()->ID());
370
371 weights_d[od] = coeffs_id.wei[0];
372 weights_d[pd()->OD() + od] = coeffs_id.wei[1];
373 indices_d[od] = coeffs_id.idx[0] * stride_d;
374 indices_d[pd()->OD() + od] = coeffs_id.idx[1] * stride_d;
375 }
376 } else {
377 assert(!"Invalid memory format kind.");
378 return status::invalid_arguments;
379 }
380
381 return status::success;
382}
383
384status_t jit_uni_resampling_fwd_t::execute(const exec_ctx_t &ctx) const {
385 const auto src = CTX_IN_MEM(const uint8_t *, DNNL_ARG_SRC);
386 auto dst = CTX_OUT_MEM(uint8_t *, DNNL_ARG_DST);
387
388 const std::vector<const void *> post_ops_binary_rhs_arg_vec
389 = binary_injector::prepare_binary_args(
390 pd()->get_conf().post_ops, ctx);
391
392 switch (pd()->desc()->alg_kind) {
393 case alg_kind::resampling_nearest:
394 return interpolate_nearest(src, dst, post_ops_binary_rhs_arg_vec);
395 case alg_kind::resampling_linear:
396 return interpolate_linear(src, dst, post_ops_binary_rhs_arg_vec);
397 default:
398 assert(!"Invalid resampling algorithm.");
399 return status::invalid_arguments;
400 }
401}
402
403status_t jit_uni_resampling_fwd_t::interpolate_nearest(const uint8_t *src,
404 uint8_t *dst, const std::vector<const void *> &post_ops_args) const {
405 const size_t src_dt_size = pd()->get_conf().src_dt_size;
406 const size_t dst_dt_size = pd()->get_conf().dst_dt_size;
407 const size_t inner_stride = pd()->get_conf().inner_stride;
408
409 const dim_t MB = pd()->MB();
410 const dim_t C = pd()->C();
411 const dim_t CB = utils::div_up(C, inner_stride);
412 const dim_t nsp_outer = MB * CB;
413 const dim_t OD = pd()->OD();
414 const dim_t OH = pd()->OH();
415 const dim_t OW = pd()->OW();
416 const dim_t ID = pd()->ID();
417 const dim_t IH = pd()->IH();
418 const dim_t IW = pd()->IW();
419
420 const unsigned *indices_d = &indices_[0];
421 const unsigned *indices_h = &indices_[OD];
422 const unsigned *indices_w = &indices_[OD + OH];
423
424 if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::ncsp) {
425 parallel_nd(MB, C, OD, [&](dim_t mb, dim_t c, dim_t od) {
426 const dim_t src_off
427 = (mb * C + c) * ID * IH * IW * src_dt_size + indices_d[od];
428 const dim_t dst_off = ((mb * C + c) * OD * OH * OW + od * OH * OW)
429 * dst_dt_size;
430
431 jit_resampling_call_s args = jit_resampling_call_s();
432 args.src = src + src_off;
433 args.dst = dst + dst_off;
434 args.dst_orig = dst;
435 args.indices = &indices_h[0];
436 args.post_ops_binary_rhs_arg_vec = post_ops_args.data();
437 args.c_offset = static_cast<size_t>(c);
438
439 (*kernel_)(&args);
440 });
441 } else if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::nspc
442 || pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) {
443 parallel_nd(nsp_outer, OD, OH, [&](dim_t nsp, dim_t od, dim_t oh) {
444 const dim_t src_off
445 = nsp * ID * IH * IW * inner_stride * src_dt_size
446 + indices_d[od] + indices_h[oh];
447 const dim_t dst_off = ((nsp * OD + od) * OH + oh) * OW
448 * inner_stride * dst_dt_size;
449
450 const size_t cb = std::div(nsp, CB).rem;
451
452 jit_resampling_call_s args = jit_resampling_call_s();
453 args.batch_of_sp_points_to_process = OW;
454 args.src = src + src_off;
455 args.dst = dst + dst_off;
456 args.dst_orig = dst;
457 args.indices = &indices_w[0];
458 args.post_ops_binary_rhs_arg_vec = post_ops_args.data();
459 args.c_offset = static_cast<size_t>(cb * inner_stride);
460
461 (*kernel_)(&args);
462 });
463 } else {
464 assert(!"Invalid memory format kind.");
465 return status::invalid_arguments;
466 }
467
468 return status::success;
469}
470
471status_t jit_uni_resampling_fwd_t::interpolate_linear(const uint8_t *src,
472 uint8_t *dst, const std::vector<const void *> &post_ops_args) const {
473 const size_t src_dt_size = pd()->get_conf().src_dt_size;
474 const size_t dst_dt_size = pd()->get_conf().dst_dt_size;
475 const size_t inner_stride = pd()->get_conf().inner_stride;
476
477 const dim_t MB = pd()->MB();
478 const dim_t C = pd()->C();
479 const dim_t CB = utils::div_up(C, inner_stride);
480 const dim_t nsp_outer = MB * CB;
481 const dim_t OD = pd()->OD();
482 const dim_t OH = pd()->OH();
483 const dim_t OW = pd()->OW();
484 const dim_t ID = pd()->ID();
485 const dim_t IH = pd()->IH();
486 const dim_t IW = pd()->IW();
487
488 if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::ncsp) {
489 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
490 const dim_t src_off = (mb * C + c) * ID * IH * IW * src_dt_size;
491 const dim_t dst_off = (mb * C + c) * OD * OH * OW * dst_dt_size;
492
493 jit_resampling_call_s args = jit_resampling_call_s();
494 args.batch_of_sp_points_to_process = OW * OH * OD;
495 args.src = src + src_off;
496 args.dst = dst + dst_off;
497 args.dst_orig = dst;
498 args.indices = &indices_[0];
499 args.weights = &weights_[0];
500 args.post_ops_binary_rhs_arg_vec = post_ops_args.data();
501 args.c_offset = static_cast<size_t>(c);
502
503 (*kernel_)(&args);
504 });
505 } else if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::nspc
506 || pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) {
507 const unsigned *indices_top = &indices_[2 * OW];
508 const unsigned *indices_bottom = &indices_[2 * OW + OH];
509 const unsigned *indices_front = &indices_[2 * (OW + OH)];
510 const unsigned *indices_back = &indices_[2 * (OW + OH) + OD];
511 const float *weights_top = &weights_[2 * OW];
512 const float *weights_bottom = &weights_[2 * OW + OH];
513 const float *weights_front = &weights_[2 * (OW + OH)];
514 const float *weights_back = &weights_[2 * (OW + OH) + OD];
515
516 parallel_nd(nsp_outer, OD, OH, [&](dim_t nsp, dim_t od, dim_t oh) {
517 const dim_t src_off
518 = nsp * ID * IH * IW * inner_stride * src_dt_size;
519 const dim_t dst_off = (((nsp * OD + od) * OH + oh) * OW)
520 * inner_stride * dst_dt_size;
521
522 const size_t cb = std::div(nsp, CB).rem;
523
524 jit_resampling_call_s args = jit_resampling_call_s();
525 args.batch_of_sp_points_to_process = OW;
526 args.src = src + src_off;
527 args.dst = dst + dst_off;
528 args.dst_orig = dst;
529 args.indices = &indices_[0];
530 args.weights = &weights_[0];
531 args.post_ops_binary_rhs_arg_vec = post_ops_args.data();
532 args.c_offset = static_cast<size_t>(cb * inner_stride);
533
534 args.src_offset_front = indices_front[od];
535 args.src_offset_back = indices_back[od];
536 args.src_offset_top = indices_top[oh];
537 args.src_offset_bottom = indices_bottom[oh];
538 args.weight_front = weights_front[od];
539 args.weight_back = weights_back[od];
540 args.weight_top = weights_top[oh];
541 args.weight_bottom = weights_bottom[oh];
542
543 (*kernel_)(&args);
544 });
545 } else {
546 assert(!"Invalid memory format kind.");
547 return status::invalid_arguments;
548 }
549
550 return status::success;
551}
552
553} // namespace x64
554} // namespace cpu
555} // namespace impl
556} // namespace dnnl
557