1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <functional> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/cpu_batch_normalization_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | #include "cpu/x64/cpu_barrier.hpp" |
31 | #include "cpu/x64/jit_generator.hpp" |
32 | |
33 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
34 | #include "cpu/x64/jit_uni_batch_normalization.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | using namespace memory_tracking::names; |
42 | |
43 | using namespace Xbyak; |
44 | namespace barrier = simple_barrier; |
45 | |
46 | using acc_data_t = float; |
47 | |
48 | namespace { |
49 | dim_t get_c_padded(const batch_normalization_pd_t *pd) { |
50 | return pd->src_md()->padded_dims[1]; |
51 | } |
52 | |
53 | bool is_nspc(const memory_desc_wrapper &d) { |
54 | using namespace format_tag; |
55 | const bool is_nspc = d.matches_one_of_tag(nc, nwc, nhwc, ndhwc); |
56 | return is_nspc; |
57 | } |
58 | } // namespace |
59 | |
60 | struct jit_bnorm_conf_t { |
61 | // TODO: put all needed info here to avoid duplicate work and potentially |
62 | // diverging definitions of derived parameters |
63 | const batch_normalization_pd_t *pd_; |
64 | |
65 | int simd_w_ {0}; |
66 | size_t dt_size_ {0}; |
67 | bool is_nspc_ {false}; |
68 | |
69 | // thread partition info |
70 | bool do_blocking_ {false}; |
71 | bool is_spatial_thr_ {false}; |
72 | dim_t C_blks_per_iter_ {0}; |
73 | int C_nthr_ {0}; |
74 | int N_nthr_ {0}; |
75 | int S_nthr_ {0}; |
76 | int64_t iters_ {0}; |
77 | // C_blks and thread partition can change for last iteration |
78 | dim_t C_blks_last_iter_ {0}; |
79 | int C_nthr_last_iter_ {0}; |
80 | int N_nthr_last_iter_ {0}; |
81 | int S_nthr_last_iter_ {0}; |
82 | |
83 | jit_bnorm_conf_t(const batch_normalization_pd_t *pd, int nthr, int simd_w) |
84 | : pd_(pd), simd_w_(simd_w) { |
85 | |
86 | const dim_t N = pd_->MB(); |
87 | const dim_t C_PADDED = get_c_padded(pd_); |
88 | const dim_t D = pd_->D(); |
89 | const dim_t H = pd_->H(); |
90 | const dim_t W = pd_->W(); |
91 | const dim_t SP = D * H * W; |
92 | |
93 | const memory_desc_wrapper src_d(pd_->src_md()); |
94 | is_nspc_ = is_nspc(src_d); |
95 | |
96 | dt_size_ = types::data_type_size(pd_->src_md()->data_type); |
97 | size_t data_size = dt_size_ * N * C_PADDED * SP; |
98 | const size_t l3_size = platform::get_per_core_cache_size(3) * nthr; |
99 | // TODO: cache balancing for nspc |
100 | const size_t l3_filling_factor = 4; |
101 | do_blocking_ = !is_nspc_ && data_size >= l3_size / l3_filling_factor; |
102 | |
103 | // find thread partition over N, C_blks and SP |
104 | |
105 | const dim_t C_blks = C_PADDED / simd_w_; |
106 | |
107 | if (do_blocking_) { |
108 | const int num_tensors = pd_->is_fwd() ? 1 : 2; |
109 | const size_t working_set_size |
110 | = dt_size_ * (N * SP * simd_w_) * num_tensors; |
111 | bnorm_utils::cache_balance(working_set_size, C_blks, N, nthr, |
112 | C_blks_per_iter_, iters_); |
113 | C_blks_last_iter_ = C_blks - (iters_ - 1) * C_blks_per_iter_; |
114 | } else { |
115 | C_blks_per_iter_ = C_blks; |
116 | iters_ = 1; |
117 | } |
118 | |
119 | is_spatial_thr_ |
120 | = this->thread_partition(/* is_spatial_thr_ = */ true, nthr, |
121 | /* dimensions */ |
122 | N, C_blks_per_iter_, SP, |
123 | /* outputs */ |
124 | C_nthr_, N_nthr_, S_nthr_); |
125 | |
126 | if (iters_ > 1) |
127 | this->thread_partition(is_spatial_thr_, nthr, |
128 | /* dimensions */ |
129 | N, C_blks_last_iter_, SP, |
130 | /* outputs */ |
131 | C_nthr_last_iter_, N_nthr_last_iter_, S_nthr_last_iter_); |
132 | } |
133 | |
134 | // given nthr and shape of problem, choose the thread partition |
135 | // to use (ie set N_nthr, C_nthr, and S_nthr) |
136 | bool thread_partition(bool spatial_thr_allowed, int nthr, dim_t N, |
137 | dim_t C_blks, dim_t SP, int &C_nthr, int &N_nthr, int &S_nthr) { |
138 | if (((nthr <= C_blks) && IMPLICATION(is_nspc_, N == 1)) |
139 | || !dnnl_thr_syncable()) { |
140 | C_nthr = nthr; |
141 | N_nthr = 1; |
142 | S_nthr = 1; |
143 | } else { |
144 | if (is_nspc_) { |
145 | if (C_blks <= 8) |
146 | C_nthr = 1; |
147 | else if (nthr >= 8 && C_blks <= 32) |
148 | C_nthr = 8; |
149 | else { |
150 | C_nthr = (int)math::gcd((dim_t)nthr, C_blks); |
151 | // Unroll by channels in JIT kernel |
152 | if ((C_nthr == C_blks) || (C_nthr == nthr)) C_nthr = 1; |
153 | } |
154 | N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr); |
155 | // heuristic for training on avx512_core_amx |
156 | // TODO: test heuristic when global stats flag is set |
157 | if (!pd_->use_global_stats() && 0 < dt_size_ && 0 < simd_w_ |
158 | && 1 < C_nthr && nthr <= N |
159 | && mayiuse(avx512_core_amx)) { |
160 | const size_t data_size |
161 | = dt_size_ * N * SP * C_blks * simd_w_; |
162 | const size_t C_split_data_size |
163 | = utils::div_up(data_size, N_nthr); |
164 | const size_t N_split_data_size |
165 | = utils::div_up(data_size, nthr); |
166 | const size_t l2_size_per_core |
167 | = platform::get_per_core_cache_size(2); |
168 | const size_t l3_size_per_core |
169 | = platform::get_per_core_cache_size(3); |
170 | const size_t cache_size_per_core |
171 | = l2_size_per_core + l3_size_per_core; |
172 | // if current split is too big for cache, better to split by N |
173 | const bool condition1 |
174 | = cache_size_per_core < C_split_data_size; |
175 | // if split by N is also too big for cache, bwd is better off as it was |
176 | const bool condition2 = pd_->is_fwd() |
177 | || cache_size_per_core >= N_split_data_size; |
178 | if (condition1 && condition2) { |
179 | C_nthr = 1; |
180 | N_nthr = nthr; |
181 | } |
182 | } |
183 | S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr)); |
184 | } else { |
185 | if (do_blocking_) { |
186 | N_nthr = (int)nstl::min<dim_t>(N, nthr); |
187 | C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr); |
188 | S_nthr = (int)nstl::min<dim_t>( |
189 | SP, nthr / (C_nthr * N_nthr)); |
190 | } else { |
191 | C_nthr = (int)math::gcd((dim_t)nthr, C_blks); |
192 | N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr); |
193 | S_nthr = (int)nstl::min<dim_t>( |
194 | SP, nthr / (C_nthr * N_nthr)); |
195 | } |
196 | } |
197 | |
198 | if (!spatial_thr_allowed) S_nthr = 1; |
199 | |
200 | if (S_nthr < 1) S_nthr = 1; |
201 | } |
202 | |
203 | // spatial_thr_allowed is meant to help maintain |
204 | // consistent decisions about spatial threading |
205 | // between mutiple invocations of this routine. |
206 | // It is caller's responsibility to check the |
207 | // return value and pass it as a flag to the |
208 | // next call if needed. |
209 | if (S_nthr == 1) spatial_thr_allowed = false; |
210 | |
211 | return spatial_thr_allowed; |
212 | } |
213 | }; |
214 | |
215 | template <cpu_isa_t isa> |
216 | struct jit_bnorm_t : public jit_generator { |
217 | struct call_params_t { |
218 | // keep all sizes at 8 bytes -- jit code expects this |
219 | size_t N_ithr, N_nthr; |
220 | size_t coff_max, soff_max; |
221 | size_t mb_stride_Bc, spat_size, spat_size_loc; |
222 | size_t S_s, S_tail; |
223 | size_t is_cblk_tail; |
224 | acc_data_t chan_size, eps, one; |
225 | const acc_data_t *scale; |
226 | const acc_data_t *shift; |
227 | const acc_data_t *mean, *var; |
228 | const acc_data_t *diff_scale; |
229 | const acc_data_t *diff_shift; |
230 | const void *src, *dst; |
231 | const void *diff_src, *diff_dst; |
232 | const acc_data_t *rbuf1, *rbuf2; |
233 | const uint8_t *ws; |
234 | barrier::ctx_64_t *barrier; |
235 | }; |
236 | |
237 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t) |
238 | |
239 | /* cpu specific part */ |
240 | using Vmm = typename utils::conditional3<isa == sse41, Xmm, isa == avx2, |
241 | Ymm, Zmm>::type; |
242 | const AddressFrame &vmmword |
243 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
244 | |
245 | const int vlen = isa == sse41 ? 32 : cpu_isa_traits<isa>::vlen; |
246 | int vlen_spat_data_ |
247 | = 0; // set by ctor depending on data type (xF16 or FP32); |
248 | |
249 | const batch_normalization_pd_t *pd_ = nullptr; |
250 | const jit_bnorm_conf_t *jbp_ = nullptr; |
251 | bool is_bf16_ = false; |
252 | bool is_f16_ = false; |
253 | bool is_avx2_ne_xf16_ = false; |
254 | |
255 | Reg64 reg_param = abi_param1; |
256 | |
257 | Reg64 reg_scale = rbx; |
258 | Reg64 reg_rbuf1 = abi_not_param1; |
259 | Reg64 reg_rbuf2 = rdx; |
260 | Reg64 reg_coff_max_fwd_copy = reg_rbuf2; |
261 | |
262 | Reg64 reg_mean = rbp; |
263 | Reg64 reg_var = reg_param; |
264 | Reg64 reg_diff_scale = rax; |
265 | Reg64 reg_coff_max_bwd_copy = reg_diff_scale; |
266 | Reg64 reg_shift = reg_rbuf1; |
267 | |
268 | Reg64 reg_coff = r8; |
269 | Reg64 reg_coff_max = r9; |
270 | Reg64 reg_soff = r10; |
271 | Reg64 reg_soff_max = r11; |
272 | Reg64 reg_diff_shift = reg_soff_max; |
273 | Reg64 reg_ctr = r12; |
274 | Reg64 reg_roff = r13; |
275 | |
276 | Reg64 reg_mb_stride_Bc = r14; |
277 | Reg64 reg_soff_nspc = reg_mb_stride_Bc; |
278 | |
279 | Reg64 reg_src = r15; |
280 | Reg64 reg_diff_src = reg_rbuf1; |
281 | Reg64 reg_dst = rsi; |
282 | Reg64 reg_diff_dst = reg_dst; |
283 | |
284 | Reg64 reg_tmp_off = reg_roff; |
285 | |
286 | // Reuse loop counters |
287 | Reg64 reg_bar = reg_coff; |
288 | Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff |
289 | Reg64 reg_tmp = reg_ctr; |
290 | |
291 | // Relu section |
292 | bool with_relu = false, with_relu_inf_only = false; |
293 | Reg64 reg_ws = reg_roff; |
294 | Reg64 reg_tmp_alpha = reg_diff_scale; // required in sse41 |
295 | Label l_relu_mask_avx2; |
296 | Opmask kstore_mask = Opmask(1); |
297 | |
298 | // channel tail processing |
299 | Opmask ktail_mask = Opmask(2); |
300 | |
301 | // FP32->BF16 emulation |
302 | bf16_emulation_t *bf16_emu_ {nullptr}; |
303 | Reg64 reg_bf16_tmp = reg_tmp; |
304 | Zmm bf16_emu_reserved_1 = Zmm(17); |
305 | Zmm bf16_emu_reserved_2 = Zmm(18); |
306 | Zmm bf16_emu_reserved_3 = Zmm(19); |
307 | Zmm bf16_emu_reserved_4 = Zmm(20); |
308 | |
309 | size_t unroll_blocks; |
310 | size_t unroll_regs; |
311 | Vmm vdiff_beta = Vmm(isa == avx512_core ? 21 : 6); |
312 | Vmm vdiff_gamma = Vmm(isa == avx512_core ? 22 : 7); |
313 | Vmm vsqrtvar = Vmm(isa == avx512_core ? 23 : 8); |
314 | Vmm vone = Vmm(isa == avx512_core ? 24 : 9); |
315 | Vmm vmean = Vmm(isa == avx512_core ? 25 : 10); |
316 | Vmm vgamma = Vmm(isa == avx512_core ? 26 : 11); |
317 | Vmm vbeta = Vmm(isa == avx512_core ? 27 : 12); |
318 | Vmm veps = Vmm(isa == avx512_core ? 28 : 13); |
319 | Vmm vchan_size = Vmm(isa == avx512_core ? 29 : 14); |
320 | Vmm vtail_mask = Vmm(isa == avx512_core ? 30 : 15); |
321 | Vmm vtmp = Vmm(isa == avx512_core ? 31 : 5); |
322 | Vmm vsrc_aux = vdiff_gamma; // used for xf16 with nspc ON AVX2 |
323 | Vmm vdst_aux = vdiff_gamma; // used for ReLU in AVX2 & sse41 |
324 | Vmm vmask = Vmm(0); |
325 | Vmm vzero; // is_fwd() ? vdiff_beta : vbeta |
326 | |
327 | size_t spat_size; |
328 | size_t chan_data_offt; |
329 | size_t spat_step; |
330 | size_t mb_offt; |
331 | size_t ws_mb_offt; |
332 | |
333 | enum { |
334 | stack_off_N_nthr = 0, |
335 | stack_off_N_ithr = 8, |
336 | stack_off_src = 16, |
337 | stack_off_dst = 24, |
338 | stack_off_diff_src = 32, |
339 | stack_off_diff_dst = 40, |
340 | stack_off_diff_scale = 48, |
341 | stack_off_ws = 56, |
342 | stack_off_barrier = 64, |
343 | stack_off_spat_size_loc = 72, |
344 | stack_off_s_s = 80, |
345 | stack_off_s_tail = 88, |
346 | stack_off_is_cblk_tail = 96, |
347 | stack_off_ws_off_copy = 104, |
348 | stack_off_shift = 112, |
349 | stack_off_diff_shift = 120, |
350 | stack_off_soff_max = 128, |
351 | stack_off_relu_alpha = 136, |
352 | stack_size_required = 144, |
353 | }; |
354 | |
355 | bool is_xf16() { return is_bf16_ || is_f16_; } |
356 | int bit_shift() { return 5 - is_xf16(); } |
357 | |
358 | bool use_bf16_emulation() { |
359 | return is_bf16_ && isa == avx512_core && !mayiuse(avx512_core_bf16); |
360 | } |
361 | |
362 | bool stream_store_supported() { |
363 | // keep original behavior for f32 |
364 | if (!is_xf16()) return true; |
365 | // TODO: check performance of heuristic for other cases, such as: |
366 | // blocked layout, pre-avx512_core_amx machines, and f32 datatype. |
367 | const bool is_applicable = jbp_->is_nspc_ && mayiuse(avx512_core_amx); |
368 | if (!is_applicable) return false; |
369 | const size_t l2_size_per_core = platform::get_per_core_cache_size(2); |
370 | const size_t l3_size_per_core = platform::get_per_core_cache_size(3); |
371 | const size_t cache_size_per_core = l2_size_per_core + l3_size_per_core; |
372 | const size_t buffer_count = pd_->is_fwd() ? 2 : 3; |
373 | const size_t data_size = buffer_count * jbp_->dt_size_ * pd_->MB() |
374 | * pd_->C() * pd_->D() * pd_->H() * pd_->W(); |
375 | // do not divide by C_nthr for nspc layout |
376 | const size_t data_size_per_core |
377 | = data_size / (jbp_->N_nthr_ * jbp_->S_nthr_); |
378 | return cache_size_per_core < data_size_per_core; |
379 | } |
380 | |
381 | bool is_c_padded() const { |
382 | const memory_desc_wrapper data_d(pd_->src_md()); |
383 | return pd_->C() != data_d.padded_dims()[1]; |
384 | } |
385 | |
386 | void compute_static_strides() { |
387 | spat_size = pd_->D() * pd_->W() * pd_->H(); |
388 | chan_data_offt = pd_->C() * sizeof(acc_data_t); |
389 | spat_step = jbp_->is_nspc_ ? chan_data_offt / (1 + is_xf16()) |
390 | : vlen_spat_data_; |
391 | mb_offt = spat_step * spat_size; |
392 | ws_mb_offt = (spat_step / (is_xf16() ? 16 : 32)) * spat_size; |
393 | } |
394 | |
395 | void load_common_params() { |
396 | #define PARAM_OFF(x) offsetof(call_params_t, x) |
397 | mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]); |
398 | if (!pd_->is_fwd()) mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]); |
399 | mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]); |
400 | mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]); |
401 | mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]); |
402 | shl(reg_coff_max, 2); |
403 | |
404 | mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); |
405 | mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]); |
406 | |
407 | uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]); |
408 | uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]); |
409 | uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); |
410 | |
411 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]); |
412 | mov(ptr[rsp + stack_off_N_nthr], reg_tmp); |
413 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]); |
414 | mov(ptr[rsp + stack_off_N_ithr], reg_tmp); |
415 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]); |
416 | mov(ptr[rsp + stack_off_src], reg_tmp); |
417 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]); |
418 | mov(ptr[rsp + stack_off_dst], reg_tmp); |
419 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]); |
420 | mov(ptr[rsp + stack_off_diff_src], reg_tmp); |
421 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]); |
422 | mov(ptr[rsp + stack_off_diff_dst], reg_tmp); |
423 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]); |
424 | mov(ptr[rsp + stack_off_ws], reg_tmp); |
425 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]); |
426 | mov(ptr[rsp + stack_off_barrier], reg_tmp); |
427 | if (jbp_->is_spatial_thr_) { |
428 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]); |
429 | mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp); |
430 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]); |
431 | mov(ptr[rsp + stack_off_s_s], reg_tmp); |
432 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]); |
433 | mov(ptr[rsp + stack_off_s_tail], reg_tmp); |
434 | } |
435 | if (is_c_padded()) { |
436 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]); |
437 | mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp); |
438 | } |
439 | |
440 | if (pd_->is_fwd()) { |
441 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(shift)]); |
442 | mov(ptr[rsp + stack_off_shift], reg_tmp); |
443 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); |
444 | mov(reg_var, reg_tmp); |
445 | } else { |
446 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale)]); |
447 | mov(ptr[rsp + stack_off_diff_scale], reg_tmp); |
448 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_shift)]); |
449 | mov(ptr[rsp + stack_off_diff_shift], reg_tmp); |
450 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(soff_max)]); |
451 | mov(ptr[rsp + stack_off_soff_max], reg_tmp); |
452 | mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); |
453 | mov(reg_var, reg_tmp); |
454 | } |
455 | if (with_relu_inf_only && pd_->alpha() != 0.f) { |
456 | mov(reg_tmp, float2int(pd_->alpha())); |
457 | mov(ptr[rsp + stack_off_relu_alpha], reg_tmp); |
458 | } |
459 | #undef PARAM_OFF |
460 | } |
461 | |
462 | void prepare_tail_mask_avx512_common() { |
463 | if (!is_c_padded()) return; |
464 | |
465 | const int tail = pd_->C() % (int)(vlen / sizeof(float)); |
466 | const int mask = (1 << tail) - 1; |
467 | |
468 | Reg32 regw_tmp = reg_tmp.cvt32(); |
469 | mov(regw_tmp, mask); |
470 | kmovw(ktail_mask, regw_tmp); |
471 | } |
472 | |
473 | void prepare_tail_mask_avx2_common() { |
474 | if (!is_c_padded()) return; |
475 | |
476 | const int tail = pd_->C() % (int)(vlen / sizeof(float)); |
477 | static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, |
478 | 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, |
479 | 0, 0, 0, 0, 0, 0, 0}; |
480 | |
481 | mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail])); |
482 | vmovups(vtail_mask, ptr[reg_tmp]); |
483 | } |
484 | |
485 | void prepare_relu() { |
486 | with_relu = pd_->is_fwd() ? pd_->with_relu_post_op(pd_->is_training()) |
487 | || pd_->fuse_norm_relu() |
488 | : pd_->fuse_norm_relu(); |
489 | with_relu_inf_only = with_relu && pd_->is_fwd() |
490 | && !(pd_->fuse_norm_relu() && pd_->is_training()); |
491 | |
492 | vzero = pd_->is_fwd() ? vdiff_beta : vbeta; |
493 | if (with_relu) { |
494 | uni_vpxor(vzero, vzero, vzero); |
495 | if (!pd_->is_fwd() && isa == avx2) prepare_l_relu_mask_avx2(); |
496 | } |
497 | } |
498 | |
499 | void prepare_l_relu_mask_avx2() { |
500 | Label l_mask_after; |
501 | jmp(l_mask_after); |
502 | align(32); |
503 | L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ |
504 | for (int i = 0; i < 8; ++i) |
505 | dd(1 << i); |
506 | L(l_mask_after); |
507 | } |
508 | |
509 | void fwd_process_relu_avx2(Vmm vdst, int offt) { |
510 | Reg64 reg_store_mask = reg_diff_scale; |
511 | Reg64 reg_soff_loc = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff; |
512 | shr(reg_soff_loc, bit_shift()); |
513 | vcmpps(vtmp, vzero, vdst, _cmp_lt_os); |
514 | vmovmskps(reg_store_mask, vtmp); |
515 | mov(ptr[reg_ws + reg_soff_loc + offt / (1 << bit_shift())], |
516 | reg_store_mask.cvt8()); |
517 | vblendvps(vdst, vzero, vdst, vtmp); |
518 | shl(reg_soff_loc, bit_shift()); |
519 | } |
520 | |
521 | void fwd_process_relu_avx512_common(Vmm vdst, int offt = 0) { |
522 | Reg64 reg_soff_loc = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff; |
523 | shr(reg_soff_loc, bit_shift()); |
524 | vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os); |
525 | kmovw(ptr[reg_ws + reg_soff_loc + offt / (1 << bit_shift())], |
526 | kstore_mask); |
527 | vblendmps(vdst | kstore_mask, vzero, vdst); |
528 | shl(reg_soff_loc, bit_shift()); |
529 | } |
530 | |
531 | void fwd_process_relu_alpha(Vmm vmm_dst) { |
532 | if (isa == avx512_core) |
533 | fwd_process_relu_alpha_avx512_common(vmm_dst); |
534 | else { |
535 | assert(utils::one_of(isa, avx2, sse41)); |
536 | if (vmm_dst.getIdx() == 0) { |
537 | uni_vmovups(vdst_aux, vmm_dst); |
538 | fwd_process_relu_alpha_avx2(vdst_aux); |
539 | uni_vmovups(Vmm(0), vdst_aux); |
540 | } else |
541 | fwd_process_relu_alpha_avx2(vmm_dst); |
542 | } |
543 | } |
544 | void fwd_process_relu_alpha_avx512_common(Vmm vmm_dst) { |
545 | const Xmm xmm_tmp = Xmm(vtmp.getIdx()); |
546 | vmovq(xmm_tmp, ptr[rsp + stack_off_relu_alpha]); |
547 | vbroadcastss(vtmp, xmm_tmp); |
548 | vcmpps(kstore_mask, vzero, vmm_dst, _cmp_lt_os); |
549 | vmulps(vtmp, vmm_dst, vtmp); |
550 | vblendmps(vmm_dst | kstore_mask, vtmp, vmm_dst); |
551 | } |
552 | |
553 | void fwd_process_relu_alpha_avx2(Vmm vmm_dst) { |
554 | const Xmm xmm_tmp = Xmm(vtmp.getIdx()); |
555 | uni_vpxor(vmask, vmask, vmask); |
556 | if (isa == sse41) { |
557 | mov(reg_tmp_alpha, ptr[rsp + stack_off_relu_alpha]); |
558 | uni_vmovq(xmm_tmp, reg_tmp_alpha); |
559 | } else |
560 | vmovq(xmm_tmp, ptr[rsp + stack_off_relu_alpha]); |
561 | uni_vbroadcastss(vtmp, xmm_tmp); |
562 | uni_vcmpps(vmask, vmm_dst, vzero, _cmp_lt_os); |
563 | uni_vmulps(vtmp, vtmp, vmm_dst); |
564 | uni_vblendvps(vmm_dst, vmm_dst, vtmp, vmask); |
565 | } |
566 | |
567 | void bwd_process_relu_avx2(Vmm vdiff_dst, int offt) { |
568 | shr(reg_soff, bit_shift()); |
569 | vpbroadcastb(vtmp, ptr[reg_ws + reg_soff + offt / (1 << bit_shift())]); |
570 | vpand(vtmp, vtmp, ptr[rip + l_relu_mask_avx2]); |
571 | vpcmpeqd(vtmp, vtmp, ptr[rip + l_relu_mask_avx2]); |
572 | vblendvps(vdiff_dst, vzero, vdiff_dst, vtmp); |
573 | shl(reg_soff, bit_shift()); |
574 | } |
575 | |
576 | void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt = 0) { |
577 | shr(jbp_->is_nspc_ ? reg_soff_nspc : reg_soff, bit_shift()); |
578 | kmovw(kstore_mask, |
579 | ptr[reg_ws + (jbp_->is_nspc_ ? reg_soff_nspc : reg_soff) |
580 | + offt / (1 << bit_shift())]); |
581 | vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst); |
582 | shl(jbp_->is_nspc_ ? reg_soff_nspc : reg_soff, bit_shift()); |
583 | } |
584 | |
585 | void merge_interleaved_to_plain( |
586 | const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0) { |
587 | Ymm ymm_even = Ymm(vmm_even.getIdx()); |
588 | Ymm ymm_odd = Ymm(vmm_odd.getIdx()); |
589 | Ymm ymm_aux0 = Ymm(vmm_aux0.getIdx()); |
590 | Ymm ymm_aux1 = ymm_odd; |
591 | |
592 | vpunpckldq(ymm_aux0, ymm_even, ymm_odd); |
593 | vpunpckhdq(ymm_aux1, ymm_even, ymm_odd); |
594 | vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20); |
595 | vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31); |
596 | } |
597 | void uni_vmovups_spat_data( |
598 | const Vmm &vmm_even, const Vmm &vmm_odd, const Address &addr) { |
599 | // load two simd_w data from addr into two registers |
600 | if (is_bf16_) { |
601 | // convert bf16 input to f32 |
602 | vcvtneebf162ps(vmm_even, addr); |
603 | vcvtneobf162ps(vmm_odd, addr); |
604 | } else if (is_f16_) { |
605 | vcvtneeph2ps(vmm_even, addr); |
606 | vcvtneoph2ps(vmm_odd, addr); |
607 | } else |
608 | assert(!"unsupported data type!" ); |
609 | } |
610 | |
611 | void uni_vmovups_spat_data( |
612 | const Operand &dst, const Operand &src, bool is_nt_store = false) { |
613 | if (dst.isMEM()) { |
614 | if (is_bf16_) { |
615 | constexpr bool isAvx2 = isa == avx2; |
616 | const typename std::conditional<isAvx2, Xmm, Ymm>::type |
617 | dst_reg {src.getIdx()}; |
618 | const typename std::conditional<isAvx2, Ymm, Zmm>::type |
619 | src_reg {src.getIdx()}; |
620 | |
621 | // convert f32 output to bf16 |
622 | if (!use_bf16_emulation()) |
623 | vcvtneps2bf16(dst_reg, src_reg, |
624 | mayiuse(avx512_core) ? Xbyak::EvexEncoding |
625 | : Xbyak::VexEncoding); |
626 | else |
627 | bf16_emu_->vcvtneps2bf16(dst_reg, src_reg); |
628 | |
629 | // store to memory |
630 | if (is_nt_store) |
631 | uni_vmovntps(dst.getAddress(), dst_reg); |
632 | else |
633 | uni_vmovups(dst.getAddress(), dst_reg); |
634 | } else if (is_f16_) { |
635 | auto src_reg = Vmm(src.getIdx()); |
636 | auto dst_reg = |
637 | typename vreg_traits<Vmm>::Vmm_lower_t(src.getIdx()); |
638 | if (is_nt_store) { |
639 | if (mayiuse(avx512_core_fp16)) |
640 | vcvtps2phx(dst_reg, src_reg); |
641 | else |
642 | vcvtps2ph(dst_reg, src_reg, _op_mxcsr); |
643 | uni_vmovntps(dst.getAddress(), dst_reg); |
644 | } else { |
645 | vcvtps2ph(dst.getAddress(), src_reg, _op_mxcsr); |
646 | } |
647 | } else { |
648 | if (is_nt_store) |
649 | uni_vmovntps(dst.getAddress(), Vmm(src.getIdx())); |
650 | else |
651 | uni_vmovups(dst.getAddress(), Vmm(src.getIdx())); |
652 | } |
653 | } else { |
654 | if (is_bf16_) { |
655 | // convert bf16 input to f32 |
656 | vpmovzxwd(Vmm(dst.getIdx()), src.getAddress()); |
657 | vpslld(Vmm(dst.getIdx()), Vmm(dst.getIdx()), 0x10); |
658 | } else if (is_f16_) { |
659 | if (mayiuse(avx512_core_fp16)) |
660 | vcvtph2psx(Vmm(dst.getIdx()), src.getAddress()); |
661 | else |
662 | vcvtph2ps(Vmm(dst.getIdx()), src.getAddress()); |
663 | } else { |
664 | uni_vmovups(Vmm(dst.getIdx()), src.getAddress()); |
665 | } |
666 | } |
667 | } |
668 | |
669 | void uni_vmovups_tail_avx2_common( |
670 | const Operand &dst, const Operand &src, Label &l_ret) { |
671 | if (dst.isMEM()) { |
672 | vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx())); |
673 | } else { |
674 | vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress()); |
675 | } |
676 | jmp(l_ret); |
677 | } |
678 | |
679 | void uni_vmovups_tail_avx512_common( |
680 | const Operand &dst, const Operand &src, Label &l_ret) { |
681 | if (dst.isMEM()) |
682 | uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx())); |
683 | else |
684 | uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress()); |
685 | |
686 | jmp(l_ret); |
687 | } |
688 | |
689 | void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) { |
690 | Label l_no_mask, l_ret; |
691 | |
692 | if (is_c_padded()) { |
693 | mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]); |
694 | cmp(reg_tmp, 0); |
695 | jz(l_no_mask); |
696 | |
697 | lea(reg_tmp, ptr[reg_coff + vlen]); |
698 | cmp(reg_tmp, reg_coff_max); |
699 | jl(l_no_mask); |
700 | assert(isa == avx512_core || isa == avx2); |
701 | if (isa == avx512_core) |
702 | uni_vmovups_tail_avx512_common(dst, src, l_ret); |
703 | else if (isa == avx2) |
704 | uni_vmovups_tail_avx2_common(dst, src, l_ret); |
705 | } |
706 | L(l_no_mask); |
707 | if (dst.isMEM()) |
708 | uni_vmovups(dst.getAddress(), Vmm(src.getIdx())); |
709 | else |
710 | uni_vmovups(Vmm(dst.getIdx()), src.getAddress()); |
711 | |
712 | L(l_ret); |
713 | } |
714 | |
715 | void barrier() { |
716 | mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); |
717 | mov(reg_bar, ptr[rsp + stack_off_barrier]); |
718 | simple_barrier::generate(*this, reg_bar, reg_nnthr); |
719 | } |
720 | |
721 | Address mean_ptr(size_t offt = 0) { |
722 | return vmmword[reg_mean + reg_coff + offt]; |
723 | } |
724 | |
725 | Address var_ptr(size_t offt = 0) { |
726 | return vmmword[reg_var + reg_coff + offt]; |
727 | } |
728 | |
729 | Address diff_gamma_ptr(size_t offt = 0) { |
730 | return vmmword[reg_diff_scale + reg_coff + offt]; |
731 | } |
732 | |
733 | Address diff_beta_ptr(size_t offt = 0) { |
734 | return vmmword[reg_diff_shift + reg_coff + offt]; |
735 | } |
736 | |
737 | Address gamma_ptr(size_t offt = 0) { |
738 | return vmmword[reg_scale + reg_coff + offt]; |
739 | } |
740 | |
741 | Address beta_ptr(size_t offt = 0) { |
742 | return vmmword[reg_shift + reg_coff + offt]; |
743 | } |
744 | |
745 | template <typename init_t, typename body_t, typename fini_t> |
746 | void spat_loop(size_t len, size_t blocks, size_t regs, init_t init, |
747 | body_t body, fini_t fini) { |
748 | size_t factor = regs * blocks; |
749 | size_t loop_unroll = len / factor * factor; |
750 | size_t loop_tail = len - loop_unroll; |
751 | size_t num_active_regs = (len < regs) ? len : regs; |
752 | for (size_t i = 0; i < num_active_regs; i++) |
753 | init(i); |
754 | if (loop_unroll) { |
755 | if (jbp_->is_spatial_thr_) { |
756 | mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); |
757 | add(reg_soff, ptr[rsp + stack_off_s_s]); |
758 | } else { |
759 | mov(reg_ctr, loop_unroll); |
760 | } |
761 | Label label; |
762 | L(label); |
763 | { |
764 | for (size_t i = 0; i < factor; i++) { |
765 | size_t base_reg = i % regs; |
766 | body(base_reg, i); |
767 | } |
768 | add(reg_soff, factor * spat_step); |
769 | sub(reg_ctr, factor); |
770 | jnz(label); |
771 | } |
772 | if (jbp_->is_spatial_thr_) { |
773 | add(reg_soff, ptr[rsp + stack_off_s_tail]); |
774 | } |
775 | } |
776 | |
777 | for (size_t i = 0; i < loop_tail; i++) { |
778 | size_t base_reg = i % regs; |
779 | body(base_reg, i); |
780 | } |
781 | if (loop_tail) add(reg_soff, loop_tail * spat_step); |
782 | |
783 | for (size_t i = 0; i < num_active_regs; i++) |
784 | fini(i); |
785 | } |
786 | |
787 | void mean_channels() { |
788 | Label ch_label; |
789 | L(ch_label); |
790 | { |
791 | uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); |
792 | spat_loop( |
793 | spat_size, unroll_blocks, unroll_regs, |
794 | [=](size_t base_reg) { |
795 | Vmm v = Vmm(base_reg * 2); |
796 | if (base_reg) uni_vpxor(v, v, v); |
797 | }, |
798 | [=](size_t base_reg, size_t i) { |
799 | Vmm v0 = Vmm(base_reg * 2 + 0); |
800 | Vmm v1 = Vmm(base_reg * 2 + 1); |
801 | size_t offt = i * vlen_spat_data_; |
802 | uni_vmovups_spat_data( |
803 | v1, vmmword[reg_src + reg_soff + offt]); |
804 | uni_vaddps(v0, v0, v1); |
805 | }, |
806 | [=](size_t base_reg) { |
807 | Vmm b = Vmm(0); |
808 | Vmm v = Vmm(base_reg * 2); |
809 | if (base_reg) uni_vaddps(b, b, v); |
810 | }); |
811 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); |
812 | |
813 | add(reg_coff, vlen); |
814 | cmp(reg_coff, reg_coff_max); |
815 | jl(ch_label); |
816 | } |
817 | } |
818 | |
819 | void mean_variance_nspc( |
820 | const int num_ch_blks, int num_spat_pts, bool compute_mean) { |
821 | |
822 | auto mean_compute_avx2_ne_xf16 = [=](int num_ch_blks, |
823 | int num_spat_pts) { |
824 | for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) { |
825 | for (int ch_idx = 0; ch_idx < num_ch_blks; ch_idx += 2) { |
826 | const int offt = ch_idx * vlen_spat_data_; |
827 | const bool is_ch_blks_tail = num_ch_blks - ch_idx < 2; |
828 | const Vmm vsrc_even = vtmp; |
829 | const Vmm vsrc_odd = vsrc_aux; |
830 | if (is_ch_blks_tail) |
831 | uni_vmovups_spat_data(vsrc_even, |
832 | vmmword[reg_src + reg_soff_nspc + offt]); |
833 | else |
834 | uni_vmovups_spat_data(vsrc_even, vsrc_odd, |
835 | vmmword[reg_src + reg_soff_nspc + offt]); |
836 | |
837 | uni_vaddps(Vmm(ch_idx), Vmm(ch_idx), vsrc_even); |
838 | if (!is_ch_blks_tail) |
839 | uni_vaddps(Vmm(ch_idx + 1), Vmm(ch_idx + 1), vsrc_odd); |
840 | } |
841 | add(reg_soff_nspc, spat_step); |
842 | } |
843 | }; |
844 | |
845 | auto variance_compute_avx2_ne_xf16 = [=](int num_ch_blks, |
846 | int num_spat_pts) { |
847 | for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) { |
848 | for (int ch_idx = 0; ch_idx < num_ch_blks; ch_idx += 2) { |
849 | const int offt = ch_idx * vlen_spat_data_; |
850 | const bool is_ch_blks_tail = num_ch_blks - ch_idx < 2; |
851 | const Vmm vsrc_even = vtmp; |
852 | const Vmm vsrc_odd = vsrc_aux; |
853 | const Vmm vmean_ch_even = Vmm(ch_idx + num_ch_blks); |
854 | const Vmm vmean_ch_odd = Vmm(ch_idx + 1 + num_ch_blks); |
855 | if (is_ch_blks_tail) |
856 | uni_vmovups_spat_data(vsrc_even, |
857 | vmmword[reg_src + reg_soff_nspc + offt]); |
858 | else |
859 | uni_vmovups_spat_data(vsrc_even, vsrc_odd, |
860 | vmmword[reg_src + reg_soff_nspc + offt]); |
861 | uni_vsubps(vsrc_even, vsrc_even, vmean_ch_even); |
862 | uni_vfmadd231ps(Vmm(ch_idx), vsrc_even, vsrc_even); |
863 | if (!is_ch_blks_tail) { |
864 | uni_vsubps(vsrc_odd, vsrc_odd, vmean_ch_odd); |
865 | uni_vfmadd231ps(Vmm(ch_idx + 1), vsrc_odd, vsrc_odd); |
866 | } |
867 | } |
868 | add(reg_soff_nspc, spat_step); |
869 | } |
870 | }; |
871 | |
872 | auto mean_compute = [=](int num_ch_blks, int num_spat_pts) { |
873 | for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) { |
874 | for (int ch_idx = 0; ch_idx < num_ch_blks; ++ch_idx) { |
875 | const int offt = ch_idx * vlen_spat_data_; |
876 | const Vmm vsrc = vtmp; |
877 | uni_vmovups_spat_data( |
878 | vsrc, vmmword[reg_src + reg_soff_nspc + offt]); |
879 | uni_vaddps(Vmm(ch_idx), Vmm(ch_idx), vsrc); |
880 | } |
881 | add(reg_soff_nspc, spat_step); |
882 | } |
883 | }; |
884 | |
885 | auto variance_compute = [=](int num_ch_blks, int num_spat_pts) { |
886 | for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) { |
887 | for (int ch_idx = 0; ch_idx < num_ch_blks; ++ch_idx) { |
888 | const int offt = ch_idx * vlen_spat_data_; |
889 | const Vmm vsrc = vtmp; |
890 | const Vmm vmean_ch = Vmm(ch_idx + num_ch_blks); |
891 | uni_vmovups_spat_data( |
892 | vsrc, vmmword[reg_src + reg_soff_nspc + offt]); |
893 | uni_vsubps(vsrc, vsrc, vmean_ch); |
894 | uni_vfmadd231ps(Vmm(ch_idx), vsrc, vsrc); |
895 | } |
896 | add(reg_soff_nspc, spat_step); |
897 | } |
898 | }; |
899 | |
900 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
901 | const int coff = idx * vlen; |
902 | uni_vmovups(Vmm(idx), vmmword[reg_rbuf1 + reg_coff + coff]); |
903 | if (!compute_mean) { |
904 | // pre-load mean to avoid extra data movement during variance |
905 | const Vmm vmean_ch = Vmm(idx + num_ch_blks); |
906 | uni_vmovups_maybe_tail(vmean_ch, mean_ptr(coff)); |
907 | } |
908 | } |
909 | |
910 | xor_(reg_soff_nspc, reg_soff_nspc); |
911 | |
912 | if (jbp_->is_spatial_thr_) { |
913 | mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); |
914 | add(reg_soff_nspc, ptr[rsp + stack_off_s_s]); |
915 | // TODO: need a better heuristic for num_spat_pts |
916 | num_spat_pts = 1; |
917 | } else { |
918 | mov(reg_ctr, spat_size); |
919 | num_spat_pts = nstl::min((size_t)num_spat_pts, spat_size); |
920 | // TODO: unroll by spatial |
921 | if (spat_size % num_spat_pts != 0) num_spat_pts = 1; |
922 | } |
923 | |
924 | Label spatial; |
925 | L(spatial); |
926 | { |
927 | if (is_avx2_ne_xf16_) |
928 | compute_mean |
929 | ? mean_compute_avx2_ne_xf16(num_ch_blks, num_spat_pts) |
930 | : variance_compute_avx2_ne_xf16( |
931 | num_ch_blks, num_spat_pts); |
932 | else |
933 | compute_mean ? mean_compute(num_ch_blks, num_spat_pts) |
934 | : variance_compute(num_ch_blks, num_spat_pts); |
935 | sub(reg_ctr, num_spat_pts); |
936 | jnz(spatial, T_NEAR); |
937 | } |
938 | |
939 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
940 | const int coff = idx * vlen; |
941 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff + coff], Vmm(idx)); |
942 | } |
943 | } |
944 | |
945 | void forward_channels_nspc_compute(const int num_ch_blks) { |
946 | auto compute = [=](bool stream_store_allowed) { |
947 | // Overwritten during mean and variance computation |
948 | uni_vpxor(vzero, vzero, vzero); |
949 | |
950 | xor_(reg_soff_nspc, reg_soff_nspc); |
951 | |
952 | if (jbp_->is_spatial_thr_) { |
953 | mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); |
954 | add(reg_soff_nspc, ptr[rsp + stack_off_s_s]); |
955 | } else { |
956 | mov(reg_ctr, spat_size); |
957 | } |
958 | |
959 | // TODO: spatial blocking |
960 | const int num_spat_pts = 1; |
961 | |
962 | // pre-compute scale for each channel to avoid costly div and sqrt |
963 | // merge variances in interleaved to plain layout if needed |
964 | for (int idx = 0; idx < num_ch_blks; idx += 2) { |
965 | const int coff_base = idx * vlen; |
966 | const bool is_ch_blks_tail = num_ch_blks - idx < 2; |
967 | const Vmm vvar_even = Vmm(idx); |
968 | const Vmm vvar_odd = Vmm(idx + 1); |
969 | if (!is_ch_blks_tail) { |
970 | uni_vmovups_maybe_tail(vvar_even, var_ptr(coff_base)); |
971 | uni_vmovups_maybe_tail(vvar_odd, var_ptr(coff_base + vlen)); |
972 | if (is_avx2_ne_xf16_ && !pd_->stats_is_src()) |
973 | merge_interleaved_to_plain(vvar_even, vvar_odd, vtmp); |
974 | } else |
975 | uni_vmovups_maybe_tail(vvar_even, var_ptr(coff_base)); |
976 | |
977 | for (int i_odd = 0; i_odd < 2 && idx + i_odd < num_ch_blks; |
978 | ++i_odd) { |
979 | const int coff = coff_base + i_odd * vlen; |
980 | const Vmm vscale = Vmm(idx + i_odd + num_ch_blks); |
981 | const Vmm vvar = i_odd ? vvar_odd : vvar_even; |
982 | uni_vmovups(vsqrtvar, vvar); |
983 | uni_vaddps(vsqrtvar, vsqrtvar, veps); |
984 | uni_vsqrtps(vsqrtvar, vsqrtvar); |
985 | |
986 | if (pd_->use_scale()) { |
987 | uni_vmovups_maybe_tail(vgamma, gamma_ptr(coff)); |
988 | uni_vdivps(vscale, vgamma, vsqrtvar, vtmp); |
989 | } else { |
990 | uni_vdivps(vscale, vone, vsqrtvar, vtmp); |
991 | } |
992 | } |
993 | } |
994 | |
995 | Label spatial; |
996 | L(spatial); |
997 | { |
998 | if (is_avx2_ne_xf16_) { |
999 | for (int idx = 0; idx < num_ch_blks; idx += 2) { |
1000 | const int offt = idx * vlen_spat_data_; |
1001 | const int coff = idx * vlen; |
1002 | const bool is_ch_blks_tail = num_ch_blks - idx < 2; |
1003 | Vmm vdata_even = Vmm(idx); |
1004 | Vmm vdata_odd = Vmm(idx + 1); |
1005 | if (is_ch_blks_tail) { |
1006 | uni_vmovups_spat_data(vdata_even, |
1007 | vmmword[reg_src + reg_soff_nspc + offt]); |
1008 | if (!pd_->stats_is_src()) |
1009 | uni_vsubps( |
1010 | vdata_even, vdata_even, mean_ptr(coff)); |
1011 | } else { |
1012 | uni_vmovups_spat_data(vdata_even, vdata_odd, |
1013 | vmmword[reg_src + reg_soff_nspc + offt]); |
1014 | // apply mean in interleave to data in interleave |
1015 | // before merge them to plain layout when needed |
1016 | if (!pd_->stats_is_src()) { |
1017 | uni_vsubps( |
1018 | vdata_even, vdata_even, mean_ptr(coff)); |
1019 | uni_vsubps(vdata_odd, vdata_odd, |
1020 | mean_ptr(coff + vlen)); |
1021 | } |
1022 | merge_interleaved_to_plain( |
1023 | vdata_even, vdata_odd, vtmp); |
1024 | } |
1025 | } |
1026 | } |
1027 | |
1028 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
1029 | const int coff = idx * vlen; |
1030 | const int offt = idx * vlen_spat_data_; |
1031 | const Vmm vdata = Vmm(idx); |
1032 | const Vmm vscale = Vmm(idx + num_ch_blks); |
1033 | uni_vmovups_maybe_tail(vmean, mean_ptr(coff)); |
1034 | |
1035 | if (pd_->use_shift()) { |
1036 | uni_vmovups_maybe_tail(vbeta, beta_ptr(coff)); |
1037 | } |
1038 | |
1039 | if (!is_avx2_ne_xf16_) |
1040 | uni_vmovups_spat_data( |
1041 | vdata, vmmword[reg_src + reg_soff_nspc + offt]); |
1042 | if (IMPLICATION(is_avx2_ne_xf16_, pd_->stats_is_src())) |
1043 | uni_vsubps(vdata, vdata, vmean); |
1044 | |
1045 | if (pd_->use_shift()) { |
1046 | // --flags=S,CH,H |
1047 | uni_vfmadd213ps(vdata, vscale, vbeta); |
1048 | } else { |
1049 | // --flags=,C |
1050 | uni_vmulps(vdata, vdata, vscale); |
1051 | } |
1052 | |
1053 | if (with_relu_inf_only) { // --attr=post_ops='relu' |
1054 | if (pd_->alpha() != 0.f) |
1055 | fwd_process_relu_alpha(vdata); |
1056 | else |
1057 | uni_vmaxps(vdata, vdata, vzero); |
1058 | } else if (with_relu) { // --flags=R |
1059 | if (isa == avx512_core) |
1060 | fwd_process_relu_avx512_common(vdata, offt); |
1061 | else if (isa == avx2) |
1062 | fwd_process_relu_avx2(vdata, offt); |
1063 | else |
1064 | assert(false); |
1065 | } |
1066 | uni_vmovups_spat_data( |
1067 | vmmword[reg_dst + reg_soff_nspc + offt], vdata, |
1068 | stream_store_allowed); |
1069 | } |
1070 | add(reg_soff_nspc, spat_step); |
1071 | sub(reg_ctr, num_spat_pts); |
1072 | jnz(spatial, T_NEAR); |
1073 | } |
1074 | }; |
1075 | |
1076 | if (stream_store_supported()) { |
1077 | Label normal_store, end_store; |
1078 | test(reg_dst, vlen_spat_data_ - 1); |
1079 | jnz(normal_store, T_NEAR); |
1080 | compute(true); |
1081 | jmp(end_store, T_NEAR); |
1082 | L(normal_store); |
1083 | { compute(false); } |
1084 | L(end_store); |
1085 | } else { |
1086 | compute(false); // disabled for bf16 when data fits in cache |
1087 | } |
1088 | } |
1089 | |
1090 | void compute_mean_variance_nspc(bool compute_mean = true) { |
1091 | xor_(reg_coff, reg_coff); |
1092 | mov(reg_coff_max_fwd_copy, reg_coff_max); |
1093 | |
1094 | Label ch_unroll_label[5]; |
1095 | const int max_ch_unroll = isa == avx512_core ? 4 : 2; |
1096 | |
1097 | // TODO: Spatial and channel unrolling decisions should be made during |
1098 | // initialization depending on the problem size |
1099 | for (int ch_idx = max_ch_unroll, sp_idx = 1; ch_idx > 0; |
1100 | --ch_idx, ++sp_idx) { |
1101 | L(ch_unroll_label[ch_idx]); |
1102 | { |
1103 | const int ch_blk_size = (1 << (ch_idx - 1)); // 8, 4, 2, 1 |
1104 | cmp(reg_coff_max, vlen * ch_blk_size); |
1105 | jl(ch_unroll_label[ch_idx - 1], T_NEAR); |
1106 | |
1107 | const int spat_blk_size = (1 << sp_idx); |
1108 | mean_variance_nspc(ch_blk_size, spat_blk_size, compute_mean); |
1109 | |
1110 | add(reg_src, vlen_spat_data_ * ch_blk_size); |
1111 | add(reg_coff, vlen * ch_blk_size); |
1112 | |
1113 | sub(reg_coff_max, vlen * ch_blk_size); |
1114 | jmp(ch_unroll_label[ch_idx], T_NEAR); |
1115 | } |
1116 | } |
1117 | L(ch_unroll_label[0]); |
1118 | |
1119 | // comeback |
1120 | mov(reg_coff_max, reg_coff_max_fwd_copy); |
1121 | |
1122 | if (is_xf16()) shr(reg_coff_max, 1); |
1123 | sub(reg_src, reg_coff_max); |
1124 | if (is_xf16()) shl(reg_coff_max, 1); |
1125 | } |
1126 | |
1127 | void var_channels() { |
1128 | Label ch_label; |
1129 | L(ch_label); |
1130 | { |
1131 | uni_vmovups_maybe_tail(vmean, mean_ptr()); |
1132 | uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); |
1133 | spat_loop( |
1134 | spat_size, unroll_blocks, unroll_regs, |
1135 | [=](size_t base_reg) { |
1136 | Vmm v = Vmm(base_reg * 3); |
1137 | if (base_reg > 0) uni_vpxor(v, v, v); |
1138 | }, |
1139 | [=](size_t base_reg, size_t i) { |
1140 | Vmm v = Vmm(3 * base_reg); |
1141 | Vmm vtmp0 = Vmm(3 * base_reg + 1); |
1142 | Vmm vtmp1 = Vmm(3 * base_reg + 2); |
1143 | size_t offt = i * vlen_spat_data_; |
1144 | uni_vmovups_spat_data( |
1145 | vtmp0, vmmword[reg_src + reg_soff + offt]); |
1146 | if (isa == sse41) { |
1147 | movups(vtmp1, vmean); |
1148 | subps(vtmp1, vtmp0); |
1149 | } else { |
1150 | vsubps(vtmp1, vmean, vtmp0); |
1151 | } |
1152 | uni_vfmadd231ps(v, vtmp1, vtmp1); |
1153 | }, |
1154 | [=](size_t base_reg) { |
1155 | Vmm b = Vmm(0); |
1156 | Vmm v = Vmm(base_reg * 3); |
1157 | if (base_reg) uni_vaddps(b, b, v); |
1158 | }); |
1159 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); |
1160 | add(reg_coff, vlen); |
1161 | cmp(reg_coff, reg_coff_max); |
1162 | jl(ch_label); |
1163 | } |
1164 | } |
1165 | |
1166 | void compute_mean_variance() { |
1167 | uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); |
1168 | xor_(reg_coff, reg_coff); |
1169 | Label zero_rbuf; |
1170 | L(zero_rbuf); |
1171 | { |
1172 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); |
1173 | add(reg_coff, isa == sse41 ? vlen / 2 : vlen); |
1174 | cmp(reg_coff, reg_coff_max); |
1175 | jne(zero_rbuf); |
1176 | } |
1177 | |
1178 | mov(reg_src, ptr[rsp + stack_off_src]); |
1179 | |
1180 | xor_(reg_soff, reg_soff); |
1181 | Label mean_spatial; |
1182 | L(mean_spatial); |
1183 | { |
1184 | xor_(reg_coff, reg_coff); |
1185 | |
1186 | if (isa == sse41) mov(reg_tmp_off, reg_soff); |
1187 | |
1188 | jbp_->is_nspc_ ? compute_mean_variance_nspc() : mean_channels(); |
1189 | |
1190 | if (isa == sse41) { |
1191 | mov(reg_soff, reg_tmp_off); |
1192 | add(reg_src, vlen / 2); |
1193 | mov(reg_coff, vlen / 2); |
1194 | |
1195 | mean_channels(); |
1196 | |
1197 | sub(reg_src, vlen / 2); |
1198 | } |
1199 | |
1200 | // Process next image |
1201 | if (jbp_->is_nspc_) { |
1202 | // Can use static offset since we comeback after spatial loop |
1203 | add(reg_src, mb_offt); |
1204 | add(reg_soff, mb_offt); |
1205 | } else { |
1206 | add(reg_soff, reg_mb_stride_Bc); |
1207 | } |
1208 | |
1209 | cmp(reg_soff, reg_soff_max); |
1210 | jl(mean_spatial); |
1211 | } |
1212 | |
1213 | if (jbp_->is_nspc_) mov(reg_src, ptr[rsp + stack_off_src]); // comeback |
1214 | |
1215 | Label no_mean_reduction; |
1216 | barrier(); |
1217 | { |
1218 | mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); |
1219 | cmp(reg_tmp, 0); |
1220 | jne(no_mean_reduction); |
1221 | mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); |
1222 | xor_(reg_coff, reg_coff); |
1223 | Label mean_reduction_channels; |
1224 | L(mean_reduction_channels); |
1225 | { |
1226 | mov(reg_roff, reg_coff); |
1227 | uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); |
1228 | uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); |
1229 | mov(reg_ctr, reg_nnthr); |
1230 | Label mean_reduction_thrs; |
1231 | L(mean_reduction_thrs); |
1232 | { |
1233 | uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); |
1234 | uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0)); |
1235 | add(reg_roff, reg_coff_max); |
1236 | sub(reg_ctr, 1); |
1237 | jnz(mean_reduction_thrs); |
1238 | } |
1239 | uni_vdivps(Vmm(1), Vmm(1), vchan_size); |
1240 | uni_vmovups_maybe_tail(mean_ptr(), Vmm(1)); |
1241 | |
1242 | add(reg_coff, isa == sse41 ? vlen / 2 : vlen); |
1243 | |
1244 | cmp(reg_coff, reg_coff_max); |
1245 | jl(mean_reduction_channels); |
1246 | } |
1247 | } |
1248 | L(no_mean_reduction); |
1249 | barrier(); |
1250 | |
1251 | xor_(reg_soff, reg_soff); |
1252 | Label var_spatial; |
1253 | L(var_spatial); |
1254 | { |
1255 | xor_(reg_coff, reg_coff); |
1256 | |
1257 | if (isa == sse41) mov(reg_tmp_off, reg_soff); |
1258 | |
1259 | jbp_->is_nspc_ ? compute_mean_variance_nspc(false) : var_channels(); |
1260 | |
1261 | if (isa == sse41) { |
1262 | mov(reg_soff, reg_tmp_off); |
1263 | add(reg_src, vlen / 2); |
1264 | mov(reg_coff, vlen / 2); |
1265 | |
1266 | var_channels(); |
1267 | |
1268 | sub(reg_src, vlen / 2); |
1269 | } |
1270 | |
1271 | // Process next image |
1272 | if (jbp_->is_nspc_) { |
1273 | // Can use static offset since we comeback after spatial loop |
1274 | add(reg_src, mb_offt); |
1275 | add(reg_soff, mb_offt); |
1276 | } else { |
1277 | add(reg_soff, reg_mb_stride_Bc); |
1278 | } |
1279 | |
1280 | cmp(reg_soff, reg_soff_max); |
1281 | jl(var_spatial); |
1282 | } |
1283 | |
1284 | if (jbp_->is_nspc_) mov(reg_src, ptr[rsp + stack_off_src]); // comeback |
1285 | |
1286 | Label no_var_reduction; |
1287 | barrier(); |
1288 | { |
1289 | mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); |
1290 | cmp(reg_tmp, 0); |
1291 | jne(no_var_reduction); |
1292 | |
1293 | mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); |
1294 | xor_(reg_coff, reg_coff); |
1295 | Label var_reduction_channels; |
1296 | L(var_reduction_channels); |
1297 | { |
1298 | mov(reg_roff, reg_coff); |
1299 | uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); |
1300 | mov(reg_ctr, reg_nnthr); |
1301 | Label var_reduction_thrs; |
1302 | L(var_reduction_thrs); |
1303 | { // TODO: unroll (?) |
1304 | uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); |
1305 | add(reg_roff, reg_coff_max); |
1306 | sub(reg_ctr, 1); |
1307 | jnz(var_reduction_thrs); |
1308 | } |
1309 | uni_vdivps(Vmm(1), Vmm(1), vchan_size); |
1310 | uni_vmovups_maybe_tail(var_ptr(), Vmm(1)); |
1311 | add(reg_coff, isa == sse41 ? vlen / 2 : vlen); |
1312 | |
1313 | cmp(reg_coff, reg_coff_max); |
1314 | jne(var_reduction_channels); |
1315 | } |
1316 | } |
1317 | L(no_var_reduction); |
1318 | barrier(); |
1319 | } |
1320 | |
1321 | void forward_channels() { |
1322 | Label ch_label; |
1323 | L(ch_label); |
1324 | { |
1325 | uni_vmovups_maybe_tail(vmean, mean_ptr()); |
1326 | uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); |
1327 | uni_vaddps(vsqrtvar, vsqrtvar, veps); |
1328 | uni_vsqrtps(vsqrtvar, vsqrtvar); |
1329 | |
1330 | if (pd_->use_scale()) { |
1331 | uni_vmovups_maybe_tail(vgamma, gamma_ptr()); |
1332 | } |
1333 | if (pd_->use_shift()) { uni_vmovups_maybe_tail(vbeta, beta_ptr()); } |
1334 | |
1335 | Vmm vscale = (pd_->use_scale()) ? vgamma : vone; |
1336 | Vmm vdiv = (pd_->use_scale()) ? vgamma : vsqrtvar; |
1337 | |
1338 | if (isa == sse41) { |
1339 | movups(vtmp, vscale); |
1340 | divps(vtmp, vsqrtvar); |
1341 | movups(vdiv, vtmp); |
1342 | } else { |
1343 | vdivps(vdiv, vscale, vsqrtvar); |
1344 | } |
1345 | |
1346 | const auto spat_loop_init_fin |
1347 | = [](size_t base_reg) { UNUSED(base_reg); }; |
1348 | |
1349 | const auto spat_loop_body = [=](size_t base_reg, size_t i, |
1350 | bool stream_store_allowed) { |
1351 | const Vmm v = Vmm(base_reg); |
1352 | const size_t offt = i * vlen_spat_data_; |
1353 | uni_vmovups_spat_data(v, vmmword[reg_src + reg_soff + offt]); |
1354 | uni_vsubps(v, v, vmean); |
1355 | if ((pd_->use_scale() && pd_->use_shift())) { |
1356 | // --flags=CH |
1357 | uni_vfmadd213ps(v, vgamma, vbeta); |
1358 | } else if (pd_->use_scale()) { |
1359 | // --flags=C |
1360 | uni_vmulps(v, v, vgamma); |
1361 | } else if (pd_->use_shift()) { |
1362 | // --flags=H |
1363 | uni_vfmadd213ps(v, vsqrtvar, vbeta); |
1364 | } else { |
1365 | uni_vmulps(v, v, vsqrtvar); |
1366 | } |
1367 | if (with_relu_inf_only) { // --attr=post_ops='relu' |
1368 | if (pd_->alpha() != 0.f) { |
1369 | fwd_process_relu_alpha(v); |
1370 | } else |
1371 | uni_vmaxps(v, v, vzero); |
1372 | } else if (with_relu) { // --flags=R |
1373 | if (isa == avx512_core) |
1374 | fwd_process_relu_avx512_common(v, offt); |
1375 | else |
1376 | fwd_process_relu_avx2(v, offt); |
1377 | } |
1378 | if (stream_store_allowed) { |
1379 | uni_vmovntps(vmmword[reg_dst + reg_soff + offt], v); |
1380 | } else { |
1381 | uni_vmovups_spat_data( |
1382 | vmmword[reg_dst + reg_soff + offt], v); |
1383 | } |
1384 | }; |
1385 | |
1386 | const auto compute = [=](bool stream_store_allowed) { |
1387 | using namespace std::placeholders; |
1388 | spat_loop(spat_size, unroll_blocks, unroll_regs, |
1389 | spat_loop_init_fin, |
1390 | std::bind(spat_loop_body, _1, _2, stream_store_allowed), |
1391 | spat_loop_init_fin); |
1392 | }; |
1393 | |
1394 | if (stream_store_supported()) { |
1395 | Label normal_store, end_store; |
1396 | test(reg_dst, vlen - 1); |
1397 | jnz(normal_store, T_NEAR); |
1398 | compute(true); |
1399 | jmp(end_store, T_NEAR); |
1400 | L(normal_store); |
1401 | { compute(false); } |
1402 | L(end_store); |
1403 | } else { |
1404 | compute(false); // no NT store for BF16 |
1405 | } |
1406 | |
1407 | add(reg_coff, vlen); |
1408 | cmp(reg_coff, reg_coff_max); |
1409 | jl(ch_label); |
1410 | } |
1411 | } |
1412 | |
1413 | void forward_channels_nspc() { |
1414 | xor_(reg_coff, reg_coff); |
1415 | mov(reg_coff_max_fwd_copy, reg_coff_max); |
1416 | |
1417 | Label ch_unroll_label[5]; |
1418 | const int max_ch_unroll |
1419 | = isa == avx512_core ? 4 - use_bf16_emulation() : 2; |
1420 | |
1421 | // TODO: Spatial and channel unrolling decisions should be made during |
1422 | // initialization depending on the problem size |
1423 | for (int ch_idx = max_ch_unroll; ch_idx > 0; --ch_idx) { |
1424 | L(ch_unroll_label[ch_idx]); |
1425 | { |
1426 | const int ch_blk_size = (1 << (ch_idx - 1)); // 8, 4, 2, 1 |
1427 | cmp(reg_coff_max, vlen * ch_blk_size); |
1428 | jl(ch_unroll_label[ch_idx - 1], T_NEAR); |
1429 | |
1430 | forward_channels_nspc_compute(ch_blk_size); |
1431 | |
1432 | add(reg_src, vlen_spat_data_ * ch_blk_size); |
1433 | add(reg_dst, vlen_spat_data_ * ch_blk_size); |
1434 | |
1435 | // advance mean_ptr() and var_ptr() |
1436 | add(reg_coff, vlen * ch_blk_size); |
1437 | |
1438 | add(reg_ws, (vlen / 32) * ch_blk_size); |
1439 | |
1440 | sub(reg_coff_max, vlen * ch_blk_size); |
1441 | jmp(ch_unroll_label[ch_idx], T_NEAR); |
1442 | } |
1443 | } |
1444 | L(ch_unroll_label[0]); |
1445 | |
1446 | // comeback |
1447 | mov(reg_coff_max, reg_coff_max_fwd_copy); |
1448 | |
1449 | if (is_xf16()) shr(reg_coff_max, 1); |
1450 | sub(reg_src, reg_coff_max); |
1451 | sub(reg_dst, reg_coff_max); |
1452 | if (is_xf16()) shl(reg_coff_max, 1); |
1453 | |
1454 | shr(reg_coff_max, 5); |
1455 | sub(reg_ws, reg_coff_max); |
1456 | shl(reg_coff_max, 5); |
1457 | } |
1458 | |
1459 | void forward() { |
1460 | mov(reg_src, ptr[rsp + stack_off_src]); |
1461 | mov(reg_dst, ptr[rsp + stack_off_dst]); |
1462 | mov(reg_ws, ptr[rsp + stack_off_ws]); |
1463 | mov(reg_shift, ptr[rsp + stack_off_shift]); |
1464 | |
1465 | xor_(reg_soff, reg_soff); |
1466 | Label dst_spatial; |
1467 | L(dst_spatial); |
1468 | { |
1469 | xor_(reg_coff, reg_coff); |
1470 | if (isa == sse41) mov(reg_tmp_off, reg_soff); |
1471 | |
1472 | jbp_->is_nspc_ ? forward_channels_nspc() : forward_channels(); |
1473 | |
1474 | if (isa == sse41) { |
1475 | mov(reg_soff, reg_tmp_off); |
1476 | add(reg_src, vlen / 2); |
1477 | add(reg_dst, vlen / 2); |
1478 | mov(reg_coff, vlen / 2); |
1479 | |
1480 | forward_channels(); |
1481 | |
1482 | sub(reg_src, vlen / 2); |
1483 | sub(reg_dst, vlen / 2); |
1484 | } |
1485 | |
1486 | // Process next image |
1487 | if (jbp_->is_nspc_) { |
1488 | // Can use static offset since we comeback after spatial loop |
1489 | add(reg_src, mb_offt); |
1490 | add(reg_dst, mb_offt); |
1491 | add(reg_soff, mb_offt); |
1492 | add(reg_ws, ws_mb_offt); |
1493 | } else { |
1494 | add(reg_soff, reg_mb_stride_Bc); |
1495 | } |
1496 | |
1497 | cmp(reg_soff, reg_soff_max); |
1498 | jl(dst_spatial); |
1499 | } |
1500 | |
1501 | if (jbp_->is_nspc_) { |
1502 | // comeback |
1503 | mov(reg_src, ptr[rsp + stack_off_src]); |
1504 | mov(reg_dst, ptr[rsp + stack_off_dst]); |
1505 | mov(reg_ws, ptr[rsp + stack_off_ws]); |
1506 | } |
1507 | } |
1508 | |
1509 | void backward_sh_channels() { |
1510 | Label sh_channels; |
1511 | L(sh_channels); |
1512 | { |
1513 | uni_vmovups_maybe_tail(vmean, mean_ptr()); |
1514 | uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); |
1515 | uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]); |
1516 | spat_loop( |
1517 | spat_size, 1, 1, |
1518 | [=](size_t base_reg) { |
1519 | if (base_reg > 0) { |
1520 | for (int i = 0; i < 2; i++) { |
1521 | Vmm v(base_reg * 5 + i); |
1522 | uni_vpxor(v, v, v); |
1523 | } |
1524 | } |
1525 | }, |
1526 | [=](size_t base_reg, size_t i) { |
1527 | // TODO: use single set of tmp regs and let ROB handle the rest |
1528 | Vmm o0 = Vmm(base_reg * 5 + 0); |
1529 | Vmm o1 = Vmm(base_reg * 5 + 1); |
1530 | Vmm t1 = Vmm(base_reg * 5 + 2); |
1531 | Vmm t2 = Vmm(base_reg * 5 + 3); |
1532 | Vmm t3 = Vmm(base_reg * 5 + 4); |
1533 | size_t offt = i * vlen_spat_data_; |
1534 | uni_vmovups_spat_data( |
1535 | t1, vmmword[reg_src + reg_soff + offt]); |
1536 | uni_vmovups_spat_data( |
1537 | t2, vmmword[reg_diff_dst + reg_soff + offt]); |
1538 | if (with_relu) { |
1539 | if (isa == avx512_core) |
1540 | bwd_process_relu_avx512_common(t2, offt); |
1541 | else if (isa == avx2) |
1542 | bwd_process_relu_avx2(t2, offt); |
1543 | else |
1544 | assert(false); |
1545 | } |
1546 | uni_vsubps(t3, vmean, t1, t3); |
1547 | if (isa == sse41) { |
1548 | mulps(t3, t2); |
1549 | subps(o0, t3); |
1550 | } else { |
1551 | vfnmadd231ps(o0, t3, t2); |
1552 | } |
1553 | uni_vaddps(o1, o1, t2); |
1554 | }, |
1555 | [=](size_t base_reg) { |
1556 | Vmm b0 = Vmm(0); |
1557 | Vmm b1 = Vmm(1); |
1558 | if (base_reg) { |
1559 | uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0)); |
1560 | uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1)); |
1561 | } |
1562 | }); |
1563 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); |
1564 | uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1)); |
1565 | add(reg_coff, vlen); |
1566 | cmp(reg_coff, reg_coff_max); |
1567 | jl(sh_channels); |
1568 | } |
1569 | } |
1570 | |
1571 | void backward_sh_channels_nspc_compute(const int num_ch_blks) { |
1572 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
1573 | const int offt = idx * vlen; |
1574 | const Vmm vdiff_gamma_ch = Vmm(idx); |
1575 | const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks); |
1576 | uni_vmovups(vdiff_gamma_ch, vmmword[reg_rbuf1 + reg_coff + offt]); |
1577 | uni_vmovups(vdiff_beta_ch, vmmword[reg_rbuf2 + reg_coff + offt]); |
1578 | } |
1579 | |
1580 | xor_(reg_soff_nspc, reg_soff_nspc); |
1581 | |
1582 | if (jbp_->is_spatial_thr_) { |
1583 | mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); |
1584 | add(reg_soff_nspc, ptr[rsp + stack_off_s_s]); |
1585 | } else { |
1586 | mov(reg_ctr, spat_size); |
1587 | } |
1588 | |
1589 | // TODO: spatial blocking |
1590 | const int num_spat_pts = 1; |
1591 | |
1592 | Label spatial; |
1593 | L(spatial); |
1594 | { |
1595 | for (int ch_idx = 0; ch_idx < num_ch_blks; ++ch_idx) { |
1596 | const int coff = ch_idx * vlen; |
1597 | const int offt = ch_idx * vlen_spat_data_; |
1598 | const Vmm vdiff_gamma_ch = Vmm(ch_idx); |
1599 | const Vmm vdiff_beta_ch = Vmm(ch_idx + num_ch_blks); |
1600 | // vdiff_beta and vdiff_gamma are free registers for nspc |
1601 | const Vmm vsrc = vdiff_gamma; |
1602 | const Vmm vdiff_dst = vdiff_beta; |
1603 | uni_vmovups_maybe_tail(vmean, mean_ptr(coff)); |
1604 | |
1605 | uni_vmovups_spat_data( |
1606 | vsrc, vmmword[reg_src + reg_soff_nspc + offt]); |
1607 | uni_vmovups_spat_data(vdiff_dst, |
1608 | vmmword[reg_diff_dst + reg_soff_nspc + offt]); |
1609 | |
1610 | if (with_relu) { |
1611 | if (isa == avx512_core) |
1612 | bwd_process_relu_avx512_common(vdiff_dst, offt); |
1613 | else |
1614 | assert(false); |
1615 | } |
1616 | |
1617 | uni_vsubps(vsrc, vsrc, vmean); |
1618 | uni_vfmadd231ps(vdiff_gamma_ch, vsrc, vdiff_dst); |
1619 | uni_vaddps(vdiff_beta_ch, vdiff_beta_ch, vdiff_dst); |
1620 | } |
1621 | add(reg_soff_nspc, spat_step); |
1622 | sub(reg_ctr, num_spat_pts); |
1623 | jnz(spatial, T_NEAR); |
1624 | } |
1625 | |
1626 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
1627 | const Vmm vdiff_gamma_ch = Vmm(idx); |
1628 | const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks); |
1629 | const int offt = idx * vlen; |
1630 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff + offt], vdiff_gamma_ch); |
1631 | uni_vmovups(vmmword[reg_rbuf2 + reg_coff + offt], vdiff_beta_ch); |
1632 | } |
1633 | } |
1634 | |
1635 | void backward_sh_channels_nspc() { |
1636 | xor_(reg_coff, reg_coff); |
1637 | mov(reg_coff_max_bwd_copy, reg_coff_max); |
1638 | |
1639 | Label ch_unroll_label[5]; |
1640 | const int max_ch_unroll = 4; |
1641 | |
1642 | // TODO: Spatial and channel unrolling decisions should be made during |
1643 | // initialization depending on the problem size |
1644 | for (int ch_idx = max_ch_unroll; ch_idx > 0; --ch_idx) { |
1645 | L(ch_unroll_label[ch_idx]); |
1646 | { |
1647 | const int ch_blk_size = (1 << (ch_idx - 1)); // 8, 4, 2, 1 |
1648 | cmp(reg_coff_max, vlen * ch_blk_size); |
1649 | jl(ch_unroll_label[ch_idx - 1], T_NEAR); |
1650 | |
1651 | backward_sh_channels_nspc_compute(ch_blk_size); |
1652 | |
1653 | add(reg_src, vlen_spat_data_ * ch_blk_size); |
1654 | add(reg_diff_dst, vlen_spat_data_ * ch_blk_size); |
1655 | |
1656 | // advance mean_ptr() and var_ptr() |
1657 | add(reg_coff, vlen * ch_blk_size); |
1658 | |
1659 | add(reg_ws, 2 * ch_blk_size); |
1660 | |
1661 | sub(reg_coff_max, vlen * ch_blk_size); |
1662 | jmp(ch_unroll_label[ch_idx], T_NEAR); |
1663 | } |
1664 | } |
1665 | L(ch_unroll_label[0]); |
1666 | |
1667 | // comeback |
1668 | mov(reg_coff_max, reg_coff_max_bwd_copy); |
1669 | mov(reg_diff_scale, ptr[rsp + stack_off_diff_scale]); |
1670 | |
1671 | if (is_xf16()) shr(reg_coff_max, 1); |
1672 | sub(reg_src, reg_coff_max); |
1673 | sub(reg_diff_dst, reg_coff_max); |
1674 | if (is_xf16()) shl(reg_coff_max, 1); |
1675 | |
1676 | if (with_relu) { |
1677 | shr(reg_coff_max, 5); |
1678 | sub(reg_ws, reg_coff_max); |
1679 | shl(reg_coff_max, 5); |
1680 | } |
1681 | } |
1682 | |
1683 | void backward_diff_channels() { |
1684 | Label diff_channels; |
1685 | L(diff_channels); |
1686 | { |
1687 | uni_vmovups_maybe_tail(vmean, mean_ptr()); |
1688 | uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); |
1689 | uni_vaddps(vsqrtvar, vsqrtvar, veps); |
1690 | uni_vsqrtps(vsqrtvar, vsqrtvar); |
1691 | uni_vdivps(vsqrtvar, vone, vsqrtvar, vtmp); |
1692 | if (pd_->use_scale()) uni_vmovups_maybe_tail(vgamma, gamma_ptr()); |
1693 | uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr()); |
1694 | uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr()); |
1695 | uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); |
1696 | uni_vdivps(vdiff_beta, vdiff_beta, vchan_size); |
1697 | uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size); |
1698 | |
1699 | const auto spat_loop_init_fin |
1700 | = [=](size_t base_reg) { UNUSED(base_reg); }; |
1701 | const auto spat_loop_body = [=](size_t base_reg, size_t i, |
1702 | bool stream_store_allowed) { |
1703 | const Vmm v(base_reg * 2 + 0); |
1704 | const Vmm t(base_reg * 2 + 1); |
1705 | const Vmm t1(base_reg * 2 + 2); |
1706 | const size_t offt = i * vlen_spat_data_; |
1707 | uni_vmovups_spat_data( |
1708 | v, vmmword[reg_diff_dst + reg_soff + offt]); |
1709 | if (with_relu) { |
1710 | if (isa == avx512_core) |
1711 | bwd_process_relu_avx512_common(v, offt); |
1712 | else if (isa == avx2) |
1713 | bwd_process_relu_avx2(v, offt); |
1714 | else |
1715 | assert(false); |
1716 | } |
1717 | if (!pd_->use_global_stats()) { |
1718 | uni_vsubps(v, v, vdiff_beta); |
1719 | uni_vmovups_spat_data( |
1720 | t, vmmword[reg_src + reg_soff + offt]); |
1721 | uni_vsubps(t, vmean, t, t1); |
1722 | uni_vmulps(t, t, vdiff_gamma); |
1723 | uni_vaddps(v, v, t); |
1724 | } |
1725 | uni_vmulps(v, v, vsqrtvar); |
1726 | if (pd_->use_scale()) { uni_vmulps(v, v, vgamma); } |
1727 | if (stream_store_allowed) { |
1728 | uni_vmovntps(vmmword[reg_diff_src + reg_soff + offt], v); |
1729 | } else { |
1730 | uni_vmovups_spat_data( |
1731 | vmmword[reg_diff_src + reg_soff + offt], v); |
1732 | } |
1733 | }; |
1734 | |
1735 | const auto compute = [=](bool stream_store_allowed) { |
1736 | using namespace std::placeholders; |
1737 | spat_loop(spat_size, unroll_blocks, unroll_regs, |
1738 | spat_loop_init_fin, |
1739 | std::bind(spat_loop_body, _1, _2, stream_store_allowed), |
1740 | spat_loop_init_fin); |
1741 | }; |
1742 | |
1743 | if (stream_store_supported()) { |
1744 | Label normal_store, end_store; |
1745 | test(reg_diff_src, vlen - 1); |
1746 | jnz(normal_store, T_NEAR); |
1747 | compute(true); |
1748 | jmp(end_store, T_NEAR); |
1749 | L(normal_store); |
1750 | { compute(false); } |
1751 | L(end_store); |
1752 | } else { |
1753 | compute(false); // no NT store for BF16 |
1754 | } |
1755 | |
1756 | add(reg_coff, vlen); |
1757 | cmp(reg_coff, reg_coff_max); |
1758 | jl(diff_channels); |
1759 | } |
1760 | } |
1761 | |
1762 | void backward_diff_channels_nspc_compute(const int num_ch_blks) { |
1763 | auto compute = [=](bool stream_store_allowed) { |
1764 | xor_(reg_soff_nspc, reg_soff_nspc); |
1765 | if (jbp_->is_spatial_thr_) { |
1766 | mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); |
1767 | add(reg_soff_nspc, ptr[rsp + stack_off_s_s]); |
1768 | } else { |
1769 | mov(reg_ctr, spat_size); |
1770 | } |
1771 | |
1772 | // TODO: spatial blocking |
1773 | const int num_spat_pts = 1; |
1774 | |
1775 | // pre-compute scale for each channel to avoid costly div and sqrt |
1776 | if (!pd_->use_global_stats()) { |
1777 | mov(ptr[rsp + stack_off_ws_off_copy], reg_ws); |
1778 | mov(reg_ws, ptr[rsp + stack_off_diff_scale]); |
1779 | } |
1780 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
1781 | const int coff = idx * vlen; |
1782 | const Vmm vsqrtvar_ch = Vmm(idx); |
1783 | uni_vmovups_maybe_tail(vsqrtvar_ch, var_ptr(coff)); |
1784 | uni_vaddps(vsqrtvar_ch, vsqrtvar_ch, veps); |
1785 | uni_vsqrtps(vsqrtvar_ch, vsqrtvar_ch); |
1786 | uni_vdivps(vsqrtvar_ch, vone, vsqrtvar_ch, vtmp); |
1787 | if (!pd_->use_global_stats()) { |
1788 | const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks); |
1789 | const Vmm vdiff_gamma_ch = Vmm(idx + 2 * num_ch_blks); |
1790 | uni_vmovups_maybe_tail(vdiff_beta_ch, |
1791 | vmmword[reg_diff_shift + reg_coff + coff]); |
1792 | uni_vmovups_maybe_tail( |
1793 | vdiff_gamma_ch, vmmword[reg_ws + reg_coff + coff]); |
1794 | uni_vdivps(vdiff_beta_ch, vdiff_beta_ch, vchan_size); |
1795 | uni_vmulps(vdiff_gamma_ch, vdiff_gamma_ch, vsqrtvar_ch); |
1796 | uni_vdivps(vdiff_gamma_ch, vdiff_gamma_ch, vchan_size); |
1797 | } |
1798 | } |
1799 | if (!pd_->use_global_stats()) { |
1800 | mov(reg_ws, ptr[rsp + stack_off_ws_off_copy]); |
1801 | } |
1802 | |
1803 | Label spatial; |
1804 | L(spatial); |
1805 | { |
1806 | for (int idx = 0; idx < num_ch_blks; ++idx) { |
1807 | const int coff = idx * vlen; |
1808 | const int offt = idx * vlen_spat_data_; |
1809 | // vdiff_beta and vdiff_gamma are free registers for nspc |
1810 | const Vmm vdiff_data = vdiff_beta; |
1811 | const Vmm vdata = vdiff_gamma; |
1812 | const Vmm vsqrtvar_ch = Vmm(idx); |
1813 | uni_vmovups_maybe_tail(vmean, mean_ptr(coff)); |
1814 | |
1815 | if (pd_->use_scale()) |
1816 | uni_vmovups_maybe_tail(vgamma, gamma_ptr(coff)); |
1817 | |
1818 | uni_vmovups_spat_data(vdiff_data, |
1819 | vmmword[reg_diff_dst + reg_soff_nspc + offt]); |
1820 | |
1821 | if (with_relu) { |
1822 | if (isa == avx512_core) |
1823 | bwd_process_relu_avx512_common(vdiff_data, offt); |
1824 | else |
1825 | assert(false); |
1826 | } |
1827 | |
1828 | if (!pd_->use_global_stats()) { |
1829 | const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks); |
1830 | const Vmm vdiff_gamma_ch = Vmm(idx + 2 * num_ch_blks); |
1831 | uni_vsubps(vdiff_data, vdiff_data, vdiff_beta_ch); |
1832 | uni_vmovups_spat_data( |
1833 | vdata, vmmword[reg_src + reg_soff_nspc + offt]); |
1834 | uni_vsubps(vdata, vmean, vdata, vtmp); |
1835 | uni_vmulps(vdata, vdata, vdiff_gamma_ch); |
1836 | uni_vaddps(vdiff_data, vdiff_data, vdata); |
1837 | } |
1838 | |
1839 | uni_vmulps(vdiff_data, vdiff_data, vsqrtvar_ch); |
1840 | |
1841 | if (pd_->use_scale()) { |
1842 | uni_vmulps(vdiff_data, vdiff_data, vgamma); |
1843 | } |
1844 | |
1845 | uni_vmovups_spat_data( |
1846 | vmmword[reg_diff_src + reg_soff_nspc + offt], |
1847 | vdiff_data, stream_store_allowed); |
1848 | } |
1849 | add(reg_soff_nspc, spat_step); |
1850 | sub(reg_ctr, num_spat_pts); |
1851 | jnz(spatial, T_NEAR); |
1852 | } |
1853 | }; |
1854 | |
1855 | if (stream_store_supported()) { |
1856 | Label normal_store, end_store; |
1857 | test(reg_diff_src, vlen - 1); |
1858 | jnz(normal_store, T_NEAR); |
1859 | compute(true); |
1860 | jmp(end_store, T_NEAR); |
1861 | L(normal_store); |
1862 | { compute(false); } |
1863 | L(end_store); |
1864 | } else { |
1865 | compute(false); // disabled for bf16 when data fits in cache |
1866 | } |
1867 | } |
1868 | |
1869 | void backward_diff_channels_nspc() { |
1870 | xor_(reg_coff, reg_coff); |
1871 | mov(reg_coff_max_bwd_copy, reg_coff_max); |
1872 | |
1873 | Label ch_unroll_label[5]; |
1874 | const int max_ch_unroll = 3; |
1875 | |
1876 | // TODO: Spatial and channel unrolling decisions should be made during |
1877 | // initialization depending on the problem size |
1878 | for (int ch_idx = max_ch_unroll; ch_idx > 0; --ch_idx) { |
1879 | L(ch_unroll_label[ch_idx]); |
1880 | { |
1881 | const int ch_blk_size = (1 << (ch_idx - 1)); // 4, 2, 1 |
1882 | cmp(reg_coff_max, vlen * ch_blk_size); |
1883 | jl(ch_unroll_label[ch_idx - 1], T_NEAR); |
1884 | |
1885 | backward_diff_channels_nspc_compute(ch_blk_size); |
1886 | |
1887 | add(reg_diff_dst, vlen_spat_data_ * ch_blk_size); |
1888 | if (!pd_->use_global_stats()) |
1889 | add(reg_src, vlen_spat_data_ * ch_blk_size); |
1890 | add(reg_diff_src, vlen_spat_data_ * ch_blk_size); |
1891 | |
1892 | // advance mean_ptr() and var_ptr() |
1893 | add(reg_coff, vlen * ch_blk_size); |
1894 | |
1895 | add(reg_ws, 2 * ch_blk_size); |
1896 | |
1897 | sub(reg_coff_max, vlen * ch_blk_size); |
1898 | jmp(ch_unroll_label[ch_idx], T_NEAR); |
1899 | } |
1900 | } |
1901 | L(ch_unroll_label[0]); |
1902 | |
1903 | // comeback |
1904 | mov(reg_coff_max, reg_coff_max_bwd_copy); |
1905 | mov(reg_diff_scale, ptr[rsp + stack_off_diff_scale]); |
1906 | |
1907 | if (is_xf16()) shr(reg_coff_max, 1); |
1908 | sub(reg_diff_dst, reg_coff_max); |
1909 | if (!pd_->use_global_stats()) sub(reg_src, reg_coff_max); |
1910 | sub(reg_diff_src, reg_coff_max); |
1911 | if (is_xf16()) shl(reg_coff_max, 1); |
1912 | |
1913 | shr(reg_coff_max, 5); |
1914 | sub(reg_ws, reg_coff_max); |
1915 | shl(reg_coff_max, 5); |
1916 | } |
1917 | |
1918 | void backward() { |
1919 | uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); |
1920 | xor_(reg_coff, reg_coff); |
1921 | Label zero_rbuf, sh_spatial; |
1922 | |
1923 | L(zero_rbuf); |
1924 | { |
1925 | uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); |
1926 | uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0)); |
1927 | add(reg_coff, isa == sse41 ? vlen / 2 : vlen); |
1928 | cmp(reg_coff, reg_coff_max); |
1929 | jne(zero_rbuf); |
1930 | } |
1931 | |
1932 | mov(reg_src, ptr[rsp + stack_off_src]); |
1933 | mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); |
1934 | if (with_relu) { |
1935 | assert(isa == avx2 || isa == avx512_core); |
1936 | mov(reg_ws, ptr[rsp + stack_off_ws]); |
1937 | } |
1938 | |
1939 | xor_(reg_soff, reg_soff); |
1940 | L(sh_spatial); |
1941 | { |
1942 | xor_(reg_coff, reg_coff); |
1943 | if (isa == sse41) { mov(reg_tmp_off, reg_soff); } |
1944 | jbp_->is_nspc_ ? backward_sh_channels_nspc() |
1945 | : backward_sh_channels(); |
1946 | if (isa == sse41) { |
1947 | mov(reg_soff, reg_tmp_off); |
1948 | add(reg_diff_dst, vlen / 2); |
1949 | add(reg_src, vlen / 2); |
1950 | mov(reg_coff, vlen / 2); |
1951 | backward_sh_channels(); |
1952 | sub(reg_diff_dst, vlen / 2); |
1953 | sub(reg_src, vlen / 2); |
1954 | } |
1955 | // Process next image |
1956 | if (jbp_->is_nspc_) { |
1957 | // Can use static offset since we comeback after spatial loop |
1958 | add(reg_src, mb_offt); |
1959 | add(reg_diff_dst, mb_offt); |
1960 | add(reg_soff, mb_offt); |
1961 | add(reg_ws, ws_mb_offt); |
1962 | } else { |
1963 | add(reg_soff, reg_mb_stride_Bc); |
1964 | } |
1965 | cmp(reg_soff, reg_soff_max); |
1966 | jl(sh_spatial); |
1967 | } |
1968 | |
1969 | if (jbp_->is_nspc_) { |
1970 | // comeback |
1971 | mov(reg_src, ptr[rsp + stack_off_src]); |
1972 | mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); |
1973 | } |
1974 | |
1975 | mov(reg_diff_scale, ptr[rsp + stack_off_diff_scale]); |
1976 | mov(reg_diff_shift, ptr[rsp + stack_off_diff_shift]); |
1977 | |
1978 | Label no_sh_reduction; |
1979 | barrier(); |
1980 | { |
1981 | mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); |
1982 | cmp(reg_tmp, 0); |
1983 | Label sh_reduction_channels; |
1984 | jne(no_sh_reduction, T_NEAR); |
1985 | |
1986 | mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); |
1987 | xor_(reg_coff, reg_coff); |
1988 | L(sh_reduction_channels); |
1989 | { |
1990 | mov(reg_roff, reg_coff); |
1991 | uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); |
1992 | uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); |
1993 | uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); |
1994 | uni_vaddps(vsqrtvar, vsqrtvar, veps); |
1995 | uni_vsqrtps(vsqrtvar, vsqrtvar); |
1996 | uni_vdivps(vsqrtvar, vone, vsqrtvar, vtmp); |
1997 | mov(reg_ctr, reg_nnthr); |
1998 | Label sh_reduction_thrs; |
1999 | L(sh_reduction_thrs); |
2000 | { // TODO: unroll (?) |
2001 | uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]); |
2002 | uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]); |
2003 | add(reg_roff, reg_coff_max); |
2004 | sub(reg_ctr, 1); |
2005 | jnz(sh_reduction_thrs); |
2006 | } |
2007 | uni_vmulps(Vmm(0), Vmm(0), vsqrtvar); |
2008 | uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0)); |
2009 | uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1)); |
2010 | add(reg_coff, isa == sse41 ? vlen / 2 : vlen); |
2011 | cmp(reg_coff, reg_coff_max); |
2012 | jne(sh_reduction_channels); |
2013 | } |
2014 | } |
2015 | L(no_sh_reduction); |
2016 | barrier(); |
2017 | |
2018 | mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); |
2019 | if (with_relu) { |
2020 | assert(isa == avx2 || isa == avx512_core); |
2021 | mov(reg_ws, ptr[rsp + stack_off_ws]); |
2022 | } |
2023 | |
2024 | xor_(reg_soff, reg_soff); |
2025 | Label diff_spatial; |
2026 | L(diff_spatial); |
2027 | { |
2028 | xor_(reg_coff, reg_coff); |
2029 | // diff_shift is shared with soff_max. |
2030 | mov(reg_diff_shift, ptr[rsp + stack_off_diff_shift]); |
2031 | if (isa == sse41) { mov(reg_tmp_off, reg_soff); } |
2032 | jbp_->is_nspc_ ? backward_diff_channels_nspc() |
2033 | : backward_diff_channels(); |
2034 | if (isa == sse41) { |
2035 | mov(reg_soff, reg_tmp_off); |
2036 | add(reg_diff_dst, vlen / 2); |
2037 | add(reg_diff_src, vlen / 2); |
2038 | add(reg_src, vlen / 2); |
2039 | mov(reg_coff, vlen / 2); |
2040 | backward_diff_channels(); |
2041 | sub(reg_diff_dst, vlen / 2); |
2042 | sub(reg_diff_src, vlen / 2); |
2043 | sub(reg_src, vlen / 2); |
2044 | } |
2045 | // Process next image |
2046 | if (jbp_->is_nspc_) { |
2047 | // Can use static offset since we comeback after spatial loop |
2048 | if (!pd_->use_global_stats()) add(reg_src, mb_offt); |
2049 | add(reg_diff_dst, mb_offt); |
2050 | add(reg_diff_src, mb_offt); |
2051 | add(reg_soff, mb_offt); |
2052 | add(reg_ws, ws_mb_offt); |
2053 | } else { |
2054 | add(reg_soff, reg_mb_stride_Bc); |
2055 | } |
2056 | |
2057 | // comeback soff_max. Shared with diff_shift. |
2058 | mov(reg_soff_max, ptr[rsp + stack_off_soff_max]); |
2059 | cmp(reg_soff, reg_soff_max); |
2060 | jl(diff_spatial); |
2061 | } |
2062 | if (jbp_->is_nspc_) { |
2063 | // comeback |
2064 | if (!pd_->use_global_stats()) |
2065 | mov(reg_src, ptr[rsp + stack_off_src]); |
2066 | mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); |
2067 | mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); |
2068 | if (with_relu) mov(reg_ws, ptr[rsp + stack_off_ws]); |
2069 | } |
2070 | } |
2071 | |
2072 | jit_bnorm_t(const batch_normalization_pd_t *pd, const jit_bnorm_conf_t *jbp) |
2073 | : jit_generator(jit_name()), pd_(pd), jbp_(jbp) { |
2074 | static_assert(isa == sse41 || isa == avx2 || isa == avx512_core, |
2075 | "unsupported isa" ); |
2076 | |
2077 | is_bf16_ = pd_->src_md()->data_type == data_type::bf16; |
2078 | is_f16_ = pd_->src_md()->data_type == data_type::f16; |
2079 | is_avx2_ne_xf16_ |
2080 | = isa == avx2 && mayiuse(avx2_vnni_2) && (is_bf16_ || is_f16_); |
2081 | vlen_spat_data_ = vlen / (1 + is_xf16()); // 32B of xF16 -> 64B of FP32 |
2082 | |
2083 | unroll_blocks = isa == avx512_core && !jbp_->is_spatial_thr_ ? 4 : 1; |
2084 | unroll_regs = isa == avx512_core && !jbp_->is_spatial_thr_ ? 4 : 1; |
2085 | } |
2086 | |
2087 | void generate() override { |
2088 | preamble(); |
2089 | |
2090 | if (use_bf16_emulation()) { |
2091 | // init emulation of bfloat16 operations |
2092 | bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserved_1, |
2093 | bf16_emu_reserved_2, bf16_emu_reserved_3, reg_bf16_tmp, |
2094 | bf16_emu_reserved_4, bf16_emu_reserved_4); |
2095 | bf16_emu_->init_vcvtneps2bf16(); |
2096 | } |
2097 | |
2098 | if (isa == avx512_core) |
2099 | prepare_tail_mask_avx512_common(); |
2100 | else if (isa == avx2) |
2101 | prepare_tail_mask_avx2_common(); |
2102 | |
2103 | compute_static_strides(); |
2104 | |
2105 | prepare_relu(); |
2106 | |
2107 | sub(rsp, stack_size_required); |
2108 | load_common_params(); |
2109 | |
2110 | if (pd_->is_fwd()) { |
2111 | if (!pd_->stats_is_src()) { compute_mean_variance(); } |
2112 | forward(); |
2113 | } else { |
2114 | backward(); |
2115 | } |
2116 | add(rsp, stack_size_required); |
2117 | postamble(); |
2118 | } |
2119 | |
2120 | void operator()(const call_params_t *p) { jit_generator::operator()(p); } |
2121 | |
2122 | ~jit_bnorm_t() override { delete bf16_emu_; } |
2123 | }; |
2124 | |
2125 | namespace bnorm_impl { |
2126 | |
2127 | template <cpu_isa_t isa> |
2128 | struct driver_t : public c_compatible { |
2129 | driver_t(const batch_normalization_pd_t *pd, int nthr) |
2130 | : pd_(pd), jbp_(pd_, nthr, simd_w), ker_(pd_, &jbp_) {} |
2131 | |
2132 | ~driver_t() = default; |
2133 | |
2134 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
2135 | const batch_normalization_pd_t *pd, int nthr) { |
2136 | dim_t C_PADDED = get_c_padded(pd); |
2137 | |
2138 | auto sbuf_sz = use_tmp_stats(pd) * 2 * C_PADDED; |
2139 | auto pbuf_sz |
2140 | = (use_tmp_diff_scale(pd) + use_tmp_diff_shift(pd)) * C_PADDED; |
2141 | auto rbuf_sz = (pd->is_fwd() ? 1 : 2) * C_PADDED * nthr; |
2142 | |
2143 | scratchpad.book<acc_data_t>(key_bnorm_tmp_stats, sbuf_sz); |
2144 | scratchpad.book<acc_data_t>(key_bnorm_tmp_diff_ss, pbuf_sz); |
2145 | scratchpad.book<acc_data_t>(key_bnorm_reduction, rbuf_sz); |
2146 | |
2147 | if (dnnl_thr_syncable()) { |
2148 | auto n_barriers = C_PADDED / simd_w; |
2149 | scratchpad.book<barrier::ctx_64_t>(key_barrier, n_barriers); |
2150 | } |
2151 | } |
2152 | |
2153 | // given nthr, shape of problem, and thread partition, |
2154 | // balance work among the threads |
2155 | void thread_balance(int ithr, int nthr, dim_t N, dim_t C_blks, dim_t SP, |
2156 | int &C_ithr, int C_nthr, dim_t &C_blk_s, dim_t &C_blk_e, |
2157 | int &N_ithr, int N_nthr, dim_t &N_s, dim_t &N_e, int &S_ithr, |
2158 | int S_nthr, dim_t &S_s, dim_t &S_e) { |
2159 | if (ithr < C_nthr * N_nthr * S_nthr) { |
2160 | utils::nd_iterator_init( |
2161 | ithr, C_ithr, C_nthr, N_ithr, N_nthr, S_ithr, S_nthr); |
2162 | balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); |
2163 | balance211(N, N_nthr, N_ithr, N_s, N_e); |
2164 | balance211(SP, S_nthr, S_ithr, S_s, S_e); |
2165 | } else { |
2166 | S_ithr = N_ithr = C_ithr = -ithr; |
2167 | S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1; |
2168 | } |
2169 | } |
2170 | |
2171 | void exec(int ithr, int nthr, const void *src, void *diff_src, void *dst, |
2172 | const void *diff_dst, const acc_data_t *scale, |
2173 | acc_data_t *diff_scale, const acc_data_t *shift, |
2174 | acc_data_t *diff_shift, const acc_data_t *mean, |
2175 | const acc_data_t *var, const uint8_t *ws, |
2176 | const memory_tracking::grantor_t &scratchpad) { |
2177 | auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats); |
2178 | auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss); |
2179 | auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction); |
2180 | auto barriers = scratchpad.get<barrier::ctx_64_t>(key_barrier); |
2181 | |
2182 | dim_t N = pd_->MB(); |
2183 | dim_t C = pd_->C(); |
2184 | dim_t C_PADDED = get_c_padded(pd_); |
2185 | dim_t D = pd_->D(); |
2186 | dim_t H = pd_->H(); |
2187 | dim_t W = pd_->W(); |
2188 | dim_t SP = D * H * W; |
2189 | dim_t img_size = C_PADDED * SP; |
2190 | const int vlen_spat_data = ker_.spat_step; |
2191 | |
2192 | typename jit_bnorm_t<isa>::call_params_t p; |
2193 | |
2194 | p.eps = pd_->desc()->batch_norm_epsilon; |
2195 | p.one = 1.0f; |
2196 | p.spat_size = SP; |
2197 | p.chan_size = 1.0f * N * p.spat_size; |
2198 | |
2199 | int C_ithr {0}, N_ithr {0}, S_ithr {0}; |
2200 | dim_t C_blk_s {0}, C_blk_e {0}, N_s {0}, N_e {0}, S_s {0}, S_e {0}; |
2201 | |
2202 | this->thread_balance(ithr, nthr, N, jbp_.C_blks_per_iter_, SP, C_ithr, |
2203 | jbp_.C_nthr_, C_blk_s, C_blk_e, N_ithr, jbp_.N_nthr_, N_s, N_e, |
2204 | S_ithr, jbp_.S_nthr_, S_s, S_e); |
2205 | |
2206 | int SP_N_ithr = N_ithr * jbp_.S_nthr_ + S_ithr; |
2207 | int SP_N_nthr = jbp_.N_nthr_ * jbp_.S_nthr_; |
2208 | assert(IMPLICATION(!dnnl_thr_syncable(), SP_N_nthr == 1)); |
2209 | |
2210 | p.N_ithr = SP_N_ithr; |
2211 | p.N_nthr = SP_N_nthr; |
2212 | |
2213 | int global_C_blk_s; |
2214 | int global_barriers_per_iter = jbp_.C_nthr_; |
2215 | |
2216 | for (int64_t it = 0; it < jbp_.iters_; it++) { |
2217 | if (it == jbp_.iters_ - 1 && jbp_.iters_ > 1) { |
2218 | C_blk_s = C_blk_e = N_s = N_e = 0; |
2219 | this->thread_balance(ithr, nthr, N, jbp_.C_blks_last_iter_, SP, |
2220 | C_ithr, jbp_.C_nthr_last_iter_, C_blk_s, C_blk_e, |
2221 | N_ithr, jbp_.N_nthr_last_iter_, N_s, N_e, S_ithr, |
2222 | jbp_.S_nthr_last_iter_, S_s, S_e); |
2223 | |
2224 | // Update call parameters for JIT, last iteration |
2225 | p.N_ithr = N_ithr * jbp_.S_nthr_last_iter_ + S_ithr; |
2226 | p.N_nthr = jbp_.N_nthr_last_iter_ * jbp_.S_nthr_last_iter_; |
2227 | } |
2228 | |
2229 | global_C_blk_s = jbp_.do_blocking_ ? (C_blk_s == -1) |
2230 | ? -1 |
2231 | : it * jbp_.C_blks_per_iter_ + C_blk_s |
2232 | : C_blk_s; |
2233 | |
2234 | int C_blks_thr = C_blk_e - C_blk_s; |
2235 | int N_thr = N_e - N_s; |
2236 | |
2237 | if (C_blks_thr == 0 || N_thr == 0) continue; |
2238 | |
2239 | size_t coff_base = global_C_blk_s * simd_w; |
2240 | size_t soff_base = jbp_.is_nspc_ |
2241 | ? coff_base + N_s * img_size |
2242 | : global_C_blk_s * p.spat_size * simd_w + N_s * img_size; |
2243 | size_t shift_off = use_tmp_diff_scale(pd_) ? pd_->C() : 0; |
2244 | |
2245 | p.spat_size_loc = S_e - S_s; |
2246 | p.S_s = S_s * vlen_spat_data; |
2247 | p.S_tail = (p.spat_size - S_e) * vlen_spat_data; |
2248 | p.coff_max = C_blks_thr * simd_w; |
2249 | const auto tmp_mean = use_tmp_stats(pd_) ? sbuf : mean; |
2250 | if (tmp_mean != nullptr) p.mean = tmp_mean + coff_base; |
2251 | const auto tmp_var = use_tmp_stats(pd_) ? sbuf + C_PADDED : var; |
2252 | if (tmp_var != nullptr) p.var = tmp_var + coff_base; |
2253 | if (scale != nullptr) p.scale = scale + coff_base; |
2254 | if (shift != nullptr) p.shift = shift + coff_base; |
2255 | const auto tmp_diff_scale |
2256 | = use_tmp_diff_scale(pd_) ? pbuf : diff_scale; |
2257 | if (tmp_diff_scale != nullptr) |
2258 | p.diff_scale = tmp_diff_scale + coff_base; |
2259 | const auto tmp_diff_shift |
2260 | = use_tmp_diff_shift(pd_) ? &pbuf[shift_off] : diff_shift; |
2261 | if (tmp_diff_shift != nullptr) |
2262 | p.diff_shift = tmp_diff_shift + coff_base; |
2263 | |
2264 | p.soff_max = jbp_.dt_size_ * N_thr * img_size; |
2265 | if (src != nullptr) |
2266 | p.src = (void *)((char *)src + soff_base * jbp_.dt_size_); |
2267 | if (dst != nullptr) |
2268 | p.dst = (void *)((char *)dst + soff_base * jbp_.dt_size_); |
2269 | if (diff_src != nullptr) |
2270 | p.diff_src = (void *)((char *)diff_src |
2271 | + soff_base * jbp_.dt_size_); |
2272 | if (diff_dst != nullptr) |
2273 | p.diff_dst = (void *)((char *)diff_dst |
2274 | + soff_base * jbp_.dt_size_); |
2275 | if (ws != nullptr) p.ws = ws + soff_base / 8; |
2276 | |
2277 | p.mb_stride_Bc |
2278 | = jbp_.dt_size_ * (img_size - p.coff_max * p.spat_size); |
2279 | |
2280 | // use SP_N_nthr which is the same as p.N_nthr except maybe for |
2281 | // the last iteration. |
2282 | p.rbuf1 = rbuf |
2283 | + ((it * jbp_.C_blks_per_iter_) * SP_N_nthr |
2284 | + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) |
2285 | * simd_w; |
2286 | // rbuf1 and rbuf2 have to be disjoint |
2287 | p.rbuf2 = p.rbuf1 + C_PADDED * nthr; |
2288 | p.is_cblk_tail |
2289 | = (it * jbp_.C_blks_per_iter_ + C_blk_e) * simd_w > C; |
2290 | |
2291 | size_t iter_barriers |
2292 | = jbp_.do_blocking_ ? it * global_barriers_per_iter : 0; |
2293 | p.barrier = barriers + C_ithr + iter_barriers; |
2294 | if (p.soff_max != 0 && p.coff_max != 0) ker_(&p); |
2295 | } |
2296 | } |
2297 | |
2298 | void init_barriers(const memory_tracking::grantor_t &scratchpad) { |
2299 | auto barriers = scratchpad.get<barrier::ctx_64_t>(key_barrier); |
2300 | if (barriers) { |
2301 | const int n_barriers = get_c_padded(pd_) / simd_w; |
2302 | for (int i = 0; i < n_barriers; ++i) |
2303 | barrier::ctx_init(&barriers[i]); |
2304 | } |
2305 | } |
2306 | |
2307 | status_t create_kernel() { return ker_.create_kernel(); } |
2308 | |
2309 | private: |
2310 | enum { |
2311 | simd_w = isa == sse41 ? 8 |
2312 | : cpu_isa_traits<isa>::vlen |
2313 | / sizeof(acc_data_t) // BF16 will expand to FP32 |
2314 | }; |
2315 | |
2316 | static bool use_tmp_stats(const batch_normalization_pd_t *pd) { |
2317 | return !pd->stats_is_src() |
2318 | && pd->desc()->prop_kind == prop_kind::forward_inference; |
2319 | } |
2320 | |
2321 | static bool use_tmp_diff_scale(const batch_normalization_pd_t *pd) { |
2322 | return (!pd->is_fwd() && !pd->use_scale()) |
2323 | || pd->desc()->prop_kind == prop_kind::backward_data; |
2324 | } |
2325 | |
2326 | static bool use_tmp_diff_shift(const batch_normalization_pd_t *pd) { |
2327 | return (!pd->is_fwd() && !pd->use_shift()) |
2328 | || pd->desc()->prop_kind == prop_kind::backward_data; |
2329 | } |
2330 | |
2331 | const batch_normalization_pd_t *pd_; |
2332 | jit_bnorm_conf_t jbp_; |
2333 | jit_bnorm_t<isa> ker_; |
2334 | }; |
2335 | } // namespace bnorm_impl |
2336 | |
2337 | using namespace data_type; |
2338 | using namespace format_tag; |
2339 | using namespace utils; |
2340 | |
2341 | /* fwd */ |
2342 | |
2343 | template <cpu_isa_t isa> |
2344 | status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init(engine_t *engine) { |
2345 | bool ok = is_fwd() && mayiuse(isa) |
2346 | && !has_zero_dim_memory() |
2347 | // Algorithm requires barriers for best performance. |
2348 | // TBB utilizes jit_uni_tbb_batch_normalization implementation. |
2349 | && dnnl_thr_syncable() |
2350 | && one_of(src_md()->data_type, f32, bf16, f16) |
2351 | && src_md()->data_type == dst_md()->data_type |
2352 | && IMPLICATION(src_md()->data_type == bf16, |
2353 | is_superset(isa, avx512_core) |
2354 | || (isa == avx2 && mayiuse(avx2_vnni_2))) |
2355 | // Note: re-using avx512_core/avx2 implementation for f16. |
2356 | // This is okay as currently, we do not support binary post-ops |
2357 | // for this primitive. |
2358 | && IMPLICATION(src_md()->data_type == f16, |
2359 | (is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16)) |
2360 | || (isa == avx2 && mayiuse(avx2_vnni_2))) |
2361 | && check_scale_shift_data_type() |
2362 | && (attr()->has_default_values() |
2363 | || with_relu_post_op(is_training())) |
2364 | && set_default_formats_common() |
2365 | && memory_desc_wrapper(src_md()) == memory_desc_wrapper(dst_md()); |
2366 | if (!ok) return status::unimplemented; |
2367 | |
2368 | // BN+Add+Relu fusion is not currently implemented |
2369 | if (fuse_norm_add_relu()) return status::unimplemented; |
2370 | |
2371 | const memory_desc_wrapper src_d(src_md()); |
2372 | if (isa == avx512_core) { |
2373 | if (!src_d.matches_one_of_tag( |
2374 | nCw16c, nChw16c, nCdhw16c, nc, nwc, nhwc, ndhwc)) |
2375 | return status::unimplemented; |
2376 | } else if (isa == avx2 && one_of(src_md()->data_type, bf16, f16)) { |
2377 | // no support for training or blocked layouts for avx2_vnni_2 |
2378 | if (is_training() || src_d.matches_one_of_tag(nc, nwc, nhwc, ndhwc)) |
2379 | return status::unimplemented; |
2380 | } else if (isa == avx2) { |
2381 | // full support |
2382 | if (!src_d.matches_one_of_tag( |
2383 | nCw8c, nChw8c, nCdhw8c, nc, nwc, nhwc, ndhwc)) |
2384 | return status::unimplemented; |
2385 | } else { |
2386 | if (!src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c)) |
2387 | return status::unimplemented; |
2388 | } |
2389 | |
2390 | const bool isa_supports_avx2 = is_superset(isa, avx2); |
2391 | if (is_training() && fuse_norm_relu()) { |
2392 | if (!isa_supports_avx2) return status::unimplemented; |
2393 | init_default_ws(1); |
2394 | } |
2395 | |
2396 | if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() |
2397 | && !isa_supports_avx2) |
2398 | return status::unimplemented; |
2399 | |
2400 | // Only IC % simd_w == 0 is supported for now |
2401 | const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(acc_data_t); |
2402 | if (src_d.matches_one_of_tag(nc, nwc, nhwc, ndhwc) |
2403 | && src_d.padded_dims()[1] % simd_w != 0) { |
2404 | return status::unimplemented; |
2405 | } |
2406 | |
2407 | nthr_ = dnnl_get_max_threads(); |
2408 | auto scratchpad = scratchpad_registry().registrar(); |
2409 | bnorm_impl::driver_t<isa>::init_scratchpad(scratchpad, this, nthr_); |
2410 | |
2411 | return status::success; |
2412 | } |
2413 | |
2414 | template <cpu_isa_t isa> |
2415 | jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t( |
2416 | const pd_t *apd) |
2417 | : primitive_t(apd) {} |
2418 | |
2419 | template <cpu_isa_t isa> |
2420 | status_t jit_uni_batch_normalization_fwd_t<isa>::init(engine_t *engine) { |
2421 | CHECK(safe_ptr_assign( |
2422 | bnorm_driver_, new bnorm_impl::driver_t<isa>(pd(), pd()->nthr_))); |
2423 | return bnorm_driver_->create_kernel(); |
2424 | } |
2425 | |
2426 | template <cpu_isa_t isa> |
2427 | status_t jit_uni_batch_normalization_fwd_t<isa>::execute( |
2428 | const exec_ctx_t &ctx) const { |
2429 | |
2430 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
2431 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
2432 | auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT); |
2433 | |
2434 | auto mean = pd()->stats_is_src() ? const_cast<acc_data_t *>( |
2435 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN)) |
2436 | : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN); |
2437 | auto var = pd()->stats_is_src() |
2438 | ? const_cast<acc_data_t *>( |
2439 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE)) |
2440 | : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE); |
2441 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
2442 | auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE); |
2443 | |
2444 | auto scratchpad = ctx.get_scratchpad_grantor(); |
2445 | |
2446 | bnorm_driver_->init_barriers(scratchpad); |
2447 | const int nthr = pd()->nthr_; |
2448 | |
2449 | parallel(nthr, [&](const int ithr, const int nthr) { |
2450 | bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, scale, |
2451 | nullptr, shift, nullptr, mean, var, ws, scratchpad); |
2452 | }); |
2453 | |
2454 | return status::success; |
2455 | } |
2456 | |
2457 | template <cpu_isa_t isa> |
2458 | jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t() { |
2459 | delete bnorm_driver_; |
2460 | } |
2461 | |
2462 | template <cpu_isa_t isa> |
2463 | status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) { |
2464 | bool ok = !is_fwd() && mayiuse(isa) |
2465 | && !has_zero_dim_memory() |
2466 | // Algorithm requires barriers for best performance. |
2467 | // TBB utilizes jit_uni_tbb_batch_normalization implementation. |
2468 | && dnnl_thr_syncable() |
2469 | && one_of(src_md()->data_type, f32, bf16, f16) |
2470 | && src_md()->data_type == diff_src_md()->data_type |
2471 | && diff_src_md()->data_type == diff_dst_md()->data_type |
2472 | && IMPLICATION( |
2473 | src_md()->data_type == bf16, is_superset(isa, avx512_core)) |
2474 | // Note: re-using avx512_core implementation for f16. This is okay |
2475 | // as currently, we do not support binary post-ops for this |
2476 | // primitive. |
2477 | && IMPLICATION(src_md()->data_type == f16, |
2478 | is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16)) |
2479 | && check_scale_shift_data_type() && attr()->has_default_values() |
2480 | && set_default_formats_common() |
2481 | && memory_desc_wrapper(diff_src_md()) |
2482 | == memory_desc_wrapper(diff_dst_md()); |
2483 | if (!ok) return status::unimplemented; |
2484 | |
2485 | // BN+Add+Relu fusion is not currently implemented |
2486 | if (fuse_norm_add_relu()) return status::unimplemented; |
2487 | |
2488 | const memory_desc_wrapper src_d(src_md()); |
2489 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
2490 | |
2491 | format_tag_t src_tag, diff_src_tag; |
2492 | if (isa == avx512_core) { |
2493 | src_tag = src_d.matches_one_of_tag( |
2494 | nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c); |
2495 | diff_src_tag = diff_src_d.matches_one_of_tag( |
2496 | nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c); |
2497 | } else { |
2498 | src_tag = src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c); |
2499 | diff_src_tag = diff_src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c); |
2500 | } |
2501 | ok = (src_tag != format_tag::undef && diff_src_tag != format_tag::undef |
2502 | && src_tag == diff_src_tag); |
2503 | if (!ok) return status::unimplemented; |
2504 | |
2505 | const bool isa_supports_avx2 = is_superset(isa, avx2); |
2506 | if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() |
2507 | && !isa_supports_avx2) |
2508 | return status::unimplemented; |
2509 | |
2510 | // Only IC % 16 == 0 is supported for now |
2511 | if (src_d.matches_one_of_tag(nc, nwc, nhwc, ndhwc) |
2512 | && src_d.padded_dims()[1] % 16 != 0) { |
2513 | return status::unimplemented; |
2514 | } |
2515 | |
2516 | if (fuse_norm_relu()) { |
2517 | if (!isa_supports_avx2) return status::unimplemented; |
2518 | init_default_ws(1); |
2519 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
2520 | } |
2521 | |
2522 | /* TODO: extra checks required */ |
2523 | |
2524 | nthr_ = dnnl_get_max_threads(); |
2525 | auto scratchpad = scratchpad_registry().registrar(); |
2526 | bnorm_impl::driver_t<isa>::init_scratchpad(scratchpad, this, nthr_); |
2527 | |
2528 | return status::success; |
2529 | } |
2530 | |
2531 | template <cpu_isa_t isa> |
2532 | jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t( |
2533 | const pd_t *apd) |
2534 | : primitive_t(apd) {} |
2535 | |
2536 | template <cpu_isa_t isa> |
2537 | status_t jit_uni_batch_normalization_bwd_t<isa>::init(engine_t *engine) { |
2538 | CHECK(safe_ptr_assign( |
2539 | bnorm_driver_, new bnorm_impl::driver_t<isa>(pd(), pd()->nthr_))); |
2540 | return bnorm_driver_->create_kernel(); |
2541 | } |
2542 | |
2543 | template <cpu_isa_t isa> |
2544 | status_t jit_uni_batch_normalization_bwd_t<isa>::execute( |
2545 | const exec_ctx_t &ctx) const { |
2546 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
2547 | auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN); |
2548 | auto var = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE); |
2549 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
2550 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
2551 | auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE); |
2552 | |
2553 | auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC); |
2554 | auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE); |
2555 | auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT); |
2556 | |
2557 | auto scratchpad = ctx.get_scratchpad_grantor(); |
2558 | |
2559 | bnorm_driver_->init_barriers(scratchpad); |
2560 | const int nthr = pd()->nthr_; |
2561 | |
2562 | parallel(nthr, [&](const int ithr, const int nthr) { |
2563 | bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, scale, |
2564 | diff_scale, nullptr, diff_shift, mean, var, ws, scratchpad); |
2565 | }); |
2566 | |
2567 | return status::success; |
2568 | } |
2569 | |
2570 | template <cpu_isa_t isa> |
2571 | jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() { |
2572 | delete bnorm_driver_; |
2573 | } |
2574 | |
2575 | /* struct instantiation */ |
2576 | template struct jit_uni_batch_normalization_fwd_t<sse41>; |
2577 | template struct jit_uni_batch_normalization_bwd_t<sse41>; |
2578 | template struct jit_uni_batch_normalization_fwd_t<avx2>; |
2579 | template struct jit_uni_batch_normalization_bwd_t<avx2>; |
2580 | template struct jit_uni_batch_normalization_fwd_t<avx512_core>; |
2581 | template struct jit_uni_batch_normalization_bwd_t<avx512_core>; |
2582 | |
2583 | } // namespace x64 |
2584 | } // namespace cpu |
2585 | } // namespace impl |
2586 | } // namespace dnnl |
2587 | |