1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRGEMM_CONV_BWD_W_HPP |
18 | #define CPU_X64_JIT_BRGEMM_CONV_BWD_W_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/cpu_convolution_pd.hpp" |
27 | #include "cpu/platform.hpp" |
28 | |
29 | #include "cpu/x64/amx_tile_configure.hpp" |
30 | #include "cpu/x64/brgemm/brgemm.hpp" |
31 | #include "cpu/x64/cpu_barrier.hpp" |
32 | #include "cpu/x64/cpu_reducer.hpp" |
33 | #include "cpu/x64/jit_brgemm_conv_comp_pad_kernel.hpp" |
34 | #include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp" |
35 | #include "cpu/x64/jit_brgemm_conv_utils.hpp" |
36 | #include "cpu/x64/jit_brgemm_post_ops.hpp" |
37 | |
38 | #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp" |
39 | #include "cpu/x64/jit_transpose_utils.hpp" |
40 | |
41 | namespace dnnl { |
42 | namespace impl { |
43 | namespace cpu { |
44 | namespace x64 { |
45 | |
46 | struct brgemm_convolution_bwd_weights_t : public primitive_t { |
47 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
48 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
49 | const convolution_fwd_pd_t *hint_fwd_pd) |
50 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
51 | , jcp_() |
52 | , brgs_sz_(0) |
53 | , bs_c(0) {} |
54 | |
55 | ~pd_t() = default; |
56 | |
57 | // ------- DECLARE_COMMON_PD_t ----- |
58 | pd_t *clone() const override { |
59 | auto new_pd = utils::make_unique<pd_t>(*this); |
60 | if (!new_pd->is_initialized()) return nullptr; |
61 | new_pd->brgs_.resize(brgs_sz_); |
62 | for (int i = 0; i < brgs_sz_; i++) { |
63 | new_pd->brgs_[i] = brgs_[i]; |
64 | new_pd->bd_masks[i] = bd_masks[i]; |
65 | } |
66 | return new_pd.release(); |
67 | } |
68 | |
69 | status_t create_primitive( |
70 | std::pair<std::shared_ptr<primitive_t>, bool> &primitive, |
71 | engine_t *engine, |
72 | const cache_blob_t &cache_blob) const override { |
73 | return primitive_t::create_primitive_common< |
74 | brgemm_convolution_bwd_weights_t, pd_t>( |
75 | primitive, this, engine, false, cache_blob); |
76 | } |
77 | |
78 | const char *name() const override { |
79 | return JIT_IMPL_NAME_HELPER("brgconv_bwd_w:" , jcp_.isa, "" ); |
80 | } |
81 | // --------------------------------- |
82 | |
83 | status_t init(engine_t *engine); |
84 | |
85 | jit_brgemm_conv_conf_t jcp_; |
86 | jit_conv_conf_t jit_jcp_; |
87 | void copy2jit_jcp(); |
88 | |
89 | int brgs_sz_; |
90 | std::vector<std::shared_ptr<brgemm_t>> brgs_; |
91 | std::vector<std::shared_ptr<std::vector<char>>> bd_masks; |
92 | |
93 | int bs_c; |
94 | std::vector<int> batchsizes; |
95 | bool are_empty_bs {false}; |
96 | |
97 | int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail, |
98 | bool is_K_tail) const { |
99 | auto my_bs = jcp_.var_bs ? 1 : bs; |
100 | auto bs_idx = jcp_.use_uker ? batchsizes[my_bs] : 0; |
101 | assert(bs_idx >= 0); |
102 | return (((m * bs_c + bs_idx) * 2 |
103 | + static_cast<int>(do_initialization)) |
104 | * 2 |
105 | + static_cast<int>(is_N_tail)) |
106 | * 2 |
107 | + static_cast<int>(is_K_tail); |
108 | } |
109 | inline int filter_w_to_src(int kw) const { |
110 | return kw * (jcp_.dilate_w + 1); |
111 | } |
112 | inline int filter_h_to_src(int kh) const { |
113 | return kh * (jcp_.dilate_h + 1) - jcp_.t_pad; |
114 | } |
115 | inline int filter_d_to_src(int kd) const { |
116 | return kd * (jcp_.dilate_d + 1) - jcp_.f_pad; |
117 | } |
118 | inline int get_start_ih(int kh, int oh_s) const { |
119 | const auto real_ih = filter_h_to_src(kh) + oh_s * jcp_.stride_h; |
120 | return utils::saturate(0, jcp_.ih, |
121 | real_ih |
122 | + utils::rnd_up( |
123 | nstl::max(0, -real_ih), jcp_.stride_h)); |
124 | } |
125 | inline int get_finish_ih(int kh, int oh_e) const { |
126 | return utils::saturate(0, jcp_.ih, |
127 | filter_h_to_src(kh) + (oh_e - 1) * jcp_.stride_h + 1); |
128 | } |
129 | inline int get_start_id(int kd, int od_s) const { |
130 | const auto real_id = filter_d_to_src(kd) + od_s * jcp_.stride_d; |
131 | return utils::saturate(0, jcp_.id, |
132 | real_id |
133 | + utils::rnd_up( |
134 | nstl::max(0, -real_id), jcp_.stride_d)); |
135 | } |
136 | inline int get_finish_id(int kd, int od_e) const { |
137 | return utils::saturate(0, jcp_.id, |
138 | filter_d_to_src(kd) + (od_e - 1) * jcp_.stride_d + 1); |
139 | } |
140 | |
141 | inline int get_finish_oh(int oh_s, int start, int end) const { |
142 | int work_rem = end - start; |
143 | return (oh_s + work_rem > jcp_.oh ? jcp_.oh : oh_s + work_rem); |
144 | } |
145 | inline int get_finish_od(int od_s, int start, int end) const { |
146 | int work_rem = end - start; |
147 | return (od_s + work_rem > jcp_.od ? jcp_.od : od_s + work_rem); |
148 | } |
149 | }; |
150 | |
151 | brgemm_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
152 | |
153 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
154 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
155 | |
156 | status_t init(engine_t *engine) override; |
157 | |
158 | status_t execute(const exec_ctx_t &ctx) const override { |
159 | execute_backward_weights(ctx); |
160 | return status::success; |
161 | } |
162 | |
163 | private: |
164 | struct thread_info_t; |
165 | |
166 | void execute_backward_weights(const exec_ctx_t &ctx) const; |
167 | void prepare_scratchpad_data(const exec_ctx_t &ctx) const; |
168 | void compute_diff_weights_2d(thread_info_t *) const; |
169 | void compute_diff_weights_3d(thread_info_t *) const; |
170 | void reduce_and_convert_diff_weights_and_bias(thread_info_t *) const; |
171 | void store_in_vnni_format(thread_info_t *) const; |
172 | |
173 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
174 | |
175 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
176 | |
177 | std::unique_ptr<jit_diff_wei_trans_to_vnni_t> diff_wei_trans_kernel_; |
178 | std::unique_ptr<jit_trans_src_t> trans_kernel_; |
179 | std::unique_ptr<jit_trans_dst_t> trans_dst_kernel_; |
180 | std::unique_ptr<jit_avx512_core_amx_bwd_bias_kernel_t> diff_bias_kernel_; |
181 | |
182 | std::vector<std::unique_ptr<brgemm_kernel_t>> brg_kernels_; |
183 | struct S_t { |
184 | char a[AMX_PALETTE_SIZE]; |
185 | }; |
186 | std::vector<S_t> brg_kernel_palettes_; |
187 | |
188 | status_t add_brg_kernel(int bs, int M, int i_N, int i_K, int i_init); |
189 | void call_brgemm_kernel( |
190 | thread_info_t &btc, int brg_idx, int batch_size, void *ptr_C) const; |
191 | inline dim_t wei_offset_int( |
192 | int g, int oc_b, int ic_b, int kd, int kh, int kw) const { |
193 | const auto &jcp = pd()->jcp_; |
194 | const dim_t = jcp.ic_block * jcp.oc_block; |
195 | dim_t |
196 | = ((kd * jcp.kh + kh) * jcp.kw + kw) * const_extra_offset; |
197 | return (dim_t)((g * jcp.nb_oc + oc_b) * jcp.nb_ic + ic_b) * jcp.kd |
198 | * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block |
199 | + extra_offset; |
200 | } |
201 | |
202 | inline dim_t wei_offset_int(int g, int oc_b, int ic_b, int kX) const { |
203 | const auto &jcp = pd()->jcp_; |
204 | const dim_t = jcp.kw * jcp.ic_block * jcp.oc_block; |
205 | dim_t = (jcp.ndims == 5) ? kX * jcp.kh * const_extra_offset |
206 | : kX * const_extra_offset; |
207 | return (dim_t)((g * jcp.nb_oc + oc_b) * jcp.nb_ic + ic_b) * jcp.kd |
208 | * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block |
209 | + extra_offset; |
210 | } |
211 | |
212 | inline dim_t wei_offset_ext(int g, int oc_b, int ic_b, int kX) const { |
213 | const auto &jcp = pd()->jcp_; |
214 | const int nb_ic = utils::div_up(jcp.ic, 2 * jcp.ic_block); |
215 | const dim_t |
216 | = jcp.kw * jcp.ic_block * jcp.oc_block * 2; |
217 | dim_t = (jcp.ndims == 5) ? kX * jcp.kh * const_extra_offset |
218 | : kX * const_extra_offset; |
219 | return (dim_t)((g * jcp.nb_oc + oc_b) * nb_ic + ic_b) * jcp.kd * jcp.kh |
220 | * jcp.kw * jcp.ic_block * jcp.oc_block * 2 |
221 | + extra_offset; |
222 | } |
223 | |
224 | inline int get_end(int start, int step, int limit) const { |
225 | return nstl::min(start + step, limit); |
226 | } |
227 | }; |
228 | |
229 | } // namespace x64 |
230 | } // namespace cpu |
231 | } // namespace impl |
232 | } // namespace dnnl |
233 | |
234 | #endif |
235 | |
236 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
237 | |