1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18#include "common/utils.hpp"
19
20#include "cpu/x64/jit_avx512_core_amx_convolution.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26namespace amx_utils {
27
28using namespace dnnl::impl::memory_tracking::names;
29using namespace dnnl::impl::utils;
30
31#define wht_blk_off(d, g, ...) \
32 (with_groups ? (d).blk_off((g), __VA_ARGS__) : (d).blk_off(__VA_ARGS__))
33
34struct spatial_features_3d {
35
36 spatial_features_3d(const jit_conv_conf_t &jcp)
37 : input_size_(jcp.id)
38 , filter_size_(jcp.kd)
39 , dilate_(jcp.dilate_d + 1)
40 , stride_(jcp.stride_d)
41 , init_pad_(jcp.f_pad)
42 , end_pad_(jcp.back_pad)
43 , is_fast_path_(dilate_ == 1 && stride_ == 1)
44 , compute_extended_features_(!(is_fast_path_ || dilate_ != 1))
45 , filter_(0)
46 , lower_offset_(0)
47 , output_offset_(0)
48 , init_overflow_(0)
49 , end_overflow_(0) {}
50
51 inline int get_init_overflow(const int in) {
52 if (is_fast_path_)
53 return nstl::max(0, filter_size_ - 1 - in - init_pad_);
54 if (dilate_ != 1)
55 return div_up(
56 nstl::max(0, (filter_size_ - 1) * dilate_ - in - init_pad_),
57 dilate_);
58 return nstl::max(0, (filter_size_ - 1 - in - init_pad_) / stride_);
59 }
60
61 inline int get_end_overflow(const int in) {
62 if (is_fast_path_)
63 return nstl::max(0, filter_size_ - input_size_ + in - end_pad_);
64 if (dilate_ != 1)
65 return div_up(nstl::max(0,
66 (filter_size_ - 1) * dilate_ + 1 - input_size_
67 + in - end_pad_),
68 dilate_);
69 return nstl::max(
70 0, (filter_size_ - input_size_ + in - end_pad_) / stride_);
71 }
72
73 void update_params(const int in) {
74
75 init_overflow_ = get_init_overflow(in);
76 end_overflow_ = get_end_overflow(in);
77
78 // overflow_kd_hi
79 const int overflow_filter_hi_ = compute_extended_features_
80 ? filter_size_ - 1
81 - nstl::modulo(input_size_ - 1 + end_pad_ - in, stride_)
82 : 0;
83 // overflow_kd_lo
84 const int overflow_filter_lo_
85 = compute_extended_features_ ? (in + init_pad_) % stride_ : 0;
86
87 filter_ = compute_extended_features_
88 ? (overflow_filter_hi_ - overflow_filter_lo_) / stride_ + 1
89 : filter_size_;
90
91 lower_offset_ = compute_extended_features_
92 ? overflow_filter_lo_ + end_overflow_ * stride_
93 : end_overflow_;
94
95 output_offset_ = compute_extended_features_
96 ? (in + init_pad_ - lower_offset_) / stride_
97 : in + init_pad_ - end_overflow_ * dilate_;
98 }
99
100 inline int get_filter_padding() {
101 return filter_ - init_overflow_ - end_overflow_;
102 }
103
104 inline int get_lower_offset() { return lower_offset_; }
105
106 inline int get_output_offset() { return output_offset_; }
107
108private:
109 const int input_size_;
110 const int filter_size_;
111 const int dilate_;
112 const int stride_;
113 const int init_pad_; // f_pad
114 const int end_pad_; // back_pad
115 const bool is_fast_path_; // 'dilate_ == 1 && stride_ == 1'
116 const bool
117 compute_extended_features_; // eq. '(!is_fast_path_) && dilate_ == 1'
118
119 int filter_;
120 int lower_offset_; // d_lo
121 int output_offset_; // d_oj
122
123 int init_overflow_; // d_t_overflow
124 int end_overflow_; // d_b_overflow
125};
126
127inline void execute_backward_convolution_body(const exec_ctx_t &ctx,
128 const jit_conv_conf_t &jcp,
129 const std::unique_ptr<jit_avx512_core_amx_bwd_data_kernel_t> &kernel,
130 const char *diff_dst, const char *weights, const char *bias,
131 const float *oscales, const float *dst_scales, char *diff_src,
132 const memory_desc_wrapper &diff_dst_d,
133 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &bias_d,
134 const memory_desc_wrapper &diff_src_d) {
135 assert(jcp.nb_ic % jcp.nb_ic_blocking == 0);
136
137 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
138 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
139
140 const size_t diff_dst_dt_size = jcp.typesize_in;
141 const size_t wei_dt_size = jcp.typesize_in;
142 const size_t bia_dt_size = jcp.typesize_bia;
143 const size_t diff_src_dt_size = jcp.typesize_out;
144
145 const dim_t wei_g_shift = wht_blk_off(weights_d, 1, 0);
146 const dim_t wei_ic_shift = is_deconv
147 ? wht_blk_off(weights_d, 0, jcp.nb_ic_blocking)
148 : wht_blk_off(weights_d, 0, 0, jcp.nb_ic_blocking);
149 const size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
150
151 auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>(
152 key_conv_amx_inp_buffer);
153 auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>(
154 key_conv_amx_wsp_buffer);
155 auto tcfg = ctx.get_scratchpad_grantor().template get<char>(
156 key_conv_amx_tilecfg);
157
158 const int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
159 const int ih_chunks = utils::div_up(jcp.ih, jcp.ih_blk_size);
160 const int work_amount
161 = jcp.mb * jcp.ngroups * jcp.id * ih_chunks * jcp.nb_iw * ic_chunks;
162
163 // Initialize the tile configuration in memory, so that each thread can
164 // load this configuration from memory via `amx_tile_configure(tcfg)`.
165 if (tcfg) kernel->tile_configure(tcfg);
166 const bool is_1d = jcp.ndims == 3;
167 const bool is_3d = jcp.ndims == 5;
168
169 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
170 int start {0}, end {0};
171 balance211(work_amount, nthr, ithr, start, end);
172
173 auto p = jit_conv_call_s();
174 amx_tile_configure(tcfg);
175 spatial_features_3d sfd(jcp);
176
177 int mb {0}, g {0}, id_s {0}, ihc {0}, iwb {0}, icc {0};
178 nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc,
179 ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks);
180 int last_copied_mb = -1;
181 int last_copied_id = -1;
182 int last_copied_ihc = -1;
183 int last_copied_iwb = -1;
184 int last_copied_g = -1;
185 while (start < end) {
186 char *inp_buffer = inp_p_buffer
187 + ithr * jcp.inp_buffer_size * diff_dst_dt_size;
188
189 assert(IMPLICATION(
190 jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding));
191 int ic = g * jcp.ic + icc * jcp.nb_ic_blocking * jcp.ic_block;
192 int icb = jcp.is_nspc ? ic : ic / jcp.ic_block;
193 assert(IMPLICATION(
194 jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding));
195 const int ocb = g * (jcp.is_nspc ? jcp.oc : jcp.nb_oc);
196 auto bias_w = bias ? bias + (bias_d.blk_off(ic) * bia_dt_size)
197 : nullptr;
198
199 const int ih_b = ihc * jcp.ih_blk_size;
200 const int ih_e = nstl::min(jcp.ih, ih_b + jcp.ih_blk_size);
201 const int iw = iwb * jcp.iw_block;
202 bool is_inp_buffer_relevant = true && last_copied_mb == mb
203 && last_copied_id == id_s && last_copied_ihc == ihc
204 && last_copied_iwb == iwb && last_copied_g == g;
205
206 sfd.update_params(id_s);
207 p.kd_padding = sfd.get_filter_padding();
208 const int d_lo = sfd.get_lower_offset();
209 const int d_oj = sfd.get_output_offset();
210
211 int ih_step = jcp.nb_ih_blocking;
212 for (int ih = ih_b; ih < ih_e; ih += ih_step) {
213 if (!is_inp_buffer_relevant) {
214 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
215 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
216 // dox: x-index dilated by strides (dox = ox * stride_x)
217 const int doh = ih + jcp.t_pad - (gen_kh - 1);
218 const int dow = iw + jcp.l_pad - (gen_kw - 1);
219 const int doh_b = ih_b + jcp.t_pad - (gen_kh - 1);
220 const int doh_l = (jcp.oh - 1) * jcp.stride_h; // last oh
221 const int dow_l = (jcp.ow - 1) * jcp.stride_w; // last ow
222
223 // dox_{s,f}: start and finish indices for copy kernel
224 const int doh_s = doh + (ih == ih_b ? 0 : gen_kh - 1);
225 const int doh_f = doh + (ih_step - 1) + (gen_kh - 1);
226 const int delta_h = doh_f - doh_s + 1;
227 const int doh_t_overflow = 0 < doh_s && doh_s < doh_l
228 ? nstl::additive_inverse_modulo(doh_s, jcp.stride_h)
229 : nstl::max(0, -doh_s);
230 const int doh_b_overflow = 0 < doh_f && doh_f < doh_l
231 ? nstl::modulo(doh_f, jcp.stride_h)
232 : nstl::max(0, nstl::min(delta_h, doh_f - doh_l));
233 int dow_s = dow;
234 int dow_f = dow + jcp.owp - 1;
235 const int delta_w = dow_f - dow_s + 1;
236 const int dow_l_overflow = 0 < dow_s && dow_s < dow_l
237 ? nstl::additive_inverse_modulo(dow_s, jcp.stride_w)
238 : nstl::max(0, -dow_s);
239 const int dow_r_overflow = 0 < dow_f && dow_f < dow_l
240 ? nstl::modulo(dow_f, jcp.stride_w)
241 : nstl::max(0, nstl::min(delta_w, dow_f - dow_l));
242 const int oh_s
243 = nstl::max(0, utils::div_up(doh_s, jcp.stride_h));
244 const int ow_s
245 = nstl::max(0, utils::div_up(dow_s, jcp.stride_w));
246 // how many real data rows to copy (including padding)
247 p.t_overflow = nstl::min(delta_h, doh_t_overflow);
248 p.b_overflow = nstl::min<size_t>(
249 delta_h - p.t_overflow, doh_b_overflow);
250 p.kh_padding = nstl::max<size_t>(
251 0, delta_h - p.t_overflow - p.b_overflow);
252 p.l_overflow = nstl::min(delta_w, dow_l_overflow);
253 p.kw_padding = nstl::max<size_t>(
254 0, delta_w - dow_l_overflow - dow_r_overflow);
255 p.r_overflow = nstl::min<size_t>(
256 delta_w - dow_l_overflow, dow_r_overflow);
257 size_t inp_offset = is_1d
258 ? diff_dst_d.blk_off(mb, ocb, ow_s)
259 : is_3d ? diff_dst_d.blk_off(
260 mb, ocb, d_oj, oh_s, ow_s)
261 : diff_dst_d.blk_off(mb, ocb, oh_s, ow_s);
262 p.src = diff_dst + diff_dst_dt_size * inp_offset;
263 p.dst = inp_buffer
264 + (size_t)(doh_s - doh_b) * jcp.owp
265 * jcp.oc_block_int * diff_dst_dt_size;
266
267 kernel->bwd_data_copy_kernel()(&p);
268 }
269
270 size_t diff_src_offset = is_1d
271 ? diff_src_d.blk_off(mb, icb, iw)
272 : is_3d ? diff_src_d.blk_off(mb, icb, id_s, ih, iw)
273 : diff_src_d.blk_off(mb, icb, ih, iw);
274 p.dst = inp_buffer
275 + (size_t)(ih - ih_b) * jcp.owp * jcp.oc_block_int
276 * diff_dst_dt_size;
277 p.src = diff_src + diff_src_dt_size * diff_src_offset;
278 p.filt = weights
279 + wei_dt_size
280 * (g * wei_g_shift + icc * wei_ic_shift
281 + d_lo * wht_d_stride);
282 p.bias = bias_w;
283 p.scales = &oscales[jcp.is_ic_scale * ic];
284 p.dst_scale = &dst_scales[0];
285 p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size;
286 p.last_h = (ih + ih_step <= ih_e);
287 p.iwb = iwb;
288 p.ic_blocks = icc * jcp.nb_ic_blocking;
289
290 (*kernel)(&p);
291 }
292 last_copied_mb = mb;
293 last_copied_id = id_s;
294 last_copied_ihc = ihc;
295 last_copied_iwb = iwb;
296 last_copied_g = g;
297 ++start;
298 nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc,
299 ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks);
300 }
301 amx_tile_release();
302 });
303}
304
305#undef wht_blk_off
306
307} // namespace amx_utils
308} // namespace x64
309} // namespace cpu
310} // namespace impl
311} // namespace dnnl
312
313// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
314