1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | #include "common/utils.hpp" |
19 | |
20 | #include "cpu/x64/jit_avx512_core_amx_convolution.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | namespace amx_utils { |
27 | |
28 | using namespace dnnl::impl::memory_tracking::names; |
29 | using namespace dnnl::impl::utils; |
30 | |
31 | #define wht_blk_off(d, g, ...) \ |
32 | (with_groups ? (d).blk_off((g), __VA_ARGS__) : (d).blk_off(__VA_ARGS__)) |
33 | |
34 | struct spatial_features_3d { |
35 | |
36 | spatial_features_3d(const jit_conv_conf_t &jcp) |
37 | : input_size_(jcp.id) |
38 | , filter_size_(jcp.kd) |
39 | , dilate_(jcp.dilate_d + 1) |
40 | , stride_(jcp.stride_d) |
41 | , init_pad_(jcp.f_pad) |
42 | , end_pad_(jcp.back_pad) |
43 | , is_fast_path_(dilate_ == 1 && stride_ == 1) |
44 | , compute_extended_features_(!(is_fast_path_ || dilate_ != 1)) |
45 | , filter_(0) |
46 | , lower_offset_(0) |
47 | , output_offset_(0) |
48 | , init_overflow_(0) |
49 | , end_overflow_(0) {} |
50 | |
51 | inline int get_init_overflow(const int in) { |
52 | if (is_fast_path_) |
53 | return nstl::max(0, filter_size_ - 1 - in - init_pad_); |
54 | if (dilate_ != 1) |
55 | return div_up( |
56 | nstl::max(0, (filter_size_ - 1) * dilate_ - in - init_pad_), |
57 | dilate_); |
58 | return nstl::max(0, (filter_size_ - 1 - in - init_pad_) / stride_); |
59 | } |
60 | |
61 | inline int get_end_overflow(const int in) { |
62 | if (is_fast_path_) |
63 | return nstl::max(0, filter_size_ - input_size_ + in - end_pad_); |
64 | if (dilate_ != 1) |
65 | return div_up(nstl::max(0, |
66 | (filter_size_ - 1) * dilate_ + 1 - input_size_ |
67 | + in - end_pad_), |
68 | dilate_); |
69 | return nstl::max( |
70 | 0, (filter_size_ - input_size_ + in - end_pad_) / stride_); |
71 | } |
72 | |
73 | void update_params(const int in) { |
74 | |
75 | init_overflow_ = get_init_overflow(in); |
76 | end_overflow_ = get_end_overflow(in); |
77 | |
78 | // overflow_kd_hi |
79 | const int overflow_filter_hi_ = compute_extended_features_ |
80 | ? filter_size_ - 1 |
81 | - nstl::modulo(input_size_ - 1 + end_pad_ - in, stride_) |
82 | : 0; |
83 | // overflow_kd_lo |
84 | const int overflow_filter_lo_ |
85 | = compute_extended_features_ ? (in + init_pad_) % stride_ : 0; |
86 | |
87 | filter_ = compute_extended_features_ |
88 | ? (overflow_filter_hi_ - overflow_filter_lo_) / stride_ + 1 |
89 | : filter_size_; |
90 | |
91 | lower_offset_ = compute_extended_features_ |
92 | ? overflow_filter_lo_ + end_overflow_ * stride_ |
93 | : end_overflow_; |
94 | |
95 | output_offset_ = compute_extended_features_ |
96 | ? (in + init_pad_ - lower_offset_) / stride_ |
97 | : in + init_pad_ - end_overflow_ * dilate_; |
98 | } |
99 | |
100 | inline int get_filter_padding() { |
101 | return filter_ - init_overflow_ - end_overflow_; |
102 | } |
103 | |
104 | inline int get_lower_offset() { return lower_offset_; } |
105 | |
106 | inline int get_output_offset() { return output_offset_; } |
107 | |
108 | private: |
109 | const int input_size_; |
110 | const int filter_size_; |
111 | const int dilate_; |
112 | const int stride_; |
113 | const int init_pad_; // f_pad |
114 | const int end_pad_; // back_pad |
115 | const bool is_fast_path_; // 'dilate_ == 1 && stride_ == 1' |
116 | const bool |
117 | compute_extended_features_; // eq. '(!is_fast_path_) && dilate_ == 1' |
118 | |
119 | int filter_; |
120 | int lower_offset_; // d_lo |
121 | int output_offset_; // d_oj |
122 | |
123 | int init_overflow_; // d_t_overflow |
124 | int end_overflow_; // d_b_overflow |
125 | }; |
126 | |
127 | inline void execute_backward_convolution_body(const exec_ctx_t &ctx, |
128 | const jit_conv_conf_t &jcp, |
129 | const std::unique_ptr<jit_avx512_core_amx_bwd_data_kernel_t> &kernel, |
130 | const char *diff_dst, const char *weights, const char *bias, |
131 | const float *oscales, const float *dst_scales, char *diff_src, |
132 | const memory_desc_wrapper &diff_dst_d, |
133 | const memory_desc_wrapper &weights_d, const memory_desc_wrapper &bias_d, |
134 | const memory_desc_wrapper &diff_src_d) { |
135 | assert(jcp.nb_ic % jcp.nb_ic_blocking == 0); |
136 | |
137 | const bool is_deconv = jcp.prop_kind != prop_kind::backward_data; |
138 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
139 | |
140 | const size_t diff_dst_dt_size = jcp.typesize_in; |
141 | const size_t wei_dt_size = jcp.typesize_in; |
142 | const size_t bia_dt_size = jcp.typesize_bia; |
143 | const size_t diff_src_dt_size = jcp.typesize_out; |
144 | |
145 | const dim_t wei_g_shift = wht_blk_off(weights_d, 1, 0); |
146 | const dim_t wei_ic_shift = is_deconv |
147 | ? wht_blk_off(weights_d, 0, jcp.nb_ic_blocking) |
148 | : wht_blk_off(weights_d, 0, 0, jcp.nb_ic_blocking); |
149 | const size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
150 | |
151 | auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>( |
152 | key_conv_amx_inp_buffer); |
153 | auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>( |
154 | key_conv_amx_wsp_buffer); |
155 | auto tcfg = ctx.get_scratchpad_grantor().template get<char>( |
156 | key_conv_amx_tilecfg); |
157 | |
158 | const int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
159 | const int ih_chunks = utils::div_up(jcp.ih, jcp.ih_blk_size); |
160 | const int work_amount |
161 | = jcp.mb * jcp.ngroups * jcp.id * ih_chunks * jcp.nb_iw * ic_chunks; |
162 | |
163 | // Initialize the tile configuration in memory, so that each thread can |
164 | // load this configuration from memory via `amx_tile_configure(tcfg)`. |
165 | if (tcfg) kernel->tile_configure(tcfg); |
166 | const bool is_1d = jcp.ndims == 3; |
167 | const bool is_3d = jcp.ndims == 5; |
168 | |
169 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
170 | int start {0}, end {0}; |
171 | balance211(work_amount, nthr, ithr, start, end); |
172 | |
173 | auto p = jit_conv_call_s(); |
174 | amx_tile_configure(tcfg); |
175 | spatial_features_3d sfd(jcp); |
176 | |
177 | int mb {0}, g {0}, id_s {0}, ihc {0}, iwb {0}, icc {0}; |
178 | nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, |
179 | ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); |
180 | int last_copied_mb = -1; |
181 | int last_copied_id = -1; |
182 | int last_copied_ihc = -1; |
183 | int last_copied_iwb = -1; |
184 | int last_copied_g = -1; |
185 | while (start < end) { |
186 | char *inp_buffer = inp_p_buffer |
187 | + ithr * jcp.inp_buffer_size * diff_dst_dt_size; |
188 | |
189 | assert(IMPLICATION( |
190 | jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding)); |
191 | int ic = g * jcp.ic + icc * jcp.nb_ic_blocking * jcp.ic_block; |
192 | int icb = jcp.is_nspc ? ic : ic / jcp.ic_block; |
193 | assert(IMPLICATION( |
194 | jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding)); |
195 | const int ocb = g * (jcp.is_nspc ? jcp.oc : jcp.nb_oc); |
196 | auto bias_w = bias ? bias + (bias_d.blk_off(ic) * bia_dt_size) |
197 | : nullptr; |
198 | |
199 | const int ih_b = ihc * jcp.ih_blk_size; |
200 | const int ih_e = nstl::min(jcp.ih, ih_b + jcp.ih_blk_size); |
201 | const int iw = iwb * jcp.iw_block; |
202 | bool is_inp_buffer_relevant = true && last_copied_mb == mb |
203 | && last_copied_id == id_s && last_copied_ihc == ihc |
204 | && last_copied_iwb == iwb && last_copied_g == g; |
205 | |
206 | sfd.update_params(id_s); |
207 | p.kd_padding = sfd.get_filter_padding(); |
208 | const int d_lo = sfd.get_lower_offset(); |
209 | const int d_oj = sfd.get_output_offset(); |
210 | |
211 | int ih_step = jcp.nb_ih_blocking; |
212 | for (int ih = ih_b; ih < ih_e; ih += ih_step) { |
213 | if (!is_inp_buffer_relevant) { |
214 | const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; |
215 | const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; |
216 | // dox: x-index dilated by strides (dox = ox * stride_x) |
217 | const int doh = ih + jcp.t_pad - (gen_kh - 1); |
218 | const int dow = iw + jcp.l_pad - (gen_kw - 1); |
219 | const int doh_b = ih_b + jcp.t_pad - (gen_kh - 1); |
220 | const int doh_l = (jcp.oh - 1) * jcp.stride_h; // last oh |
221 | const int dow_l = (jcp.ow - 1) * jcp.stride_w; // last ow |
222 | |
223 | // dox_{s,f}: start and finish indices for copy kernel |
224 | const int doh_s = doh + (ih == ih_b ? 0 : gen_kh - 1); |
225 | const int doh_f = doh + (ih_step - 1) + (gen_kh - 1); |
226 | const int delta_h = doh_f - doh_s + 1; |
227 | const int doh_t_overflow = 0 < doh_s && doh_s < doh_l |
228 | ? nstl::additive_inverse_modulo(doh_s, jcp.stride_h) |
229 | : nstl::max(0, -doh_s); |
230 | const int doh_b_overflow = 0 < doh_f && doh_f < doh_l |
231 | ? nstl::modulo(doh_f, jcp.stride_h) |
232 | : nstl::max(0, nstl::min(delta_h, doh_f - doh_l)); |
233 | int dow_s = dow; |
234 | int dow_f = dow + jcp.owp - 1; |
235 | const int delta_w = dow_f - dow_s + 1; |
236 | const int dow_l_overflow = 0 < dow_s && dow_s < dow_l |
237 | ? nstl::additive_inverse_modulo(dow_s, jcp.stride_w) |
238 | : nstl::max(0, -dow_s); |
239 | const int dow_r_overflow = 0 < dow_f && dow_f < dow_l |
240 | ? nstl::modulo(dow_f, jcp.stride_w) |
241 | : nstl::max(0, nstl::min(delta_w, dow_f - dow_l)); |
242 | const int oh_s |
243 | = nstl::max(0, utils::div_up(doh_s, jcp.stride_h)); |
244 | const int ow_s |
245 | = nstl::max(0, utils::div_up(dow_s, jcp.stride_w)); |
246 | // how many real data rows to copy (including padding) |
247 | p.t_overflow = nstl::min(delta_h, doh_t_overflow); |
248 | p.b_overflow = nstl::min<size_t>( |
249 | delta_h - p.t_overflow, doh_b_overflow); |
250 | p.kh_padding = nstl::max<size_t>( |
251 | 0, delta_h - p.t_overflow - p.b_overflow); |
252 | p.l_overflow = nstl::min(delta_w, dow_l_overflow); |
253 | p.kw_padding = nstl::max<size_t>( |
254 | 0, delta_w - dow_l_overflow - dow_r_overflow); |
255 | p.r_overflow = nstl::min<size_t>( |
256 | delta_w - dow_l_overflow, dow_r_overflow); |
257 | size_t inp_offset = is_1d |
258 | ? diff_dst_d.blk_off(mb, ocb, ow_s) |
259 | : is_3d ? diff_dst_d.blk_off( |
260 | mb, ocb, d_oj, oh_s, ow_s) |
261 | : diff_dst_d.blk_off(mb, ocb, oh_s, ow_s); |
262 | p.src = diff_dst + diff_dst_dt_size * inp_offset; |
263 | p.dst = inp_buffer |
264 | + (size_t)(doh_s - doh_b) * jcp.owp |
265 | * jcp.oc_block_int * diff_dst_dt_size; |
266 | |
267 | kernel->bwd_data_copy_kernel()(&p); |
268 | } |
269 | |
270 | size_t diff_src_offset = is_1d |
271 | ? diff_src_d.blk_off(mb, icb, iw) |
272 | : is_3d ? diff_src_d.blk_off(mb, icb, id_s, ih, iw) |
273 | : diff_src_d.blk_off(mb, icb, ih, iw); |
274 | p.dst = inp_buffer |
275 | + (size_t)(ih - ih_b) * jcp.owp * jcp.oc_block_int |
276 | * diff_dst_dt_size; |
277 | p.src = diff_src + diff_src_dt_size * diff_src_offset; |
278 | p.filt = weights |
279 | + wei_dt_size |
280 | * (g * wei_g_shift + icc * wei_ic_shift |
281 | + d_lo * wht_d_stride); |
282 | p.bias = bias_w; |
283 | p.scales = &oscales[jcp.is_ic_scale * ic]; |
284 | p.dst_scale = &dst_scales[0]; |
285 | p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size; |
286 | p.last_h = (ih + ih_step <= ih_e); |
287 | p.iwb = iwb; |
288 | p.ic_blocks = icc * jcp.nb_ic_blocking; |
289 | |
290 | (*kernel)(&p); |
291 | } |
292 | last_copied_mb = mb; |
293 | last_copied_id = id_s; |
294 | last_copied_ihc = ihc; |
295 | last_copied_iwb = iwb; |
296 | last_copied_g = g; |
297 | ++start; |
298 | nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, |
299 | ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); |
300 | } |
301 | amx_tile_release(); |
302 | }); |
303 | } |
304 | |
305 | #undef wht_blk_off |
306 | |
307 | } // namespace amx_utils |
308 | } // namespace x64 |
309 | } // namespace cpu |
310 | } // namespace impl |
311 | } // namespace dnnl |
312 | |
313 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
314 | |