1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_RNN_RNN_REORDERS_HPP
18#define CPU_RNN_RNN_REORDERS_HPP
19
20#include <assert.h>
21
22#include "common/bfloat16.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/primitive.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/platform.hpp"
29#include "cpu/reorder/cpu_reorder_pd.hpp"
30#include "cpu/simple_q10n.hpp"
31
32#include "cpu/gemm/gemm_pack.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37
38static inline void init_dims(dim_t &L, dim_t &D, dim_t &I, dim_t &G, dim_t &O,
39 const memory_desc_wrapper &mdw) {
40 const auto dims = mdw.dims();
41 const auto ndims = mdw.ndims();
42 L = dims[0];
43 D = dims[1];
44 I = dims[2];
45 G = 0;
46 O = 0;
47 // weights_layer/weights_iter case
48 if (ndims == 5) {
49 G = dims[3];
50 O = dims[4];
51 }
52 // projection weights case
53 if (ndims == 4) {
54 G = 1;
55 O = dims[3];
56 }
57 assert(G != 0 && O != 0);
58};
59
60template <data_type_t type_i>
61static inline void quantize_igo(int8_t *scratch_quantized,
62 const memory_desc_wrapper &src_d, const float *src, int mask,
63 float *scales) {
64 typedef typename prec_traits<type_i>::type in_data_t;
65
66 // TODO: trivial strides assumes here.
67 // Use proper strides where appropriate
68 dim_t L, D, I, G, O;
69 init_dims(L, D, I, G, O, src_d);
70
71 assert(scales != nullptr);
72 parallel(0, [&](const int ithr, const int nthr) {
73 dim_t start {0}, end {0};
74 balance211(L * D * I, nthr, ithr, start, end);
75 for (int ldi = start; ldi < end; ldi++) {
76 for (int go = 0; go < G * O; go++) {
77 const float s = scales[(mask == 0) ? 0 : go];
78 scratch_quantized[ldi * G * O + go]
79 = qz_b0<in_data_t, int8_t>()(src[ldi * G * O + go], s);
80 }
81 }
82 });
83}
84
85template <data_type_t type_i>
86static inline void quantize_goi(int8_t *scratch_quantized,
87 const memory_desc_wrapper &src_d, const float *src, int mask,
88 float *scales) {
89 typedef typename prec_traits<type_i>::type in_data_t;
90
91 // TODO: trivial strides assumes here.
92 // Use proper strides where appropriate
93 dim_t L, D, I, G, O;
94 init_dims(L, D, I, G, O, src_d);
95
96 assert(scales != nullptr);
97 parallel_nd(L * D, G * O, [&](dim_t ld, dim_t go) {
98 const float s = scales[(mask == 0) ? 0 : go];
99 PRAGMA_OMP_SIMD()
100 for (dim_t i = 0; i < I; i++) {
101 scratch_quantized[ld * I * G * O + i * G * O + go]
102 = qz_b0<in_data_t, int8_t>()(
103 src[ld * G * O * I + go * I + i], s);
104 }
105 });
106}
107
108static inline void compensate_igo(float *compensation,
109 const memory_desc_wrapper &src_d, int8_t *scratch_quantized,
110 int32_t *scratch_compensation, size_t scratch_comp_sz, int nthr) {
111 // TODO: trivial strides assumed here.
112 // Use proper strides where appropriate
113 dim_t L, D, I, G, O;
114 init_dims(L, D, I, G, O, src_d);
115
116 // We parallelize on LD and GO
117 // TODO: maybe restrict parallelism as we might have large
118 // parallelisation overhead if dimensions are small
119 const int LD_nthr = nstl::min(L * D, dim_t(nthr));
120 const int GO_nthr = nstl::min(G * O, dim_t(nthr / LD_nthr));
121 parallel(nthr, [&](const int ithr, const int nthr) {
122 int LD_ithr = -1;
123 int GO_ithr = -1;
124 dim_t LD_s = -1, LD_e = -1;
125 dim_t GO_s = -1, GO_e = -1;
126 if (ithr < LD_nthr * GO_nthr) {
127 LD_ithr = ithr % LD_nthr;
128 GO_ithr = ithr / LD_nthr;
129 balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
130 balance211(G * O, GO_nthr, GO_ithr, GO_s, GO_e);
131 }
132 int32_t *compensation_s32
133 = scratch_compensation + ithr * scratch_comp_sz;
134 for (int ld = LD_s; ld < LD_e; ld++) {
135 if (I == 1) {
136 PRAGMA_OMP_SIMD()
137 for (int go = GO_s; go < GO_e; go++)
138 compensation[ld * G * O + go] = saturate<float>(
139 scratch_quantized[ld * I * G * O + go]);
140 } else {
141 // We split the loop on I in three to avoid conditionals or zeroing compensation
142 int i = 0;
143 PRAGMA_OMP_SIMD()
144 for (int go = GO_s; go < GO_e; go++)
145 compensation_s32[go]
146 = scratch_quantized[go + G * O * (i + I * (ld))];
147 // 1 <= i < I-1
148 for (i = 1; i < I - 1; i++) {
149 PRAGMA_OMP_SIMD()
150 for (int go = GO_s; go < GO_e; go++)
151 compensation_s32[go] += scratch_quantized[go
152 + G * O * (i + I * (ld))];
153 }
154 // i = I-1
155 PRAGMA_OMP_SIMD()
156 for (int go = GO_s; go < GO_e; go++)
157 compensation[ld * G * O + go] = saturate<float>(
158 compensation_s32[go]
159 + scratch_quantized[go + G * O * (i + I * (ld))]);
160 }
161 }
162 });
163}
164
165static inline void compensate_goi(float *compensation,
166 const memory_desc_wrapper &src_d, int8_t *scratch_quantized) {
167 // TODO: trivial strides assumed here.
168 // Use proper strides where appropriate
169 dim_t L, D, I, G, O;
170 init_dims(L, D, I, G, O, src_d);
171
172 parallel_nd(L * D, G * O, [&](dim_t ld, dim_t go) {
173 int32_t compensation_s32 = 0;
174 PRAGMA_OMP_SIMD()
175 for (dim_t i = 0; i < I; i++) {
176 compensation_s32
177 += scratch_quantized[ld * I * G * O + i * G * O + go];
178 }
179 // TODO: do not convert to f32 if this compensation is not
180 // going to be added to a bias (e.g. like in lstm
181 // projection where it is directly added to the s32
182 // accumulators)
183 compensation[ld * G * O + go] = saturate<float>(compensation_s32);
184 });
185}
186
187template <data_type_t type_i, data_type_t type_o>
188struct rnn_data_reorder_t : public primitive_t {
189 struct pd_t : public cpu_reorder_pd_t {
190 using cpu_reorder_pd_t::cpu_reorder_pd_t;
191
192 DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
193
194 private:
195 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
196 const primitive_attr_t *attr, engine_t *src_engine,
197 const memory_desc_t *src_md, engine_t *dst_engine,
198 const memory_desc_t *dst_md) {
199 using namespace format_tag;
200 using namespace status;
201 const memory_desc_wrapper id(src_md), od(dst_md);
202
203 bool args_ok = true;
204#define PD_CHECK_ARG(x) args_ok = args_ok && (x)
205 PD_CHECK_ARG(id.data_type() == type_i);
206 PD_CHECK_ARG(od.data_type() == type_o);
207 PD_CHECK_ARG(utils::one_of(id.ndims(), 3, 4));
208 PD_CHECK_ARG(!id.has_runtime_dims_or_strides());
209 auto skip_mask = primitive_attr_t::skip_mask_t::rnn_data_qparams
210 | primitive_attr_t::skip_mask_t::rnn_weights_qparams
211 | primitive_attr_t::skip_mask_t::
212 rnn_weights_projection_qparams;
213 PD_CHECK_ARG(attr->has_default_values(skip_mask));
214 PD_CHECK_ARG(IMPLICATION(id.ndims() == 3,
215 id.matches_tag(tnc) && od.matches_tag(tnc)));
216 PD_CHECK_ARG(IMPLICATION(id.ndims() == 4,
217 id.matches_tag(ldnc) && od.matches_tag(ldnc)));
218#undef PD_CHECK_ARG
219 if (!args_ok) return invalid_arguments;
220
221 auto _pd = new pd_t(attr, src_engine->kind(), src_md,
222 dst_engine->kind(), dst_md);
223 if (_pd == nullptr) return out_of_memory;
224 if (_pd->init(engine, src_engine, dst_engine) != success) {
225 delete _pd;
226 return unimplemented;
227 }
228 _pd->init_scratchpad_md();
229 return safe_ptr_assign(*reorder_pd, _pd);
230 }
231 friend dnnl::impl::impl_list_item_t;
232 };
233
234 rnn_data_reorder_t(const pd_t *apd) : primitive_t(apd) {}
235
236private:
237 typedef typename prec_traits<type_i>::type in_data_t;
238 typedef typename prec_traits<type_o>::type out_data_t;
239
240 bool is_dense() const {
241 const memory_desc_wrapper &input_d = pd()->src_md();
242 const memory_desc_wrapper &output_d = pd()->dst_md();
243 return utils::everyone_is(1,
244 input_d.blocking_desc().strides[input_d.ndims() - 1],
245 output_d.blocking_desc().strides[output_d.ndims() - 1]);
246 }
247
248 /* This function assumes that only the innermost dimension (C) is
249 dense (that is to say, stride is 1). This is enough to have
250 good performance and allow non trivial strides on other
251 dimensions (to allow an "optimized" path for views for
252 example).
253 */
254 status_t execute_dense(out_data_t *output, const in_data_t *input,
255 const float scale, const float shift) const {
256 assert(type_i == data_type::f32);
257 assert(type_o == data_type::u8 || type_o == data_type::s8);
258
259 const memory_desc_wrapper &input_d = pd()->src_md();
260 const memory_desc_wrapper &output_d = pd()->dst_md();
261 const dim_t outer_dim
262 = utils::array_product(input_d.dims(), input_d.ndims() - 1);
263 const dim_t inner_dim = input_d.dims()[input_d.ndims() - 1];
264
265 parallel(0, [&](const int ithr, const int nthr) {
266 dim_t start {0}, end {0};
267 balance211(outer_dim, nthr, ithr, start, end);
268 for (int i = start; i < end; ++i) {
269 const dim_t off_in = input_d.off_l(i * inner_dim);
270 const dim_t off_out = output_d.off_l(i * inner_dim);
271 const in_data_t *__restrict i_ = input + off_in;
272 out_data_t *__restrict o_ = output + off_out;
273 PRAGMA_OMP_SIMD()
274 for (int j = 0; j < inner_dim; ++j) {
275 const float in = (float)i_[j] * scale + shift;
276 o_[j] = qz_a1b0<float, out_data_t>()(in);
277 }
278 }
279 });
280 return status::success;
281 }
282
283 status_t execute_generic(out_data_t *output, const in_data_t *input,
284 float scale, float shift) const {
285 assert(type_i == data_type::f32);
286 assert(type_o == data_type::u8 || type_o == data_type::s8);
287
288 const memory_desc_wrapper &input_d = pd()->src_md();
289 const memory_desc_wrapper &output_d = pd()->dst_md();
290 const size_t nelems = input_d.nelems();
291 parallel_nd(nelems, [&](size_t i) {
292 const float in = (float)input[input_d.off_l(i)] * scale + shift;
293 output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in);
294 });
295 return status::success;
296 }
297
298 status_t execute(const exec_ctx_t &ctx) const override {
299 auto input = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM);
300 auto output = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO);
301 const float scale = pd()->attr()->rnn_data_qparams_.scale_;
302 const float shift = pd()->attr()->rnn_data_qparams_.shift_;
303
304 if (is_dense())
305 return execute_dense(output, input, scale, shift);
306 else
307 return execute_generic(output, input, scale, shift);
308 }
309
310 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
311};
312
313template <data_type_t type_i>
314struct rnn_weights_reorder_s8_t : public primitive_t {
315 struct pd_t : public cpu_reorder_pd_t {
316 using cpu_reorder_pd_t::cpu_reorder_pd_t;
317 typedef dnnl_status_t (*gemm_pack_f)(const char *identifier,
318 const char *transa, const char *transb, const dim_t *M,
319 const dim_t *N, const dim_t *K, const dim_t *lda,
320 const dim_t *ldb, const void *src, void *dst);
321
322 DECLARE_COMMON_PD_T("rnn_weights_reorder_s8", rnn_weights_reorder_s8_t);
323
324 status_t init(
325 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
326 status_t status
327 = cpu_reorder_pd_t::init(engine, src_engine, dst_engine);
328 if (status != status::success) return status;
329
330 nthr_ = dnnl_get_max_threads();
331 init_scratchpad();
332
333 return status::success;
334 }
335
336 format_tag_t itag_ = format_tag::undef;
337 format_tag_t otag_ = format_tag::undef;
338 size_t thr_scratch_comp_sz_ = 0;
339 int nthr_; // To not exceed the limit in execute used for set up.
340 gemm_pack_f gemm_pack;
341
342 private:
343 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
344 const primitive_attr_t *attr, engine_t *src_engine,
345 const memory_desc_t *src_md, engine_t *dst_engine,
346 const memory_desc_t *dst_md) {
347 using namespace format_tag;
348 using namespace rnn_packed_format;
349 using namespace status;
350 const memory_desc_wrapper id(src_md), od(dst_md);
351
352 bool args_ok = true;
353#define PD_CHECK_ARG(x) args_ok = args_ok && (x)
354 // Fast checks
355 PD_CHECK_ARG(id.data_type() == type_i);
356 PD_CHECK_ARG(od.data_type() == data_type::s8);
357 PD_CHECK_ARG(od.format_kind() == format_kind::rnn_packed);
358 PD_CHECK_ARG(utils::one_of(
359 od.rnn_packed_desc().format, ldigo_p, ldio_p));
360 PD_CHECK_ARG(od.ndims() == id.ndims());
361 // TODO: we have to skip projection qparam even for regular lstm
362 // as we use the same attr for regular weights and projection
363 auto skip_mask = primitive_attr_t::skip_mask_t::rnn_data_qparams
364 | primitive_attr_t::skip_mask_t::rnn_weights_qparams
365 | primitive_attr_t::skip_mask_t::
366 rnn_weights_projection_qparams;
367 PD_CHECK_ARG(attr->has_default_values(skip_mask));
368 if (!args_ok) return invalid_arguments;
369
370 // Slower checks
371 PD_CHECK_ARG(id.is_dense());
372 if (!args_ok) return invalid_arguments;
373
374 format_tag_t itag = id.matches_one_of_tag(ldigo, ldgoi, ldio, ldoi);
375 if (itag == format_tag::undef) return invalid_arguments;
376
377 // TODO: add support for layer and direction dimensions
378 // weights_layer and weights_iter
379 if (id.ndims() == 5
380 && !utils::one_of(attr->rnn_weights_qparams_.mask_, 0, 24))
381 return unimplemented;
382 // weights_projection
383 if (id.ndims() == 4
384 && !utils::one_of(
385 attr->rnn_weights_projection_qparams_.mask_, 0, 8))
386 return unimplemented;
387
388 auto _pd = new pd_t(attr, src_engine->kind(), src_md,
389 dst_engine->kind(), dst_md);
390 if (_pd == nullptr) return out_of_memory;
391 _pd->itag_ = itag;
392 if (_pd->init(engine, src_engine, dst_engine) != success) {
393 delete _pd;
394 return unimplemented;
395 }
396 _pd->init_scratchpad_md();
397 const bool is_s8s8 = dst_md->extra.flags
398 & memory_extra_flags::rnn_s8s8_compensation;
399 _pd->gemm_pack = is_s8s8 ? &gemm_s8s8s32_pack : &gemm_s8u8s32_pack;
400
401 return safe_ptr_assign(*reorder_pd, _pd);
402#undef PD_CHECK_ARG
403 }
404
405 void init_scratchpad() {
406 using namespace format_tag;
407
408 const memory_desc_wrapper id(src_md());
409 const size_t nelems = id.nelems();
410 const auto &dims = id.dims();
411
412 using namespace memory_tracking::names;
413 auto scratchpad = scratchpad_registry().registrar();
414 const size_t quantization_size = nelems;
415 // we do not use GO directly, as this can cause false
416 // sharing when parallelizing on I (2 threads writing to
417 // the same cache line)
418 thr_scratch_comp_sz_ = itag_ == ldigo ? dims[3] * dims[4] : dims[3];
419 thr_scratch_comp_sz_ = utils::rnd_up(thr_scratch_comp_sz_, 16);
420 size_t reduction_size = 0;
421 if (utils::one_of(itag_, ldigo, ldio))
422 reduction_size = nthr_ * thr_scratch_comp_sz_;
423
424 scratchpad.template book<int8_t>(
425 key_reorder_rnn_weights_quantization, quantization_size);
426 scratchpad.template book<int32_t>(
427 key_reorder_rnn_weights_reduction, reduction_size);
428 }
429
430 friend dnnl::impl::impl_list_item_t;
431 };
432
433 rnn_weights_reorder_s8_t(const pd_t *apd) : primitive_t(apd) {}
434
435private:
436 typedef typename prec_traits<type_i>::type in_data_t;
437
438 status_t execute(const exec_ctx_t &ctx) const override {
439 // TODO: trivial strides assumed here.
440 // Use proper strides where appropriate
441
442 using namespace format_tag;
443
444 auto src = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM);
445 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_TO);
446 const memory_desc_wrapper &src_d = pd()->src_md();
447 const memory_desc_wrapper &dst_d = pd()->dst_md();
448 if (src_d.has_zero_dim()) {
449 assert(dst_d.has_zero_dim());
450 return status::success;
451 }
452
453 dim_t L, D, I, G, O;
454 init_dims(L, D, I, G, O, src_d);
455
456 /* Quantize src & compute compensation */
457 auto scratch_quantized
458 = (int8_t * __restrict) ctx.get_scratchpad_grantor()
459 .template get<void>(memory_tracking::names::
460 key_reorder_rnn_weights_quantization);
461 auto scratch_compensation
462 = (int32_t * __restrict) ctx.get_scratchpad_grantor()
463 .template get<void>(memory_tracking::names::
464 key_reorder_rnn_weights_reduction);
465 float *comp = reinterpret_cast<float *>(
466 dst + dst_d.rnn_packed_desc().offset_compensation);
467 float *scales = nullptr;
468 int mask = 0;
469 if (src_d.ndims() == 5) {
470 scales = pd()->attr()->rnn_weights_qparams_.scales_;
471 mask = pd()->attr()->rnn_weights_qparams_.mask_;
472 }
473 if (src_d.ndims() == 4) {
474 scales = pd()->attr()->rnn_weights_projection_qparams_.scales_;
475 mask = pd()->attr()->rnn_weights_projection_qparams_.mask_;
476 }
477 /* Step 1: we quantize if we need to */
478 if (type_i == data_type::f32) {
479 switch (pd()->itag_) {
480 case ldigo:
481 case ldio:
482 quantize_igo<type_i>(scratch_quantized, src_d, (float *)src,
483 mask, scales);
484 break;
485 case ldgoi:
486 case ldoi:
487 quantize_goi<type_i>(scratch_quantized, src_d, (float *)src,
488 mask, scales);
489 break;
490 default: assert(!"Unsupported reorder");
491 }
492 } else
493 scratch_quantized = (int8_t * __restrict) src;
494
495 /* Step 2: we pre-compute the compensation */
496 switch (pd()->itag_) {
497 case ldigo:
498 case ldio:
499 compensate_igo(comp, src_d, scratch_quantized,
500 scratch_compensation, pd()->thr_scratch_comp_sz_,
501 pd()->nthr_);
502 break;
503 case ldgoi:
504 case ldoi: compensate_goi(comp, src_d, scratch_quantized); break;
505 default: assert(!"Unsupported reorder");
506 }
507
508 /* Step 3: we pack the matrix */
509 const auto off_igo = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) {
510 return o + O * (g + G * (i + I * (d + D * l)));
511 };
512 const int n_parts = dst_d.rnn_packed_desc().n_parts;
513 const size_t *size_packed_cell = dst_d.rnn_packed_desc().part_pack_size;
514 const int *parts = dst_d.rnn_packed_desc().parts;
515 const dim_t n = dst_d.rnn_packed_desc().n;
516 const dim_t ldb = dst_d.rnn_packed_desc().ldb;
517 char *to_pack = dst;
518
519 for (dim_t l = 0; l < L; l++) {
520 for (dim_t d = 0; d < D; d++) {
521 for (dim_t p = 0; p < n_parts; p++) {
522 dim_t g = (p > 0) ? parts[p - 1] : 0;
523 dim_t m_p = parts[p] * O;
524 dim_t k_p = I;
525 dim_t lda = (dim_t)G * O;
526 CHECK(pd()->gemm_pack("A", "N", "N", &m_p, &n, &k_p, &lda,
527 &ldb, scratch_quantized + off_igo(l, d, 0, g, 0),
528 to_pack));
529 to_pack += size_packed_cell[p];
530 }
531 }
532 }
533 return status::success;
534 }
535
536 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
537};
538
539template <data_type_t type_i, data_type_t type_o>
540struct rnn_weights_reorder_t : public primitive_t {
541 struct pd_t : public cpu_reorder_pd_t {
542 using cpu_reorder_pd_t::cpu_reorder_pd_t;
543
544 DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
545
546 format_tag_t itag_;
547
548 status_t init(
549 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
550 status_t status
551 = cpu_reorder_pd_t::init(engine, src_engine, dst_engine);
552 if (status != status::success) return status;
553
554 init_scratchpad();
555
556 return status::success;
557 }
558
559 private:
560 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
561 const primitive_attr_t *attr, engine_t *src_engine,
562 const memory_desc_t *src_md, engine_t *dst_engine,
563 const memory_desc_t *dst_md) {
564 using namespace format_tag;
565 using namespace rnn_packed_format;
566 using namespace status;
567
568 const memory_desc_wrapper id(src_md), od(dst_md);
569 bool args_ok = true;
570#define PD_CHECK_ARG(x) args_ok = args_ok && (x)
571 PD_CHECK_ARG(id.data_type() == type_i);
572 PD_CHECK_ARG(od.data_type() == type_o);
573 PD_CHECK_ARG(od.format_kind() == format_kind::rnn_packed);
574 PD_CHECK_ARG(utils::one_of(
575 od.rnn_packed_desc().format, ldigo_p, ldgoi_p, ldio_p));
576 PD_CHECK_ARG(attr->has_default_values());
577#undef PD_CHECK_ARG
578 if (!args_ok) return invalid_arguments;
579
580 format_tag_t itag = id.matches_one_of_tag(ldigo, ldgoi, ldio, ldoi);
581 if (itag == format_tag::undef) return invalid_arguments;
582
583 auto _pd = new pd_t(attr, src_engine->kind(), src_md,
584 dst_engine->kind(), dst_md);
585 if (_pd == nullptr) return out_of_memory;
586 if (_pd->init(engine, src_engine, dst_engine) != success) {
587 delete _pd;
588 return unimplemented;
589 }
590 _pd->itag_ = itag;
591 _pd->init_scratchpad_md();
592 return safe_ptr_assign(*reorder_pd, _pd);
593 }
594
595 void init_scratchpad() {
596 using namespace format_tag;
597 using namespace rnn_packed_format;
598
599 const memory_desc_wrapper id(src_md());
600 const memory_desc_wrapper od(dst_md());
601 const rnn_packed_desc_t &rnn_pdata = od.rnn_packed_desc();
602
603 format_tag_t itag = id.matches_one_of_tag(ldigo, ldgoi, ldio);
604 const bool layout_cross_case
605 = (itag == ldigo && rnn_pdata.format == ldgoi_p)
606 || (itag == ldgoi && rnn_pdata.format == ldigo_p)
607 || (itag == ldio && rnn_pdata.format == ldio_p),
608 dt_cross_case
609 = type_i == data_type::f32 && type_o == data_type::bf16;
610 const size_t sz = id.nelems();
611
612 using namespace memory_tracking::names;
613 auto scratchpad = scratchpad_registry().registrar();
614 scratchpad.template book<out_data_t>(
615 key_reorder_rnn_weights_transposition,
616 layout_cross_case ? sz : 0);
617 scratchpad.template book<out_data_t>(
618 key_reorder_rnn_weights_bf16_cvt, dt_cross_case ? sz : 0);
619 }
620 friend dnnl::impl::impl_list_item_t;
621 };
622
623 rnn_weights_reorder_t(const pd_t *apd) : primitive_t(apd) {}
624
625private:
626 typedef typename prec_traits<type_i>::type in_data_t;
627 typedef typename prec_traits<type_o>::type out_data_t;
628
629 status_t execute(const exec_ctx_t &ctx) const override {
630 // TODO: trivial strides assumed here.
631 // Use proper strides where appropriate
632
633 using namespace format_tag;
634 using namespace rnn_packed_format;
635
636 auto input = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM);
637 auto output = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO);
638 const memory_desc_wrapper &input_d = pd()->src_md();
639 const memory_desc_wrapper &output_d = pd()->dst_md();
640 if (input_d.has_zero_dim()) {
641 assert(output_d.has_zero_dim());
642 return status::success;
643 }
644
645 const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc();
646 dim_t L, D, I, G, O;
647 init_dims(L, D, I, G, O, input_d);
648
649 /* Pack */
650 const bool from_igo = utils::one_of(pd()->itag_, ldigo, ldio);
651 const bool to_igo = utils::one_of(rnn_pdata.format, ldigo_p, ldio_p);
652 const int n_parts = rnn_pdata.n_parts;
653 const size_t *size_packed_cell = rnn_pdata.part_pack_size;
654 const int *parts = rnn_pdata.parts;
655 const dim_t n = rnn_pdata.n;
656
657 /* Convert fp32 input to bf16 */
658 out_data_t *input_cvt = (out_data_t *)input;
659 if (type_i == data_type::f32 && type_o == data_type::bf16) {
660 input_cvt
661 = (out_data_t *)ctx.get_scratchpad_grantor()
662 .template get<void>(memory_tracking::names::
663 key_reorder_rnn_weights_bf16_cvt);
664 parallel_nd(L * D, [&](dim_t ld) {
665 cvt_float_to_bfloat16((bfloat16_t *)input_cvt + ld * G * O * I,
666 (float *)input + ld * G * O * I, G * O * I);
667 });
668 }
669
670 /* Transpose weights prior to packing to ensure that packed GEMM
671 * algorithm will be dispatched */
672 out_data_t *input_tr = input_cvt;
673 if (from_igo != to_igo) {
674 input_tr
675 = (out_data_t *)ctx.get_scratchpad_grantor().template get<void>(
676 memory_tracking::names::
677 key_reorder_rnn_weights_transposition);
678 const dim_t M = to_igo ? G * O : I;
679 const dim_t N = to_igo ? I : G * O;
680 parallel_nd(L * D, N, [&](dim_t ld, dim_t i) {
681 for (dim_t j = 0; j < M; j++) {
682 input_tr[ld * M * N + i * M + j]
683 = input_cvt[ld * M * N + j * N + i];
684 }
685 });
686 }
687
688 const auto off_igo = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) {
689 return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
690 };
691 const auto off_goi = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) {
692 return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
693 };
694 const dim_t lda = to_igo ? G * O : I;
695 const dim_t ldb = rnn_pdata.ldb;
696 for (dim_t l = 0; l < L; l++) {
697 for (dim_t d = 0; d < D; d++) {
698 for (dim_t p = 0; p < n_parts; p++) {
699 const dim_t g = (p > 0) ? parts[p - 1] : 0;
700 const dim_t m_p = to_igo ? parts[p] * O : I;
701 const dim_t k_p = to_igo ? I : parts[p] * O;
702 if (type_o == data_type::bf16) {
703 CHECK(gemm_bf16bf16f32_pack("A", "N", "N", &m_p, &n,
704 &k_p, &lda, &ldb,
705 (bfloat16_t *)&input_tr[to_igo
706 ? off_igo(l, d, 0, g, 0)
707 : off_goi(l, d, 0, g, 0)],
708 (bfloat16_t *)output));
709 } else {
710 CHECK(sgemm_pack("A", "N", "N", &m_p, &n, &k_p, &lda,
711 &ldb,
712 (float *)&input_tr[to_igo
713 ? off_igo(l, d, 0, g, 0)
714 : off_goi(l, d, 0, g, 0)],
715 (float *)output));
716 }
717 output += size_packed_cell[p] / sizeof(out_data_t);
718 }
719 }
720 }
721 return status::success;
722 }
723
724 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
725};
726
727template <data_type_t type_i, data_type_t type_o>
728struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
729 struct pd_t : public cpu_reorder_pd_t {
730 using cpu_reorder_pd_t::cpu_reorder_pd_t;
731
732 DECLARE_COMMON_PD_T("rnn_brgemm_weights_reorder_s8_t",
733 rnn_brgemm_weights_reorder_s8_t);
734
735 format_tag_t itag_;
736 format_tag_t otag_;
737 int nthr_; // To not exceed the limit in execute used for set up.
738 size_t thr_scratch_comp_sz_ = 0;
739
740 status_t init(
741 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
742 status_t status
743 = cpu_reorder_pd_t::init(engine, src_engine, dst_engine);
744 if (status != status::success) return status;
745
746 nthr_ = dnnl_get_max_threads();
747 init_scratchpad();
748
749 return status::success;
750 }
751
752 private:
753 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
754 const primitive_attr_t *attr, engine_t *src_engine,
755 const memory_desc_t *src_md, engine_t *dst_engine,
756 const memory_desc_t *dst_md) {
757 using namespace status;
758 using namespace format_tag;
759 using namespace memory_extra_flags;
760
761 const memory_desc_wrapper id(src_md), od(dst_md);
762
763 const bool args_ok = true && id.data_type() == type_i
764 && od.data_type() == data_type::s8 && id.is_dense();
765 if (!args_ok) return invalid_arguments;
766
767 const auto skip_mask
768 = primitive_attr_t::skip_mask_t::rnn_data_qparams
769 | primitive_attr_t::skip_mask_t::rnn_weights_qparams
770 | primitive_attr_t::skip_mask_t::
771 rnn_weights_projection_qparams;
772 if (!attr->has_default_values(skip_mask)) return invalid_arguments;
773
774 // TODO: add support for layer and direction dimensions
775 // weights_layer and weights_iter
776 if (id.ndims() == 5
777 && !utils::one_of(attr->rnn_weights_qparams_.mask_, 0, 24))
778 return unimplemented;
779 // weights_projection
780 if (id.ndims() == 4
781 && !utils::one_of(
782 attr->rnn_weights_projection_qparams_.mask_, 0, 8))
783 return unimplemented;
784
785 // Check the proper memory desc has been passed to u8s8 and s8s8
786 // Note: currently rnn_u8s8_compensation and rnn_s8s8_compensation
787 // have common bit so we have to perform additional checks to
788 // separate these two cases
789 const bool check_u8s8 = (od.extra().flags & rnn_u8s8_compensation)
790 && !types::extra_flag_rnn_s8s8_compensation_is_set(
791 od.extra().flags)
792 && od.extra().compensation_mask
793 == ((id.ndims() == 5) ? 27 /* 11011 */
794 : 13 /* 1101 */);
795 const bool check_s8s8 = od.extra().flags & rnn_s8s8_compensation
796 && od.extra().compensation_mask == 0;
797 if (!(check_u8s8 || check_s8s8)) return invalid_arguments;
798
799 auto _pd = new pd_t(attr, src_engine->kind(), src_md,
800 dst_engine->kind(), dst_md);
801 if (_pd == nullptr) return out_of_memory;
802 if (_pd->init(engine, src_engine, dst_engine) != success) {
803 delete _pd;
804 return unimplemented;
805 }
806
807 _pd->itag_ = format_tag::undef;
808
809 format_tag_t otag, itag;
810
811 itag = id.matches_one_of_tag(ldigo, ldio);
812 otag = od.matches_one_of_tag(ldgOI64o4i, ldgOI32o4i, ldOI32o4i);
813 if (itag != format_tag::undef && otag != format_tag::undef) {
814 _pd->itag_ = itag;
815 _pd->otag_ = otag;
816 } else {
817 delete _pd;
818 return invalid_arguments;
819 }
820 _pd->init_scratchpad_md();
821 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
822 }
823
824 void init_scratchpad() {
825 using namespace format_tag;
826
827 const memory_desc_wrapper id(src_md());
828 const size_t nelems = id.nelems();
829 const auto &dims = id.dims();
830 const auto ndims = id.ndims();
831
832 using namespace memory_tracking::names;
833 auto scratchpad = scratchpad_registry().registrar();
834 const size_t quantization_size = nelems;
835 // we do not use GO directly, as this can cause false
836 // sharing when parallelizing on I (2 threads writing to
837 // the same cache line)
838 thr_scratch_comp_sz_ = (ndims == 5) ? dims[3] * dims[4] : dims[3];
839 thr_scratch_comp_sz_ = utils::rnd_up(thr_scratch_comp_sz_, 16);
840 const size_t reduction_size = nthr_ * thr_scratch_comp_sz_;
841
842 scratchpad.template book<int8_t>(
843 key_reorder_rnn_weights_quantization, quantization_size);
844 scratchpad.template book<int32_t>(
845 key_reorder_rnn_weights_reduction, reduction_size);
846 }
847 friend dnnl::impl::impl_list_item_t;
848 };
849
850 rnn_brgemm_weights_reorder_s8_t(const pd_t *apd) : primitive_t(apd) {}
851
852private:
853 typedef typename prec_traits<type_i>::type in_data_t;
854 typedef typename prec_traits<type_o>::type out_data_t;
855
856 status_t execute(const exec_ctx_t &ctx) const override {
857 using namespace format_tag;
858 using namespace data_type;
859 using namespace utils;
860 using namespace memory_extra_flags;
861
862 auto src = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM);
863 auto dst = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO);
864 const memory_desc_wrapper &src_d = pd()->src_md();
865 const memory_desc_wrapper &dst_d = pd()->dst_md();
866 if (src_d.has_zero_dim()) {
867 assert(dst_d.has_zero_dim());
868 return status::success;
869 }
870
871 const auto &blocked_d = dst_d;
872 const auto &pdims = blocked_d.padded_dims();
873
874 const int o_block = pd()->otag_ == ldgOI64o4i ? 64 : 32;
875 static constexpr int i_block = 4;
876
877 dim_t L, D, I, G, O;
878 init_dims(L, D, I, G, O, src_d);
879
880 const dim_t pI = pdims[2];
881 const dim_t pO = (src_d.ndims() == 5) ? pdims[4] : pdims[3];
882 const dim_t IB = pI / i_block;
883 const dim_t OB = pO / o_block;
884
885 const size_t compensation_offset = (size_t)L * D * G * pI * pO;
886
887 /* Quantize src & compute compensation */
888 auto scratch_quantized
889 = (int8_t * __restrict) ctx.get_scratchpad_grantor()
890 .template get<void>(memory_tracking::names::
891 key_reorder_rnn_weights_quantization);
892 auto scratch_compensation
893 = (int32_t * __restrict) ctx.get_scratchpad_grantor()
894 .template get<void>(memory_tracking::names::
895 key_reorder_rnn_weights_reduction);
896 float *comp = reinterpret_cast<float *>(dst + compensation_offset);
897 const bool req_s8s8_comp = (dst_d.extra().flags & rnn_u8s8_compensation)
898 && !types::extra_flag_rnn_s8s8_compensation_is_set(
899 dst_d.extra().flags);
900 const auto mask_ok = [&](int mask) {
901 return mask
902 == ((src_d.ndims() == 5) ? 27 /* 11011 */
903 : 13 /* 1101 */);
904 };
905
906 float *scales = nullptr;
907 int mask = 0;
908 if (src_d.ndims() == 5) {
909 scales = pd()->attr()->rnn_weights_qparams_.scales_;
910 mask = pd()->attr()->rnn_weights_qparams_.mask_;
911 }
912 if (src_d.ndims() == 4) {
913 scales = pd()->attr()->rnn_weights_projection_qparams_.scales_;
914 mask = pd()->attr()->rnn_weights_projection_qparams_.mask_;
915 }
916 if (type_i == data_type::f32) {
917 quantize_igo<type_i>(
918 scratch_quantized, src_d, (float *)src, mask, scales);
919 } else
920 scratch_quantized = (int8_t * __restrict) src;
921
922 if (req_s8s8_comp && mask_ok(dst_d.extra().compensation_mask))
923 compensate_igo(comp, src_d, scratch_quantized, scratch_compensation,
924 pd()->thr_scratch_comp_sz_, pd()->nthr_);
925
926 const auto off_plain
927 = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) {
928 return ((((dim_t)l * D + d) * I + i) * G + g) * O + o;
929 };
930
931 const auto off_blk = [&](dim_t l, dim_t d, dim_t g, dim_t ob,
932 dim_t ib) {
933 return (((((dim_t)l * D + d) * G + g) * OB + ob) * IB + ib)
934 * i_block * o_block;
935 };
936 const auto off_inner_blk = [&](int xdim, int y, int x,
937 int folding_factor) {
938 const int row = (xdim) * (y / folding_factor) * folding_factor;
939 const int col = x * folding_factor + (y % folding_factor);
940 return row + col;
941 };
942 const auto kernel_plain_to_blocked
943 = [&](const out_data_t *inp, out_data_t *out, int ib, int ob) {
944 PRAGMA_OMP_SIMD()
945 for (int i = 0; i < i_block * o_block; i++)
946 out[i] = 0;
947
948 for_(int i = 0; i < i_block; i++)
949 for (int o = 0; o < o_block; o++) {
950 if ((i + ib * i_block < I) && (o + ob * o_block < O))
951 out[off_inner_blk(o_block, i, o, i_block)]
952 = inp[i * G * O + o];
953 }
954 };
955
956 parallel_nd(L, D, G, OB, IB,
957 [&](dim_t l, dim_t d, dim_t g, dim_t ob, dim_t ib) {
958 auto inp = &scratch_quantized[off_plain(
959 l, d, ib * i_block, g, ob * o_block)];
960 auto out = &dst[off_blk(l, d, g, ob, ib)];
961
962 kernel_plain_to_blocked(inp, out, ib, ob);
963 });
964
965 return status::success;
966 }
967
968 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
969};
970
971} // namespace cpu
972} // namespace impl
973} // namespace dnnl
974
975#endif
976