1 | /******************************************************************************* |
2 | * Copyright 2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/rnn/jit_brgemm_transpose_single_row.hpp" |
18 | |
19 | #include <cmath> |
20 | #include "cpu/rnn/rnn_utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | jit_brgemm_transpose_single_row_t::jit_brgemm_transpose_single_row_t( |
28 | const int m_block) |
29 | : jit_generator(jit_name()) |
30 | , m_block_(m_block) |
31 | , full_loop_iters_(m_block_ / (vmms_available_ * simd_w_)) |
32 | , tail_(m_block_ % simd_w_) |
33 | , k_blocks_nb_(m_block_ / simd_w_) {} |
34 | |
35 | void jit_brgemm_transpose_single_row_t::generate() { |
36 | preamble(); |
37 | load_addresses(); |
38 | compute_loop(); |
39 | postamble(); |
40 | } |
41 | |
42 | #define PARAM_OFF(x) \ |
43 | offsetof(jit_brgemm_transpose_single_row_t::call_params_t, x) |
44 | void jit_brgemm_transpose_single_row_t::load_addresses() { |
45 | mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]); |
46 | mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]); |
47 | } |
48 | #undef PARAM |
49 | |
50 | void jit_brgemm_transpose_single_row_t::compute( |
51 | const dim_t unrolling, const bool is_tail) { |
52 | |
53 | if (is_tail) { |
54 | mov(reg_tmp_.cvt32(), std::pow(2, tail_) - 1); |
55 | kmovd(tail_mask_, reg_tmp_.cvt32()); |
56 | } |
57 | |
58 | for (int k = unrolling - 1; k >= 0; k--) { |
59 | const auto read_vmm |
60 | = is_tail ? Xbyak::Zmm(k) | tail_mask_ | T_z : Xbyak::Zmm(k); |
61 | const auto src_off = k * simd_w_ * sizeof(bfloat16_t); |
62 | vpmovzxwd(read_vmm, ptr[reg_src_ + src_off]); |
63 | } |
64 | |
65 | for (int k = unrolling - 1; k >= 0; k--) { |
66 | const auto store_vmm |
67 | = is_tail ? Xbyak::Zmm(k) | tail_mask_ : Xbyak::Zmm(k); |
68 | const auto dst_off = k * simd_w_ * sizeof(float); |
69 | uni_vmovups(ptr[reg_dst_ + dst_off], store_vmm); |
70 | } |
71 | } |
72 | |
73 | void jit_brgemm_transpose_single_row_t::compute_loop() { |
74 | Xbyak::Label unroll_full_loop, loop_end; |
75 | |
76 | if (full_loop_iters_ > 0) { |
77 | const auto loop_l_off = vmms_available_ * simd_w_; |
78 | const auto loop_src_off = loop_l_off * sizeof(bfloat16_t); |
79 | const auto loop_dst_off = loop_l_off * sizeof(float); |
80 | |
81 | mov(reg_full_loop_, full_loop_iters_); |
82 | L(unroll_full_loop); |
83 | { |
84 | cmp(reg_full_loop_, 0); |
85 | je(loop_end, T_NEAR); |
86 | |
87 | compute(vmms_available_, false); |
88 | |
89 | add(reg_src_, loop_src_off); |
90 | add(reg_dst_, loop_dst_off); |
91 | |
92 | dec(reg_full_loop_); |
93 | jmp(unroll_full_loop); |
94 | } |
95 | L(loop_end); |
96 | } |
97 | |
98 | const int k_blocks_left = k_blocks_nb_ - full_loop_iters_ * vmms_available_; |
99 | if (k_blocks_left > 0) { |
100 | const auto off = k_blocks_left * simd_w_; |
101 | const auto src_off = off * sizeof(bfloat16_t); |
102 | const auto dst_off = off * sizeof(float); |
103 | |
104 | compute(k_blocks_left, false); |
105 | add(reg_src_, src_off); |
106 | add(reg_dst_, dst_off); |
107 | } |
108 | |
109 | if (tail_ > 0) compute(1, true); |
110 | } |
111 | |
112 | #undef PARAM_OFF |
113 | |
114 | } // namespace x64 |
115 | } // namespace cpu |
116 | } // namespace impl |
117 | } // namespace dnnl |
118 | |