1 | /******************************************************************************* |
2 | * Copyright 2017-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifdef __INTEL_COMPILER |
18 | #include <immintrin.h> |
19 | #endif |
20 | |
21 | #include "oneapi/dnnl/dnnl_types.h" |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp" |
29 | |
30 | #ifndef _MSC_VER |
31 | #define pragma_unroll _Pragma("unroll") |
32 | #else |
33 | #define pragma_unroll |
34 | #endif |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | using namespace dnnl::impl::status; |
42 | using namespace dnnl::impl::memory_tracking::names; |
43 | using namespace dnnl::impl::utils; |
44 | |
45 | template <bool is_fwd> |
46 | void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::weight_transform_data( |
47 | const jit_conv_winograd_conf_t &jcp, float *wp, float *twp) const { |
48 | float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f, |
49 | 1.13777777777778f, 0.430252100840336f, 0.179271708683473f}; |
50 | const int kh = 3; |
51 | const int kw = 3; |
52 | float Fw[alpha][alpha][simd_w][simd_w]; |
53 | float F[kh][kw][simd_w][simd_w]; |
54 | float T[alpha][3][simd_w]; |
55 | auto p = jit_wino_transform_call_s(); |
56 | |
57 | p.src = wp; |
58 | p.dst = twp; |
59 | p.G = G; |
60 | p.M = F; |
61 | p.Mw = Fw; |
62 | p.T = T; |
63 | |
64 | kernel_->weights_transform_data_ker(&p); |
65 | } |
66 | |
67 | template <bool is_fwd> |
68 | void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::output_transform_data( |
69 | int image, const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, |
70 | float *toutp, float *pout_b, float *bias) const { |
71 | |
72 | float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; |
73 | float Ow[alpha][alpha][simd_w]; |
74 | float O[tile_size][tile_size][simd_w]; |
75 | float T[tile_size][alpha][simd_w]; |
76 | |
77 | auto p = jit_wino_transform_call_s(); |
78 | p.src = toutp; |
79 | p.dst = pout_b; |
80 | p.G = G; |
81 | p.M = O; |
82 | p.Mw = Ow; |
83 | p.T = T; |
84 | p.bias = bias; |
85 | |
86 | int tile_base_index = image * jcp.itiles * jcp.jtiles; |
87 | int tile_block_ur = tile_base_index % jcp.tile_block_ur; |
88 | int nb_tile_block_ur |
89 | = (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; |
90 | int tile_block |
91 | = (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; |
92 | |
93 | for (int tj = 0; tj < jcp.jtiles; tj++) { |
94 | for (int ti = 0; ti < jcp.itiles; ti++) { |
95 | |
96 | p.tile_block_ur = tile_block_ur; |
97 | p.nb_tile_block_ur = nb_tile_block_ur; |
98 | p.tile_block = tile_block; |
99 | p.tj = tj; |
100 | p.ti = ti; |
101 | |
102 | kernel_->output_transform_data_ker(&p); |
103 | |
104 | tile_block_ur++; |
105 | if (tile_block_ur >= jcp.tile_block_ur) { |
106 | tile_block_ur = 0; |
107 | nb_tile_block_ur++; |
108 | } |
109 | if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { |
110 | nb_tile_block_ur = 0; |
111 | tile_block++; |
112 | } |
113 | } |
114 | } |
115 | } |
116 | |
117 | template <bool is_fwd> |
118 | void _jit_avx512_core_f32_wino_conv_4x3_t< |
119 | is_fwd>::output_transform_tileblock_data(int tile_block, |
120 | const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, |
121 | float *toutp, float *outp, float *bias) const { |
122 | |
123 | float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; |
124 | float Ow[alpha][alpha][simd_w]; |
125 | float O[tile_size][tile_size][simd_w]; |
126 | float T[tile_size][alpha][simd_w]; |
127 | |
128 | auto p = jit_wino_transform_call_s(); |
129 | p.src = toutp; |
130 | p.dst = outp; |
131 | p.G = G; |
132 | p.M = O; |
133 | p.Mw = Ow; |
134 | p.T = T; |
135 | p.bias = bias; |
136 | |
137 | int outw = is_fwd ? jcp.ow : jcp.iw; |
138 | int outh = is_fwd ? jcp.oh : jcp.ih; |
139 | |
140 | int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; |
141 | |
142 | for (int nb_tile_block_ur = 0; nb_tile_block_ur < jcp.nb_tile_block_ur; |
143 | nb_tile_block_ur++) { |
144 | |
145 | for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; |
146 | tile_block_ur++) { |
147 | int img = tile_index / (jcp.jtiles * jcp.itiles); |
148 | int ti = tile_index % jcp.itiles; |
149 | int tj = (tile_index / jcp.itiles) % jcp.jtiles; |
150 | |
151 | p.tile_block_ur = tile_block_ur; |
152 | p.nb_tile_block_ur = nb_tile_block_ur; |
153 | p.tile_block = tile_block; |
154 | p.tj = tj; |
155 | p.ti = ti; |
156 | p.dst = outp |
157 | + img * (jcp.dimM / jcp.dimM_simd_block) * outh * outw |
158 | * jcp.dimM_simd_block; |
159 | |
160 | kernel_->output_transform_data_ker(&p); |
161 | |
162 | tile_index++; |
163 | } |
164 | } |
165 | } |
166 | |
167 | template <bool is_fwd> |
168 | void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::input_transform_data( |
169 | int image, const jit_conv_winograd_conf_t &jcp, float *inp, |
170 | float *tinp) const { |
171 | float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, -0.625f, |
172 | 1.5f, -1.5f, -2.640625f}; |
173 | |
174 | float Iw[alpha][alpha][simd_w]; |
175 | float I[alpha][alpha][simd_w]; |
176 | float T[alpha][alpha][simd_w]; |
177 | |
178 | auto p = jit_wino_transform_call_s(); |
179 | |
180 | p.src = inp; |
181 | p.dst = tinp; |
182 | p.G = G; |
183 | p.M = I; |
184 | p.Mw = Iw; |
185 | p.T = T; |
186 | |
187 | int tile_base_index = image * jcp.itiles * jcp.jtiles; |
188 | int tile_block_ur = tile_base_index % jcp.tile_block_ur; |
189 | int nb_tile_block_ur |
190 | = (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; |
191 | int tile_block |
192 | = (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; |
193 | |
194 | for (int tj = 0; tj < jcp.jtiles; tj++) { |
195 | for (int ti = 0; ti < jcp.itiles; ti++) { |
196 | |
197 | p.tile_block_ur = tile_block_ur; |
198 | p.nb_tile_block_ur = nb_tile_block_ur; |
199 | p.tile_block = tile_block; |
200 | p.tj = tj; |
201 | p.ti = ti; |
202 | |
203 | kernel_->input_transform_data_ker(&p); |
204 | |
205 | tile_block_ur++; |
206 | if (tile_block_ur >= jcp.tile_block_ur) { |
207 | tile_block_ur = 0; |
208 | nb_tile_block_ur++; |
209 | } |
210 | if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { |
211 | nb_tile_block_ur = 0; |
212 | tile_block++; |
213 | } |
214 | } |
215 | } |
216 | } |
217 | |
218 | template <bool is_fwd> |
219 | void _jit_avx512_core_f32_wino_conv_4x3_t< |
220 | is_fwd>::input_transform_tileblock_data(int tile_block, |
221 | const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp) const { |
222 | float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, -0.625f, |
223 | 1.5f, -1.5f, -2.640625f}; |
224 | float Iw[alpha][alpha][simd_w]; |
225 | float I[alpha][alpha][simd_w]; |
226 | float T[alpha][alpha][simd_w]; |
227 | |
228 | const int inph = is_fwd ? jcp.ih : jcp.oh; |
229 | const int inpw = is_fwd ? jcp.iw : jcp.ow; |
230 | |
231 | array_offset_calculator<float, 5> input( |
232 | inp, jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); |
233 | array_offset_calculator<float, 7> output(tinp, alpha, alpha, jcp.dimN_block, |
234 | jcp.dimK_nb_block, jcp.dimK_block, jcp.dimN_reg_block, |
235 | jcp.dimK_reg_block); |
236 | |
237 | auto p = jit_wino_transform_call_s(); |
238 | |
239 | p.dst = tinp; |
240 | p.G = G; |
241 | p.M = I; |
242 | p.Mw = Iw; |
243 | p.T = T; |
244 | |
245 | int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; |
246 | |
247 | for (int nb_tile_block_ur = 0; nb_tile_block_ur < jcp.nb_tile_block_ur; |
248 | nb_tile_block_ur++) { |
249 | |
250 | for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; |
251 | tile_block_ur++) { |
252 | |
253 | int img = tile_index / (jcp.jtiles * jcp.itiles); |
254 | int ti = tile_index % jcp.itiles; |
255 | int tj = (tile_index / jcp.itiles) % jcp.jtiles; |
256 | float *pinp_b = &(input(img, 0, 0, 0, 0)); |
257 | |
258 | p.src = pinp_b; |
259 | p.tile_block_ur = tile_block_ur; |
260 | p.nb_tile_block_ur = nb_tile_block_ur; |
261 | p.tj = tj; |
262 | p.ti = ti; |
263 | |
264 | kernel_->input_transform_data_ker(&p); |
265 | |
266 | tile_index++; |
267 | } |
268 | } |
269 | } |
270 | |
271 | template <bool is_fwd> |
272 | void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D( |
273 | float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, |
274 | const memory_tracking::grantor_t &scratchpad) const { |
275 | const auto &jcp = kernel_->jcp; |
276 | const auto &p_ops = attr_->post_ops_; |
277 | |
278 | const int inph = is_fwd ? jcp.ih : jcp.oh; |
279 | const int inpw = is_fwd ? jcp.iw : jcp.ow; |
280 | const int outh = is_fwd ? jcp.oh : jcp.ih; |
281 | const int outw = is_fwd ? jcp.ow : jcp.iw; |
282 | |
283 | /* Notation: |
284 | FWD: dimM:oc, dimN:ntiles, dimK:ic, |
285 | BWD: dimM:ic, dimN:ntiles, dimK:oc, |
286 | FWD/BWD: V: src/diff_dst transform, U:weight transform, |
287 | M:dst/diff_src transform */ |
288 | array_offset_calculator<float, 5> input(inp_ptr, jcp.mb, |
289 | jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); |
290 | array_offset_calculator<float, 5> output(out_ptr, jcp.mb, |
291 | jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); |
292 | array_offset_calculator<float, 6> weights(wei_ptr, |
293 | jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh, |
294 | jcp.kw, jcp.ic_simd_block, jcp.oc_simd_block); |
295 | array_offset_calculator<float, 2> bias( |
296 | bias_ptr, jcp.dimM / jcp.dimM_simd_block, jcp.dimM_simd_block); |
297 | |
298 | array_offset_calculator<float, 8> M(is_fwd |
299 | ? scratchpad.template get<float>(key_wino_M) |
300 | : scratchpad.template get<float>(key_wino_V), |
301 | jcp.dimN_nb_block, jcp.dimM_nb_block, alpha, alpha, jcp.dimN_block, |
302 | jcp.dimM_block * jcp.dimM_reg_block, jcp.dimN_reg_block, |
303 | jcp.dimM_simd_block); |
304 | |
305 | auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) |
306 | ? wei_ptr |
307 | : scratchpad.template get<float>(key_wino_U); |
308 | |
309 | array_offset_calculator<float, 8> U(wino_wei, jcp.dimM_nb_block, alpha, |
310 | alpha, jcp.dimK_nb_block, jcp.dimM_block * jcp.dimM_reg_block, |
311 | jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_simd_block); |
312 | array_offset_calculator<float, 8> V(is_fwd |
313 | ? scratchpad.template get<float>(key_wino_V) |
314 | : scratchpad.template get<float>(key_wino_M), |
315 | jcp.dimN_nb_block, alpha, alpha, jcp.dimN_block, jcp.dimK_nb_block, |
316 | jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); |
317 | |
318 | const bool wants_padded_bias |
319 | = jcp.with_bias && jcp.oc_without_padding != jcp.oc; |
320 | float last_slice_bias[simd_w] = {0}; |
321 | if (wants_padded_bias) { |
322 | for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) |
323 | last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); |
324 | } |
325 | |
326 | parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, |
327 | [&](dim_t img, dim_t K_blk1, dim_t K_blk2) { |
328 | input_transform_data(img, jcp, |
329 | &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, |
330 | 0)), |
331 | &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0))); |
332 | }); |
333 | |
334 | if (jcp.prop_kind != prop_kind::forward_inference) { |
335 | parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), |
336 | (jcp.ic_block * jcp.ic_reg_block), |
337 | [&](dim_t ofm1, dim_t ifm1, dim_t ofm2, dim_t ifm2) { |
338 | float *U_base_ptr = is_fwd |
339 | ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) |
340 | : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); |
341 | weight_transform_data(jcp, |
342 | &(weights(ofm1 * jcp.oc_block * jcp.oc_reg_block |
343 | + ofm2, |
344 | ifm1 * jcp.ic_block * jcp.ic_reg_block |
345 | + ifm2, |
346 | 0, 0, 0, 0)), |
347 | U_base_ptr); |
348 | }); |
349 | } |
350 | |
351 | parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, |
352 | [&](dim_t N_blk1, dim_t oj, dim_t oi, dim_t M_blk1) { |
353 | for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) |
354 | for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++) |
355 | kernel_->gemm_loop_ker((float *)&(M(N_blk1, M_blk1, oj, |
356 | oi, N_blk2, 0, 0, 0)), |
357 | (const float *)&( |
358 | U(M_blk1, oj, oi, K_blk1, 0, 0, 0, 0)), |
359 | (const float *)&(V(N_blk1, oj, oi, N_blk2, |
360 | K_blk1, 0, 0, 0)), |
361 | K_blk1); |
362 | }); |
363 | |
364 | parallel_nd(jcp.mb, jcp.dimM_nb_block, |
365 | (jcp.dimM_block * jcp.dimM_reg_block), |
366 | [&](dim_t img, dim_t M_blk1, dim_t M_blk2) { |
367 | const int M_blk |
368 | = M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; |
369 | |
370 | float *bias_ptr = wants_padded_bias |
371 | && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 |
372 | ? last_slice_bias |
373 | : jcp.with_bias ? &bias(M_blk, 0) : nullptr; |
374 | output_transform_data(img, jcp, p_ops, |
375 | &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), |
376 | &(output(img, M_blk, 0, 0, 0)), bias_ptr); |
377 | }); |
378 | } |
379 | |
380 | template <bool is_fwd> |
381 | void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD( |
382 | float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, |
383 | const memory_tracking::grantor_t &scratchpad) const { |
384 | const auto &jcp = kernel_->jcp; |
385 | const auto &p_ops = attr_->post_ops_; |
386 | |
387 | const int inph = is_fwd ? jcp.ih : jcp.oh; |
388 | const int inpw = is_fwd ? jcp.iw : jcp.ow; |
389 | const int outh = is_fwd ? jcp.oh : jcp.ih; |
390 | const int outw = is_fwd ? jcp.ow : jcp.iw; |
391 | |
392 | array_offset_calculator<float, 5> input(inp_ptr, jcp.mb, |
393 | jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); |
394 | array_offset_calculator<float, 5> output(out_ptr, jcp.mb, |
395 | jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); |
396 | array_offset_calculator<float, 6> weights(wei_ptr, |
397 | jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh, |
398 | jcp.kw, jcp.ic_simd_block, jcp.oc_simd_block); |
399 | array_offset_calculator<float, 2> bias( |
400 | bias_ptr, jcp.oc / jcp.oc_simd_block, jcp.oc_simd_block); |
401 | |
402 | auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) |
403 | ? wei_ptr |
404 | : scratchpad.template get<float>(key_wino_U); |
405 | |
406 | array_offset_calculator<float, 8> U(wino_wei, jcp.dimM_nb_block, alpha, |
407 | alpha, jcp.dimK_nb_block, jcp.dimM_block * jcp.dimM_reg_block, |
408 | jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_simd_block); |
409 | |
410 | array_offset_calculator<float, 8> M(is_fwd |
411 | ? scratchpad.template get<float>(key_wino_M) |
412 | : scratchpad.template get<float>(key_wino_V), |
413 | 0, jcp.dimM_nb_block, alpha, alpha, jcp.dimN_block, |
414 | jcp.dimM_block * jcp.dimM_reg_block, jcp.dimN_reg_block, |
415 | jcp.dimM_simd_block); |
416 | array_offset_calculator<float, 8> V(is_fwd |
417 | ? scratchpad.template get<float>(key_wino_V) |
418 | : scratchpad.template get<float>(key_wino_M), |
419 | 0, alpha, alpha, jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, |
420 | jcp.dimN_reg_block, jcp.dimK_reg_block); |
421 | |
422 | const bool wants_padded_bias |
423 | = jcp.with_bias && jcp.oc_without_padding != jcp.oc; |
424 | float last_slice_bias[simd_w] = {0}; |
425 | if (wants_padded_bias) { |
426 | for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) |
427 | last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); |
428 | } |
429 | |
430 | if (jcp.prop_kind != prop_kind::forward_inference) { |
431 | |
432 | parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), |
433 | (jcp.ic_block * jcp.ic_reg_block), |
434 | [&](dim_t ofm1, dim_t ifm1, dim_t ofm2, dim_t ifm2) { |
435 | float *U_base_ptr = is_fwd |
436 | ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) |
437 | : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); |
438 | weight_transform_data(jcp, |
439 | &(weights(ofm1 * jcp.oc_block * jcp.oc_reg_block |
440 | + ofm2, |
441 | ifm1 * jcp.ic_block * jcp.ic_reg_block |
442 | + ifm2, |
443 | 0, 0, 0, 0)), |
444 | U_base_ptr); |
445 | }); |
446 | } |
447 | |
448 | parallel_nd_ext(jcp.nthr, jcp.tile_block, |
449 | [&](int ithr, int nthr, dim_t tile_block) { |
450 | assert(nthr <= jcp.nthr); |
451 | MAYBE_UNUSED(nthr); |
452 | |
453 | for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) { |
454 | for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) { |
455 | |
456 | input_transform_tileblock_data(tile_block, jcp, |
457 | &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, |
458 | 0, 0)), |
459 | &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); |
460 | } |
461 | } |
462 | |
463 | for (int oj = 0; oj < alpha; oj++) { |
464 | for (int oi = 0; oi < alpha; oi++) { |
465 | for_(int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; |
466 | M_blk1++) |
467 | for_(int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; |
468 | K_blk1++) |
469 | for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) |
470 | kernel_->gemm_loop_ker( |
471 | (float *)&(M(ithr, M_blk1, oj, oi, N_blk, 0, |
472 | 0, 0)), |
473 | (const float *)&(U(M_blk1, oj, oi, K_blk1, |
474 | 0, 0, 0, 0)), |
475 | (const float *)&(V(ithr, oj, oi, N_blk, |
476 | K_blk1, 0, 0, 0)), |
477 | K_blk1); |
478 | } |
479 | } |
480 | |
481 | for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) { |
482 | for (int M_blk2 = 0; |
483 | M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; |
484 | M_blk2++) { |
485 | const int M_blk |
486 | = M_blk1 * jcp.dimM_block * jcp.dimM_reg_block |
487 | + M_blk2; |
488 | |
489 | float *bias_ptr = wants_padded_bias |
490 | && M_blk |
491 | == jcp.dimM / jcp.dimM_simd_block |
492 | - 1 |
493 | ? last_slice_bias |
494 | : jcp.with_bias ? &bias(M_blk, 0) : nullptr; |
495 | |
496 | output_transform_tileblock_data(tile_block, jcp, p_ops, |
497 | &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)), |
498 | &(output(0, M_blk, 0, 0, 0)), bias_ptr); |
499 | } |
500 | } |
501 | }); |
502 | } |
503 | |
504 | template struct _jit_avx512_core_f32_wino_conv_4x3_t<true>; |
505 | template struct _jit_avx512_core_f32_wino_conv_4x3_t<false>; |
506 | |
507 | namespace { |
508 | |
509 | void subarray_sum(size_t num_arrs, float *output, size_t nelems, |
510 | float *input_ptrs[], size_t input_starts[], size_t input_ends[]) { |
511 | using namespace nstl; |
512 | const size_t block_size = 16 * 1024 / sizeof(float); |
513 | const size_t blocks_number = nelems / block_size; |
514 | const size_t tail = nelems % block_size; |
515 | |
516 | PRAGMA_OMP(parallel) |
517 | { |
518 | const int ithr = OMP_GET_THREAD_NUM(); |
519 | const int nthr = OMP_GET_NUM_THREADS(); |
520 | size_t start {0}, end {0}; |
521 | balance211(blocks_number, nthr, ithr, start, end); |
522 | |
523 | for (size_t nb = start; nb < end; ++nb) { |
524 | size_t start_e = nb * block_size; |
525 | size_t end_e = start_e + block_size; |
526 | size_t input_start = max(start_e, min(input_starts[0], end_e)); |
527 | size_t input_end = max(start_e, min(input_ends[0], end_e)); |
528 | |
529 | PRAGMA_OMP_SIMD() |
530 | for (size_t e = start_e; e < input_start; e++) { |
531 | output[e] = 0.f; |
532 | } |
533 | |
534 | PRAGMA_OMP_SIMD() |
535 | for (size_t e = input_start; e < input_end; e++) { |
536 | output[e] = input_ptrs[0][e]; |
537 | } |
538 | |
539 | PRAGMA_OMP_SIMD() |
540 | for (size_t e = input_end; e < end_e; e++) { |
541 | output[e] = 0.f; |
542 | } |
543 | |
544 | for (size_t a = 1; a < num_arrs; a++) { |
545 | input_start = max(start_e, input_starts[a]); |
546 | input_end = min(input_ends[a], end_e); |
547 | |
548 | PRAGMA_OMP_SIMD() |
549 | for (size_t e = input_start; e < input_end; e++) { |
550 | output[e] += input_ptrs[a][e]; |
551 | } |
552 | } |
553 | } |
554 | |
555 | if (tail != 0 && ithr == nthr - 1) { |
556 | size_t start_e = nelems - tail; |
557 | size_t end_e = nelems; |
558 | size_t input_start = max(start_e, min(input_starts[0], end_e)); |
559 | size_t input_end = max(start_e, min(input_ends[0], end_e)); |
560 | |
561 | PRAGMA_OMP_SIMD() |
562 | for (size_t e = start_e; e < input_start; e++) { |
563 | output[e] = 0.f; |
564 | } |
565 | |
566 | PRAGMA_OMP_SIMD() |
567 | for (size_t e = input_start; e < input_end; e++) { |
568 | output[e] = input_ptrs[0][e]; |
569 | } |
570 | |
571 | PRAGMA_OMP_SIMD() |
572 | for (size_t e = input_end; e < end_e; e++) { |
573 | output[e] = 0.f; |
574 | } |
575 | |
576 | for (size_t a = 1; a < num_arrs; a++) { |
577 | input_start = max(start_e, input_starts[a]); |
578 | input_end = min(input_ends[a], end_e); |
579 | |
580 | PRAGMA_OMP_SIMD() |
581 | for (size_t e = input_start; e < input_end; e++) { |
582 | output[e] += input_ptrs[a][e]; |
583 | } |
584 | } |
585 | } |
586 | } |
587 | } |
588 | |
589 | const int max_threads_number = 1024; |
590 | |
591 | // Sum to the first buffer array |
592 | void array_sum(size_t num_arrs, float *output, size_t nelems, |
593 | float *input_ptrs[], bool reduce_to_first = true) { |
594 | const size_t block_size = 16 * 1024 / sizeof(float); |
595 | const size_t blocks_number = nelems / block_size; |
596 | const size_t tail = nelems % block_size; |
597 | |
598 | PRAGMA_OMP(parallel) |
599 | { |
600 | const size_t ithr = OMP_GET_THREAD_NUM(); |
601 | const size_t nthr = OMP_GET_NUM_THREADS(); |
602 | size_t start {0}, end {0}; |
603 | balance211(blocks_number, nthr, ithr, start, end); |
604 | |
605 | for (size_t nb = start; nb < end; ++nb) { |
606 | size_t start_e = nb * block_size; |
607 | size_t end_e = start_e + block_size; |
608 | if (!reduce_to_first) { |
609 | PRAGMA_OMP_SIMD() |
610 | for (size_t e = start_e; e < end_e; e++) { |
611 | output[e] = input_ptrs[0][e]; |
612 | } |
613 | } |
614 | for (size_t a = 1; a < num_arrs; a++) { |
615 | PRAGMA_OMP_SIMD() |
616 | for (size_t e = start_e; e < end_e; e++) { |
617 | output[e] += input_ptrs[a][e]; |
618 | } |
619 | } |
620 | } |
621 | |
622 | if (tail != 0 && ithr == nthr - 1) { |
623 | size_t start_e = nelems - tail; |
624 | size_t end_e = nelems; |
625 | if (!reduce_to_first) { |
626 | PRAGMA_OMP_SIMD() |
627 | for (size_t e = start_e; e < end_e; e++) { |
628 | output[e] = input_ptrs[0][e]; |
629 | } |
630 | } |
631 | for (size_t a = 1; a < num_arrs; a++) { |
632 | PRAGMA_OMP_SIMD() |
633 | for (size_t e = start_e; e < end_e; e++) { |
634 | output[e] += input_ptrs[a][e]; |
635 | } |
636 | } |
637 | } |
638 | } |
639 | } |
640 | } // namespace |
641 | |
642 | void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t:: |
643 | _execute_backward_weights_SDGtWo(const float *ptr_src, |
644 | const float *ptr_diff_dst, float *ptr_diff_weights, |
645 | float *ptr_diff_bias, |
646 | const memory_tracking::grantor_t &scratchpad) const { |
647 | const auto &jcp = kernel_->jcp; |
648 | const int nthreads = jcp.nthr; |
649 | |
650 | array_offset_calculator<float, 5> src( |
651 | (float *)ptr_src, jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); |
652 | array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, jcp.mb, |
653 | jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); |
654 | array_offset_calculator<float, 6> diff_weights(ptr_diff_weights, |
655 | jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); |
656 | |
657 | array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U), 0, |
658 | alpha, alpha, jcp.oc_block, jcp.ic_block, jcp.ic_simd_block, |
659 | jcp.oc_reg_block, jcp.oc_simd_block); |
660 | |
661 | const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc * jcp.ic |
662 | / jcp.nb_ic; |
663 | array_offset_calculator<float, 7> diff_weights_prv( |
664 | scratchpad.get<float>(key_wino_U) + U_sz, 0, jcp.oc / simd_w, |
665 | jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); |
666 | |
667 | array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M), 0, |
668 | alpha, alpha, jcp.oc_block, jcp.nb_tile_block_ur, jcp.tile_block_ur, |
669 | jcp.oc_reg_block, jcp.oc_simd_block); |
670 | |
671 | array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V), 0, |
672 | alpha, alpha, jcp.ic_block, jcp.nb_tile_block_ur, jcp.tile_block_ur, |
673 | jcp.ic_simd_block); |
674 | |
675 | array_offset_calculator<float, 2> diff_bias_prv( |
676 | scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc); |
677 | |
678 | auto trans_ker_p = jit_wino_transform_call_s(); |
679 | float I[alpha][alpha][simd_w]; |
680 | float T[alpha][alpha][simd_w]; |
681 | float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, |
682 | -0.625f, 1.5f, -1.5f, -2.640625f}; |
683 | float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, |
684 | 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, |
685 | 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; |
686 | float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; |
687 | |
688 | PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)) |
689 | { |
690 | if (jcp.with_bias) { |
691 | parallel_nd_in_omp( |
692 | nthreads, jcp.oc / simd_w, [&](dim_t ithr, dim_t ofm) { |
693 | float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w)); |
694 | PRAGMA_OMP_SIMD() |
695 | for (int v = 0; v < simd_w; v++) { |
696 | pdbias[v] = 0.0f; |
697 | } |
698 | }); |
699 | } |
700 | |
701 | int ithr = OMP_GET_THREAD_NUM(); |
702 | for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) { |
703 | int first_tblk = 0; |
704 | PRAGMA_OMP(for) |
705 | for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) { |
706 | int tile_index |
707 | = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur; |
708 | int img = tile_index / (jcp.itiles * jcp.jtiles); |
709 | trans_ker_p.ti = tile_index % jcp.itiles; |
710 | trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles; |
711 | trans_ker_p.M = I; |
712 | trans_ker_p.T = T; |
713 | trans_ker_p.G = G_I_3x3_4x4; |
714 | for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { |
715 | int ifm = ifm1 * jcp.ic_block + ifm2; |
716 | trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); |
717 | trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0)); |
718 | kernel_->src_transform(&trans_ker_p); |
719 | } |
720 | |
721 | for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) { |
722 | trans_ker_p.G = G_W_3x3_4x4; |
723 | for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { |
724 | int ofm = (ofm1 * jcp.oc_block + ofm2) |
725 | * jcp.oc_reg_block; |
726 | trans_ker_p.src |
727 | = (float *)&(diff_dst(img, ofm, 0, 0, 0)); |
728 | trans_ker_p.dst |
729 | = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0)); |
730 | if (jcp.with_bias && ifm1 == 0) { |
731 | trans_ker_p.bias = (float *)&( |
732 | diff_bias_prv(ithr, ofm * simd_w)); |
733 | kernel_->diff_dst_transform_wbias(&trans_ker_p); |
734 | } else { |
735 | kernel_->diff_dst_transform(&trans_ker_p); |
736 | } |
737 | } |
738 | |
739 | for (int oj = 0; oj < alpha; ++oj) { |
740 | for (int oi = 0; oi < alpha; ++oi) { |
741 | kernel_->gemm_loop_ker_first_iter( |
742 | &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)), |
743 | &(M(ithr, oj, oi, 0, 0, 0, 0, 0)), |
744 | &(V(ithr, oj, oi, 0, 0, 0, 0))); |
745 | } |
746 | } |
747 | trans_ker_p.G = G_O_3x3_4x4; |
748 | for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { |
749 | for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) { |
750 | int ofm = (ofm1 * jcp.oc_block + ofm2) |
751 | * jcp.oc_reg_block |
752 | + ofm3; |
753 | for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { |
754 | int ifm = ifm1 * jcp.ic_block + ifm2; |
755 | trans_ker_p.src = (float *)&( |
756 | Us(ithr, 0, 0, ofm2, ifm2, 0, ofm3, 0)); |
757 | trans_ker_p.dst = (float *)&(diff_weights_prv( |
758 | ithr, ofm, ifm, 0, 0, 0, 0)); |
759 | if (first_tblk == 0) { |
760 | kernel_->diff_weights_transform( |
761 | &trans_ker_p); |
762 | } else { |
763 | kernel_->diff_weights_transform_accum( |
764 | &trans_ker_p); |
765 | } |
766 | } |
767 | } |
768 | } |
769 | } |
770 | ++first_tblk; |
771 | } |
772 | } |
773 | } |
774 | |
775 | // Reduce diff-weights |
776 | { |
777 | float *output = ptr_diff_weights; |
778 | float *input_base = scratchpad.get<float>(key_wino_U) + U_sz; |
779 | int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw; |
780 | float *input_ptrs[max_threads_number]; |
781 | for (int i = 0; i < nthreads; ++i) { |
782 | input_ptrs[i] = input_base + nelems * i; |
783 | } |
784 | array_sum(nthreads, output, nelems, input_ptrs, false); |
785 | |
786 | if (jcp.with_bias) { |
787 | output = ptr_diff_bias; |
788 | input_base = scratchpad.get<float>(key_conv_bia_reduction); |
789 | for (int i = 0; i < nthreads; ++i) { |
790 | input_ptrs[i] = input_base + jcp.oc * i; |
791 | } |
792 | array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs, |
793 | false); |
794 | } |
795 | } |
796 | } |
797 | |
798 | void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t:: |
799 | _execute_backward_weights_S_D_Giot_W(const float *ptr_src, |
800 | const float *ptr_diff_dst, float *ptr_diff_weights, |
801 | float *ptr_diff_bias, |
802 | const memory_tracking::grantor_t &scratchpad) const { |
803 | const auto &jcp = kernel_->jcp; |
804 | const int nthreads = jcp.nthr; |
805 | |
806 | array_offset_calculator<float, 5> src( |
807 | (float *)ptr_src, jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); |
808 | array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, jcp.mb, |
809 | jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); |
810 | array_offset_calculator<float, 6> diff_weights((float *)ptr_diff_weights, |
811 | jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); |
812 | array_offset_calculator<float, 1> diff_bias((float *)ptr_diff_bias, jcp.oc); |
813 | |
814 | array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U), |
815 | jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.oc_block, jcp.ic_block, |
816 | jcp.ic_simd_block, jcp.oc_reg_block, jcp.oc_simd_block); |
817 | |
818 | const int U_size = jcp.oc * jcp.ic * alpha * alpha; |
819 | array_offset_calculator<float, 10> Us( |
820 | scratchpad.get<float>(key_wino_U) + U_size, 0, jcp.nb_ic, jcp.nb_oc, |
821 | alpha, alpha, jcp.oc_block, jcp.ic_block, jcp.ic_simd_block, |
822 | jcp.oc_reg_block, jcp.oc_simd_block); |
823 | |
824 | array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M), |
825 | jcp.nb_oc, jcp.tile_block, alpha, alpha, jcp.oc_block, |
826 | jcp.nb_tile_block_ur, jcp.tile_block_ur, jcp.oc_reg_block, |
827 | jcp.oc_simd_block); |
828 | |
829 | array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V), |
830 | jcp.nb_ic, jcp.tile_block, alpha, alpha, jcp.ic_block, |
831 | jcp.nb_tile_block_ur, jcp.tile_block_ur, jcp.ic_simd_block); |
832 | |
833 | array_offset_calculator<float, 2> diff_bias_prv( |
834 | scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc); |
835 | |
836 | size_t input_starts[max_threads_number] = {0}; |
837 | size_t input_ends[max_threads_number] = {0}; |
838 | size_t first_tblk = 0; |
839 | |
840 | auto trans_ker_p = jit_wino_transform_call_s(); |
841 | float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, |
842 | -0.625f, 1.5f, -1.5f, -2.640625f}; |
843 | float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, |
844 | 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, |
845 | 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; |
846 | float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; |
847 | float I[alpha][alpha][simd_w]; |
848 | float T[alpha][alpha][simd_w]; |
849 | |
850 | PRAGMA_OMP(parallel num_threads(nthreads) |
851 | firstprivate(first_tblk, trans_ker_p, I, T)) |
852 | { |
853 | if (jcp.with_bias) { |
854 | parallel_nd_in_omp(nthreads, jcp.oc, [&](dim_t ithr, dim_t ofm) { |
855 | diff_bias_prv(ithr, ofm) = 0.0f; |
856 | }); |
857 | } |
858 | |
859 | trans_ker_p.G = G_I_3x3_4x4; |
860 | trans_ker_p.M = I; |
861 | trans_ker_p.T = T; |
862 | |
863 | parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb, |
864 | [&](dim_t ifm1, dim_t ifm2, dim_t img) { |
865 | size_t ifm = ifm1 * jcp.ic_block + ifm2; |
866 | size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); |
867 | size_t tblk3 = tile_base_index % jcp.tile_block_ur; |
868 | size_t tblk2 = (tile_base_index / jcp.tile_block_ur) |
869 | % jcp.nb_tile_block_ur; |
870 | size_t tblk1 = (tile_base_index / jcp.tile_block_ur) |
871 | / jcp.nb_tile_block_ur; |
872 | trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; |
873 | trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); |
874 | trans_ker_p.dst |
875 | = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0)); |
876 | kernel_->src_transform(&trans_ker_p); |
877 | }); |
878 | |
879 | int ithr = OMP_GET_THREAD_NUM(); |
880 | trans_ker_p.G = G_W_3x3_4x4; |
881 | parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb, |
882 | [&](dim_t ofm1, dim_t ofm2, dim_t img) { |
883 | int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; |
884 | size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); |
885 | size_t tblk3 = tile_base_index % jcp.tile_block_ur; |
886 | size_t tblk2 = (tile_base_index / jcp.tile_block_ur) |
887 | % jcp.nb_tile_block_ur; |
888 | size_t tblk1 = (tile_base_index / jcp.tile_block_ur) |
889 | / jcp.nb_tile_block_ur; |
890 | trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; |
891 | trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); |
892 | trans_ker_p.dst = (float *)&( |
893 | M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0)); |
894 | if (jcp.with_bias) { |
895 | trans_ker_p.bias |
896 | = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); |
897 | kernel_->diff_dst_transform_wbias(&trans_ker_p); |
898 | } else { |
899 | kernel_->diff_dst_transform(&trans_ker_p); |
900 | } |
901 | }); |
902 | |
903 | PRAGMA_OMP(barrier) |
904 | |
905 | parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block, |
906 | [&](dim_t ifm1, dim_t ofm1, dim_t oj, dim_t oi, dim_t tblk1) { |
907 | if (first_tblk == 0) { |
908 | input_starts[ithr] = (float *)&(Us(ithr, ifm1, ofm1, oj, |
909 | oi, 0, 0, 0, 0, 0)) |
910 | - (float *)&( |
911 | Us(ithr, 0, 0, 0, 0, 0, 0, 0, 0, 0)); |
912 | input_ends[ithr] = input_starts[ithr] |
913 | + jcp.oc_block * jcp.ic_block |
914 | * jcp.ic_simd_block * jcp.oc_reg_block |
915 | * jcp.oc_simd_block; |
916 | } else if (tblk1 == 0) { |
917 | input_ends[ithr] += jcp.oc_block * jcp.ic_block |
918 | * jcp.ic_simd_block * jcp.oc_reg_block |
919 | * jcp.oc_simd_block; |
920 | } |
921 | |
922 | if (first_tblk == 0 || tblk1 == 0) { |
923 | kernel_->gemm_loop_ker_first_iter( |
924 | &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0, 0)), |
925 | &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), |
926 | &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); |
927 | } else { |
928 | kernel_->gemm_loop_ker( |
929 | &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0, 0)), |
930 | &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), |
931 | &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); |
932 | } |
933 | ++first_tblk; |
934 | }); |
935 | } |
936 | |
937 | // Reduce diff-weights |
938 | { |
939 | float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0)); |
940 | size_t nelems = jcp.ic * jcp.oc * alpha * alpha; |
941 | float *input_ptrs[max_threads_number]; |
942 | for (int i = 0; i < nthreads; ++i) |
943 | input_ptrs[i] = output + nelems * (i + 1); |
944 | subarray_sum( |
945 | nthreads, output, nelems, input_ptrs, input_starts, input_ends); |
946 | } |
947 | |
948 | trans_ker_p.G = G_O_3x3_4x4; |
949 | PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p)) |
950 | { |
951 | parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, |
952 | jcp.oc_reg_block, |
953 | [&](dim_t ifm1, dim_t ofm1, dim_t ofm2, dim_t ifm2, |
954 | dim_t ofm3) { |
955 | int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block |
956 | + ofm3; |
957 | int ifm = ifm1 * jcp.ic_block + ifm2; |
958 | trans_ker_p.src = (float *)&( |
959 | U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, ofm3, 0)); |
960 | trans_ker_p.dst |
961 | = (float *)&(diff_weights(ofm, ifm, 0, 0, 0, 0)); |
962 | kernel_->diff_weights_transform(&trans_ker_p); |
963 | }); |
964 | } |
965 | |
966 | if (jcp.with_bias) { |
967 | parallel_nd(jcp.oc / simd_w, [&](dim_t ofm1) { |
968 | float *pbias = &(diff_bias(ofm1 * simd_w)); |
969 | float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w)); |
970 | |
971 | const int blk_sz = ofm1 == jcp.oc / simd_w - 1 |
972 | ? jcp.oc_without_padding - ofm1 * simd_w |
973 | : simd_w; |
974 | |
975 | PRAGMA_OMP_SIMD() |
976 | for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { |
977 | pbias[ofm2] = pbias_prv[ofm2]; |
978 | } |
979 | |
980 | for (int ithr = 1; ithr < nthreads; ++ithr) { |
981 | pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w)); |
982 | PRAGMA_OMP_SIMD() |
983 | for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { |
984 | pbias[ofm2] += pbias_prv[ofm2]; |
985 | } |
986 | } |
987 | }); |
988 | } |
989 | } |
990 | |
991 | } // namespace x64 |
992 | } // namespace cpu |
993 | } // namespace impl |
994 | } // namespace dnnl |
995 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
996 | |