1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_BWD_W_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_BWD_W_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/platform.hpp"
28
29#include "cpu/x64/amx_tile_configure.hpp"
30#include "cpu/x64/brgemm/brgemm.hpp"
31#include "cpu/x64/cpu_barrier.hpp"
32#include "cpu/x64/cpu_reducer.hpp"
33#include "cpu/x64/jit_brgemm_conv_comp_pad_kernel.hpp"
34#include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp"
35#include "cpu/x64/jit_brgemm_conv_utils.hpp"
36#include "cpu/x64/jit_brgemm_post_ops.hpp"
37
38#include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp"
39#include "cpu/x64/jit_transpose_utils.hpp"
40
41namespace dnnl {
42namespace impl {
43namespace cpu {
44namespace x64 {
45
46struct brgemm_convolution_bwd_weights_t : public primitive_t {
47 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
48 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
49 const convolution_fwd_pd_t *hint_fwd_pd)
50 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
51 , jcp_()
52 , brgs_sz_(0)
53 , bs_c(0) {}
54
55 ~pd_t() = default;
56
57 // ------- DECLARE_COMMON_PD_t -----
58 pd_t *clone() const override {
59 auto new_pd = utils::make_unique<pd_t>(*this);
60 if (!new_pd->is_initialized()) return nullptr;
61 new_pd->brgs_.resize(brgs_sz_);
62 for (int i = 0; i < brgs_sz_; i++) {
63 new_pd->brgs_[i] = brgs_[i];
64 new_pd->bd_masks[i] = bd_masks[i];
65 }
66 return new_pd.release();
67 }
68
69 status_t create_primitive(
70 std::pair<std::shared_ptr<primitive_t>, bool> &primitive,
71 engine_t *engine,
72 const cache_blob_t &cache_blob) const override {
73 return primitive_t::create_primitive_common<
74 brgemm_convolution_bwd_weights_t, pd_t>(
75 primitive, this, engine, false, cache_blob);
76 }
77
78 const char *name() const override {
79 return JIT_IMPL_NAME_HELPER("brgconv_bwd_w:", jcp_.isa, "");
80 }
81 // ---------------------------------
82
83 status_t init(engine_t *engine);
84
85 jit_brgemm_conv_conf_t jcp_;
86 jit_conv_conf_t jit_jcp_;
87 void copy2jit_jcp();
88
89 int brgs_sz_;
90 std::vector<std::shared_ptr<brgemm_t>> brgs_;
91 std::vector<std::shared_ptr<std::vector<char>>> bd_masks;
92
93 int bs_c;
94 std::vector<int> batchsizes;
95 bool are_empty_bs {false};
96
97 int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail,
98 bool is_K_tail) const {
99 auto my_bs = jcp_.var_bs ? 1 : bs;
100 auto bs_idx = jcp_.use_uker ? batchsizes[my_bs] : 0;
101 assert(bs_idx >= 0);
102 return (((m * bs_c + bs_idx) * 2
103 + static_cast<int>(do_initialization))
104 * 2
105 + static_cast<int>(is_N_tail))
106 * 2
107 + static_cast<int>(is_K_tail);
108 }
109 inline int filter_w_to_src(int kw) const {
110 return kw * (jcp_.dilate_w + 1);
111 }
112 inline int filter_h_to_src(int kh) const {
113 return kh * (jcp_.dilate_h + 1) - jcp_.t_pad;
114 }
115 inline int filter_d_to_src(int kd) const {
116 return kd * (jcp_.dilate_d + 1) - jcp_.f_pad;
117 }
118 inline int get_start_ih(int kh, int oh_s) const {
119 const auto real_ih = filter_h_to_src(kh) + oh_s * jcp_.stride_h;
120 return utils::saturate(0, jcp_.ih,
121 real_ih
122 + utils::rnd_up(
123 nstl::max(0, -real_ih), jcp_.stride_h));
124 }
125 inline int get_finish_ih(int kh, int oh_e) const {
126 return utils::saturate(0, jcp_.ih,
127 filter_h_to_src(kh) + (oh_e - 1) * jcp_.stride_h + 1);
128 }
129 inline int get_start_id(int kd, int od_s) const {
130 const auto real_id = filter_d_to_src(kd) + od_s * jcp_.stride_d;
131 return utils::saturate(0, jcp_.id,
132 real_id
133 + utils::rnd_up(
134 nstl::max(0, -real_id), jcp_.stride_d));
135 }
136 inline int get_finish_id(int kd, int od_e) const {
137 return utils::saturate(0, jcp_.id,
138 filter_d_to_src(kd) + (od_e - 1) * jcp_.stride_d + 1);
139 }
140
141 inline int get_finish_oh(int oh_s, int start, int end) const {
142 int work_rem = end - start;
143 return (oh_s + work_rem > jcp_.oh ? jcp_.oh : oh_s + work_rem);
144 }
145 inline int get_finish_od(int od_s, int start, int end) const {
146 int work_rem = end - start;
147 return (od_s + work_rem > jcp_.od ? jcp_.od : od_s + work_rem);
148 }
149 };
150
151 brgemm_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
152
153 typedef typename prec_traits<data_type::bf16>::type src_data_t;
154 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
155
156 status_t init(engine_t *engine) override;
157
158 status_t execute(const exec_ctx_t &ctx) const override {
159 execute_backward_weights(ctx);
160 return status::success;
161 }
162
163private:
164 struct thread_info_t;
165
166 void execute_backward_weights(const exec_ctx_t &ctx) const;
167 void prepare_scratchpad_data(const exec_ctx_t &ctx) const;
168 void compute_diff_weights_2d(thread_info_t *) const;
169 void compute_diff_weights_3d(thread_info_t *) const;
170 void reduce_and_convert_diff_weights_and_bias(thread_info_t *) const;
171 void store_in_vnni_format(thread_info_t *) const;
172
173 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
174
175 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
176
177 std::unique_ptr<jit_diff_wei_trans_to_vnni_t> diff_wei_trans_kernel_;
178 std::unique_ptr<jit_trans_src_t> trans_kernel_;
179 std::unique_ptr<jit_trans_dst_t> trans_dst_kernel_;
180 std::unique_ptr<jit_avx512_core_amx_bwd_bias_kernel_t> diff_bias_kernel_;
181
182 std::vector<std::unique_ptr<brgemm_kernel_t>> brg_kernels_;
183 struct S_t {
184 char a[AMX_PALETTE_SIZE];
185 };
186 std::vector<S_t> brg_kernel_palettes_;
187
188 status_t add_brg_kernel(int bs, int M, int i_N, int i_K, int i_init);
189 void call_brgemm_kernel(
190 thread_info_t &btc, int brg_idx, int batch_size, void *ptr_C) const;
191 inline dim_t wei_offset_int(
192 int g, int oc_b, int ic_b, int kd, int kh, int kw) const {
193 const auto &jcp = pd()->jcp_;
194 const dim_t const_extra_offset = jcp.ic_block * jcp.oc_block;
195 dim_t extra_offset
196 = ((kd * jcp.kh + kh) * jcp.kw + kw) * const_extra_offset;
197 return (dim_t)((g * jcp.nb_oc + oc_b) * jcp.nb_ic + ic_b) * jcp.kd
198 * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block
199 + extra_offset;
200 }
201
202 inline dim_t wei_offset_int(int g, int oc_b, int ic_b, int kX) const {
203 const auto &jcp = pd()->jcp_;
204 const dim_t const_extra_offset = jcp.kw * jcp.ic_block * jcp.oc_block;
205 dim_t extra_offset = (jcp.ndims == 5) ? kX * jcp.kh * const_extra_offset
206 : kX * const_extra_offset;
207 return (dim_t)((g * jcp.nb_oc + oc_b) * jcp.nb_ic + ic_b) * jcp.kd
208 * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block
209 + extra_offset;
210 }
211
212 inline dim_t wei_offset_ext(int g, int oc_b, int ic_b, int kX) const {
213 const auto &jcp = pd()->jcp_;
214 const int nb_ic = utils::div_up(jcp.ic, 2 * jcp.ic_block);
215 const dim_t const_extra_offset
216 = jcp.kw * jcp.ic_block * jcp.oc_block * 2;
217 dim_t extra_offset = (jcp.ndims == 5) ? kX * jcp.kh * const_extra_offset
218 : kX * const_extra_offset;
219 return (dim_t)((g * jcp.nb_oc + oc_b) * nb_ic + ic_b) * jcp.kd * jcp.kh
220 * jcp.kw * jcp.ic_block * jcp.oc_block * 2
221 + extra_offset;
222 }
223
224 inline int get_end(int start, int step, int limit) const {
225 return nstl::min(start + step, limit);
226 }
227};
228
229} // namespace x64
230} // namespace cpu
231} // namespace impl
232} // namespace dnnl
233
234#endif
235
236// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
237