1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_BRGEMM_TRANSPOSE_SINGLE_ROW_HPP
18#define CPU_X64_RNN_JIT_BRGEMM_TRANSPOSE_SINGLE_ROW_HPP
19
20#include <vector>
21#include "cpu/x64/jit_generator.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace rnn_utils {
27struct rnn_conf_t;
28}; // namespace rnn_utils
29
30namespace x64 {
31
32/*
33 * Transpose generator for brgemm based rnn, optimized for cases when number of
34 * input rows == 1.
35 * In such case, because of perf reasons, number of output columns is extended
36 * to 2.
37 */
38class jit_brgemm_transpose_single_row_t : public jit_generator {
39public:
40 jit_brgemm_transpose_single_row_t(const int m_block);
41
42 struct call_params_t {
43 const void *src = nullptr;
44 void *dst = nullptr;
45 };
46
47 void operator()(
48 jit_brgemm_transpose_single_row_t::call_params_t *params) const {
49 jit_generator::operator()(params);
50 }
51
52private:
53 std::vector<Xbyak::Zmm> reserve_acc_regs();
54 void generate() override;
55 void load_addresses();
56 void compute_loop();
57 void compute(const dim_t unrolling, const bool is_tail);
58
59 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_transpose_single_row_t)
60 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_brgemm_transpose_single_row_t);
61
62 static constexpr dim_t simd_w_ = 16;
63 static constexpr dim_t vmms_available_ = 32;
64
65 const int m_block_;
66 const int full_loop_iters_;
67 const int tail_;
68 const int k_blocks_nb_;
69
70 const Xbyak::Reg64 &reg_src_ = r8;
71 const Xbyak::Reg64 &reg_dst_ = r9;
72
73 const Xbyak::Reg64 &reg_tmp_ = r10;
74 const Xbyak::Reg64 &reg_full_loop_ = r11;
75 const Xbyak::Opmask &tail_mask_ = k1;
76};
77
78} // namespace x64
79} // namespace cpu
80} // namespace impl
81} // namespace dnnl
82
83#endif
84