1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_RNN_RNN_REORDERS_HPP |
18 | #define CPU_RNN_RNN_REORDERS_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/bfloat16.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/primitive.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/platform.hpp" |
29 | #include "cpu/reorder/cpu_reorder_pd.hpp" |
30 | #include "cpu/simple_q10n.hpp" |
31 | |
32 | #include "cpu/gemm/gemm_pack.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | |
38 | static inline void init_dims(dim_t &L, dim_t &D, dim_t &I, dim_t &G, dim_t &O, |
39 | const memory_desc_wrapper &mdw) { |
40 | const auto dims = mdw.dims(); |
41 | const auto ndims = mdw.ndims(); |
42 | L = dims[0]; |
43 | D = dims[1]; |
44 | I = dims[2]; |
45 | G = 0; |
46 | O = 0; |
47 | // weights_layer/weights_iter case |
48 | if (ndims == 5) { |
49 | G = dims[3]; |
50 | O = dims[4]; |
51 | } |
52 | // projection weights case |
53 | if (ndims == 4) { |
54 | G = 1; |
55 | O = dims[3]; |
56 | } |
57 | assert(G != 0 && O != 0); |
58 | }; |
59 | |
60 | template <data_type_t type_i> |
61 | static inline void quantize_igo(int8_t *scratch_quantized, |
62 | const memory_desc_wrapper &src_d, const float *src, int mask, |
63 | float *scales) { |
64 | typedef typename prec_traits<type_i>::type in_data_t; |
65 | |
66 | // TODO: trivial strides assumes here. |
67 | // Use proper strides where appropriate |
68 | dim_t L, D, I, G, O; |
69 | init_dims(L, D, I, G, O, src_d); |
70 | |
71 | assert(scales != nullptr); |
72 | parallel(0, [&](const int ithr, const int nthr) { |
73 | dim_t start {0}, end {0}; |
74 | balance211(L * D * I, nthr, ithr, start, end); |
75 | for (int ldi = start; ldi < end; ldi++) { |
76 | for (int go = 0; go < G * O; go++) { |
77 | const float s = scales[(mask == 0) ? 0 : go]; |
78 | scratch_quantized[ldi * G * O + go] |
79 | = qz_b0<in_data_t, int8_t>()(src[ldi * G * O + go], s); |
80 | } |
81 | } |
82 | }); |
83 | } |
84 | |
85 | template <data_type_t type_i> |
86 | static inline void quantize_goi(int8_t *scratch_quantized, |
87 | const memory_desc_wrapper &src_d, const float *src, int mask, |
88 | float *scales) { |
89 | typedef typename prec_traits<type_i>::type in_data_t; |
90 | |
91 | // TODO: trivial strides assumes here. |
92 | // Use proper strides where appropriate |
93 | dim_t L, D, I, G, O; |
94 | init_dims(L, D, I, G, O, src_d); |
95 | |
96 | assert(scales != nullptr); |
97 | parallel_nd(L * D, G * O, [&](dim_t ld, dim_t go) { |
98 | const float s = scales[(mask == 0) ? 0 : go]; |
99 | PRAGMA_OMP_SIMD() |
100 | for (dim_t i = 0; i < I; i++) { |
101 | scratch_quantized[ld * I * G * O + i * G * O + go] |
102 | = qz_b0<in_data_t, int8_t>()( |
103 | src[ld * G * O * I + go * I + i], s); |
104 | } |
105 | }); |
106 | } |
107 | |
108 | static inline void compensate_igo(float *compensation, |
109 | const memory_desc_wrapper &src_d, int8_t *scratch_quantized, |
110 | int32_t *scratch_compensation, size_t scratch_comp_sz, int nthr) { |
111 | // TODO: trivial strides assumed here. |
112 | // Use proper strides where appropriate |
113 | dim_t L, D, I, G, O; |
114 | init_dims(L, D, I, G, O, src_d); |
115 | |
116 | // We parallelize on LD and GO |
117 | // TODO: maybe restrict parallelism as we might have large |
118 | // parallelisation overhead if dimensions are small |
119 | const int LD_nthr = nstl::min(L * D, dim_t(nthr)); |
120 | const int GO_nthr = nstl::min(G * O, dim_t(nthr / LD_nthr)); |
121 | parallel(nthr, [&](const int ithr, const int nthr) { |
122 | int LD_ithr = -1; |
123 | int GO_ithr = -1; |
124 | dim_t LD_s = -1, LD_e = -1; |
125 | dim_t GO_s = -1, GO_e = -1; |
126 | if (ithr < LD_nthr * GO_nthr) { |
127 | LD_ithr = ithr % LD_nthr; |
128 | GO_ithr = ithr / LD_nthr; |
129 | balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); |
130 | balance211(G * O, GO_nthr, GO_ithr, GO_s, GO_e); |
131 | } |
132 | int32_t *compensation_s32 |
133 | = scratch_compensation + ithr * scratch_comp_sz; |
134 | for (int ld = LD_s; ld < LD_e; ld++) { |
135 | if (I == 1) { |
136 | PRAGMA_OMP_SIMD() |
137 | for (int go = GO_s; go < GO_e; go++) |
138 | compensation[ld * G * O + go] = saturate<float>( |
139 | scratch_quantized[ld * I * G * O + go]); |
140 | } else { |
141 | // We split the loop on I in three to avoid conditionals or zeroing compensation |
142 | int i = 0; |
143 | PRAGMA_OMP_SIMD() |
144 | for (int go = GO_s; go < GO_e; go++) |
145 | compensation_s32[go] |
146 | = scratch_quantized[go + G * O * (i + I * (ld))]; |
147 | // 1 <= i < I-1 |
148 | for (i = 1; i < I - 1; i++) { |
149 | PRAGMA_OMP_SIMD() |
150 | for (int go = GO_s; go < GO_e; go++) |
151 | compensation_s32[go] += scratch_quantized[go |
152 | + G * O * (i + I * (ld))]; |
153 | } |
154 | // i = I-1 |
155 | PRAGMA_OMP_SIMD() |
156 | for (int go = GO_s; go < GO_e; go++) |
157 | compensation[ld * G * O + go] = saturate<float>( |
158 | compensation_s32[go] |
159 | + scratch_quantized[go + G * O * (i + I * (ld))]); |
160 | } |
161 | } |
162 | }); |
163 | } |
164 | |
165 | static inline void compensate_goi(float *compensation, |
166 | const memory_desc_wrapper &src_d, int8_t *scratch_quantized) { |
167 | // TODO: trivial strides assumed here. |
168 | // Use proper strides where appropriate |
169 | dim_t L, D, I, G, O; |
170 | init_dims(L, D, I, G, O, src_d); |
171 | |
172 | parallel_nd(L * D, G * O, [&](dim_t ld, dim_t go) { |
173 | int32_t compensation_s32 = 0; |
174 | PRAGMA_OMP_SIMD() |
175 | for (dim_t i = 0; i < I; i++) { |
176 | compensation_s32 |
177 | += scratch_quantized[ld * I * G * O + i * G * O + go]; |
178 | } |
179 | // TODO: do not convert to f32 if this compensation is not |
180 | // going to be added to a bias (e.g. like in lstm |
181 | // projection where it is directly added to the s32 |
182 | // accumulators) |
183 | compensation[ld * G * O + go] = saturate<float>(compensation_s32); |
184 | }); |
185 | } |
186 | |
187 | template <data_type_t type_i, data_type_t type_o> |
188 | struct rnn_data_reorder_t : public primitive_t { |
189 | struct pd_t : public cpu_reorder_pd_t { |
190 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
191 | |
192 | DECLARE_COMMON_PD_T("rnn_data_reorder" , rnn_data_reorder_t); |
193 | |
194 | private: |
195 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
196 | const primitive_attr_t *attr, engine_t *src_engine, |
197 | const memory_desc_t *src_md, engine_t *dst_engine, |
198 | const memory_desc_t *dst_md) { |
199 | using namespace format_tag; |
200 | using namespace status; |
201 | const memory_desc_wrapper id(src_md), od(dst_md); |
202 | |
203 | bool args_ok = true; |
204 | #define PD_CHECK_ARG(x) args_ok = args_ok && (x) |
205 | PD_CHECK_ARG(id.data_type() == type_i); |
206 | PD_CHECK_ARG(od.data_type() == type_o); |
207 | PD_CHECK_ARG(utils::one_of(id.ndims(), 3, 4)); |
208 | PD_CHECK_ARG(!id.has_runtime_dims_or_strides()); |
209 | auto skip_mask = primitive_attr_t::skip_mask_t::rnn_data_qparams |
210 | | primitive_attr_t::skip_mask_t::rnn_weights_qparams |
211 | | primitive_attr_t::skip_mask_t:: |
212 | rnn_weights_projection_qparams; |
213 | PD_CHECK_ARG(attr->has_default_values(skip_mask)); |
214 | PD_CHECK_ARG(IMPLICATION(id.ndims() == 3, |
215 | id.matches_tag(tnc) && od.matches_tag(tnc))); |
216 | PD_CHECK_ARG(IMPLICATION(id.ndims() == 4, |
217 | id.matches_tag(ldnc) && od.matches_tag(ldnc))); |
218 | #undef PD_CHECK_ARG |
219 | if (!args_ok) return invalid_arguments; |
220 | |
221 | auto _pd = new pd_t(attr, src_engine->kind(), src_md, |
222 | dst_engine->kind(), dst_md); |
223 | if (_pd == nullptr) return out_of_memory; |
224 | if (_pd->init(engine, src_engine, dst_engine) != success) { |
225 | delete _pd; |
226 | return unimplemented; |
227 | } |
228 | _pd->init_scratchpad_md(); |
229 | return safe_ptr_assign(*reorder_pd, _pd); |
230 | } |
231 | friend dnnl::impl::impl_list_item_t; |
232 | }; |
233 | |
234 | rnn_data_reorder_t(const pd_t *apd) : primitive_t(apd) {} |
235 | |
236 | private: |
237 | typedef typename prec_traits<type_i>::type in_data_t; |
238 | typedef typename prec_traits<type_o>::type out_data_t; |
239 | |
240 | bool is_dense() const { |
241 | const memory_desc_wrapper &input_d = pd()->src_md(); |
242 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
243 | return utils::everyone_is(1, |
244 | input_d.blocking_desc().strides[input_d.ndims() - 1], |
245 | output_d.blocking_desc().strides[output_d.ndims() - 1]); |
246 | } |
247 | |
248 | /* This function assumes that only the innermost dimension (C) is |
249 | dense (that is to say, stride is 1). This is enough to have |
250 | good performance and allow non trivial strides on other |
251 | dimensions (to allow an "optimized" path for views for |
252 | example). |
253 | */ |
254 | status_t execute_dense(out_data_t *output, const in_data_t *input, |
255 | const float scale, const float shift) const { |
256 | assert(type_i == data_type::f32); |
257 | assert(type_o == data_type::u8 || type_o == data_type::s8); |
258 | |
259 | const memory_desc_wrapper &input_d = pd()->src_md(); |
260 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
261 | const dim_t outer_dim |
262 | = utils::array_product(input_d.dims(), input_d.ndims() - 1); |
263 | const dim_t inner_dim = input_d.dims()[input_d.ndims() - 1]; |
264 | |
265 | parallel(0, [&](const int ithr, const int nthr) { |
266 | dim_t start {0}, end {0}; |
267 | balance211(outer_dim, nthr, ithr, start, end); |
268 | for (int i = start; i < end; ++i) { |
269 | const dim_t off_in = input_d.off_l(i * inner_dim); |
270 | const dim_t off_out = output_d.off_l(i * inner_dim); |
271 | const in_data_t *__restrict i_ = input + off_in; |
272 | out_data_t *__restrict o_ = output + off_out; |
273 | PRAGMA_OMP_SIMD() |
274 | for (int j = 0; j < inner_dim; ++j) { |
275 | const float in = (float)i_[j] * scale + shift; |
276 | o_[j] = qz_a1b0<float, out_data_t>()(in); |
277 | } |
278 | } |
279 | }); |
280 | return status::success; |
281 | } |
282 | |
283 | status_t execute_generic(out_data_t *output, const in_data_t *input, |
284 | float scale, float shift) const { |
285 | assert(type_i == data_type::f32); |
286 | assert(type_o == data_type::u8 || type_o == data_type::s8); |
287 | |
288 | const memory_desc_wrapper &input_d = pd()->src_md(); |
289 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
290 | const size_t nelems = input_d.nelems(); |
291 | parallel_nd(nelems, [&](size_t i) { |
292 | const float in = (float)input[input_d.off_l(i)] * scale + shift; |
293 | output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in); |
294 | }); |
295 | return status::success; |
296 | } |
297 | |
298 | status_t execute(const exec_ctx_t &ctx) const override { |
299 | auto input = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM); |
300 | auto output = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO); |
301 | const float scale = pd()->attr()->rnn_data_qparams_.scale_; |
302 | const float shift = pd()->attr()->rnn_data_qparams_.shift_; |
303 | |
304 | if (is_dense()) |
305 | return execute_dense(output, input, scale, shift); |
306 | else |
307 | return execute_generic(output, input, scale, shift); |
308 | } |
309 | |
310 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
311 | }; |
312 | |
313 | template <data_type_t type_i> |
314 | struct rnn_weights_reorder_s8_t : public primitive_t { |
315 | struct pd_t : public cpu_reorder_pd_t { |
316 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
317 | typedef dnnl_status_t (*gemm_pack_f)(const char *identifier, |
318 | const char *transa, const char *transb, const dim_t *M, |
319 | const dim_t *N, const dim_t *K, const dim_t *lda, |
320 | const dim_t *ldb, const void *src, void *dst); |
321 | |
322 | DECLARE_COMMON_PD_T("rnn_weights_reorder_s8" , rnn_weights_reorder_s8_t); |
323 | |
324 | status_t init( |
325 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
326 | status_t status |
327 | = cpu_reorder_pd_t::init(engine, src_engine, dst_engine); |
328 | if (status != status::success) return status; |
329 | |
330 | nthr_ = dnnl_get_max_threads(); |
331 | init_scratchpad(); |
332 | |
333 | return status::success; |
334 | } |
335 | |
336 | format_tag_t itag_ = format_tag::undef; |
337 | format_tag_t otag_ = format_tag::undef; |
338 | size_t thr_scratch_comp_sz_ = 0; |
339 | int nthr_; // To not exceed the limit in execute used for set up. |
340 | gemm_pack_f gemm_pack; |
341 | |
342 | private: |
343 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
344 | const primitive_attr_t *attr, engine_t *src_engine, |
345 | const memory_desc_t *src_md, engine_t *dst_engine, |
346 | const memory_desc_t *dst_md) { |
347 | using namespace format_tag; |
348 | using namespace rnn_packed_format; |
349 | using namespace status; |
350 | const memory_desc_wrapper id(src_md), od(dst_md); |
351 | |
352 | bool args_ok = true; |
353 | #define PD_CHECK_ARG(x) args_ok = args_ok && (x) |
354 | // Fast checks |
355 | PD_CHECK_ARG(id.data_type() == type_i); |
356 | PD_CHECK_ARG(od.data_type() == data_type::s8); |
357 | PD_CHECK_ARG(od.format_kind() == format_kind::rnn_packed); |
358 | PD_CHECK_ARG(utils::one_of( |
359 | od.rnn_packed_desc().format, ldigo_p, ldio_p)); |
360 | PD_CHECK_ARG(od.ndims() == id.ndims()); |
361 | // TODO: we have to skip projection qparam even for regular lstm |
362 | // as we use the same attr for regular weights and projection |
363 | auto skip_mask = primitive_attr_t::skip_mask_t::rnn_data_qparams |
364 | | primitive_attr_t::skip_mask_t::rnn_weights_qparams |
365 | | primitive_attr_t::skip_mask_t:: |
366 | rnn_weights_projection_qparams; |
367 | PD_CHECK_ARG(attr->has_default_values(skip_mask)); |
368 | if (!args_ok) return invalid_arguments; |
369 | |
370 | // Slower checks |
371 | PD_CHECK_ARG(id.is_dense()); |
372 | if (!args_ok) return invalid_arguments; |
373 | |
374 | format_tag_t itag = id.matches_one_of_tag(ldigo, ldgoi, ldio, ldoi); |
375 | if (itag == format_tag::undef) return invalid_arguments; |
376 | |
377 | // TODO: add support for layer and direction dimensions |
378 | // weights_layer and weights_iter |
379 | if (id.ndims() == 5 |
380 | && !utils::one_of(attr->rnn_weights_qparams_.mask_, 0, 24)) |
381 | return unimplemented; |
382 | // weights_projection |
383 | if (id.ndims() == 4 |
384 | && !utils::one_of( |
385 | attr->rnn_weights_projection_qparams_.mask_, 0, 8)) |
386 | return unimplemented; |
387 | |
388 | auto _pd = new pd_t(attr, src_engine->kind(), src_md, |
389 | dst_engine->kind(), dst_md); |
390 | if (_pd == nullptr) return out_of_memory; |
391 | _pd->itag_ = itag; |
392 | if (_pd->init(engine, src_engine, dst_engine) != success) { |
393 | delete _pd; |
394 | return unimplemented; |
395 | } |
396 | _pd->init_scratchpad_md(); |
397 | const bool is_s8s8 = dst_md->extra.flags |
398 | & memory_extra_flags::rnn_s8s8_compensation; |
399 | _pd->gemm_pack = is_s8s8 ? &gemm_s8s8s32_pack : &gemm_s8u8s32_pack; |
400 | |
401 | return safe_ptr_assign(*reorder_pd, _pd); |
402 | #undef PD_CHECK_ARG |
403 | } |
404 | |
405 | void init_scratchpad() { |
406 | using namespace format_tag; |
407 | |
408 | const memory_desc_wrapper id(src_md()); |
409 | const size_t nelems = id.nelems(); |
410 | const auto &dims = id.dims(); |
411 | |
412 | using namespace memory_tracking::names; |
413 | auto scratchpad = scratchpad_registry().registrar(); |
414 | const size_t quantization_size = nelems; |
415 | // we do not use GO directly, as this can cause false |
416 | // sharing when parallelizing on I (2 threads writing to |
417 | // the same cache line) |
418 | thr_scratch_comp_sz_ = itag_ == ldigo ? dims[3] * dims[4] : dims[3]; |
419 | thr_scratch_comp_sz_ = utils::rnd_up(thr_scratch_comp_sz_, 16); |
420 | size_t reduction_size = 0; |
421 | if (utils::one_of(itag_, ldigo, ldio)) |
422 | reduction_size = nthr_ * thr_scratch_comp_sz_; |
423 | |
424 | scratchpad.template book<int8_t>( |
425 | key_reorder_rnn_weights_quantization, quantization_size); |
426 | scratchpad.template book<int32_t>( |
427 | key_reorder_rnn_weights_reduction, reduction_size); |
428 | } |
429 | |
430 | friend dnnl::impl::impl_list_item_t; |
431 | }; |
432 | |
433 | rnn_weights_reorder_s8_t(const pd_t *apd) : primitive_t(apd) {} |
434 | |
435 | private: |
436 | typedef typename prec_traits<type_i>::type in_data_t; |
437 | |
438 | status_t execute(const exec_ctx_t &ctx) const override { |
439 | // TODO: trivial strides assumed here. |
440 | // Use proper strides where appropriate |
441 | |
442 | using namespace format_tag; |
443 | |
444 | auto src = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM); |
445 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_TO); |
446 | const memory_desc_wrapper &src_d = pd()->src_md(); |
447 | const memory_desc_wrapper &dst_d = pd()->dst_md(); |
448 | if (src_d.has_zero_dim()) { |
449 | assert(dst_d.has_zero_dim()); |
450 | return status::success; |
451 | } |
452 | |
453 | dim_t L, D, I, G, O; |
454 | init_dims(L, D, I, G, O, src_d); |
455 | |
456 | /* Quantize src & compute compensation */ |
457 | auto scratch_quantized |
458 | = (int8_t * __restrict) ctx.get_scratchpad_grantor() |
459 | .template get<void>(memory_tracking::names:: |
460 | key_reorder_rnn_weights_quantization); |
461 | auto scratch_compensation |
462 | = (int32_t * __restrict) ctx.get_scratchpad_grantor() |
463 | .template get<void>(memory_tracking::names:: |
464 | key_reorder_rnn_weights_reduction); |
465 | float *comp = reinterpret_cast<float *>( |
466 | dst + dst_d.rnn_packed_desc().offset_compensation); |
467 | float *scales = nullptr; |
468 | int mask = 0; |
469 | if (src_d.ndims() == 5) { |
470 | scales = pd()->attr()->rnn_weights_qparams_.scales_; |
471 | mask = pd()->attr()->rnn_weights_qparams_.mask_; |
472 | } |
473 | if (src_d.ndims() == 4) { |
474 | scales = pd()->attr()->rnn_weights_projection_qparams_.scales_; |
475 | mask = pd()->attr()->rnn_weights_projection_qparams_.mask_; |
476 | } |
477 | /* Step 1: we quantize if we need to */ |
478 | if (type_i == data_type::f32) { |
479 | switch (pd()->itag_) { |
480 | case ldigo: |
481 | case ldio: |
482 | quantize_igo<type_i>(scratch_quantized, src_d, (float *)src, |
483 | mask, scales); |
484 | break; |
485 | case ldgoi: |
486 | case ldoi: |
487 | quantize_goi<type_i>(scratch_quantized, src_d, (float *)src, |
488 | mask, scales); |
489 | break; |
490 | default: assert(!"Unsupported reorder" ); |
491 | } |
492 | } else |
493 | scratch_quantized = (int8_t * __restrict) src; |
494 | |
495 | /* Step 2: we pre-compute the compensation */ |
496 | switch (pd()->itag_) { |
497 | case ldigo: |
498 | case ldio: |
499 | compensate_igo(comp, src_d, scratch_quantized, |
500 | scratch_compensation, pd()->thr_scratch_comp_sz_, |
501 | pd()->nthr_); |
502 | break; |
503 | case ldgoi: |
504 | case ldoi: compensate_goi(comp, src_d, scratch_quantized); break; |
505 | default: assert(!"Unsupported reorder" ); |
506 | } |
507 | |
508 | /* Step 3: we pack the matrix */ |
509 | const auto off_igo = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) { |
510 | return o + O * (g + G * (i + I * (d + D * l))); |
511 | }; |
512 | const int n_parts = dst_d.rnn_packed_desc().n_parts; |
513 | const size_t *size_packed_cell = dst_d.rnn_packed_desc().part_pack_size; |
514 | const int *parts = dst_d.rnn_packed_desc().parts; |
515 | const dim_t n = dst_d.rnn_packed_desc().n; |
516 | const dim_t ldb = dst_d.rnn_packed_desc().ldb; |
517 | char *to_pack = dst; |
518 | |
519 | for (dim_t l = 0; l < L; l++) { |
520 | for (dim_t d = 0; d < D; d++) { |
521 | for (dim_t p = 0; p < n_parts; p++) { |
522 | dim_t g = (p > 0) ? parts[p - 1] : 0; |
523 | dim_t m_p = parts[p] * O; |
524 | dim_t k_p = I; |
525 | dim_t lda = (dim_t)G * O; |
526 | CHECK(pd()->gemm_pack("A" , "N" , "N" , &m_p, &n, &k_p, &lda, |
527 | &ldb, scratch_quantized + off_igo(l, d, 0, g, 0), |
528 | to_pack)); |
529 | to_pack += size_packed_cell[p]; |
530 | } |
531 | } |
532 | } |
533 | return status::success; |
534 | } |
535 | |
536 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
537 | }; |
538 | |
539 | template <data_type_t type_i, data_type_t type_o> |
540 | struct rnn_weights_reorder_t : public primitive_t { |
541 | struct pd_t : public cpu_reorder_pd_t { |
542 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
543 | |
544 | DECLARE_COMMON_PD_T("rnn_weights_reorder" , rnn_weights_reorder_t); |
545 | |
546 | format_tag_t itag_; |
547 | |
548 | status_t init( |
549 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
550 | status_t status |
551 | = cpu_reorder_pd_t::init(engine, src_engine, dst_engine); |
552 | if (status != status::success) return status; |
553 | |
554 | init_scratchpad(); |
555 | |
556 | return status::success; |
557 | } |
558 | |
559 | private: |
560 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
561 | const primitive_attr_t *attr, engine_t *src_engine, |
562 | const memory_desc_t *src_md, engine_t *dst_engine, |
563 | const memory_desc_t *dst_md) { |
564 | using namespace format_tag; |
565 | using namespace rnn_packed_format; |
566 | using namespace status; |
567 | |
568 | const memory_desc_wrapper id(src_md), od(dst_md); |
569 | bool args_ok = true; |
570 | #define PD_CHECK_ARG(x) args_ok = args_ok && (x) |
571 | PD_CHECK_ARG(id.data_type() == type_i); |
572 | PD_CHECK_ARG(od.data_type() == type_o); |
573 | PD_CHECK_ARG(od.format_kind() == format_kind::rnn_packed); |
574 | PD_CHECK_ARG(utils::one_of( |
575 | od.rnn_packed_desc().format, ldigo_p, ldgoi_p, ldio_p)); |
576 | PD_CHECK_ARG(attr->has_default_values()); |
577 | #undef PD_CHECK_ARG |
578 | if (!args_ok) return invalid_arguments; |
579 | |
580 | format_tag_t itag = id.matches_one_of_tag(ldigo, ldgoi, ldio, ldoi); |
581 | if (itag == format_tag::undef) return invalid_arguments; |
582 | |
583 | auto _pd = new pd_t(attr, src_engine->kind(), src_md, |
584 | dst_engine->kind(), dst_md); |
585 | if (_pd == nullptr) return out_of_memory; |
586 | if (_pd->init(engine, src_engine, dst_engine) != success) { |
587 | delete _pd; |
588 | return unimplemented; |
589 | } |
590 | _pd->itag_ = itag; |
591 | _pd->init_scratchpad_md(); |
592 | return safe_ptr_assign(*reorder_pd, _pd); |
593 | } |
594 | |
595 | void init_scratchpad() { |
596 | using namespace format_tag; |
597 | using namespace rnn_packed_format; |
598 | |
599 | const memory_desc_wrapper id(src_md()); |
600 | const memory_desc_wrapper od(dst_md()); |
601 | const rnn_packed_desc_t &rnn_pdata = od.rnn_packed_desc(); |
602 | |
603 | format_tag_t itag = id.matches_one_of_tag(ldigo, ldgoi, ldio); |
604 | const bool layout_cross_case |
605 | = (itag == ldigo && rnn_pdata.format == ldgoi_p) |
606 | || (itag == ldgoi && rnn_pdata.format == ldigo_p) |
607 | || (itag == ldio && rnn_pdata.format == ldio_p), |
608 | dt_cross_case |
609 | = type_i == data_type::f32 && type_o == data_type::bf16; |
610 | const size_t sz = id.nelems(); |
611 | |
612 | using namespace memory_tracking::names; |
613 | auto scratchpad = scratchpad_registry().registrar(); |
614 | scratchpad.template book<out_data_t>( |
615 | key_reorder_rnn_weights_transposition, |
616 | layout_cross_case ? sz : 0); |
617 | scratchpad.template book<out_data_t>( |
618 | key_reorder_rnn_weights_bf16_cvt, dt_cross_case ? sz : 0); |
619 | } |
620 | friend dnnl::impl::impl_list_item_t; |
621 | }; |
622 | |
623 | rnn_weights_reorder_t(const pd_t *apd) : primitive_t(apd) {} |
624 | |
625 | private: |
626 | typedef typename prec_traits<type_i>::type in_data_t; |
627 | typedef typename prec_traits<type_o>::type out_data_t; |
628 | |
629 | status_t execute(const exec_ctx_t &ctx) const override { |
630 | // TODO: trivial strides assumed here. |
631 | // Use proper strides where appropriate |
632 | |
633 | using namespace format_tag; |
634 | using namespace rnn_packed_format; |
635 | |
636 | auto input = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM); |
637 | auto output = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO); |
638 | const memory_desc_wrapper &input_d = pd()->src_md(); |
639 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
640 | if (input_d.has_zero_dim()) { |
641 | assert(output_d.has_zero_dim()); |
642 | return status::success; |
643 | } |
644 | |
645 | const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); |
646 | dim_t L, D, I, G, O; |
647 | init_dims(L, D, I, G, O, input_d); |
648 | |
649 | /* Pack */ |
650 | const bool from_igo = utils::one_of(pd()->itag_, ldigo, ldio); |
651 | const bool to_igo = utils::one_of(rnn_pdata.format, ldigo_p, ldio_p); |
652 | const int n_parts = rnn_pdata.n_parts; |
653 | const size_t *size_packed_cell = rnn_pdata.part_pack_size; |
654 | const int *parts = rnn_pdata.parts; |
655 | const dim_t n = rnn_pdata.n; |
656 | |
657 | /* Convert fp32 input to bf16 */ |
658 | out_data_t *input_cvt = (out_data_t *)input; |
659 | if (type_i == data_type::f32 && type_o == data_type::bf16) { |
660 | input_cvt |
661 | = (out_data_t *)ctx.get_scratchpad_grantor() |
662 | .template get<void>(memory_tracking::names:: |
663 | key_reorder_rnn_weights_bf16_cvt); |
664 | parallel_nd(L * D, [&](dim_t ld) { |
665 | cvt_float_to_bfloat16((bfloat16_t *)input_cvt + ld * G * O * I, |
666 | (float *)input + ld * G * O * I, G * O * I); |
667 | }); |
668 | } |
669 | |
670 | /* Transpose weights prior to packing to ensure that packed GEMM |
671 | * algorithm will be dispatched */ |
672 | out_data_t *input_tr = input_cvt; |
673 | if (from_igo != to_igo) { |
674 | input_tr |
675 | = (out_data_t *)ctx.get_scratchpad_grantor().template get<void>( |
676 | memory_tracking::names:: |
677 | key_reorder_rnn_weights_transposition); |
678 | const dim_t M = to_igo ? G * O : I; |
679 | const dim_t N = to_igo ? I : G * O; |
680 | parallel_nd(L * D, N, [&](dim_t ld, dim_t i) { |
681 | for (dim_t j = 0; j < M; j++) { |
682 | input_tr[ld * M * N + i * M + j] |
683 | = input_cvt[ld * M * N + j * N + i]; |
684 | } |
685 | }); |
686 | } |
687 | |
688 | const auto off_igo = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) { |
689 | return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; |
690 | }; |
691 | const auto off_goi = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) { |
692 | return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; |
693 | }; |
694 | const dim_t lda = to_igo ? G * O : I; |
695 | const dim_t ldb = rnn_pdata.ldb; |
696 | for (dim_t l = 0; l < L; l++) { |
697 | for (dim_t d = 0; d < D; d++) { |
698 | for (dim_t p = 0; p < n_parts; p++) { |
699 | const dim_t g = (p > 0) ? parts[p - 1] : 0; |
700 | const dim_t m_p = to_igo ? parts[p] * O : I; |
701 | const dim_t k_p = to_igo ? I : parts[p] * O; |
702 | if (type_o == data_type::bf16) { |
703 | CHECK(gemm_bf16bf16f32_pack("A" , "N" , "N" , &m_p, &n, |
704 | &k_p, &lda, &ldb, |
705 | (bfloat16_t *)&input_tr[to_igo |
706 | ? off_igo(l, d, 0, g, 0) |
707 | : off_goi(l, d, 0, g, 0)], |
708 | (bfloat16_t *)output)); |
709 | } else { |
710 | CHECK(sgemm_pack("A" , "N" , "N" , &m_p, &n, &k_p, &lda, |
711 | &ldb, |
712 | (float *)&input_tr[to_igo |
713 | ? off_igo(l, d, 0, g, 0) |
714 | : off_goi(l, d, 0, g, 0)], |
715 | (float *)output)); |
716 | } |
717 | output += size_packed_cell[p] / sizeof(out_data_t); |
718 | } |
719 | } |
720 | } |
721 | return status::success; |
722 | } |
723 | |
724 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
725 | }; |
726 | |
727 | template <data_type_t type_i, data_type_t type_o> |
728 | struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { |
729 | struct pd_t : public cpu_reorder_pd_t { |
730 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
731 | |
732 | DECLARE_COMMON_PD_T("rnn_brgemm_weights_reorder_s8_t" , |
733 | rnn_brgemm_weights_reorder_s8_t); |
734 | |
735 | format_tag_t itag_; |
736 | format_tag_t otag_; |
737 | int nthr_; // To not exceed the limit in execute used for set up. |
738 | size_t thr_scratch_comp_sz_ = 0; |
739 | |
740 | status_t init( |
741 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
742 | status_t status |
743 | = cpu_reorder_pd_t::init(engine, src_engine, dst_engine); |
744 | if (status != status::success) return status; |
745 | |
746 | nthr_ = dnnl_get_max_threads(); |
747 | init_scratchpad(); |
748 | |
749 | return status::success; |
750 | } |
751 | |
752 | private: |
753 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
754 | const primitive_attr_t *attr, engine_t *src_engine, |
755 | const memory_desc_t *src_md, engine_t *dst_engine, |
756 | const memory_desc_t *dst_md) { |
757 | using namespace status; |
758 | using namespace format_tag; |
759 | using namespace memory_extra_flags; |
760 | |
761 | const memory_desc_wrapper id(src_md), od(dst_md); |
762 | |
763 | const bool args_ok = true && id.data_type() == type_i |
764 | && od.data_type() == data_type::s8 && id.is_dense(); |
765 | if (!args_ok) return invalid_arguments; |
766 | |
767 | const auto skip_mask |
768 | = primitive_attr_t::skip_mask_t::rnn_data_qparams |
769 | | primitive_attr_t::skip_mask_t::rnn_weights_qparams |
770 | | primitive_attr_t::skip_mask_t:: |
771 | rnn_weights_projection_qparams; |
772 | if (!attr->has_default_values(skip_mask)) return invalid_arguments; |
773 | |
774 | // TODO: add support for layer and direction dimensions |
775 | // weights_layer and weights_iter |
776 | if (id.ndims() == 5 |
777 | && !utils::one_of(attr->rnn_weights_qparams_.mask_, 0, 24)) |
778 | return unimplemented; |
779 | // weights_projection |
780 | if (id.ndims() == 4 |
781 | && !utils::one_of( |
782 | attr->rnn_weights_projection_qparams_.mask_, 0, 8)) |
783 | return unimplemented; |
784 | |
785 | // Check the proper memory desc has been passed to u8s8 and s8s8 |
786 | // Note: currently rnn_u8s8_compensation and rnn_s8s8_compensation |
787 | // have common bit so we have to perform additional checks to |
788 | // separate these two cases |
789 | const bool check_u8s8 = (od.extra().flags & rnn_u8s8_compensation) |
790 | && !types::extra_flag_rnn_s8s8_compensation_is_set( |
791 | od.extra().flags) |
792 | && od.extra().compensation_mask |
793 | == ((id.ndims() == 5) ? 27 /* 11011 */ |
794 | : 13 /* 1101 */); |
795 | const bool check_s8s8 = od.extra().flags & rnn_s8s8_compensation |
796 | && od.extra().compensation_mask == 0; |
797 | if (!(check_u8s8 || check_s8s8)) return invalid_arguments; |
798 | |
799 | auto _pd = new pd_t(attr, src_engine->kind(), src_md, |
800 | dst_engine->kind(), dst_md); |
801 | if (_pd == nullptr) return out_of_memory; |
802 | if (_pd->init(engine, src_engine, dst_engine) != success) { |
803 | delete _pd; |
804 | return unimplemented; |
805 | } |
806 | |
807 | _pd->itag_ = format_tag::undef; |
808 | |
809 | format_tag_t otag, itag; |
810 | |
811 | itag = id.matches_one_of_tag(ldigo, ldio); |
812 | otag = od.matches_one_of_tag(ldgOI64o4i, ldgOI32o4i, ldOI32o4i); |
813 | if (itag != format_tag::undef && otag != format_tag::undef) { |
814 | _pd->itag_ = itag; |
815 | _pd->otag_ = otag; |
816 | } else { |
817 | delete _pd; |
818 | return invalid_arguments; |
819 | } |
820 | _pd->init_scratchpad_md(); |
821 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
822 | } |
823 | |
824 | void init_scratchpad() { |
825 | using namespace format_tag; |
826 | |
827 | const memory_desc_wrapper id(src_md()); |
828 | const size_t nelems = id.nelems(); |
829 | const auto &dims = id.dims(); |
830 | const auto ndims = id.ndims(); |
831 | |
832 | using namespace memory_tracking::names; |
833 | auto scratchpad = scratchpad_registry().registrar(); |
834 | const size_t quantization_size = nelems; |
835 | // we do not use GO directly, as this can cause false |
836 | // sharing when parallelizing on I (2 threads writing to |
837 | // the same cache line) |
838 | thr_scratch_comp_sz_ = (ndims == 5) ? dims[3] * dims[4] : dims[3]; |
839 | thr_scratch_comp_sz_ = utils::rnd_up(thr_scratch_comp_sz_, 16); |
840 | const size_t reduction_size = nthr_ * thr_scratch_comp_sz_; |
841 | |
842 | scratchpad.template book<int8_t>( |
843 | key_reorder_rnn_weights_quantization, quantization_size); |
844 | scratchpad.template book<int32_t>( |
845 | key_reorder_rnn_weights_reduction, reduction_size); |
846 | } |
847 | friend dnnl::impl::impl_list_item_t; |
848 | }; |
849 | |
850 | rnn_brgemm_weights_reorder_s8_t(const pd_t *apd) : primitive_t(apd) {} |
851 | |
852 | private: |
853 | typedef typename prec_traits<type_i>::type in_data_t; |
854 | typedef typename prec_traits<type_o>::type out_data_t; |
855 | |
856 | status_t execute(const exec_ctx_t &ctx) const override { |
857 | using namespace format_tag; |
858 | using namespace data_type; |
859 | using namespace utils; |
860 | using namespace memory_extra_flags; |
861 | |
862 | auto src = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM); |
863 | auto dst = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO); |
864 | const memory_desc_wrapper &src_d = pd()->src_md(); |
865 | const memory_desc_wrapper &dst_d = pd()->dst_md(); |
866 | if (src_d.has_zero_dim()) { |
867 | assert(dst_d.has_zero_dim()); |
868 | return status::success; |
869 | } |
870 | |
871 | const auto &blocked_d = dst_d; |
872 | const auto &pdims = blocked_d.padded_dims(); |
873 | |
874 | const int o_block = pd()->otag_ == ldgOI64o4i ? 64 : 32; |
875 | static constexpr int i_block = 4; |
876 | |
877 | dim_t L, D, I, G, O; |
878 | init_dims(L, D, I, G, O, src_d); |
879 | |
880 | const dim_t pI = pdims[2]; |
881 | const dim_t pO = (src_d.ndims() == 5) ? pdims[4] : pdims[3]; |
882 | const dim_t IB = pI / i_block; |
883 | const dim_t OB = pO / o_block; |
884 | |
885 | const size_t compensation_offset = (size_t)L * D * G * pI * pO; |
886 | |
887 | /* Quantize src & compute compensation */ |
888 | auto scratch_quantized |
889 | = (int8_t * __restrict) ctx.get_scratchpad_grantor() |
890 | .template get<void>(memory_tracking::names:: |
891 | key_reorder_rnn_weights_quantization); |
892 | auto scratch_compensation |
893 | = (int32_t * __restrict) ctx.get_scratchpad_grantor() |
894 | .template get<void>(memory_tracking::names:: |
895 | key_reorder_rnn_weights_reduction); |
896 | float *comp = reinterpret_cast<float *>(dst + compensation_offset); |
897 | const bool req_s8s8_comp = (dst_d.extra().flags & rnn_u8s8_compensation) |
898 | && !types::extra_flag_rnn_s8s8_compensation_is_set( |
899 | dst_d.extra().flags); |
900 | const auto mask_ok = [&](int mask) { |
901 | return mask |
902 | == ((src_d.ndims() == 5) ? 27 /* 11011 */ |
903 | : 13 /* 1101 */); |
904 | }; |
905 | |
906 | float *scales = nullptr; |
907 | int mask = 0; |
908 | if (src_d.ndims() == 5) { |
909 | scales = pd()->attr()->rnn_weights_qparams_.scales_; |
910 | mask = pd()->attr()->rnn_weights_qparams_.mask_; |
911 | } |
912 | if (src_d.ndims() == 4) { |
913 | scales = pd()->attr()->rnn_weights_projection_qparams_.scales_; |
914 | mask = pd()->attr()->rnn_weights_projection_qparams_.mask_; |
915 | } |
916 | if (type_i == data_type::f32) { |
917 | quantize_igo<type_i>( |
918 | scratch_quantized, src_d, (float *)src, mask, scales); |
919 | } else |
920 | scratch_quantized = (int8_t * __restrict) src; |
921 | |
922 | if (req_s8s8_comp && mask_ok(dst_d.extra().compensation_mask)) |
923 | compensate_igo(comp, src_d, scratch_quantized, scratch_compensation, |
924 | pd()->thr_scratch_comp_sz_, pd()->nthr_); |
925 | |
926 | const auto off_plain |
927 | = [&](dim_t l, dim_t d, dim_t i, dim_t g, dim_t o) { |
928 | return ((((dim_t)l * D + d) * I + i) * G + g) * O + o; |
929 | }; |
930 | |
931 | const auto off_blk = [&](dim_t l, dim_t d, dim_t g, dim_t ob, |
932 | dim_t ib) { |
933 | return (((((dim_t)l * D + d) * G + g) * OB + ob) * IB + ib) |
934 | * i_block * o_block; |
935 | }; |
936 | const auto off_inner_blk = [&](int xdim, int y, int x, |
937 | int folding_factor) { |
938 | const int row = (xdim) * (y / folding_factor) * folding_factor; |
939 | const int col = x * folding_factor + (y % folding_factor); |
940 | return row + col; |
941 | }; |
942 | const auto kernel_plain_to_blocked |
943 | = [&](const out_data_t *inp, out_data_t *out, int ib, int ob) { |
944 | PRAGMA_OMP_SIMD() |
945 | for (int i = 0; i < i_block * o_block; i++) |
946 | out[i] = 0; |
947 | |
948 | for_(int i = 0; i < i_block; i++) |
949 | for (int o = 0; o < o_block; o++) { |
950 | if ((i + ib * i_block < I) && (o + ob * o_block < O)) |
951 | out[off_inner_blk(o_block, i, o, i_block)] |
952 | = inp[i * G * O + o]; |
953 | } |
954 | }; |
955 | |
956 | parallel_nd(L, D, G, OB, IB, |
957 | [&](dim_t l, dim_t d, dim_t g, dim_t ob, dim_t ib) { |
958 | auto inp = &scratch_quantized[off_plain( |
959 | l, d, ib * i_block, g, ob * o_block)]; |
960 | auto out = &dst[off_blk(l, d, g, ob, ib)]; |
961 | |
962 | kernel_plain_to_blocked(inp, out, ib, ob); |
963 | }); |
964 | |
965 | return status::success; |
966 | } |
967 | |
968 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
969 | }; |
970 | |
971 | } // namespace cpu |
972 | } // namespace impl |
973 | } // namespace dnnl |
974 | |
975 | #endif |
976 | |