1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_DIFF_WEIGHTS_PEEPHOLE_HPP
18#define CPU_X64_RNN_JIT_DIFF_WEIGHTS_PEEPHOLE_HPP
19
20#include "cpu/x64/jit_generator.hpp"
21#include "cpu/x64/utils/jit_io_helper.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace rnn_utils {
27struct rnn_conf_t;
28}; // namespace rnn_utils
29namespace x64 {
30
31class jit_diff_weights_peephole_t : public jit_generator {
32public:
33 jit_diff_weights_peephole_t(
34 const rnn_utils::rnn_conf_t &rnn, const dim_t dhc_block);
35
36 struct call_params_t {
37 const void *c_states = nullptr;
38 const void *scratch_gates = nullptr;
39 void *dst = nullptr;
40 };
41
42 void operator()(jit_diff_weights_peephole_t::call_params_t *params) const {
43 jit_generator::operator()(params);
44 }
45
46private:
47 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_diff_weights_peephole_t);
48 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_diff_weights_peephole_t);
49
50 void generate() override;
51 void load_addresses();
52 void init();
53 void compute_loop();
54 void compute_dst(size_t unrolling, bool tail);
55
56 static constexpr dim_t simd_w_ = 16;
57 static constexpr dim_t max_unrolling = 10;
58
59 const data_type_t c_states_dt_;
60 const data_type_t scratch_dt_;
61 const data_type_t dst_dt_;
62
63 const Xbyak::Reg64 &loop_cnt_ = rax;
64 const Xbyak::Reg64 &reg_c_states_ = r8;
65 const Xbyak::Reg64 &reg_scratch_gates_ = r9;
66 const Xbyak::Reg64 &reg_dst_ = r10;
67 const Xbyak::Reg64 &reg_tmp_ = r11;
68 const Xbyak::Reg64 &reg_offset_ = r12;
69
70 const Xbyak::Opmask &tail_opmask_ = k3;
71
72 const dim_t compute_block_size_;
73 const dim_t tail_size_;
74
75 io::jit_io_multi_dt_helper_t<Xbyak::Zmm> io_;
76};
77
78} // namespace x64
79} // namespace cpu
80} // namespace impl
81} // namespace dnnl
82
83#endif
84