1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_BRGEMM_CELL_COMMON_BWD_HPP
18#define CPU_X64_RNN_BRGEMM_CELL_COMMON_BWD_HPP
19
20#include "common/bfloat16.hpp"
21#include "cpu/rnn/rnn_utils.hpp"
22#include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp"
23#include "cpu/x64/rnn/brgemm_cell_common_utils.hpp"
24#include "cpu/x64/rnn/jit_diff_weights_peephole.hpp"
25#include "cpu/x64/rnn/jit_gates_reduction.hpp"
26#include "cpu/x64/rnn/rnn_brgemm_utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33/*
34 * Calculations:
35 * scratch * w_iter = dff_src_iter
36 * A (mb x rnn.n_gates * rnn.dhc) * B (rnn.n_gates * rnn.dhc, rnn.sic) =
37 * C (mb x rnn.sic)
38 *
39 * scratch * w_layer = dff_src_layer
40 * A (mb x rnn.n_gates * rnn.dhc) * B (rnn.n_gates * rnn.dhc, rnn.slc) =
41 * C (mb x rnn.slc)
42 *
43 * Data formats:
44 * scratch = igo (mb, n_gates, rnn.dhc)
45 * w_iter = gIo32i(f32)/gIO32i2o(bf16) (n_gates, rnn.sic, rnn.dhc)
46 * w_layer = gIo32i(f32)/gIO32i2o(bf16) (n_gates, rnn.slc, rnn.dhc)
47 * diff_src_layer = io (mb, rnn.slc)
48 * diff_src_iter = io (mb, rnn.sic)
49 */
50template <typename weights_t, typename scratch_t, typename gemm_acc_t>
51class brgemm_diff_src_layer_iter_t {
52public:
53 using ref_rnn_brgemm_t
54 = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::backward>;
55
56 brgemm_diff_src_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm_,
57 const rnn_utils::rnn_conf_t &rnn,
58 rnn_utils::cell_position_t cell_position, scratch_t *scratch_gates,
59 weights_t *w_iter, weights_t *w_layer, gemm_acc_t *diff_src_iter,
60 gemm_acc_t *diff_src_layer, gemm_acc_t *amx_scratchpad,
61 x64::brgemm_batch_element_t *addr_batch_global);
62
63 void execute() const;
64
65private:
66 struct thread_exec_ctx_t {
67 x64::brgemm_batch_element_t *addr_batch;
68 gemm_acc_t *amx_buffer;
69 amx_tile_configuration_loader_t tile_configure_if_needed;
70 };
71
72 void kernel_amx(const int ithr, const int nthr) const;
73 void kernel_amx_compute_iter(const int m_block_id, const int n_block_id,
74 const int gates_start, const int gates_end,
75 thread_exec_ctx_t &ctx) const;
76 void kernel(const int ithr, const int nthr) const;
77
78 const ref_rnn_brgemm_t &rnn_brgemm_;
79 const rnn_utils::rnn_conf_t &rnn_;
80 const scratch_t *const A_;
81 const weights_t *const B_wei_iter_;
82 const weights_t *const B_wei_layer_;
83 gemm_acc_t *const C_diff_iter_;
84 gemm_acc_t *const C_diff_layer_;
85 const dim_t k_blocks_n_gates_;
86 const dim_t k_blocks_;
87 const dim_t k_tail_;
88 const dim_t k_block_;
89 const dim_t A_k_tail_offset_;
90 const dim_t B_k_tail_offset_;
91 const dim_t B_nb_offset_;
92 const dim_t B_kb_offset_;
93 const dim_t B_gb_iter_offset_;
94 const dim_t B_gb_layer_offset_;
95 const dim_t LDA_;
96 const dim_t LDC_;
97 const dim_t max_nthr_;
98 const dim_t n_blocking_;
99 const dim_t m_blocking_;
100 const int work_amount_;
101 const dim_t max_n_layer_blocks_;
102 const dim_t max_n_iter_blocks_;
103 const bool gemm_layer_needed_;
104 const brgemm_kernel_t *const kernel_iter_full_blocks_b0_;
105 const brgemm_kernel_t *const kernel_iter_full_blocks_b1_;
106 const brgemm_kernel_t *const kernel_iter_n_tail_b0_;
107 const brgemm_kernel_t *const kernel_iter_n_tail_b1_;
108 const brgemm_kernel_t *const kernel_iter_k_tail_;
109 const brgemm_kernel_t *const kernel_iter_nk_tail_;
110 const brgemm_kernel_t *const kernel_layer_full_blocks_b0_;
111 const brgemm_kernel_t *const kernel_layer_full_blocks_b1_;
112 const brgemm_kernel_t *const kernel_layer_n_tail_b0_;
113 const brgemm_kernel_t *const kernel_layer_n_tail_b1_;
114 const brgemm_kernel_t *const kernel_layer_k_tail_;
115 const brgemm_kernel_t *const kernel_layer_nk_tail_;
116 gemm_acc_t *const amx_scratchpad_;
117 brgemm_batch_element_t *const addr_batch_global_;
118};
119
120/*
121 * Calculations:
122 * src_layer^T * scratch = dff_weights_layer
123 * A before transpose (rnn.mb, rnn.slc) - layout in memory
124 * A (rnn.slc, rnn.mb) * B (rnn.mb, rnn.n_gates * rnn.dhc) =
125 * C (rnn.slc, rnn.n_gates * rnn.dhc)
126 * src_iter^T * scratch = dff_weights_iter
127 * A (rnn.sic, rnn.mb) * B (rnn.mb, rnn.n_gates * rnn.dhc) =
128 * C (rnn.sic, rnn.n_gates * rnn.dhc)
129 *
130 * Performing gates reduction
131 * diff_bias = scratch_blocked reduction over mb
132 *
133 * Data formats:
134 * src_iter = io (mb, rnn.sic) -> transposed oi (rnn.sic, mb)
135 * src_layer = io (mb, rnn.slc) -> transposed oi (rnn.sic, mb)
136 *
137 * scratch = igo (mb, n_gates, rnn.dhc)
138 * Note:
139 * For calculation purposes scratch is transformed locally to blocked
140 * (in case of bf16 vnni friendly) format Oi32o(f32)/OI32o2i(bf16)
141 *
142 * dff_weights_iter = igo (rnn.sic, rnn.n_gates, rnn.dhc)
143 * dff_weights_layer = igo (rnn.slc, rnn.n_gates, rnn.dhc)
144 * diff_bias = go(n_gates, rnn.dhc)
145 */
146template <typename src_layer_t, typename src_iter_t, typename scratch_t,
147 typename gemm_acc_t>
148class brgemm_diff_weights_layer_iter_t {
149
150public:
151 using ref_rnn_brgemm_t
152 = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::backward>;
153
154 brgemm_diff_weights_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm_,
155 const rnn_utils::rnn_conf_t &rnn,
156 rnn_utils::cell_position_t cell_position,
157 const src_layer_t *src_iter,
158 scratch_t *const A_iter_transposed_scratch,
159 const src_iter_t *src_layer,
160 scratch_t *const A_layer_transposed_scratch,
161 const scratch_t *scratch, scratch_t *scratch_gates_blocked,
162 gemm_acc_t *diff_weights_iter, gemm_acc_t *diff_weights_layer,
163 gemm_acc_t *diff_bias, gemm_acc_t *amx_scratchpad,
164 x64::brgemm_batch_element_t *addr_batch_global);
165
166 void execute() const;
167
168private:
169 const ref_rnn_brgemm_t &rnn_brgemm_;
170 const rnn_utils::rnn_conf_t &rnn_;
171 const bool is_amx_;
172 const src_iter_t *const A_iter_;
173 scratch_t *const A_iter_transposed_scratch_;
174 const src_layer_t *const A_layer_;
175 scratch_t *const A_layer_transposed_scratch_;
176 const scratch_t *const B_;
177 scratch_t *const B_blocked_scratch_;
178 gemm_acc_t *const C_iter_;
179 gemm_acc_t *const C_layer_;
180 gemm_acc_t *const diff_bias_;
181 const dim_t LDA_iter_;
182 const dim_t LDA_layer_;
183 const dim_t LDC_iter_;
184 const dim_t LDC_layer_;
185 const dim_t max_nthr_;
186 const dim_t n_blocking_;
187 const dim_t m_blocking_;
188 const dim_t k_blocks_;
189 const dim_t k_tail_;
190 const dim_t k_block_;
191 const dim_t m_iter_block_;
192 const dim_t m_layer_block_;
193 const dim_t A_k_iter_tail_offset_;
194 const dim_t A_k_layer_tail_offset_;
195 const dim_t B_kb_offset_;
196 const dim_t B_k_tail_offset_;
197 const dim_t B_k_tail_offset_blocked_;
198 const int work_amount_;
199 const brgemm_kernel_t *const kernel_iter_full_blocks_;
200 const brgemm_kernel_t *const kernel_iter_n_tail_;
201 const brgemm_kernel_t *const kernel_iter_k_tail_;
202 const brgemm_kernel_t *const kernel_iter_nk_tail_;
203 const brgemm_kernel_t *const kernel_layer_full_blocks_;
204 const brgemm_kernel_t *const kernel_layer_n_tail_;
205 const brgemm_kernel_t *const kernel_layer_k_tail_;
206 const brgemm_kernel_t *const kernel_layer_nk_tail_;
207 const rnn_utils::cell_position_t cell_position_;
208 const jit_gates_reduction_t *const kernel_gates_reduction_;
209 const jit_gates_reduction_t *const kernel_gates_reduction_tail_;
210 const jit_brgemm_transpose_single_row_t *const kernel_transpose_iter_;
211 const jit_brgemm_transpose_single_row_t *const kernel_transpose_layer_;
212 gemm_acc_t *const amx_scratchpad_;
213 brgemm_batch_element_t *const addr_batch_global_;
214
215 void kernel_amx(const int ithr, const int nthr) const;
216 void kernel(const int ithr, const int nthr) const;
217 void reorder_scratch_gates(
218 const scratch_t *src, scratch_t *dst, const bool do_n_tail) const;
219};
220
221template <typename scratch_t>
222class brgemm_diff_wei_peep_t {
223public:
224 using ref_rnn_brgemm_t
225 = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::backward>;
226
227 brgemm_diff_wei_peep_t(const ref_rnn_brgemm_t &rnn_brgemm,
228 const rnn_utils::rnn_conf_t &rnn,
229 rnn_utils::cell_position_t cell_position,
230 const scratch_t *scratch_gates, const void *src_iter_c,
231 const void *dst_iter_c, float *diff_weights_peephole);
232
233 void execute() const;
234
235private:
236 void kernel(const int ithr, const int nthr) const;
237
238 const int n_gates_ = 3;
239 const rnn_utils::rnn_conf_t &rnn_;
240 const scratch_t *scratch_gates_;
241 const void *src_iter_c_;
242 const void *dst_iter_c_;
243 float *diff_weights_peephole_;
244 const int work_amount_;
245 const int dst_iter_c_ld_;
246 const int src_iter_c_ld_;
247 const jit_diff_weights_peephole_t *const kernel_;
248 const jit_diff_weights_peephole_t *const kernel_tail_;
249};
250
251} // namespace x64
252} // namespace cpu
253} // namespace impl
254} // namespace dnnl
255
256#endif
257