1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_RNN_COMMON_POSTGEMM_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_RNN_COMMON_POSTGEMM_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/rnn_pd.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
26 | #include "cpu/x64/jit_generator.hpp" |
27 | |
28 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
29 | |
30 | #include "cpu/rnn/rnn_utils.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | struct jit_uni_rnn_postgemm : public jit_generator { |
38 | |
39 | jit_uni_rnn_postgemm(const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd, |
40 | const char *name) |
41 | : jit_generator(name) |
42 | , rnn_(rnn) |
43 | , pd_(pd) |
44 | , projection_(false) |
45 | , bias_dt_size_(types::data_type_size(rnn.bias_dt)) |
46 | , cstate_dt_size_(types::data_type_size(rnn.src_iter_c_dt)) |
47 | , is_avx512(mayiuse(avx512_core)) |
48 | , is_avx2(mayiuse(avx2)) |
49 | , dscale_off_addr(0) |
50 | , dshift_off_addr(0) |
51 | , ymm_perm_mask_addr(0) |
52 | , zmm_perm_mask_addr(0) |
53 | , zero_addr(0) |
54 | , u8_saturation_addr(0) |
55 | , weights_scales_reg(r13) |
56 | , qtable(r14) |
57 | // implementations avoids to preserve Vmm(0) because of potential |
58 | // conflict with required in injectors usage for masks on sse4.1 |
59 | // so it can be used as commong temporal vector register |
60 | , tmp_vector_register_idx(0) |
61 | , qd_reg_idx(tmp_vector_register_idx) |
62 | , bf16_reg1(zmm31) |
63 | , bf16_reg2(zmm30) |
64 | , bf16_reg3(zmm29) |
65 | , bf16_reg4(r13) |
66 | , bf16_reg5(zmm28) |
67 | , bf16_k_mask(k2) |
68 | , tmp_reg(bf16_reg4) |
69 | , zmm_tail_k_mask(k3) |
70 | , bf16_dq_reg_idx(tmp_vector_register_idx) {} |
71 | |
72 | ~jit_uni_rnn_postgemm() { |
73 | if (bf16_emu_) delete bf16_emu_; |
74 | } |
75 | |
76 | bool is_projection() const { return projection_; }; |
77 | |
78 | virtual status_t init(data_type_t src_data_t) { |
79 | // no need to check as bf16 is guarded for avx512 and above in rnn primtive |
80 | using namespace Xbyak; |
81 | if (src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16)) { |
82 | bf16_emu_ = new bf16_emulation_t(this, bf16_reg1, bf16_reg2, |
83 | bf16_reg3, bf16_reg4, bf16_reg5); |
84 | |
85 | } else |
86 | bf16_emu_ = nullptr; |
87 | return status::success; |
88 | } |
89 | |
90 | template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t, |
91 | typename gemm_acc_t, typename gates_t, typename scratch_t> |
92 | rnn_postgemm_sig(execute) { |
93 | if (pd_->desc()->prop_kind == prop_kind::backward) |
94 | execute_bwd(rnn, cell_position, ws_gates_, scratch_gates_, |
95 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, |
96 | src_iter_c_, diff_src_layer_, diff_augru_attention_, |
97 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
98 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_, |
99 | ws_grid_, scratch_cell_, dst_iter_, weights_scales_, |
100 | block_step); |
101 | else |
102 | execute_fwd(rnn, cell_position, ws_gates_, scratch_gates_, |
103 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, |
104 | src_iter_c_, diff_src_layer_, diff_augru_attention_, |
105 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
106 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_, |
107 | ws_grid_, scratch_cell_, dst_iter_, weights_scales_, |
108 | block_step); |
109 | } |
110 | |
111 | template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t, |
112 | typename gemm_acc_t, typename gates_t, typename scratch_t> |
113 | rnn_postgemm_sig(execute_fwd) { |
114 | using namespace rnn_utils; |
115 | if (rnn.is_brgemm && !rnn_.unfused_post_gemm) { |
116 | for (int i = 0; i < rnn.m_block; i++) |
117 | postgemm_fwd_call(i, rnn, cell_position, ws_gates_, |
118 | scratch_gates_, augru_attention_, dst_layer_, |
119 | dst_iter_c_, src_iter_, src_iter_c_, weights_peephole_, |
120 | bias_, ws_grid_, scratch_cell_, dst_iter_, |
121 | weights_scales_, block_step); |
122 | } else { |
123 | // Todo: add parallelization on dhc for the batch 1 case |
124 | // Assumption: the kernel runs a loop on dhc elements |
125 | parallel_nd(rnn.mb, [&](dim_t i) { |
126 | postgemm_fwd_call(i, rnn, cell_position, ws_gates_, |
127 | scratch_gates_, augru_attention_, dst_layer_, |
128 | dst_iter_c_, src_iter_, src_iter_c_, weights_peephole_, |
129 | bias_, ws_grid_, scratch_cell_, dst_iter_, |
130 | weights_scales_, 0); |
131 | }); |
132 | } |
133 | } |
134 | |
135 | template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t, |
136 | typename gates_t, typename scratch_t> |
137 | inline void postgemm_fwd_call(int m, const rnn_utils::rnn_conf_t &rnn, |
138 | rnn_utils::cell_position_t cell_position, gates_t *ws_gates_, |
139 | scratch_t *scratch_gates_, const dst_layer_t *augru_attention_, |
140 | dst_layer_t *dst_layer_, void *dst_iter_c_, |
141 | const src_iter_t *src_iter_, const void *src_iter_c_, |
142 | const float *weights_peephole_, const void *bias_, |
143 | gates_t *ws_grid_, scratch_t *scratch_cell_, dst_iter_t *dst_iter_, |
144 | float *weights_scales_, int block_step) const { |
145 | const rnn_utils::ws_gates_aoc<gates_t> ws_gates(rnn, ws_gates_); |
146 | const rnn_utils::scratch_gates_aoc<scratch_t> scratch_gates( |
147 | rnn, scratch_gates_); |
148 | const rnn_utils::weights_peephole_aoc_t<const float> weights_peephole( |
149 | rnn, weights_peephole_); |
150 | const auto bias = rnn_utils::make_raw_aoc( |
151 | bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc); |
152 | |
153 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
154 | const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position); |
155 | const auto dst_layer_ld |
156 | = rnn.dst_layer_ld(cell_position, is_projection()); |
157 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
158 | const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position); |
159 | |
160 | const rnn_utils::ws_states_layer_aoc<dst_layer_t> dst_layer( |
161 | rnn, dst_layer_, dst_layer_ld); |
162 | const rnn_utils::ws_states_iter_aoc<dst_iter_t> dst_iter( |
163 | rnn, dst_iter_, dst_iter_ld); |
164 | const rnn_utils::ws_states_iter_aoc<const src_iter_t> src_iter( |
165 | rnn, src_iter_, src_iter_ld); |
166 | const rnn_utils::augru_attention_aoc<const dst_layer_t> augru_attention( |
167 | rnn, augru_attention_); |
168 | const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_, |
169 | types::data_type_size(rnn.dst_iter_c_dt), |
170 | rnn.ws_states_iter_c_nld, dst_iter_c_ld); |
171 | const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_, |
172 | types::data_type_size(rnn.src_iter_c_dt), |
173 | rnn.ws_states_iter_c_nld, src_iter_c_ld); |
174 | const rnn_utils::ws_gates_aoc<scratch_t> scratch_cell( |
175 | rnn, scratch_cell_); |
176 | const utils::array_offset_calculator<gates_t, 2> ws_Wh_b( |
177 | ws_grid_, rnn.mb, rnn.dhc); |
178 | |
179 | // Since the function F(...) returns by reference so an exception has |
180 | // to be made for nullptr argument |
181 | #define SAFE_PTR(F, ...) (CONCAT2(F, _) ? &(F(__VA_ARGS__)) : nullptr) |
182 | |
183 | void *param1_ = SAFE_PTR(ws_gates, m, 0, 0); // RNN, LSTM, GRU |
184 | void *param2_ = SAFE_PTR(scratch_gates, m, 0, 0); // RNN, LSTM, GRU |
185 | const void *param3_ = bias(0, 0); // RNN, LSTM, GRU |
186 | void *param4_ = SAFE_PTR(dst_layer, m, 0); // RNN, LSTM, GRU |
187 | void *param5_ = SAFE_PTR(dst_iter, m, 0); // RNN, LSTM, GRU |
188 | const void *param6_; |
189 | void *param7_, *param8_; |
190 | void *param9_ = (void *)weights_scales_; |
191 | const size_t param10_ = block_step; |
192 | const void *param11_ = nullptr; |
193 | |
194 | switch (pd_->cell_kind()) { |
195 | case alg_kind::vanilla_lstm: |
196 | param6_ = is_projection() ? src_iter_c_ : src_iter_c(m, 0); |
197 | param7_ = const_cast<void *>(dst_iter_c(m, 0)); |
198 | param8_ = (void *)SAFE_PTR(weights_peephole, 0, 0); |
199 | break; |
200 | case alg_kind::lbr_gru: |
201 | param6_ = SAFE_PTR(src_iter, m, 0); |
202 | param7_ = SAFE_PTR(scratch_cell, m, 0, 0); |
203 | param8_ = ws_grid_ ? &ws_Wh_b(m, 0) : nullptr; |
204 | break; |
205 | case alg_kind::vanilla_gru: |
206 | param6_ = SAFE_PTR(src_iter, m, 0); |
207 | param7_ = nullptr; |
208 | param8_ = nullptr; |
209 | break; |
210 | case alg_kind::lbr_augru: |
211 | param6_ = SAFE_PTR(src_iter, m, 0); |
212 | param7_ = SAFE_PTR(scratch_cell, m, 0, 0); |
213 | param8_ = ws_grid_ ? &ws_Wh_b(m, 0) : nullptr; |
214 | param11_ = SAFE_PTR(augru_attention, m); |
215 | break; |
216 | case alg_kind::vanilla_augru: |
217 | param6_ = SAFE_PTR(src_iter, m, 0); |
218 | param7_ = nullptr; |
219 | param8_ = nullptr; |
220 | param11_ = SAFE_PTR(augru_attention, m); |
221 | break; |
222 | default: |
223 | param6_ = nullptr; |
224 | param7_ = nullptr; |
225 | param8_ = nullptr; |
226 | param11_ = nullptr; |
227 | break; |
228 | } |
229 | this->operator()(param1_, param2_, param3_, param4_, param5_, param6_, |
230 | param7_, param8_, param9_, param10_, param11_); |
231 | #undef SAFE_PTR |
232 | } |
233 | |
234 | template <typename dst_layer_t, typename dst_iter_t, typename src_iter_t, |
235 | typename gemm_acc_t, typename gates_t, typename scratch_t> |
236 | rnn_postgemm_sig(execute_bwd) { |
237 | using namespace rnn_utils; |
238 | const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position); |
239 | const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position); |
240 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
241 | |
242 | const rnn_utils::weights_peephole_aoc_t<const float> weights_peephole( |
243 | rnn, weights_peephole_); |
244 | const rnn_utils::ws_gates_aoc<gates_t> ws_gates(rnn, ws_gates_); |
245 | const rnn_utils::ws_gates_aoc<scratch_t> scratch_gates( |
246 | rnn, scratch_gates_); |
247 | const rnn_utils::ws_diff_states_layer_aoc<gemm_acc_t> diff_src_layer( |
248 | rnn, diff_src_layer_); |
249 | const rnn_utils::ws_diff_states_iter_aoc<gemm_acc_t> diff_src_iter( |
250 | rnn, diff_src_iter_); |
251 | const rnn_utils::ws_diff_states_iter_c_aoc<gemm_acc_t> diff_src_iter_c( |
252 | rnn, diff_src_iter_c_); |
253 | const rnn_utils::augru_attention_aoc<gemm_acc_t> diff_augru_attention( |
254 | rnn, diff_augru_attention_); |
255 | const rnn_utils::ws_diff_states_layer_aoc<gemm_acc_t> diff_dst_layer( |
256 | rnn, diff_dst_layer_); |
257 | const rnn_utils::ws_diff_states_iter_aoc<gemm_acc_t> diff_dst_iter( |
258 | rnn, diff_dst_iter_); |
259 | const rnn_utils::ws_diff_states_iter_c_aoc<gemm_acc_t> diff_dst_iter_c( |
260 | rnn, diff_dst_iter_c_); |
261 | const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_, |
262 | types::data_type_size(rnn.dst_iter_c_dt), |
263 | rnn.ws_states_iter_c_nld, dst_iter_c_ld); |
264 | const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_, |
265 | types::data_type_size(rnn.src_iter_c_dt), |
266 | rnn.ws_states_iter_c_nld, src_iter_c_ld); |
267 | const rnn_utils::augru_attention_aoc<const dst_layer_t> augru_attention( |
268 | rnn, augru_attention_); |
269 | const ws_states_iter_aoc<const src_iter_t> src_iter( |
270 | rnn, src_iter_, src_iter_ld); |
271 | const ws_gates_aoc<scratch_t> scratch_cell(rnn, scratch_cell_); |
272 | const utils::array_offset_calculator<scratch_t, 2> hG1( |
273 | scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld); |
274 | const utils::array_offset_calculator<gates_t, 2> ws_grid( |
275 | ws_grid_, rnn.mb, rnn.dhc); |
276 | // Since the function F(...) returns by reference so an exception has |
277 | // to be made for nullptr argument |
278 | #define SAFE_PTR(F, ...) (CONCAT2(F, _) ? &(F(__VA_ARGS__)) : nullptr) |
279 | // Todo: add parallelization on dhc for the batch 1 case |
280 | // Assumption: the kernel runs a loop on dhc elements |
281 | parallel_nd(rnn.mb, [&](dim_t i) { |
282 | void *param1_, *param2_, *param4_, *param5_, *param7_, *param8_, |
283 | *param9_; |
284 | const void *param3_, *param6_; |
285 | static constexpr size_t param10_ = 0; |
286 | const void *param11_ = nullptr; |
287 | void *param12_ = nullptr; |
288 | switch (pd_->cell_kind()) { |
289 | case alg_kind::vanilla_lstm: |
290 | param1_ = SAFE_PTR(ws_gates, i, 0, 0); |
291 | param2_ = SAFE_PTR(scratch_gates, i, 0, 0); //RNN, LSTM, GRU |
292 | param3_ = SAFE_PTR(diff_dst_layer, i, 0); |
293 | param4_ = SAFE_PTR(diff_dst_iter, i, 0); |
294 | param5_ = SAFE_PTR(diff_src_iter_c, i, 0); |
295 | param6_ = SAFE_PTR(diff_dst_iter_c, i, 0); |
296 | param7_ = const_cast<void *>(src_iter_c(i, 0)); |
297 | param8_ = const_cast<void *>(dst_iter_c(i, 0)); |
298 | param9_ = (void *)SAFE_PTR(weights_peephole, 0, 0); |
299 | break; |
300 | case alg_kind::lbr_gru: |
301 | param1_ = SAFE_PTR(ws_gates, i, 0, 0); |
302 | param2_ = SAFE_PTR(scratch_gates, i, 0, 0); |
303 | param3_ = SAFE_PTR(diff_dst_layer, i, 0); |
304 | param4_ = SAFE_PTR(diff_dst_iter, i, 0); |
305 | param5_ = SAFE_PTR(diff_src_iter, i, 0); |
306 | param6_ = SAFE_PTR(src_iter, i, 0); |
307 | param7_ = SAFE_PTR(scratch_cell, i, 0, 0); |
308 | param8_ = SAFE_PTR(ws_grid, i, 0); |
309 | param9_ = nullptr; |
310 | break; |
311 | case alg_kind::vanilla_gru: |
312 | // TODO: split part 1 and part2 APIs/ABIs |
313 | param1_ = SAFE_PTR(ws_gates, i, 0, 0); |
314 | param2_ = SAFE_PTR(scratch_gates, i, 0, 0); //RNN, LSTM, GRU |
315 | param3_ = SAFE_PTR(diff_dst_layer, i, 0); // non part2 |
316 | param4_ = SAFE_PTR(diff_dst_iter, i, 0); // non part2 |
317 | param5_ = SAFE_PTR(diff_src_iter, i, 0); |
318 | param6_ = SAFE_PTR(src_iter, i, 0); |
319 | param7_ = scratch_cell_ ? &hG1(i, 0) : nullptr; // non part1 |
320 | param8_ = SAFE_PTR(ws_grid, i, 0); // non part1 |
321 | param9_ = SAFE_PTR(diff_src_layer, i, 0); // non part1 |
322 | break; |
323 | case alg_kind::lbr_augru: |
324 | param1_ = SAFE_PTR(ws_gates, i, 0, 0); |
325 | param2_ = SAFE_PTR(scratch_gates, i, 0, 0); |
326 | param3_ = SAFE_PTR(diff_dst_layer, i, 0); |
327 | param4_ = SAFE_PTR(diff_dst_iter, i, 0); |
328 | param5_ = SAFE_PTR(diff_src_iter, i, 0); |
329 | param6_ = SAFE_PTR(src_iter, i, 0); |
330 | param7_ = SAFE_PTR(scratch_cell, i, 0, 0); |
331 | param8_ = SAFE_PTR(ws_grid, i, 0); |
332 | param9_ = nullptr; |
333 | param11_ = SAFE_PTR(augru_attention, i); |
334 | param12_ = SAFE_PTR(diff_augru_attention, i); |
335 | break; |
336 | case alg_kind::vanilla_augru: |
337 | // TODO: split part 1 and part2 APIs/ABIs |
338 | param1_ = SAFE_PTR(ws_gates, i, 0, 0); |
339 | param2_ = SAFE_PTR(scratch_gates, i, 0, 0); //RNN, LSTM, GRU |
340 | param3_ = SAFE_PTR(diff_dst_layer, i, 0); // non part2 |
341 | param4_ = SAFE_PTR(diff_dst_iter, i, 0); // non part2 |
342 | param5_ = SAFE_PTR(diff_src_iter, i, 0); |
343 | param6_ = SAFE_PTR(src_iter, i, 0); |
344 | param7_ = scratch_cell_ ? &hG1(i, 0) : nullptr; // non part1 |
345 | param8_ = SAFE_PTR(ws_grid, i, 0); // non part1 |
346 | param9_ = SAFE_PTR(diff_src_layer, i, 0); // non part1 |
347 | param11_ = SAFE_PTR(augru_attention, i); |
348 | param12_ = SAFE_PTR(diff_augru_attention, i); |
349 | break; |
350 | case alg_kind::vanilla_rnn: |
351 | param1_ = SAFE_PTR(ws_gates, i, 0, 0); |
352 | param2_ = SAFE_PTR(scratch_gates, i, 0, 0); |
353 | param3_ = SAFE_PTR(diff_dst_layer, i, 0); |
354 | param4_ = SAFE_PTR(diff_dst_iter, i, 0); |
355 | param5_ = nullptr; |
356 | param6_ = nullptr; |
357 | param7_ = nullptr; |
358 | param8_ = nullptr; |
359 | param9_ = nullptr; |
360 | break; |
361 | default: |
362 | assert(!"unsupported" ); |
363 | param1_ = nullptr; |
364 | param2_ = nullptr; |
365 | param3_ = nullptr; |
366 | param4_ = nullptr; |
367 | param5_ = nullptr; |
368 | param6_ = nullptr; |
369 | param7_ = nullptr; |
370 | param8_ = nullptr; |
371 | param9_ = nullptr; |
372 | break; |
373 | } |
374 | this->operator()(param1_, param2_, param3_, param4_, param5_, |
375 | param6_, param7_, param8_, param9_, param10_, param11_, |
376 | param12_); |
377 | }); |
378 | #undef SAFE_PTR |
379 | } |
380 | |
381 | protected: |
382 | void init_regs( |
383 | float *weights_scales, size_t vlen, size_t tail_elements = 0) { |
384 | if (is_avx512 && tail_elements > 0) { |
385 | mov(tmp_reg, size_t((1 << tail_elements) - 1)); |
386 | kmovq(zmm_tail_k_mask, tmp_reg); |
387 | is_zmm_mask_initialized = true; |
388 | } |
389 | switch (pd_->weights_md()->data_type) { |
390 | case data_type::bf16: { |
391 | /* bfloat downconvert init */ |
392 | if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16(); |
393 | /* init mask for upconvert */ |
394 | const auto tmp_reg32 = tmp_reg.cvt32(); |
395 | mov(tmp_reg32, 1); |
396 | kmovd(bf16_k_mask, tmp_reg32); |
397 | break; |
398 | } |
399 | case data_type::s8: { |
400 | /* int8 (de)quantization init*/ |
401 | mov(qtable, qlabel); |
402 | if (rnn_.is_brgemm && !rnn_.unfused_post_gemm) { |
403 | auto base_args = get_stack_params_address(); |
404 | // Read param #9 |
405 | #ifdef _WIN32 |
406 | mov(weights_scales_reg, ptr[base_args + 32]); |
407 | #else |
408 | mov(weights_scales_reg, ptr[base_args + 16]); |
409 | #endif |
410 | } else { |
411 | float *weights_scales |
412 | = pd_->attr()->rnn_weights_qparams_.scales_; |
413 | mov(weights_scales_reg, size_t(weights_scales)); |
414 | } |
415 | |
416 | zero_addr = ptr[qtable]; |
417 | u8_saturation_addr = ptr[qtable + vlen]; |
418 | dscale_off_addr = ptr[qtable + 2 * vlen]; |
419 | dshift_off_addr = ptr[qtable + 3 * vlen]; |
420 | ymm_perm_mask_addr = ptr[qtable + 4 * vlen]; |
421 | zmm_perm_mask_addr |
422 | = ptr[qtable + 4 * vlen + cpu_isa_traits<avx>::vlen]; |
423 | break; |
424 | } |
425 | case data_type::f32: { |
426 | break; |
427 | } |
428 | default: assert(!"not supported" ); |
429 | } |
430 | } |
431 | |
432 | void init_regs(size_t vlen, size_t tail_elements = 0) { |
433 | assert(pd_->weights_md()->data_type != data_type::s8); |
434 | return init_regs(nullptr, vlen, tail_elements); |
435 | }; |
436 | |
437 | void init_table(size_t vlen) { |
438 | if (pd_->weights_md()->data_type != data_type::s8) return; |
439 | /* int8 (de)quantization init*/ |
440 | const primitive_attr_t *attr = pd_->attr(); |
441 | const float data_scale = attr->rnn_data_qparams_.scale_; |
442 | const float data_shift = attr->rnn_data_qparams_.shift_; |
443 | |
444 | L(qlabel); |
445 | { |
446 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
447 | dd(float2int(0.0f)); |
448 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
449 | dd(float2int(255.0f)); |
450 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
451 | dd(float2int(data_scale)); |
452 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
453 | dd(float2int(data_shift)); |
454 | // perm mask for ymm |
455 | dd(0); |
456 | dd(4); |
457 | dd(2); |
458 | dd(3); |
459 | dd(1); |
460 | dd(5); |
461 | dd(6); |
462 | dd(7); |
463 | // perm mask for zmm |
464 | dd(0); |
465 | dd(4); |
466 | dd(8); |
467 | dd(12); |
468 | dd(1); |
469 | dd(5); |
470 | dd(6); |
471 | dd(7); |
472 | dd(2); |
473 | dd(9); |
474 | dd(10); |
475 | dd(11); |
476 | dd(3); |
477 | dd(12); |
478 | dd(13); |
479 | dd(14); |
480 | } |
481 | } |
482 | |
483 | void inc_regs(int mask, size_t vlen) { |
484 | if (pd_->weights_md()->data_type == data_type::s8) { |
485 | if (mask != 0) add(weights_scales_reg, vlen); |
486 | } |
487 | } |
488 | void inc_regs(size_t vlen) { |
489 | assert(pd_->weights_md()->data_type != data_type::s8); |
490 | inc_regs(0, vlen); |
491 | } |
492 | |
493 | #ifdef DNNL_ENABLE_FAST_RCP |
494 | template <typename Vmm> |
495 | void fast_recip(Vmm s, Vmm tmp, int vlen_bytes) { |
496 | if (can_do_zmm_masked_tail_processing(s, vlen_bytes)) { |
497 | Xbyak::Zmm s_masked |
498 | = Xbyak::Zmm(s.getIdx()) | zmm_tail_k_mask | T_z; |
499 | uni_vrcpps(tmp_masked, s); |
500 | } else if (vlen_bytes == (int)s.getBit() / 8) { |
501 | // no tail processing |
502 | uni_vrcpps(tmp, s); |
503 | } else if (4 == vlen_bytes) { |
504 | // special case for scalar-based tail processing to prevent divide by zero |
505 | uni_vrcpss(tmp, s); |
506 | } else |
507 | assert(!"unsupported case" ); |
508 | |
509 | // we add one Newton iteration |
510 | uni_vmulps(s, s, tmp); |
511 | uni_vmulps(s, s, tmp); // s <- s * tmp^2 |
512 | uni_vaddps(tmp, tmp, tmp); |
513 | uni_vsubps(tmp, tmp, s); |
514 | uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 |
515 | } |
516 | #endif |
517 | |
518 | // quantize from float to u8 |
519 | // Assumption: write_only = true assumes that the quantized value |
520 | // to write is in src |
521 | template <typename Vmm> |
522 | void q_d(data_type_t src_data_t, Xbyak::Address dst, Vmm src, int in_len, |
523 | bool write_only = false) { |
524 | Vmm qd_vmm(qd_reg_idx); |
525 | if (!write_only) { |
526 | uni_vpxor(qd_vmm, qd_vmm, qd_vmm); |
527 | uni_vmulps(src, src, dscale_off_addr); // apply scale |
528 | uni_vaddps(src, src, dshift_off_addr); // apply shift |
529 | // To saturate properly, we use min/max on the float value |
530 | uni_vmaxps(src, src, zero_addr); |
531 | uni_vminps(src, src, u8_saturation_addr); |
532 | uni_vcvtps2dq(src, src); // convert to int32 |
533 | uni_vpackssdw(src, src, qd_vmm); // convert from s32 to s16 |
534 | // convert from s16 to u8/s8 with saturation |
535 | if (src_data_t == data_type::u8) |
536 | uni_vpackuswb(src, src, qd_vmm); |
537 | else |
538 | uni_vpacksswb(src, src, qd_vmm); |
539 | } |
540 | |
541 | if (can_do_zmm_masked_tail_processing(src, in_len)) { |
542 | if (!write_only) { |
543 | Xbyak::Zmm srcz(src.getIdx()), tmpz(qd_vmm.getIdx()); |
544 | uni_vmovups(tmpz, zmm_perm_mask_addr); |
545 | vpermd(srcz, tmpz, srcz); |
546 | } |
547 | |
548 | Xbyak::Zmm src_masked = Xbyak::Zmm(src.getIdx()) | zmm_tail_k_mask; |
549 | vmovdqu8(dst, src_masked); |
550 | return; |
551 | } |
552 | |
553 | // Note that the results are interleaved by 128 bit chunks, so we need to merge them together |
554 | switch (in_len) { |
555 | case 64: { // Intel AVX-512 |
556 | if (!write_only) { |
557 | Xbyak::Zmm srcz(src.getIdx()), tmpz(qd_vmm.getIdx()); |
558 | uni_vmovups(tmpz, zmm_perm_mask_addr); |
559 | vpermd(srcz, tmpz, srcz); |
560 | } |
561 | uni_vmovups(dst, Xbyak::Xmm(src.getIdx())); |
562 | break; |
563 | } |
564 | case 32: { // Intel AVX |
565 | if (!write_only) { |
566 | Xbyak::Ymm srcy(src.getIdx()), tmpy(qd_vmm.getIdx()); |
567 | uni_vmovups(tmpy, ymm_perm_mask_addr); |
568 | vpermd(srcy, tmpy, srcy); |
569 | } |
570 | uni_vmovsd(dst, Xbyak::Xmm(src.getIdx())); |
571 | break; |
572 | } |
573 | case 16: // sse: nothing to do |
574 | uni_vmovss(dst, Xbyak::Xmm(src.getIdx())); |
575 | break; |
576 | case 4: uni_vpextrb(dst, Xbyak::Xmm(src.getIdx()), 0x0); break; |
577 | |
578 | default: assert(!"unsupported case" ); |
579 | }; |
580 | } |
581 | |
582 | // dequantize from s32 to float |
583 | template <typename Vmm> |
584 | void deq_w(data_type_t src_data_t, Vmm s, Vmm tmp1, Vmm tmp2, |
585 | dim_t scale_off, int mask, int vlen_bytes, |
586 | Xbyak::Reg64 *comp = nullptr) { |
587 | // nothing to do if not int8 |
588 | if (!utils::one_of(src_data_t, data_type::u8, data_type::s8)) return; |
589 | |
590 | size_t qscale_dt_size = sizeof(float); |
591 | |
592 | // TODO: if mask is 0 precompute mul and inverse |
593 | if (mask == 0) |
594 | uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); |
595 | else { |
596 | auto scales_ptr |
597 | = ptr[weights_scales_reg + scale_off * qscale_dt_size]; |
598 | load(tmp1, scales_ptr, data_type::f32, vlen_bytes); |
599 | } |
600 | uni_vcvtdq2ps(s, s); |
601 | // Here we subtract a compensation if need be |
602 | if (comp) { uni_vsubps(s, s, ptr[*comp]); } |
603 | uni_vmulps(tmp1, tmp1, dscale_off_addr); |
604 | #ifdef DNNL_ENABLE_FAST_RCP |
605 | fast_recip(tmp1, tmp2, vlen_bytes); |
606 | uni_vmulps(s, s, tmp1); |
607 | #else |
608 | if (can_do_zmm_masked_tail_processing(s, vlen_bytes)) { |
609 | Xbyak::Zmm s_masked |
610 | = Xbyak::Zmm(s.getIdx()) | zmm_tail_k_mask | T_z; |
611 | uni_vdivps(s_masked, s, tmp1); |
612 | } else |
613 | uni_vdivps(s, s, tmp1); |
614 | #endif |
615 | } |
616 | |
617 | // dequantize from u8 to float |
618 | template <typename Vmm> |
619 | void deq_h(Vmm dst, Xbyak::Address src, int in_len) { |
620 | if (can_do_zmm_masked_tail_processing(dst, in_len)) { |
621 | Xbyak::Zmm dst_masked |
622 | = Xbyak::Zmm(dst.getIdx()) | zmm_tail_k_mask | T_z; |
623 | uni_vpmovzxbd(dst_masked, src); |
624 | } else if (4 == in_len) { |
625 | // special case for scalar-based tail processing |
626 | Xbyak::Xmm dst_xmm = Xbyak::Xmm(dst.getIdx()); |
627 | uni_vpinsrb(dst_xmm, dst_xmm, src, 0x0); |
628 | uni_vpmovzxbd(dst_xmm, dst_xmm); |
629 | } else if (in_len == (int)dst.getBit() / 8) { |
630 | // no tail processing |
631 | uni_vpmovzxbd(dst, src); |
632 | } else { |
633 | assert(!"unsupported case" ); |
634 | } |
635 | uni_vcvtdq2ps(dst, dst); |
636 | uni_vsubps(dst, dst, dshift_off_addr); |
637 | uni_vdivps(dst, dst, dscale_off_addr); |
638 | } |
639 | |
640 | // upconvert from bf16 to float |
641 | template <typename Vmm> |
642 | void bf16_uc(Vmm dst, Xbyak::Address src, int in_len) { |
643 | switch (in_len) { |
644 | case 64: vpmovzxwd(dst, src); break; |
645 | case 4: vpmovzxwd(dst | bf16_k_mask | T_z, src); break; |
646 | default: |
647 | assert(is_zmm_mask_initialized); |
648 | vpmovzxwd(dst | zmm_tail_k_mask | T_z, src); |
649 | } |
650 | |
651 | vpslld(dst, dst, 0x10); |
652 | } |
653 | |
654 | // downconvert from float to bf16 |
655 | // Assumption: write_only = true assumes that we want to |
656 | // immediately rewrite the downconverted result that is still in |
657 | // bf16_dq_reg_idx |
658 | template <typename Vmm> |
659 | void bf16_dc( |
660 | Xbyak::Address dst, Vmm src, int in_len, bool write_only = false) { |
661 | Xbyak::Zmm srcz(src.getIdx()); |
662 | Xbyak::Ymm bf16_reg_dc(bf16_dq_reg_idx); |
663 | if (!write_only) { |
664 | if (bf16_emu_) |
665 | bf16_emu_->vcvtneps2bf16(bf16_reg_dc, srcz); |
666 | else |
667 | vcvtneps2bf16(bf16_reg_dc, srcz); |
668 | } |
669 | switch (in_len) { |
670 | case 64: uni_vmovups(dst, bf16_reg_dc); break; |
671 | case 4: |
672 | uni_vpextrw(dst, Xbyak::Xmm(bf16_reg_dc.getIdx()), 0x0); |
673 | break; |
674 | default: |
675 | assert(is_zmm_mask_initialized); |
676 | vmovdqu16(dst, Xbyak::Zmm(bf16_dq_reg_idx) | zmm_tail_k_mask); |
677 | } |
678 | } |
679 | |
680 | // handles quantization/conversion and write to memory |
681 | // Note: values in src register might be modified |
682 | // Assumption: write_only = true assumes that |
683 | // 1. to_src was already called with the same source and with |
684 | // write_only = false. |
685 | // 2. the src register and the temporary registers for |
686 | // quantization/downconvert were not overritten in between the two |
687 | // calls |
688 | template <typename Vmm> |
689 | void to_src(const Xbyak::Address &dst, const Vmm &src, data_type_t src_dt, |
690 | int in_len, bool write_only = false) { |
691 | switch (src_dt) { |
692 | case data_type::f32: store(dst, src, src_dt, in_len); break; |
693 | case data_type::bf16: bf16_dc(dst, src, in_len, write_only); break; |
694 | case data_type::u8: |
695 | case data_type::s8: |
696 | q_d(src_dt, dst, src, in_len, write_only); |
697 | break; |
698 | default: assert(!"unsupported" ); |
699 | } |
700 | } |
701 | |
702 | template <typename Vmm> |
703 | void to_float(const Vmm &dst, const Xbyak::Address &src, data_type_t src_dt, |
704 | int in_len) { |
705 | switch (src_dt) { |
706 | case data_type::f32: load(dst, src, src_dt, in_len); break; |
707 | case data_type::bf16: bf16_uc(dst, src, in_len); break; |
708 | case data_type::u8: |
709 | case data_type::s8: deq_h(dst, src, in_len); break; |
710 | default: assert(!"unsupported" ); |
711 | } |
712 | } |
713 | |
714 | template <typename Vmm> |
715 | void load(const Vmm &dst, const Xbyak::Address &src, data_type_t dt, |
716 | int vlen_bytes) { |
717 | if (can_do_zmm_masked_tail_processing(dst, vlen_bytes)) { |
718 | load_zmm_masked(dst, src, dt); |
719 | return; |
720 | } |
721 | |
722 | if (((int)dst.getBit() / 8) == vlen_bytes) |
723 | uni_vmovups(dst, src); |
724 | else if (4 == vlen_bytes) |
725 | // special case for scalar-based tail processing |
726 | uni_vmovss(dst, src); |
727 | else |
728 | assert(!"unsupported case" ); |
729 | } |
730 | |
731 | template <typename Vmm> |
732 | void compute_vaddps( |
733 | const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) { |
734 | if (vlen_bytes == 4) |
735 | // special case for scalar-based tail processing |
736 | uni_vaddss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
737 | Xbyak::Xmm(v3.getIdx())); |
738 | else |
739 | uni_vaddps(v1, v2, v3); |
740 | } |
741 | |
742 | template <typename Vmm> |
743 | void compute_vsubps( |
744 | const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) { |
745 | if (vlen_bytes == 4) |
746 | // special case for scalar-based tail processing |
747 | uni_vsubss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
748 | Xbyak::Xmm(v3.getIdx())); |
749 | else |
750 | uni_vsubps(v1, v2, v3); |
751 | } |
752 | |
753 | template <typename Vmm> |
754 | void compute_vsubps(const Vmm &v1, const Vmm &v2, const Vmm &v3, |
755 | const Vmm &buf, int vlen_bytes) { |
756 | if (vlen_bytes == 4) |
757 | // special case for scalar-based tail processing |
758 | uni_vsubss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
759 | Xbyak::Xmm(v3.getIdx()), Xbyak::Xmm(buf.getIdx())); |
760 | else |
761 | uni_vsubps(v1, v2, v3, buf); |
762 | } |
763 | |
764 | template <typename Vmm> |
765 | void compute_vmulps( |
766 | const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) { |
767 | if (vlen_bytes == 4) |
768 | // special case for scalar-based tail processing |
769 | uni_vmulss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
770 | Xbyak::Xmm(v3.getIdx())); |
771 | else |
772 | uni_vmulps(v1, v2, v3); |
773 | } |
774 | |
775 | template <typename Vmm> |
776 | void compute_vmulps(const Vmm &v1, const Vmm &v2, const Vmm &v3, |
777 | const Vmm &buf, int vlen_bytes) { |
778 | if (vlen_bytes == 4) |
779 | // special case for scalar-based tail processing |
780 | uni_vmulss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
781 | Xbyak::Xmm(v3.getIdx()), Xbyak::Xmm(buf.getIdx())); |
782 | else |
783 | uni_vmulps(v1, v2, v3, buf); |
784 | } |
785 | |
786 | template <typename Vmm> |
787 | void compute_vfmadd231ps( |
788 | const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) { |
789 | if (vlen_bytes == 4) |
790 | // special case for scalar-based tail processing |
791 | uni_vfmadd231ss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
792 | Xbyak::Xmm(v3.getIdx())); |
793 | else |
794 | uni_vfmadd231ps(v1, v2, v3); |
795 | } |
796 | |
797 | template <typename Vmm> |
798 | void compute_vfmadd231ps(const Vmm &v1, const Vmm &v2, |
799 | const Xbyak::Address &addr, int vlen_bytes, |
800 | /* required for isa below avx2 only */ |
801 | const Vmm &tmp_vmm_for_address_load) { |
802 | if (!is_avx2) { |
803 | // to avoid issues with not 16 bytes aligned memory for sse4.1 or |
804 | // overriding v2 values for avx load values from memory to provided |
805 | // tmp_vmm_for_address_load and use variant with vmm arguments only |
806 | load(tmp_vmm_for_address_load, addr, data_type::f32, vlen_bytes); |
807 | compute_vfmadd231ps(v1, tmp_vmm_for_address_load, v2, vlen_bytes); |
808 | return; |
809 | } |
810 | |
811 | if (can_do_zmm_masked_tail_processing(v1, vlen_bytes)) { |
812 | Xbyak::Zmm dst_masked |
813 | = Xbyak::Zmm(v1.getIdx()) | zmm_tail_k_mask | T_z; |
814 | uni_vfmadd231ps(dst_masked, Xbyak::Zmm(v2.getIdx()), addr); |
815 | return; |
816 | } |
817 | |
818 | if (vlen_bytes == 4) |
819 | // special case for scalar-based tail processing |
820 | uni_vfmadd231ss( |
821 | Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), addr); |
822 | else |
823 | uni_vfmadd231ps(v1, v2, addr); |
824 | } |
825 | |
826 | template <typename Vmm> |
827 | void compute_vfmadd213ps( |
828 | const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) { |
829 | if (vlen_bytes == 4) |
830 | // special case for scalar-based tail processing |
831 | uni_vfmadd213ss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()), |
832 | Xbyak::Xmm(v3.getIdx())); |
833 | else |
834 | uni_vfmadd213ps(v1, v2, v3); |
835 | } |
836 | |
837 | template <typename Vmm> |
838 | void store(const Xbyak::Address &dst, const Vmm &src, data_type_t dt, |
839 | int vlen_bytes) { |
840 | if (can_do_zmm_masked_tail_processing(src, vlen_bytes)) { |
841 | store_zmm_masked(dst, src, dt); |
842 | return; |
843 | } |
844 | |
845 | MAYBE_UNUSED(dt); |
846 | if (((int)src.getBit() / 8) == vlen_bytes) |
847 | uni_vmovups(dst, src); |
848 | else if (4 == vlen_bytes) |
849 | // special case for scalar-based tail processing |
850 | uni_vmovss(dst, src); |
851 | else |
852 | assert(!"unsupported case" ); |
853 | } |
854 | |
855 | const rnn_utils::rnn_conf_t &rnn_; |
856 | const rnn_pd_t *pd_; |
857 | bool projection_; |
858 | bf16_emulation_t *bf16_emu_ = nullptr; |
859 | const size_t bias_dt_size_; |
860 | const size_t cstate_dt_size_; |
861 | const bool is_avx512; |
862 | const bool is_avx2; |
863 | |
864 | private: |
865 | // registers/Labels used for int8 quantization and conversions |
866 | Xbyak::Address dscale_off_addr; |
867 | Xbyak::Address dshift_off_addr; |
868 | Xbyak::Address ymm_perm_mask_addr; |
869 | Xbyak::Address zmm_perm_mask_addr; |
870 | Xbyak::Address zero_addr; |
871 | Xbyak::Address u8_saturation_addr; |
872 | Xbyak::Reg64 weights_scales_reg; |
873 | Xbyak::Reg64 qtable; |
874 | Xbyak::Label qlabel; |
875 | int tmp_vector_register_idx; |
876 | int qd_reg_idx; |
877 | |
878 | // registers used for bf16 conversions |
879 | Xbyak::Zmm bf16_reg1; |
880 | Xbyak::Zmm bf16_reg2; |
881 | Xbyak::Zmm bf16_reg3; |
882 | Xbyak::Reg64 bf16_reg4; |
883 | Xbyak::Zmm bf16_reg5; |
884 | Xbyak::Reg64 bf16_reg_mask; |
885 | Xbyak::Opmask bf16_k_mask; |
886 | Xbyak::Reg64 tmp_reg; |
887 | Xbyak::Opmask zmm_tail_k_mask; |
888 | |
889 | int bf16_dq_reg_idx; |
890 | bool is_zmm_mask_initialized = false; |
891 | |
892 | template <typename Vmm> |
893 | bool can_do_zmm_masked_tail_processing(Vmm vmm_reg, int in_len_bytes) { |
894 | const int vmm_bytes = vmm_reg.getBit() / 8; |
895 | return is_zmm_mask_initialized && vmm_bytes == 64 |
896 | && in_len_bytes < vmm_bytes; |
897 | } |
898 | |
899 | template <typename Vmm> |
900 | void load_zmm_masked( |
901 | const Vmm &dst, const Xbyak::Address &src, data_type_t dt) { |
902 | Xbyak::Zmm dst_masked |
903 | = Xbyak::Zmm(dst.getIdx()) | zmm_tail_k_mask | T_z; |
904 | switch (dt) { |
905 | case data_type::bf16: vmovdqu16(dst_masked, src); break; |
906 | case data_type::s8: |
907 | case data_type::u8: vmovdqu8(dst_masked, src); break; |
908 | default: vmovups(dst_masked, src); |
909 | } |
910 | } |
911 | |
912 | template <typename Vmm> |
913 | void store_zmm_masked( |
914 | const Xbyak::Address &dst, const Vmm &src, data_type_t dt) { |
915 | const Xbyak::Zmm src_masked |
916 | = Xbyak::Zmm(src.getIdx()) | zmm_tail_k_mask; |
917 | switch (dt) { |
918 | case data_type::bf16: vmovdqu16(dst, src_masked); break; |
919 | case data_type::s8: |
920 | case data_type::u8: vmovdqu8(dst, src_masked); break; |
921 | default: vmovups(dst, src_masked); |
922 | } |
923 | } |
924 | }; |
925 | |
926 | } // namespace x64 |
927 | } // namespace cpu |
928 | } // namespace impl |
929 | } // namespace dnnl |
930 | |
931 | #endif |
932 | |