1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/rnn/jit_gates_reduction.hpp" |
18 | |
19 | #include <cmath> |
20 | #include "cpu/rnn/rnn_utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | jit_gates_reduction_t::jit_gates_reduction_t( |
28 | const rnn_utils::rnn_conf_t &rnn, bool is_n_tail) |
29 | : jit_generator(jit_name()) |
30 | , rnn_(rnn) |
31 | , is_n_tail_(is_n_tail) |
32 | , n_block_(is_n_tail_ ? rnn_.diff_wei_brgemm.n_tail |
33 | : rnn_.diff_wei_brgemm.n_block) |
34 | , n_simd_w_blks_(n_block_ / simd_w_) |
35 | , n_tail_(n_block_ % simd_w_) |
36 | , bf16_ones_(rnn_.is_bf16_conf() ? reserve_vmm() : 0) |
37 | , acc_regs_(reserve_acc_regs()) {} |
38 | |
39 | void jit_gates_reduction_t::generate() { |
40 | preamble(); |
41 | load_addresses(); |
42 | init(); |
43 | compute_loop(); |
44 | store_data(); |
45 | postamble(); |
46 | } |
47 | |
48 | #define PARAM_OFF(x) offsetof(jit_gates_reduction_t::call_params_t, x) |
49 | |
50 | size_t jit_gates_reduction_t::reserve_vmm() { |
51 | return number_reserved_vmms_++; |
52 | } |
53 | |
54 | std::vector<Xbyak::Zmm> jit_gates_reduction_t::reserve_acc_regs() { |
55 | std::vector<Xbyak::Zmm> acc_regs; |
56 | acc_regs.reserve(n_simd_w_blks_ + n_tail_); |
57 | |
58 | for (int i = 0; i < n_simd_w_blks_; ++i) |
59 | acc_regs.emplace_back(Xbyak::Zmm(reserve_vmm())); |
60 | |
61 | if (n_tail_) acc_regs.emplace_back(Xbyak::Zmm(reserve_vmm())); |
62 | |
63 | return acc_regs; |
64 | } |
65 | |
66 | void jit_gates_reduction_t::load_addresses() { |
67 | mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]); |
68 | mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]); |
69 | } |
70 | |
71 | void jit_gates_reduction_t::init() { |
72 | static constexpr auto off_step = simd_w_ * sizeof(float); |
73 | |
74 | for (int i = 0; i < n_simd_w_blks_; ++i) |
75 | uni_vmovups(acc_regs_[i], ptr[reg_dst_ + (i * off_step)]); |
76 | |
77 | if (n_tail_) { |
78 | const int mask_f32 = (1 << n_tail_) - 1; |
79 | const Xbyak::Reg32 regw_tmp = reg_tmp_.cvt32(); |
80 | mov(regw_tmp, mask_f32); |
81 | kmovw(tail_mask_, regw_tmp); |
82 | |
83 | uni_vmovups(acc_regs_.back() | tail_mask_ | T_z, |
84 | ptr[reg_dst_ + (n_simd_w_blks_ * off_step)]); |
85 | } |
86 | |
87 | if (rnn_.is_bf16_conf()) { |
88 | xor_(reg_tmp_, reg_tmp_); |
89 | mov(reg_tmp_.cvt16(), bfloat16_t(1.0f).raw_bits_); |
90 | const Xbyak::Xmm xmm_tmp(bf16_ones_.getIdx()); |
91 | vmovd(xmm_tmp, reg_tmp_.cvt32()); |
92 | vpbroadcastw(bf16_ones_, xmm_tmp); |
93 | } |
94 | } |
95 | |
96 | void jit_gates_reduction_t::compute_step( |
97 | const Xbyak::Zmm &acc, const Xbyak::Address &addr, bool tail) { |
98 | |
99 | const auto dst = tail ? (acc | tail_mask_) : acc; |
100 | |
101 | if (rnn_.is_bf16_conf()) |
102 | vdpbf16ps(dst, bf16_ones_, addr); |
103 | else |
104 | uni_vaddps(dst, acc, addr); |
105 | } |
106 | |
107 | void jit_gates_reduction_t::compute(dim_t unrolling) { |
108 | |
109 | const int n_block_off = rnn_.diff_wei_brgemm.n_block * sizeof(float); |
110 | |
111 | for (dim_t k = 0; k < unrolling; ++k) { |
112 | const int k_offset = -1 * (k + 1) * n_block_off; |
113 | const int first_reversed_block = acc_regs_.size() - 1; |
114 | |
115 | for (int n_block = first_reversed_block; n_block >= 0; --n_block) { |
116 | const bool tail = static_cast<bool>(n_tail_) |
117 | && n_block == first_reversed_block; |
118 | const auto &acc_zmm = acc_regs_[n_block]; |
119 | const int nk_offset = k_offset + n_block * simd_w_ * sizeof(float); |
120 | compute_step(acc_zmm, ptr[reg_src_ + reg_loop_ + nk_offset], tail); |
121 | } |
122 | } |
123 | } |
124 | |
125 | void jit_gates_reduction_t::compute_loop() { |
126 | const dim_t k_block = 32; |
127 | const dim_t k_pack = rnn_.is_bf16_conf() ? 2 : 1; |
128 | const dim_t k = rnn_.diff_wei_brgemm.Kpadded; |
129 | const auto res = std::div(k, k_block); |
130 | const int n_block_off = rnn_.diff_wei_brgemm.n_block |
131 | * (rnn_.is_bf16_conf() ? sizeof(bfloat16_t) : sizeof(float)); |
132 | const auto &num_k_blks = res.quot; |
133 | const auto &k_tail = res.rem; |
134 | |
135 | Xbyak::Label unroll_loop, unroll_loop_tail, end; |
136 | |
137 | mov(reg_loop_, k * n_block_off); |
138 | |
139 | const dim_t tail_bytes = k_tail * n_block_off; |
140 | const dim_t block_bytes = k_block * n_block_off; |
141 | |
142 | L(unroll_loop); |
143 | { |
144 | if (num_k_blks) { |
145 | cmp(reg_loop_, tail_bytes); |
146 | jle(unroll_loop_tail, T_NEAR); |
147 | compute(k_block / k_pack); |
148 | |
149 | sub(reg_loop_, block_bytes); |
150 | jmp(unroll_loop); |
151 | } |
152 | } |
153 | |
154 | L(unroll_loop_tail); |
155 | { |
156 | if (tail_bytes) { compute(res.rem / k_pack); } |
157 | } |
158 | |
159 | L(end); |
160 | } |
161 | |
162 | void jit_gates_reduction_t::store_data() { |
163 | static constexpr auto off_step = simd_w_ * sizeof(float); |
164 | |
165 | for (int i = 0; i < n_simd_w_blks_; ++i) |
166 | uni_vmovups(ptr[reg_dst_ + (i * off_step)], acc_regs_[i]); |
167 | |
168 | if (n_tail_) |
169 | uni_vmovups(ptr[reg_dst_ + (n_simd_w_blks_ * off_step)] | tail_mask_, |
170 | acc_regs_.back()); |
171 | } |
172 | |
173 | #undef PARAM_OFF |
174 | |
175 | } // namespace x64 |
176 | } // namespace cpu |
177 | } // namespace impl |
178 | } // namespace dnnl |
179 | |