1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_BWD_STRIDED_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_BWD_STRIDED_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/platform.hpp"
28
29#include "cpu/x64/amx_tile_configure.hpp"
30#include "cpu/x64/brgemm/brgemm.hpp"
31#include "cpu/x64/jit_brgemm_conv_bwd_trans_kernel.hpp"
32#include "cpu/x64/jit_brgemm_post_ops.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39template <cpu_isa_t isa, bool enable_postops = false>
40struct brgemm_convolution_bwd_strided_t : public primitive_t {
41
42 struct pd_t : public cpu_convolution_bwd_data_pd_t {
43 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
44 const typename pd_t::hint_class *hint_fwd_pd)
45 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
46
47 ~pd_t() = default;
48
49 // ------- DECLARE_COMMON_PD_t -------
50 pd_t *clone() const override {
51 auto new_pd = utils::make_unique<pd_t>(*this);
52 if (!new_pd->is_initialized()) return nullptr;
53 new_pd->brgs_.resize(brgs_sz_);
54 for (int i = 0; i < brgs_sz_; i++) {
55 new_pd->brgs_[i] = brgs_[i];
56 new_pd->bd_masks[i] = bd_masks[i];
57 }
58 return new_pd.release();
59 }
60
61 status_t create_primitive(
62 std::pair<std::shared_ptr<primitive_t>, bool> &primitive,
63 engine_t *engine,
64 const cache_blob_t &cache_blob) const override {
65 return primitive_t::create_primitive_common<
66 brgemm_convolution_bwd_strided_t, pd_t>(
67 primitive, this, engine, false, cache_blob);
68 }
69
70 const char *name() const override {
71 return JIT_IMPL_NAME_HELPER("brgconv_strided:", isa, "");
72 }
73 // ---------------------------------
74
75 status_t init(engine_t *engine);
76
77 int brgs_sz_;
78 std::vector<std::shared_ptr<brgemm_t>> brgs_;
79 std::vector<std::shared_ptr<std::vector<char>>> bd_masks;
80 jit_brgemm_conv_conf_t jcp_;
81 // batch sizes info for unrolled kernels
82 int bs_c, first_bs;
83 std::vector<int> batchsizes;
84 int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail,
85 bool is_K_tail) const {
86 auto bs_idx = 0;
87 return (((m * bs_c + bs_idx) * 2
88 + static_cast<int>(do_initialization))
89 * 2
90 + static_cast<int>(is_N_tail))
91 * 2
92 + static_cast<int>(is_K_tail);
93 }
94 };
95
96 brgemm_convolution_bwd_strided_t(const pd_t *apd)
97 : primitive_t(apd), bias_d(pd()->weights_md(1)) {}
98
99 ~brgemm_convolution_bwd_strided_t() = default;
100
101 status_t execute(const exec_ctx_t &ctx) const override;
102
103protected:
104 status_t init(engine_t *engine) override;
105
106private:
107 struct S_t {
108 char a[AMX_PALETTE_SIZE];
109 };
110
111 // brgemm convolution execution context
112 struct brgemm_bwd_exec_ctx_t {
113 brgemm_bwd_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd)
114 : diff_dst(CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST))
115 , weights(CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS))
116 , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS))
117 , dst(CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC))
118 , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args(
119 pd->attr()->post_ops_, ctx)) {}
120 const char *const __restrict diff_dst;
121 const char *const __restrict weights;
122 const char *const __restrict bias;
123 char *const __restrict dst;
124 const std::vector<const void *> post_ops_binary_rhs_arg_vec;
125 };
126
127 struct brgemm_bwd_thread_ctx_t {
128 brgemm_bwd_thread_ctx_t(brgemm_bwd_exec_ctx_t &brgemm_ctx_, int ithr_,
129 brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_,
130 char *wsp_tile_)
131 : brgemm_ctx(brgemm_ctx_)
132 , ithr(ithr_)
133 , brg_batch(brg_batch_)
134 , c_buffer(c_buffer_)
135 , wsp_tile(wsp_tile_) {}
136
137 brgemm_bwd_exec_ctx_t &brgemm_ctx;
138 int ithr;
139 brgemm_batch_element_t *__restrict brg_batch;
140 char *c_buffer;
141 char *wsp_tile;
142 S_t cur_palette;
143 int g, n, icb;
144 int id, idb, ih, ihb, iwb;
145 int occ;
146 int sw;
147 const float *oscales {nullptr};
148 };
149
150 void ker_trans(brgemm_bwd_thread_ctx_t &btc, char *inp_buffer) const;
151
152 void call_brgemm_kernel(brgemm_bwd_thread_ctx_t &btc, int brg_idx,
153 int batch_size, char *ptr_C, char *ptr_D, const char *bias_w,
154 int g_ic, bool do_postops, const void *binary_post_ops_rhs,
155 int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
156 int32_t *s8s8_comp, bool do_only_comp,
157 bool is_first_call_postops) const;
158
159 void maybe_trans_inp(int ithr, const char *__restrict input,
160 char *__restrict inp_buffer, uint8_t *__restrict inp_buffer_mask,
161 int g, int n, int icc, int odb, int ohb, int owb, int last_g,
162 int last_n, int last_icc, int last_odb, int last_ohb,
163 int last_owb) const;
164
165 status_t add_brg_kernel(int bs, int M, int i_N, int i_K, int i_init);
166 const pd_t *pd() const {
167 return static_cast<const pd_t *>(primitive_t::pd().get());
168 }
169
170 std::vector<std::unique_ptr<brgemm_kernel_t>> brg_kernels_;
171 std::unique_ptr<jit_avx512_core_brgemm_conv_bwd_trans_kernel::
172 jit_avx512_core_brgemm_conv_bwd_trans_kernel_t>
173 copy_to_pbuffer_;
174 std::vector<S_t> brg_kernel_palettes_;
175
176 size_t acc_dsz, bia_dsz, src_dsz, wei_dsz, dst_dsz;
177
178 const memory_desc_wrapper bias_d;
179
180 int KD, KH, KW, EXT_KD, EXT_KH, EXT_KW, KS, KD_BLOCK, KH_BLOCK, KW_BLOCK,
181 KD_BLOCK_PAD, KH_BLOCK_PAD, ID, IH, IW, ODP, OHP, OWP, OD, OH, OW,
182 SD, SH, SW, FP, TP, LP, DD, DH, DW;
183 dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_oc_sz,
184 wei_kw_sz, wei_kh_sz, wei_kd_sz, wei_icb_sz;
185 dim_t pbuf_w_sz, pbuf_h_sz, pbuf_d_sz;
186
187 int oc_chunks;
188 bool need_postwork;
189 bool is_amx;
190};
191
192} // namespace x64
193} // namespace cpu
194} // namespace impl
195} // namespace dnnl
196
197#endif
198
199// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
200