1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_RNN_COMMON_POSTGEMM_HPP
18#define CPU_X64_RNN_JIT_UNI_RNN_COMMON_POSTGEMM_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/rnn_pd.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
26#include "cpu/x64/jit_generator.hpp"
27
28#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
29
30#include "cpu/rnn/rnn_utils.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37struct jit_uni_rnn_postgemm : public jit_generator {
38
39 jit_uni_rnn_postgemm(const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd,
40 const char *name)
41 : jit_generator(name)
42 , rnn_(rnn)
43 , pd_(pd)
44 , projection_(false)
45 , bias_dt_size_(types::data_type_size(rnn.bias_dt))
46 , cstate_dt_size_(types::data_type_size(rnn.src_iter_c_dt))
47 , is_avx512(mayiuse(avx512_core))
48 , is_avx2(mayiuse(avx2))
49 , dscale_off_addr(0)
50 , dshift_off_addr(0)
51 , ymm_perm_mask_addr(0)
52 , zmm_perm_mask_addr(0)
53 , zero_addr(0)
54 , u8_saturation_addr(0)
55 , weights_scales_reg(r13)
56 , qtable(r14)
57 // implementations avoids to preserve Vmm(0) because of potential
58 // conflict with required in injectors usage for masks on sse4.1
59 // so it can be used as commong temporal vector register
60 , tmp_vector_register_idx(0)
61 , qd_reg_idx(tmp_vector_register_idx)
62 , bf16_reg1(zmm31)
63 , bf16_reg2(zmm30)
64 , bf16_reg3(zmm29)
65 , bf16_reg4(r13)
66 , bf16_reg5(zmm28)
67 , bf16_k_mask(k2)
68 , tmp_reg(bf16_reg4)
69 , zmm_tail_k_mask(k3)
70 , bf16_dq_reg_idx(tmp_vector_register_idx) {}
71
72 ~jit_uni_rnn_postgemm() {
73 if (bf16_emu_) delete bf16_emu_;
74 }
75
76 bool is_projection() const { return projection_; };
77
78 virtual status_t init(data_type_t src_data_t) {
79 // no need to check as bf16 is guarded for avx512 and above in rnn primtive
80 using namespace Xbyak;
81 if (src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16)) {
82 bf16_emu_ = new bf16_emulation_t(this, bf16_reg1, bf16_reg2,
83 bf16_reg3, bf16_reg4, bf16_reg5);
84
85 } else
86 bf16_emu_ = nullptr;
87 return status::success;
88 }
89
90 template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t,
91 typename gemm_acc_t, typename gates_t, typename scratch_t>
92 rnn_postgemm_sig(execute) {
93 if (pd_->desc()->prop_kind == prop_kind::backward)
94 execute_bwd(rnn, cell_position, ws_gates_, scratch_gates_,
95 augru_attention_, dst_layer_, dst_iter_c_, src_iter_,
96 src_iter_c_, diff_src_layer_, diff_augru_attention_,
97 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
98 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_,
99 ws_grid_, scratch_cell_, dst_iter_, weights_scales_,
100 block_step);
101 else
102 execute_fwd(rnn, cell_position, ws_gates_, scratch_gates_,
103 augru_attention_, dst_layer_, dst_iter_c_, src_iter_,
104 src_iter_c_, diff_src_layer_, diff_augru_attention_,
105 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
106 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_,
107 ws_grid_, scratch_cell_, dst_iter_, weights_scales_,
108 block_step);
109 }
110
111 template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t,
112 typename gemm_acc_t, typename gates_t, typename scratch_t>
113 rnn_postgemm_sig(execute_fwd) {
114 using namespace rnn_utils;
115 if (rnn.is_brgemm && !rnn_.unfused_post_gemm) {
116 for (int i = 0; i < rnn.m_block; i++)
117 postgemm_fwd_call(i, rnn, cell_position, ws_gates_,
118 scratch_gates_, augru_attention_, dst_layer_,
119 dst_iter_c_, src_iter_, src_iter_c_, weights_peephole_,
120 bias_, ws_grid_, scratch_cell_, dst_iter_,
121 weights_scales_, block_step);
122 } else {
123 // Todo: add parallelization on dhc for the batch 1 case
124 // Assumption: the kernel runs a loop on dhc elements
125 parallel_nd(rnn.mb, [&](dim_t i) {
126 postgemm_fwd_call(i, rnn, cell_position, ws_gates_,
127 scratch_gates_, augru_attention_, dst_layer_,
128 dst_iter_c_, src_iter_, src_iter_c_, weights_peephole_,
129 bias_, ws_grid_, scratch_cell_, dst_iter_,
130 weights_scales_, 0);
131 });
132 }
133 }
134
135 template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t,
136 typename gates_t, typename scratch_t>
137 inline void postgemm_fwd_call(int m, const rnn_utils::rnn_conf_t &rnn,
138 rnn_utils::cell_position_t cell_position, gates_t *ws_gates_,
139 scratch_t *scratch_gates_, const dst_layer_t *augru_attention_,
140 dst_layer_t *dst_layer_, void *dst_iter_c_,
141 const src_iter_t *src_iter_, const void *src_iter_c_,
142 const float *weights_peephole_, const void *bias_,
143 gates_t *ws_grid_, scratch_t *scratch_cell_, dst_iter_t *dst_iter_,
144 float *weights_scales_, int block_step) const {
145 const rnn_utils::ws_gates_aoc<gates_t> ws_gates(rnn, ws_gates_);
146 const rnn_utils::scratch_gates_aoc<scratch_t> scratch_gates(
147 rnn, scratch_gates_);
148 const rnn_utils::weights_peephole_aoc_t<const float> weights_peephole(
149 rnn, weights_peephole_);
150 const auto bias = rnn_utils::make_raw_aoc(
151 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
152
153 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
154 const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position);
155 const auto dst_layer_ld
156 = rnn.dst_layer_ld(cell_position, is_projection());
157 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
158 const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position);
159
160 const rnn_utils::ws_states_layer_aoc<dst_layer_t> dst_layer(
161 rnn, dst_layer_, dst_layer_ld);
162 const rnn_utils::ws_states_iter_aoc<dst_iter_t> dst_iter(
163 rnn, dst_iter_, dst_iter_ld);
164 const rnn_utils::ws_states_iter_aoc<const src_iter_t> src_iter(
165 rnn, src_iter_, src_iter_ld);
166 const rnn_utils::augru_attention_aoc<const dst_layer_t> augru_attention(
167 rnn, augru_attention_);
168 const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_,
169 types::data_type_size(rnn.dst_iter_c_dt),
170 rnn.ws_states_iter_c_nld, dst_iter_c_ld);
171 const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_,
172 types::data_type_size(rnn.src_iter_c_dt),
173 rnn.ws_states_iter_c_nld, src_iter_c_ld);
174 const rnn_utils::ws_gates_aoc<scratch_t> scratch_cell(
175 rnn, scratch_cell_);
176 const utils::array_offset_calculator<gates_t, 2> ws_Wh_b(
177 ws_grid_, rnn.mb, rnn.dhc);
178
179// Since the function F(...) returns by reference so an exception has
180// to be made for nullptr argument
181#define SAFE_PTR(F, ...) (CONCAT2(F, _) ? &(F(__VA_ARGS__)) : nullptr)
182
183 void *param1_ = SAFE_PTR(ws_gates, m, 0, 0); // RNN, LSTM, GRU
184 void *param2_ = SAFE_PTR(scratch_gates, m, 0, 0); // RNN, LSTM, GRU
185 const void *param3_ = bias(0, 0); // RNN, LSTM, GRU
186 void *param4_ = SAFE_PTR(dst_layer, m, 0); // RNN, LSTM, GRU
187 void *param5_ = SAFE_PTR(dst_iter, m, 0); // RNN, LSTM, GRU
188 const void *param6_;
189 void *param7_, *param8_;
190 void *param9_ = (void *)weights_scales_;
191 const size_t param10_ = block_step;
192 const void *param11_ = nullptr;
193
194 switch (pd_->cell_kind()) {
195 case alg_kind::vanilla_lstm:
196 param6_ = is_projection() ? src_iter_c_ : src_iter_c(m, 0);
197 param7_ = const_cast<void *>(dst_iter_c(m, 0));
198 param8_ = (void *)SAFE_PTR(weights_peephole, 0, 0);
199 break;
200 case alg_kind::lbr_gru:
201 param6_ = SAFE_PTR(src_iter, m, 0);
202 param7_ = SAFE_PTR(scratch_cell, m, 0, 0);
203 param8_ = ws_grid_ ? &ws_Wh_b(m, 0) : nullptr;
204 break;
205 case alg_kind::vanilla_gru:
206 param6_ = SAFE_PTR(src_iter, m, 0);
207 param7_ = nullptr;
208 param8_ = nullptr;
209 break;
210 case alg_kind::lbr_augru:
211 param6_ = SAFE_PTR(src_iter, m, 0);
212 param7_ = SAFE_PTR(scratch_cell, m, 0, 0);
213 param8_ = ws_grid_ ? &ws_Wh_b(m, 0) : nullptr;
214 param11_ = SAFE_PTR(augru_attention, m);
215 break;
216 case alg_kind::vanilla_augru:
217 param6_ = SAFE_PTR(src_iter, m, 0);
218 param7_ = nullptr;
219 param8_ = nullptr;
220 param11_ = SAFE_PTR(augru_attention, m);
221 break;
222 default:
223 param6_ = nullptr;
224 param7_ = nullptr;
225 param8_ = nullptr;
226 param11_ = nullptr;
227 break;
228 }
229 this->operator()(param1_, param2_, param3_, param4_, param5_, param6_,
230 param7_, param8_, param9_, param10_, param11_);
231#undef SAFE_PTR
232 }
233
234 template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t,
235 typename gemm_acc_t, typename gates_t, typename scratch_t>
236 rnn_postgemm_sig(execute_bwd) {
237 using namespace rnn_utils;
238 const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position);
239 const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position);
240 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
241
242 const rnn_utils::weights_peephole_aoc_t<const float> weights_peephole(
243 rnn, weights_peephole_);
244 const rnn_utils::ws_gates_aoc<gates_t> ws_gates(rnn, ws_gates_);
245 const rnn_utils::ws_gates_aoc<scratch_t> scratch_gates(
246 rnn, scratch_gates_);
247 const rnn_utils::ws_diff_states_layer_aoc<gemm_acc_t> diff_src_layer(
248 rnn, diff_src_layer_);
249 const rnn_utils::ws_diff_states_iter_aoc<gemm_acc_t> diff_src_iter(
250 rnn, diff_src_iter_);
251 const rnn_utils::ws_diff_states_iter_c_aoc<gemm_acc_t> diff_src_iter_c(
252 rnn, diff_src_iter_c_);
253 const rnn_utils::augru_attention_aoc<gemm_acc_t> diff_augru_attention(
254 rnn, diff_augru_attention_);
255 const rnn_utils::ws_diff_states_layer_aoc<gemm_acc_t> diff_dst_layer(
256 rnn, diff_dst_layer_);
257 const rnn_utils::ws_diff_states_iter_aoc<gemm_acc_t> diff_dst_iter(
258 rnn, diff_dst_iter_);
259 const rnn_utils::ws_diff_states_iter_c_aoc<gemm_acc_t> diff_dst_iter_c(
260 rnn, diff_dst_iter_c_);
261 const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_,
262 types::data_type_size(rnn.dst_iter_c_dt),
263 rnn.ws_states_iter_c_nld, dst_iter_c_ld);
264 const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_,
265 types::data_type_size(rnn.src_iter_c_dt),
266 rnn.ws_states_iter_c_nld, src_iter_c_ld);
267 const rnn_utils::augru_attention_aoc<const dst_layer_t> augru_attention(
268 rnn, augru_attention_);
269 const ws_states_iter_aoc<const src_iter_t> src_iter(
270 rnn, src_iter_, src_iter_ld);
271 const ws_gates_aoc<scratch_t> scratch_cell(rnn, scratch_cell_);
272 const utils::array_offset_calculator<scratch_t, 2> hG1(
273 scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld);
274 const utils::array_offset_calculator<gates_t, 2> ws_grid(
275 ws_grid_, rnn.mb, rnn.dhc);
276// Since the function F(...) returns by reference so an exception has
277// to be made for nullptr argument
278#define SAFE_PTR(F, ...) (CONCAT2(F, _) ? &(F(__VA_ARGS__)) : nullptr)
279 // Todo: add parallelization on dhc for the batch 1 case
280 // Assumption: the kernel runs a loop on dhc elements
281 parallel_nd(rnn.mb, [&](dim_t i) {
282 void *param1_, *param2_, *param4_, *param5_, *param7_, *param8_,
283 *param9_;
284 const void *param3_, *param6_;
285 static constexpr size_t param10_ = 0;
286 const void *param11_ = nullptr;
287 void *param12_ = nullptr;
288 switch (pd_->cell_kind()) {
289 case alg_kind::vanilla_lstm:
290 param1_ = SAFE_PTR(ws_gates, i, 0, 0);
291 param2_ = SAFE_PTR(scratch_gates, i, 0, 0); //RNN, LSTM, GRU
292 param3_ = SAFE_PTR(diff_dst_layer, i, 0);
293 param4_ = SAFE_PTR(diff_dst_iter, i, 0);
294 param5_ = SAFE_PTR(diff_src_iter_c, i, 0);
295 param6_ = SAFE_PTR(diff_dst_iter_c, i, 0);
296 param7_ = const_cast<void *>(src_iter_c(i, 0));
297 param8_ = const_cast<void *>(dst_iter_c(i, 0));
298 param9_ = (void *)SAFE_PTR(weights_peephole, 0, 0);
299 break;
300 case alg_kind::lbr_gru:
301 param1_ = SAFE_PTR(ws_gates, i, 0, 0);
302 param2_ = SAFE_PTR(scratch_gates, i, 0, 0);
303 param3_ = SAFE_PTR(diff_dst_layer, i, 0);
304 param4_ = SAFE_PTR(diff_dst_iter, i, 0);
305 param5_ = SAFE_PTR(diff_src_iter, i, 0);
306 param6_ = SAFE_PTR(src_iter, i, 0);
307 param7_ = SAFE_PTR(scratch_cell, i, 0, 0);
308 param8_ = SAFE_PTR(ws_grid, i, 0);
309 param9_ = nullptr;
310 break;
311 case alg_kind::vanilla_gru:
312 // TODO: split part 1 and part2 APIs/ABIs
313 param1_ = SAFE_PTR(ws_gates, i, 0, 0);
314 param2_ = SAFE_PTR(scratch_gates, i, 0, 0); //RNN, LSTM, GRU
315 param3_ = SAFE_PTR(diff_dst_layer, i, 0); // non part2
316 param4_ = SAFE_PTR(diff_dst_iter, i, 0); // non part2
317 param5_ = SAFE_PTR(diff_src_iter, i, 0);
318 param6_ = SAFE_PTR(src_iter, i, 0);
319 param7_ = scratch_cell_ ? &hG1(i, 0) : nullptr; // non part1
320 param8_ = SAFE_PTR(ws_grid, i, 0); // non part1
321 param9_ = SAFE_PTR(diff_src_layer, i, 0); // non part1
322 break;
323 case alg_kind::lbr_augru:
324 param1_ = SAFE_PTR(ws_gates, i, 0, 0);
325 param2_ = SAFE_PTR(scratch_gates, i, 0, 0);
326 param3_ = SAFE_PTR(diff_dst_layer, i, 0);
327 param4_ = SAFE_PTR(diff_dst_iter, i, 0);
328 param5_ = SAFE_PTR(diff_src_iter, i, 0);
329 param6_ = SAFE_PTR(src_iter, i, 0);
330 param7_ = SAFE_PTR(scratch_cell, i, 0, 0);
331 param8_ = SAFE_PTR(ws_grid, i, 0);
332 param9_ = nullptr;
333 param11_ = SAFE_PTR(augru_attention, i);
334 param12_ = SAFE_PTR(diff_augru_attention, i);
335 break;
336 case alg_kind::vanilla_augru:
337 // TODO: split part 1 and part2 APIs/ABIs
338 param1_ = SAFE_PTR(ws_gates, i, 0, 0);
339 param2_ = SAFE_PTR(scratch_gates, i, 0, 0); //RNN, LSTM, GRU
340 param3_ = SAFE_PTR(diff_dst_layer, i, 0); // non part2
341 param4_ = SAFE_PTR(diff_dst_iter, i, 0); // non part2
342 param5_ = SAFE_PTR(diff_src_iter, i, 0);
343 param6_ = SAFE_PTR(src_iter, i, 0);
344 param7_ = scratch_cell_ ? &hG1(i, 0) : nullptr; // non part1
345 param8_ = SAFE_PTR(ws_grid, i, 0); // non part1
346 param9_ = SAFE_PTR(diff_src_layer, i, 0); // non part1
347 param11_ = SAFE_PTR(augru_attention, i);
348 param12_ = SAFE_PTR(diff_augru_attention, i);
349 break;
350 case alg_kind::vanilla_rnn:
351 param1_ = SAFE_PTR(ws_gates, i, 0, 0);
352 param2_ = SAFE_PTR(scratch_gates, i, 0, 0);
353 param3_ = SAFE_PTR(diff_dst_layer, i, 0);
354 param4_ = SAFE_PTR(diff_dst_iter, i, 0);
355 param5_ = nullptr;
356 param6_ = nullptr;
357 param7_ = nullptr;
358 param8_ = nullptr;
359 param9_ = nullptr;
360 break;
361 default:
362 assert(!"unsupported");
363 param1_ = nullptr;
364 param2_ = nullptr;
365 param3_ = nullptr;
366 param4_ = nullptr;
367 param5_ = nullptr;
368 param6_ = nullptr;
369 param7_ = nullptr;
370 param8_ = nullptr;
371 param9_ = nullptr;
372 break;
373 }
374 this->operator()(param1_, param2_, param3_, param4_, param5_,
375 param6_, param7_, param8_, param9_, param10_, param11_,
376 param12_);
377 });
378#undef SAFE_PTR
379 }
380
381protected:
382 void init_regs(
383 float *weights_scales, size_t vlen, size_t tail_elements = 0) {
384 if (is_avx512 && tail_elements > 0) {
385 mov(tmp_reg, size_t((1 << tail_elements) - 1));
386 kmovq(zmm_tail_k_mask, tmp_reg);
387 is_zmm_mask_initialized = true;
388 }
389 switch (pd_->weights_md()->data_type) {
390 case data_type::bf16: {
391 /* bfloat downconvert init */
392 if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16();
393 /* init mask for upconvert */
394 const auto tmp_reg32 = tmp_reg.cvt32();
395 mov(tmp_reg32, 1);
396 kmovd(bf16_k_mask, tmp_reg32);
397 break;
398 }
399 case data_type::s8: {
400 /* int8 (de)quantization init*/
401 mov(qtable, qlabel);
402 if (rnn_.is_brgemm && !rnn_.unfused_post_gemm) {
403 auto base_args = get_stack_params_address();
404 // Read param #9
405#ifdef _WIN32
406 mov(weights_scales_reg, ptr[base_args + 32]);
407#else
408 mov(weights_scales_reg, ptr[base_args + 16]);
409#endif
410 } else {
411 float *weights_scales
412 = pd_->attr()->rnn_weights_qparams_.scales_;
413 mov(weights_scales_reg, size_t(weights_scales));
414 }
415
416 zero_addr = ptr[qtable];
417 u8_saturation_addr = ptr[qtable + vlen];
418 dscale_off_addr = ptr[qtable + 2 * vlen];
419 dshift_off_addr = ptr[qtable + 3 * vlen];
420 ymm_perm_mask_addr = ptr[qtable + 4 * vlen];
421 zmm_perm_mask_addr
422 = ptr[qtable + 4 * vlen + cpu_isa_traits<avx>::vlen];
423 break;
424 }
425 case data_type::f32: {
426 break;
427 }
428 default: assert(!"not supported");
429 }
430 }
431
432 void init_regs(size_t vlen, size_t tail_elements = 0) {
433 assert(pd_->weights_md()->data_type != data_type::s8);
434 return init_regs(nullptr, vlen, tail_elements);
435 };
436
437 void init_table(size_t vlen) {
438 if (pd_->weights_md()->data_type != data_type::s8) return;
439 /* int8 (de)quantization init*/
440 const primitive_attr_t *attr = pd_->attr();
441 const float data_scale = attr->rnn_data_qparams_.scale_;
442 const float data_shift = attr->rnn_data_qparams_.shift_;
443
444 L(qlabel);
445 {
446 for (size_t i = 0; i < vlen / sizeof(float); i++)
447 dd(float2int(0.0f));
448 for (size_t i = 0; i < vlen / sizeof(float); i++)
449 dd(float2int(255.0f));
450 for (size_t i = 0; i < vlen / sizeof(float); i++)
451 dd(float2int(data_scale));
452 for (size_t i = 0; i < vlen / sizeof(float); i++)
453 dd(float2int(data_shift));
454 // perm mask for ymm
455 dd(0);
456 dd(4);
457 dd(2);
458 dd(3);
459 dd(1);
460 dd(5);
461 dd(6);
462 dd(7);
463 // perm mask for zmm
464 dd(0);
465 dd(4);
466 dd(8);
467 dd(12);
468 dd(1);
469 dd(5);
470 dd(6);
471 dd(7);
472 dd(2);
473 dd(9);
474 dd(10);
475 dd(11);
476 dd(3);
477 dd(12);
478 dd(13);
479 dd(14);
480 }
481 }
482
483 void inc_regs(int mask, size_t vlen) {
484 if (pd_->weights_md()->data_type == data_type::s8) {
485 if (mask != 0) add(weights_scales_reg, vlen);
486 }
487 }
488 void inc_regs(size_t vlen) {
489 assert(pd_->weights_md()->data_type != data_type::s8);
490 inc_regs(0, vlen);
491 }
492
493#ifdef DNNL_ENABLE_FAST_RCP
494 template <typename Vmm>
495 void fast_recip(Vmm s, Vmm tmp, int vlen_bytes) {
496 if (can_do_zmm_masked_tail_processing(s, vlen_bytes)) {
497 Xbyak::Zmm s_masked
498 = Xbyak::Zmm(s.getIdx()) | zmm_tail_k_mask | T_z;
499 uni_vrcpps(tmp_masked, s);
500 } else if (vlen_bytes == (int)s.getBit() / 8) {
501 // no tail processing
502 uni_vrcpps(tmp, s);
503 } else if (4 == vlen_bytes) {
504 // special case for scalar-based tail processing to prevent divide by zero
505 uni_vrcpss(tmp, s);
506 } else
507 assert(!"unsupported case");
508
509 // we add one Newton iteration
510 uni_vmulps(s, s, tmp);
511 uni_vmulps(s, s, tmp); // s <- s * tmp^2
512 uni_vaddps(tmp, tmp, tmp);
513 uni_vsubps(tmp, tmp, s);
514 uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2
515 }
516#endif
517
518 // quantize from float to u8
519 // Assumption: write_only = true assumes that the quantized value
520 // to write is in src
521 template <typename Vmm>
522 void q_d(data_type_t src_data_t, Xbyak::Address dst, Vmm src, int in_len,
523 bool write_only = false) {
524 Vmm qd_vmm(qd_reg_idx);
525 if (!write_only) {
526 uni_vpxor(qd_vmm, qd_vmm, qd_vmm);
527 uni_vmulps(src, src, dscale_off_addr); // apply scale
528 uni_vaddps(src, src, dshift_off_addr); // apply shift
529 // To saturate properly, we use min/max on the float value
530 uni_vmaxps(src, src, zero_addr);
531 uni_vminps(src, src, u8_saturation_addr);
532 uni_vcvtps2dq(src, src); // convert to int32
533 uni_vpackssdw(src, src, qd_vmm); // convert from s32 to s16
534 // convert from s16 to u8/s8 with saturation
535 if (src_data_t == data_type::u8)
536 uni_vpackuswb(src, src, qd_vmm);
537 else
538 uni_vpacksswb(src, src, qd_vmm);
539 }
540
541 if (can_do_zmm_masked_tail_processing(src, in_len)) {
542 if (!write_only) {
543 Xbyak::Zmm srcz(src.getIdx()), tmpz(qd_vmm.getIdx());
544 uni_vmovups(tmpz, zmm_perm_mask_addr);
545 vpermd(srcz, tmpz, srcz);
546 }
547
548 Xbyak::Zmm src_masked = Xbyak::Zmm(src.getIdx()) | zmm_tail_k_mask;
549 vmovdqu8(dst, src_masked);
550 return;
551 }
552
553 // Note that the results are interleaved by 128 bit chunks, so we need to merge them together
554 switch (in_len) {
555 case 64: { // Intel AVX-512
556 if (!write_only) {
557 Xbyak::Zmm srcz(src.getIdx()), tmpz(qd_vmm.getIdx());
558 uni_vmovups(tmpz, zmm_perm_mask_addr);
559 vpermd(srcz, tmpz, srcz);
560 }
561 uni_vmovups(dst, Xbyak::Xmm(src.getIdx()));
562 break;
563 }
564 case 32: { // Intel AVX
565 if (!write_only) {
566 Xbyak::Ymm srcy(src.getIdx()), tmpy(qd_vmm.getIdx());
567 uni_vmovups(tmpy, ymm_perm_mask_addr);
568 vpermd(srcy, tmpy, srcy);
569 }
570 uni_vmovsd(dst, Xbyak::Xmm(src.getIdx()));
571 break;
572 }
573 case 16: // sse: nothing to do
574 uni_vmovss(dst, Xbyak::Xmm(src.getIdx()));
575 break;
576 case 4: uni_vpextrb(dst, Xbyak::Xmm(src.getIdx()), 0x0); break;
577
578 default: assert(!"unsupported case");
579 };
580 }
581
582 // dequantize from s32 to float
583 template <typename Vmm>
584 void deq_w(data_type_t src_data_t, Vmm s, Vmm tmp1, Vmm tmp2,
585 dim_t scale_off, int mask, int vlen_bytes,
586 Xbyak::Reg64 *comp = nullptr) {
587 // nothing to do if not int8
588 if (!utils::one_of(src_data_t, data_type::u8, data_type::s8)) return;
589
590 size_t qscale_dt_size = sizeof(float);
591
592 // TODO: if mask is 0 precompute mul and inverse
593 if (mask == 0)
594 uni_vbroadcastss(tmp1, ptr[weights_scales_reg]);
595 else {
596 auto scales_ptr
597 = ptr[weights_scales_reg + scale_off * qscale_dt_size];
598 load(tmp1, scales_ptr, data_type::f32, vlen_bytes);
599 }
600 uni_vcvtdq2ps(s, s);
601 // Here we subtract a compensation if need be
602 if (comp) { uni_vsubps(s, s, ptr[*comp]); }
603 uni_vmulps(tmp1, tmp1, dscale_off_addr);
604#ifdef DNNL_ENABLE_FAST_RCP
605 fast_recip(tmp1, tmp2, vlen_bytes);
606 uni_vmulps(s, s, tmp1);
607#else
608 if (can_do_zmm_masked_tail_processing(s, vlen_bytes)) {
609 Xbyak::Zmm s_masked
610 = Xbyak::Zmm(s.getIdx()) | zmm_tail_k_mask | T_z;
611 uni_vdivps(s_masked, s, tmp1);
612 } else
613 uni_vdivps(s, s, tmp1);
614#endif
615 }
616
617 // dequantize from u8 to float
618 template <typename Vmm>
619 void deq_h(Vmm dst, Xbyak::Address src, int in_len) {
620 if (can_do_zmm_masked_tail_processing(dst, in_len)) {
621 Xbyak::Zmm dst_masked
622 = Xbyak::Zmm(dst.getIdx()) | zmm_tail_k_mask | T_z;
623 uni_vpmovzxbd(dst_masked, src);
624 } else if (4 == in_len) {
625 // special case for scalar-based tail processing
626 Xbyak::Xmm dst_xmm = Xbyak::Xmm(dst.getIdx());
627 uni_vpinsrb(dst_xmm, dst_xmm, src, 0x0);
628 uni_vpmovzxbd(dst_xmm, dst_xmm);
629 } else if (in_len == (int)dst.getBit() / 8) {
630 // no tail processing
631 uni_vpmovzxbd(dst, src);
632 } else {
633 assert(!"unsupported case");
634 }
635 uni_vcvtdq2ps(dst, dst);
636 uni_vsubps(dst, dst, dshift_off_addr);
637 uni_vdivps(dst, dst, dscale_off_addr);
638 }
639
640 // upconvert from bf16 to float
641 template <typename Vmm>
642 void bf16_uc(Vmm dst, Xbyak::Address src, int in_len) {
643 switch (in_len) {
644 case 64: vpmovzxwd(dst, src); break;
645 case 4: vpmovzxwd(dst | bf16_k_mask | T_z, src); break;
646 default:
647 assert(is_zmm_mask_initialized);
648 vpmovzxwd(dst | zmm_tail_k_mask | T_z, src);
649 }
650
651 vpslld(dst, dst, 0x10);
652 }
653
654 // downconvert from float to bf16
655 // Assumption: write_only = true assumes that we want to
656 // immediately rewrite the downconverted result that is still in
657 // bf16_dq_reg_idx
658 template <typename Vmm>
659 void bf16_dc(
660 Xbyak::Address dst, Vmm src, int in_len, bool write_only = false) {
661 Xbyak::Zmm srcz(src.getIdx());
662 Xbyak::Ymm bf16_reg_dc(bf16_dq_reg_idx);
663 if (!write_only) {
664 if (bf16_emu_)
665 bf16_emu_->vcvtneps2bf16(bf16_reg_dc, srcz);
666 else
667 vcvtneps2bf16(bf16_reg_dc, srcz);
668 }
669 switch (in_len) {
670 case 64: uni_vmovups(dst, bf16_reg_dc); break;
671 case 4:
672 uni_vpextrw(dst, Xbyak::Xmm(bf16_reg_dc.getIdx()), 0x0);
673 break;
674 default:
675 assert(is_zmm_mask_initialized);
676 vmovdqu16(dst, Xbyak::Zmm(bf16_dq_reg_idx) | zmm_tail_k_mask);
677 }
678 }
679
680 // handles quantization/conversion and write to memory
681 // Note: values in src register might be modified
682 // Assumption: write_only = true assumes that
683 // 1. to_src was already called with the same source and with
684 // write_only = false.
685 // 2. the src register and the temporary registers for
686 // quantization/downconvert were not overritten in between the two
687 // calls
688 template <typename Vmm>
689 void to_src(const Xbyak::Address &dst, const Vmm &src, data_type_t src_dt,
690 int in_len, bool write_only = false) {
691 switch (src_dt) {
692 case data_type::f32: store(dst, src, src_dt, in_len); break;
693 case data_type::bf16: bf16_dc(dst, src, in_len, write_only); break;
694 case data_type::u8:
695 case data_type::s8:
696 q_d(src_dt, dst, src, in_len, write_only);
697 break;
698 default: assert(!"unsupported");
699 }
700 }
701
702 template <typename Vmm>
703 void to_float(const Vmm &dst, const Xbyak::Address &src, data_type_t src_dt,
704 int in_len) {
705 switch (src_dt) {
706 case data_type::f32: load(dst, src, src_dt, in_len); break;
707 case data_type::bf16: bf16_uc(dst, src, in_len); break;
708 case data_type::u8:
709 case data_type::s8: deq_h(dst, src, in_len); break;
710 default: assert(!"unsupported");
711 }
712 }
713
714 template <typename Vmm>
715 void load(const Vmm &dst, const Xbyak::Address &src, data_type_t dt,
716 int vlen_bytes) {
717 if (can_do_zmm_masked_tail_processing(dst, vlen_bytes)) {
718 load_zmm_masked(dst, src, dt);
719 return;
720 }
721
722 if (((int)dst.getBit() / 8) == vlen_bytes)
723 uni_vmovups(dst, src);
724 else if (4 == vlen_bytes)
725 // special case for scalar-based tail processing
726 uni_vmovss(dst, src);
727 else
728 assert(!"unsupported case");
729 }
730
731 template <typename Vmm>
732 void compute_vaddps(
733 const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
734 if (vlen_bytes == 4)
735 // special case for scalar-based tail processing
736 uni_vaddss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
737 Xbyak::Xmm(v3.getIdx()));
738 else
739 uni_vaddps(v1, v2, v3);
740 }
741
742 template <typename Vmm>
743 void compute_vsubps(
744 const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
745 if (vlen_bytes == 4)
746 // special case for scalar-based tail processing
747 uni_vsubss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
748 Xbyak::Xmm(v3.getIdx()));
749 else
750 uni_vsubps(v1, v2, v3);
751 }
752
753 template <typename Vmm>
754 void compute_vsubps(const Vmm &v1, const Vmm &v2, const Vmm &v3,
755 const Vmm &buf, int vlen_bytes) {
756 if (vlen_bytes == 4)
757 // special case for scalar-based tail processing
758 uni_vsubss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
759 Xbyak::Xmm(v3.getIdx()), Xbyak::Xmm(buf.getIdx()));
760 else
761 uni_vsubps(v1, v2, v3, buf);
762 }
763
764 template <typename Vmm>
765 void compute_vmulps(
766 const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
767 if (vlen_bytes == 4)
768 // special case for scalar-based tail processing
769 uni_vmulss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
770 Xbyak::Xmm(v3.getIdx()));
771 else
772 uni_vmulps(v1, v2, v3);
773 }
774
775 template <typename Vmm>
776 void compute_vmulps(const Vmm &v1, const Vmm &v2, const Vmm &v3,
777 const Vmm &buf, int vlen_bytes) {
778 if (vlen_bytes == 4)
779 // special case for scalar-based tail processing
780 uni_vmulss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
781 Xbyak::Xmm(v3.getIdx()), Xbyak::Xmm(buf.getIdx()));
782 else
783 uni_vmulps(v1, v2, v3, buf);
784 }
785
786 template <typename Vmm>
787 void compute_vfmadd231ps(
788 const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
789 if (vlen_bytes == 4)
790 // special case for scalar-based tail processing
791 uni_vfmadd231ss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
792 Xbyak::Xmm(v3.getIdx()));
793 else
794 uni_vfmadd231ps(v1, v2, v3);
795 }
796
797 template <typename Vmm>
798 void compute_vfmadd231ps(const Vmm &v1, const Vmm &v2,
799 const Xbyak::Address &addr, int vlen_bytes,
800 /* required for isa below avx2 only */
801 const Vmm &tmp_vmm_for_address_load) {
802 if (!is_avx2) {
803 // to avoid issues with not 16 bytes aligned memory for sse4.1 or
804 // overriding v2 values for avx load values from memory to provided
805 // tmp_vmm_for_address_load and use variant with vmm arguments only
806 load(tmp_vmm_for_address_load, addr, data_type::f32, vlen_bytes);
807 compute_vfmadd231ps(v1, tmp_vmm_for_address_load, v2, vlen_bytes);
808 return;
809 }
810
811 if (can_do_zmm_masked_tail_processing(v1, vlen_bytes)) {
812 Xbyak::Zmm dst_masked
813 = Xbyak::Zmm(v1.getIdx()) | zmm_tail_k_mask | T_z;
814 uni_vfmadd231ps(dst_masked, Xbyak::Zmm(v2.getIdx()), addr);
815 return;
816 }
817
818 if (vlen_bytes == 4)
819 // special case for scalar-based tail processing
820 uni_vfmadd231ss(
821 Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), addr);
822 else
823 uni_vfmadd231ps(v1, v2, addr);
824 }
825
826 template <typename Vmm>
827 void compute_vfmadd213ps(
828 const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
829 if (vlen_bytes == 4)
830 // special case for scalar-based tail processing
831 uni_vfmadd213ss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
832 Xbyak::Xmm(v3.getIdx()));
833 else
834 uni_vfmadd213ps(v1, v2, v3);
835 }
836
837 template <typename Vmm>
838 void store(const Xbyak::Address &dst, const Vmm &src, data_type_t dt,
839 int vlen_bytes) {
840 if (can_do_zmm_masked_tail_processing(src, vlen_bytes)) {
841 store_zmm_masked(dst, src, dt);
842 return;
843 }
844
845 MAYBE_UNUSED(dt);
846 if (((int)src.getBit() / 8) == vlen_bytes)
847 uni_vmovups(dst, src);
848 else if (4 == vlen_bytes)
849 // special case for scalar-based tail processing
850 uni_vmovss(dst, src);
851 else
852 assert(!"unsupported case");
853 }
854
855 const rnn_utils::rnn_conf_t &rnn_;
856 const rnn_pd_t *pd_;
857 bool projection_;
858 bf16_emulation_t *bf16_emu_ = nullptr;
859 const size_t bias_dt_size_;
860 const size_t cstate_dt_size_;
861 const bool is_avx512;
862 const bool is_avx2;
863
864private:
865 // registers/Labels used for int8 quantization and conversions
866 Xbyak::Address dscale_off_addr;
867 Xbyak::Address dshift_off_addr;
868 Xbyak::Address ymm_perm_mask_addr;
869 Xbyak::Address zmm_perm_mask_addr;
870 Xbyak::Address zero_addr;
871 Xbyak::Address u8_saturation_addr;
872 Xbyak::Reg64 weights_scales_reg;
873 Xbyak::Reg64 qtable;
874 Xbyak::Label qlabel;
875 int tmp_vector_register_idx;
876 int qd_reg_idx;
877
878 // registers used for bf16 conversions
879 Xbyak::Zmm bf16_reg1;
880 Xbyak::Zmm bf16_reg2;
881 Xbyak::Zmm bf16_reg3;
882 Xbyak::Reg64 bf16_reg4;
883 Xbyak::Zmm bf16_reg5;
884 Xbyak::Reg64 bf16_reg_mask;
885 Xbyak::Opmask bf16_k_mask;
886 Xbyak::Reg64 tmp_reg;
887 Xbyak::Opmask zmm_tail_k_mask;
888
889 int bf16_dq_reg_idx;
890 bool is_zmm_mask_initialized = false;
891
892 template <typename Vmm>
893 bool can_do_zmm_masked_tail_processing(Vmm vmm_reg, int in_len_bytes) {
894 const int vmm_bytes = vmm_reg.getBit() / 8;
895 return is_zmm_mask_initialized && vmm_bytes == 64
896 && in_len_bytes < vmm_bytes;
897 }
898
899 template <typename Vmm>
900 void load_zmm_masked(
901 const Vmm &dst, const Xbyak::Address &src, data_type_t dt) {
902 Xbyak::Zmm dst_masked
903 = Xbyak::Zmm(dst.getIdx()) | zmm_tail_k_mask | T_z;
904 switch (dt) {
905 case data_type::bf16: vmovdqu16(dst_masked, src); break;
906 case data_type::s8:
907 case data_type::u8: vmovdqu8(dst_masked, src); break;
908 default: vmovups(dst_masked, src);
909 }
910 }
911
912 template <typename Vmm>
913 void store_zmm_masked(
914 const Xbyak::Address &dst, const Vmm &src, data_type_t dt) {
915 const Xbyak::Zmm src_masked
916 = Xbyak::Zmm(src.getIdx()) | zmm_tail_k_mask;
917 switch (dt) {
918 case data_type::bf16: vmovdqu16(dst, src_masked); break;
919 case data_type::s8:
920 case data_type::u8: vmovdqu8(dst, src_masked); break;
921 default: vmovups(dst, src_masked);
922 }
923 }
924};
925
926} // namespace x64
927} // namespace cpu
928} // namespace impl
929} // namespace dnnl
930
931#endif
932