1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/rnn/jit_brgemm_transpose_single_row.hpp"
18
19#include <cmath>
20#include "cpu/rnn/rnn_utils.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27jit_brgemm_transpose_single_row_t::jit_brgemm_transpose_single_row_t(
28 const int m_block)
29 : jit_generator(jit_name())
30 , m_block_(m_block)
31 , full_loop_iters_(m_block_ / (vmms_available_ * simd_w_))
32 , tail_(m_block_ % simd_w_)
33 , k_blocks_nb_(m_block_ / simd_w_) {}
34
35void jit_brgemm_transpose_single_row_t::generate() {
36 preamble();
37 load_addresses();
38 compute_loop();
39 postamble();
40}
41
42#define PARAM_OFF(x) \
43 offsetof(jit_brgemm_transpose_single_row_t::call_params_t, x)
44void jit_brgemm_transpose_single_row_t::load_addresses() {
45 mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]);
46 mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]);
47}
48#undef PARAM
49
50void jit_brgemm_transpose_single_row_t::compute(
51 const dim_t unrolling, const bool is_tail) {
52
53 if (is_tail) {
54 mov(reg_tmp_.cvt32(), std::pow(2, tail_) - 1);
55 kmovd(tail_mask_, reg_tmp_.cvt32());
56 }
57
58 for (int k = unrolling - 1; k >= 0; k--) {
59 const auto read_vmm
60 = is_tail ? Xbyak::Zmm(k) | tail_mask_ | T_z : Xbyak::Zmm(k);
61 const auto src_off = k * simd_w_ * sizeof(bfloat16_t);
62 vpmovzxwd(read_vmm, ptr[reg_src_ + src_off]);
63 }
64
65 for (int k = unrolling - 1; k >= 0; k--) {
66 const auto store_vmm
67 = is_tail ? Xbyak::Zmm(k) | tail_mask_ : Xbyak::Zmm(k);
68 const auto dst_off = k * simd_w_ * sizeof(float);
69 uni_vmovups(ptr[reg_dst_ + dst_off], store_vmm);
70 }
71}
72
73void jit_brgemm_transpose_single_row_t::compute_loop() {
74 Xbyak::Label unroll_full_loop, loop_end;
75
76 if (full_loop_iters_ > 0) {
77 const auto loop_l_off = vmms_available_ * simd_w_;
78 const auto loop_src_off = loop_l_off * sizeof(bfloat16_t);
79 const auto loop_dst_off = loop_l_off * sizeof(float);
80
81 mov(reg_full_loop_, full_loop_iters_);
82 L(unroll_full_loop);
83 {
84 cmp(reg_full_loop_, 0);
85 je(loop_end, T_NEAR);
86
87 compute(vmms_available_, false);
88
89 add(reg_src_, loop_src_off);
90 add(reg_dst_, loop_dst_off);
91
92 dec(reg_full_loop_);
93 jmp(unroll_full_loop);
94 }
95 L(loop_end);
96 }
97
98 const int k_blocks_left = k_blocks_nb_ - full_loop_iters_ * vmms_available_;
99 if (k_blocks_left > 0) {
100 const auto off = k_blocks_left * simd_w_;
101 const auto src_off = off * sizeof(bfloat16_t);
102 const auto dst_off = off * sizeof(float);
103
104 compute(k_blocks_left, false);
105 add(reg_src_, src_off);
106 add(reg_dst_, dst_off);
107 }
108
109 if (tail_ > 0) compute(1, true);
110}
111
112#undef PARAM_OFF
113
114} // namespace x64
115} // namespace cpu
116} // namespace impl
117} // namespace dnnl
118