1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REORDER_SIMPLE_REORDER_HPP
18#define CPU_REORDER_SIMPLE_REORDER_HPP
19
20#include <algorithm>
21#include <assert.h>
22
23#include "common/bfloat16.hpp"
24#include "common/c_types_map.hpp"
25#include "common/dnnl_thread.hpp"
26#include "common/math_utils.hpp"
27#include "common/primitive.hpp"
28#include "common/primitive_attr.hpp"
29#include "common/tag_traits.hpp"
30#include "common/type_helpers.hpp"
31#include "common/utils.hpp"
32
33#include "cpu/cpu_primitive.hpp"
34#include "cpu/reorder/cpu_reorder_pd.hpp"
35
36#include "cpu/simple_q10n.hpp"
37
38namespace dnnl {
39namespace impl {
40namespace cpu {
41
42using bd = block_dim_t;
43using ib = inner_blk_t;
44
45template <impl::data_type_t type>
46using data_t = typename prec_traits<type>::type;
47
48template <impl::data_type_t type_i, impl::data_type_t type_o>
49using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
50
51template <impl::data_type_t type_i, impl::data_type_t type_o>
52using _qz = qz<data_t<type_i>, data_t<type_o>>;
53
54namespace fmt_order {
55const bool keep = true;
56const bool reverse = false;
57const bool any = keep;
58} // namespace fmt_order
59
60namespace spec {
61struct direct_copy {};
62struct direct_copy_except_dim_0 {};
63struct reference {};
64struct conv_req_comp {}; // {s8, u8: asymmetric quantization}
65} // namespace spec
66
67#define SIMPLE_REORDER_TEMPL_DECL \
68 impl::data_type_t type_i, impl::format_tag_t tag_i, \
69 impl::data_type_t type_o, impl::format_tag_t tag_o, \
70 bool order_keep
71#define SIMPLE_REORDER_TEMPL_CALL type_i, tag_i, type_o, tag_o, order_keep
72
73#define DECLARE_COMMON_PARAMS() \
74 auto input = CTX_IN_MEM(const data_t<type_i> *, DNNL_ARG_FROM); \
75 auto output = CTX_OUT_MEM(data_t<type_o> *, DNNL_ARG_TO); \
76 const auto &scratchpad = ctx.get_scratchpad_grantor(); \
77 MAYBE_UNUSED(scratchpad); \
78 const auto input_d = ctx.memory_mdw(DNNL_ARG_FROM, pd->src_md()); \
79 const auto output_d = ctx.memory_mdw(DNNL_ARG_TO, pd->dst_md()); \
80 DEFINE_ARG_SCALES_BUFFER_ATTR(pd->attr(), src_scales, DNNL_ARG_FROM); \
81 DEFINE_ARG_SCALES_BUFFER_ATTR(pd->attr(), dst_scales_, DNNL_ARG_TO); \
82 int src_scales_mask, dst_scales_mask; \
83 CHECK(get_scales_mask(pd->attr(), &src_scales_mask, &dst_scales_mask)); \
84 int scales_mask = std::max(src_scales_mask, dst_scales_mask); \
85 MAYBE_UNUSED(scales_mask); \
86 dim_t D_start, D_mask, D_rest; \
87 pd->get_D_values(input_d, scales_mask, &D_start, &D_mask, &D_rest); \
88 const float *dst_scales = pd->precompute_scales( \
89 scratchpad, pd->attr(), D_mask, dst_scales_); \
90 MAYBE_UNUSED(dst_scales); \
91 DEFINE_ZERO_POINT_VALUE_ATTR(pd->attr(), src_zp, DNNL_ARG_FROM); \
92 DEFINE_ZERO_POINT_VALUE_ATTR(pd->attr(), dst_zp, DNNL_ARG_TO); \
93 const float alpha = src_scales[0] * dst_scales[0]; \
94 MAYBE_UNUSED(alpha); \
95 const float beta = pd->beta(); \
96 MAYBE_UNUSED(beta);
97
98#define GET_SCRATCHPAD_SIZE_ZERO() \
99 static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, \
100 const memory_desc_wrapper &output_d) { \
101 return 0; \
102 }
103
104/* specific reorders: common template */
105template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
106struct simple_reorder_impl {};
107
108namespace {
109inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i,
110 impl::format_tag_t tag_o, const memory_desc_wrapper &input_d,
111 const memory_desc_wrapper &output_d) {
112 if (input_d.has_runtime_dims_or_strides()) return false;
113 return input_d.matches_tag(order_keep ? tag_i : tag_o)
114 && output_d.matches_tag(order_keep ? tag_o : tag_i);
115}
116inline bool simple_po_check(const primitive_attr_t *attr) {
117 const auto &po = attr->post_ops_;
118 return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false));
119}
120inline status_t get_scales_mask(
121 const primitive_attr_t *attr, int *src_mask, int *dst_mask) {
122 const auto &s = attr->scales_;
123 if (src_mask) {
124 *src_mask = 0;
125 if (!s.get(DNNL_ARG_SRC).has_default_values())
126 *src_mask = s.get(DNNL_ARG_SRC).mask_;
127 }
128 if (dst_mask) {
129 *dst_mask = 0;
130 if (!s.get(DNNL_ARG_DST).has_default_values())
131 *dst_mask = s.get(DNNL_ARG_DST).mask_;
132 }
133
134 // This is used in a check function.
135 if (*src_mask > 0 && *dst_mask > 0 && *dst_mask != *src_mask)
136 return status::invalid_arguments;
137 return status::success;
138}
139inline bool simple_attr_check(const primitive_attr_t *attr,
140 bool many_scales_support, bool sum_support) {
141 using smask_t = primitive_attr_t::skip_mask_t;
142 smask_t skip_mask = smask_t::scales_runtime;
143 if (sum_support) skip_mask = skip_mask | smask_t::post_ops;
144 if (!attr->has_default_values(skip_mask)) return false;
145 if (sum_support) simple_po_check(attr);
146 if (many_scales_support) return true;
147 int src_mask, dst_mask;
148 if (get_scales_mask(attr, &src_mask, &dst_mask) != status::success)
149 return false;
150 return src_mask == 0 && dst_mask == 0;
151}
152} // namespace
153
154/* specific reorders: implementation */
155template <SIMPLE_REORDER_TEMPL_DECL>
156struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
157 typename utils::enable_if<tag_i == format_tag::any
158 && utils::one_of(tag_o, format_tag::wio,
159 format_tag::wigo, format_tag::hwio,
160 format_tag::hwigo, format_tag::dhwio,
161 format_tag::dhwigo),
162 spec::conv_req_comp>::type> {
163 static bool is_applicable(const memory_desc_wrapper &input_d,
164 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
165 using namespace data_type;
166 using namespace utils;
167
168 if (input_d.has_runtime_dims_or_strides()) return false;
169
170 int src_scales_mask, dst_scales_mask;
171 auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask);
172 if (status != status::success) return false;
173 int scales_mask = std::max(src_scales_mask, dst_scales_mask);
174
175 static constexpr bool w_groups = one_of(
176 tag_o, format_tag::wigo, format_tag::hwigo, format_tag::dhwigo);
177
178 const bool req_comp = output_d.extra().flags
179 & memory_extra_flags::compensation_conv_s8s8;
180 const bool req_asymmetric_comp = output_d.extra().flags
181 & memory_extra_flags::compensation_conv_asymmetric_src;
182
183 auto mask_ok = [&](bool check, int mask) {
184 return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1));
185 };
186
187 return simple_attr_check(attr, true, false)
188 && output_d.matches_tag(tag_o) && input_d.is_plain()
189 && (req_comp || req_asymmetric_comp)
190 && mask_ok(req_comp, output_d.extra().compensation_mask)
191 && mask_ok(req_asymmetric_comp,
192 output_d.extra().asymm_compensation_mask)
193 && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1))
194 && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3))
195 && one_of(input_d.data_type(), f32, s8, bf16)
196 && output_d.data_type() == s8;
197 }
198
199 GET_SCRATCHPAD_SIZE_ZERO();
200
201 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
202 DECLARE_COMMON_PARAMS();
203
204 static constexpr bool w_groups = utils::one_of(
205 tag_o, format_tag::wigo, format_tag::hwigo, format_tag::dhwigo);
206 static constexpr bool w_height
207 = !utils::one_of(tag_o, format_tag::wio, format_tag::wigo);
208 static constexpr bool w_depth
209 = utils::one_of(tag_o, format_tag::dhwio, format_tag::dhwigo);
210
211 const auto &dims = input_d.dims();
212
213 const dim_t G = w_groups ? dims[0] : 1;
214 const dim_t OC = dims[w_groups + 0];
215 const dim_t IC = dims[w_groups + 1];
216 const dim_t D = w_depth ? dims[w_groups + 2] : 1;
217 const dim_t H = w_height ? dims[w_groups + w_depth + 2] : 1;
218 const dim_t W = dims[w_groups + w_depth + w_height + 2];
219
220 const bool req_comp = output_d.extra().flags
221 & memory_extra_flags::compensation_conv_s8s8;
222 const bool has_asymmetric_comp = output_d.extra().flags
223 & memory_extra_flags::compensation_conv_asymmetric_src;
224
225 assert(req_comp || has_asymmetric_comp);
226
227 float adj_scale
228 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
229 ? output_d.extra().scale_adjust
230 : 1.f;
231
232 size_t offset = output_d.size() - output_d.additional_buffer_size();
233 size_t comp_size = output_d.additional_buffer_size(
234 memory_extra_flags::compensation_conv_s8s8);
235 size_t zp_offset = offset + (req_comp ? comp_size : 0);
236 int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
237 : nullptr;
238 int32_t *zp = has_asymmetric_comp
239 ? reinterpret_cast<int32_t *>(output + zp_offset)
240 : nullptr;
241
242 const bool per_oc = scales_mask & (1 << (w_groups + 0));
243 const bool per_ic = scales_mask & (1 << (w_groups + 1));
244 const size_t ic_stride = per_ic ? 1 : 0;
245 const size_t oc_stride = per_oc ? per_ic ? IC : 1 : 0;
246
247 parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
248 if (req_comp) cp[g * OC + oc] = 0;
249 if (has_asymmetric_comp) zp[g * OC + oc] = 0;
250 for_(dim_t ic = 0; ic < IC; ic++)
251 for_(dim_t d = 0; d < D; d++)
252 for_(dim_t h = 0; h < H; h++)
253 for (dim_t w = 0; w < W; w++) {
254 auto i = w_depth
255 ? input[input_d.blk_off<!w_groups>(g, oc, ic, d, h, w)]
256 : w_height ? input[input_d.blk_off<!w_groups>(
257 g, oc, ic, h, w)]
258 : input[input_d.blk_off<!w_groups>(
259 g, oc, ic, w)];
260 auto &o = w_depth
261 ? output[output_d.blk_off<!w_groups>(
262 g, oc, ic, d, h, w)]
263 : w_height ? output[output_d.blk_off<!w_groups>(
264 g, oc, ic, h, w)]
265 : output[output_d.blk_off<!w_groups>(
266 g, oc, ic, w)];
267 const size_t os_off
268 = (g * OC + oc) * oc_stride + ic * ic_stride;
269 const float s = src_scales[src_scales_mask == 0 ? 0 : os_off];
270 const float d = dst_scales[dst_scales_mask == 0 ? 0 : os_off];
271
272 o = qz_b0<data_t<type_i>, data_t<type_o>>()(
273 i, s * adj_scale * d);
274 if (req_comp) cp[g * OC + oc] -= (int32_t)o;
275 if (has_asymmetric_comp) zp[g * OC + oc] -= (int32_t)o;
276 }
277 if (req_comp) cp[g * OC + oc] *= 128;
278 });
279 return status::success;
280 }
281};
282
283template <SIMPLE_REORDER_TEMPL_DECL>
284struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
285 typename utils::enable_if<
286 (utils::one_of(tag_i, format_tag::iwo, format_tag::oiw,
287 format_tag::wio)
288 && utils::one_of(tag_o, format_tag::OIw4i16o4i,
289 format_tag::OIw4i32o4i, format_tag::OIw4i64o4i,
290 format_tag::OIw2i8o4i, format_tag::OIw4o4i))
291 || (utils::one_of(tag_i, format_tag::oi, format_tag::io)
292 && utils::one_of(tag_o, format_tag::OI4i16o4i,
293 format_tag::OI4i32o4i,
294 format_tag::OI4i64o4i))
295 || (utils::one_of(
296 tag_i, format_tag::goiw, format_tag::wigo)
297 && utils::one_of(tag_o, format_tag::gOIw4i16o4i,
298 format_tag::gOIw2i8o4i,
299 format_tag::gOIw4o4i))
300 || (utils::one_of(tag_i, format_tag::ihwo,
301 format_tag::hwio, format_tag::oihw)
302 && utils::one_of(tag_o, format_tag::OIhw4i16o4i,
303 format_tag::OIhw4i32o4i,
304 format_tag::OIhw4i64o4i,
305 format_tag::OIhw2i8o4i,
306 format_tag::OIhw4o4i))
307 || (utils::one_of(tag_i, format_tag::idhwo,
308 format_tag::dhwio, format_tag::oidhw)
309 && utils::one_of(tag_o,
310 format_tag::OIdhw4i16o4i,
311 format_tag::OIdhw4i32o4i,
312 format_tag::OIdhw4i64o4i,
313 format_tag::OIdhw2i8o4i,
314 format_tag::OIdhw4o4i))
315 || (utils::one_of(
316 tag_i, format_tag::goihw, format_tag::hwigo)
317 && utils::one_of(tag_o, format_tag::gOIhw4o4i,
318 format_tag::gOIhw2i8o4i,
319 format_tag::gOIhw4i16o4i))
320 || (utils::one_of(tag_i, format_tag::goidhw)
321 && (utils::one_of(tag_o,
322 format_tag::gOIdhw4i16o4i,
323 format_tag::gOIdhw2i8o4i,
324 format_tag::gOIdhw4o4i))),
325 spec::conv_req_comp>::type> {
326 static bool is_applicable(const memory_desc_wrapper &input_d,
327 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
328 using namespace format_tag;
329 using namespace data_type;
330 using namespace utils;
331
332 if (input_d.has_runtime_dims_or_strides()) return false;
333
334 int src_scales_mask, dst_scales_mask;
335 auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask);
336 if (status != status::success) return false;
337 int scales_mask = std::max(src_scales_mask, dst_scales_mask);
338
339 const bool w_groups = !one_of(tag_o, OIw4i16o4i, OIw2i8o4i, OIw4o4i,
340 OIhw4i16o4i, OIhw2i8o4i, OIhw4o4i, OIdhw4i16o4i, OIdhw2i8o4i,
341 OIdhw4o4i, OI4i16o4i, OI4i32o4i, OI4i64o4i, OIw4i32o4i,
342 OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i, OIdhw4i32o4i,
343 OIdhw4i64o4i);
344
345 const bool req_comp = output_d.extra().flags
346 & memory_extra_flags::compensation_conv_s8s8;
347 const bool req_asymmetric_comp = output_d.extra().flags
348 & memory_extra_flags::compensation_conv_asymmetric_src;
349
350 auto mask_ok = [&](bool check, int mask) {
351 return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1));
352 };
353
354 return simple_attr_check(attr, true, false)
355 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
356 && (req_comp || req_asymmetric_comp)
357 && mask_ok(req_comp, output_d.extra().compensation_mask)
358 && mask_ok(req_asymmetric_comp,
359 output_d.extra().asymm_compensation_mask)
360 && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1))
361 && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3))
362 && one_of(input_d.data_type(), f32, s8, bf16)
363 && output_d.data_type() == s8;
364 }
365
366 GET_SCRATCHPAD_SIZE_ZERO();
367
368 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
369 DECLARE_COMMON_PARAMS();
370 using namespace format_tag;
371
372 static constexpr bool w_groups = !utils::one_of(tag_o, OIw4o4i,
373 OIw4i16o4i, OIhw4i16o4i, OIdhw4i16o4i, OIhw4o4i, OIw2i8o4i,
374 OIhw2i8o4i, OIdhw2i8o4i, OIdhw4o4i, OI4i16o4i, OI4i32o4i,
375 OI4i64o4i, OIw4i32o4i, OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i,
376 OIdhw4i32o4i, OIdhw4i64o4i);
377
378 constexpr int is_0d
379 = utils::one_of(tag_o, OI4i16o4i, OI4i32o4i, OI4i64o4i);
380 constexpr int is_1d
381 = utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i, gOIw2i8o4i,
382 OIw2i8o4i, gOIw4o4i, OIw4o4i, OIw4i32o4i, OIw4i64o4i);
383 constexpr int is_3d = utils::one_of(tag_o, gOIdhw4i16o4i, OIdhw4i16o4i,
384 gOIdhw2i8o4i, OIdhw2i8o4i, gOIdhw4o4i, OIdhw4o4i, OIdhw4i32o4i,
385 OIdhw4i64o4i);
386 constexpr dim_t icblksize = utils::one_of(tag_traits<tag_o>::inner_blks,
387 ib::_4a4b, ib::_4b4c)
388 ? 4
389 : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_2c8b4c,
390 ib::_2b8a4b)
391 ? 8
392 : 16;
393 constexpr dim_t ocblksize
394 = tag_traits<tag_o>::inner_blks == ib::_4b32a4b
395 ? 32
396 : tag_traits<tag_o>::inner_blks == ib::_4b64a4b ? 64
397 : icblksize;
398
399 const auto &plain_d = order_keep ? input_d : output_d;
400 const auto &dims = input_d.dims();
401 const int ndims = input_d.ndims();
402 const auto &pdims
403 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
404
405 const dim_t G = w_groups ? dims[0] : 1;
406 const dim_t OC = dims[w_groups + 0];
407 const dim_t PADDED_OC = pdims[w_groups + 0];
408 const dim_t NB_OC = pdims[w_groups + 0] / ocblksize;
409 const dim_t IC = dims[w_groups + 1];
410 const dim_t NB_IC = pdims[w_groups + 1] / icblksize;
411 const dim_t D = is_3d ? dims[2 + w_groups] : 1;
412 const dim_t H = is_1d || is_0d ? 1 : dims[2 + w_groups + is_3d];
413 const dim_t W = is_0d ? 1 : dims[w_groups + is_3d + 3 - is_1d];
414
415 // XXX: Currently user can pass a mask that has non-zero values in
416 // dimensions that do not exist in a md. Since attributes are created
417 // separately mask can't be validated.
418 // This line truncates a given mask in range [0, 1 << ndims - 1]
419 // TODO: Such masks can be either prohibited at pd creation step at
420 // API level or checked by each implementation that relies on it.
421 scales_mask &= (1 << ndims) - 1;
422
423 const bool req_comp = output_d.extra().flags
424 & memory_extra_flags::compensation_conv_s8s8;
425 const bool has_asymmetric_comp = output_d.extra().flags
426 & memory_extra_flags::compensation_conv_asymmetric_src;
427
428 assert(req_comp || has_asymmetric_comp);
429
430 float adj_scale
431 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
432 ? output_d.extra().scale_adjust
433 : 1.f;
434
435 const bool per_oc = scales_mask & (1 << (w_groups + 0));
436 const bool per_ic = scales_mask & (1 << (w_groups + 1));
437 const size_t ic_stride = per_ic ? 1 : 0;
438 const size_t oc_stride = per_oc ? per_ic ? IC : 1 : 0;
439 const size_t nb_ic_stride = (per_ic ? 1 : 0) * icblksize;
440 const size_t nb_oc_stride = (per_oc ? per_ic ? IC : 1 : 0) * ocblksize;
441
442 // This kernel is used primarily for tensors with multiple inner
443 // blocks for which generic zero padding must be used.
444 // TODO: apply zero padding inside parallel_nd()
445 ctx.zero_pad_output(DNNL_ARG_TO);
446
447 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
448 int32_t *c, int32_t *zp, const float *s,
449 const float *d, const dim_t oc_block,
450 const dim_t ic_block) {
451#define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
452 for_(dim_t ic = 0; ic < ic_block; ++ic)
453 for (dim_t oc = 0; oc < oc_block; ++oc) {
454 const auto plain_off
455 = oc * plain_d.blocking_desc().strides[w_groups + 0]
456 + ic * plain_d.blocking_desc().strides[w_groups + 1];
457 const size_t os_off = oc * oc_stride + ic * ic_stride;
458 const float src_scale = s[src_scales_mask == 0 ? 0 : os_off];
459 const float dst_scale = d[dst_scales_mask == 0 ? 0 : os_off];
460 out[index(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()(
461 inp[plain_off], src_scale * adj_scale * dst_scale);
462 if (req_comp) c[oc] -= (128 * (int32_t)(out[index(oc, ic)]));
463 if (has_asymmetric_comp)
464 zp[oc] -= (int32_t)(out[index(oc, ic)]);
465 }
466#undef index
467 };
468
469 constexpr dim_t i_mult_ic = icblksize;
470 constexpr dim_t i_mult_oc = ocblksize;
471 constexpr dim_t o_mult = 1;
472
473 size_t offset = output_d.size() - output_d.additional_buffer_size();
474 size_t comp_size = output_d.additional_buffer_size(
475 memory_extra_flags::compensation_conv_s8s8);
476 size_t zp_offset = offset + (req_comp ? comp_size : 0);
477 int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
478 : nullptr;
479 int32_t *zp = has_asymmetric_comp
480 ? reinterpret_cast<int32_t *>(output + zp_offset)
481 : nullptr;
482
483 parallel_nd(G * PADDED_OC, [&](dim_t i) {
484 if (req_comp) cp[i] = 0;
485 if (has_asymmetric_comp) zp[i] = 0;
486 });
487
488#define wei_blk_off(md, g, o, i, d, h, w) \
489 (is_0d ? (md).blk_off<!w_groups>(g, o, i) \
490 : is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
491 : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \
492 : (md).blk_off<!w_groups>(g, o, i, h, w))
493 parallel_nd(G, NB_OC, [&](dim_t g, dim_t O) {
494 for_(dim_t I = 0; I < NB_IC; I++)
495 for_(dim_t d = 0; d < D; d++)
496 for_(dim_t h = 0; h < H; h++)
497 for (dim_t w = 0; w < W; w++) {
498 auto i = &input[wei_blk_off(
499 input_d, g, i_mult_oc * O, i_mult_ic * I, d, h, w)];
500 auto o = &output[wei_blk_off(
501 output_d, g, o_mult * O, o_mult * I, d, h, w)];
502 const dim_t oc_block = nstl::min(ocblksize, OC - O * ocblksize);
503 const dim_t ic_block = nstl::min(icblksize, IC - I * icblksize);
504 dim_t _offset = (g * NB_OC + O) * ocblksize;
505 dim_t os_nb_off
506 = (g * NB_OC + O) * nb_oc_stride + I * nb_ic_stride;
507 const float *src_scales_ptr
508 = &src_scales[src_scales_mask == 0 ? 0 : os_nb_off];
509 const float *dst_scales_ptr
510 = &dst_scales[dst_scales_mask == 0 ? 0 : os_nb_off];
511 ker(i, o, (order_keep && req_comp) ? &cp[_offset] : nullptr,
512 (order_keep && has_asymmetric_comp) ? &zp[_offset]
513 : nullptr,
514 src_scales_ptr, dst_scales_ptr, oc_block, ic_block);
515 }
516 });
517
518#undef wei_blk_off
519
520 return status::success;
521 }
522};
523
524/* Asymmetric Blocking */
525template <SIMPLE_REORDER_TEMPL_DECL>
526struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
527 typename utils::enable_if<(utils::one_of(tag_i, format_tag::iwo,
528 format_tag::oiw, format_tag::wio)
529 && utils::one_of(
530 tag_o, format_tag::Owi16o))
531 || (utils::one_of(
532 tag_i, format_tag::goiw, format_tag::wigo)
533 && utils::one_of(tag_o, format_tag::gOwi16o))
534 || (utils::one_of(tag_i, format_tag::ihwo,
535 format_tag::hwio, format_tag::oihw)
536 && utils::one_of(tag_o, format_tag::Owhi16o))
537 || (utils::one_of(
538 tag_i, format_tag::goihw, format_tag::hwigo)
539 && utils::one_of(tag_o, format_tag::gOwhi16o)),
540 spec::conv_req_comp>::type> {
541 static bool is_applicable(const memory_desc_wrapper &input_d,
542 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
543 using namespace format_tag;
544 using namespace data_type;
545 using namespace utils;
546
547 if (input_d.has_runtime_dims_or_strides()) return false;
548
549 const bool w_groups = !one_of(tag_o, Owi16o, Owhi16o);
550
551 // Current formats are only used in jit kernels that natively
552 // support s8 instructions, hence, there is no need for signed
553 // compensation.
554 const bool req_comp = output_d.extra().flags
555 & memory_extra_flags::compensation_conv_s8s8;
556
557 const bool req_asymmetric_comp = output_d.extra().flags
558 & memory_extra_flags::compensation_conv_asymmetric_src;
559
560 auto mask_ok = [&](bool check, int mask) {
561 const int c_mask = 0x1,
562 g_mask = 0x3; // mask for i/o-channel and ngroups
563 return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask));
564 };
565
566 return simple_attr_check(attr, true, false)
567 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
568 && mask_ok(req_asymmetric_comp,
569 output_d.extra().asymm_compensation_mask)
570 && one_of(input_d.data_type(), f32, s8, bf16)
571 && output_d.data_type() == s8 && !req_comp;
572 }
573
574 GET_SCRATCHPAD_SIZE_ZERO();
575
576 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
577 DECLARE_COMMON_PARAMS();
578 using namespace format_tag;
579
580 static constexpr bool w_groups = !utils::one_of(tag_o, Owi16o, Owhi16o);
581 constexpr int is_1d = utils::one_of(tag_o, Owi16o, gOwi16o);
582 const bool is_3d = false; // TODO once enabled
583
584 constexpr dim_t oc_blksize = 16;
585
586 const auto &plain_d = order_keep ? input_d : output_d;
587 const auto &dims = input_d.dims();
588 const auto &pdims
589 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
590
591 const dim_t G = w_groups ? dims[0] : 1;
592 const dim_t OC = dims[w_groups + 0];
593 const dim_t NB_OC = pdims[w_groups + 0] / oc_blksize;
594 const dim_t IC = dims[w_groups + 1];
595
596 const dim_t D = is_3d ? dims[2 + w_groups] : 1;
597 const dim_t H = is_1d ? 1 : dims[2 + w_groups + is_3d];
598 const dim_t W = dims[w_groups + is_3d + 3 - is_1d];
599
600 const bool has_asymmetric_comp = output_d.extra().flags
601 & memory_extra_flags::compensation_conv_asymmetric_src;
602
603 float adj_scale
604 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
605 ? output_d.extra().scale_adjust
606 : 1.f;
607
608 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
609 int32_t *zp, const float *s, const float *d,
610 const dim_t oc_block) {
611 for (dim_t oc = 0; oc < oc_block; ++oc) {
612 const auto plain_off
613 = oc * plain_d.blocking_desc().strides[w_groups + 0];
614 out[oc] = qz_b0<data_t<type_i>, data_t<type_o>>()(
615 inp[plain_off], s[oc] * adj_scale * d[oc]);
616 if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[oc]);
617 }
618 // fill memory with '0' in case of padded channel dimensions
619 for (dim_t oc = oc_block; oc < oc_blksize; ++oc) {
620 out[oc] = 0;
621 }
622 };
623
624 size_t offset = output_d.size() - output_d.additional_buffer_size();
625 int32_t *zp = has_asymmetric_comp
626 ? reinterpret_cast<int32_t *>(output + offset)
627 : nullptr;
628
629 if (has_asymmetric_comp) {
630 parallel_nd(G * NB_OC * oc_blksize, [&](dim_t i) { zp[i] = 0; });
631 }
632
633#define wei_blk_off(md, g, o, i, d, h, w) \
634 (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
635 : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \
636 : (md).blk_off<!w_groups>(g, o, i, h, w))
637
638 parallel_nd(G, NB_OC, [&](dim_t g, dim_t O) {
639 for_(dim_t I = 0; I < IC; I++)
640 for_(dim_t d = 0; d < D; d++)
641 for_(dim_t h = 0; h < H; h++)
642 for (dim_t w = 0; w < W; w++) {
643 auto i = &input[wei_blk_off(
644 input_d, g, oc_blksize * O, I, d, h, w)];
645 auto o = &output[wei_blk_off(output_d, g, O, I, d, h, w)];
646 const dim_t oc_block
647 = nstl::min(oc_blksize, OC - O * oc_blksize);
648 dim_t _offset = (g * NB_OC + O) * oc_blksize;
649 int32_t *zp_ptr = (order_keep && has_asymmetric_comp)
650 ? &zp[_offset]
651 : nullptr;
652 const float *src_scales_ptr
653 = &src_scales[src_scales_mask == 0 ? 0 : _offset];
654 const float *dst_scales_ptr
655 = &dst_scales[dst_scales_mask == 0 ? 0 : _offset];
656 ker(i, o, zp_ptr, src_scales_ptr, dst_scales_ptr, oc_block);
657 }
658 });
659
660#undef wei_blk_off
661
662 return status::success;
663 }
664};
665
666/* Asymmetric Blocking */
667template <SIMPLE_REORDER_TEMPL_DECL>
668struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
669 typename utils::enable_if<(utils::one_of(tag_i, format_tag::iwo,
670 format_tag::oiw, format_tag::wio)
671 && utils::one_of(tag_o,
672 format_tag::OwI16o4i,
673 format_tag::OIw16i16o4i))
674 || (utils::one_of(
675 tag_i, format_tag::goiw, format_tag::wigo)
676 && utils::one_of(tag_o, format_tag::gOwI16o4i,
677 format_tag::gOIw16i16o4i))
678 || (utils::one_of(tag_i, format_tag::ihwo,
679 format_tag::hwio, format_tag::oihw)
680 && utils::one_of(tag_o, format_tag::OhwI16o4i,
681 format_tag::OIhw16i16o4i))
682 || (utils::one_of(
683 tag_i, format_tag::goihw, format_tag::hwigo)
684 && utils::one_of(tag_o, format_tag::gOhwI16o4i,
685 format_tag::gOIhw16i16o4i))
686 || (utils::one_of(tag_i, format_tag::idhwo,
687 format_tag::dhwio, format_tag::oidhw)
688 && utils::one_of(tag_o, format_tag::OdhwI16o4i,
689 format_tag::OIdhw16i16o4i))
690 || (utils::one_of(tag_i, format_tag::goidhw)
691 && utils::one_of(tag_o, format_tag::gOdhwI16o4i,
692 format_tag::gOIdhw16i16o4i)),
693 spec::conv_req_comp>::type> {
694 static bool is_applicable(const memory_desc_wrapper &input_d,
695 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
696 using namespace format_tag;
697 using namespace data_type;
698 using namespace utils;
699
700 if (input_d.has_runtime_dims_or_strides()) return false;
701
702 int src_scales_mask, dst_scales_mask;
703 auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask);
704 if (status != status::success) return false;
705 int scales_mask = std::max(src_scales_mask, dst_scales_mask);
706
707 const bool w_groups = !one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i,
708 OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i);
709
710 // Current formats are only used in jit kernels that natively
711 // support s8 instructions, hence, there is no need for signed
712 // compensation.
713 const bool req_comp = output_d.extra().flags
714 & memory_extra_flags::compensation_conv_s8s8;
715
716 const bool req_asymmetric_comp = output_d.extra().flags
717 & memory_extra_flags::compensation_conv_asymmetric_src;
718
719 auto mask_ok = [&](bool check, int mask) {
720 const int c_mask = 0x1,
721 g_mask = 0x3; // mask for o-channel and ngroups
722 return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask));
723 };
724
725 return simple_attr_check(attr, true, false)
726 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
727 && mask_ok(req_asymmetric_comp,
728 output_d.extra().asymm_compensation_mask)
729 && one_of(input_d.data_type(), f32, s8, bf16)
730 && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1))
731 && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3))
732 && output_d.data_type() == s8 && !req_comp;
733 }
734
735 GET_SCRATCHPAD_SIZE_ZERO();
736
737 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
738 DECLARE_COMMON_PARAMS();
739 using namespace format_tag;
740
741 static constexpr bool w_groups
742 = !utils::one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i,
743 OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i);
744 constexpr int is_1d = utils::one_of(
745 tag_o, OwI16o4i, gOwI16o4i, OIw16i16o4i, gOIw16i16o4i);
746 const bool is_3d = utils::one_of(
747 tag_o, OdhwI16o4i, gOdhwI16o4i, OIdhw16i16o4i, gOIdhw16i16o4i);
748
749 constexpr dim_t oc_blksize = 16;
750 constexpr dim_t ic_blksize
751 = utils::one_of(tag_traits<tag_o>::inner_blks, ib::_16b16a4b,
752 ib::_16c16b4c)
753 ? 64
754 : utils::one_of(
755 tag_traits<tag_o>::inner_blks, ib::_16a4b, ib::_16b4c)
756 ? 4
757 : 1;
758 assert(ic_blksize != 1);
759
760 const auto &plain_d = order_keep ? input_d : output_d;
761 const auto &dims = input_d.dims();
762 const auto &pdims
763 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
764
765 const dim_t G = w_groups ? dims[0] : 1;
766 const dim_t OC = dims[w_groups + 0];
767 const dim_t NB_OC = pdims[w_groups + 0] / oc_blksize;
768 const dim_t IC = dims[w_groups + 1];
769 const dim_t NB_IC = pdims[w_groups + 1] / ic_blksize;
770
771 const dim_t D = is_3d ? dims[2 + w_groups] : 1;
772 const dim_t H = is_1d ? 1 : dims[2 + w_groups + is_3d];
773 const dim_t W = dims[w_groups + is_3d + 3 - is_1d];
774
775 const bool has_asymmetric_comp = output_d.extra().flags
776 & memory_extra_flags::compensation_conv_asymmetric_src;
777
778 float adj_scale
779 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
780 ? output_d.extra().scale_adjust
781 : 1.f;
782
783 // This kernel is used primarily for tensors with multiple inner
784 // blocks for which generic zero padding must be used.
785 // TODO: apply zero padding inside parallel_nd()
786 ctx.zero_pad_output(DNNL_ARG_TO);
787
788 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
789 int32_t *zp, const float *s, const float *d,
790 const dim_t oc_block, const dim_t ic_block) {
791 for_(dim_t ic = 0; ic < ic_block; ++ic)
792 for (dim_t oc = 0; oc < oc_block; ++oc) {
793 const auto plain_off
794 = oc * plain_d.blocking_desc().strides[w_groups + 0]
795 + ic * plain_d.blocking_desc().strides[w_groups + 1];
796 auto index = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
797 oc, ic);
798 out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
799 inp[plain_off], s[oc] * adj_scale * d[oc]);
800
801 if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[index]);
802 }
803 };
804
805 size_t offset = output_d.size() - output_d.additional_buffer_size();
806 int32_t *zp = has_asymmetric_comp
807 ? reinterpret_cast<int32_t *>(output + offset)
808 : nullptr;
809
810 if (has_asymmetric_comp) {
811 parallel_nd(G * NB_OC * oc_blksize, [&](dim_t i) { zp[i] = 0; });
812 }
813
814#define wei_blk_off(md, g, o, i, d, h, w) \
815 (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
816 : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \
817 : (md).blk_off<!w_groups>(g, o, i, h, w))
818
819 parallel_nd(G, NB_OC, [&](dim_t g, dim_t O) {
820 for_(dim_t I = 0; I < NB_IC; I++)
821 for_(dim_t d = 0; d < D; d++)
822 for_(dim_t h = 0; h < H; h++)
823 for (dim_t w = 0; w < W; w++) {
824 auto i = &input[wei_blk_off(
825 input_d, g, oc_blksize * O, ic_blksize * I, d, h, w)];
826 auto o = &output[wei_blk_off(output_d, g, O, I, d, h, w)];
827 const dim_t oc_block
828 = nstl::min(oc_blksize, OC - O * oc_blksize);
829 const dim_t ic_block
830 = nstl::min(ic_blksize, IC - I * ic_blksize);
831 dim_t _offset = (g * NB_OC + O) * oc_blksize;
832 int32_t *zp_ptr = (order_keep && has_asymmetric_comp)
833 ? &zp[_offset]
834 : nullptr;
835 const float *src_scales_ptr
836 = &src_scales[src_scales_mask == 0 ? 0 : _offset];
837 const float *dst_scales_ptr
838 = &dst_scales[dst_scales_mask == 0 ? 0 : _offset];
839 ker(i, o, zp_ptr, src_scales_ptr, dst_scales_ptr, oc_block,
840 ic_block);
841 }
842 });
843
844#undef wei_blk_off
845
846 return status::success;
847 }
848};
849
850template <SIMPLE_REORDER_TEMPL_DECL>
851struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
852 typename utils::enable_if<
853 (utils::one_of(tag_i, format_tag::ab, format_tag::ba,
854 format_tag::abc, format_tag::acb)
855 && utils::one_of(tag_o, format_tag::BA16a16b4a,
856 format_tag::BA16a32b4a, format_tag::BA16a48b4a,
857 format_tag::BA16a64b4a, format_tag::aCB16b16c4b,
858 format_tag::aCB16b32c4b,
859 format_tag::aCB16b48c4b,
860 format_tag::aCB16b64c4b)),
861 spec::conv_req_comp>::type> {
862 static bool is_applicable(const memory_desc_wrapper &input_d,
863 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
864 using namespace format_tag;
865 using namespace data_type;
866 using namespace utils;
867
868 if (input_d.has_runtime_dims_or_strides()) return false;
869
870 const bool req_comp = output_d.extra().flags
871 & memory_extra_flags::compensation_conv_s8s8;
872 const bool req_asymmetric_comp = output_d.extra().flags
873 & memory_extra_flags::compensation_conv_asymmetric_src;
874
875 const auto ndims = input_d.ndims();
876 auto mask_ok = [&](bool check, int mask) {
877 return IMPLICATION(
878 check, mask == (1 << ndims) - 1 - (1 << (ndims - 2)));
879 };
880
881 int src_scales_mask, dst_scales_mask;
882 auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask);
883 if (status != status::success) return false;
884 int scales_mask = std::max(src_scales_mask, dst_scales_mask);
885 const size_t D_mask
886 = array_product(input_d.dims(), math::ilog2q(scales_mask + 1));
887
888 return simple_attr_check(attr, true, false)
889 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
890 && mask_ok(req_comp, output_d.extra().compensation_mask)
891 && mask_ok(req_asymmetric_comp,
892 output_d.extra().asymm_compensation_mask)
893 && one_of(input_d.data_type(), f32, s8, bf16, f16)
894 && output_d.data_type() == s8 && D_mask == 1;
895 }
896
897 GET_SCRATCHPAD_SIZE_ZERO();
898
899 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
900 DECLARE_COMMON_PARAMS();
901 using namespace format_tag;
902
903 // {[batch_dim][d0][d1], [batch_dim][d1][d0]} -> [batch_dim][D1][D0][16][D1_blksize][4]
904 // 2D: batch_dim - none, d0 <-> a, d1 <-> b
905 // 3D: batch_dim <-> a, d0 <-> b, d1 <-> c
906 constexpr dim_t D0_blksize = 64;
907 constexpr dim_t D1_blksize
908 = (utils::one_of(tag_traits<tag_o>::inner_blks, ib::_16a64b4a,
909 ib::_16b64c4b))
910 ? 64
911 : (utils::one_of(tag_traits<tag_o>::inner_blks, ib::_16a48b4a,
912 ib::_16b48c4b))
913 ? 48
914 : (utils::one_of(tag_traits<tag_o>::inner_blks,
915 ib::_16a32b4a, ib::_16b32c4b))
916 ? 32
917 : (utils::one_of(tag_traits<tag_o>::inner_blks,
918 ib::_16a16b4a, ib::_16b16c4b))
919 ? 16
920 : 1;
921 assert(D1_blksize != 1);
922
923 const auto &plain_d = order_keep ? input_d : output_d;
924 const auto &dims = input_d.dims();
925 const auto ndims = input_d.ndims();
926 const auto &pdims
927 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
928
929 const dim_t batch_dim = ndims > 2 ? dims[ndims - 3] : 1;
930 const dim_t D0dim = dims[ndims - 2];
931 const dim_t NB_D0dim = pdims[ndims - 2] / D0_blksize;
932 const dim_t D1dim = dims[ndims - 1];
933 const dim_t NB_D1dim = pdims[ndims - 1] / D1_blksize;
934 assert(pdims[ndims - 1] == NB_D1dim * D1_blksize);
935
936 const bool req_comp = output_d.extra().flags
937 & memory_extra_flags::compensation_conv_s8s8;
938 const bool has_asymmetric_comp = output_d.extra().flags
939 & memory_extra_flags::compensation_conv_asymmetric_src;
940
941 float adj_scale
942 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
943 ? output_d.extra().scale_adjust
944 : 1.f;
945
946 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
947 int32_t *cp, int32_t *zp, const float *s,
948 const float *d, const int d0_block,
949 const int d1_block) {
950 for (int d0 = 0; d0 < d0_block; ++d0) {
951 for (int d1 = 0; d1 < d1_block; ++d1) {
952 const auto plain_off
953 = d0 * plain_d.blocking_desc().strides[ndims - 2]
954 + d1 * plain_d.blocking_desc().strides[ndims - 1];
955 auto index
956 = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
957 d0, d1);
958 out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
959 inp[plain_off], s[0] * adj_scale * d[0]);
960
961 auto o = static_cast<int32_t>(out[index]);
962 if (req_comp) cp[d1] -= (128 * o);
963 if (has_asymmetric_comp) zp[d1] -= o;
964 }
965 for (int d1 = d1_block; d1 < D1_blksize; ++d1) {
966 auto index
967 = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
968 d0, d1);
969 out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
970 0, s[0] * adj_scale * d[0]);
971 }
972 }
973
974 for_(int d0 = d0_block; d0 < D0_blksize; ++d0)
975 for (int d1 = 0; d1 < D1_blksize; ++d1) {
976 auto index = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
977 d0, d1);
978 out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
979 0, s[0] * adj_scale * d[0]);
980 }
981 };
982
983 const auto w_d = order_keep ? output_d : input_d;
984 size_t offset = w_d.size() - w_d.additional_buffer_size();
985 size_t comp_size = output_d.additional_buffer_size(
986 memory_extra_flags::compensation_conv_s8s8);
987 size_t zp_offset = offset + (req_comp ? comp_size : 0);
988 int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
989 : nullptr;
990 int32_t *zp = has_asymmetric_comp
991 ? reinterpret_cast<int32_t *>(output + zp_offset)
992 : nullptr;
993
994 if (has_asymmetric_comp || req_comp) {
995 parallel_nd(batch_dim * NB_D1dim * D1_blksize, [&](dim_t i) {
996 if (req_comp) cp[i] = 0;
997 if (has_asymmetric_comp) zp[i] = 0;
998 });
999 }
1000
1001#define get_blk_off(md, batch, d0, d1) \
1002 (ndims == 3 ? (md).blk_off((batch), (d0), (d1)) : (md).blk_off((d0), (d1)))
1003
1004 parallel_nd(batch_dim, NB_D1dim, [&](dim_t batch, dim_t D1) {
1005 for (int D0 = 0; D0 < NB_D0dim; D0++) {
1006 auto i = &input[get_blk_off(
1007 input_d, batch, D0_blksize * D0, D1_blksize * D1)];
1008 auto o = &output[get_blk_off(output_d, batch, D0, D1)];
1009 const dim_t d0_block
1010 = nstl::min(D0_blksize, D0dim - D0 * D0_blksize);
1011 const dim_t d1_block
1012 = nstl::min(D1_blksize, D1dim - D1 * D1_blksize);
1013 dim_t _offset = batch * NB_D1dim * D1_blksize + D1 * D1_blksize;
1014 int32_t *zp_ptr = (order_keep && has_asymmetric_comp)
1015 ? &zp[_offset]
1016 : nullptr;
1017 const float *src_scales_ptr
1018 = &src_scales[src_scales_mask == 0 ? 0 : _offset];
1019 const float *dst_scales_ptr
1020 = &dst_scales[dst_scales_mask == 0 ? 0 : _offset];
1021 ker(i, o, (order_keep && req_comp) ? &cp[_offset] : nullptr,
1022 zp_ptr, src_scales_ptr, dst_scales_ptr, d0_block,
1023 d1_block);
1024 }
1025 });
1026
1027#undef get_blk_off
1028
1029 return status::success;
1030 }
1031};
1032
1033template <SIMPLE_REORDER_TEMPL_DECL>
1034struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1035 typename utils::enable_if<
1036 (utils::one_of(tag_i, format_tag::goiw, format_tag::wigo)
1037 && utils::one_of(tag_o, format_tag::Goiw16g,
1038 format_tag::Goiw8g, format_tag::Goiw4g))
1039 || (utils::one_of(
1040 tag_i, format_tag::goihw, format_tag::hwigo)
1041 && utils::one_of(tag_o, format_tag::Goihw16g,
1042 format_tag::Goihw8g,
1043 format_tag::Goihw4g)),
1044 spec::conv_req_comp>::type> {
1045 static bool is_applicable(const memory_desc_wrapper &input_d,
1046 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1047 using namespace data_type;
1048 using namespace utils;
1049
1050 if (input_d.has_runtime_dims_or_strides()) return false;
1051
1052 int src_scales_mask, dst_scales_mask;
1053 auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask);
1054 if (status != status::success) return false;
1055 int scales_mask = std::max(src_scales_mask, dst_scales_mask);
1056
1057 const dim_t g = input_d.dims()[0];
1058 const dim_t oc = input_d.dims()[1];
1059 const dim_t ic = input_d.dims()[2];
1060
1061 const bool req_comp = output_d.extra().flags
1062 & memory_extra_flags::compensation_conv_s8s8;
1063 const bool req_asymmetric_comp = output_d.extra().flags
1064 & memory_extra_flags::compensation_conv_asymmetric_src;
1065 int s8s8_comp_mask = output_d.extra().compensation_mask;
1066 int zp_comp_mask = output_d.extra().asymm_compensation_mask;
1067 int comp_mask = std::max(s8s8_comp_mask, zp_comp_mask);
1068
1069 const size_t D_mask
1070 = array_product(input_d.dims(), math::ilog2q(comp_mask + 1));
1071
1072 return order_keep && oc == 1 && ic == 1 // depth-wise case
1073 && simple_attr_check(attr, true, false)
1074 && (req_comp || req_asymmetric_comp)
1075 && IMPLICATION(req_comp && req_asymmetric_comp,
1076 output_d.extra().compensation_mask
1077 == output_d.extra().asymm_compensation_mask)
1078 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
1079 && IMPLICATION(
1080 req_comp, one_of(D_mask, (size_t)1, (size_t)g * oc))
1081 && one_of(scales_mask, 0, 0x3)
1082 && one_of(input_d.data_type(), f32, s8, bf16)
1083 && output_d.data_type() == s8;
1084 }
1085
1086 GET_SCRATCHPAD_SIZE_ZERO();
1087
1088 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1089 DECLARE_COMMON_PARAMS();
1090
1091 constexpr bool is_1d
1092 = utils::one_of(tag_i, format_tag::goiw, format_tag::wigo);
1093 constexpr dim_t blksize
1094 = utils::one_of(tag_o, format_tag::Goihw4g, format_tag::Goiw4g)
1095 ? 4
1096 : utils::one_of(tag_o, format_tag::Goihw8g, format_tag::Goiw8g)
1097 ? 8
1098 : 16;
1099
1100 const auto &dims = input_d.dims();
1101 const auto &pdims = output_d.padded_dims();
1102 const dim_t G = dims[0];
1103 const dim_t Gp = pdims[0];
1104 const dim_t OC = dims[1];
1105 const dim_t IC = dims[2];
1106 const dim_t H = is_1d ? 1 : dims[3];
1107 const dim_t W = dims[4 - is_1d];
1108 const bool zero_padding_needed = !output_d.is_dense();
1109
1110 const bool req_comp = output_d.extra().flags
1111 & memory_extra_flags::compensation_conv_s8s8;
1112 const bool has_asymmetric_comp = output_d.extra().flags
1113 & memory_extra_flags::compensation_conv_asymmetric_src;
1114
1115 assert(req_comp || has_asymmetric_comp);
1116
1117 float adj_scale
1118 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
1119 ? output_d.extra().scale_adjust
1120 : 1.f;
1121
1122 auto ker_out = [&](const data_t<type_i> *inp, data_t<type_o> *out,
1123 const float *src_scales, const float *dst_scales,
1124 const dim_t g_block) {
1125 PRAGMA_OMP_SIMD()
1126 for (dim_t g = 0; g < g_block; g++) {
1127 const auto i_off = g * input_d.blocking_desc().strides[0];
1128 const float src_scale
1129 = src_scales[src_scales_mask == 0 ? 0 : g * OC];
1130 const float dst_scale
1131 = dst_scales[dst_scales_mask == 0 ? 0 : g * OC];
1132 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
1133 inp[i_off], src_scale * adj_scale * dst_scale);
1134 }
1135 };
1136
1137 /* Note: having separate kernels for s8 and zero-point fixes a
1138 * compiler-generated bug which results in seg-fault. */
1139 auto ker_s8 = [&](const data_t<type_o> *out, int32_t *cp,
1140 const dim_t g_block) {
1141 PRAGMA_OMP_SIMD()
1142 for (dim_t g = 0; g < g_block; g++) {
1143 cp[g * OC] -= 128 * (int32_t)(out[g]);
1144 }
1145 };
1146 auto ker_zp = [&](const data_t<type_o> *out, int32_t *zp,
1147 const dim_t g_block) {
1148 PRAGMA_OMP_SIMD()
1149 for (dim_t g = 0; g < g_block; g++) {
1150 zp[g * OC] -= (int32_t)(out[g]);
1151 }
1152 };
1153
1154 size_t offset = output_d.size() - output_d.additional_buffer_size();
1155 size_t comp_size = output_d.additional_buffer_size(
1156 memory_extra_flags::compensation_conv_s8s8);
1157 size_t zp_offset = offset + (req_comp ? comp_size : 0);
1158 int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
1159 : nullptr;
1160 int32_t *zp = has_asymmetric_comp
1161 ? reinterpret_cast<int32_t *>(output + zp_offset)
1162 : nullptr;
1163
1164 parallel_nd((Gp / blksize) * OC, [&](dim_t ib) {
1165 PRAGMA_OMP_SIMD()
1166 for (dim_t i = 0; i < blksize; i++) {
1167 if (req_comp) cp[ib * blksize + i] = 0;
1168 if (has_asymmetric_comp) zp[ib * blksize + i] = 0;
1169 }
1170 });
1171
1172#define wei_blk_off(md, g, o, i, h, w) \
1173 (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w))
1174
1175 parallel_nd(Gp / blksize, OC, [&](dim_t gb, dim_t O) {
1176 for (dim_t I = 0; I < IC; I++) {
1177 for_(dim_t h = 0; h < H; h++)
1178 for (dim_t w = 0; w < W; w++) {
1179 const dim_t g_block = nstl::min(G - gb * blksize, blksize);
1180 const auto inp = &input[wei_blk_off(
1181 input_d, gb * blksize, O, I, h, w)];
1182 const auto out
1183 = &output[wei_blk_off(output_d, gb, O, I, h, w)];
1184 dim_t offset = gb * blksize + O;
1185 const float *src_scales_ptr
1186 = &src_scales[src_scales_mask == 0 ? 0 : offset];
1187 const float *dst_scales_ptr
1188 = &dst_scales[dst_scales_mask == 0 ? 0 : offset];
1189
1190 ker_out(inp, out, src_scales_ptr, dst_scales_ptr, g_block);
1191 if (req_comp) ker_s8(out, &cp[offset], g_block);
1192 if (has_asymmetric_comp) ker_zp(out, &zp[offset], g_block);
1193
1194 if (zero_padding_needed) {
1195 PRAGMA_OMP_SIMD()
1196 for (int off = g_block; off < blksize; off++)
1197 out[off] = 0;
1198 }
1199 }
1200 }
1201 });
1202
1203#undef wei_blk_off
1204
1205 return status::success;
1206 }
1207};
1208
1209/* bf16 reorders */
1210template <SIMPLE_REORDER_TEMPL_DECL>
1211struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1212 typename utils::enable_if<(
1213 (tag_i == format_tag::goihw || tag_i == format_tag::oihw)
1214 && (tag_o == format_tag::gOIhw16i16o
1215 || tag_o == format_tag::OIhw16i16o
1216 || tag_o == format_tag::gOIhw8i16o2i
1217 || tag_o == format_tag::OIhw8i16o2i
1218 || tag_o == format_tag::gOIhw8o16i2o
1219 || tag_o == format_tag::OIhw8o16i2o
1220 || tag_o == format_tag::gIOhw8o16i2o
1221 || tag_o == format_tag::IOhw8o16i2o)
1222 && type_i == data_type::f32
1223 && type_o == data_type::bf16)>::type> {
1224 static bool is_applicable(const memory_desc_wrapper &input_d,
1225 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1226 using namespace data_type;
1227
1228 if (input_d.has_runtime_dims_or_strides()) return false;
1229
1230 return order_keep && input_d.matches_tag(tag_i)
1231 && output_d.matches_tag(tag_o) && input_d.data_type() == f32
1232 && output_d.data_type() == bf16 && attr->has_default_values();
1233 }
1234
1235 static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
1236 const memory_desc_wrapper &output_d) {
1237 const dim_t blksize = 16;
1238 return sizeof(float) * blksize * blksize * dnnl_get_max_threads();
1239 }
1240
1241 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1242 DECLARE_COMMON_PARAMS();
1243 using namespace format_tag;
1244
1245 static constexpr bool w_groups = tag_i == goihw;
1246 const dim_t blksize = 16;
1247 const int sblk = 2;
1248
1249 const auto &plain_d = input_d;
1250 const auto &dims = input_d.dims();
1251 const auto &pdims = output_d.padded_dims();
1252
1253 const dim_t G = w_groups ? dims[0] : 1;
1254 const dim_t OC = dims[w_groups + 0];
1255 const dim_t NB_OC = pdims[w_groups + 0] / blksize;
1256 const dim_t IC = dims[w_groups + 1];
1257 const dim_t NB_IC = pdims[w_groups + 1] / blksize;
1258 const dim_t H = dims[w_groups + 2];
1259 const dim_t W = dims[w_groups + 3];
1260
1261 const size_t wsp_size = blksize * blksize;
1262 float *wspace = scratchpad.template get<float>(
1263 memory_tracking::names::key_reorder_space);
1264
1265 auto index = [&](dim_t ic, dim_t oc) -> dim_t {
1266 if (utils::one_of(tag_o, gOIhw16i16o, OIhw16i16o))
1267 return (ic * blksize + oc);
1268 else if (utils::one_of(tag_o, gOIhw8i16o2i, OIhw8i16o2i))
1269 return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
1270 else if (utils::one_of(tag_o, gOIhw8o16i2o, gIOhw8o16i2o,
1271 OIhw8o16i2o, IOhw8o16i2o))
1272 return ((oc / sblk) * blksize * sblk + sblk * ic + oc % sblk);
1273 else
1274 assert(!"Invalid weight format");
1275 return dim_t(0);
1276 };
1277
1278 auto ker = [&](const data_t<type_i> *inp, data_t<type_i> *out,
1279 const dim_t curr_oc_block, const dim_t oc_block,
1280 const dim_t curr_ic_block, const dim_t ic_block) {
1281 dim_t ic = 0;
1282 for (ic = 0; ic < curr_ic_block; ++ic) {
1283 dim_t oc = 0;
1284 for (oc = 0; oc < curr_oc_block; ++oc) {
1285 const auto plain_off
1286 = oc * plain_d.blocking_desc().strides[w_groups + 0]
1287 + ic
1288 * plain_d.blocking_desc()
1289 .strides[w_groups + 1];
1290 out[index(ic, oc)] = inp[plain_off];
1291 }
1292 for (/* continue */; oc < oc_block; ++oc) {
1293 out[index(ic, oc)] = (data_t<type_i>)0;
1294 }
1295 }
1296 for (/* continue */; ic < ic_block; ++ic) {
1297 for (dim_t oc = 0; oc < oc_block; ++oc) {
1298 out[index(ic, oc)] = (data_t<type_i>)0;
1299 }
1300 }
1301 };
1302
1303 constexpr int i_mult = blksize;
1304 constexpr int o_mult = 1;
1305
1306 parallel_nd_ext(0, G, NB_OC, NB_IC, H, W,
1307 [&](int ithr, int, dim_t g, dim_t O, dim_t I, dim_t h,
1308 dim_t w) {
1309 float *_wspace = wspace + wsp_size * ithr;
1310 auto i = &input[input_d.blk_off<!w_groups>(
1311 g, i_mult * O, i_mult * I, h, w)];
1312 auto o = &output[output_d.blk_off<!w_groups>(
1313 g, o_mult * O, o_mult * I, h, w)];
1314 const dim_t oc_block = nstl::min(blksize, OC - O * blksize);
1315 const dim_t ic_block = nstl::min(blksize, IC - I * blksize);
1316 ker(i, _wspace, oc_block, blksize, ic_block, blksize);
1317 cvt_float_to_bfloat16(o, _wspace, wsp_size);
1318 });
1319
1320 return status::success;
1321 }
1322};
1323
1324template <SIMPLE_REORDER_TEMPL_DECL>
1325struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1326 typename utils::enable_if<(tag_i == format_tag::nchw
1327 && tag_o == format_tag::nChw16c)
1328 && type_i == data_type::f32
1329 && type_o == data_type::bf16>::type> {
1330 static bool is_applicable(const memory_desc_wrapper &input_d,
1331 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1332 using namespace data_type;
1333
1334 if (input_d.has_runtime_dims_or_strides()) return false;
1335
1336 return input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
1337 && input_d.data_type() == f32 && output_d.data_type() == bf16
1338 && attr->has_default_values();
1339 }
1340
1341 static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
1342 const memory_desc_wrapper &output_d) {
1343 const size_t blksize = 16;
1344 const size_t W = input_d.dims()[3];
1345 return sizeof(float) * blksize * W * dnnl_get_max_threads();
1346 }
1347
1348 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1349 DECLARE_COMMON_PARAMS();
1350
1351 const dim_t blksize = 16;
1352
1353 const auto &flat_d = input_d;
1354 const auto &dims = input_d.dims();
1355 const auto &pdims = output_d.padded_dims();
1356
1357 const dim_t C = dims[1];
1358 const dim_t H = dims[2];
1359 const dim_t W = dims[3];
1360
1361 const dim_t wsp_size = W * blksize;
1362 float *wspace = scratchpad.template get<float>(
1363 memory_tracking::names::key_reorder_space);
1364
1365 auto ker = [&](const data_t<type_i> *i, data_t<type_i> *o,
1366 const dim_t curr_c_block, const dim_t c_block) {
1367 for (dim_t w = 0; w < W; ++w) {
1368 dim_t c = 0;
1369 for (c = 0; c < curr_c_block; ++c) {
1370 const ptrdiff_t flat_off = 0
1371 + c * flat_d.blocking_desc().strides[1]
1372 + w * flat_d.blocking_desc().strides[3];
1373 o[w * blksize + c] = i[flat_off];
1374 }
1375 for (/* continue */; c < c_block; ++c) {
1376 o[w * blksize + c] = (data_t<type_i>)0;
1377 }
1378 }
1379 };
1380
1381 constexpr int i_c_mult = blksize;
1382 constexpr int o_c_mult = 1;
1383
1384 parallel_nd_ext(0, dims[0], pdims[1] / blksize, H,
1385 [&](int ithr, int, dim_t n, dim_t nb_c, dim_t h) {
1386 float *_wspace = wspace + wsp_size * ithr;
1387 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
1388 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
1389 const dim_t c_block
1390 = nstl::min(blksize, C - nb_c * blksize);
1391 ker(i, _wspace, c_block, blksize);
1392 cvt_float_to_bfloat16(o, _wspace, wsp_size);
1393 });
1394
1395 return status::success;
1396 }
1397};
1398
1399/* reorders with tail support */
1400
1401template <SIMPLE_REORDER_TEMPL_DECL>
1402struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1403 typename utils::enable_if<false
1404 || (utils::one_of(
1405 tag_i, format_tag::nCdhw4c, format_tag::nCdhw8c)
1406 && tag_o == format_tag::nCdhw16c)
1407 || (utils::one_of(tag_i, format_tag::nChw4c, format_tag::nChw8c)
1408 && tag_o == format_tag::nChw16c)
1409 || (utils::one_of(tag_i, format_tag::nCw4c, format_tag::nCw8c)
1410 && tag_o == format_tag::nCw16c)>::type> {
1411 static bool is_applicable(const memory_desc_wrapper &input_d,
1412 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1413 return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d)
1414 && simple_attr_check(attr, false, true);
1415 }
1416
1417 GET_SCRATCHPAD_SIZE_ZERO();
1418
1419 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1420 DECLARE_COMMON_PARAMS();
1421 using namespace format_tag;
1422
1423 constexpr int is_1d = utils::one_of(tag_i, nCw4c, nCw8c);
1424 constexpr int is_3d = utils::one_of(tag_i, nCdhw4c, nCdhw8c);
1425
1426 constexpr dim_t blksize_i
1427 = tag_traits<tag_i>::inner_blks == ib::_4b ? 4 : 8;
1428 constexpr dim_t blksize_16 = 16;
1429
1430 constexpr dim_t ic_mult = order_keep ? blksize_16 / blksize_i : 1;
1431 constexpr dim_t oc_mult = order_keep ? 1 : blksize_16 / blksize_i;
1432
1433 const auto &dims = input_d.dims();
1434 const auto &pdims
1435 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
1436
1437 const auto &d_i = order_keep ? input_d : output_d;
1438 const auto stride_C_in_blk_i = d_i.blocking_desc().strides[1];
1439
1440 const dim_t C = dims[1];
1441 const dim_t D = is_3d ? dims[2] : 1;
1442 const dim_t H = is_1d ? 1 : dims[2 + is_3d];
1443 const dim_t W = dims[3 + is_3d - is_1d];
1444
1445 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1446 const int block) {
1447 const int nb = utils::div_up(block, blksize_i);
1448 if (alpha == 1.0 && beta == 0.0) {
1449 for (int b = 0; b < nb; ++b) {
1450 const ptrdiff_t i_off
1451 = b * (order_keep ? stride_C_in_blk_i : blksize_i);
1452 const ptrdiff_t o_off
1453 = b * (order_keep ? blksize_i : stride_C_in_blk_i);
1454 const int block_i
1455 = nstl::min(blksize_i, block - b * blksize_i);
1456 for (int c = 0; c < block_i; ++c) {
1457 o[o_off + c] = _qz_a1b0<type_i, type_o>()(i[i_off + c]);
1458 }
1459 if (b + 1 == nb) {
1460 // zero padding
1461 const auto pad_size = order_keep
1462 ? blksize_16 - ((nb - 1) * blksize_i)
1463 : blksize_i;
1464 const auto pad_start = block_i + o_off;
1465 const auto pad_end = pad_size + o_off;
1466 PRAGMA_OMP_SIMD()
1467 for (int i = pad_start; i < pad_end; i++) {
1468 o[i] = 0;
1469 }
1470 }
1471 }
1472 } else {
1473 for (int b = 0; b < nb; ++b) {
1474 const ptrdiff_t i_off
1475 = b * (order_keep ? stride_C_in_blk_i : blksize_i);
1476 const ptrdiff_t o_off
1477 = b * (order_keep ? blksize_i : stride_C_in_blk_i);
1478 const int block_i
1479 = nstl::min(blksize_i, block - b * blksize_i);
1480 for (int c = 0; c < block_i; ++c) {
1481 o[o_off + c] = _qz<type_i, type_o>()(
1482 i[i_off + c], o[o_off + c], alpha, beta);
1483 }
1484 if (b + 1 == nb) {
1485 // zero padding
1486 const auto pad_size = order_keep
1487 ? blksize_16 - ((nb - 1) * blksize_i)
1488 : blksize_i;
1489 const auto pad_start = block_i + o_off;
1490 const auto pad_end = pad_size + o_off;
1491 PRAGMA_OMP_SIMD()
1492 for (int i = pad_start; i < pad_end; i++) {
1493 o[i] = 0;
1494 }
1495 }
1496 }
1497 }
1498 };
1499
1500#define data_blk_off(md, n, c, d, h, w) \
1501 (is_1d ? (md).blk_off(n, c, w) \
1502 : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
1503
1504 parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
1505 [&](dim_t n, dim_t nb_c, dim_t d, dim_t h, dim_t w) {
1506 auto i = &input[data_blk_off(
1507 input_d, n, ic_mult * nb_c, d, h, w)];
1508 auto o = &output[data_blk_off(
1509 output_d, n, oc_mult * nb_c, d, h, w)];
1510 const int block
1511 = nstl::min(blksize_16, C - nb_c * blksize_16);
1512 ker(i, o, block);
1513 });
1514
1515#undef data_blk_off
1516
1517 return status::success;
1518 }
1519};
1520
1521#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
1522 static bool is_applicable(const memory_desc_wrapper &input_d, \
1523 const memory_desc_wrapper &output_d, \
1524 const primitive_attr_t *attr) { \
1525 return !input_d.has_runtime_dims_or_strides() \
1526 && simple_attr_check(attr, false, true) \
1527 && (order_keep ? output_d.matches_tag(tag_o) \
1528 && input_d.is_plain() \
1529 : input_d.matches_tag(tag_o) \
1530 && output_d.is_plain()); \
1531 }
1532
1533template <SIMPLE_REORDER_TEMPL_DECL>
1534struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1535 typename utils::enable_if<tag_i == format_tag::any
1536 && (tag_traits<tag_o>::block_dims == bd::_A
1537 || tag_traits<tag_o>::block_dims == bd::_B)
1538 && tag_traits<tag_o>::ndims >= 3
1539 && tag_traits<tag_o>::ndims <= 6>::type> {
1540 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1541
1542 GET_SCRATCHPAD_SIZE_ZERO();
1543
1544 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1545 DECLARE_COMMON_PARAMS();
1546
1547 const auto &flat_d = order_keep ? input_d : output_d;
1548 const auto &block_d = order_keep ? output_d : input_d;
1549 const dims_t &dims = input_d.dims();
1550 const dims_t &pdims = block_d.padded_dims();
1551
1552 const int ndims = tag_traits<tag_o>::ndims;
1553 const int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1;
1554
1555 const dim_t H0 = dims[0];
1556 const dim_t H1 = dims[1];
1557 const dim_t M0 = ndims == 6 ? dims[ndims - 4] : 1;
1558 const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1;
1559 const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1;
1560 const dim_t L = dims[ndims - 1];
1561 const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1];
1562 const dim_t l_flat_stride = flat_d.blocking_desc().strides[ndims - 1];
1563 const dim_t blk_flat_stride = flat_d.blocking_desc().strides[blk_idx];
1564 using namespace data_type;
1565 using namespace utils;
1566
1567 dim_t blksize = -1;
1568 switch (tag_traits<tag_o>::inner_blks) {
1569 case ib::_4a:
1570 case ib::_4b: blksize = 4; break;
1571 case ib::_8a:
1572 case ib::_8b: blksize = 8; break;
1573 default: blksize = 16;
1574 }
1575
1576 constexpr bool f32bf16
1577 = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16);
1578
1579 auto wrap_qz_a1b0 = [=](data_t<type_o> &out, data_t<type_i> inp) {
1580 if (f32bf16)
1581 out = inp;
1582 else
1583 out = _qz_a1b0<type_i, type_o>()(inp);
1584 };
1585
1586 auto wrap_qz = [=](data_t<type_o> &out, data_t<type_i> inp, float alpha,
1587 float beta) {
1588 if (f32bf16)
1589 out = alpha * inp + (beta ? beta * out : 0);
1590 else
1591 out = _qz<type_i, type_o>()(inp, out, alpha, beta);
1592 };
1593
1594 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) {
1595 if (alpha == 1.0 && beta == 0.0) {
1596 for (int l = 0; l < L; ++l) {
1597 for (int blk = 0; blk < block; ++blk) {
1598 const dim_t flat_off
1599 = blk * blk_flat_stride + l * l_flat_stride;
1600 const dim_t blk_offset = l * l_blk_stride + blk;
1601 if (order_keep) {
1602 wrap_qz_a1b0(o[blk_offset], i[flat_off]);
1603 } else {
1604 wrap_qz_a1b0(o[flat_off], i[blk_offset]);
1605 }
1606 }
1607 if (order_keep) {
1608 // zero padding
1609 const auto pad_start = block + l * l_blk_stride;
1610 const auto pad_end = blksize + l * l_blk_stride;
1611 PRAGMA_OMP_SIMD()
1612 for (int i = pad_start; i < pad_end; ++i) {
1613 o[i] = 0;
1614 }
1615 }
1616 }
1617 } else {
1618 for (int l = 0; l < L; ++l) {
1619 for (int blk = 0; blk < block; ++blk) {
1620 const dim_t flat_off
1621 = blk * blk_flat_stride + l * l_flat_stride;
1622 const dim_t blk_offset = l * l_blk_stride + blk;
1623 if (order_keep)
1624 wrap_qz(o[blk_offset], i[flat_off], alpha, beta);
1625 else
1626 wrap_qz(o[flat_off], i[blk_offset], alpha, beta);
1627 }
1628 if (order_keep) {
1629 // zero padding
1630 const auto pad_start = block + l * l_blk_stride;
1631 const auto pad_end = blksize + l * l_blk_stride;
1632 PRAGMA_OMP_SIMD()
1633 for (int i = pad_start; i < pad_end; ++i) {
1634 o[i] = 0;
1635 }
1636 }
1637 }
1638 }
1639 };
1640
1641#define off(md, h0, h1, m0, m1, m2) \
1642 (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \
1643 : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \
1644 : ndims >= 4 \
1645 ? (md).blk_off(h0, h1, m2) \
1646 : /* ndims >= 3 ? */ (md).blk_off(h0, h1))
1647
1648 const int i_mult = order_keep ? blksize : 1;
1649 const int o_mult = order_keep ? 1 : blksize;
1650
1651 if (blk_idx == 0) {
1652 const dim_t BH0 = pdims[0] / blksize;
1653 parallel_nd(BH0, H1, M0, M1, M2,
1654 [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) {
1655 auto i = &input[off(
1656 input_d, bh0 * i_mult, h1, m0, m1, m2)];
1657 auto o = &output[off(
1658 output_d, bh0 * o_mult, h1, m0, m1, m2)];
1659 const int block
1660 = nstl::min<int>(blksize, H0 - bh0 * blksize);
1661 ker(i, o, block);
1662 });
1663 } else if (blk_idx == 1) {
1664 const dim_t BH1 = pdims[1] / blksize;
1665 parallel_nd(H0, BH1, M0, M1, M2,
1666 [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) {
1667 auto i = &input[off(
1668 input_d, h0, bh1 * i_mult, m0, m1, m2)];
1669 auto o = &output[off(
1670 output_d, h0, bh1 * o_mult, m0, m1, m2)];
1671 const int block
1672 = nstl::min<int>(blksize, H1 - bh1 * blksize);
1673 ker(i, o, block);
1674 });
1675 } else {
1676 assert(!"unimplemented");
1677 }
1678
1679#undef off
1680
1681 return status::success;
1682 }
1683};
1684
1685template <SIMPLE_REORDER_TEMPL_DECL>
1686struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1687 typename utils::enable_if<tag_i == format_tag::any
1688 && (tag_traits<tag_o>::block_dims == bd::_AB
1689 || tag_traits<tag_o>::block_dims == bd::_BC)
1690 && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB,
1691 tag_traits<tag_o>::ndims >= 3
1692 && tag_traits<tag_o>::ndims <= 5)
1693 && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC,
1694 tag_traits<tag_o>::ndims >= 4
1695 && tag_traits<tag_o>::ndims <= 6)>::type> {
1696 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1697
1698 GET_SCRATCHPAD_SIZE_ZERO();
1699
1700 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1701 DECLARE_COMMON_PARAMS();
1702
1703 const auto &flat_d = order_keep ? input_d : output_d;
1704 const auto &dims = input_d.dims();
1705 const auto &pdims
1706 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
1707
1708 constexpr int ndims = tag_traits<tag_o>::ndims;
1709
1710 static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC;
1711 const dim_t G = with_g ? dims[0] : 1;
1712
1713 const dim_t H0 = dims[0 + with_g];
1714 const dim_t H1 = dims[1 + with_g];
1715
1716 const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1;
1717 const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1;
1718 const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1;
1719
1720 const dim_t h0_flat_stride = flat_d.blocking_desc().strides[with_g + 0];
1721 const dim_t h1_flat_stride = flat_d.blocking_desc().strides[with_g + 1];
1722 using namespace data_type;
1723 using namespace utils;
1724
1725 dim_t blksize_0 = -1;
1726 dim_t blksize_1 = -1;
1727 switch (tag_traits<tag_o>::inner_blks) {
1728 case ib::_4b4a:
1729 case ib::_4b4c:
1730 case ib::_4c4b:
1731 blksize_0 = 4;
1732 blksize_1 = 4;
1733 break;
1734 case ib::_8a8b:
1735 case ib::_8b8a:
1736 case ib::_8b8c:
1737 case ib::_8c8b:
1738 case ib::_2c8b4c:
1739 blksize_0 = 8;
1740 blksize_1 = 8;
1741 break;
1742 case ib::_16a16b:
1743 case ib::_16b16a:
1744 case ib::_16b16c:
1745 case ib::_16c16b:
1746 case ib::_8a16b2a:
1747 case ib::_4b16a4b:
1748 case ib::_8b16a2b:
1749 case ib::_8b16c2b:
1750 case ib::_4c16b4c:
1751 case ib::_8c16b2c:
1752 blksize_0 = 16;
1753 blksize_1 = 16;
1754 break;
1755 default: blksize_0 = -1; blksize_1 = -1;
1756 }
1757
1758 const dim_t NB_H0 = pdims[0 + with_g] / blksize_0;
1759 const dim_t NB_H1 = pdims[1 + with_g] / blksize_1;
1760
1761 constexpr bool f32bf16
1762 = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16);
1763
1764 auto wrap_qz_a1b0 = [=](data_t<type_o> &out, data_t<type_i> inp) {
1765 if (f32bf16)
1766 out = inp;
1767 else
1768 out = _qz_a1b0<type_i, type_o>()(inp);
1769 };
1770
1771 auto wrap_qz = [=](data_t<type_o> &out, data_t<type_i> inp, float alpha,
1772 float beta) {
1773 if (f32bf16)
1774 out = alpha * inp + (beta ? beta * out : 0);
1775 else
1776 out = _qz<type_i, type_o>()(inp, out, alpha, beta);
1777 };
1778
1779 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1780 const int block_h0, const int block_h1) {
1781#define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
1782 if (alpha == 1.0 && beta == 0.0) {
1783 for (int h0 = 0; h0 < block_h0; ++h0) {
1784 for (int h1 = 0; h1 < block_h1; ++h1) {
1785 const dim_t flat_off
1786 = h0 * h0_flat_stride + h1 * h1_flat_stride;
1787 if (order_keep)
1788 wrap_qz_a1b0(o[blk_off(h0, h1)], i[flat_off]);
1789 else
1790 wrap_qz_a1b0(o[flat_off], i[blk_off(h0, h1)]);
1791 }
1792 if (order_keep && block_h1 < blksize_1) {
1793 // zero padding
1794 PRAGMA_OMP_SIMD()
1795 for (int h1 = block_h1; h1 < blksize_1; h1++) {
1796 o[blk_off(h0, h1)] = 0;
1797 }
1798 }
1799 }
1800 if (order_keep && block_h0 < blksize_0) {
1801 // zero padding
1802 for (int h0 = block_h0; h0 < blksize_0; h0++) {
1803 PRAGMA_OMP_SIMD()
1804 for (int h1 = 0; h1 < blksize_1; ++h1) {
1805 o[blk_off(h0, h1)] = 0;
1806 }
1807 }
1808 }
1809 } else {
1810 for (int h0 = 0; h0 < block_h0; ++h0) {
1811 for (int h1 = 0; h1 < block_h1; ++h1) {
1812 const dim_t flat_off
1813 = h0 * h0_flat_stride + h1 * h1_flat_stride;
1814 if (order_keep)
1815 wrap_qz(o[blk_off(h0, h1)], i[flat_off], alpha,
1816 beta);
1817 else
1818 wrap_qz(o[flat_off], i[blk_off(h0, h1)], alpha,
1819 beta);
1820 }
1821 if (order_keep && block_h1 < blksize_1) {
1822 // zero padding
1823 PRAGMA_OMP_SIMD()
1824 for (int h1 = block_h1; h1 < blksize_1; h1++) {
1825 o[blk_off(h0, h1)] = 0;
1826 }
1827 }
1828 }
1829 if (order_keep && block_h0 < blksize_0) {
1830 // zero padding
1831 for (int h0 = block_h0; h0 < blksize_0; h0++) {
1832 PRAGMA_OMP_SIMD()
1833 for (int h1 = 0; h1 < blksize_1; ++h1) {
1834 o[blk_off(h0, h1)] = 0;
1835 }
1836 }
1837 }
1838 }
1839
1840#undef blk_off
1841 };
1842
1843 const int i_mult_0 = order_keep ? blksize_0 : 1;
1844 const int o_mult_0 = order_keep ? 1 : blksize_0;
1845
1846 const int i_mult_1 = order_keep ? blksize_1 : 1;
1847 const int o_mult_1 = order_keep ? 1 : blksize_1;
1848
1849#define off(md, g, h0, h1, m0, m1, m2) \
1850 (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \
1851 : ndims >= 4 + with_g \
1852 ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \
1853 : /* ndims >= 3 + with_g ? */ (md) \
1854 .blk_off<!with_g>(g, h0, h1, m2))
1855
1856 parallel_nd(G, NB_H0, NB_H1, M0, M1, M2,
1857 [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1,
1858 dim_t m2) {
1859 auto i = &input[off(input_d, g, i_mult_0 * nb_h0,
1860 i_mult_1 * nb_h1, m0, m1, m2)];
1861 auto o = &output[off(output_d, g, o_mult_0 * nb_h0,
1862 o_mult_1 * nb_h1, m0, m1, m2)];
1863 const int block_h0
1864 = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0);
1865 const int block_h1
1866 = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1);
1867 ker(i, o, block_h0, block_h1);
1868 });
1869
1870#undef off
1871
1872 return status::success;
1873 }
1874};
1875
1876/* generic and direct-copy reorders */
1877
1878template <SIMPLE_REORDER_TEMPL_DECL>
1879struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1880 typename utils::enable_if<tag_i == format_tag::any
1881 && tag_o == format_tag::any
1882 && order_keep == fmt_order::any,
1883 spec::direct_copy>::type> {
1884 static bool is_applicable(const memory_desc_wrapper &input_d,
1885 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1886 return !input_d.has_runtime_dims_or_strides()
1887 && input_d.similar_to(output_d, true, false, 0)
1888 && input_d.is_dense() && output_d.is_dense()
1889 && simple_attr_check(attr, false, true);
1890 }
1891
1892 GET_SCRATCHPAD_SIZE_ZERO();
1893
1894 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1895 DECLARE_COMMON_PARAMS();
1896
1897 input += input_d.blk_off(0);
1898 output += output_d.blk_off(0);
1899
1900 const size_t nelems = input_d.nelems();
1901
1902 constexpr int block_size = 16;
1903 const auto num_blocks = nelems / block_size;
1904 const auto rem_elems = nelems % block_size;
1905
1906 parallel(0, [&](const int ithr, const int nthr) {
1907 size_t start {0}, end {0};
1908 balance211(num_blocks, nthr, ithr, start, end);
1909 start = start * block_size;
1910 end = end * block_size;
1911
1912 if (alpha == 1.0 && beta == 0.0) {
1913 PRAGMA_OMP_SIMD()
1914 for (size_t e = start; e < end; ++e) {
1915 output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()(
1916 input[e]);
1917 }
1918 } else if (alpha == 1.0) {
1919 PRAGMA_OMP_SIMD()
1920 for (size_t e = start; e < end; ++e) {
1921 output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()(
1922 input[e], output[e], beta);
1923 }
1924 } else if (beta == 0.0) {
1925 PRAGMA_OMP_SIMD()
1926 for (size_t e = start; e < end; ++e) {
1927 output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()(
1928 input[e], alpha);
1929 }
1930 } else {
1931 PRAGMA_OMP_SIMD()
1932 for (size_t e = start; e < end; ++e) {
1933 output[e] = qz<data_t<type_i>, data_t<type_o>>()(
1934 input[e], output[e], alpha, beta);
1935 }
1936 }
1937
1938 if (rem_elems != 0 && ithr == nthr - 1) {
1939 if (alpha == 1.0 && beta == 0.0) {
1940 PRAGMA_OMP_SIMD()
1941 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1942 output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()(
1943 input[e]);
1944 }
1945 } else if (alpha == 1.0) {
1946 PRAGMA_OMP_SIMD()
1947 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1948 output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()(
1949 input[e], output[e], beta);
1950 }
1951 } else if (beta == 0.0) {
1952 PRAGMA_OMP_SIMD()
1953 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1954 output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()(
1955 input[e], alpha);
1956 }
1957 } else {
1958 PRAGMA_OMP_SIMD()
1959 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1960 output[e] = qz<data_t<type_i>, data_t<type_o>>()(
1961 input[e], output[e], alpha, beta);
1962 }
1963 }
1964 }
1965 });
1966 return status::success;
1967 }
1968};
1969
1970template <SIMPLE_REORDER_TEMPL_DECL>
1971struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1972 typename utils::enable_if<tag_i == format_tag::any
1973 && tag_o == format_tag::any
1974 && order_keep == fmt_order::any,
1975 spec::direct_copy_except_dim_0>::type> {
1976 static bool is_applicable(const memory_desc_wrapper &input_d,
1977 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1978 auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
1979 return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
1980 };
1981 return !input_d.has_runtime_dims_or_strides()
1982 && input_d.similar_to(output_d, true, false, 1)
1983 && is_dense_no_0(input_d) && is_dense_no_0(output_d)
1984 && simple_attr_check(attr, false, true);
1985 }
1986
1987 GET_SCRATCHPAD_SIZE_ZERO();
1988
1989 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1990 DECLARE_COMMON_PARAMS();
1991 using namespace utils;
1992
1993 input += input_d.blk_off(0);
1994 output += output_d.blk_off(0);
1995
1996 const int N = input_d.dims()[0];
1997 const dim_t is = input_d.blocking_desc().strides[0];
1998 const dim_t os = output_d.blocking_desc().strides[0];
1999 const dim_t nelems_no_d0 = nelems_no_dim_0(input_d);
2000 const dim_t work_amount = N * nelems_no_d0;
2001
2002 if (alpha == 1.0 && beta == 0.0) {
2003 parallel(0, [&](const int ithr, const int nthr) {
2004 dim_t n {0}, dim1_s {0};
2005 dim_t start {0}, end {0};
2006 balance211(work_amount, nthr, ithr, start, end);
2007 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
2008 while (start < end) {
2009 dim_t work_rem = end - start;
2010 dim_t dim1_e = std::min(dim1_s + work_rem, nelems_no_d0);
2011 PRAGMA_OMP_SIMD()
2012 for (dim_t e = dim1_s; e < dim1_e; ++e) {
2013 output[os * n + e]
2014 = _qz_a1b0<type_i, type_o>()(input[is * n + e]);
2015 }
2016 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
2017 }
2018 });
2019 } else {
2020 parallel(0, [&](const int ithr, const int nthr) {
2021 dim_t n {0}, dim1_s {0};
2022 dim_t start {0}, end {0};
2023 balance211(work_amount, nthr, ithr, start, end);
2024 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
2025 while (start < end) {
2026 dim_t work_rem = end - start;
2027 dim_t dim1_e = std::min(dim1_s + work_rem, nelems_no_d0);
2028 PRAGMA_OMP_SIMD()
2029 for (dim_t e = dim1_s; e < dim1_e; ++e) {
2030 output[os * n + e]
2031 = _qz<type_i, type_o>()(input[is * n + e],
2032 output[os * n + e], alpha, beta);
2033 }
2034 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
2035 }
2036 });
2037 }
2038
2039 return status::success;
2040 }
2041
2042private:
2043 static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
2044 const int ndims = data_d.ndims();
2045 if (ndims <= 1) return 1;
2046 return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
2047 }
2048
2049 static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
2050 dims_t blocks;
2051 data_d.compute_blocks(blocks);
2052
2053 const auto &blk = data_d.blocking_desc();
2054
2055 dim_t blk_size = 1;
2056 for (int iblk = 0; iblk < blk.inner_nblks; ++iblk)
2057 blk_size *= blk.inner_blks[iblk];
2058
2059 dim_t max_size = blk_size;
2060 for (int d = 1; d < data_d.ndims(); ++d) {
2061 max_size = nstl::max(max_size,
2062 data_d.padded_dims()[d] / blocks[d] * blk.strides[d]);
2063 }
2064
2065 return max_size;
2066 }
2067};
2068
2069template <SIMPLE_REORDER_TEMPL_DECL>
2070struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
2071 typename utils::enable_if<tag_i == format_tag::any
2072 && tag_o == format_tag::any
2073 && order_keep == fmt_order::any,
2074 spec::reference>::type> {
2075 static bool is_applicable(const memory_desc_wrapper &input_d,
2076 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
2077 /* supported smask: 0x0...011..10...0,
2078 * i.e. 1 should be contiguous */
2079 int src_scales_mask = -1;
2080 int dst_scales_mask = -1;
2081 CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask));
2082
2083 for (auto smask : {src_scales_mask, dst_scales_mask}) {
2084 for (; smask > 0 && !(smask & 0x1); smask >>= 1)
2085 ;
2086 for (; smask > 0 && smask & 0x1; smask >>= 1)
2087 ;
2088 if (smask != 0) return false;
2089 }
2090
2091 using skip_mask_t = dnnl_primitive_attr::skip_mask_t;
2092 return input_d.is_blocking_desc() && output_d.is_blocking_desc()
2093 && !output_d.is_additional_buffer()
2094 && !input_d.is_additional_buffer()
2095 && attr->has_default_values(skip_mask_t::scales_runtime
2096 | skip_mask_t::zero_points_runtime
2097 | skip_mask_t::post_ops)
2098 && simple_po_check(attr);
2099 }
2100
2101 GET_SCRATCHPAD_SIZE_ZERO();
2102
2103 static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
2104 DECLARE_COMMON_PARAMS();
2105
2106 // This kernel is used also for tensors with multiple inner
2107 // blocks for which generic zero padding must be used.
2108 // TODO: apply zero padding inside parallel_nd()
2109 ctx.zero_pad_output(DNNL_ARG_TO);
2110
2111 parallel_nd(D_start, D_mask, D_rest,
2112 [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
2113 const float src_scale
2114 = src_scales[src_scales_mask == 0 ? 0 : dm];
2115 const float dst_scale
2116 = dst_scales[dst_scales_mask == 0 ? 0 : dm];
2117
2118 const size_t e = (ds * D_mask + dm) * D_rest + dr;
2119 const auto &i = input[input_d.off_l(e)];
2120 auto &o = output[output_d.off_l(e)];
2121
2122 float f = src_scale * ((float)i - src_zp);
2123 if (beta) f += beta * o;
2124 f = f * dst_scale + dst_zp;
2125 o = _qz_a1b0<data_type::f32, type_o>()(f);
2126 });
2127
2128 return status::success;
2129 }
2130};
2131
2132/* high level class declaration */
2133
2134template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
2135struct simple_reorder_t : public primitive_t {
2136 struct pd_t : public cpu_reorder_pd_t {
2137 using cpu_reorder_pd_t::cpu_reorder_pd_t;
2138
2139 DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
2140
2141 private:
2142 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
2143 const primitive_attr_t *attr, engine_t *src_engine,
2144 const memory_desc_t *src_md, engine_t *dst_engine,
2145 const memory_desc_t *dst_md) {
2146 using skip_mask_t = dnnl_primitive_attr::skip_mask_t;
2147 bool args_ok = src_md->data_type == type_i
2148 && dst_md->data_type == type_o
2149 && attr->has_default_values(skip_mask_t::scales_runtime
2150 | skip_mask_t::zero_points
2151 | skip_mask_t::zero_points_runtime
2152 | skip_mask_t::post_ops)
2153 && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
2154 spec>::is_applicable(src_md, dst_md, attr);
2155 if (!args_ok) return status::invalid_arguments;
2156
2157 int mask = -1;
2158 bool is_set = false;
2159 CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set));
2160 const memory_desc_wrapper input_d(src_md);
2161 if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0)
2162 return status::unimplemented;
2163
2164 auto _pd = new pd_t(attr, src_engine->kind(), src_md,
2165 dst_engine->kind(), dst_md);
2166 if (_pd == nullptr) return status::out_of_memory;
2167 if (_pd->init(engine, src_engine, dst_engine) != status::success) {
2168 delete _pd;
2169 return status::unimplemented;
2170 }
2171
2172 const size_t scratchpad_sz_
2173 = simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
2174 spec>::get_scratchpad_size(src_md, dst_md);
2175 auto scratchpad = _pd->scratchpad_registry().registrar();
2176 scratchpad.book(memory_tracking::names::key_reorder_space,
2177 scratchpad_sz_, 1, 16);
2178
2179 if (is_set && mask > 0) {
2180 dim_t D_mask;
2181 _pd->get_D_values(input_d, mask, nullptr, &D_mask, nullptr);
2182 scratchpad.template book<float>(
2183 memory_tracking::names::
2184 key_reorder_precomputed_dst_scales,
2185 D_mask);
2186 }
2187
2188 _pd->init_scratchpad_md();
2189 return safe_ptr_assign(*reorder_pd, _pd);
2190 }
2191 friend dnnl::impl::impl_list_item_t;
2192 };
2193
2194 simple_reorder_t(const pd_t *apd) : primitive_t(apd) {}
2195
2196 status_t execute(const exec_ctx_t &ctx) const override {
2197 return simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
2198 pd(), ctx);
2199 }
2200
2201private:
2202 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
2203};
2204
2205#undef SIMPLE_REORDER_TEMPL_DECL
2206#undef SIMPLE_REORDER_TEMPL_CALL
2207
2208} // namespace cpu
2209} // namespace impl
2210} // namespace dnnl
2211
2212#endif
2213
2214// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
2215