1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_GATES_REDUCTION_HPP
18#define CPU_X64_RNN_JIT_GATES_REDUCTION_HPP
19
20#include <vector>
21#include "cpu/x64/jit_generator.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace rnn_utils {
27struct rnn_conf_t;
28}; // namespace rnn_utils
29
30namespace x64 {
31
32/*
33 * Used in gates reduction phase during backward rnn/lstm calculations.
34 * Fused into diff weights calculations. Performing diff_bias calculations.
35 *
36 * diff_bias = scratch_blocked reduction over mb
37 *
38 * Data formats
39 * scratch_blocked Oi32o(f32)/OI32o2i(bf16) (n_gates * rnn.dhc, mb)
40 * diff_bias = o(n_gates * rnn.dhc)
41 */
42class jit_gates_reduction_t : public jit_generator {
43public:
44 jit_gates_reduction_t(const rnn_utils::rnn_conf_t &rnn, bool is_n_tail);
45
46 struct call_params_t {
47 const void *src = nullptr;
48 void *dst = nullptr;
49 };
50
51 void operator()(jit_gates_reduction_t::call_params_t *params) const {
52 jit_generator::operator()(params);
53 }
54
55private:
56 std::vector<Xbyak::Zmm> reserve_acc_regs();
57 void generate() override;
58 void load_addresses();
59 void init();
60 void store_data();
61 void compute_loop();
62 void compute(dim_t unrolling);
63 void compute_step(
64 const Xbyak::Zmm &acc, const Xbyak::Address &addr, bool tail);
65 size_t reserve_vmm();
66
67 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_gates_reduction_t)
68 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_gates_reduction_t);
69
70 static constexpr dim_t simd_w_ = 16;
71
72 size_t number_reserved_vmms_ = 0;
73 const rnn_utils::rnn_conf_t &rnn_;
74 const bool is_n_tail_;
75 const dim_t n_block_;
76 const dim_t n_simd_w_blks_;
77 const dim_t n_tail_;
78
79 const Xbyak::Reg64 &reg_src_ = r8;
80 const Xbyak::Reg64 &reg_dst_ = r9;
81 const Xbyak::Reg64 &reg_tmp_ = r10;
82 const Xbyak::Reg64 &reg_loop_ = r11;
83 const Xbyak::Opmask &tail_mask_ = k3;
84 const Xbyak::Zmm bf16_ones_;
85 std::vector<Xbyak::Zmm> acc_regs_;
86};
87
88} // namespace x64
89} // namespace cpu
90} // namespace impl
91} // namespace dnnl
92
93#endif
94