1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REORDER_SIMPLE_REORDER_HPP |
18 | #define CPU_REORDER_SIMPLE_REORDER_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <assert.h> |
22 | |
23 | #include "common/bfloat16.hpp" |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/dnnl_thread.hpp" |
26 | #include "common/math_utils.hpp" |
27 | #include "common/primitive.hpp" |
28 | #include "common/primitive_attr.hpp" |
29 | #include "common/tag_traits.hpp" |
30 | #include "common/type_helpers.hpp" |
31 | #include "common/utils.hpp" |
32 | |
33 | #include "cpu/cpu_primitive.hpp" |
34 | #include "cpu/reorder/cpu_reorder_pd.hpp" |
35 | |
36 | #include "cpu/simple_q10n.hpp" |
37 | |
38 | namespace dnnl { |
39 | namespace impl { |
40 | namespace cpu { |
41 | |
42 | using bd = block_dim_t; |
43 | using ib = inner_blk_t; |
44 | |
45 | template <impl::data_type_t type> |
46 | using data_t = typename prec_traits<type>::type; |
47 | |
48 | template <impl::data_type_t type_i, impl::data_type_t type_o> |
49 | using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>; |
50 | |
51 | template <impl::data_type_t type_i, impl::data_type_t type_o> |
52 | using _qz = qz<data_t<type_i>, data_t<type_o>>; |
53 | |
54 | namespace fmt_order { |
55 | const bool keep = true; |
56 | const bool reverse = false; |
57 | const bool any = keep; |
58 | } // namespace fmt_order |
59 | |
60 | namespace spec { |
61 | struct direct_copy {}; |
62 | struct direct_copy_except_dim_0 {}; |
63 | struct reference {}; |
64 | struct conv_req_comp {}; // {s8, u8: asymmetric quantization} |
65 | } // namespace spec |
66 | |
67 | #define SIMPLE_REORDER_TEMPL_DECL \ |
68 | impl::data_type_t type_i, impl::format_tag_t tag_i, \ |
69 | impl::data_type_t type_o, impl::format_tag_t tag_o, \ |
70 | bool order_keep |
71 | #define SIMPLE_REORDER_TEMPL_CALL type_i, tag_i, type_o, tag_o, order_keep |
72 | |
73 | #define DECLARE_COMMON_PARAMS() \ |
74 | auto input = CTX_IN_MEM(const data_t<type_i> *, DNNL_ARG_FROM); \ |
75 | auto output = CTX_OUT_MEM(data_t<type_o> *, DNNL_ARG_TO); \ |
76 | const auto &scratchpad = ctx.get_scratchpad_grantor(); \ |
77 | MAYBE_UNUSED(scratchpad); \ |
78 | const auto input_d = ctx.memory_mdw(DNNL_ARG_FROM, pd->src_md()); \ |
79 | const auto output_d = ctx.memory_mdw(DNNL_ARG_TO, pd->dst_md()); \ |
80 | DEFINE_ARG_SCALES_BUFFER_ATTR(pd->attr(), src_scales, DNNL_ARG_FROM); \ |
81 | DEFINE_ARG_SCALES_BUFFER_ATTR(pd->attr(), dst_scales_, DNNL_ARG_TO); \ |
82 | int src_scales_mask, dst_scales_mask; \ |
83 | CHECK(get_scales_mask(pd->attr(), &src_scales_mask, &dst_scales_mask)); \ |
84 | int scales_mask = std::max(src_scales_mask, dst_scales_mask); \ |
85 | MAYBE_UNUSED(scales_mask); \ |
86 | dim_t D_start, D_mask, D_rest; \ |
87 | pd->get_D_values(input_d, scales_mask, &D_start, &D_mask, &D_rest); \ |
88 | const float *dst_scales = pd->precompute_scales( \ |
89 | scratchpad, pd->attr(), D_mask, dst_scales_); \ |
90 | MAYBE_UNUSED(dst_scales); \ |
91 | DEFINE_ZERO_POINT_VALUE_ATTR(pd->attr(), src_zp, DNNL_ARG_FROM); \ |
92 | DEFINE_ZERO_POINT_VALUE_ATTR(pd->attr(), dst_zp, DNNL_ARG_TO); \ |
93 | const float alpha = src_scales[0] * dst_scales[0]; \ |
94 | MAYBE_UNUSED(alpha); \ |
95 | const float beta = pd->beta(); \ |
96 | MAYBE_UNUSED(beta); |
97 | |
98 | #define GET_SCRATCHPAD_SIZE_ZERO() \ |
99 | static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, \ |
100 | const memory_desc_wrapper &output_d) { \ |
101 | return 0; \ |
102 | } |
103 | |
104 | /* specific reorders: common template */ |
105 | template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void> |
106 | struct simple_reorder_impl {}; |
107 | |
108 | namespace { |
109 | inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i, |
110 | impl::format_tag_t tag_o, const memory_desc_wrapper &input_d, |
111 | const memory_desc_wrapper &output_d) { |
112 | if (input_d.has_runtime_dims_or_strides()) return false; |
113 | return input_d.matches_tag(order_keep ? tag_i : tag_o) |
114 | && output_d.matches_tag(order_keep ? tag_o : tag_i); |
115 | } |
116 | inline bool simple_po_check(const primitive_attr_t *attr) { |
117 | const auto &po = attr->post_ops_; |
118 | return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false)); |
119 | } |
120 | inline status_t get_scales_mask( |
121 | const primitive_attr_t *attr, int *src_mask, int *dst_mask) { |
122 | const auto &s = attr->scales_; |
123 | if (src_mask) { |
124 | *src_mask = 0; |
125 | if (!s.get(DNNL_ARG_SRC).has_default_values()) |
126 | *src_mask = s.get(DNNL_ARG_SRC).mask_; |
127 | } |
128 | if (dst_mask) { |
129 | *dst_mask = 0; |
130 | if (!s.get(DNNL_ARG_DST).has_default_values()) |
131 | *dst_mask = s.get(DNNL_ARG_DST).mask_; |
132 | } |
133 | |
134 | // This is used in a check function. |
135 | if (*src_mask > 0 && *dst_mask > 0 && *dst_mask != *src_mask) |
136 | return status::invalid_arguments; |
137 | return status::success; |
138 | } |
139 | inline bool simple_attr_check(const primitive_attr_t *attr, |
140 | bool many_scales_support, bool sum_support) { |
141 | using smask_t = primitive_attr_t::skip_mask_t; |
142 | smask_t skip_mask = smask_t::scales_runtime; |
143 | if (sum_support) skip_mask = skip_mask | smask_t::post_ops; |
144 | if (!attr->has_default_values(skip_mask)) return false; |
145 | if (sum_support) simple_po_check(attr); |
146 | if (many_scales_support) return true; |
147 | int src_mask, dst_mask; |
148 | if (get_scales_mask(attr, &src_mask, &dst_mask) != status::success) |
149 | return false; |
150 | return src_mask == 0 && dst_mask == 0; |
151 | } |
152 | } // namespace |
153 | |
154 | /* specific reorders: implementation */ |
155 | template <SIMPLE_REORDER_TEMPL_DECL> |
156 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
157 | typename utils::enable_if<tag_i == format_tag::any |
158 | && utils::one_of(tag_o, format_tag::wio, |
159 | format_tag::wigo, format_tag::hwio, |
160 | format_tag::hwigo, format_tag::dhwio, |
161 | format_tag::dhwigo), |
162 | spec::conv_req_comp>::type> { |
163 | static bool is_applicable(const memory_desc_wrapper &input_d, |
164 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
165 | using namespace data_type; |
166 | using namespace utils; |
167 | |
168 | if (input_d.has_runtime_dims_or_strides()) return false; |
169 | |
170 | int src_scales_mask, dst_scales_mask; |
171 | auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); |
172 | if (status != status::success) return false; |
173 | int scales_mask = std::max(src_scales_mask, dst_scales_mask); |
174 | |
175 | static constexpr bool w_groups = one_of( |
176 | tag_o, format_tag::wigo, format_tag::hwigo, format_tag::dhwigo); |
177 | |
178 | const bool req_comp = output_d.extra().flags |
179 | & memory_extra_flags::compensation_conv_s8s8; |
180 | const bool req_asymmetric_comp = output_d.extra().flags |
181 | & memory_extra_flags::compensation_conv_asymmetric_src; |
182 | |
183 | auto mask_ok = [&](bool check, int mask) { |
184 | return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1)); |
185 | }; |
186 | |
187 | return simple_attr_check(attr, true, false) |
188 | && output_d.matches_tag(tag_o) && input_d.is_plain() |
189 | && (req_comp || req_asymmetric_comp) |
190 | && mask_ok(req_comp, output_d.extra().compensation_mask) |
191 | && mask_ok(req_asymmetric_comp, |
192 | output_d.extra().asymm_compensation_mask) |
193 | && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1)) |
194 | && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3)) |
195 | && one_of(input_d.data_type(), f32, s8, bf16) |
196 | && output_d.data_type() == s8; |
197 | } |
198 | |
199 | GET_SCRATCHPAD_SIZE_ZERO(); |
200 | |
201 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
202 | DECLARE_COMMON_PARAMS(); |
203 | |
204 | static constexpr bool w_groups = utils::one_of( |
205 | tag_o, format_tag::wigo, format_tag::hwigo, format_tag::dhwigo); |
206 | static constexpr bool w_height |
207 | = !utils::one_of(tag_o, format_tag::wio, format_tag::wigo); |
208 | static constexpr bool w_depth |
209 | = utils::one_of(tag_o, format_tag::dhwio, format_tag::dhwigo); |
210 | |
211 | const auto &dims = input_d.dims(); |
212 | |
213 | const dim_t G = w_groups ? dims[0] : 1; |
214 | const dim_t OC = dims[w_groups + 0]; |
215 | const dim_t IC = dims[w_groups + 1]; |
216 | const dim_t D = w_depth ? dims[w_groups + 2] : 1; |
217 | const dim_t H = w_height ? dims[w_groups + w_depth + 2] : 1; |
218 | const dim_t W = dims[w_groups + w_depth + w_height + 2]; |
219 | |
220 | const bool req_comp = output_d.extra().flags |
221 | & memory_extra_flags::compensation_conv_s8s8; |
222 | const bool has_asymmetric_comp = output_d.extra().flags |
223 | & memory_extra_flags::compensation_conv_asymmetric_src; |
224 | |
225 | assert(req_comp || has_asymmetric_comp); |
226 | |
227 | float adj_scale |
228 | = (output_d.extra().flags & memory_extra_flags::scale_adjust) |
229 | ? output_d.extra().scale_adjust |
230 | : 1.f; |
231 | |
232 | size_t offset = output_d.size() - output_d.additional_buffer_size(); |
233 | size_t comp_size = output_d.additional_buffer_size( |
234 | memory_extra_flags::compensation_conv_s8s8); |
235 | size_t zp_offset = offset + (req_comp ? comp_size : 0); |
236 | int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset) |
237 | : nullptr; |
238 | int32_t *zp = has_asymmetric_comp |
239 | ? reinterpret_cast<int32_t *>(output + zp_offset) |
240 | : nullptr; |
241 | |
242 | const bool per_oc = scales_mask & (1 << (w_groups + 0)); |
243 | const bool per_ic = scales_mask & (1 << (w_groups + 1)); |
244 | const size_t ic_stride = per_ic ? 1 : 0; |
245 | const size_t oc_stride = per_oc ? per_ic ? IC : 1 : 0; |
246 | |
247 | parallel_nd(G, OC, [&](dim_t g, dim_t oc) { |
248 | if (req_comp) cp[g * OC + oc] = 0; |
249 | if (has_asymmetric_comp) zp[g * OC + oc] = 0; |
250 | for_(dim_t ic = 0; ic < IC; ic++) |
251 | for_(dim_t d = 0; d < D; d++) |
252 | for_(dim_t h = 0; h < H; h++) |
253 | for (dim_t w = 0; w < W; w++) { |
254 | auto i = w_depth |
255 | ? input[input_d.blk_off<!w_groups>(g, oc, ic, d, h, w)] |
256 | : w_height ? input[input_d.blk_off<!w_groups>( |
257 | g, oc, ic, h, w)] |
258 | : input[input_d.blk_off<!w_groups>( |
259 | g, oc, ic, w)]; |
260 | auto &o = w_depth |
261 | ? output[output_d.blk_off<!w_groups>( |
262 | g, oc, ic, d, h, w)] |
263 | : w_height ? output[output_d.blk_off<!w_groups>( |
264 | g, oc, ic, h, w)] |
265 | : output[output_d.blk_off<!w_groups>( |
266 | g, oc, ic, w)]; |
267 | const size_t os_off |
268 | = (g * OC + oc) * oc_stride + ic * ic_stride; |
269 | const float s = src_scales[src_scales_mask == 0 ? 0 : os_off]; |
270 | const float d = dst_scales[dst_scales_mask == 0 ? 0 : os_off]; |
271 | |
272 | o = qz_b0<data_t<type_i>, data_t<type_o>>()( |
273 | i, s * adj_scale * d); |
274 | if (req_comp) cp[g * OC + oc] -= (int32_t)o; |
275 | if (has_asymmetric_comp) zp[g * OC + oc] -= (int32_t)o; |
276 | } |
277 | if (req_comp) cp[g * OC + oc] *= 128; |
278 | }); |
279 | return status::success; |
280 | } |
281 | }; |
282 | |
283 | template <SIMPLE_REORDER_TEMPL_DECL> |
284 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
285 | typename utils::enable_if< |
286 | (utils::one_of(tag_i, format_tag::iwo, format_tag::oiw, |
287 | format_tag::wio) |
288 | && utils::one_of(tag_o, format_tag::OIw4i16o4i, |
289 | format_tag::OIw4i32o4i, format_tag::OIw4i64o4i, |
290 | format_tag::OIw2i8o4i, format_tag::OIw4o4i)) |
291 | || (utils::one_of(tag_i, format_tag::oi, format_tag::io) |
292 | && utils::one_of(tag_o, format_tag::OI4i16o4i, |
293 | format_tag::OI4i32o4i, |
294 | format_tag::OI4i64o4i)) |
295 | || (utils::one_of( |
296 | tag_i, format_tag::goiw, format_tag::wigo) |
297 | && utils::one_of(tag_o, format_tag::gOIw4i16o4i, |
298 | format_tag::gOIw2i8o4i, |
299 | format_tag::gOIw4o4i)) |
300 | || (utils::one_of(tag_i, format_tag::ihwo, |
301 | format_tag::hwio, format_tag::oihw) |
302 | && utils::one_of(tag_o, format_tag::OIhw4i16o4i, |
303 | format_tag::OIhw4i32o4i, |
304 | format_tag::OIhw4i64o4i, |
305 | format_tag::OIhw2i8o4i, |
306 | format_tag::OIhw4o4i)) |
307 | || (utils::one_of(tag_i, format_tag::idhwo, |
308 | format_tag::dhwio, format_tag::oidhw) |
309 | && utils::one_of(tag_o, |
310 | format_tag::OIdhw4i16o4i, |
311 | format_tag::OIdhw4i32o4i, |
312 | format_tag::OIdhw4i64o4i, |
313 | format_tag::OIdhw2i8o4i, |
314 | format_tag::OIdhw4o4i)) |
315 | || (utils::one_of( |
316 | tag_i, format_tag::goihw, format_tag::hwigo) |
317 | && utils::one_of(tag_o, format_tag::gOIhw4o4i, |
318 | format_tag::gOIhw2i8o4i, |
319 | format_tag::gOIhw4i16o4i)) |
320 | || (utils::one_of(tag_i, format_tag::goidhw) |
321 | && (utils::one_of(tag_o, |
322 | format_tag::gOIdhw4i16o4i, |
323 | format_tag::gOIdhw2i8o4i, |
324 | format_tag::gOIdhw4o4i))), |
325 | spec::conv_req_comp>::type> { |
326 | static bool is_applicable(const memory_desc_wrapper &input_d, |
327 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
328 | using namespace format_tag; |
329 | using namespace data_type; |
330 | using namespace utils; |
331 | |
332 | if (input_d.has_runtime_dims_or_strides()) return false; |
333 | |
334 | int src_scales_mask, dst_scales_mask; |
335 | auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); |
336 | if (status != status::success) return false; |
337 | int scales_mask = std::max(src_scales_mask, dst_scales_mask); |
338 | |
339 | const bool w_groups = !one_of(tag_o, OIw4i16o4i, OIw2i8o4i, OIw4o4i, |
340 | OIhw4i16o4i, OIhw2i8o4i, OIhw4o4i, OIdhw4i16o4i, OIdhw2i8o4i, |
341 | OIdhw4o4i, OI4i16o4i, OI4i32o4i, OI4i64o4i, OIw4i32o4i, |
342 | OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i, OIdhw4i32o4i, |
343 | OIdhw4i64o4i); |
344 | |
345 | const bool req_comp = output_d.extra().flags |
346 | & memory_extra_flags::compensation_conv_s8s8; |
347 | const bool req_asymmetric_comp = output_d.extra().flags |
348 | & memory_extra_flags::compensation_conv_asymmetric_src; |
349 | |
350 | auto mask_ok = [&](bool check, int mask) { |
351 | return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1)); |
352 | }; |
353 | |
354 | return simple_attr_check(attr, true, false) |
355 | && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) |
356 | && (req_comp || req_asymmetric_comp) |
357 | && mask_ok(req_comp, output_d.extra().compensation_mask) |
358 | && mask_ok(req_asymmetric_comp, |
359 | output_d.extra().asymm_compensation_mask) |
360 | && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1)) |
361 | && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3)) |
362 | && one_of(input_d.data_type(), f32, s8, bf16) |
363 | && output_d.data_type() == s8; |
364 | } |
365 | |
366 | GET_SCRATCHPAD_SIZE_ZERO(); |
367 | |
368 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
369 | DECLARE_COMMON_PARAMS(); |
370 | using namespace format_tag; |
371 | |
372 | static constexpr bool w_groups = !utils::one_of(tag_o, OIw4o4i, |
373 | OIw4i16o4i, OIhw4i16o4i, OIdhw4i16o4i, OIhw4o4i, OIw2i8o4i, |
374 | OIhw2i8o4i, OIdhw2i8o4i, OIdhw4o4i, OI4i16o4i, OI4i32o4i, |
375 | OI4i64o4i, OIw4i32o4i, OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i, |
376 | OIdhw4i32o4i, OIdhw4i64o4i); |
377 | |
378 | constexpr int is_0d |
379 | = utils::one_of(tag_o, OI4i16o4i, OI4i32o4i, OI4i64o4i); |
380 | constexpr int is_1d |
381 | = utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i, gOIw2i8o4i, |
382 | OIw2i8o4i, gOIw4o4i, OIw4o4i, OIw4i32o4i, OIw4i64o4i); |
383 | constexpr int is_3d = utils::one_of(tag_o, gOIdhw4i16o4i, OIdhw4i16o4i, |
384 | gOIdhw2i8o4i, OIdhw2i8o4i, gOIdhw4o4i, OIdhw4o4i, OIdhw4i32o4i, |
385 | OIdhw4i64o4i); |
386 | constexpr dim_t icblksize = utils::one_of(tag_traits<tag_o>::inner_blks, |
387 | ib::_4a4b, ib::_4b4c) |
388 | ? 4 |
389 | : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_2c8b4c, |
390 | ib::_2b8a4b) |
391 | ? 8 |
392 | : 16; |
393 | constexpr dim_t ocblksize |
394 | = tag_traits<tag_o>::inner_blks == ib::_4b32a4b |
395 | ? 32 |
396 | : tag_traits<tag_o>::inner_blks == ib::_4b64a4b ? 64 |
397 | : icblksize; |
398 | |
399 | const auto &plain_d = order_keep ? input_d : output_d; |
400 | const auto &dims = input_d.dims(); |
401 | const int ndims = input_d.ndims(); |
402 | const auto &pdims |
403 | = order_keep ? output_d.padded_dims() : input_d.padded_dims(); |
404 | |
405 | const dim_t G = w_groups ? dims[0] : 1; |
406 | const dim_t OC = dims[w_groups + 0]; |
407 | const dim_t PADDED_OC = pdims[w_groups + 0]; |
408 | const dim_t NB_OC = pdims[w_groups + 0] / ocblksize; |
409 | const dim_t IC = dims[w_groups + 1]; |
410 | const dim_t NB_IC = pdims[w_groups + 1] / icblksize; |
411 | const dim_t D = is_3d ? dims[2 + w_groups] : 1; |
412 | const dim_t H = is_1d || is_0d ? 1 : dims[2 + w_groups + is_3d]; |
413 | const dim_t W = is_0d ? 1 : dims[w_groups + is_3d + 3 - is_1d]; |
414 | |
415 | // XXX: Currently user can pass a mask that has non-zero values in |
416 | // dimensions that do not exist in a md. Since attributes are created |
417 | // separately mask can't be validated. |
418 | // This line truncates a given mask in range [0, 1 << ndims - 1] |
419 | // TODO: Such masks can be either prohibited at pd creation step at |
420 | // API level or checked by each implementation that relies on it. |
421 | scales_mask &= (1 << ndims) - 1; |
422 | |
423 | const bool req_comp = output_d.extra().flags |
424 | & memory_extra_flags::compensation_conv_s8s8; |
425 | const bool has_asymmetric_comp = output_d.extra().flags |
426 | & memory_extra_flags::compensation_conv_asymmetric_src; |
427 | |
428 | assert(req_comp || has_asymmetric_comp); |
429 | |
430 | float adj_scale |
431 | = (output_d.extra().flags & memory_extra_flags::scale_adjust) |
432 | ? output_d.extra().scale_adjust |
433 | : 1.f; |
434 | |
435 | const bool per_oc = scales_mask & (1 << (w_groups + 0)); |
436 | const bool per_ic = scales_mask & (1 << (w_groups + 1)); |
437 | const size_t ic_stride = per_ic ? 1 : 0; |
438 | const size_t oc_stride = per_oc ? per_ic ? IC : 1 : 0; |
439 | const size_t nb_ic_stride = (per_ic ? 1 : 0) * icblksize; |
440 | const size_t nb_oc_stride = (per_oc ? per_ic ? IC : 1 : 0) * ocblksize; |
441 | |
442 | // This kernel is used primarily for tensors with multiple inner |
443 | // blocks for which generic zero padding must be used. |
444 | // TODO: apply zero padding inside parallel_nd() |
445 | ctx.zero_pad_output(DNNL_ARG_TO); |
446 | |
447 | auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out, |
448 | int32_t *c, int32_t *zp, const float *s, |
449 | const float *d, const dim_t oc_block, |
450 | const dim_t ic_block) { |
451 | #define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks> |
452 | for_(dim_t ic = 0; ic < ic_block; ++ic) |
453 | for (dim_t oc = 0; oc < oc_block; ++oc) { |
454 | const auto plain_off |
455 | = oc * plain_d.blocking_desc().strides[w_groups + 0] |
456 | + ic * plain_d.blocking_desc().strides[w_groups + 1]; |
457 | const size_t os_off = oc * oc_stride + ic * ic_stride; |
458 | const float src_scale = s[src_scales_mask == 0 ? 0 : os_off]; |
459 | const float dst_scale = d[dst_scales_mask == 0 ? 0 : os_off]; |
460 | out[index(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
461 | inp[plain_off], src_scale * adj_scale * dst_scale); |
462 | if (req_comp) c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); |
463 | if (has_asymmetric_comp) |
464 | zp[oc] -= (int32_t)(out[index(oc, ic)]); |
465 | } |
466 | #undef index |
467 | }; |
468 | |
469 | constexpr dim_t i_mult_ic = icblksize; |
470 | constexpr dim_t i_mult_oc = ocblksize; |
471 | constexpr dim_t o_mult = 1; |
472 | |
473 | size_t offset = output_d.size() - output_d.additional_buffer_size(); |
474 | size_t comp_size = output_d.additional_buffer_size( |
475 | memory_extra_flags::compensation_conv_s8s8); |
476 | size_t zp_offset = offset + (req_comp ? comp_size : 0); |
477 | int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset) |
478 | : nullptr; |
479 | int32_t *zp = has_asymmetric_comp |
480 | ? reinterpret_cast<int32_t *>(output + zp_offset) |
481 | : nullptr; |
482 | |
483 | parallel_nd(G * PADDED_OC, [&](dim_t i) { |
484 | if (req_comp) cp[i] = 0; |
485 | if (has_asymmetric_comp) zp[i] = 0; |
486 | }); |
487 | |
488 | #define wei_blk_off(md, g, o, i, d, h, w) \ |
489 | (is_0d ? (md).blk_off<!w_groups>(g, o, i) \ |
490 | : is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \ |
491 | : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \ |
492 | : (md).blk_off<!w_groups>(g, o, i, h, w)) |
493 | parallel_nd(G, NB_OC, [&](dim_t g, dim_t O) { |
494 | for_(dim_t I = 0; I < NB_IC; I++) |
495 | for_(dim_t d = 0; d < D; d++) |
496 | for_(dim_t h = 0; h < H; h++) |
497 | for (dim_t w = 0; w < W; w++) { |
498 | auto i = &input[wei_blk_off( |
499 | input_d, g, i_mult_oc * O, i_mult_ic * I, d, h, w)]; |
500 | auto o = &output[wei_blk_off( |
501 | output_d, g, o_mult * O, o_mult * I, d, h, w)]; |
502 | const dim_t oc_block = nstl::min(ocblksize, OC - O * ocblksize); |
503 | const dim_t ic_block = nstl::min(icblksize, IC - I * icblksize); |
504 | dim_t _offset = (g * NB_OC + O) * ocblksize; |
505 | dim_t os_nb_off |
506 | = (g * NB_OC + O) * nb_oc_stride + I * nb_ic_stride; |
507 | const float *src_scales_ptr |
508 | = &src_scales[src_scales_mask == 0 ? 0 : os_nb_off]; |
509 | const float *dst_scales_ptr |
510 | = &dst_scales[dst_scales_mask == 0 ? 0 : os_nb_off]; |
511 | ker(i, o, (order_keep && req_comp) ? &cp[_offset] : nullptr, |
512 | (order_keep && has_asymmetric_comp) ? &zp[_offset] |
513 | : nullptr, |
514 | src_scales_ptr, dst_scales_ptr, oc_block, ic_block); |
515 | } |
516 | }); |
517 | |
518 | #undef wei_blk_off |
519 | |
520 | return status::success; |
521 | } |
522 | }; |
523 | |
524 | /* Asymmetric Blocking */ |
525 | template <SIMPLE_REORDER_TEMPL_DECL> |
526 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
527 | typename utils::enable_if<(utils::one_of(tag_i, format_tag::iwo, |
528 | format_tag::oiw, format_tag::wio) |
529 | && utils::one_of( |
530 | tag_o, format_tag::Owi16o)) |
531 | || (utils::one_of( |
532 | tag_i, format_tag::goiw, format_tag::wigo) |
533 | && utils::one_of(tag_o, format_tag::gOwi16o)) |
534 | || (utils::one_of(tag_i, format_tag::ihwo, |
535 | format_tag::hwio, format_tag::oihw) |
536 | && utils::one_of(tag_o, format_tag::Owhi16o)) |
537 | || (utils::one_of( |
538 | tag_i, format_tag::goihw, format_tag::hwigo) |
539 | && utils::one_of(tag_o, format_tag::gOwhi16o)), |
540 | spec::conv_req_comp>::type> { |
541 | static bool is_applicable(const memory_desc_wrapper &input_d, |
542 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
543 | using namespace format_tag; |
544 | using namespace data_type; |
545 | using namespace utils; |
546 | |
547 | if (input_d.has_runtime_dims_or_strides()) return false; |
548 | |
549 | const bool w_groups = !one_of(tag_o, Owi16o, Owhi16o); |
550 | |
551 | // Current formats are only used in jit kernels that natively |
552 | // support s8 instructions, hence, there is no need for signed |
553 | // compensation. |
554 | const bool req_comp = output_d.extra().flags |
555 | & memory_extra_flags::compensation_conv_s8s8; |
556 | |
557 | const bool req_asymmetric_comp = output_d.extra().flags |
558 | & memory_extra_flags::compensation_conv_asymmetric_src; |
559 | |
560 | auto mask_ok = [&](bool check, int mask) { |
561 | const int c_mask = 0x1, |
562 | g_mask = 0x3; // mask for i/o-channel and ngroups |
563 | return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask)); |
564 | }; |
565 | |
566 | return simple_attr_check(attr, true, false) |
567 | && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) |
568 | && mask_ok(req_asymmetric_comp, |
569 | output_d.extra().asymm_compensation_mask) |
570 | && one_of(input_d.data_type(), f32, s8, bf16) |
571 | && output_d.data_type() == s8 && !req_comp; |
572 | } |
573 | |
574 | GET_SCRATCHPAD_SIZE_ZERO(); |
575 | |
576 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
577 | DECLARE_COMMON_PARAMS(); |
578 | using namespace format_tag; |
579 | |
580 | static constexpr bool w_groups = !utils::one_of(tag_o, Owi16o, Owhi16o); |
581 | constexpr int is_1d = utils::one_of(tag_o, Owi16o, gOwi16o); |
582 | const bool is_3d = false; // TODO once enabled |
583 | |
584 | constexpr dim_t oc_blksize = 16; |
585 | |
586 | const auto &plain_d = order_keep ? input_d : output_d; |
587 | const auto &dims = input_d.dims(); |
588 | const auto &pdims |
589 | = order_keep ? output_d.padded_dims() : input_d.padded_dims(); |
590 | |
591 | const dim_t G = w_groups ? dims[0] : 1; |
592 | const dim_t OC = dims[w_groups + 0]; |
593 | const dim_t NB_OC = pdims[w_groups + 0] / oc_blksize; |
594 | const dim_t IC = dims[w_groups + 1]; |
595 | |
596 | const dim_t D = is_3d ? dims[2 + w_groups] : 1; |
597 | const dim_t H = is_1d ? 1 : dims[2 + w_groups + is_3d]; |
598 | const dim_t W = dims[w_groups + is_3d + 3 - is_1d]; |
599 | |
600 | const bool has_asymmetric_comp = output_d.extra().flags |
601 | & memory_extra_flags::compensation_conv_asymmetric_src; |
602 | |
603 | float adj_scale |
604 | = (output_d.extra().flags & memory_extra_flags::scale_adjust) |
605 | ? output_d.extra().scale_adjust |
606 | : 1.f; |
607 | |
608 | auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out, |
609 | int32_t *zp, const float *s, const float *d, |
610 | const dim_t oc_block) { |
611 | for (dim_t oc = 0; oc < oc_block; ++oc) { |
612 | const auto plain_off |
613 | = oc * plain_d.blocking_desc().strides[w_groups + 0]; |
614 | out[oc] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
615 | inp[plain_off], s[oc] * adj_scale * d[oc]); |
616 | if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[oc]); |
617 | } |
618 | // fill memory with '0' in case of padded channel dimensions |
619 | for (dim_t oc = oc_block; oc < oc_blksize; ++oc) { |
620 | out[oc] = 0; |
621 | } |
622 | }; |
623 | |
624 | size_t offset = output_d.size() - output_d.additional_buffer_size(); |
625 | int32_t *zp = has_asymmetric_comp |
626 | ? reinterpret_cast<int32_t *>(output + offset) |
627 | : nullptr; |
628 | |
629 | if (has_asymmetric_comp) { |
630 | parallel_nd(G * NB_OC * oc_blksize, [&](dim_t i) { zp[i] = 0; }); |
631 | } |
632 | |
633 | #define wei_blk_off(md, g, o, i, d, h, w) \ |
634 | (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \ |
635 | : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \ |
636 | : (md).blk_off<!w_groups>(g, o, i, h, w)) |
637 | |
638 | parallel_nd(G, NB_OC, [&](dim_t g, dim_t O) { |
639 | for_(dim_t I = 0; I < IC; I++) |
640 | for_(dim_t d = 0; d < D; d++) |
641 | for_(dim_t h = 0; h < H; h++) |
642 | for (dim_t w = 0; w < W; w++) { |
643 | auto i = &input[wei_blk_off( |
644 | input_d, g, oc_blksize * O, I, d, h, w)]; |
645 | auto o = &output[wei_blk_off(output_d, g, O, I, d, h, w)]; |
646 | const dim_t oc_block |
647 | = nstl::min(oc_blksize, OC - O * oc_blksize); |
648 | dim_t _offset = (g * NB_OC + O) * oc_blksize; |
649 | int32_t *zp_ptr = (order_keep && has_asymmetric_comp) |
650 | ? &zp[_offset] |
651 | : nullptr; |
652 | const float *src_scales_ptr |
653 | = &src_scales[src_scales_mask == 0 ? 0 : _offset]; |
654 | const float *dst_scales_ptr |
655 | = &dst_scales[dst_scales_mask == 0 ? 0 : _offset]; |
656 | ker(i, o, zp_ptr, src_scales_ptr, dst_scales_ptr, oc_block); |
657 | } |
658 | }); |
659 | |
660 | #undef wei_blk_off |
661 | |
662 | return status::success; |
663 | } |
664 | }; |
665 | |
666 | /* Asymmetric Blocking */ |
667 | template <SIMPLE_REORDER_TEMPL_DECL> |
668 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
669 | typename utils::enable_if<(utils::one_of(tag_i, format_tag::iwo, |
670 | format_tag::oiw, format_tag::wio) |
671 | && utils::one_of(tag_o, |
672 | format_tag::OwI16o4i, |
673 | format_tag::OIw16i16o4i)) |
674 | || (utils::one_of( |
675 | tag_i, format_tag::goiw, format_tag::wigo) |
676 | && utils::one_of(tag_o, format_tag::gOwI16o4i, |
677 | format_tag::gOIw16i16o4i)) |
678 | || (utils::one_of(tag_i, format_tag::ihwo, |
679 | format_tag::hwio, format_tag::oihw) |
680 | && utils::one_of(tag_o, format_tag::OhwI16o4i, |
681 | format_tag::OIhw16i16o4i)) |
682 | || (utils::one_of( |
683 | tag_i, format_tag::goihw, format_tag::hwigo) |
684 | && utils::one_of(tag_o, format_tag::gOhwI16o4i, |
685 | format_tag::gOIhw16i16o4i)) |
686 | || (utils::one_of(tag_i, format_tag::idhwo, |
687 | format_tag::dhwio, format_tag::oidhw) |
688 | && utils::one_of(tag_o, format_tag::OdhwI16o4i, |
689 | format_tag::OIdhw16i16o4i)) |
690 | || (utils::one_of(tag_i, format_tag::goidhw) |
691 | && utils::one_of(tag_o, format_tag::gOdhwI16o4i, |
692 | format_tag::gOIdhw16i16o4i)), |
693 | spec::conv_req_comp>::type> { |
694 | static bool is_applicable(const memory_desc_wrapper &input_d, |
695 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
696 | using namespace format_tag; |
697 | using namespace data_type; |
698 | using namespace utils; |
699 | |
700 | if (input_d.has_runtime_dims_or_strides()) return false; |
701 | |
702 | int src_scales_mask, dst_scales_mask; |
703 | auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); |
704 | if (status != status::success) return false; |
705 | int scales_mask = std::max(src_scales_mask, dst_scales_mask); |
706 | |
707 | const bool w_groups = !one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i, |
708 | OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i); |
709 | |
710 | // Current formats are only used in jit kernels that natively |
711 | // support s8 instructions, hence, there is no need for signed |
712 | // compensation. |
713 | const bool req_comp = output_d.extra().flags |
714 | & memory_extra_flags::compensation_conv_s8s8; |
715 | |
716 | const bool req_asymmetric_comp = output_d.extra().flags |
717 | & memory_extra_flags::compensation_conv_asymmetric_src; |
718 | |
719 | auto mask_ok = [&](bool check, int mask) { |
720 | const int c_mask = 0x1, |
721 | g_mask = 0x3; // mask for o-channel and ngroups |
722 | return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask)); |
723 | }; |
724 | |
725 | return simple_attr_check(attr, true, false) |
726 | && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) |
727 | && mask_ok(req_asymmetric_comp, |
728 | output_d.extra().asymm_compensation_mask) |
729 | && one_of(input_d.data_type(), f32, s8, bf16) |
730 | && IMPLICATION(!w_groups, one_of(scales_mask, 0, 0x1)) |
731 | && IMPLICATION(w_groups, one_of(scales_mask, 0, 0x3)) |
732 | && output_d.data_type() == s8 && !req_comp; |
733 | } |
734 | |
735 | GET_SCRATCHPAD_SIZE_ZERO(); |
736 | |
737 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
738 | DECLARE_COMMON_PARAMS(); |
739 | using namespace format_tag; |
740 | |
741 | static constexpr bool w_groups |
742 | = !utils::one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i, |
743 | OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i); |
744 | constexpr int is_1d = utils::one_of( |
745 | tag_o, OwI16o4i, gOwI16o4i, OIw16i16o4i, gOIw16i16o4i); |
746 | const bool is_3d = utils::one_of( |
747 | tag_o, OdhwI16o4i, gOdhwI16o4i, OIdhw16i16o4i, gOIdhw16i16o4i); |
748 | |
749 | constexpr dim_t oc_blksize = 16; |
750 | constexpr dim_t ic_blksize |
751 | = utils::one_of(tag_traits<tag_o>::inner_blks, ib::_16b16a4b, |
752 | ib::_16c16b4c) |
753 | ? 64 |
754 | : utils::one_of( |
755 | tag_traits<tag_o>::inner_blks, ib::_16a4b, ib::_16b4c) |
756 | ? 4 |
757 | : 1; |
758 | assert(ic_blksize != 1); |
759 | |
760 | const auto &plain_d = order_keep ? input_d : output_d; |
761 | const auto &dims = input_d.dims(); |
762 | const auto &pdims |
763 | = order_keep ? output_d.padded_dims() : input_d.padded_dims(); |
764 | |
765 | const dim_t G = w_groups ? dims[0] : 1; |
766 | const dim_t OC = dims[w_groups + 0]; |
767 | const dim_t NB_OC = pdims[w_groups + 0] / oc_blksize; |
768 | const dim_t IC = dims[w_groups + 1]; |
769 | const dim_t NB_IC = pdims[w_groups + 1] / ic_blksize; |
770 | |
771 | const dim_t D = is_3d ? dims[2 + w_groups] : 1; |
772 | const dim_t H = is_1d ? 1 : dims[2 + w_groups + is_3d]; |
773 | const dim_t W = dims[w_groups + is_3d + 3 - is_1d]; |
774 | |
775 | const bool has_asymmetric_comp = output_d.extra().flags |
776 | & memory_extra_flags::compensation_conv_asymmetric_src; |
777 | |
778 | float adj_scale |
779 | = (output_d.extra().flags & memory_extra_flags::scale_adjust) |
780 | ? output_d.extra().scale_adjust |
781 | : 1.f; |
782 | |
783 | // This kernel is used primarily for tensors with multiple inner |
784 | // blocks for which generic zero padding must be used. |
785 | // TODO: apply zero padding inside parallel_nd() |
786 | ctx.zero_pad_output(DNNL_ARG_TO); |
787 | |
788 | auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out, |
789 | int32_t *zp, const float *s, const float *d, |
790 | const dim_t oc_block, const dim_t ic_block) { |
791 | for_(dim_t ic = 0; ic < ic_block; ++ic) |
792 | for (dim_t oc = 0; oc < oc_block; ++oc) { |
793 | const auto plain_off |
794 | = oc * plain_d.blocking_desc().strides[w_groups + 0] |
795 | + ic * plain_d.blocking_desc().strides[w_groups + 1]; |
796 | auto index = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>( |
797 | oc, ic); |
798 | out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
799 | inp[plain_off], s[oc] * adj_scale * d[oc]); |
800 | |
801 | if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[index]); |
802 | } |
803 | }; |
804 | |
805 | size_t offset = output_d.size() - output_d.additional_buffer_size(); |
806 | int32_t *zp = has_asymmetric_comp |
807 | ? reinterpret_cast<int32_t *>(output + offset) |
808 | : nullptr; |
809 | |
810 | if (has_asymmetric_comp) { |
811 | parallel_nd(G * NB_OC * oc_blksize, [&](dim_t i) { zp[i] = 0; }); |
812 | } |
813 | |
814 | #define wei_blk_off(md, g, o, i, d, h, w) \ |
815 | (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \ |
816 | : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \ |
817 | : (md).blk_off<!w_groups>(g, o, i, h, w)) |
818 | |
819 | parallel_nd(G, NB_OC, [&](dim_t g, dim_t O) { |
820 | for_(dim_t I = 0; I < NB_IC; I++) |
821 | for_(dim_t d = 0; d < D; d++) |
822 | for_(dim_t h = 0; h < H; h++) |
823 | for (dim_t w = 0; w < W; w++) { |
824 | auto i = &input[wei_blk_off( |
825 | input_d, g, oc_blksize * O, ic_blksize * I, d, h, w)]; |
826 | auto o = &output[wei_blk_off(output_d, g, O, I, d, h, w)]; |
827 | const dim_t oc_block |
828 | = nstl::min(oc_blksize, OC - O * oc_blksize); |
829 | const dim_t ic_block |
830 | = nstl::min(ic_blksize, IC - I * ic_blksize); |
831 | dim_t _offset = (g * NB_OC + O) * oc_blksize; |
832 | int32_t *zp_ptr = (order_keep && has_asymmetric_comp) |
833 | ? &zp[_offset] |
834 | : nullptr; |
835 | const float *src_scales_ptr |
836 | = &src_scales[src_scales_mask == 0 ? 0 : _offset]; |
837 | const float *dst_scales_ptr |
838 | = &dst_scales[dst_scales_mask == 0 ? 0 : _offset]; |
839 | ker(i, o, zp_ptr, src_scales_ptr, dst_scales_ptr, oc_block, |
840 | ic_block); |
841 | } |
842 | }); |
843 | |
844 | #undef wei_blk_off |
845 | |
846 | return status::success; |
847 | } |
848 | }; |
849 | |
850 | template <SIMPLE_REORDER_TEMPL_DECL> |
851 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
852 | typename utils::enable_if< |
853 | (utils::one_of(tag_i, format_tag::ab, format_tag::ba, |
854 | format_tag::abc, format_tag::acb) |
855 | && utils::one_of(tag_o, format_tag::BA16a16b4a, |
856 | format_tag::BA16a32b4a, format_tag::BA16a48b4a, |
857 | format_tag::BA16a64b4a, format_tag::aCB16b16c4b, |
858 | format_tag::aCB16b32c4b, |
859 | format_tag::aCB16b48c4b, |
860 | format_tag::aCB16b64c4b)), |
861 | spec::conv_req_comp>::type> { |
862 | static bool is_applicable(const memory_desc_wrapper &input_d, |
863 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
864 | using namespace format_tag; |
865 | using namespace data_type; |
866 | using namespace utils; |
867 | |
868 | if (input_d.has_runtime_dims_or_strides()) return false; |
869 | |
870 | const bool req_comp = output_d.extra().flags |
871 | & memory_extra_flags::compensation_conv_s8s8; |
872 | const bool req_asymmetric_comp = output_d.extra().flags |
873 | & memory_extra_flags::compensation_conv_asymmetric_src; |
874 | |
875 | const auto ndims = input_d.ndims(); |
876 | auto mask_ok = [&](bool check, int mask) { |
877 | return IMPLICATION( |
878 | check, mask == (1 << ndims) - 1 - (1 << (ndims - 2))); |
879 | }; |
880 | |
881 | int src_scales_mask, dst_scales_mask; |
882 | auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); |
883 | if (status != status::success) return false; |
884 | int scales_mask = std::max(src_scales_mask, dst_scales_mask); |
885 | const size_t D_mask |
886 | = array_product(input_d.dims(), math::ilog2q(scales_mask + 1)); |
887 | |
888 | return simple_attr_check(attr, true, false) |
889 | && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) |
890 | && mask_ok(req_comp, output_d.extra().compensation_mask) |
891 | && mask_ok(req_asymmetric_comp, |
892 | output_d.extra().asymm_compensation_mask) |
893 | && one_of(input_d.data_type(), f32, s8, bf16, f16) |
894 | && output_d.data_type() == s8 && D_mask == 1; |
895 | } |
896 | |
897 | GET_SCRATCHPAD_SIZE_ZERO(); |
898 | |
899 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
900 | DECLARE_COMMON_PARAMS(); |
901 | using namespace format_tag; |
902 | |
903 | // {[batch_dim][d0][d1], [batch_dim][d1][d0]} -> [batch_dim][D1][D0][16][D1_blksize][4] |
904 | // 2D: batch_dim - none, d0 <-> a, d1 <-> b |
905 | // 3D: batch_dim <-> a, d0 <-> b, d1 <-> c |
906 | constexpr dim_t D0_blksize = 64; |
907 | constexpr dim_t D1_blksize |
908 | = (utils::one_of(tag_traits<tag_o>::inner_blks, ib::_16a64b4a, |
909 | ib::_16b64c4b)) |
910 | ? 64 |
911 | : (utils::one_of(tag_traits<tag_o>::inner_blks, ib::_16a48b4a, |
912 | ib::_16b48c4b)) |
913 | ? 48 |
914 | : (utils::one_of(tag_traits<tag_o>::inner_blks, |
915 | ib::_16a32b4a, ib::_16b32c4b)) |
916 | ? 32 |
917 | : (utils::one_of(tag_traits<tag_o>::inner_blks, |
918 | ib::_16a16b4a, ib::_16b16c4b)) |
919 | ? 16 |
920 | : 1; |
921 | assert(D1_blksize != 1); |
922 | |
923 | const auto &plain_d = order_keep ? input_d : output_d; |
924 | const auto &dims = input_d.dims(); |
925 | const auto ndims = input_d.ndims(); |
926 | const auto &pdims |
927 | = order_keep ? output_d.padded_dims() : input_d.padded_dims(); |
928 | |
929 | const dim_t batch_dim = ndims > 2 ? dims[ndims - 3] : 1; |
930 | const dim_t D0dim = dims[ndims - 2]; |
931 | const dim_t NB_D0dim = pdims[ndims - 2] / D0_blksize; |
932 | const dim_t D1dim = dims[ndims - 1]; |
933 | const dim_t NB_D1dim = pdims[ndims - 1] / D1_blksize; |
934 | assert(pdims[ndims - 1] == NB_D1dim * D1_blksize); |
935 | |
936 | const bool req_comp = output_d.extra().flags |
937 | & memory_extra_flags::compensation_conv_s8s8; |
938 | const bool has_asymmetric_comp = output_d.extra().flags |
939 | & memory_extra_flags::compensation_conv_asymmetric_src; |
940 | |
941 | float adj_scale |
942 | = (output_d.extra().flags & memory_extra_flags::scale_adjust) |
943 | ? output_d.extra().scale_adjust |
944 | : 1.f; |
945 | |
946 | auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out, |
947 | int32_t *cp, int32_t *zp, const float *s, |
948 | const float *d, const int d0_block, |
949 | const int d1_block) { |
950 | for (int d0 = 0; d0 < d0_block; ++d0) { |
951 | for (int d1 = 0; d1 < d1_block; ++d1) { |
952 | const auto plain_off |
953 | = d0 * plain_d.blocking_desc().strides[ndims - 2] |
954 | + d1 * plain_d.blocking_desc().strides[ndims - 1]; |
955 | auto index |
956 | = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>( |
957 | d0, d1); |
958 | out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
959 | inp[plain_off], s[0] * adj_scale * d[0]); |
960 | |
961 | auto o = static_cast<int32_t>(out[index]); |
962 | if (req_comp) cp[d1] -= (128 * o); |
963 | if (has_asymmetric_comp) zp[d1] -= o; |
964 | } |
965 | for (int d1 = d1_block; d1 < D1_blksize; ++d1) { |
966 | auto index |
967 | = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>( |
968 | d0, d1); |
969 | out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
970 | 0, s[0] * adj_scale * d[0]); |
971 | } |
972 | } |
973 | |
974 | for_(int d0 = d0_block; d0 < D0_blksize; ++d0) |
975 | for (int d1 = 0; d1 < D1_blksize; ++d1) { |
976 | auto index = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>( |
977 | d0, d1); |
978 | out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
979 | 0, s[0] * adj_scale * d[0]); |
980 | } |
981 | }; |
982 | |
983 | const auto w_d = order_keep ? output_d : input_d; |
984 | size_t offset = w_d.size() - w_d.additional_buffer_size(); |
985 | size_t comp_size = output_d.additional_buffer_size( |
986 | memory_extra_flags::compensation_conv_s8s8); |
987 | size_t zp_offset = offset + (req_comp ? comp_size : 0); |
988 | int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset) |
989 | : nullptr; |
990 | int32_t *zp = has_asymmetric_comp |
991 | ? reinterpret_cast<int32_t *>(output + zp_offset) |
992 | : nullptr; |
993 | |
994 | if (has_asymmetric_comp || req_comp) { |
995 | parallel_nd(batch_dim * NB_D1dim * D1_blksize, [&](dim_t i) { |
996 | if (req_comp) cp[i] = 0; |
997 | if (has_asymmetric_comp) zp[i] = 0; |
998 | }); |
999 | } |
1000 | |
1001 | #define get_blk_off(md, batch, d0, d1) \ |
1002 | (ndims == 3 ? (md).blk_off((batch), (d0), (d1)) : (md).blk_off((d0), (d1))) |
1003 | |
1004 | parallel_nd(batch_dim, NB_D1dim, [&](dim_t batch, dim_t D1) { |
1005 | for (int D0 = 0; D0 < NB_D0dim; D0++) { |
1006 | auto i = &input[get_blk_off( |
1007 | input_d, batch, D0_blksize * D0, D1_blksize * D1)]; |
1008 | auto o = &output[get_blk_off(output_d, batch, D0, D1)]; |
1009 | const dim_t d0_block |
1010 | = nstl::min(D0_blksize, D0dim - D0 * D0_blksize); |
1011 | const dim_t d1_block |
1012 | = nstl::min(D1_blksize, D1dim - D1 * D1_blksize); |
1013 | dim_t _offset = batch * NB_D1dim * D1_blksize + D1 * D1_blksize; |
1014 | int32_t *zp_ptr = (order_keep && has_asymmetric_comp) |
1015 | ? &zp[_offset] |
1016 | : nullptr; |
1017 | const float *src_scales_ptr |
1018 | = &src_scales[src_scales_mask == 0 ? 0 : _offset]; |
1019 | const float *dst_scales_ptr |
1020 | = &dst_scales[dst_scales_mask == 0 ? 0 : _offset]; |
1021 | ker(i, o, (order_keep && req_comp) ? &cp[_offset] : nullptr, |
1022 | zp_ptr, src_scales_ptr, dst_scales_ptr, d0_block, |
1023 | d1_block); |
1024 | } |
1025 | }); |
1026 | |
1027 | #undef get_blk_off |
1028 | |
1029 | return status::success; |
1030 | } |
1031 | }; |
1032 | |
1033 | template <SIMPLE_REORDER_TEMPL_DECL> |
1034 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1035 | typename utils::enable_if< |
1036 | (utils::one_of(tag_i, format_tag::goiw, format_tag::wigo) |
1037 | && utils::one_of(tag_o, format_tag::Goiw16g, |
1038 | format_tag::Goiw8g, format_tag::Goiw4g)) |
1039 | || (utils::one_of( |
1040 | tag_i, format_tag::goihw, format_tag::hwigo) |
1041 | && utils::one_of(tag_o, format_tag::Goihw16g, |
1042 | format_tag::Goihw8g, |
1043 | format_tag::Goihw4g)), |
1044 | spec::conv_req_comp>::type> { |
1045 | static bool is_applicable(const memory_desc_wrapper &input_d, |
1046 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
1047 | using namespace data_type; |
1048 | using namespace utils; |
1049 | |
1050 | if (input_d.has_runtime_dims_or_strides()) return false; |
1051 | |
1052 | int src_scales_mask, dst_scales_mask; |
1053 | auto status = get_scales_mask(attr, &src_scales_mask, &dst_scales_mask); |
1054 | if (status != status::success) return false; |
1055 | int scales_mask = std::max(src_scales_mask, dst_scales_mask); |
1056 | |
1057 | const dim_t g = input_d.dims()[0]; |
1058 | const dim_t oc = input_d.dims()[1]; |
1059 | const dim_t ic = input_d.dims()[2]; |
1060 | |
1061 | const bool req_comp = output_d.extra().flags |
1062 | & memory_extra_flags::compensation_conv_s8s8; |
1063 | const bool req_asymmetric_comp = output_d.extra().flags |
1064 | & memory_extra_flags::compensation_conv_asymmetric_src; |
1065 | int s8s8_comp_mask = output_d.extra().compensation_mask; |
1066 | int zp_comp_mask = output_d.extra().asymm_compensation_mask; |
1067 | int comp_mask = std::max(s8s8_comp_mask, zp_comp_mask); |
1068 | |
1069 | const size_t D_mask |
1070 | = array_product(input_d.dims(), math::ilog2q(comp_mask + 1)); |
1071 | |
1072 | return order_keep && oc == 1 && ic == 1 // depth-wise case |
1073 | && simple_attr_check(attr, true, false) |
1074 | && (req_comp || req_asymmetric_comp) |
1075 | && IMPLICATION(req_comp && req_asymmetric_comp, |
1076 | output_d.extra().compensation_mask |
1077 | == output_d.extra().asymm_compensation_mask) |
1078 | && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) |
1079 | && IMPLICATION( |
1080 | req_comp, one_of(D_mask, (size_t)1, (size_t)g * oc)) |
1081 | && one_of(scales_mask, 0, 0x3) |
1082 | && one_of(input_d.data_type(), f32, s8, bf16) |
1083 | && output_d.data_type() == s8; |
1084 | } |
1085 | |
1086 | GET_SCRATCHPAD_SIZE_ZERO(); |
1087 | |
1088 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1089 | DECLARE_COMMON_PARAMS(); |
1090 | |
1091 | constexpr bool is_1d |
1092 | = utils::one_of(tag_i, format_tag::goiw, format_tag::wigo); |
1093 | constexpr dim_t blksize |
1094 | = utils::one_of(tag_o, format_tag::Goihw4g, format_tag::Goiw4g) |
1095 | ? 4 |
1096 | : utils::one_of(tag_o, format_tag::Goihw8g, format_tag::Goiw8g) |
1097 | ? 8 |
1098 | : 16; |
1099 | |
1100 | const auto &dims = input_d.dims(); |
1101 | const auto &pdims = output_d.padded_dims(); |
1102 | const dim_t G = dims[0]; |
1103 | const dim_t Gp = pdims[0]; |
1104 | const dim_t OC = dims[1]; |
1105 | const dim_t IC = dims[2]; |
1106 | const dim_t H = is_1d ? 1 : dims[3]; |
1107 | const dim_t W = dims[4 - is_1d]; |
1108 | const bool zero_padding_needed = !output_d.is_dense(); |
1109 | |
1110 | const bool req_comp = output_d.extra().flags |
1111 | & memory_extra_flags::compensation_conv_s8s8; |
1112 | const bool has_asymmetric_comp = output_d.extra().flags |
1113 | & memory_extra_flags::compensation_conv_asymmetric_src; |
1114 | |
1115 | assert(req_comp || has_asymmetric_comp); |
1116 | |
1117 | float adj_scale |
1118 | = (output_d.extra().flags & memory_extra_flags::scale_adjust) |
1119 | ? output_d.extra().scale_adjust |
1120 | : 1.f; |
1121 | |
1122 | auto ker_out = [&](const data_t<type_i> *inp, data_t<type_o> *out, |
1123 | const float *src_scales, const float *dst_scales, |
1124 | const dim_t g_block) { |
1125 | PRAGMA_OMP_SIMD() |
1126 | for (dim_t g = 0; g < g_block; g++) { |
1127 | const auto i_off = g * input_d.blocking_desc().strides[0]; |
1128 | const float src_scale |
1129 | = src_scales[src_scales_mask == 0 ? 0 : g * OC]; |
1130 | const float dst_scale |
1131 | = dst_scales[dst_scales_mask == 0 ? 0 : g * OC]; |
1132 | out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
1133 | inp[i_off], src_scale * adj_scale * dst_scale); |
1134 | } |
1135 | }; |
1136 | |
1137 | /* Note: having separate kernels for s8 and zero-point fixes a |
1138 | * compiler-generated bug which results in seg-fault. */ |
1139 | auto ker_s8 = [&](const data_t<type_o> *out, int32_t *cp, |
1140 | const dim_t g_block) { |
1141 | PRAGMA_OMP_SIMD() |
1142 | for (dim_t g = 0; g < g_block; g++) { |
1143 | cp[g * OC] -= 128 * (int32_t)(out[g]); |
1144 | } |
1145 | }; |
1146 | auto ker_zp = [&](const data_t<type_o> *out, int32_t *zp, |
1147 | const dim_t g_block) { |
1148 | PRAGMA_OMP_SIMD() |
1149 | for (dim_t g = 0; g < g_block; g++) { |
1150 | zp[g * OC] -= (int32_t)(out[g]); |
1151 | } |
1152 | }; |
1153 | |
1154 | size_t offset = output_d.size() - output_d.additional_buffer_size(); |
1155 | size_t comp_size = output_d.additional_buffer_size( |
1156 | memory_extra_flags::compensation_conv_s8s8); |
1157 | size_t zp_offset = offset + (req_comp ? comp_size : 0); |
1158 | int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset) |
1159 | : nullptr; |
1160 | int32_t *zp = has_asymmetric_comp |
1161 | ? reinterpret_cast<int32_t *>(output + zp_offset) |
1162 | : nullptr; |
1163 | |
1164 | parallel_nd((Gp / blksize) * OC, [&](dim_t ib) { |
1165 | PRAGMA_OMP_SIMD() |
1166 | for (dim_t i = 0; i < blksize; i++) { |
1167 | if (req_comp) cp[ib * blksize + i] = 0; |
1168 | if (has_asymmetric_comp) zp[ib * blksize + i] = 0; |
1169 | } |
1170 | }); |
1171 | |
1172 | #define wei_blk_off(md, g, o, i, h, w) \ |
1173 | (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w)) |
1174 | |
1175 | parallel_nd(Gp / blksize, OC, [&](dim_t gb, dim_t O) { |
1176 | for (dim_t I = 0; I < IC; I++) { |
1177 | for_(dim_t h = 0; h < H; h++) |
1178 | for (dim_t w = 0; w < W; w++) { |
1179 | const dim_t g_block = nstl::min(G - gb * blksize, blksize); |
1180 | const auto inp = &input[wei_blk_off( |
1181 | input_d, gb * blksize, O, I, h, w)]; |
1182 | const auto out |
1183 | = &output[wei_blk_off(output_d, gb, O, I, h, w)]; |
1184 | dim_t offset = gb * blksize + O; |
1185 | const float *src_scales_ptr |
1186 | = &src_scales[src_scales_mask == 0 ? 0 : offset]; |
1187 | const float *dst_scales_ptr |
1188 | = &dst_scales[dst_scales_mask == 0 ? 0 : offset]; |
1189 | |
1190 | ker_out(inp, out, src_scales_ptr, dst_scales_ptr, g_block); |
1191 | if (req_comp) ker_s8(out, &cp[offset], g_block); |
1192 | if (has_asymmetric_comp) ker_zp(out, &zp[offset], g_block); |
1193 | |
1194 | if (zero_padding_needed) { |
1195 | PRAGMA_OMP_SIMD() |
1196 | for (int off = g_block; off < blksize; off++) |
1197 | out[off] = 0; |
1198 | } |
1199 | } |
1200 | } |
1201 | }); |
1202 | |
1203 | #undef wei_blk_off |
1204 | |
1205 | return status::success; |
1206 | } |
1207 | }; |
1208 | |
1209 | /* bf16 reorders */ |
1210 | template <SIMPLE_REORDER_TEMPL_DECL> |
1211 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1212 | typename utils::enable_if<( |
1213 | (tag_i == format_tag::goihw || tag_i == format_tag::oihw) |
1214 | && (tag_o == format_tag::gOIhw16i16o |
1215 | || tag_o == format_tag::OIhw16i16o |
1216 | || tag_o == format_tag::gOIhw8i16o2i |
1217 | || tag_o == format_tag::OIhw8i16o2i |
1218 | || tag_o == format_tag::gOIhw8o16i2o |
1219 | || tag_o == format_tag::OIhw8o16i2o |
1220 | || tag_o == format_tag::gIOhw8o16i2o |
1221 | || tag_o == format_tag::IOhw8o16i2o) |
1222 | && type_i == data_type::f32 |
1223 | && type_o == data_type::bf16)>::type> { |
1224 | static bool is_applicable(const memory_desc_wrapper &input_d, |
1225 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
1226 | using namespace data_type; |
1227 | |
1228 | if (input_d.has_runtime_dims_or_strides()) return false; |
1229 | |
1230 | return order_keep && input_d.matches_tag(tag_i) |
1231 | && output_d.matches_tag(tag_o) && input_d.data_type() == f32 |
1232 | && output_d.data_type() == bf16 && attr->has_default_values(); |
1233 | } |
1234 | |
1235 | static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, |
1236 | const memory_desc_wrapper &output_d) { |
1237 | const dim_t blksize = 16; |
1238 | return sizeof(float) * blksize * blksize * dnnl_get_max_threads(); |
1239 | } |
1240 | |
1241 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1242 | DECLARE_COMMON_PARAMS(); |
1243 | using namespace format_tag; |
1244 | |
1245 | static constexpr bool w_groups = tag_i == goihw; |
1246 | const dim_t blksize = 16; |
1247 | const int sblk = 2; |
1248 | |
1249 | const auto &plain_d = input_d; |
1250 | const auto &dims = input_d.dims(); |
1251 | const auto &pdims = output_d.padded_dims(); |
1252 | |
1253 | const dim_t G = w_groups ? dims[0] : 1; |
1254 | const dim_t OC = dims[w_groups + 0]; |
1255 | const dim_t NB_OC = pdims[w_groups + 0] / blksize; |
1256 | const dim_t IC = dims[w_groups + 1]; |
1257 | const dim_t NB_IC = pdims[w_groups + 1] / blksize; |
1258 | const dim_t H = dims[w_groups + 2]; |
1259 | const dim_t W = dims[w_groups + 3]; |
1260 | |
1261 | const size_t wsp_size = blksize * blksize; |
1262 | float *wspace = scratchpad.template get<float>( |
1263 | memory_tracking::names::key_reorder_space); |
1264 | |
1265 | auto index = [&](dim_t ic, dim_t oc) -> dim_t { |
1266 | if (utils::one_of(tag_o, gOIhw16i16o, OIhw16i16o)) |
1267 | return (ic * blksize + oc); |
1268 | else if (utils::one_of(tag_o, gOIhw8i16o2i, OIhw8i16o2i)) |
1269 | return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk); |
1270 | else if (utils::one_of(tag_o, gOIhw8o16i2o, gIOhw8o16i2o, |
1271 | OIhw8o16i2o, IOhw8o16i2o)) |
1272 | return ((oc / sblk) * blksize * sblk + sblk * ic + oc % sblk); |
1273 | else |
1274 | assert(!"Invalid weight format" ); |
1275 | return dim_t(0); |
1276 | }; |
1277 | |
1278 | auto ker = [&](const data_t<type_i> *inp, data_t<type_i> *out, |
1279 | const dim_t curr_oc_block, const dim_t oc_block, |
1280 | const dim_t curr_ic_block, const dim_t ic_block) { |
1281 | dim_t ic = 0; |
1282 | for (ic = 0; ic < curr_ic_block; ++ic) { |
1283 | dim_t oc = 0; |
1284 | for (oc = 0; oc < curr_oc_block; ++oc) { |
1285 | const auto plain_off |
1286 | = oc * plain_d.blocking_desc().strides[w_groups + 0] |
1287 | + ic |
1288 | * plain_d.blocking_desc() |
1289 | .strides[w_groups + 1]; |
1290 | out[index(ic, oc)] = inp[plain_off]; |
1291 | } |
1292 | for (/* continue */; oc < oc_block; ++oc) { |
1293 | out[index(ic, oc)] = (data_t<type_i>)0; |
1294 | } |
1295 | } |
1296 | for (/* continue */; ic < ic_block; ++ic) { |
1297 | for (dim_t oc = 0; oc < oc_block; ++oc) { |
1298 | out[index(ic, oc)] = (data_t<type_i>)0; |
1299 | } |
1300 | } |
1301 | }; |
1302 | |
1303 | constexpr int i_mult = blksize; |
1304 | constexpr int o_mult = 1; |
1305 | |
1306 | parallel_nd_ext(0, G, NB_OC, NB_IC, H, W, |
1307 | [&](int ithr, int, dim_t g, dim_t O, dim_t I, dim_t h, |
1308 | dim_t w) { |
1309 | float *_wspace = wspace + wsp_size * ithr; |
1310 | auto i = &input[input_d.blk_off<!w_groups>( |
1311 | g, i_mult * O, i_mult * I, h, w)]; |
1312 | auto o = &output[output_d.blk_off<!w_groups>( |
1313 | g, o_mult * O, o_mult * I, h, w)]; |
1314 | const dim_t oc_block = nstl::min(blksize, OC - O * blksize); |
1315 | const dim_t ic_block = nstl::min(blksize, IC - I * blksize); |
1316 | ker(i, _wspace, oc_block, blksize, ic_block, blksize); |
1317 | cvt_float_to_bfloat16(o, _wspace, wsp_size); |
1318 | }); |
1319 | |
1320 | return status::success; |
1321 | } |
1322 | }; |
1323 | |
1324 | template <SIMPLE_REORDER_TEMPL_DECL> |
1325 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1326 | typename utils::enable_if<(tag_i == format_tag::nchw |
1327 | && tag_o == format_tag::nChw16c) |
1328 | && type_i == data_type::f32 |
1329 | && type_o == data_type::bf16>::type> { |
1330 | static bool is_applicable(const memory_desc_wrapper &input_d, |
1331 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
1332 | using namespace data_type; |
1333 | |
1334 | if (input_d.has_runtime_dims_or_strides()) return false; |
1335 | |
1336 | return input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o) |
1337 | && input_d.data_type() == f32 && output_d.data_type() == bf16 |
1338 | && attr->has_default_values(); |
1339 | } |
1340 | |
1341 | static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, |
1342 | const memory_desc_wrapper &output_d) { |
1343 | const size_t blksize = 16; |
1344 | const size_t W = input_d.dims()[3]; |
1345 | return sizeof(float) * blksize * W * dnnl_get_max_threads(); |
1346 | } |
1347 | |
1348 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1349 | DECLARE_COMMON_PARAMS(); |
1350 | |
1351 | const dim_t blksize = 16; |
1352 | |
1353 | const auto &flat_d = input_d; |
1354 | const auto &dims = input_d.dims(); |
1355 | const auto &pdims = output_d.padded_dims(); |
1356 | |
1357 | const dim_t C = dims[1]; |
1358 | const dim_t H = dims[2]; |
1359 | const dim_t W = dims[3]; |
1360 | |
1361 | const dim_t wsp_size = W * blksize; |
1362 | float *wspace = scratchpad.template get<float>( |
1363 | memory_tracking::names::key_reorder_space); |
1364 | |
1365 | auto ker = [&](const data_t<type_i> *i, data_t<type_i> *o, |
1366 | const dim_t curr_c_block, const dim_t c_block) { |
1367 | for (dim_t w = 0; w < W; ++w) { |
1368 | dim_t c = 0; |
1369 | for (c = 0; c < curr_c_block; ++c) { |
1370 | const ptrdiff_t flat_off = 0 |
1371 | + c * flat_d.blocking_desc().strides[1] |
1372 | + w * flat_d.blocking_desc().strides[3]; |
1373 | o[w * blksize + c] = i[flat_off]; |
1374 | } |
1375 | for (/* continue */; c < c_block; ++c) { |
1376 | o[w * blksize + c] = (data_t<type_i>)0; |
1377 | } |
1378 | } |
1379 | }; |
1380 | |
1381 | constexpr int i_c_mult = blksize; |
1382 | constexpr int o_c_mult = 1; |
1383 | |
1384 | parallel_nd_ext(0, dims[0], pdims[1] / blksize, H, |
1385 | [&](int ithr, int, dim_t n, dim_t nb_c, dim_t h) { |
1386 | float *_wspace = wspace + wsp_size * ithr; |
1387 | auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)]; |
1388 | auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)]; |
1389 | const dim_t c_block |
1390 | = nstl::min(blksize, C - nb_c * blksize); |
1391 | ker(i, _wspace, c_block, blksize); |
1392 | cvt_float_to_bfloat16(o, _wspace, wsp_size); |
1393 | }); |
1394 | |
1395 | return status::success; |
1396 | } |
1397 | }; |
1398 | |
1399 | /* reorders with tail support */ |
1400 | |
1401 | template <SIMPLE_REORDER_TEMPL_DECL> |
1402 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1403 | typename utils::enable_if<false |
1404 | || (utils::one_of( |
1405 | tag_i, format_tag::nCdhw4c, format_tag::nCdhw8c) |
1406 | && tag_o == format_tag::nCdhw16c) |
1407 | || (utils::one_of(tag_i, format_tag::nChw4c, format_tag::nChw8c) |
1408 | && tag_o == format_tag::nChw16c) |
1409 | || (utils::one_of(tag_i, format_tag::nCw4c, format_tag::nCw8c) |
1410 | && tag_o == format_tag::nCw16c)>::type> { |
1411 | static bool is_applicable(const memory_desc_wrapper &input_d, |
1412 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
1413 | return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) |
1414 | && simple_attr_check(attr, false, true); |
1415 | } |
1416 | |
1417 | GET_SCRATCHPAD_SIZE_ZERO(); |
1418 | |
1419 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1420 | DECLARE_COMMON_PARAMS(); |
1421 | using namespace format_tag; |
1422 | |
1423 | constexpr int is_1d = utils::one_of(tag_i, nCw4c, nCw8c); |
1424 | constexpr int is_3d = utils::one_of(tag_i, nCdhw4c, nCdhw8c); |
1425 | |
1426 | constexpr dim_t blksize_i |
1427 | = tag_traits<tag_i>::inner_blks == ib::_4b ? 4 : 8; |
1428 | constexpr dim_t blksize_16 = 16; |
1429 | |
1430 | constexpr dim_t ic_mult = order_keep ? blksize_16 / blksize_i : 1; |
1431 | constexpr dim_t oc_mult = order_keep ? 1 : blksize_16 / blksize_i; |
1432 | |
1433 | const auto &dims = input_d.dims(); |
1434 | const auto &pdims |
1435 | = order_keep ? output_d.padded_dims() : input_d.padded_dims(); |
1436 | |
1437 | const auto &d_i = order_keep ? input_d : output_d; |
1438 | const auto stride_C_in_blk_i = d_i.blocking_desc().strides[1]; |
1439 | |
1440 | const dim_t C = dims[1]; |
1441 | const dim_t D = is_3d ? dims[2] : 1; |
1442 | const dim_t H = is_1d ? 1 : dims[2 + is_3d]; |
1443 | const dim_t W = dims[3 + is_3d - is_1d]; |
1444 | |
1445 | auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, |
1446 | const int block) { |
1447 | const int nb = utils::div_up(block, blksize_i); |
1448 | if (alpha == 1.0 && beta == 0.0) { |
1449 | for (int b = 0; b < nb; ++b) { |
1450 | const ptrdiff_t i_off |
1451 | = b * (order_keep ? stride_C_in_blk_i : blksize_i); |
1452 | const ptrdiff_t o_off |
1453 | = b * (order_keep ? blksize_i : stride_C_in_blk_i); |
1454 | const int block_i |
1455 | = nstl::min(blksize_i, block - b * blksize_i); |
1456 | for (int c = 0; c < block_i; ++c) { |
1457 | o[o_off + c] = _qz_a1b0<type_i, type_o>()(i[i_off + c]); |
1458 | } |
1459 | if (b + 1 == nb) { |
1460 | // zero padding |
1461 | const auto pad_size = order_keep |
1462 | ? blksize_16 - ((nb - 1) * blksize_i) |
1463 | : blksize_i; |
1464 | const auto pad_start = block_i + o_off; |
1465 | const auto pad_end = pad_size + o_off; |
1466 | PRAGMA_OMP_SIMD() |
1467 | for (int i = pad_start; i < pad_end; i++) { |
1468 | o[i] = 0; |
1469 | } |
1470 | } |
1471 | } |
1472 | } else { |
1473 | for (int b = 0; b < nb; ++b) { |
1474 | const ptrdiff_t i_off |
1475 | = b * (order_keep ? stride_C_in_blk_i : blksize_i); |
1476 | const ptrdiff_t o_off |
1477 | = b * (order_keep ? blksize_i : stride_C_in_blk_i); |
1478 | const int block_i |
1479 | = nstl::min(blksize_i, block - b * blksize_i); |
1480 | for (int c = 0; c < block_i; ++c) { |
1481 | o[o_off + c] = _qz<type_i, type_o>()( |
1482 | i[i_off + c], o[o_off + c], alpha, beta); |
1483 | } |
1484 | if (b + 1 == nb) { |
1485 | // zero padding |
1486 | const auto pad_size = order_keep |
1487 | ? blksize_16 - ((nb - 1) * blksize_i) |
1488 | : blksize_i; |
1489 | const auto pad_start = block_i + o_off; |
1490 | const auto pad_end = pad_size + o_off; |
1491 | PRAGMA_OMP_SIMD() |
1492 | for (int i = pad_start; i < pad_end; i++) { |
1493 | o[i] = 0; |
1494 | } |
1495 | } |
1496 | } |
1497 | } |
1498 | }; |
1499 | |
1500 | #define data_blk_off(md, n, c, d, h, w) \ |
1501 | (is_1d ? (md).blk_off(n, c, w) \ |
1502 | : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) |
1503 | |
1504 | parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, |
1505 | [&](dim_t n, dim_t nb_c, dim_t d, dim_t h, dim_t w) { |
1506 | auto i = &input[data_blk_off( |
1507 | input_d, n, ic_mult * nb_c, d, h, w)]; |
1508 | auto o = &output[data_blk_off( |
1509 | output_d, n, oc_mult * nb_c, d, h, w)]; |
1510 | const int block |
1511 | = nstl::min(blksize_16, C - nb_c * blksize_16); |
1512 | ker(i, o, block); |
1513 | }); |
1514 | |
1515 | #undef data_blk_off |
1516 | |
1517 | return status::success; |
1518 | } |
1519 | }; |
1520 | |
1521 | #define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ |
1522 | static bool is_applicable(const memory_desc_wrapper &input_d, \ |
1523 | const memory_desc_wrapper &output_d, \ |
1524 | const primitive_attr_t *attr) { \ |
1525 | return !input_d.has_runtime_dims_or_strides() \ |
1526 | && simple_attr_check(attr, false, true) \ |
1527 | && (order_keep ? output_d.matches_tag(tag_o) \ |
1528 | && input_d.is_plain() \ |
1529 | : input_d.matches_tag(tag_o) \ |
1530 | && output_d.is_plain()); \ |
1531 | } |
1532 | |
1533 | template <SIMPLE_REORDER_TEMPL_DECL> |
1534 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1535 | typename utils::enable_if<tag_i == format_tag::any |
1536 | && (tag_traits<tag_o>::block_dims == bd::_A |
1537 | || tag_traits<tag_o>::block_dims == bd::_B) |
1538 | && tag_traits<tag_o>::ndims >= 3 |
1539 | && tag_traits<tag_o>::ndims <= 6>::type> { |
1540 | PLAIN_TO_BLOCKED_IS_APPLICABLE(); |
1541 | |
1542 | GET_SCRATCHPAD_SIZE_ZERO(); |
1543 | |
1544 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1545 | DECLARE_COMMON_PARAMS(); |
1546 | |
1547 | const auto &flat_d = order_keep ? input_d : output_d; |
1548 | const auto &block_d = order_keep ? output_d : input_d; |
1549 | const dims_t &dims = input_d.dims(); |
1550 | const dims_t &pdims = block_d.padded_dims(); |
1551 | |
1552 | const int ndims = tag_traits<tag_o>::ndims; |
1553 | const int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1; |
1554 | |
1555 | const dim_t H0 = dims[0]; |
1556 | const dim_t H1 = dims[1]; |
1557 | const dim_t M0 = ndims == 6 ? dims[ndims - 4] : 1; |
1558 | const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; |
1559 | const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; |
1560 | const dim_t L = dims[ndims - 1]; |
1561 | const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; |
1562 | const dim_t l_flat_stride = flat_d.blocking_desc().strides[ndims - 1]; |
1563 | const dim_t blk_flat_stride = flat_d.blocking_desc().strides[blk_idx]; |
1564 | using namespace data_type; |
1565 | using namespace utils; |
1566 | |
1567 | dim_t blksize = -1; |
1568 | switch (tag_traits<tag_o>::inner_blks) { |
1569 | case ib::_4a: |
1570 | case ib::_4b: blksize = 4; break; |
1571 | case ib::_8a: |
1572 | case ib::_8b: blksize = 8; break; |
1573 | default: blksize = 16; |
1574 | } |
1575 | |
1576 | constexpr bool f32bf16 |
1577 | = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16); |
1578 | |
1579 | auto wrap_qz_a1b0 = [=](data_t<type_o> &out, data_t<type_i> inp) { |
1580 | if (f32bf16) |
1581 | out = inp; |
1582 | else |
1583 | out = _qz_a1b0<type_i, type_o>()(inp); |
1584 | }; |
1585 | |
1586 | auto wrap_qz = [=](data_t<type_o> &out, data_t<type_i> inp, float alpha, |
1587 | float beta) { |
1588 | if (f32bf16) |
1589 | out = alpha * inp + (beta ? beta * out : 0); |
1590 | else |
1591 | out = _qz<type_i, type_o>()(inp, out, alpha, beta); |
1592 | }; |
1593 | |
1594 | auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) { |
1595 | if (alpha == 1.0 && beta == 0.0) { |
1596 | for (int l = 0; l < L; ++l) { |
1597 | for (int blk = 0; blk < block; ++blk) { |
1598 | const dim_t flat_off |
1599 | = blk * blk_flat_stride + l * l_flat_stride; |
1600 | const dim_t blk_offset = l * l_blk_stride + blk; |
1601 | if (order_keep) { |
1602 | wrap_qz_a1b0(o[blk_offset], i[flat_off]); |
1603 | } else { |
1604 | wrap_qz_a1b0(o[flat_off], i[blk_offset]); |
1605 | } |
1606 | } |
1607 | if (order_keep) { |
1608 | // zero padding |
1609 | const auto pad_start = block + l * l_blk_stride; |
1610 | const auto pad_end = blksize + l * l_blk_stride; |
1611 | PRAGMA_OMP_SIMD() |
1612 | for (int i = pad_start; i < pad_end; ++i) { |
1613 | o[i] = 0; |
1614 | } |
1615 | } |
1616 | } |
1617 | } else { |
1618 | for (int l = 0; l < L; ++l) { |
1619 | for (int blk = 0; blk < block; ++blk) { |
1620 | const dim_t flat_off |
1621 | = blk * blk_flat_stride + l * l_flat_stride; |
1622 | const dim_t blk_offset = l * l_blk_stride + blk; |
1623 | if (order_keep) |
1624 | wrap_qz(o[blk_offset], i[flat_off], alpha, beta); |
1625 | else |
1626 | wrap_qz(o[flat_off], i[blk_offset], alpha, beta); |
1627 | } |
1628 | if (order_keep) { |
1629 | // zero padding |
1630 | const auto pad_start = block + l * l_blk_stride; |
1631 | const auto pad_end = blksize + l * l_blk_stride; |
1632 | PRAGMA_OMP_SIMD() |
1633 | for (int i = pad_start; i < pad_end; ++i) { |
1634 | o[i] = 0; |
1635 | } |
1636 | } |
1637 | } |
1638 | } |
1639 | }; |
1640 | |
1641 | #define off(md, h0, h1, m0, m1, m2) \ |
1642 | (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ |
1643 | : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ |
1644 | : ndims >= 4 \ |
1645 | ? (md).blk_off(h0, h1, m2) \ |
1646 | : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) |
1647 | |
1648 | const int i_mult = order_keep ? blksize : 1; |
1649 | const int o_mult = order_keep ? 1 : blksize; |
1650 | |
1651 | if (blk_idx == 0) { |
1652 | const dim_t BH0 = pdims[0] / blksize; |
1653 | parallel_nd(BH0, H1, M0, M1, M2, |
1654 | [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { |
1655 | auto i = &input[off( |
1656 | input_d, bh0 * i_mult, h1, m0, m1, m2)]; |
1657 | auto o = &output[off( |
1658 | output_d, bh0 * o_mult, h1, m0, m1, m2)]; |
1659 | const int block |
1660 | = nstl::min<int>(blksize, H0 - bh0 * blksize); |
1661 | ker(i, o, block); |
1662 | }); |
1663 | } else if (blk_idx == 1) { |
1664 | const dim_t BH1 = pdims[1] / blksize; |
1665 | parallel_nd(H0, BH1, M0, M1, M2, |
1666 | [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { |
1667 | auto i = &input[off( |
1668 | input_d, h0, bh1 * i_mult, m0, m1, m2)]; |
1669 | auto o = &output[off( |
1670 | output_d, h0, bh1 * o_mult, m0, m1, m2)]; |
1671 | const int block |
1672 | = nstl::min<int>(blksize, H1 - bh1 * blksize); |
1673 | ker(i, o, block); |
1674 | }); |
1675 | } else { |
1676 | assert(!"unimplemented" ); |
1677 | } |
1678 | |
1679 | #undef off |
1680 | |
1681 | return status::success; |
1682 | } |
1683 | }; |
1684 | |
1685 | template <SIMPLE_REORDER_TEMPL_DECL> |
1686 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1687 | typename utils::enable_if<tag_i == format_tag::any |
1688 | && (tag_traits<tag_o>::block_dims == bd::_AB |
1689 | || tag_traits<tag_o>::block_dims == bd::_BC) |
1690 | && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB, |
1691 | tag_traits<tag_o>::ndims >= 3 |
1692 | && tag_traits<tag_o>::ndims <= 5) |
1693 | && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC, |
1694 | tag_traits<tag_o>::ndims >= 4 |
1695 | && tag_traits<tag_o>::ndims <= 6)>::type> { |
1696 | PLAIN_TO_BLOCKED_IS_APPLICABLE(); |
1697 | |
1698 | GET_SCRATCHPAD_SIZE_ZERO(); |
1699 | |
1700 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1701 | DECLARE_COMMON_PARAMS(); |
1702 | |
1703 | const auto &flat_d = order_keep ? input_d : output_d; |
1704 | const auto &dims = input_d.dims(); |
1705 | const auto &pdims |
1706 | = order_keep ? output_d.padded_dims() : input_d.padded_dims(); |
1707 | |
1708 | constexpr int ndims = tag_traits<tag_o>::ndims; |
1709 | |
1710 | static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC; |
1711 | const dim_t G = with_g ? dims[0] : 1; |
1712 | |
1713 | const dim_t H0 = dims[0 + with_g]; |
1714 | const dim_t H1 = dims[1 + with_g]; |
1715 | |
1716 | const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; |
1717 | const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; |
1718 | const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; |
1719 | |
1720 | const dim_t h0_flat_stride = flat_d.blocking_desc().strides[with_g + 0]; |
1721 | const dim_t h1_flat_stride = flat_d.blocking_desc().strides[with_g + 1]; |
1722 | using namespace data_type; |
1723 | using namespace utils; |
1724 | |
1725 | dim_t blksize_0 = -1; |
1726 | dim_t blksize_1 = -1; |
1727 | switch (tag_traits<tag_o>::inner_blks) { |
1728 | case ib::_4b4a: |
1729 | case ib::_4b4c: |
1730 | case ib::_4c4b: |
1731 | blksize_0 = 4; |
1732 | blksize_1 = 4; |
1733 | break; |
1734 | case ib::_8a8b: |
1735 | case ib::_8b8a: |
1736 | case ib::_8b8c: |
1737 | case ib::_8c8b: |
1738 | case ib::_2c8b4c: |
1739 | blksize_0 = 8; |
1740 | blksize_1 = 8; |
1741 | break; |
1742 | case ib::_16a16b: |
1743 | case ib::_16b16a: |
1744 | case ib::_16b16c: |
1745 | case ib::_16c16b: |
1746 | case ib::_8a16b2a: |
1747 | case ib::_4b16a4b: |
1748 | case ib::_8b16a2b: |
1749 | case ib::_8b16c2b: |
1750 | case ib::_4c16b4c: |
1751 | case ib::_8c16b2c: |
1752 | blksize_0 = 16; |
1753 | blksize_1 = 16; |
1754 | break; |
1755 | default: blksize_0 = -1; blksize_1 = -1; |
1756 | } |
1757 | |
1758 | const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; |
1759 | const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; |
1760 | |
1761 | constexpr bool f32bf16 |
1762 | = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16); |
1763 | |
1764 | auto wrap_qz_a1b0 = [=](data_t<type_o> &out, data_t<type_i> inp) { |
1765 | if (f32bf16) |
1766 | out = inp; |
1767 | else |
1768 | out = _qz_a1b0<type_i, type_o>()(inp); |
1769 | }; |
1770 | |
1771 | auto wrap_qz = [=](data_t<type_o> &out, data_t<type_i> inp, float alpha, |
1772 | float beta) { |
1773 | if (f32bf16) |
1774 | out = alpha * inp + (beta ? beta * out : 0); |
1775 | else |
1776 | out = _qz<type_i, type_o>()(inp, out, alpha, beta); |
1777 | }; |
1778 | |
1779 | auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, |
1780 | const int block_h0, const int block_h1) { |
1781 | #define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks> |
1782 | if (alpha == 1.0 && beta == 0.0) { |
1783 | for (int h0 = 0; h0 < block_h0; ++h0) { |
1784 | for (int h1 = 0; h1 < block_h1; ++h1) { |
1785 | const dim_t flat_off |
1786 | = h0 * h0_flat_stride + h1 * h1_flat_stride; |
1787 | if (order_keep) |
1788 | wrap_qz_a1b0(o[blk_off(h0, h1)], i[flat_off]); |
1789 | else |
1790 | wrap_qz_a1b0(o[flat_off], i[blk_off(h0, h1)]); |
1791 | } |
1792 | if (order_keep && block_h1 < blksize_1) { |
1793 | // zero padding |
1794 | PRAGMA_OMP_SIMD() |
1795 | for (int h1 = block_h1; h1 < blksize_1; h1++) { |
1796 | o[blk_off(h0, h1)] = 0; |
1797 | } |
1798 | } |
1799 | } |
1800 | if (order_keep && block_h0 < blksize_0) { |
1801 | // zero padding |
1802 | for (int h0 = block_h0; h0 < blksize_0; h0++) { |
1803 | PRAGMA_OMP_SIMD() |
1804 | for (int h1 = 0; h1 < blksize_1; ++h1) { |
1805 | o[blk_off(h0, h1)] = 0; |
1806 | } |
1807 | } |
1808 | } |
1809 | } else { |
1810 | for (int h0 = 0; h0 < block_h0; ++h0) { |
1811 | for (int h1 = 0; h1 < block_h1; ++h1) { |
1812 | const dim_t flat_off |
1813 | = h0 * h0_flat_stride + h1 * h1_flat_stride; |
1814 | if (order_keep) |
1815 | wrap_qz(o[blk_off(h0, h1)], i[flat_off], alpha, |
1816 | beta); |
1817 | else |
1818 | wrap_qz(o[flat_off], i[blk_off(h0, h1)], alpha, |
1819 | beta); |
1820 | } |
1821 | if (order_keep && block_h1 < blksize_1) { |
1822 | // zero padding |
1823 | PRAGMA_OMP_SIMD() |
1824 | for (int h1 = block_h1; h1 < blksize_1; h1++) { |
1825 | o[blk_off(h0, h1)] = 0; |
1826 | } |
1827 | } |
1828 | } |
1829 | if (order_keep && block_h0 < blksize_0) { |
1830 | // zero padding |
1831 | for (int h0 = block_h0; h0 < blksize_0; h0++) { |
1832 | PRAGMA_OMP_SIMD() |
1833 | for (int h1 = 0; h1 < blksize_1; ++h1) { |
1834 | o[blk_off(h0, h1)] = 0; |
1835 | } |
1836 | } |
1837 | } |
1838 | } |
1839 | |
1840 | #undef blk_off |
1841 | }; |
1842 | |
1843 | const int i_mult_0 = order_keep ? blksize_0 : 1; |
1844 | const int o_mult_0 = order_keep ? 1 : blksize_0; |
1845 | |
1846 | const int i_mult_1 = order_keep ? blksize_1 : 1; |
1847 | const int o_mult_1 = order_keep ? 1 : blksize_1; |
1848 | |
1849 | #define off(md, g, h0, h1, m0, m1, m2) \ |
1850 | (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \ |
1851 | : ndims >= 4 + with_g \ |
1852 | ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \ |
1853 | : /* ndims >= 3 + with_g ? */ (md) \ |
1854 | .blk_off<!with_g>(g, h0, h1, m2)) |
1855 | |
1856 | parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, |
1857 | [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, |
1858 | dim_t m2) { |
1859 | auto i = &input[off(input_d, g, i_mult_0 * nb_h0, |
1860 | i_mult_1 * nb_h1, m0, m1, m2)]; |
1861 | auto o = &output[off(output_d, g, o_mult_0 * nb_h0, |
1862 | o_mult_1 * nb_h1, m0, m1, m2)]; |
1863 | const int block_h0 |
1864 | = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0); |
1865 | const int block_h1 |
1866 | = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1); |
1867 | ker(i, o, block_h0, block_h1); |
1868 | }); |
1869 | |
1870 | #undef off |
1871 | |
1872 | return status::success; |
1873 | } |
1874 | }; |
1875 | |
1876 | /* generic and direct-copy reorders */ |
1877 | |
1878 | template <SIMPLE_REORDER_TEMPL_DECL> |
1879 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1880 | typename utils::enable_if<tag_i == format_tag::any |
1881 | && tag_o == format_tag::any |
1882 | && order_keep == fmt_order::any, |
1883 | spec::direct_copy>::type> { |
1884 | static bool is_applicable(const memory_desc_wrapper &input_d, |
1885 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
1886 | return !input_d.has_runtime_dims_or_strides() |
1887 | && input_d.similar_to(output_d, true, false, 0) |
1888 | && input_d.is_dense() && output_d.is_dense() |
1889 | && simple_attr_check(attr, false, true); |
1890 | } |
1891 | |
1892 | GET_SCRATCHPAD_SIZE_ZERO(); |
1893 | |
1894 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1895 | DECLARE_COMMON_PARAMS(); |
1896 | |
1897 | input += input_d.blk_off(0); |
1898 | output += output_d.blk_off(0); |
1899 | |
1900 | const size_t nelems = input_d.nelems(); |
1901 | |
1902 | constexpr int block_size = 16; |
1903 | const auto num_blocks = nelems / block_size; |
1904 | const auto rem_elems = nelems % block_size; |
1905 | |
1906 | parallel(0, [&](const int ithr, const int nthr) { |
1907 | size_t start {0}, end {0}; |
1908 | balance211(num_blocks, nthr, ithr, start, end); |
1909 | start = start * block_size; |
1910 | end = end * block_size; |
1911 | |
1912 | if (alpha == 1.0 && beta == 0.0) { |
1913 | PRAGMA_OMP_SIMD() |
1914 | for (size_t e = start; e < end; ++e) { |
1915 | output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()( |
1916 | input[e]); |
1917 | } |
1918 | } else if (alpha == 1.0) { |
1919 | PRAGMA_OMP_SIMD() |
1920 | for (size_t e = start; e < end; ++e) { |
1921 | output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()( |
1922 | input[e], output[e], beta); |
1923 | } |
1924 | } else if (beta == 0.0) { |
1925 | PRAGMA_OMP_SIMD() |
1926 | for (size_t e = start; e < end; ++e) { |
1927 | output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
1928 | input[e], alpha); |
1929 | } |
1930 | } else { |
1931 | PRAGMA_OMP_SIMD() |
1932 | for (size_t e = start; e < end; ++e) { |
1933 | output[e] = qz<data_t<type_i>, data_t<type_o>>()( |
1934 | input[e], output[e], alpha, beta); |
1935 | } |
1936 | } |
1937 | |
1938 | if (rem_elems != 0 && ithr == nthr - 1) { |
1939 | if (alpha == 1.0 && beta == 0.0) { |
1940 | PRAGMA_OMP_SIMD() |
1941 | for (size_t e = nelems - rem_elems; e < nelems; ++e) { |
1942 | output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()( |
1943 | input[e]); |
1944 | } |
1945 | } else if (alpha == 1.0) { |
1946 | PRAGMA_OMP_SIMD() |
1947 | for (size_t e = nelems - rem_elems; e < nelems; ++e) { |
1948 | output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()( |
1949 | input[e], output[e], beta); |
1950 | } |
1951 | } else if (beta == 0.0) { |
1952 | PRAGMA_OMP_SIMD() |
1953 | for (size_t e = nelems - rem_elems; e < nelems; ++e) { |
1954 | output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()( |
1955 | input[e], alpha); |
1956 | } |
1957 | } else { |
1958 | PRAGMA_OMP_SIMD() |
1959 | for (size_t e = nelems - rem_elems; e < nelems; ++e) { |
1960 | output[e] = qz<data_t<type_i>, data_t<type_o>>()( |
1961 | input[e], output[e], alpha, beta); |
1962 | } |
1963 | } |
1964 | } |
1965 | }); |
1966 | return status::success; |
1967 | } |
1968 | }; |
1969 | |
1970 | template <SIMPLE_REORDER_TEMPL_DECL> |
1971 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
1972 | typename utils::enable_if<tag_i == format_tag::any |
1973 | && tag_o == format_tag::any |
1974 | && order_keep == fmt_order::any, |
1975 | spec::direct_copy_except_dim_0>::type> { |
1976 | static bool is_applicable(const memory_desc_wrapper &input_d, |
1977 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
1978 | auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { |
1979 | return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); |
1980 | }; |
1981 | return !input_d.has_runtime_dims_or_strides() |
1982 | && input_d.similar_to(output_d, true, false, 1) |
1983 | && is_dense_no_0(input_d) && is_dense_no_0(output_d) |
1984 | && simple_attr_check(attr, false, true); |
1985 | } |
1986 | |
1987 | GET_SCRATCHPAD_SIZE_ZERO(); |
1988 | |
1989 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
1990 | DECLARE_COMMON_PARAMS(); |
1991 | using namespace utils; |
1992 | |
1993 | input += input_d.blk_off(0); |
1994 | output += output_d.blk_off(0); |
1995 | |
1996 | const int N = input_d.dims()[0]; |
1997 | const dim_t is = input_d.blocking_desc().strides[0]; |
1998 | const dim_t os = output_d.blocking_desc().strides[0]; |
1999 | const dim_t nelems_no_d0 = nelems_no_dim_0(input_d); |
2000 | const dim_t work_amount = N * nelems_no_d0; |
2001 | |
2002 | if (alpha == 1.0 && beta == 0.0) { |
2003 | parallel(0, [&](const int ithr, const int nthr) { |
2004 | dim_t n {0}, dim1_s {0}; |
2005 | dim_t start {0}, end {0}; |
2006 | balance211(work_amount, nthr, ithr, start, end); |
2007 | nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); |
2008 | while (start < end) { |
2009 | dim_t work_rem = end - start; |
2010 | dim_t dim1_e = std::min(dim1_s + work_rem, nelems_no_d0); |
2011 | PRAGMA_OMP_SIMD() |
2012 | for (dim_t e = dim1_s; e < dim1_e; ++e) { |
2013 | output[os * n + e] |
2014 | = _qz_a1b0<type_i, type_o>()(input[is * n + e]); |
2015 | } |
2016 | nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); |
2017 | } |
2018 | }); |
2019 | } else { |
2020 | parallel(0, [&](const int ithr, const int nthr) { |
2021 | dim_t n {0}, dim1_s {0}; |
2022 | dim_t start {0}, end {0}; |
2023 | balance211(work_amount, nthr, ithr, start, end); |
2024 | nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); |
2025 | while (start < end) { |
2026 | dim_t work_rem = end - start; |
2027 | dim_t dim1_e = std::min(dim1_s + work_rem, nelems_no_d0); |
2028 | PRAGMA_OMP_SIMD() |
2029 | for (dim_t e = dim1_s; e < dim1_e; ++e) { |
2030 | output[os * n + e] |
2031 | = _qz<type_i, type_o>()(input[is * n + e], |
2032 | output[os * n + e], alpha, beta); |
2033 | } |
2034 | nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); |
2035 | } |
2036 | }); |
2037 | } |
2038 | |
2039 | return status::success; |
2040 | } |
2041 | |
2042 | private: |
2043 | static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) { |
2044 | const int ndims = data_d.ndims(); |
2045 | if (ndims <= 1) return 1; |
2046 | return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1); |
2047 | } |
2048 | |
2049 | static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) { |
2050 | dims_t blocks; |
2051 | data_d.compute_blocks(blocks); |
2052 | |
2053 | const auto &blk = data_d.blocking_desc(); |
2054 | |
2055 | dim_t blk_size = 1; |
2056 | for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) |
2057 | blk_size *= blk.inner_blks[iblk]; |
2058 | |
2059 | dim_t max_size = blk_size; |
2060 | for (int d = 1; d < data_d.ndims(); ++d) { |
2061 | max_size = nstl::max(max_size, |
2062 | data_d.padded_dims()[d] / blocks[d] * blk.strides[d]); |
2063 | } |
2064 | |
2065 | return max_size; |
2066 | } |
2067 | }; |
2068 | |
2069 | template <SIMPLE_REORDER_TEMPL_DECL> |
2070 | struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
2071 | typename utils::enable_if<tag_i == format_tag::any |
2072 | && tag_o == format_tag::any |
2073 | && order_keep == fmt_order::any, |
2074 | spec::reference>::type> { |
2075 | static bool is_applicable(const memory_desc_wrapper &input_d, |
2076 | const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { |
2077 | /* supported smask: 0x0...011..10...0, |
2078 | * i.e. 1 should be contiguous */ |
2079 | int src_scales_mask = -1; |
2080 | int dst_scales_mask = -1; |
2081 | CHECK(get_scales_mask(attr, &src_scales_mask, &dst_scales_mask)); |
2082 | |
2083 | for (auto smask : {src_scales_mask, dst_scales_mask}) { |
2084 | for (; smask > 0 && !(smask & 0x1); smask >>= 1) |
2085 | ; |
2086 | for (; smask > 0 && smask & 0x1; smask >>= 1) |
2087 | ; |
2088 | if (smask != 0) return false; |
2089 | } |
2090 | |
2091 | using skip_mask_t = dnnl_primitive_attr::skip_mask_t; |
2092 | return input_d.is_blocking_desc() && output_d.is_blocking_desc() |
2093 | && !output_d.is_additional_buffer() |
2094 | && !input_d.is_additional_buffer() |
2095 | && attr->has_default_values(skip_mask_t::scales_runtime |
2096 | | skip_mask_t::zero_points_runtime |
2097 | | skip_mask_t::post_ops) |
2098 | && simple_po_check(attr); |
2099 | } |
2100 | |
2101 | GET_SCRATCHPAD_SIZE_ZERO(); |
2102 | |
2103 | static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) { |
2104 | DECLARE_COMMON_PARAMS(); |
2105 | |
2106 | // This kernel is used also for tensors with multiple inner |
2107 | // blocks for which generic zero padding must be used. |
2108 | // TODO: apply zero padding inside parallel_nd() |
2109 | ctx.zero_pad_output(DNNL_ARG_TO); |
2110 | |
2111 | parallel_nd(D_start, D_mask, D_rest, |
2112 | [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) { |
2113 | const float src_scale |
2114 | = src_scales[src_scales_mask == 0 ? 0 : dm]; |
2115 | const float dst_scale |
2116 | = dst_scales[dst_scales_mask == 0 ? 0 : dm]; |
2117 | |
2118 | const size_t e = (ds * D_mask + dm) * D_rest + dr; |
2119 | const auto &i = input[input_d.off_l(e)]; |
2120 | auto &o = output[output_d.off_l(e)]; |
2121 | |
2122 | float f = src_scale * ((float)i - src_zp); |
2123 | if (beta) f += beta * o; |
2124 | f = f * dst_scale + dst_zp; |
2125 | o = _qz_a1b0<data_type::f32, type_o>()(f); |
2126 | }); |
2127 | |
2128 | return status::success; |
2129 | } |
2130 | }; |
2131 | |
2132 | /* high level class declaration */ |
2133 | |
2134 | template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void> |
2135 | struct simple_reorder_t : public primitive_t { |
2136 | struct pd_t : public cpu_reorder_pd_t { |
2137 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
2138 | |
2139 | DECLARE_COMMON_PD_T("simple:any" , simple_reorder_t); |
2140 | |
2141 | private: |
2142 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
2143 | const primitive_attr_t *attr, engine_t *src_engine, |
2144 | const memory_desc_t *src_md, engine_t *dst_engine, |
2145 | const memory_desc_t *dst_md) { |
2146 | using skip_mask_t = dnnl_primitive_attr::skip_mask_t; |
2147 | bool args_ok = src_md->data_type == type_i |
2148 | && dst_md->data_type == type_o |
2149 | && attr->has_default_values(skip_mask_t::scales_runtime |
2150 | | skip_mask_t::zero_points |
2151 | | skip_mask_t::zero_points_runtime |
2152 | | skip_mask_t::post_ops) |
2153 | && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
2154 | spec>::is_applicable(src_md, dst_md, attr); |
2155 | if (!args_ok) return status::invalid_arguments; |
2156 | |
2157 | int mask = -1; |
2158 | bool is_set = false; |
2159 | CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); |
2160 | const memory_desc_wrapper input_d(src_md); |
2161 | if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) |
2162 | return status::unimplemented; |
2163 | |
2164 | auto _pd = new pd_t(attr, src_engine->kind(), src_md, |
2165 | dst_engine->kind(), dst_md); |
2166 | if (_pd == nullptr) return status::out_of_memory; |
2167 | if (_pd->init(engine, src_engine, dst_engine) != status::success) { |
2168 | delete _pd; |
2169 | return status::unimplemented; |
2170 | } |
2171 | |
2172 | const size_t scratchpad_sz_ |
2173 | = simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, |
2174 | spec>::get_scratchpad_size(src_md, dst_md); |
2175 | auto scratchpad = _pd->scratchpad_registry().registrar(); |
2176 | scratchpad.book(memory_tracking::names::key_reorder_space, |
2177 | scratchpad_sz_, 1, 16); |
2178 | |
2179 | if (is_set && mask > 0) { |
2180 | dim_t D_mask; |
2181 | _pd->get_D_values(input_d, mask, nullptr, &D_mask, nullptr); |
2182 | scratchpad.template book<float>( |
2183 | memory_tracking::names:: |
2184 | key_reorder_precomputed_dst_scales, |
2185 | D_mask); |
2186 | } |
2187 | |
2188 | _pd->init_scratchpad_md(); |
2189 | return safe_ptr_assign(*reorder_pd, _pd); |
2190 | } |
2191 | friend dnnl::impl::impl_list_item_t; |
2192 | }; |
2193 | |
2194 | simple_reorder_t(const pd_t *apd) : primitive_t(apd) {} |
2195 | |
2196 | status_t execute(const exec_ctx_t &ctx) const override { |
2197 | return simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute( |
2198 | pd(), ctx); |
2199 | } |
2200 | |
2201 | private: |
2202 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
2203 | }; |
2204 | |
2205 | #undef SIMPLE_REORDER_TEMPL_DECL |
2206 | #undef SIMPLE_REORDER_TEMPL_CALL |
2207 | |
2208 | } // namespace cpu |
2209 | } // namespace impl |
2210 | } // namespace dnnl |
2211 | |
2212 | #endif |
2213 | |
2214 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
2215 | |