1 | /******************************************************************************* |
2 | * Copyright 2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_DIFF_WEIGHTS_PEEPHOLE_HPP |
18 | #define CPU_X64_RNN_JIT_DIFF_WEIGHTS_PEEPHOLE_HPP |
19 | |
20 | #include "cpu/x64/jit_generator.hpp" |
21 | #include "cpu/x64/utils/jit_io_helper.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace rnn_utils { |
27 | struct rnn_conf_t; |
28 | }; // namespace rnn_utils |
29 | namespace x64 { |
30 | |
31 | class jit_diff_weights_peephole_t : public jit_generator { |
32 | public: |
33 | jit_diff_weights_peephole_t( |
34 | const rnn_utils::rnn_conf_t &rnn, const dim_t dhc_block); |
35 | |
36 | struct call_params_t { |
37 | const void *c_states = nullptr; |
38 | const void *scratch_gates = nullptr; |
39 | void *dst = nullptr; |
40 | }; |
41 | |
42 | void operator()(jit_diff_weights_peephole_t::call_params_t *params) const { |
43 | jit_generator::operator()(params); |
44 | } |
45 | |
46 | private: |
47 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_diff_weights_peephole_t); |
48 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_diff_weights_peephole_t); |
49 | |
50 | void generate() override; |
51 | void load_addresses(); |
52 | void init(); |
53 | void compute_loop(); |
54 | void compute_dst(size_t unrolling, bool tail); |
55 | |
56 | static constexpr dim_t simd_w_ = 16; |
57 | static constexpr dim_t max_unrolling = 10; |
58 | |
59 | const data_type_t c_states_dt_; |
60 | const data_type_t scratch_dt_; |
61 | const data_type_t dst_dt_; |
62 | |
63 | const Xbyak::Reg64 &loop_cnt_ = rax; |
64 | const Xbyak::Reg64 ®_c_states_ = r8; |
65 | const Xbyak::Reg64 ®_scratch_gates_ = r9; |
66 | const Xbyak::Reg64 ®_dst_ = r10; |
67 | const Xbyak::Reg64 ®_tmp_ = r11; |
68 | const Xbyak::Reg64 ®_offset_ = r12; |
69 | |
70 | const Xbyak::Opmask &tail_opmask_ = k3; |
71 | |
72 | const dim_t compute_block_size_; |
73 | const dim_t tail_size_; |
74 | |
75 | io::jit_io_multi_dt_helper_t<Xbyak::Zmm> io_; |
76 | }; |
77 | |
78 | } // namespace x64 |
79 | } // namespace cpu |
80 | } // namespace impl |
81 | } // namespace dnnl |
82 | |
83 | #endif |
84 | |