1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_BRGEMM_CELL_COMMON_BWD_HPP |
18 | #define CPU_X64_RNN_BRGEMM_CELL_COMMON_BWD_HPP |
19 | |
20 | #include "common/bfloat16.hpp" |
21 | #include "cpu/rnn/rnn_utils.hpp" |
22 | #include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp" |
23 | #include "cpu/x64/rnn/brgemm_cell_common_utils.hpp" |
24 | #include "cpu/x64/rnn/jit_diff_weights_peephole.hpp" |
25 | #include "cpu/x64/rnn/jit_gates_reduction.hpp" |
26 | #include "cpu/x64/rnn/rnn_brgemm_utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | /* |
34 | * Calculations: |
35 | * scratch * w_iter = dff_src_iter |
36 | * A (mb x rnn.n_gates * rnn.dhc) * B (rnn.n_gates * rnn.dhc, rnn.sic) = |
37 | * C (mb x rnn.sic) |
38 | * |
39 | * scratch * w_layer = dff_src_layer |
40 | * A (mb x rnn.n_gates * rnn.dhc) * B (rnn.n_gates * rnn.dhc, rnn.slc) = |
41 | * C (mb x rnn.slc) |
42 | * |
43 | * Data formats: |
44 | * scratch = igo (mb, n_gates, rnn.dhc) |
45 | * w_iter = gIo32i(f32)/gIO32i2o(bf16) (n_gates, rnn.sic, rnn.dhc) |
46 | * w_layer = gIo32i(f32)/gIO32i2o(bf16) (n_gates, rnn.slc, rnn.dhc) |
47 | * diff_src_layer = io (mb, rnn.slc) |
48 | * diff_src_iter = io (mb, rnn.sic) |
49 | */ |
50 | template <typename weights_t, typename scratch_t, typename gemm_acc_t> |
51 | class brgemm_diff_src_layer_iter_t { |
52 | public: |
53 | using ref_rnn_brgemm_t |
54 | = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::backward>; |
55 | |
56 | brgemm_diff_src_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm_, |
57 | const rnn_utils::rnn_conf_t &rnn, |
58 | rnn_utils::cell_position_t cell_position, scratch_t *scratch_gates, |
59 | weights_t *w_iter, weights_t *w_layer, gemm_acc_t *diff_src_iter, |
60 | gemm_acc_t *diff_src_layer, gemm_acc_t *amx_scratchpad, |
61 | x64::brgemm_batch_element_t *addr_batch_global); |
62 | |
63 | void execute() const; |
64 | |
65 | private: |
66 | struct thread_exec_ctx_t { |
67 | x64::brgemm_batch_element_t *addr_batch; |
68 | gemm_acc_t *amx_buffer; |
69 | amx_tile_configuration_loader_t tile_configure_if_needed; |
70 | }; |
71 | |
72 | void kernel_amx(const int ithr, const int nthr) const; |
73 | void kernel_amx_compute_iter(const int m_block_id, const int n_block_id, |
74 | const int gates_start, const int gates_end, |
75 | thread_exec_ctx_t &ctx) const; |
76 | void kernel(const int ithr, const int nthr) const; |
77 | |
78 | const ref_rnn_brgemm_t &rnn_brgemm_; |
79 | const rnn_utils::rnn_conf_t &rnn_; |
80 | const scratch_t *const A_; |
81 | const weights_t *const B_wei_iter_; |
82 | const weights_t *const B_wei_layer_; |
83 | gemm_acc_t *const C_diff_iter_; |
84 | gemm_acc_t *const C_diff_layer_; |
85 | const dim_t k_blocks_n_gates_; |
86 | const dim_t k_blocks_; |
87 | const dim_t k_tail_; |
88 | const dim_t k_block_; |
89 | const dim_t A_k_tail_offset_; |
90 | const dim_t B_k_tail_offset_; |
91 | const dim_t B_nb_offset_; |
92 | const dim_t B_kb_offset_; |
93 | const dim_t B_gb_iter_offset_; |
94 | const dim_t B_gb_layer_offset_; |
95 | const dim_t LDA_; |
96 | const dim_t LDC_; |
97 | const dim_t max_nthr_; |
98 | const dim_t n_blocking_; |
99 | const dim_t m_blocking_; |
100 | const int work_amount_; |
101 | const dim_t max_n_layer_blocks_; |
102 | const dim_t max_n_iter_blocks_; |
103 | const bool gemm_layer_needed_; |
104 | const brgemm_kernel_t *const kernel_iter_full_blocks_b0_; |
105 | const brgemm_kernel_t *const kernel_iter_full_blocks_b1_; |
106 | const brgemm_kernel_t *const kernel_iter_n_tail_b0_; |
107 | const brgemm_kernel_t *const kernel_iter_n_tail_b1_; |
108 | const brgemm_kernel_t *const kernel_iter_k_tail_; |
109 | const brgemm_kernel_t *const kernel_iter_nk_tail_; |
110 | const brgemm_kernel_t *const kernel_layer_full_blocks_b0_; |
111 | const brgemm_kernel_t *const kernel_layer_full_blocks_b1_; |
112 | const brgemm_kernel_t *const kernel_layer_n_tail_b0_; |
113 | const brgemm_kernel_t *const kernel_layer_n_tail_b1_; |
114 | const brgemm_kernel_t *const kernel_layer_k_tail_; |
115 | const brgemm_kernel_t *const kernel_layer_nk_tail_; |
116 | gemm_acc_t *const amx_scratchpad_; |
117 | brgemm_batch_element_t *const addr_batch_global_; |
118 | }; |
119 | |
120 | /* |
121 | * Calculations: |
122 | * src_layer^T * scratch = dff_weights_layer |
123 | * A before transpose (rnn.mb, rnn.slc) - layout in memory |
124 | * A (rnn.slc, rnn.mb) * B (rnn.mb, rnn.n_gates * rnn.dhc) = |
125 | * C (rnn.slc, rnn.n_gates * rnn.dhc) |
126 | * src_iter^T * scratch = dff_weights_iter |
127 | * A (rnn.sic, rnn.mb) * B (rnn.mb, rnn.n_gates * rnn.dhc) = |
128 | * C (rnn.sic, rnn.n_gates * rnn.dhc) |
129 | * |
130 | * Performing gates reduction |
131 | * diff_bias = scratch_blocked reduction over mb |
132 | * |
133 | * Data formats: |
134 | * src_iter = io (mb, rnn.sic) -> transposed oi (rnn.sic, mb) |
135 | * src_layer = io (mb, rnn.slc) -> transposed oi (rnn.sic, mb) |
136 | * |
137 | * scratch = igo (mb, n_gates, rnn.dhc) |
138 | * Note: |
139 | * For calculation purposes scratch is transformed locally to blocked |
140 | * (in case of bf16 vnni friendly) format Oi32o(f32)/OI32o2i(bf16) |
141 | * |
142 | * dff_weights_iter = igo (rnn.sic, rnn.n_gates, rnn.dhc) |
143 | * dff_weights_layer = igo (rnn.slc, rnn.n_gates, rnn.dhc) |
144 | * diff_bias = go(n_gates, rnn.dhc) |
145 | */ |
146 | template <typename src_layer_t, typename src_iter_t, typename scratch_t, |
147 | typename gemm_acc_t> |
148 | class brgemm_diff_weights_layer_iter_t { |
149 | |
150 | public: |
151 | using ref_rnn_brgemm_t |
152 | = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::backward>; |
153 | |
154 | brgemm_diff_weights_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm_, |
155 | const rnn_utils::rnn_conf_t &rnn, |
156 | rnn_utils::cell_position_t cell_position, |
157 | const src_layer_t *src_iter, |
158 | scratch_t *const A_iter_transposed_scratch, |
159 | const src_iter_t *src_layer, |
160 | scratch_t *const A_layer_transposed_scratch, |
161 | const scratch_t *scratch, scratch_t *scratch_gates_blocked, |
162 | gemm_acc_t *diff_weights_iter, gemm_acc_t *diff_weights_layer, |
163 | gemm_acc_t *diff_bias, gemm_acc_t *amx_scratchpad, |
164 | x64::brgemm_batch_element_t *addr_batch_global); |
165 | |
166 | void execute() const; |
167 | |
168 | private: |
169 | const ref_rnn_brgemm_t &rnn_brgemm_; |
170 | const rnn_utils::rnn_conf_t &rnn_; |
171 | const bool is_amx_; |
172 | const src_iter_t *const A_iter_; |
173 | scratch_t *const A_iter_transposed_scratch_; |
174 | const src_layer_t *const A_layer_; |
175 | scratch_t *const A_layer_transposed_scratch_; |
176 | const scratch_t *const B_; |
177 | scratch_t *const B_blocked_scratch_; |
178 | gemm_acc_t *const C_iter_; |
179 | gemm_acc_t *const C_layer_; |
180 | gemm_acc_t *const diff_bias_; |
181 | const dim_t LDA_iter_; |
182 | const dim_t LDA_layer_; |
183 | const dim_t LDC_iter_; |
184 | const dim_t LDC_layer_; |
185 | const dim_t max_nthr_; |
186 | const dim_t n_blocking_; |
187 | const dim_t m_blocking_; |
188 | const dim_t k_blocks_; |
189 | const dim_t k_tail_; |
190 | const dim_t k_block_; |
191 | const dim_t m_iter_block_; |
192 | const dim_t m_layer_block_; |
193 | const dim_t A_k_iter_tail_offset_; |
194 | const dim_t A_k_layer_tail_offset_; |
195 | const dim_t B_kb_offset_; |
196 | const dim_t B_k_tail_offset_; |
197 | const dim_t B_k_tail_offset_blocked_; |
198 | const int work_amount_; |
199 | const brgemm_kernel_t *const kernel_iter_full_blocks_; |
200 | const brgemm_kernel_t *const kernel_iter_n_tail_; |
201 | const brgemm_kernel_t *const kernel_iter_k_tail_; |
202 | const brgemm_kernel_t *const kernel_iter_nk_tail_; |
203 | const brgemm_kernel_t *const kernel_layer_full_blocks_; |
204 | const brgemm_kernel_t *const kernel_layer_n_tail_; |
205 | const brgemm_kernel_t *const kernel_layer_k_tail_; |
206 | const brgemm_kernel_t *const kernel_layer_nk_tail_; |
207 | const rnn_utils::cell_position_t cell_position_; |
208 | const jit_gates_reduction_t *const kernel_gates_reduction_; |
209 | const jit_gates_reduction_t *const kernel_gates_reduction_tail_; |
210 | const jit_brgemm_transpose_single_row_t *const kernel_transpose_iter_; |
211 | const jit_brgemm_transpose_single_row_t *const kernel_transpose_layer_; |
212 | gemm_acc_t *const amx_scratchpad_; |
213 | brgemm_batch_element_t *const addr_batch_global_; |
214 | |
215 | void kernel_amx(const int ithr, const int nthr) const; |
216 | void kernel(const int ithr, const int nthr) const; |
217 | void reorder_scratch_gates( |
218 | const scratch_t *src, scratch_t *dst, const bool do_n_tail) const; |
219 | }; |
220 | |
221 | template <typename scratch_t> |
222 | class brgemm_diff_wei_peep_t { |
223 | public: |
224 | using ref_rnn_brgemm_t |
225 | = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::backward>; |
226 | |
227 | brgemm_diff_wei_peep_t(const ref_rnn_brgemm_t &rnn_brgemm, |
228 | const rnn_utils::rnn_conf_t &rnn, |
229 | rnn_utils::cell_position_t cell_position, |
230 | const scratch_t *scratch_gates, const void *src_iter_c, |
231 | const void *dst_iter_c, float *diff_weights_peephole); |
232 | |
233 | void execute() const; |
234 | |
235 | private: |
236 | void kernel(const int ithr, const int nthr) const; |
237 | |
238 | const int n_gates_ = 3; |
239 | const rnn_utils::rnn_conf_t &rnn_; |
240 | const scratch_t *scratch_gates_; |
241 | const void *src_iter_c_; |
242 | const void *dst_iter_c_; |
243 | float *diff_weights_peephole_; |
244 | const int work_amount_; |
245 | const int dst_iter_c_ld_; |
246 | const int src_iter_c_ld_; |
247 | const jit_diff_weights_peephole_t *const kernel_; |
248 | const jit_diff_weights_peephole_t *const kernel_tail_; |
249 | }; |
250 | |
251 | } // namespace x64 |
252 | } // namespace cpu |
253 | } // namespace impl |
254 | } // namespace dnnl |
255 | |
256 | #endif |
257 | |