1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/rnn/jit_gates_reduction.hpp"
18
19#include <cmath>
20#include "cpu/rnn/rnn_utils.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27jit_gates_reduction_t::jit_gates_reduction_t(
28 const rnn_utils::rnn_conf_t &rnn, bool is_n_tail)
29 : jit_generator(jit_name())
30 , rnn_(rnn)
31 , is_n_tail_(is_n_tail)
32 , n_block_(is_n_tail_ ? rnn_.diff_wei_brgemm.n_tail
33 : rnn_.diff_wei_brgemm.n_block)
34 , n_simd_w_blks_(n_block_ / simd_w_)
35 , n_tail_(n_block_ % simd_w_)
36 , bf16_ones_(rnn_.is_bf16_conf() ? reserve_vmm() : 0)
37 , acc_regs_(reserve_acc_regs()) {}
38
39void jit_gates_reduction_t::generate() {
40 preamble();
41 load_addresses();
42 init();
43 compute_loop();
44 store_data();
45 postamble();
46}
47
48#define PARAM_OFF(x) offsetof(jit_gates_reduction_t::call_params_t, x)
49
50size_t jit_gates_reduction_t::reserve_vmm() {
51 return number_reserved_vmms_++;
52}
53
54std::vector<Xbyak::Zmm> jit_gates_reduction_t::reserve_acc_regs() {
55 std::vector<Xbyak::Zmm> acc_regs;
56 acc_regs.reserve(n_simd_w_blks_ + n_tail_);
57
58 for (int i = 0; i < n_simd_w_blks_; ++i)
59 acc_regs.emplace_back(Xbyak::Zmm(reserve_vmm()));
60
61 if (n_tail_) acc_regs.emplace_back(Xbyak::Zmm(reserve_vmm()));
62
63 return acc_regs;
64}
65
66void jit_gates_reduction_t::load_addresses() {
67 mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]);
68 mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]);
69}
70
71void jit_gates_reduction_t::init() {
72 static constexpr auto off_step = simd_w_ * sizeof(float);
73
74 for (int i = 0; i < n_simd_w_blks_; ++i)
75 uni_vmovups(acc_regs_[i], ptr[reg_dst_ + (i * off_step)]);
76
77 if (n_tail_) {
78 const int mask_f32 = (1 << n_tail_) - 1;
79 const Xbyak::Reg32 regw_tmp = reg_tmp_.cvt32();
80 mov(regw_tmp, mask_f32);
81 kmovw(tail_mask_, regw_tmp);
82
83 uni_vmovups(acc_regs_.back() | tail_mask_ | T_z,
84 ptr[reg_dst_ + (n_simd_w_blks_ * off_step)]);
85 }
86
87 if (rnn_.is_bf16_conf()) {
88 xor_(reg_tmp_, reg_tmp_);
89 mov(reg_tmp_.cvt16(), bfloat16_t(1.0f).raw_bits_);
90 const Xbyak::Xmm xmm_tmp(bf16_ones_.getIdx());
91 vmovd(xmm_tmp, reg_tmp_.cvt32());
92 vpbroadcastw(bf16_ones_, xmm_tmp);
93 }
94}
95
96void jit_gates_reduction_t::compute_step(
97 const Xbyak::Zmm &acc, const Xbyak::Address &addr, bool tail) {
98
99 const auto dst = tail ? (acc | tail_mask_) : acc;
100
101 if (rnn_.is_bf16_conf())
102 vdpbf16ps(dst, bf16_ones_, addr);
103 else
104 uni_vaddps(dst, acc, addr);
105}
106
107void jit_gates_reduction_t::compute(dim_t unrolling) {
108
109 const int n_block_off = rnn_.diff_wei_brgemm.n_block * sizeof(float);
110
111 for (dim_t k = 0; k < unrolling; ++k) {
112 const int k_offset = -1 * (k + 1) * n_block_off;
113 const int first_reversed_block = acc_regs_.size() - 1;
114
115 for (int n_block = first_reversed_block; n_block >= 0; --n_block) {
116 const bool tail = static_cast<bool>(n_tail_)
117 && n_block == first_reversed_block;
118 const auto &acc_zmm = acc_regs_[n_block];
119 const int nk_offset = k_offset + n_block * simd_w_ * sizeof(float);
120 compute_step(acc_zmm, ptr[reg_src_ + reg_loop_ + nk_offset], tail);
121 }
122 }
123}
124
125void jit_gates_reduction_t::compute_loop() {
126 const dim_t k_block = 32;
127 const dim_t k_pack = rnn_.is_bf16_conf() ? 2 : 1;
128 const dim_t k = rnn_.diff_wei_brgemm.Kpadded;
129 const auto res = std::div(k, k_block);
130 const int n_block_off = rnn_.diff_wei_brgemm.n_block
131 * (rnn_.is_bf16_conf() ? sizeof(bfloat16_t) : sizeof(float));
132 const auto &num_k_blks = res.quot;
133 const auto &k_tail = res.rem;
134
135 Xbyak::Label unroll_loop, unroll_loop_tail, end;
136
137 mov(reg_loop_, k * n_block_off);
138
139 const dim_t tail_bytes = k_tail * n_block_off;
140 const dim_t block_bytes = k_block * n_block_off;
141
142 L(unroll_loop);
143 {
144 if (num_k_blks) {
145 cmp(reg_loop_, tail_bytes);
146 jle(unroll_loop_tail, T_NEAR);
147 compute(k_block / k_pack);
148
149 sub(reg_loop_, block_bytes);
150 jmp(unroll_loop);
151 }
152 }
153
154 L(unroll_loop_tail);
155 {
156 if (tail_bytes) { compute(res.rem / k_pack); }
157 }
158
159 L(end);
160}
161
162void jit_gates_reduction_t::store_data() {
163 static constexpr auto off_step = simd_w_ * sizeof(float);
164
165 for (int i = 0; i < n_simd_w_blks_; ++i)
166 uni_vmovups(ptr[reg_dst_ + (i * off_step)], acc_regs_[i]);
167
168 if (n_tail_)
169 uni_vmovups(ptr[reg_dst_ + (n_simd_w_blks_ * off_step)] | tail_mask_,
170 acc_regs_.back());
171}
172
173#undef PARAM_OFF
174
175} // namespace x64
176} // namespace cpu
177} // namespace impl
178} // namespace dnnl
179