1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | #include <cmath> |
19 | #include <memory> |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | #include "common/memory_tracking.hpp" |
25 | #include "common/nstl.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/cpu_batch_normalization_utils.hpp" |
30 | #include "cpu/platform.hpp" |
31 | #include "cpu/x64/jit_generator.hpp" |
32 | |
33 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
34 | #include "cpu/x64/jit_uni_tbb_batch_normalization.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | using namespace memory_tracking::names; |
42 | using namespace Xbyak; |
43 | using acc_data_t = float; |
44 | |
45 | constexpr int bits_per_byte = 8; |
46 | |
47 | bool normalize_only(const batch_normalization_pd_t *pd) { |
48 | return pd->stats_is_src() && pd->is_fwd(); |
49 | } |
50 | |
51 | dim_t get_c_padded(const batch_normalization_pd_t *pd) { |
52 | return pd->src_md()->padded_dims[1]; |
53 | } |
54 | |
55 | template <cpu_isa_t isa> |
56 | int get_vlen(jit_memory_tag_kind_t tag_kind) { |
57 | return isa == sse41 && tag_kind == jit_memory_tag_kind_t::blocked |
58 | ? 32 |
59 | : cpu_isa_traits<isa>::vlen; |
60 | } |
61 | |
62 | template <cpu_isa_t isa> |
63 | int get_simd_w(jit_memory_tag_kind_t tag_kind) { |
64 | return get_vlen<isa>(tag_kind) / sizeof(acc_data_t); |
65 | } |
66 | |
67 | template <cpu_isa_t isa> |
68 | bool is_avx2_ne_xf16(const batch_normalization_pd_t *pd) { |
69 | return isa == avx2 && mayiuse(avx2_vnni_2) |
70 | && utils::one_of( |
71 | pd->src_md()->data_type, data_type::bf16, data_type::f16); |
72 | } |
73 | |
74 | template <cpu_isa_t isa> |
75 | std::tuple<dim_t, dim_t, dim_t> get_data_strides( |
76 | const batch_normalization_pd_t *pd, jit_memory_tag_kind_t tag_kind) { |
77 | const int simd_w = get_simd_w<isa>(tag_kind); |
78 | size_t stride_N, stride_S, stride_C; |
79 | |
80 | if (tag_kind == jit_memory_tag_kind_t::nspc) { |
81 | stride_C = static_cast<size_t>(simd_w); |
82 | stride_S = static_cast<size_t>(pd->C()); |
83 | stride_N = static_cast<size_t>(pd->D() * pd->H() * pd->W()) * stride_S; |
84 | } else { |
85 | const size_t C_blks = static_cast<size_t>(get_c_padded(pd) / simd_w); |
86 | |
87 | stride_C = static_cast<size_t>(pd->D() * pd->H() * pd->W() * simd_w); |
88 | stride_S = static_cast<size_t>(simd_w); |
89 | stride_N = C_blks * stride_C; |
90 | } |
91 | |
92 | return std::make_tuple(stride_N, stride_S, stride_C); |
93 | } |
94 | |
95 | #define PARAM_ADDR(x) (reg_param_ + offsetof(call_params_t, x)) |
96 | template <cpu_isa_t isa> |
97 | struct jit_bnorm_process_tail_t { |
98 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
99 | |
100 | jit_bnorm_process_tail_t(const batch_normalization_pd_t *pd, |
101 | jit_generator *host, Reg64 reg_tmp, Reg64 reg_blk_has_tail, |
102 | Reg64 reg_C, Vmm vtail_mask, Opmask ktail_mask) |
103 | : h_(host) |
104 | , reg_tmp_(reg_tmp) |
105 | , reg_blk_has_tail_(reg_blk_has_tail) |
106 | , reg_C_(reg_C) |
107 | , vtail_mask_(vtail_mask) |
108 | , ktail_mask_(ktail_mask) { |
109 | const memory_desc_wrapper data_d(pd->src_md()); |
110 | c_is_padded_ = pd->C() != data_d.padded_dims()[1]; |
111 | |
112 | const int vlen = isa == sse41 ? 32 : cpu_isa_traits<isa>::vlen; |
113 | tail_ = pd->C() % (int)(vlen / sizeof(float)); |
114 | } |
115 | |
116 | jit_generator *const h_; |
117 | const Reg64 reg_tmp_; |
118 | const Reg64 reg_blk_has_tail_; |
119 | const Reg64 reg_C_; |
120 | const Vmm vtail_mask_; |
121 | const Opmask ktail_mask_; |
122 | bool c_is_padded_; |
123 | int tail_; |
124 | |
125 | void prepare_tail_mask_avx512_common() { |
126 | if (!c_is_padded_) return; |
127 | |
128 | const int mask = (1 << tail_) - 1; |
129 | |
130 | Reg32 regw_tmp = reg_tmp_.cvt32(); |
131 | h_->mov(regw_tmp, mask); |
132 | h_->kmovw(ktail_mask_, regw_tmp); |
133 | } |
134 | |
135 | void prepare_tail_mask_avx2_common() { |
136 | if (!c_is_padded_) return; |
137 | |
138 | static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, |
139 | 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, |
140 | 0, 0, 0, 0, 0, 0, 0}; |
141 | |
142 | h_->mov(reg_tmp_, reinterpret_cast<size_t>(&mask[8 - tail_])); |
143 | h_->vmovups(vtail_mask_, h_->ptr[reg_tmp_]); |
144 | } |
145 | |
146 | void prepare_tail() { |
147 | if (isa == avx512_core) |
148 | prepare_tail_mask_avx512_common(); |
149 | else if (isa == avx2) |
150 | prepare_tail_mask_avx2_common(); |
151 | } |
152 | |
153 | void uni_vmovups_tail_avx2_common( |
154 | const Operand &dst, const Operand &src, Label &l_ret) { |
155 | if (dst.isMEM()) { |
156 | h_->vmaskmovps(dst.getAddress(), vtail_mask_, Vmm(src.getIdx())); |
157 | } else { |
158 | h_->vmaskmovps(Vmm(dst.getIdx()), vtail_mask_, src.getAddress()); |
159 | } |
160 | h_->jmp(l_ret); |
161 | } |
162 | |
163 | void uni_vmovups_tail_avx512_common( |
164 | const Operand &dst, const Operand &src, Label &l_ret) { |
165 | if (dst.isMEM()) |
166 | h_->uni_vmovups(dst.getAddress() | ktail_mask_ | h_->T_z, |
167 | Vmm(src.getIdx())); |
168 | else |
169 | h_->uni_vmovups(Vmm(dst.getIdx()) | ktail_mask_ | h_->T_z, |
170 | src.getAddress()); |
171 | |
172 | h_->jmp(l_ret); |
173 | } |
174 | |
175 | void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) { |
176 | Label l_no_mask, l_ret; |
177 | if (c_is_padded_) { |
178 | h_->cmp(reg_blk_has_tail_, 0); |
179 | h_->jz(l_no_mask); |
180 | |
181 | h_->cmp(reg_C_, 1); |
182 | h_->jne(l_no_mask); |
183 | assert(isa == avx512_core || isa == avx2); |
184 | if (isa == avx512_core) |
185 | uni_vmovups_tail_avx512_common(dst, src, l_ret); |
186 | else if (isa == avx2) |
187 | uni_vmovups_tail_avx2_common(dst, src, l_ret); |
188 | } |
189 | h_->L(l_no_mask); |
190 | if (dst.isMEM()) |
191 | h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx())); |
192 | else |
193 | h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress()); |
194 | |
195 | h_->L(l_ret); |
196 | } |
197 | }; |
198 | |
199 | template <cpu_isa_t isa> |
200 | struct jit_bnorm_process_relu_t { |
201 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
202 | |
203 | jit_bnorm_process_relu_t(const batch_normalization_pd_t *pd, |
204 | jit_generator *host, Reg64 reg_off_dat, Reg64 reg_tmp, |
205 | Reg64 reg_ptr_ws, Vmm vzero, Vmm vstore_mask, Opmask kstore_mask, |
206 | Vmm valpha, Vmm vmask, Reg64 reg_alpha) |
207 | : h_(host) |
208 | , reg_off_dat_(reg_off_dat) |
209 | , reg_tmp_(reg_tmp) |
210 | , reg_ptr_ws_(reg_ptr_ws) |
211 | , reg_alpha(reg_alpha) |
212 | , vzero_(vzero) |
213 | , vstore_mask_(vstore_mask) |
214 | , kstore_mask_(kstore_mask) |
215 | , valpha(valpha) |
216 | , vmask(vmask) |
217 | , with_relu_(pd->with_relu_post_op(pd->is_training()) |
218 | || pd->fuse_norm_relu()) |
219 | , with_relu_inf_only_( |
220 | with_relu_ && !(pd->fuse_norm_relu() && pd->is_training())) |
221 | , bit_shift_(static_cast<int>(log2(bits_per_byte |
222 | * types::data_type_size(pd->src_md()->data_type)))) |
223 | , alpha(with_relu_inf_only_ && pd->with_relu_post_op(pd->is_training()) |
224 | ? pd->alpha() |
225 | : 0.f) {} |
226 | |
227 | jit_bnorm_process_relu_t(const batch_normalization_pd_t *pd, |
228 | jit_generator *host, Reg64 reg_off_dat, Reg64 reg_tmp, |
229 | Reg64 reg_ptr_ws, Vmm vzero, Vmm vstore_mask, Opmask kstore_mask) |
230 | : jit_bnorm_process_relu_t(pd, host, reg_off_dat, reg_tmp, reg_ptr_ws, |
231 | vzero, vstore_mask, kstore_mask, Vmm(), Vmm(), Reg64()) {} |
232 | |
233 | jit_generator *const h_; |
234 | const Reg64 reg_off_dat_; |
235 | const Reg64 reg_tmp_; |
236 | const Reg64 reg_ptr_ws_; |
237 | const Reg64 reg_alpha; |
238 | const Vmm vzero_, vstore_mask_; |
239 | const Opmask kstore_mask_; |
240 | // used for ReLU computation |
241 | const Vmm valpha; |
242 | const Vmm vmask; // used for AVX2 and SSE41 |
243 | Label l_relu_mask_avx2_; |
244 | const bool with_relu_, with_relu_inf_only_; |
245 | const int bit_shift_; |
246 | const float alpha; |
247 | |
248 | bool with_relu() const { return with_relu_; } |
249 | |
250 | bool with_relu_inf_only() const { return with_relu_inf_only_; } |
251 | |
252 | void fwd_prepare_relu() { |
253 | if (with_relu_) { h_->uni_vpxor(vzero_, vzero_, vzero_); } |
254 | if (with_relu_inf_only_ && alpha != 0) |
255 | h_->mov(reg_alpha, float2int(alpha)); |
256 | } |
257 | |
258 | void bwd_prepare_relu() { |
259 | if (with_relu_) { |
260 | h_->uni_vpxor(vzero_, vzero_, vzero_); |
261 | if (isa == avx2) prepare_l_relu_mask_avx2(); |
262 | } |
263 | } |
264 | |
265 | void prepare_l_relu_mask_avx2() { |
266 | Label l_mask_after; |
267 | h_->jmp(l_mask_after); |
268 | h_->align(32); |
269 | h_->L(l_relu_mask_avx2_); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ |
270 | for (int i = 0; i < 8; ++i) |
271 | h_->dd(1 << i); |
272 | h_->L(l_mask_after); |
273 | } |
274 | |
275 | void fwd_process_relu(Vmm v, const int off = 0) { |
276 | if (with_relu_inf_only_) { |
277 | if (alpha != 0.f) |
278 | fwd_process_relu_alpha(v); |
279 | else |
280 | h_->uni_vmaxps(v, v, vzero_); |
281 | } else if (with_relu_) { |
282 | if (isa == avx512_core) |
283 | fwd_process_relu_avx512_common(v, off); |
284 | else if (isa == avx2) |
285 | fwd_process_relu_avx2(v, off); |
286 | else |
287 | assert(false); |
288 | } |
289 | } |
290 | |
291 | void bwd_process_relu(Vmm v, const int off = 0) { |
292 | if (with_relu_) { |
293 | if (isa == avx512_core) |
294 | bwd_process_relu_avx512_common(v, off); |
295 | else if (isa == avx2) |
296 | bwd_process_relu_avx2(v, off); |
297 | else |
298 | assert(false); |
299 | } |
300 | } |
301 | |
302 | void fwd_process_relu_avx2(Vmm vdst, const int off = 0) { |
303 | Reg64 reg_store_mask = reg_tmp_; |
304 | h_->shr(reg_off_dat_, bit_shift_); |
305 | h_->vcmpps(vstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os); |
306 | h_->vmovmskps(reg_store_mask, vstore_mask_); |
307 | h_->mov(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off], |
308 | reg_store_mask.cvt8()); |
309 | h_->vblendvps(vdst, vzero_, vdst, vstore_mask_); |
310 | h_->shl(reg_off_dat_, bit_shift_); |
311 | } |
312 | |
313 | void fwd_process_relu_avx512_common(Vmm vdst, const int off = 0) { |
314 | h_->shr(reg_off_dat_, bit_shift_); |
315 | h_->vcmpps(kstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os); |
316 | h_->kmovw(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off], kstore_mask_); |
317 | h_->vblendmps(vdst | kstore_mask_, vzero_, vdst); |
318 | h_->shl(reg_off_dat_, bit_shift_); |
319 | } |
320 | |
321 | void fwd_process_relu_alpha(Vmm vmm_dst) { |
322 | if (isa == avx512_core) |
323 | fwd_process_relu_alpha_avx512_common(vmm_dst); |
324 | else { |
325 | assert(utils::one_of(isa, avx2, sse41)); |
326 | fwd_process_relu_alpha_avx2(vmm_dst); |
327 | } |
328 | } |
329 | |
330 | void fwd_process_relu_alpha_avx512_common(Vmm vmm_dst) { |
331 | const Xmm xmm_aux = Xmm(valpha.getIdx()); |
332 | h_->vmovq(xmm_aux, reg_alpha); |
333 | h_->vbroadcastss(valpha, xmm_aux); |
334 | h_->vcmpps(kstore_mask_, vzero_, vmm_dst, h_->_cmp_lt_os); |
335 | h_->vmulps(valpha, vmm_dst, valpha); |
336 | h_->vblendmps(vmm_dst | kstore_mask_, valpha, vmm_dst); |
337 | } |
338 | |
339 | void fwd_process_relu_alpha_avx2(Vmm vmm_dst) { |
340 | const Xmm xmm_aux = Xmm(valpha.getIdx()); |
341 | h_->uni_vpxor(vmask, vmask, vmask); |
342 | h_->uni_vmovq(xmm_aux, reg_alpha); |
343 | h_->uni_vbroadcastss(valpha, xmm_aux); |
344 | h_->uni_vcmpps(vmask, vmm_dst, vzero_, h_->_cmp_lt_os); |
345 | h_->uni_vmulps(valpha, valpha, vmm_dst); |
346 | h_->uni_vblendvps( |
347 | vmm_dst, vmm_dst, valpha, vmask); // swaped aux and dst |
348 | } |
349 | |
350 | void bwd_process_relu_avx2(Vmm vdiff_dst, const int off = 0) { |
351 | h_->shr(reg_off_dat_, bit_shift_); |
352 | h_->vpbroadcastb( |
353 | vstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]); |
354 | h_->vpand(vstore_mask_, vstore_mask_, |
355 | h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]); |
356 | h_->vpcmpeqd(vstore_mask_, vstore_mask_, |
357 | h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]); |
358 | h_->vblendvps(vdiff_dst, vzero_, vdiff_dst, vstore_mask_); |
359 | h_->shl(reg_off_dat_, bit_shift_); |
360 | } |
361 | |
362 | void bwd_process_relu_avx512_common(Vmm vdiff_dst, const int off = 0) { |
363 | h_->shr(reg_off_dat_, bit_shift_); |
364 | h_->kmovw(kstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]); |
365 | h_->vmovups(vdiff_dst | kstore_mask_ | h_->T_z, vdiff_dst); |
366 | h_->shl(reg_off_dat_, bit_shift_); |
367 | } |
368 | }; |
369 | |
370 | template <cpu_isa_t isa> |
371 | struct helper_vmovups_data_t { |
372 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
373 | |
374 | helper_vmovups_data_t(const batch_normalization_pd_t *pd, |
375 | jit_generator *host, Zmm zmm_reserved_1, Zmm zmm_reserved_2, |
376 | Zmm zmm_reserved_3, Zmm zmm_reserved_4, Reg64 reg_tmp) |
377 | : h_(host), bf16_emu_(nullptr) { |
378 | is_bf16_ = pd->src_md()->data_type == data_type::bf16; |
379 | is_f16_ = pd->src_md()->data_type == data_type::f16; |
380 | if (is_bf16_ && isa == avx512_core && !mayiuse(avx512_core_bf16)) { |
381 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(h_, zmm_reserved_1, |
382 | zmm_reserved_2, zmm_reserved_3, reg_tmp, zmm_reserved_4, |
383 | zmm_reserved_4); |
384 | } |
385 | } |
386 | |
387 | jit_generator *const h_; |
388 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
389 | bool is_bf16_; |
390 | bool is_f16_; |
391 | |
392 | void merge_interleaved_to_plain(const Vmm &vmm_even, const Vmm &vmm_odd, |
393 | const Vmm &vmm_aux0) const { |
394 | Ymm ymm_even = Ymm(vmm_even.getIdx()); |
395 | Ymm ymm_odd = Ymm(vmm_odd.getIdx()); |
396 | Ymm ymm_aux0 = Ymm(vmm_aux0.getIdx()); |
397 | Ymm ymm_aux1 = ymm_odd; |
398 | |
399 | h_->vpunpckldq(ymm_aux0, ymm_even, ymm_odd); |
400 | h_->vpunpckhdq(ymm_aux1, ymm_even, ymm_odd); |
401 | h_->vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20); |
402 | h_->vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31); |
403 | } |
404 | |
405 | void operator()(const Vmm &vmm_even, const Vmm &vmm_odd, |
406 | const Address &addr) const { |
407 | // load two simd_w data from addr into two registers |
408 | if (is_bf16_) { |
409 | // convert bf16 input to f32 |
410 | h_->vcvtneebf162ps(vmm_even, addr); |
411 | h_->vcvtneobf162ps(vmm_odd, addr); |
412 | } else if (is_f16_) { |
413 | h_->vcvtneeph2ps(vmm_even, addr); |
414 | h_->vcvtneoph2ps(vmm_odd, addr); |
415 | } else |
416 | assert(!"unsupported data type" ); |
417 | } |
418 | |
419 | void operator()(const Operand &dst, const Operand &src) const { |
420 | if (dst.isMEM()) { |
421 | if (is_bf16_) { |
422 | constexpr bool isAvx2 = isa == avx2; |
423 | const typename std::conditional<isAvx2, Xmm, Ymm>::type |
424 | dst_reg {src.getIdx()}; |
425 | const typename std::conditional<isAvx2, Ymm, Zmm>::type |
426 | src_reg {src.getIdx()}; |
427 | |
428 | // convert f32 output to bf16 |
429 | if (!bf16_emu_) |
430 | h_->vcvtneps2bf16(dst_reg, src_reg, |
431 | mayiuse(avx512_core) ? Xbyak::EvexEncoding |
432 | : Xbyak::VexEncoding); |
433 | else |
434 | bf16_emu_->vcvtneps2bf16(dst_reg, src_reg); |
435 | |
436 | h_->uni_vmovups(dst.getAddress(), dst_reg); |
437 | } else if (is_f16_) { |
438 | auto src_reg = Vmm(src.getIdx()); |
439 | h_->vcvtps2ph(dst.getAddress(), src_reg, h_->_op_mxcsr); |
440 | } else { |
441 | h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx())); |
442 | } |
443 | } else { |
444 | if (is_bf16_) { |
445 | // convert bf16 input to f32 |
446 | h_->vpmovzxwd(Vmm(dst.getIdx()), src.getAddress()); |
447 | h_->vpslld(Vmm(dst.getIdx()), Vmm(dst.getIdx()), 0x10); |
448 | } else if (is_f16_) { |
449 | if (mayiuse(avx512_core_fp16)) |
450 | h_->vcvtph2psx(Vmm(dst.getIdx()), src.getAddress()); |
451 | else |
452 | h_->vcvtph2ps(Vmm(dst.getIdx()), src.getAddress()); |
453 | } else { |
454 | h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress()); |
455 | } |
456 | } |
457 | } |
458 | |
459 | private: |
460 | DNNL_DISALLOW_COPY_AND_ASSIGN(helper_vmovups_data_t); |
461 | }; |
462 | |
463 | template <cpu_isa_t isa> |
464 | struct jit_bnorm_fwd_statistics_t : public jit_generator { |
465 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_statistics_t) |
466 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
467 | |
468 | const AddressFrame &vmmword |
469 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
470 | |
471 | struct call_params_t { |
472 | size_t N, C, S; |
473 | const void *src; |
474 | const acc_data_t *mean; |
475 | const acc_data_t *var; |
476 | size_t blk_has_tail; |
477 | size_t do_normalise; |
478 | }; |
479 | |
480 | const Reg64 reg_param_ = abi_param1; |
481 | const Reg64 reg_tmp_ = abi_not_param1; |
482 | const Reg64 reg_N_ = rsi; |
483 | const Reg64 reg_S_ = rax; |
484 | const Reg64 reg_C_ = rdx; |
485 | const Reg64 reg_off_c_ = rbx; |
486 | const Reg64 reg_blk_has_tail_ = rbp; |
487 | |
488 | const Reg64 reg_off_dat_ = r8; |
489 | const Reg64 reg_off_dat_save_ = r9; |
490 | const Reg64 reg_ptr_mean_ = r10; |
491 | const Reg64 reg_ptr_var_ = r11; |
492 | const Reg64 reg_ptr_src_ = r12; |
493 | const Reg64 reg_do_normalise_ = r13; |
494 | const Reg64 reg_ptr_stat_ = r14; |
495 | |
496 | const Vmm v_ = Vmm(0); |
497 | const Vmm vtmp_ = Vmm(1); |
498 | const Vmm vtail_mask_ = Vmm(2); |
499 | const Vmm vNS_ = Vmm(3); |
500 | const Vmm vzero_ = Vmm(4); |
501 | const Vmm vsrc_aux = Vmm(2); //use for xf16 nspc on AVX2 |
502 | // When variance is computed then two vmms(one for variance and |
503 | // one for mean) are needed to unroll one c block at any moment, |
504 | // therefore the number of registers which are used to unrolling |
505 | // must to be divisible by two. |
506 | static constexpr int min_idx_to_unroll_ = 4; |
507 | static constexpr int max_idx_to_unroll_ = isa == avx512_core ? 28 : 16; |
508 | static constexpr int number_of_vmms_to_unrolling_variables_ |
509 | = max_idx_to_unroll_ - min_idx_to_unroll_; |
510 | static_assert(number_of_vmms_to_unrolling_variables_ % 2 == 0 |
511 | && number_of_vmms_to_unrolling_variables_ != 0, |
512 | "Number of register to unrolling must to be divisible by 2." ); |
513 | |
514 | const Opmask ktail_mask_ = k2; |
515 | |
516 | const batch_normalization_pd_t *pd_; |
517 | const jit_memory_tag_kind_t tag_kind_; |
518 | const int vlen; |
519 | const int simd_w; |
520 | const bool is_avx2_ne_xf16_; |
521 | jit_bnorm_process_tail_t<isa> jit_tail_; |
522 | helper_vmovups_data_t<isa> helper_vmovups_; |
523 | int stride_N_, stride_S_, stride_C_; |
524 | size_t data_type_size_, acc_type_size_; |
525 | |
526 | void load_common_params() { |
527 | #define PARAM_PTR(x) ptr[PARAM_ADDR(x)] |
528 | mov(reg_ptr_src_, PARAM_PTR(src)); |
529 | mov(reg_ptr_mean_, PARAM_PTR(mean)); |
530 | mov(reg_ptr_var_, PARAM_PTR(var)); |
531 | #undef PARAM_PTR |
532 | mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]); |
533 | mov(reg_do_normalise_, dword[PARAM_ADDR(do_normalise)]); |
534 | } |
535 | |
536 | void zeroise() { |
537 | Label label_zeroise; |
538 | xor_(reg_off_c_, reg_off_c_); |
539 | uni_vpxor(vzero_, vzero_, vzero_); |
540 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
541 | L(label_zeroise); |
542 | { |
543 | jit_tail_.uni_vmovups_maybe_tail( |
544 | vmmword[reg_ptr_stat_ + reg_off_c_], vzero_); |
545 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
546 | jit_tail_.uni_vmovups_maybe_tail( |
547 | vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], vzero_); |
548 | } |
549 | add(reg_off_c_, simd_w * acc_type_size_); |
550 | dec(reg_C_); |
551 | jnz(label_zeroise); |
552 | } |
553 | } |
554 | |
555 | void load_stat(bool compute_mean, const int c_blks_to_unroll = 1) { |
556 | int start_idx = min_idx_to_unroll_; |
557 | int end_idx = c_blks_to_unroll + min_idx_to_unroll_; |
558 | const int step = simd_w * acc_type_size_; |
559 | |
560 | // load mean or variance |
561 | for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) { |
562 | const Vmm vstat = Vmm(idx); |
563 | jit_tail_.uni_vmovups_maybe_tail( |
564 | vstat, vmmword[reg_ptr_stat_ + reg_off_c_ + off]); |
565 | } |
566 | |
567 | // if variance is counted then mean also is needed |
568 | if (!compute_mean) { |
569 | start_idx = min_idx_to_unroll_ + c_blks_to_unroll; |
570 | end_idx = min_idx_to_unroll_ + 2 * c_blks_to_unroll; |
571 | |
572 | for (int idx = start_idx, off = 0; idx < end_idx; |
573 | idx++, off += step) { |
574 | const Vmm vmean = Vmm(idx); |
575 | jit_tail_.uni_vmovups_maybe_tail( |
576 | vmean, vmmword[reg_ptr_mean_ + reg_off_c_ + off]); |
577 | } |
578 | } |
579 | } |
580 | |
581 | void compute_stat(bool compute_mean, const int c_blks_to_unroll = 1) { |
582 | const int start_idx = min_idx_to_unroll_; |
583 | const int end_idx = c_blks_to_unroll + min_idx_to_unroll_; |
584 | const int step = simd_w * data_type_size_; |
585 | |
586 | for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) { |
587 | const Vmm vstat = Vmm(idx); |
588 | |
589 | helper_vmovups_(v_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]); |
590 | |
591 | if (compute_mean) { |
592 | uni_vaddps(vstat, vstat, v_); |
593 | } else { |
594 | const Vmm vmean = Vmm(idx + c_blks_to_unroll); |
595 | |
596 | // var += (src - mean)^2 |
597 | uni_vsubps(vtmp_, v_, vmean, vtmp_); |
598 | uni_vfmadd231ps(vstat, vtmp_, vtmp_); |
599 | } |
600 | } |
601 | } |
602 | |
603 | void compute_stat_avx2_ne_xf16( |
604 | bool compute_mean, const int c_blks_to_unroll = 1) { |
605 | const int start_idx = min_idx_to_unroll_; |
606 | const int end_idx = c_blks_to_unroll + min_idx_to_unroll_; |
607 | const int step = simd_w * data_type_size_; |
608 | |
609 | for (int idx = start_idx, off = 0; idx < end_idx; |
610 | idx += 2, off += 2 * step) { |
611 | const bool is_c_blks_tail = (end_idx - idx) < 2; |
612 | const Vmm vsrc_even = v_; |
613 | const Vmm vsrc_odd = vsrc_aux; |
614 | if (is_c_blks_tail) |
615 | helper_vmovups_( |
616 | vsrc_even, vmmword[reg_ptr_src_ + reg_off_dat_ + off]); |
617 | else |
618 | helper_vmovups_(vsrc_even, vsrc_odd, |
619 | vmmword[reg_ptr_src_ + reg_off_dat_ + off]); |
620 | for (int i_odd = 0; i_odd < 2 && idx + i_odd < end_idx; ++i_odd) { |
621 | const Vmm vstat = Vmm(idx + i_odd); |
622 | const Vmm vsrc = i_odd ? vsrc_odd : vsrc_even; |
623 | if (compute_mean) { |
624 | uni_vaddps(vstat, vstat, vsrc); |
625 | } else { |
626 | const Vmm vmean = Vmm(idx + i_odd + c_blks_to_unroll); |
627 | uni_vsubps(vtmp_, vsrc, vmean, vtmp_); |
628 | uni_vfmadd231ps(vstat, vtmp_, vtmp_); |
629 | } |
630 | } |
631 | } |
632 | } |
633 | |
634 | void store_stat(const int c_blks_to_unroll = 1) { |
635 | const int start_idx = min_idx_to_unroll_; |
636 | const int end_idx = c_blks_to_unroll + min_idx_to_unroll_; |
637 | const int step = simd_w * acc_type_size_; |
638 | |
639 | for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) { |
640 | const Vmm vstat = Vmm(idx); |
641 | |
642 | jit_tail_.uni_vmovups_maybe_tail( |
643 | vmmword[reg_ptr_stat_ + reg_off_c_ + off], vstat); |
644 | } |
645 | } |
646 | |
647 | void compute_blocked(bool compute_mean) { |
648 | Label label_C, label_S; |
649 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
650 | L(label_C); |
651 | { |
652 | mov(reg_off_dat_, reg_off_dat_save_); |
653 | |
654 | load_stat(compute_mean); |
655 | |
656 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
657 | L(label_S); |
658 | { |
659 | compute_stat(compute_mean); |
660 | |
661 | add(reg_off_dat_, stride_S_ * data_type_size_); |
662 | |
663 | dec(reg_S_); |
664 | jnz(label_S); |
665 | } |
666 | |
667 | store_stat(); |
668 | |
669 | add(reg_off_dat_save_, stride_C_ * data_type_size_); |
670 | add(reg_off_c_, simd_w * acc_type_size_); |
671 | |
672 | dec(reg_C_); |
673 | jnz(label_C); |
674 | } |
675 | } |
676 | |
677 | void compute_nspc(bool compute_mean) { |
678 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
679 | |
680 | // When a variance is computed, two values are unrolled: mean and variance, |
681 | // so number_of_vmms_to_unrolling_variables_ is divided by 2. |
682 | const int max_of_unrolled_c_blks = compute_mean |
683 | ? number_of_vmms_to_unrolling_variables_ |
684 | : number_of_vmms_to_unrolling_variables_ / 2; |
685 | std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1); |
686 | |
687 | for (int c_blks_to_unroll = max_of_unrolled_c_blks; |
688 | c_blks_to_unroll > 0; --c_blks_to_unroll) { |
689 | L(c_unroll_label[c_blks_to_unroll]); |
690 | { |
691 | cmp(reg_C_, c_blks_to_unroll); |
692 | jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR); |
693 | |
694 | mov(reg_off_dat_, reg_off_dat_save_); |
695 | |
696 | load_stat(compute_mean, c_blks_to_unroll); |
697 | |
698 | Label label_S; |
699 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
700 | L(label_S); |
701 | { |
702 | is_avx2_ne_xf16_ |
703 | ? compute_stat_avx2_ne_xf16( |
704 | compute_mean, c_blks_to_unroll) |
705 | : compute_stat(compute_mean, c_blks_to_unroll); |
706 | |
707 | add(reg_off_dat_, stride_S_ * data_type_size_); |
708 | |
709 | dec(reg_S_); |
710 | jnz(label_S); |
711 | } |
712 | |
713 | store_stat(c_blks_to_unroll); |
714 | |
715 | add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_); |
716 | add(reg_off_dat_save_, |
717 | c_blks_to_unroll * stride_C_ * data_type_size_); |
718 | |
719 | sub(reg_C_, c_blks_to_unroll); |
720 | jmp(c_unroll_label[c_blks_to_unroll], T_NEAR); |
721 | } |
722 | } |
723 | L(c_unroll_label[0]); |
724 | } |
725 | |
726 | void compute(bool compute_mean) { |
727 | Label label_N; |
728 | mov(reg_N_, dword[PARAM_ADDR(N)]); |
729 | L(label_N); |
730 | { |
731 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
732 | xor_(reg_off_c_, reg_off_c_); |
733 | |
734 | tag_kind_ == jit_memory_tag_kind_t::nspc |
735 | ? compute_nspc(compute_mean) |
736 | : compute_blocked(compute_mean); |
737 | |
738 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
739 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
740 | xor_(reg_off_c_, reg_off_c_); |
741 | add(reg_off_dat_save_, vlen / 2); |
742 | add(reg_off_c_, vlen / 2); |
743 | |
744 | compute_blocked(compute_mean); |
745 | } |
746 | |
747 | add(reg_ptr_src_, stride_N_ * data_type_size_); |
748 | dec(reg_N_); |
749 | jnz(label_N); |
750 | } |
751 | } |
752 | |
753 | void normalize() { |
754 | Label label_ret, label_normalise; |
755 | cmp(reg_do_normalise_, 0); |
756 | jz(label_ret); |
757 | |
758 | const int S = pd_->D() * pd_->H() * pd_->W(); |
759 | mov(reg_tmp_, float2int(pd_->MB() * S)); |
760 | Xmm xtmp = Xmm(vtmp_.getIdx()); |
761 | uni_vmovq(xtmp, reg_tmp_); |
762 | uni_vbroadcastss(vNS_, xtmp); |
763 | |
764 | xor_(reg_off_c_, reg_off_c_); |
765 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
766 | L(label_normalise); |
767 | { |
768 | jit_tail_.uni_vmovups_maybe_tail( |
769 | v_, vmmword[reg_ptr_stat_ + reg_off_c_]); |
770 | uni_vdivps(v_, v_, vNS_); |
771 | jit_tail_.uni_vmovups_maybe_tail( |
772 | vmmword[reg_ptr_stat_ + reg_off_c_], v_); |
773 | |
774 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
775 | jit_tail_.uni_vmovups_maybe_tail( |
776 | v_, vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2]); |
777 | uni_vdivps(v_, v_, vNS_); |
778 | jit_tail_.uni_vmovups_maybe_tail( |
779 | vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], v_); |
780 | } |
781 | |
782 | add(reg_off_c_, simd_w * acc_type_size_); |
783 | dec(reg_C_); |
784 | jnz(label_normalise); |
785 | } |
786 | |
787 | L(label_ret); |
788 | } |
789 | |
790 | jit_bnorm_fwd_statistics_t(const batch_normalization_pd_t *pd, |
791 | const jit_memory_tag_kind_t tag_kind) |
792 | : jit_generator(jit_name()) |
793 | , pd_(pd) |
794 | , tag_kind_(tag_kind) |
795 | , vlen(get_vlen<isa>(tag_kind)) |
796 | , simd_w(get_simd_w<isa>(tag_kind)) |
797 | , is_avx2_ne_xf16_(is_avx2_ne_xf16<isa>(pd)) |
798 | , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_, |
799 | ktail_mask_) |
800 | , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) { |
801 | static_assert(utils::one_of(isa, sse41, avx2, avx512_core), |
802 | "unsupported isa" ); |
803 | |
804 | std::tie(stride_N_, stride_S_, stride_C_) |
805 | = get_data_strides<isa>(pd_, tag_kind); |
806 | |
807 | data_type_size_ = types::data_type_size(pd->src_md()->data_type); |
808 | acc_type_size_ = sizeof(acc_data_t); |
809 | } |
810 | }; |
811 | |
812 | template <cpu_isa_t isa> |
813 | struct jit_bnorm_fwd_mean_t : jit_bnorm_fwd_statistics_t<isa> { |
814 | using call_params_t = |
815 | typename jit_bnorm_fwd_statistics_t<isa>::call_params_t; |
816 | |
817 | jit_bnorm_fwd_mean_t(const batch_normalization_pd_t *pd, |
818 | const jit_memory_tag_kind_t tag_kind) |
819 | : jit_bnorm_fwd_statistics_t<isa>(pd, tag_kind) {} |
820 | |
821 | void generate() override { |
822 | this->preamble(); |
823 | this->load_common_params(); |
824 | this->mov(this->reg_ptr_stat_, this->reg_ptr_mean_); |
825 | this->jit_tail_.prepare_tail(); |
826 | this->zeroise(); |
827 | this->compute(true); |
828 | this->normalize(); |
829 | this->postamble(); |
830 | } |
831 | }; |
832 | |
833 | template <cpu_isa_t isa> |
834 | struct jit_bnorm_fwd_var_t : jit_bnorm_fwd_statistics_t<isa> { |
835 | using call_params_t = |
836 | typename jit_bnorm_fwd_statistics_t<isa>::call_params_t; |
837 | |
838 | jit_bnorm_fwd_var_t(const batch_normalization_pd_t *pd, |
839 | const jit_memory_tag_kind_t tag_kind) |
840 | : jit_bnorm_fwd_statistics_t<isa>(pd, tag_kind) {} |
841 | |
842 | void generate() override { |
843 | this->preamble(); |
844 | this->load_common_params(); |
845 | this->mov(this->reg_ptr_stat_, this->reg_ptr_var_); |
846 | this->jit_tail_.prepare_tail(); |
847 | this->zeroise(); |
848 | this->compute(false); |
849 | this->normalize(); |
850 | this->postamble(); |
851 | } |
852 | }; |
853 | |
854 | template <cpu_isa_t isa> |
855 | struct jit_bnorm_fwd_t : public jit_generator { |
856 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_t) |
857 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
858 | |
859 | const AddressFrame &vmmword |
860 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
861 | |
862 | struct call_params_t { |
863 | size_t N, C, S; |
864 | const void *src; |
865 | void *dst; |
866 | const uint8_t *ws; |
867 | const acc_data_t *mean, *var; |
868 | const acc_data_t *scale, *shift; |
869 | size_t blk_has_tail; |
870 | }; |
871 | |
872 | const Reg64 reg_param_ = abi_param1; |
873 | const Reg64 reg_tmp_ = abi_not_param1; |
874 | const Reg64 reg_N_ = rsi; |
875 | const Reg64 reg_S_ = rax; |
876 | const Reg64 reg_C_ = rdx; |
877 | const Reg64 reg_off_c_ = rbx; |
878 | const Reg64 reg_blk_has_tail_ = rbp; |
879 | |
880 | const Reg64 reg_off_dat_ = r8; |
881 | const Reg64 reg_off_dat_save_ = r9; |
882 | const Reg64 reg_ptr_ws_ = r10; |
883 | const Reg64 reg_ptr_scale_ = r11; |
884 | const Reg64 reg_ptr_shift_ = reg_N_; |
885 | const Reg64 reg_ptr_var_ = r12; |
886 | const Reg64 reg_ptr_mean_ = r13; |
887 | const Reg64 reg_ptr_dst_ = r14; |
888 | const Reg64 reg_ptr_src_ = r15; |
889 | const Reg64 reg_alpha_ = reg_ptr_ws_; |
890 | |
891 | const Vmm vmask = Vmm(0); // required for avx2 and sse41 ReLU computation |
892 | const Vmm vone_ = Vmm(1); |
893 | const Vmm vmean_ = Vmm(2); |
894 | const Vmm vvar_ = Vmm(3); |
895 | const Vmm vsqrtvar_ = Vmm(4); |
896 | const Vmm vgamma_ = Vmm(5); |
897 | const Vmm vbeta_ = Vmm(6); |
898 | const Vmm veps_ = Vmm(7); |
899 | const Vmm vtmp_ = Vmm(8); |
900 | const Vmm v_ = Vmm(9); |
901 | const Vmm vzero_ = Vmm(10); |
902 | const Vmm vtail_mask_ = Vmm(11); |
903 | const Vmm valpha = Vmm(12); |
904 | const Vmm vsrc_aux = Vmm(13); |
905 | const Vmm vstore_mask_ = vtmp_; |
906 | const Vmm vmean_even_ = vmean_; |
907 | const Vmm vmean_odd_ = Vmm(14); |
908 | const Vmm vsqrtvar_even_ = vsqrtvar_; |
909 | const Vmm vsqrtvar_odd_ = Vmm(15); |
910 | const Vmm vvar_even_ = vvar_; |
911 | const Vmm vvar_odd_ = vsrc_aux; |
912 | |
913 | const Opmask kstore_mask_ = k1; |
914 | const Opmask ktail_mask_ = k2; |
915 | |
916 | const batch_normalization_pd_t *pd_; |
917 | const jit_memory_tag_kind_t tag_kind_; |
918 | const int vlen; |
919 | const int simd_w; |
920 | const bool is_avx2_ne_xf16_; |
921 | jit_bnorm_process_tail_t<isa> jit_tail_; |
922 | jit_bnorm_process_relu_t<isa> jit_relu_; |
923 | helper_vmovups_data_t<isa> helper_vmovups_; |
924 | int stride_N_, stride_S_, stride_C_; |
925 | size_t data_type_size_, acc_type_size_; |
926 | |
927 | enum { |
928 | stack_off_N = 0, |
929 | stack_off_shift = 8, |
930 | stack_size_required = 16, |
931 | }; |
932 | |
933 | void load_common_params() { |
934 | #define PARAM_PTR(x) ptr[PARAM_ADDR(x)] |
935 | mov(reg_ptr_src_, PARAM_PTR(src)); |
936 | mov(reg_ptr_dst_, PARAM_PTR(dst)); |
937 | mov(reg_ptr_mean_, PARAM_PTR(mean)); |
938 | mov(reg_ptr_var_, PARAM_PTR(var)); |
939 | mov(reg_ptr_scale_, PARAM_PTR(scale)); |
940 | if (jit_relu_.with_relu_ && !jit_relu_.with_relu_inf_only_) |
941 | mov(reg_ptr_ws_, PARAM_PTR(ws)); |
942 | |
943 | Xmm x = Xmm(v_.getIdx()); |
944 | |
945 | mov(reg_tmp_, float2int(pd_->desc()->batch_norm_epsilon)); |
946 | uni_vmovq(x, reg_tmp_); |
947 | uni_vbroadcastss(veps_, x); |
948 | |
949 | mov(reg_tmp_, float2int(1.f)); |
950 | uni_vmovq(x, reg_tmp_); |
951 | uni_vbroadcastss(vone_, x); |
952 | |
953 | mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]); |
954 | |
955 | mov(reg_tmp_, PARAM_PTR(shift)); |
956 | mov(ptr[rsp + stack_off_shift], reg_tmp_); |
957 | mov(reg_tmp_, PARAM_PTR(N)); |
958 | mov(ptr[rsp + stack_off_N], reg_tmp_); |
959 | #undef PARAM_PTR |
960 | } |
961 | |
962 | void load_c_specifics( |
963 | const bool has_load_mean_sqrtvar, const int offt = 0) { |
964 | if (!has_load_mean_sqrtvar) { |
965 | jit_tail_.uni_vmovups_maybe_tail( |
966 | vmean_, vmmword[reg_ptr_mean_ + reg_off_c_ + offt]); |
967 | jit_tail_.uni_vmovups_maybe_tail( |
968 | vvar_, vmmword[reg_ptr_var_ + reg_off_c_ + offt]); |
969 | |
970 | uni_vmovups(vsqrtvar_, vvar_); |
971 | uni_vaddps(vsqrtvar_, vsqrtvar_, veps_); |
972 | uni_vsqrtps(vsqrtvar_, vsqrtvar_); |
973 | if (isa == sse41) { |
974 | movups(vtmp_, vone_); |
975 | divps(vtmp_, vsqrtvar_); |
976 | movups(vsqrtvar_, vtmp_); |
977 | } else |
978 | vdivps(vsqrtvar_, vone_, vsqrtvar_); |
979 | } |
980 | |
981 | if (pd_->use_scale()) |
982 | jit_tail_.uni_vmovups_maybe_tail( |
983 | vgamma_, vmmword[reg_ptr_scale_ + reg_off_c_ + offt]); |
984 | if (pd_->use_shift()) |
985 | jit_tail_.uni_vmovups_maybe_tail( |
986 | vbeta_, vmmword[reg_ptr_shift_ + reg_off_c_ + offt]); |
987 | } |
988 | |
989 | void compute_bnorm(const Vmm &v, const Vmm &vmean, const Vmm &vsqrtvar, |
990 | bool stream_store_allowed, bool has_load_src, const int offt = 0) { |
991 | if (!has_load_src) |
992 | helper_vmovups_(v, vmmword[reg_ptr_src_ + reg_off_dat_ + offt]); |
993 | uni_vsubps(v, v, vmean); |
994 | uni_vmulps(v, v, vsqrtvar); |
995 | |
996 | if (pd_->use_scale() && pd_->use_shift()) |
997 | uni_vfmadd213ps(v, vgamma_, vbeta_); |
998 | else if (pd_->use_scale()) |
999 | uni_vmulps(v, v, vgamma_); |
1000 | else if (pd_->use_shift()) |
1001 | uni_vaddps(v, v, vbeta_); |
1002 | |
1003 | jit_relu_.fwd_process_relu(v); |
1004 | |
1005 | if (stream_store_allowed) { |
1006 | uni_vmovntps(vmmword[reg_ptr_dst_ + reg_off_dat_ + offt], v); |
1007 | } else { |
1008 | helper_vmovups_(vmmword[reg_ptr_dst_ + reg_off_dat_ + offt], v); |
1009 | } |
1010 | } |
1011 | |
1012 | void load_two_c_mean_sqrtvar() { |
1013 | const int offt = simd_w * acc_type_size_; |
1014 | jit_tail_.uni_vmovups_maybe_tail( |
1015 | vmean_even_, vmmword[reg_ptr_mean_ + reg_off_c_]); |
1016 | jit_tail_.uni_vmovups_maybe_tail( |
1017 | vmean_odd_, vmmword[reg_ptr_mean_ + reg_off_c_ + offt]); |
1018 | jit_tail_.uni_vmovups_maybe_tail( |
1019 | vvar_even_, vmmword[reg_ptr_var_ + reg_off_c_]); |
1020 | jit_tail_.uni_vmovups_maybe_tail( |
1021 | vvar_odd_, vmmword[reg_ptr_var_ + reg_off_c_ + offt]); |
1022 | |
1023 | // merge mean and variance in interleave to plain layout when needed |
1024 | if (!pd_->stats_is_src()) { |
1025 | helper_vmovups_.merge_interleaved_to_plain( |
1026 | vmean_even_, vmean_odd_, vtmp_); |
1027 | helper_vmovups_.merge_interleaved_to_plain( |
1028 | vvar_even_, vvar_odd_, vtmp_); |
1029 | } |
1030 | uni_vmovups(vsqrtvar_even_, vvar_even_); |
1031 | uni_vaddps(vsqrtvar_even_, vsqrtvar_even_, veps_); |
1032 | uni_vsqrtps(vsqrtvar_even_, vsqrtvar_even_); |
1033 | vdivps(vsqrtvar_even_, vone_, vsqrtvar_even_); |
1034 | |
1035 | uni_vmovups(vsqrtvar_odd_, vvar_odd_); |
1036 | uni_vaddps(vsqrtvar_odd_, vsqrtvar_odd_, veps_); |
1037 | uni_vsqrtps(vsqrtvar_odd_, vsqrtvar_odd_); |
1038 | vdivps(vsqrtvar_odd_, vone_, vsqrtvar_odd_); |
1039 | } |
1040 | |
1041 | void compute_bnorm_avx2_ne_xf16( |
1042 | const bool is_c_blks_tail, bool stream_store_allowed) { |
1043 | const Vmm vsrc_even = v_; |
1044 | const Vmm vsrc_odd = vsrc_aux; |
1045 | if (is_c_blks_tail) { |
1046 | compute_bnorm( |
1047 | vsrc_even, vmean_, vsqrtvar_, stream_store_allowed, false); |
1048 | } else { |
1049 | helper_vmovups_( |
1050 | vsrc_even, vsrc_odd, vmmword[reg_ptr_src_ + reg_off_dat_]); |
1051 | helper_vmovups_.merge_interleaved_to_plain( |
1052 | vsrc_even, vsrc_odd, vtmp_); |
1053 | load_c_specifics(true); |
1054 | compute_bnorm(vsrc_even, vmean_even_, vsqrtvar_even_, |
1055 | stream_store_allowed, true); |
1056 | |
1057 | load_c_specifics(true, simd_w * acc_type_size_); |
1058 | compute_bnorm(vsrc_odd, vmean_odd_, vsqrtvar_odd_, |
1059 | stream_store_allowed, true, stride_C_ * data_type_size_); |
1060 | } |
1061 | } |
1062 | |
1063 | void compute_avx2_ne_xf16(bool stream_store_allowed) { |
1064 | Label label_C, label_S, label_C_tail, label_C_end, label_S_C_tail; |
1065 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1066 | L(label_C); |
1067 | { |
1068 | cmp(reg_C_, 1); |
1069 | jle(label_C_tail, T_NEAR); |
1070 | |
1071 | mov(reg_off_dat_, reg_off_dat_save_); |
1072 | load_two_c_mean_sqrtvar(); |
1073 | |
1074 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1075 | L(label_S); |
1076 | { |
1077 | compute_bnorm_avx2_ne_xf16(false, stream_store_allowed); |
1078 | |
1079 | add(reg_off_dat_, stride_S_ * data_type_size_); |
1080 | dec(reg_S_); |
1081 | jnz(label_S, T_NEAR); |
1082 | } |
1083 | add(reg_off_dat_save_, 2 * stride_C_ * data_type_size_); |
1084 | add(reg_off_c_, 2 * simd_w * acc_type_size_); |
1085 | |
1086 | sub(reg_C_, 2); |
1087 | jnz(label_C, T_NEAR); |
1088 | } |
1089 | |
1090 | L(label_C_tail); |
1091 | { |
1092 | cmp(reg_C_, 0); |
1093 | jz(label_C_end, T_NEAR); |
1094 | |
1095 | mov(reg_off_dat_, reg_off_dat_save_); |
1096 | load_c_specifics(false); |
1097 | |
1098 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1099 | L(label_S_C_tail); |
1100 | { |
1101 | compute_bnorm_avx2_ne_xf16(true, stream_store_allowed); |
1102 | |
1103 | add(reg_off_dat_, stride_S_ * data_type_size_); |
1104 | dec(reg_S_); |
1105 | jnz(label_S_C_tail, T_NEAR); |
1106 | } |
1107 | } |
1108 | L(label_C_end); |
1109 | } |
1110 | |
1111 | void compute_blocked(bool stream_store_allowed) { |
1112 | Label label_C, label_S; |
1113 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1114 | L(label_C); |
1115 | { |
1116 | mov(reg_off_dat_, reg_off_dat_save_); |
1117 | |
1118 | load_c_specifics(false); |
1119 | |
1120 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1121 | L(label_S); |
1122 | { |
1123 | compute_bnorm( |
1124 | v_, vmean_, vsqrtvar_, stream_store_allowed, false); |
1125 | |
1126 | add(reg_off_dat_, stride_S_ * data_type_size_); |
1127 | |
1128 | dec(reg_S_); |
1129 | jnz(label_S); |
1130 | } |
1131 | |
1132 | add(reg_off_dat_save_, stride_C_ * data_type_size_); |
1133 | add(reg_off_c_, simd_w * acc_type_size_); |
1134 | |
1135 | dec(reg_C_); |
1136 | jnz(label_C); |
1137 | } |
1138 | } |
1139 | |
1140 | void compute(bool stream_store_allowed) { |
1141 | Label label_N; |
1142 | mov(reg_N_, ptr[rsp + stack_off_N]); |
1143 | L(label_N); |
1144 | { |
1145 | // save reg_N_, because register is shared with reg_ptr_shift_ |
1146 | mov(ptr[rsp + stack_off_N], reg_N_); |
1147 | mov(reg_ptr_shift_, ptr[rsp + stack_off_shift]); |
1148 | |
1149 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
1150 | xor_(reg_off_c_, reg_off_c_); |
1151 | |
1152 | is_avx2_ne_xf16_ ? compute_avx2_ne_xf16(stream_store_allowed) |
1153 | : compute_blocked(stream_store_allowed); |
1154 | |
1155 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
1156 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
1157 | xor_(reg_off_c_, reg_off_c_); |
1158 | add(reg_off_dat_save_, vlen / 2); |
1159 | add(reg_off_c_, vlen / 2); |
1160 | |
1161 | compute_blocked(stream_store_allowed); |
1162 | } |
1163 | |
1164 | add(reg_ptr_src_, stride_N_ * data_type_size_); |
1165 | add(reg_ptr_dst_, stride_N_ * data_type_size_); |
1166 | if (jit_relu_.with_relu_ && !jit_relu_.with_relu_inf_only_) |
1167 | add(reg_ptr_ws_, stride_N_ / bits_per_byte); |
1168 | |
1169 | // restore reg_N_, because register is shared with reg_ptr_shift_ |
1170 | mov(reg_N_, ptr[rsp + stack_off_N]); |
1171 | dec(reg_N_); |
1172 | jnz(label_N); |
1173 | } |
1174 | } |
1175 | |
1176 | jit_bnorm_fwd_t(const batch_normalization_pd_t *pd, |
1177 | const jit_memory_tag_kind_t tag_kind) |
1178 | : jit_generator(jit_name()) |
1179 | , pd_(pd) |
1180 | , tag_kind_(tag_kind) |
1181 | , vlen(get_vlen<isa>(tag_kind)) |
1182 | , simd_w(get_simd_w<isa>(tag_kind)) |
1183 | , is_avx2_ne_xf16_(is_avx2_ne_xf16<isa>(pd)) |
1184 | , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_, |
1185 | ktail_mask_) |
1186 | , jit_relu_(pd, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_, |
1187 | vstore_mask_, kstore_mask_, valpha, vmask, reg_alpha_) |
1188 | , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) { |
1189 | static_assert(utils::one_of(isa, sse41, avx2, avx512_core), |
1190 | "unsupported isa" ); |
1191 | |
1192 | std::tie(stride_N_, stride_S_, stride_C_) |
1193 | = get_data_strides<isa>(pd_, tag_kind); |
1194 | |
1195 | data_type_size_ = types::data_type_size(pd->src_md()->data_type); |
1196 | acc_type_size_ = sizeof(acc_data_t); |
1197 | } |
1198 | |
1199 | void generate() override { |
1200 | bool is_xf16 = utils::one_of( |
1201 | pd_->src_md()->data_type, data_type::bf16, data_type::f16); |
1202 | const bool is_tail_in_nspc_format |
1203 | = tag_kind_ == jit_memory_tag_kind_t::nspc |
1204 | && jit_tail_.tail_ != 0; |
1205 | const bool stream_store_allowed = !is_xf16 && !is_tail_in_nspc_format; |
1206 | |
1207 | preamble(); |
1208 | if (helper_vmovups_.bf16_emu_) |
1209 | helper_vmovups_.bf16_emu_->init_vcvtneps2bf16(); |
1210 | sub(rsp, stack_size_required); |
1211 | load_common_params(); |
1212 | jit_relu_.fwd_prepare_relu(); |
1213 | jit_tail_.prepare_tail(); |
1214 | |
1215 | Label normal_store, end_store; |
1216 | test(reg_ptr_dst_, vlen - 1); |
1217 | jnz(normal_store, T_NEAR); |
1218 | compute(stream_store_allowed); |
1219 | jmp(end_store, T_NEAR); |
1220 | L(normal_store); |
1221 | { compute(false); } |
1222 | L(end_store); |
1223 | |
1224 | add(rsp, stack_size_required); |
1225 | postamble(); |
1226 | } |
1227 | }; |
1228 | |
1229 | template <cpu_isa_t isa> |
1230 | struct jit_bnorm_bwd_t : public jit_generator { |
1231 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_t) |
1232 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
1233 | |
1234 | const AddressFrame &vmmword |
1235 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
1236 | |
1237 | struct call_params_t { |
1238 | size_t N, C, S; |
1239 | const void *src, *diff_src, *diff_dst; |
1240 | const uint8_t *ws; |
1241 | const acc_data_t *mean, *var; |
1242 | const acc_data_t *scale, *diff_scale, *diff_shift; |
1243 | size_t blk_has_tail; |
1244 | }; |
1245 | |
1246 | const Reg64 reg_param_ = abi_param1; |
1247 | const Reg64 reg_tmp_ = abi_not_param1; |
1248 | const Reg64 reg_N_ = rsi; |
1249 | const Reg64 reg_S_ = rax; |
1250 | const Reg64 reg_C_ = rdx; |
1251 | const Reg64 reg_off_c_ = rbx; |
1252 | const Reg64 reg_blk_has_tail_ = rbp; |
1253 | |
1254 | const Reg64 reg_off_dat_ = r8; |
1255 | const Reg64 reg_off_dat_save_ = r9; |
1256 | const Reg64 reg_ptr_c_ = r10; |
1257 | const Reg64 reg_ptr_ws_ = r11; |
1258 | const Reg64 reg_ptr_diff_dst_ = r12; |
1259 | const Reg64 reg_ptr_diff_src_ = r13; |
1260 | const Reg64 reg_ptr_src_ = r14; |
1261 | |
1262 | const Vmm vzero_ = Vmm(0); |
1263 | const Vmm vone_ = Vmm(1); |
1264 | const Vmm vmean_ = Vmm(2); |
1265 | const Vmm vsqrtvar_ = Vmm(3); |
1266 | const Vmm vgamma_ = Vmm(4); |
1267 | const Vmm vdiff_gamma_ = Vmm(5); |
1268 | const Vmm vdiff_beta_ = Vmm(6); |
1269 | const Vmm veps_ = Vmm(7); |
1270 | const Vmm vNS_ = Vmm(8); |
1271 | const Vmm vtmp_ = Vmm(9); |
1272 | const Vmm v_ = Vmm(10); |
1273 | const Vmm vtail_mask_ = Vmm(11); |
1274 | const Vmm vstore_mask_ = vtmp_; |
1275 | |
1276 | const Opmask kstore_mask_ = k1; |
1277 | const Opmask ktail_mask_ = k2; |
1278 | |
1279 | const batch_normalization_pd_t *pd_; |
1280 | const jit_memory_tag_kind_t tag_kind_; |
1281 | const int vlen; |
1282 | const int simd_w; |
1283 | jit_bnorm_process_tail_t<isa> jit_tail_; |
1284 | jit_bnorm_process_relu_t<isa> jit_relu_; |
1285 | helper_vmovups_data_t<isa> helper_vmovups_; |
1286 | int stride_N_, stride_S_, stride_C_; |
1287 | size_t data_type_size_, acc_type_size_; |
1288 | |
1289 | void load_common_params() { |
1290 | #define PARAM_PTR(x) ptr[PARAM_ADDR(x)] |
1291 | mov(reg_ptr_src_, PARAM_PTR(src)); |
1292 | mov(reg_ptr_diff_src_, PARAM_PTR(diff_src)); |
1293 | mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst)); |
1294 | mov(reg_ptr_ws_, PARAM_PTR(ws)); |
1295 | #undef PARAM_PTR |
1296 | |
1297 | Xmm x = Xmm(v_.getIdx()); |
1298 | |
1299 | mov(reg_tmp_, float2int(pd_->desc()->batch_norm_epsilon)); |
1300 | uni_vmovq(x, reg_tmp_); |
1301 | uni_vbroadcastss(veps_, x); |
1302 | |
1303 | mov(reg_tmp_, float2int(1.f)); |
1304 | uni_vmovq(x, reg_tmp_); |
1305 | uni_vbroadcastss(vone_, x); |
1306 | |
1307 | const int S = pd_->D() * pd_->H() * pd_->W(); |
1308 | mov(reg_tmp_, float2int(pd_->MB() * S)); |
1309 | uni_vmovq(x, reg_tmp_); |
1310 | uni_vbroadcastss(vNS_, x); |
1311 | |
1312 | mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]); |
1313 | } |
1314 | |
1315 | void load_c_specifics() { |
1316 | mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]); |
1317 | jit_tail_.uni_vmovups_maybe_tail( |
1318 | vmean_, vmmword[reg_ptr_c_ + reg_off_c_]); |
1319 | |
1320 | mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]); |
1321 | jit_tail_.uni_vmovups_maybe_tail( |
1322 | vsqrtvar_, vmmword[reg_ptr_c_ + reg_off_c_]); |
1323 | uni_vaddps(vsqrtvar_, vsqrtvar_, veps_); |
1324 | uni_vsqrtps(vsqrtvar_, vsqrtvar_); |
1325 | |
1326 | if (isa == sse41) { |
1327 | movups(vtmp_, vone_); |
1328 | divps(vtmp_, vsqrtvar_); |
1329 | movups(vsqrtvar_, vtmp_); |
1330 | } else |
1331 | vdivps(vsqrtvar_, vone_, vsqrtvar_); |
1332 | |
1333 | if (pd_->use_scale()) { |
1334 | mov(reg_ptr_c_, ptr[PARAM_ADDR(scale)]); |
1335 | jit_tail_.uni_vmovups_maybe_tail( |
1336 | vgamma_, vmmword[reg_ptr_c_ + reg_off_c_]); |
1337 | } |
1338 | |
1339 | if (calculate_diff_stats()) { |
1340 | mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_scale)]); |
1341 | jit_tail_.uni_vmovups_maybe_tail( |
1342 | vdiff_gamma_, vmmword[reg_ptr_c_ + reg_off_c_]); |
1343 | uni_vmulps(vdiff_gamma_, vdiff_gamma_, vsqrtvar_); |
1344 | uni_vdivps(vdiff_gamma_, vdiff_gamma_, vNS_); |
1345 | mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_shift)]); |
1346 | jit_tail_.uni_vmovups_maybe_tail( |
1347 | vdiff_beta_, vmmword[reg_ptr_c_ + reg_off_c_]); |
1348 | uni_vdivps(vdiff_beta_, vdiff_beta_, vNS_); |
1349 | } |
1350 | } |
1351 | |
1352 | void compute_bnorm(bool stream_store_allowed) { |
1353 | helper_vmovups_(v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_]); |
1354 | jit_relu_.bwd_process_relu(v_); |
1355 | |
1356 | if (calculate_diff_stats()) { |
1357 | uni_vsubps(v_, v_, vdiff_beta_); |
1358 | helper_vmovups_(vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_]); |
1359 | uni_vsubps(vtmp_, vtmp_, vmean_); |
1360 | uni_vmulps(vtmp_, vtmp_, vdiff_gamma_); |
1361 | uni_vsubps(v_, v_, vtmp_); |
1362 | } |
1363 | |
1364 | if (pd_->use_scale()) uni_vmulps(v_, v_, vgamma_); |
1365 | uni_vmulps(v_, v_, vsqrtvar_); |
1366 | |
1367 | if (stream_store_allowed) { |
1368 | uni_vmovntps(vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_); |
1369 | } else { |
1370 | helper_vmovups_(vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_); |
1371 | } |
1372 | } |
1373 | |
1374 | void compute_blocked(bool stream_store_allowed) { |
1375 | Label label_C, label_S; |
1376 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1377 | L(label_C); |
1378 | { |
1379 | mov(reg_off_dat_, reg_off_dat_save_); |
1380 | |
1381 | load_c_specifics(); |
1382 | |
1383 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1384 | L(label_S); |
1385 | { |
1386 | compute_bnorm(stream_store_allowed); |
1387 | |
1388 | add(reg_off_dat_, stride_S_ * data_type_size_); |
1389 | |
1390 | dec(reg_S_); |
1391 | jnz(label_S); |
1392 | } |
1393 | |
1394 | add(reg_off_dat_save_, stride_C_ * data_type_size_); |
1395 | add(reg_off_c_, simd_w * acc_type_size_); |
1396 | |
1397 | dec(reg_C_); |
1398 | jnz(label_C); |
1399 | } |
1400 | } |
1401 | |
1402 | void compute_nspc(bool stream_store_allowed) { |
1403 | Label label_C, label_S; |
1404 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1405 | L(label_S); |
1406 | { |
1407 | mov(reg_off_dat_, reg_off_dat_save_); |
1408 | xor_(reg_off_c_, reg_off_c_); |
1409 | |
1410 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1411 | L(label_C); |
1412 | { |
1413 | load_c_specifics(); |
1414 | |
1415 | compute_bnorm(stream_store_allowed); |
1416 | |
1417 | add(reg_off_c_, simd_w * acc_type_size_); |
1418 | add(reg_off_dat_, stride_C_ * data_type_size_); |
1419 | |
1420 | dec(reg_C_); |
1421 | jnz(label_C); |
1422 | } |
1423 | |
1424 | add(reg_off_dat_save_, stride_S_ * data_type_size_); |
1425 | |
1426 | dec(reg_S_); |
1427 | jnz(label_S); |
1428 | } |
1429 | } |
1430 | |
1431 | void compute(bool stream_store_allowed) { |
1432 | Label label_N; |
1433 | mov(reg_N_, dword[PARAM_ADDR(N)]); |
1434 | L(label_N); |
1435 | { |
1436 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
1437 | xor_(reg_off_c_, reg_off_c_); |
1438 | |
1439 | tag_kind_ == jit_memory_tag_kind_t::nspc |
1440 | ? compute_nspc(stream_store_allowed) |
1441 | : compute_blocked(stream_store_allowed); |
1442 | |
1443 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
1444 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
1445 | xor_(reg_off_c_, reg_off_c_); |
1446 | add(reg_off_dat_save_, vlen / 2); |
1447 | add(reg_off_c_, vlen / 2); |
1448 | |
1449 | compute_blocked(stream_store_allowed); |
1450 | } |
1451 | |
1452 | add(reg_ptr_src_, stride_N_ * data_type_size_); |
1453 | add(reg_ptr_diff_src_, stride_N_ * data_type_size_); |
1454 | add(reg_ptr_diff_dst_, stride_N_ * data_type_size_); |
1455 | add(reg_ptr_ws_, stride_N_ / bits_per_byte); |
1456 | |
1457 | dec(reg_N_); |
1458 | jnz(label_N); |
1459 | } |
1460 | } |
1461 | |
1462 | bool calculate_diff_stats() const { return !pd_->use_global_stats(); } |
1463 | |
1464 | jit_bnorm_bwd_t(const batch_normalization_pd_t *pd, |
1465 | const jit_memory_tag_kind_t tag_kind) |
1466 | : jit_generator(jit_name()) |
1467 | , pd_(pd) |
1468 | , tag_kind_(tag_kind) |
1469 | , vlen(get_vlen<isa>(tag_kind)) |
1470 | , simd_w(get_simd_w<isa>(tag_kind)) |
1471 | , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_, |
1472 | ktail_mask_) |
1473 | , jit_relu_(pd, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_, |
1474 | vstore_mask_, kstore_mask_) |
1475 | , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) { |
1476 | static_assert(utils::one_of(isa, sse41, avx2, avx512_core), |
1477 | "unsupported isa" ); |
1478 | |
1479 | std::tie(stride_N_, stride_S_, stride_C_) |
1480 | = get_data_strides<isa>(pd_, tag_kind); |
1481 | |
1482 | data_type_size_ = types::data_type_size(pd->src_md()->data_type); |
1483 | acc_type_size_ = sizeof(acc_data_t); |
1484 | } |
1485 | |
1486 | void generate() override { |
1487 | bool is_bf16 = pd_->src_md()->data_type == data_type::bf16; |
1488 | bool is_f16 = pd_->src_md()->data_type == data_type::f16; |
1489 | const bool is_tail_in_nspc_format |
1490 | = tag_kind_ == jit_memory_tag_kind_t::nspc |
1491 | && jit_tail_.tail_ != 0; |
1492 | const bool stream_store_allowed |
1493 | = !is_bf16 && !is_f16 && !is_tail_in_nspc_format; |
1494 | |
1495 | preamble(); |
1496 | if (helper_vmovups_.bf16_emu_) |
1497 | helper_vmovups_.bf16_emu_->init_vcvtneps2bf16(); |
1498 | load_common_params(); |
1499 | jit_relu_.bwd_prepare_relu(); |
1500 | jit_tail_.prepare_tail(); |
1501 | |
1502 | Label normal_store, end_store; |
1503 | test(reg_ptr_diff_src_, vlen - 1); |
1504 | jnz(normal_store, T_NEAR); |
1505 | compute(stream_store_allowed); |
1506 | jmp(end_store, T_NEAR); |
1507 | L(normal_store); |
1508 | { compute(false); } |
1509 | L(end_store); |
1510 | |
1511 | postamble(); |
1512 | } |
1513 | }; |
1514 | |
1515 | template <cpu_isa_t isa> |
1516 | struct jit_bnorm_bwd_diff_ss_t : public jit_generator { |
1517 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_diff_ss_t) |
1518 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
1519 | |
1520 | const AddressFrame &vmmword |
1521 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
1522 | |
1523 | struct call_params_t { |
1524 | size_t N, C, S; |
1525 | const void *src, *diff_dst; |
1526 | const uint8_t *ws; |
1527 | const acc_data_t *mean, *var; |
1528 | const acc_data_t *diff_gamma, *diff_beta; |
1529 | size_t blk_has_tail; |
1530 | }; |
1531 | |
1532 | const Reg64 reg_param_ = abi_param1; |
1533 | const Reg64 reg_tmp_ = abi_not_param1; |
1534 | const Reg64 reg_N_ = rsi; |
1535 | const Reg64 reg_S_ = rax; |
1536 | const Reg64 reg_C_ = rdx; |
1537 | const Reg64 reg_off_c_ = rbx; |
1538 | const Reg64 reg_blk_has_tail_ = rbp; |
1539 | |
1540 | const Reg64 reg_off_dat_ = r8; |
1541 | const Reg64 reg_off_dat_save_ = r9; |
1542 | const Reg64 reg_ptr_c_ = r10; |
1543 | const Reg64 reg_ptr_diff_gamma_ = r11; |
1544 | const Reg64 reg_ptr_diff_beta_ = r12; |
1545 | const Reg64 reg_ptr_ws_ = r13; |
1546 | const Reg64 reg_ptr_diff_dst_ = r14; |
1547 | const Reg64 reg_ptr_src_ = r15; |
1548 | |
1549 | const Vmm vtail_mask_ = Vmm(0); |
1550 | const Vmm v_ = Vmm(1); |
1551 | const Vmm vtmp_ = Vmm(2); |
1552 | const Vmm vstore_mask_ = vtmp_; |
1553 | const Vmm vzero_ = Vmm(3); |
1554 | const Vmm veps_ = Vmm(4); |
1555 | const Vmm vone_ = Vmm(5); |
1556 | // Diff_beta, diff_gamma and one of the statistic values(mean or sqrtvar) |
1557 | // are unrolled i.e.three vmms are needed to unroll one c block at any moment, |
1558 | // therefore the number of registers which are used to unrolling must to be |
1559 | // divisible by three. |
1560 | static constexpr int min_idx_to_unroll_ = 6; |
1561 | static constexpr int max_idx_to_unroll_ = isa == avx512_core ? 27 : 15; |
1562 | static constexpr int number_of_unrolled_variables_ = 3; |
1563 | static constexpr int number_of_vmms_to_unrolling_variables_ |
1564 | = max_idx_to_unroll_ - min_idx_to_unroll_; |
1565 | static_assert(number_of_vmms_to_unrolling_variables_ |
1566 | % number_of_unrolled_variables_ |
1567 | == 0 |
1568 | && number_of_vmms_to_unrolling_variables_ != 0, |
1569 | "Number of register to unrolling must to be divisible by 3." ); |
1570 | |
1571 | const Opmask kstore_mask_ = k1; |
1572 | const Opmask ktail_mask_ = k2; |
1573 | |
1574 | const batch_normalization_pd_t *pd_; |
1575 | const jit_memory_tag_kind_t tag_kind_; |
1576 | const int vlen; |
1577 | const int simd_w; |
1578 | jit_bnorm_process_tail_t<isa> jit_tail_; |
1579 | jit_bnorm_process_relu_t<isa> jit_relu_; |
1580 | helper_vmovups_data_t<isa> helper_vmovups_; |
1581 | int stride_N_, stride_S_, stride_C_; |
1582 | size_t data_type_size_, acc_type_size_; |
1583 | |
1584 | void load_common_params() { |
1585 | #define PARAM_PTR(x) ptr[PARAM_ADDR(x)] |
1586 | mov(reg_ptr_src_, PARAM_PTR(src)); |
1587 | mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst)); |
1588 | mov(reg_ptr_ws_, PARAM_PTR(ws)); |
1589 | mov(reg_ptr_diff_gamma_, PARAM_PTR(diff_gamma)); |
1590 | mov(reg_ptr_diff_beta_, PARAM_PTR(diff_beta)); |
1591 | #undef PARAM_PTR |
1592 | |
1593 | Xmm x = Xmm(v_.getIdx()); |
1594 | |
1595 | mov(reg_tmp_, float2int(pd_->desc()->batch_norm_epsilon)); |
1596 | uni_vmovq(x, reg_tmp_); |
1597 | uni_vbroadcastss(veps_, x); |
1598 | |
1599 | mov(reg_tmp_, float2int(1.f)); |
1600 | uni_vmovq(x, reg_tmp_); |
1601 | uni_vbroadcastss(vone_, x); |
1602 | |
1603 | mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]); |
1604 | } |
1605 | |
1606 | void zeroise() { |
1607 | Label label_zeroise; |
1608 | xor_(reg_off_c_, reg_off_c_); |
1609 | uni_vpxor(vzero_, vzero_, vzero_); |
1610 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1611 | L(label_zeroise); |
1612 | { |
1613 | jit_tail_.uni_vmovups_maybe_tail( |
1614 | vmmword[reg_ptr_diff_gamma_ + reg_off_c_], vzero_); |
1615 | jit_tail_.uni_vmovups_maybe_tail( |
1616 | vmmword[reg_ptr_diff_beta_ + reg_off_c_], vzero_); |
1617 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
1618 | jit_tail_.uni_vmovups_maybe_tail( |
1619 | vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + vlen / 2], |
1620 | vzero_); |
1621 | jit_tail_.uni_vmovups_maybe_tail( |
1622 | vmmword[reg_ptr_diff_beta_ + reg_off_c_ + vlen / 2], |
1623 | vzero_); |
1624 | } |
1625 | add(reg_off_c_, simd_w * acc_type_size_); |
1626 | dec(reg_C_); |
1627 | jnz(label_zeroise); |
1628 | } |
1629 | } |
1630 | |
1631 | void load_mean(const int c_blks_to_unroll = 1) { |
1632 | mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]); |
1633 | |
1634 | const int start_idx = min_idx_to_unroll_; |
1635 | const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll |
1636 | + min_idx_to_unroll_; |
1637 | const int step = simd_w * acc_type_size_; |
1638 | |
1639 | for (int idx = start_idx, off = 0; idx < end_idx; |
1640 | idx += number_of_unrolled_variables_, off += step) { |
1641 | const Vmm vmean = Vmm(idx); |
1642 | |
1643 | jit_tail_.uni_vmovups_maybe_tail( |
1644 | vmean, vmmword[reg_ptr_c_ + reg_off_c_ + off]); |
1645 | } |
1646 | } |
1647 | |
1648 | void zeroise_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) { |
1649 | const int start_idx = min_idx_to_unroll_; |
1650 | const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll |
1651 | + min_idx_to_unroll_; |
1652 | |
1653 | for (int idx = start_idx; idx < end_idx; |
1654 | idx += number_of_unrolled_variables_) { |
1655 | const Vmm vdiff_beta = Vmm(idx + 1); |
1656 | const Vmm vdiff_gamma = Vmm(idx + 2); |
1657 | |
1658 | uni_vpxor(vdiff_beta, vdiff_beta, vdiff_beta); |
1659 | uni_vpxor(vdiff_gamma, vdiff_gamma, vdiff_gamma); |
1660 | } |
1661 | } |
1662 | |
1663 | void load_and_prepare_sqrtvar(const int c_blks_to_unroll = 1) { |
1664 | mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]); |
1665 | |
1666 | const int start_idx = min_idx_to_unroll_; |
1667 | const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll |
1668 | + min_idx_to_unroll_; |
1669 | const int step = simd_w * acc_type_size_; |
1670 | |
1671 | for (int idx = start_idx, off = 0; idx < end_idx; |
1672 | idx += number_of_unrolled_variables_, off += step) { |
1673 | const Vmm vsqrtvar = Vmm(idx); |
1674 | |
1675 | jit_tail_.uni_vmovups_maybe_tail( |
1676 | vsqrtvar, vmmword[reg_ptr_c_ + reg_off_c_ + off]); |
1677 | |
1678 | // 1.0 / sqrt(var + eps) |
1679 | uni_vaddps(vsqrtvar, vsqrtvar, veps_); |
1680 | uni_vsqrtps(vsqrtvar, vsqrtvar); |
1681 | |
1682 | if (isa == sse41) { |
1683 | movups(vtmp_, vone_); |
1684 | divps(vtmp_, vsqrtvar); |
1685 | movups(vsqrtvar, vtmp_); |
1686 | } else |
1687 | vdivps(vsqrtvar, vone_, vsqrtvar); |
1688 | } |
1689 | } |
1690 | |
1691 | void compute_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) { |
1692 | const int start_idx = min_idx_to_unroll_; |
1693 | const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll |
1694 | + min_idx_to_unroll_; |
1695 | const int step = simd_w * data_type_size_; |
1696 | |
1697 | for (int idx = start_idx, off = 0; idx < end_idx; |
1698 | idx += number_of_unrolled_variables_, off += step) { |
1699 | const Vmm vmean = Vmm(idx); |
1700 | const Vmm vdiff_beta = Vmm(idx + 1); |
1701 | const Vmm vdiff_gamma = Vmm(idx + 2); |
1702 | |
1703 | helper_vmovups_( |
1704 | v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_ + off]); |
1705 | |
1706 | jit_relu_.bwd_process_relu( |
1707 | v_, off / (bits_per_byte * data_type_size_)); |
1708 | |
1709 | // diff_beta |
1710 | uni_vaddps(vdiff_beta, vdiff_beta, v_); |
1711 | |
1712 | helper_vmovups_(vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]); |
1713 | |
1714 | // diff_gamma, note that diff_gamma will be multiplied |
1715 | // by sqrtvar before store |
1716 | uni_vsubps(vtmp_, vtmp_, vmean); |
1717 | uni_vfmadd231ps(vdiff_gamma, vtmp_, v_); |
1718 | } |
1719 | } |
1720 | |
1721 | void store_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) { |
1722 | const int start_idx = min_idx_to_unroll_; |
1723 | const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll |
1724 | + min_idx_to_unroll_; |
1725 | const int step = simd_w * acc_type_size_; |
1726 | |
1727 | for (int idx = start_idx, off = 0; idx < end_idx; |
1728 | idx += number_of_unrolled_variables_, off += step) { |
1729 | const Vmm vdiff_beta = Vmm(idx + 1); |
1730 | |
1731 | jit_tail_.uni_vmovups_maybe_tail( |
1732 | vtmp_, vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off]); |
1733 | uni_vaddps(vdiff_beta, vdiff_beta, vtmp_); |
1734 | jit_tail_.uni_vmovups_maybe_tail( |
1735 | vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off], vdiff_beta); |
1736 | } |
1737 | |
1738 | for (int idx = start_idx, off = 0; idx < end_idx; |
1739 | idx += number_of_unrolled_variables_, off += step) { |
1740 | const Vmm vsqrtvar = Vmm(idx); |
1741 | const Vmm vdiff_gamma = Vmm(idx + 2); |
1742 | |
1743 | // multiply diff_gamma by 1.0/sqrt(var + eps) |
1744 | uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); |
1745 | |
1746 | jit_tail_.uni_vmovups_maybe_tail( |
1747 | vtmp_, vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off]); |
1748 | uni_vaddps(vdiff_gamma, vdiff_gamma, vtmp_); |
1749 | jit_tail_.uni_vmovups_maybe_tail( |
1750 | vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off], |
1751 | vdiff_gamma); |
1752 | } |
1753 | } |
1754 | |
1755 | void compute_blocked() { |
1756 | Label label_C, label_S; |
1757 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1758 | L(label_C); |
1759 | { |
1760 | mov(reg_off_dat_, reg_off_dat_save_); |
1761 | |
1762 | load_mean(); |
1763 | zeroise_diff_beta_and_diff_gamma(); |
1764 | |
1765 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1766 | L(label_S); |
1767 | { |
1768 | compute_diff_beta_and_diff_gamma(); |
1769 | |
1770 | add(reg_off_dat_, stride_S_ * data_type_size_); |
1771 | |
1772 | dec(reg_S_); |
1773 | jnz(label_S); |
1774 | } |
1775 | |
1776 | load_and_prepare_sqrtvar(); |
1777 | store_diff_beta_and_diff_gamma(); |
1778 | |
1779 | add(reg_off_dat_save_, stride_C_ * data_type_size_); |
1780 | add(reg_off_c_, simd_w * acc_type_size_); |
1781 | |
1782 | dec(reg_C_); |
1783 | jnz(label_C); |
1784 | } |
1785 | } |
1786 | |
1787 | void compute_nspc() { |
1788 | mov(reg_C_, dword[PARAM_ADDR(C)]); |
1789 | |
1790 | constexpr int max_of_unrolled_c_blks |
1791 | = number_of_vmms_to_unrolling_variables_ |
1792 | / number_of_unrolled_variables_; |
1793 | std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1); |
1794 | |
1795 | for (int c_blks_to_unroll = max_of_unrolled_c_blks; |
1796 | c_blks_to_unroll > 0; --c_blks_to_unroll) { |
1797 | L(c_unroll_label[c_blks_to_unroll]); |
1798 | { |
1799 | cmp(reg_C_, c_blks_to_unroll); |
1800 | jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR); |
1801 | |
1802 | mov(reg_off_dat_, reg_off_dat_save_); |
1803 | |
1804 | load_mean(c_blks_to_unroll); |
1805 | zeroise_diff_beta_and_diff_gamma(c_blks_to_unroll); |
1806 | |
1807 | Label label_S; |
1808 | mov(reg_S_, dword[PARAM_ADDR(S)]); |
1809 | L(label_S); |
1810 | { |
1811 | compute_diff_beta_and_diff_gamma(c_blks_to_unroll); |
1812 | |
1813 | add(reg_off_dat_, stride_S_ * data_type_size_); |
1814 | |
1815 | dec(reg_S_); |
1816 | jnz(label_S); |
1817 | } |
1818 | |
1819 | load_and_prepare_sqrtvar(c_blks_to_unroll); |
1820 | store_diff_beta_and_diff_gamma(c_blks_to_unroll); |
1821 | |
1822 | add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_); |
1823 | add(reg_off_dat_save_, |
1824 | c_blks_to_unroll * stride_C_ * data_type_size_); |
1825 | |
1826 | sub(reg_C_, c_blks_to_unroll); |
1827 | jmp(c_unroll_label[c_blks_to_unroll], T_NEAR); |
1828 | } |
1829 | } |
1830 | L(c_unroll_label[0]); |
1831 | } |
1832 | |
1833 | void compute() { |
1834 | Label label_N; |
1835 | mov(reg_N_, dword[PARAM_ADDR(N)]); |
1836 | L(label_N); |
1837 | { |
1838 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
1839 | xor_(reg_off_c_, reg_off_c_); |
1840 | |
1841 | tag_kind_ == jit_memory_tag_kind_t::nspc ? compute_nspc() |
1842 | : compute_blocked(); |
1843 | |
1844 | if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) { |
1845 | xor_(reg_off_dat_save_, reg_off_dat_save_); |
1846 | xor_(reg_off_c_, reg_off_c_); |
1847 | add(reg_off_dat_save_, vlen / 2); |
1848 | add(reg_off_c_, vlen / 2); |
1849 | |
1850 | compute_blocked(); |
1851 | } |
1852 | |
1853 | add(reg_ptr_src_, stride_N_ * data_type_size_); |
1854 | add(reg_ptr_diff_dst_, stride_N_ * data_type_size_); |
1855 | add(reg_ptr_ws_, stride_N_ / bits_per_byte); |
1856 | |
1857 | dec(reg_N_); |
1858 | jnz(label_N); |
1859 | } |
1860 | } |
1861 | |
1862 | jit_bnorm_bwd_diff_ss_t(const batch_normalization_pd_t *pd, |
1863 | const jit_memory_tag_kind_t tag_kind) |
1864 | : jit_generator(jit_name()) |
1865 | , pd_(pd) |
1866 | , tag_kind_(tag_kind) |
1867 | , vlen(get_vlen<isa>(tag_kind)) |
1868 | , simd_w(get_simd_w<isa>(tag_kind)) |
1869 | , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_, |
1870 | ktail_mask_) |
1871 | , jit_relu_(pd, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_, |
1872 | vstore_mask_, kstore_mask_) |
1873 | , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) { |
1874 | static_assert(utils::one_of(isa, sse41, avx2, avx512_core), |
1875 | "unsupported isa" ); |
1876 | |
1877 | std::tie(stride_N_, stride_S_, stride_C_) |
1878 | = get_data_strides<isa>(pd_, tag_kind); |
1879 | |
1880 | data_type_size_ = types::data_type_size(pd->src_md()->data_type); |
1881 | acc_type_size_ = sizeof(acc_data_t); |
1882 | } |
1883 | |
1884 | void generate() override { |
1885 | preamble(); |
1886 | load_common_params(); |
1887 | jit_relu_.bwd_prepare_relu(); |
1888 | jit_tail_.prepare_tail(); |
1889 | zeroise(); |
1890 | compute(); |
1891 | postamble(); |
1892 | } |
1893 | }; |
1894 | |
1895 | namespace bnorm_tbb_impl { |
1896 | |
1897 | template <cpu_isa_t isa> |
1898 | struct driver_t : public c_compatible { |
1899 | private: |
1900 | struct bnorm_dims_t { |
1901 | dim_t N, C, S; |
1902 | dim_t glob; |
1903 | }; |
1904 | |
1905 | DNNL_DISALLOW_COPY_AND_ASSIGN(driver_t); |
1906 | |
1907 | public: |
1908 | driver_t(const batch_normalization_pd_t *pd, |
1909 | const jit_memory_tag_kind_t tag_kind) |
1910 | : pd_(pd), tag_kind_(tag_kind), simd_w(get_simd_w<isa>(tag_kind)) { |
1911 | nthr_ = dnnl_get_max_threads(); |
1912 | N_ = pd_->MB(); |
1913 | S_ = pd_->D() * pd_->H() * pd_->W(); |
1914 | C_ = pd_->C(); |
1915 | C_blks_ = get_c_padded(pd_) / simd_w; |
1916 | |
1917 | const size_t l3_size = platform::get_per_core_cache_size(3) * nthr_ / 2; |
1918 | int num_tensors = pd_->is_fwd() ? 1 : 2; |
1919 | dt_size_ = types::data_type_size(pd_->src_md()->data_type); |
1920 | const size_t working_set_size |
1921 | = dt_size_ * N_ * S_ * simd_w * num_tensors; |
1922 | |
1923 | do_blocking_ = tag_kind_ == jit_memory_tag_kind_t::nspc |
1924 | ? false |
1925 | : working_set_size * C_blks_ >= l3_size / 2 && l3_size > 0; |
1926 | |
1927 | if (tag_kind_ == jit_memory_tag_kind_t::nspc) { |
1928 | if (normalize_only(pd_)) { |
1929 | // blocks have to fit in a 4rth of L1 so that they don't get evicted |
1930 | // There are at most 6 tensors: src, dst, mean, var, scale, shift |
1931 | dim_t n_tensors = 2 + pd_->use_scale() + pd_->use_shift(); |
1932 | C_blk_step_ = utils::saturate<dim_t>(1, C_blks_, |
1933 | platform::get_per_core_cache_size(1) |
1934 | / get_vlen<isa>(jit_memory_tag_kind_t::nspc) |
1935 | / n_tensors); |
1936 | } else |
1937 | C_blk_step_ = C_blks_; |
1938 | } else { |
1939 | C_blk_step_ = utils::saturate<dim_t>( |
1940 | 1, C_blks_, l3_size / working_set_size); |
1941 | } |
1942 | } |
1943 | |
1944 | status_t create_kernel() { |
1945 | if (pd_->is_fwd()) { |
1946 | CHECK(safe_ptr_assign( |
1947 | ker_fwd_, new jit_bnorm_fwd_t<isa>(pd_, tag_kind_))); |
1948 | CHECK(ker_fwd_->create_kernel()); |
1949 | if (!pd_->stats_is_src()) { |
1950 | CHECK(safe_ptr_assign(ker_fwd_mean_, |
1951 | new jit_bnorm_fwd_mean_t<isa>(pd_, tag_kind_))); |
1952 | CHECK(safe_ptr_assign(ker_fwd_var_, |
1953 | new jit_bnorm_fwd_var_t<isa>(pd_, tag_kind_))); |
1954 | CHECK(ker_fwd_mean_->create_kernel()); |
1955 | CHECK(ker_fwd_var_->create_kernel()); |
1956 | } |
1957 | } else { |
1958 | CHECK(safe_ptr_assign( |
1959 | ker_bwd_, new jit_bnorm_bwd_t<isa>(pd_, tag_kind_))); |
1960 | CHECK(safe_ptr_assign(ker_bwd_diff_ss_, |
1961 | new jit_bnorm_bwd_diff_ss_t<isa>(pd_, tag_kind_))); |
1962 | CHECK(ker_bwd_->create_kernel()); |
1963 | CHECK(ker_bwd_diff_ss_->create_kernel()); |
1964 | } |
1965 | return status::success; |
1966 | } |
1967 | |
1968 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
1969 | const batch_normalization_pd_t *pd) { |
1970 | |
1971 | int nthrs = dnnl_get_max_threads(); |
1972 | int C_PADDED = get_c_padded(pd); |
1973 | |
1974 | auto sbuf_sz = use_tmp_stats(pd) * 2 * C_PADDED; |
1975 | auto pbuf_sz |
1976 | = (use_tmp_diff_scale(pd) + use_tmp_diff_shift(pd)) * C_PADDED; |
1977 | auto rbuf_sz = (pd->is_fwd() ? 1 : 2) * C_PADDED * nthrs; |
1978 | |
1979 | scratchpad.book<acc_data_t>(key_bnorm_tmp_stats, sbuf_sz); |
1980 | scratchpad.book<acc_data_t>(key_bnorm_tmp_diff_ss, pbuf_sz); |
1981 | scratchpad.book<acc_data_t>(key_bnorm_reduction, rbuf_sz); |
1982 | } |
1983 | |
1984 | void exec_fwd_step_stats(const dim_t C_blks, const bnorm_dims_t &nthr, |
1985 | const void *src, acc_data_t *mean, acc_data_t *var, |
1986 | acc_data_t *rbuf, bool blk_has_tail) { |
1987 | size_t stride_C, stride_N, stride_S; |
1988 | std::tie(stride_N, stride_S, stride_C) |
1989 | = get_data_strides<isa>(pd_, tag_kind_); |
1990 | |
1991 | const int nthr_NS = nthr.N * nthr.S; |
1992 | const bool need_reduction = nthr_NS > 1; |
1993 | const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w; |
1994 | |
1995 | const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size; |
1996 | |
1997 | auto reduce = [&](acc_data_t *stat, acc_data_t *r_stat) { |
1998 | if (!need_reduction) return; |
1999 | acc_data_t *loc_stat = r_stat; |
2000 | |
2001 | for (dim_t c = 0; c < size_C_stat; ++c) |
2002 | stat[c] = loc_stat[c]; |
2003 | |
2004 | for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) { |
2005 | loc_stat += size_C_stat; |
2006 | for (dim_t c = 0; c < size_C_stat; ++c) |
2007 | stat[c] += loc_stat[c]; |
2008 | } |
2009 | |
2010 | for (dim_t c = 0; c < size_C_stat; ++c) |
2011 | stat[c] /= N_ * S_; |
2012 | }; |
2013 | |
2014 | // find local mean |
2015 | acc_data_t *r_mean = need_reduction ? rbuf : mean; |
2016 | parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) { |
2017 | assert(nthr_glob == nthr.glob); |
2018 | const auto ithr = map_thread(ithr_glob, nthr); |
2019 | bnorm_dims_t start, stop; |
2020 | work_distribution(C_blks, ithr, nthr, start, stop); |
2021 | |
2022 | auto c = typename jit_bnorm_fwd_mean_t<isa>::call_params_t(); |
2023 | c.N = stop.N - start.N; |
2024 | c.C = stop.C - start.C; |
2025 | c.S = stop.S - start.S; |
2026 | |
2027 | const size_t d_off = start.N * stride_N + start.C * stride_C |
2028 | + start.S * stride_S; |
2029 | c.src = (void *)((char *)src + d_off * dt_size_); |
2030 | const int ithr_NS = ithr.N * nthr.S + ithr.S; |
2031 | c.mean = &r_mean[ithr_NS * size_C_stat + start.C * simd_w]; |
2032 | c.blk_has_tail = blk_has_tail && stop.C == C_blks; |
2033 | c.do_normalise = !need_reduction; |
2034 | (*ker_fwd_mean_)(&c); |
2035 | }); |
2036 | |
2037 | // mean reduction |
2038 | reduce(mean, r_mean); |
2039 | |
2040 | // find local var |
2041 | acc_data_t *r_var = need_reduction ? rbuf : var; |
2042 | parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) { |
2043 | assert(nthr_glob == nthr.glob); |
2044 | const auto ithr = map_thread(ithr_glob, nthr); |
2045 | bnorm_dims_t start, stop; |
2046 | work_distribution(C_blks, ithr, nthr, start, stop); |
2047 | |
2048 | auto c = typename jit_bnorm_fwd_var_t<isa>::call_params_t(); |
2049 | c.N = stop.N - start.N; |
2050 | c.C = stop.C - start.C; |
2051 | c.S = stop.S - start.S; |
2052 | |
2053 | const size_t d_off = start.N * stride_N + start.C * stride_C |
2054 | + start.S * stride_S; |
2055 | c.src = (void *)((char *)src + d_off * dt_size_); |
2056 | const int ithr_NS = ithr.N * nthr.S + ithr.S; |
2057 | c.mean = &mean[start.C * simd_w]; |
2058 | c.var = &r_var[ithr_NS * size_C_stat + start.C * simd_w]; |
2059 | c.blk_has_tail = blk_has_tail && stop.C == C_blks; |
2060 | c.do_normalise = !need_reduction; |
2061 | (*ker_fwd_var_)(&c); |
2062 | }); |
2063 | |
2064 | // var reduction |
2065 | reduce(var, r_var); |
2066 | } |
2067 | |
2068 | void exec_fwd_step_normalization(const dim_t C_blks, |
2069 | const bnorm_dims_t &nthr, const void *src, void *dst, |
2070 | const acc_data_t *scale, const acc_data_t *shift, |
2071 | const acc_data_t *mean, const acc_data_t *var, uint8_t *ws, |
2072 | bool blk_has_tail) { |
2073 | size_t stride_C, stride_N, stride_S; |
2074 | std::tie(stride_N, stride_S, stride_C) |
2075 | = get_data_strides<isa>(pd_, tag_kind_); |
2076 | |
2077 | parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) { |
2078 | assert(nthr_glob == nthr.glob); |
2079 | const auto ithr = map_thread(ithr_glob, nthr); |
2080 | bnorm_dims_t start, stop; |
2081 | work_distribution(C_blks, ithr, nthr, start, stop); |
2082 | |
2083 | auto c = typename jit_bnorm_fwd_t<isa>::call_params_t(); |
2084 | c.N = stop.N - start.N; |
2085 | c.C = stop.C - start.C; |
2086 | c.S = stop.S - start.S; |
2087 | const size_t d_off = start.N * stride_N + start.C * stride_C |
2088 | + start.S * stride_S; |
2089 | c.src = (void *)((char *)src + d_off * dt_size_); |
2090 | c.dst = (void *)((char *)dst + d_off * dt_size_); |
2091 | c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr; |
2092 | c.mean = &mean[start.C * simd_w]; |
2093 | c.var = &var[start.C * simd_w]; |
2094 | c.scale = scale ? &scale[start.C * simd_w] : nullptr; |
2095 | c.shift = shift ? &shift[start.C * simd_w] : nullptr; |
2096 | c.blk_has_tail = blk_has_tail && stop.C == C_blks; |
2097 | (*ker_fwd_)(&c); |
2098 | }); |
2099 | } |
2100 | |
2101 | void exec_fwd(const void *src, void *dst, const acc_data_t *scale, |
2102 | const acc_data_t *shift, acc_data_t *mean, acc_data_t *var, |
2103 | uint8_t *ws, const memory_tracking::grantor_t &scratchpad) { |
2104 | auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction); |
2105 | if (use_tmp_stats(pd_)) { |
2106 | auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats); |
2107 | mean = sbuf; |
2108 | var = sbuf + C_blks_ * simd_w; |
2109 | } |
2110 | |
2111 | size_t stride_C; |
2112 | std::tie(std::ignore, std::ignore, stride_C) |
2113 | = get_data_strides<isa>(pd_, tag_kind_); |
2114 | |
2115 | dim_t C_blk_step = C_blk_step_; |
2116 | auto nthr = bnorm_dims_t(); |
2117 | |
2118 | thread_distribution(C_blk_step, nthr); |
2119 | |
2120 | for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) { |
2121 | if (C_blk_st + C_blk_step > C_blks_) { |
2122 | C_blk_step = C_blks_ - C_blk_st; |
2123 | thread_distribution(C_blk_step, nthr); |
2124 | } |
2125 | |
2126 | if (!pd_->stats_is_src()) { |
2127 | exec_fwd_step_stats(C_blk_step, nthr, |
2128 | (void *)((char *)src |
2129 | + (C_blk_st * stride_C) * dt_size_), |
2130 | mean + C_blk_st * simd_w, var + C_blk_st * simd_w, rbuf, |
2131 | (C_blk_st + C_blk_step) * simd_w > C_); |
2132 | } |
2133 | exec_fwd_step_normalization(C_blk_step, nthr, |
2134 | (void *)((char *)src + (C_blk_st * stride_C) * dt_size_), |
2135 | (void *)((char *)dst + (C_blk_st * stride_C) * dt_size_), |
2136 | scale + C_blk_st * simd_w, shift + C_blk_st * simd_w, |
2137 | mean + C_blk_st * simd_w, var + C_blk_st * simd_w, |
2138 | ws + C_blk_st * stride_C / bits_per_byte, |
2139 | (C_blk_st + C_blk_step) * simd_w > C_); |
2140 | } |
2141 | } |
2142 | |
2143 | void exec_bwd_step_diff_ss(const dim_t C_blks, const bnorm_dims_t &nthr, |
2144 | const void *src, const void *diff_dst, const acc_data_t *mean, |
2145 | const acc_data_t *var, const uint8_t *ws, acc_data_t *diff_scale, |
2146 | acc_data_t *diff_shift, acc_data_t *rbuf, bool blk_has_tail) { |
2147 | size_t stride_C, stride_N, stride_S; |
2148 | std::tie(stride_N, stride_S, stride_C) |
2149 | = get_data_strides<isa>(pd_, tag_kind_); |
2150 | |
2151 | const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w; |
2152 | const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size; |
2153 | |
2154 | const int nthr_NS = nthr.N * nthr.S; |
2155 | const bool need_reduction = nthr_NS > 1; |
2156 | |
2157 | acc_data_t *diff_gamma = diff_scale; |
2158 | acc_data_t *diff_beta = diff_shift; |
2159 | |
2160 | acc_data_t *const r_diff_gamma = need_reduction ? rbuf : diff_gamma; |
2161 | acc_data_t *const r_diff_beta |
2162 | = need_reduction ? rbuf + nthr_NS * size_C_stat : diff_beta; |
2163 | |
2164 | auto reduce = [&]() { |
2165 | if (!need_reduction) return; |
2166 | |
2167 | // diff_gamma |
2168 | const acc_data_t *loc_diff_gamma = r_diff_gamma; |
2169 | for (dim_t c = 0; c < size_C_stat; ++c) |
2170 | diff_gamma[c] = loc_diff_gamma[c]; |
2171 | for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) { |
2172 | loc_diff_gamma += size_C_stat; |
2173 | for (dim_t c = 0; c < size_C_stat; ++c) |
2174 | diff_gamma[c] += loc_diff_gamma[c]; |
2175 | } |
2176 | |
2177 | // diff_beta |
2178 | const acc_data_t *loc_diff_beta = r_diff_beta; |
2179 | for (dim_t c = 0; c < size_C_stat; ++c) |
2180 | diff_beta[c] = loc_diff_beta[c]; |
2181 | for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) { |
2182 | loc_diff_beta += size_C_stat; |
2183 | for (dim_t c = 0; c < size_C_stat; ++c) |
2184 | diff_beta[c] += loc_diff_beta[c]; |
2185 | } |
2186 | }; |
2187 | |
2188 | parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) { |
2189 | assert(nthr_glob == nthr.glob); |
2190 | const auto ithr = map_thread(ithr_glob, nthr); |
2191 | bnorm_dims_t start, stop; |
2192 | work_distribution(C_blks, ithr, nthr, start, stop); |
2193 | |
2194 | const int ithr_NS = ithr.N * nthr.S + ithr.S; |
2195 | acc_data_t *loc_diff_gamma = &r_diff_gamma[ithr_NS * size_C_stat]; |
2196 | acc_data_t *loc_diff_beta = &r_diff_beta[ithr_NS * size_C_stat]; |
2197 | |
2198 | auto c = typename jit_bnorm_bwd_diff_ss_t<isa>::call_params_t(); |
2199 | c.N = stop.N - start.N; |
2200 | c.C = stop.C - start.C; |
2201 | c.S = stop.S - start.S; |
2202 | |
2203 | const size_t d_off = start.N * stride_N + start.C * stride_C |
2204 | + start.S * stride_S; |
2205 | c.src = (void *)((char *)src + d_off * dt_size_); |
2206 | c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_); |
2207 | c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr; |
2208 | c.mean = &mean[start.C * simd_w]; |
2209 | c.var = &var[start.C * simd_w]; |
2210 | c.diff_gamma = &loc_diff_gamma[start.C * simd_w]; |
2211 | c.diff_beta = &loc_diff_beta[start.C * simd_w]; |
2212 | c.blk_has_tail = blk_has_tail && stop.C == C_blks; |
2213 | |
2214 | (*ker_bwd_diff_ss_)(&c); |
2215 | }); |
2216 | |
2217 | reduce(); |
2218 | } |
2219 | |
2220 | void exec_bwd_step_normalization(const dim_t C_blks, |
2221 | const bnorm_dims_t &nthr, const void *src, void *diff_src, |
2222 | const void *diff_dst, const acc_data_t *mean, const acc_data_t *var, |
2223 | const uint8_t *ws, const acc_data_t *scale, |
2224 | const acc_data_t *diff_scale, const acc_data_t *diff_shift, |
2225 | bool blk_has_tail) { |
2226 | size_t stride_C, stride_N, stride_S; |
2227 | std::tie(stride_N, stride_S, stride_C) |
2228 | = get_data_strides<isa>(pd_, tag_kind_); |
2229 | |
2230 | parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) { |
2231 | assert(nthr_glob == nthr.glob); |
2232 | const auto ithr = map_thread(ithr_glob, nthr); |
2233 | bnorm_dims_t start, stop; |
2234 | work_distribution(C_blks, ithr, nthr, start, stop); |
2235 | |
2236 | auto c = typename jit_bnorm_bwd_t<isa>::call_params_t(); |
2237 | c.N = stop.N - start.N; |
2238 | c.C = stop.C - start.C; |
2239 | c.S = stop.S - start.S; |
2240 | |
2241 | const size_t d_off = start.N * stride_N + start.C * stride_C |
2242 | + start.S * stride_S; |
2243 | c.src = (void *)((char *)src + d_off * dt_size_); |
2244 | c.diff_src = (void *)((char *)diff_src + d_off * dt_size_); |
2245 | c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_); |
2246 | c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr; |
2247 | c.mean = &mean[start.C * simd_w]; |
2248 | c.var = &var[start.C * simd_w]; |
2249 | c.scale = scale ? &scale[start.C * simd_w] : nullptr; |
2250 | c.diff_scale = &diff_scale[start.C * simd_w]; |
2251 | c.diff_shift = &diff_shift[start.C * simd_w]; |
2252 | c.blk_has_tail = blk_has_tail && stop.C == C_blks; |
2253 | |
2254 | (*ker_bwd_)(&c); |
2255 | }); |
2256 | } |
2257 | |
2258 | void exec_bwd(const void *src, void *diff_src, const void *diff_dst, |
2259 | const acc_data_t *scale, acc_data_t *diff_scale, |
2260 | acc_data_t *diff_shift, const acc_data_t *mean, |
2261 | const acc_data_t *var, const uint8_t *ws, |
2262 | const memory_tracking::grantor_t &scratchpad) { |
2263 | auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction); |
2264 | if (use_tmp_diff_scale(pd_)) { |
2265 | auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss); |
2266 | diff_scale = pbuf; |
2267 | } |
2268 | if (use_tmp_diff_shift(pd_)) { |
2269 | auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss); |
2270 | size_t shift_off = use_tmp_diff_scale(pd_) ? pd_->C() : 0; |
2271 | diff_shift = &pbuf[shift_off]; |
2272 | } |
2273 | |
2274 | size_t stride_C; |
2275 | std::tie(std::ignore, std::ignore, stride_C) |
2276 | = get_data_strides<isa>(pd_, tag_kind_); |
2277 | |
2278 | dim_t C_blk_step = C_blk_step_; |
2279 | auto nthr = bnorm_dims_t(); |
2280 | |
2281 | thread_distribution(C_blk_step, nthr); |
2282 | |
2283 | for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) { |
2284 | if (C_blk_st + C_blk_step > C_blks_) { |
2285 | C_blk_step = C_blks_ - C_blk_st; |
2286 | thread_distribution(C_blk_step, nthr); |
2287 | } |
2288 | |
2289 | exec_bwd_step_diff_ss(C_blk_step, nthr, |
2290 | (void *)((char *)src + (C_blk_st * stride_C) * dt_size_), |
2291 | (void *)((char *)diff_dst |
2292 | + (C_blk_st * stride_C) * dt_size_), |
2293 | mean + C_blk_st * simd_w, var + C_blk_st * simd_w, |
2294 | ws + C_blk_st * stride_C / bits_per_byte, |
2295 | diff_scale + C_blk_st * simd_w, |
2296 | diff_shift + C_blk_st * simd_w, rbuf, |
2297 | (C_blk_st + C_blk_step) * simd_w > C_); |
2298 | |
2299 | exec_bwd_step_normalization(C_blk_step, nthr, |
2300 | (void *)((char *)src + (C_blk_st * stride_C) * dt_size_), |
2301 | (void *)((char *)diff_src |
2302 | + (C_blk_st * stride_C) * dt_size_), |
2303 | (void *)((char *)diff_dst |
2304 | + (C_blk_st * stride_C) * dt_size_), |
2305 | mean + C_blk_st * simd_w, var + C_blk_st * simd_w, |
2306 | ws + C_blk_st * stride_C / bits_per_byte, |
2307 | scale + C_blk_st * simd_w, diff_scale + C_blk_st * simd_w, |
2308 | diff_shift + C_blk_st * simd_w, |
2309 | (C_blk_st + C_blk_step) * simd_w > C_); |
2310 | } |
2311 | } |
2312 | |
2313 | private: |
2314 | static bool use_tmp_stats(const batch_normalization_pd_t *pd) { |
2315 | return !pd->stats_is_src() |
2316 | && pd->desc()->prop_kind == prop_kind::forward_inference; |
2317 | } |
2318 | |
2319 | static bool use_tmp_diff_scale(const batch_normalization_pd_t *pd) { |
2320 | return (!pd->is_fwd() && !pd->use_scale()) |
2321 | || pd->desc()->prop_kind == prop_kind::backward_data; |
2322 | } |
2323 | |
2324 | static bool use_tmp_diff_shift(const batch_normalization_pd_t *pd) { |
2325 | return (!pd->is_fwd() && !pd->use_shift()) |
2326 | || pd->desc()->prop_kind == prop_kind::backward_data; |
2327 | } |
2328 | |
2329 | void thread_distribution_nspc(dim_t C_blks, bnorm_dims_t &nthr) { |
2330 | if (normalize_only(pd_)) { |
2331 | // We want to keep some granularity on S so that we can |
2332 | // stay on 1 socket if possible, this is why we divide |
2333 | // work in chunks fitting in L2 |
2334 | |
2335 | dim_t n_stats_ss_tensors = pd_->use_scale() + pd_->use_shift(); |
2336 | dim_t size_stats_ss_tensors = n_stats_ss_tensors * get_c_padded(pd_) |
2337 | * sizeof(acc_data_t); |
2338 | |
2339 | dim_t size_src_dst = 2 * N_ * S_ * get_c_padded(pd_) |
2340 | * types::data_type_size(pd_->src_md()->data_type); |
2341 | |
2342 | dim_t total_size = size_src_dst + size_stats_ss_tensors; |
2343 | |
2344 | // Try to create at least nthr_ chunks for realtime inference. Not |
2345 | // enabled for throughput inference to avoid potential regressions |
2346 | // for multi-socket runs with threadpool runtime. |
2347 | // TODO: Enable for throughput inference. |
2348 | const int n_chunks_min = nthr_ <= 8 ? nthr_ : 1; |
2349 | const size_t l2_per_core = platform::get_per_core_cache_size(2); |
2350 | dim_t n_chunks |
2351 | = nstl::max<dim_t>(n_chunks_min, total_size / l2_per_core); |
2352 | |
2353 | // we prioritize parallelization on N, then S, and finally C |
2354 | nthr.N = utils::saturate<dim_t>(1, N_, n_chunks); |
2355 | nthr.S = utils::saturate<dim_t>(1, S_, n_chunks / nthr.N); |
2356 | nthr.C = utils::saturate<dim_t>( |
2357 | 1, C_blks, n_chunks / (nthr.N * nthr.S)); |
2358 | } else { |
2359 | if ((nthr_ <= C_blks && nthr_ == 1) || C_blks <= 8) |
2360 | nthr.C = 1; |
2361 | else if (nthr_ >= 8 && C_blks <= 32) |
2362 | nthr.C = 8; |
2363 | else { |
2364 | nthr.C = math::gcd((dim_t)nthr_, C_blks); |
2365 | // Unroll by channels in JIT kernel |
2366 | if ((nthr.C == C_blks) || (nthr.C == nthr_)) nthr.C = 1; |
2367 | } |
2368 | nthr.N = utils::saturate((dim_t)1, N_, nthr_ / nthr.C); |
2369 | nthr.S = utils::saturate((dim_t)1, S_, nthr_ / (nthr.C * nthr.N)); |
2370 | } |
2371 | } |
2372 | |
2373 | void thread_distribution(dim_t C_blks, bnorm_dims_t &nthr) { |
2374 | if (do_blocking_) { |
2375 | nthr.N = nstl::min<dim_t>(N_, nthr_); |
2376 | nthr.C = nstl::min<dim_t>(C_blks, nthr_ / nthr.N); |
2377 | nthr.S = utils::saturate((dim_t)1, S_, nthr_ / (nthr.C * nthr.N)); |
2378 | } else { |
2379 | if (tag_kind_ == jit_memory_tag_kind_t::nspc) { |
2380 | thread_distribution_nspc(C_blks, nthr); |
2381 | } else { |
2382 | nthr.C = math::gcd((dim_t)nthr_, C_blks); |
2383 | nthr.N = utils::saturate((dim_t)1, N_, nthr_ / nthr.C); |
2384 | nthr.S = utils::saturate( |
2385 | (dim_t)1, S_, nthr_ / (nthr.C * nthr.N)); |
2386 | } |
2387 | } |
2388 | nthr.glob = nthr.N * nthr.C * nthr.S; |
2389 | } |
2390 | |
2391 | int map_thread_c(int ithr_glob, const bnorm_dims_t &nthr) { |
2392 | return ithr_glob / nthr.N / nthr.S; |
2393 | } |
2394 | |
2395 | bnorm_dims_t map_thread(int ithr_glob, const bnorm_dims_t &nthr) { |
2396 | auto ithr = bnorm_dims_t(); |
2397 | ithr.glob = ithr_glob; |
2398 | ithr.C = map_thread_c(ithr.glob, nthr); |
2399 | ithr.N = ithr.glob / nthr.S % nthr.N; |
2400 | ithr.S = ithr.glob % nthr.S; |
2401 | return ithr; |
2402 | } |
2403 | |
2404 | void work_distribution_c(dim_t C_blks, int ithr_c, int nthr_c, |
2405 | dim_t &start_c, dim_t &stop_c) { |
2406 | balance211(C_blks, nthr_c, ithr_c, start_c, stop_c); |
2407 | } |
2408 | |
2409 | void work_distribution(dim_t C_blks, const bnorm_dims_t &ithr, |
2410 | const bnorm_dims_t &nthr, bnorm_dims_t &start, bnorm_dims_t &stop) { |
2411 | work_distribution_c(C_blks, ithr.C, nthr.C, start.C, stop.C); |
2412 | balance211(N_, nthr.N, ithr.N, start.N, stop.N); |
2413 | balance211(S_, nthr.S, ithr.S, start.S, stop.S); |
2414 | } |
2415 | |
2416 | const batch_normalization_pd_t *pd_; |
2417 | const jit_memory_tag_kind_t tag_kind_; |
2418 | const int simd_w; |
2419 | |
2420 | bool do_blocking_; |
2421 | |
2422 | int nthr_; |
2423 | |
2424 | dim_t N_, S_; // MB, D * H *W |
2425 | dim_t C_, C_blks_; // C / simd_w |
2426 | dim_t C_blk_step_; // for C_blks = 0 .. C_blks_, += C_blk_step_ |
2427 | |
2428 | std::unique_ptr<jit_bnorm_fwd_t<isa>> ker_fwd_; |
2429 | std::unique_ptr<jit_bnorm_fwd_mean_t<isa>> ker_fwd_mean_; |
2430 | std::unique_ptr<jit_bnorm_fwd_var_t<isa>> ker_fwd_var_; |
2431 | std::unique_ptr<jit_bnorm_bwd_t<isa>> ker_bwd_; |
2432 | std::unique_ptr<jit_bnorm_bwd_diff_ss_t<isa>> ker_bwd_diff_ss_; |
2433 | |
2434 | size_t dt_size_; |
2435 | }; |
2436 | } // namespace bnorm_tbb_impl |
2437 | |
2438 | using namespace data_type; |
2439 | using namespace format_tag; |
2440 | using namespace utils; |
2441 | |
2442 | /* fwd */ |
2443 | template <cpu_isa_t isa> |
2444 | status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::pd_t::init( |
2445 | engine_t *engine) { |
2446 | const bool ok = is_fwd() && mayiuse(isa) && !has_zero_dim_memory() |
2447 | && one_of(src_md()->data_type, f32, bf16, f16) |
2448 | && src_md()->data_type == dst_md()->data_type |
2449 | && IMPLICATION(src_md()->data_type == bf16, |
2450 | is_superset(isa, avx512_core) |
2451 | || (isa == avx2 && mayiuse(avx2_vnni_2))) |
2452 | // Note: re-using avx512_core/avx2 implementation for f16. |
2453 | // This is okay as currently, we do not support binary post-ops |
2454 | // for this primitive. |
2455 | && IMPLICATION(src_md()->data_type == f16, |
2456 | (is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16)) |
2457 | || (isa == avx2 && mayiuse(avx2_vnni_2))) |
2458 | && check_scale_shift_data_type() |
2459 | && (attr()->has_default_values() |
2460 | || with_relu_post_op(is_training())) |
2461 | && set_default_formats_common() |
2462 | && memory_desc_wrapper(src_md()) == memory_desc_wrapper(dst_md()); |
2463 | if (!ok) return status::unimplemented; |
2464 | |
2465 | // BN+Add+Relu fusion is not currently implemented |
2466 | if (fuse_norm_add_relu()) return status::unimplemented; |
2467 | |
2468 | const format_tag_t blocked_tag = is_superset(isa, avx512_core) |
2469 | ? utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c) |
2470 | : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); |
2471 | |
2472 | const format_tag_t blocked_format |
2473 | = memory_desc_matches_tag(*src_md(), blocked_tag) |
2474 | ? blocked_tag |
2475 | : format_tag::undef; |
2476 | const format_tag_t nspc_format |
2477 | = memory_desc_matches_one_of_tag(*src_md(), nc, nwc, nhwc, ndhwc); |
2478 | |
2479 | if (memory_desc_matches_tag(*dst_md(), blocked_format)) |
2480 | tag_kind_ = jit_memory_tag_kind_t::blocked; |
2481 | else if (memory_desc_matches_tag(*dst_md(), nspc_format)) { |
2482 | tag_kind_ = jit_memory_tag_kind_t::nspc; |
2483 | const int simd_w = get_simd_w<isa>(tag_kind_); |
2484 | if (C() % simd_w != 0) return status::unimplemented; |
2485 | } else |
2486 | return status::unimplemented; |
2487 | |
2488 | // AVX2 only supports xf16 on plain layout and inference |
2489 | if (utils::one_of(src_md()->data_type, bf16, f16) && isa == avx2 |
2490 | && (is_training() |
2491 | || !memory_desc_matches_tag(*dst_md(), nspc_format))) |
2492 | return status::unimplemented; |
2493 | |
2494 | const bool isa_supports_avx2 = is_superset(isa, avx2); |
2495 | if (is_training() && fuse_norm_relu()) { |
2496 | if (!isa_supports_avx2) return status::unimplemented; |
2497 | init_default_ws(1); |
2498 | } |
2499 | |
2500 | if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() |
2501 | && !isa_supports_avx2) |
2502 | return status::unimplemented; |
2503 | |
2504 | auto scratchpad = scratchpad_registry().registrar(); |
2505 | bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this); |
2506 | |
2507 | return status::success; |
2508 | } |
2509 | |
2510 | template <cpu_isa_t isa> |
2511 | jit_uni_tbb_batch_normalization_fwd_t< |
2512 | isa>::jit_uni_tbb_batch_normalization_fwd_t(const pd_t *apd) |
2513 | : primitive_t(apd) {} |
2514 | |
2515 | template <cpu_isa_t isa> |
2516 | status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::init(engine_t *engine) { |
2517 | CHECK(safe_ptr_assign(bnorm_driver_, |
2518 | new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_))); |
2519 | return bnorm_driver_->create_kernel(); |
2520 | } |
2521 | |
2522 | template <cpu_isa_t isa> |
2523 | status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::execute( |
2524 | const exec_ctx_t &ctx) const { |
2525 | |
2526 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
2527 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
2528 | auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT); |
2529 | |
2530 | auto mean = pd()->stats_is_src() ? const_cast<acc_data_t *>( |
2531 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN)) |
2532 | : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN); |
2533 | auto var = pd()->stats_is_src() |
2534 | ? const_cast<acc_data_t *>( |
2535 | CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE)) |
2536 | : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE); |
2537 | |
2538 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
2539 | auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE); |
2540 | |
2541 | auto scratchpad = ctx.get_scratchpad_grantor(); |
2542 | |
2543 | bnorm_driver_->exec_fwd(src, dst, scale, shift, mean, var, ws, scratchpad); |
2544 | |
2545 | return status::success; |
2546 | } |
2547 | |
2548 | template <cpu_isa_t isa> |
2549 | jit_uni_tbb_batch_normalization_fwd_t< |
2550 | isa>::~jit_uni_tbb_batch_normalization_fwd_t() |
2551 | = default; |
2552 | |
2553 | template struct jit_uni_tbb_batch_normalization_fwd_t<sse41>; |
2554 | template struct jit_uni_tbb_batch_normalization_fwd_t<avx2>; |
2555 | template struct jit_uni_tbb_batch_normalization_fwd_t<avx512_core>; |
2556 | |
2557 | /* bwd */ |
2558 | template <cpu_isa_t isa> |
2559 | status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::pd_t::init( |
2560 | engine_t *engine) { |
2561 | bool ok = !is_fwd() && mayiuse(isa) && !has_zero_dim_memory() |
2562 | && one_of(src_md()->data_type, f32, bf16, f16) |
2563 | && src_md()->data_type == diff_src_md()->data_type |
2564 | && diff_src_md()->data_type == diff_dst_md()->data_type |
2565 | && IMPLICATION( |
2566 | src_md()->data_type == bf16, is_superset(isa, avx512_core)) |
2567 | // Note: re-using avx512_core implementation for f16. This is okay |
2568 | // as currently, we do not support binary post-ops for this |
2569 | // primitive. |
2570 | && IMPLICATION(src_md()->data_type == f16, |
2571 | is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16)) |
2572 | && check_scale_shift_data_type() && attr()->has_default_values() |
2573 | && set_default_formats_common() |
2574 | && memory_desc_wrapper(diff_src_md()) |
2575 | == memory_desc_wrapper(diff_dst_md()); |
2576 | if (!ok) return status::unimplemented; |
2577 | |
2578 | // BN+Add+Relu fusion is not currently implemented |
2579 | if (fuse_norm_add_relu()) return status::unimplemented; |
2580 | |
2581 | const format_tag_t blocked_tag = is_superset(isa, avx512_core) |
2582 | ? utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c) |
2583 | : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); |
2584 | |
2585 | const format_tag_t blocked_format |
2586 | = memory_desc_matches_tag(*src_md(), blocked_tag) |
2587 | ? blocked_tag |
2588 | : format_tag::undef; |
2589 | const format_tag_t nspc_format |
2590 | = memory_desc_matches_one_of_tag(*src_md(), nc, nwc, nhwc, ndhwc); |
2591 | |
2592 | if (memory_desc_matches_tag(*diff_src_md(), blocked_format)) |
2593 | tag_kind_ = jit_memory_tag_kind_t::blocked; |
2594 | else if (memory_desc_matches_tag(*diff_src_md(), nspc_format)) { |
2595 | tag_kind_ = jit_memory_tag_kind_t::nspc; |
2596 | const int simd_w = get_simd_w<isa>(tag_kind_); |
2597 | if (C() % simd_w != 0) return status::unimplemented; |
2598 | } else |
2599 | return status::unimplemented; |
2600 | |
2601 | const bool isa_supports_avx2 = is_superset(isa, avx2); |
2602 | if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() |
2603 | && !isa_supports_avx2) |
2604 | return status::unimplemented; |
2605 | |
2606 | if (fuse_norm_relu()) { |
2607 | if (!isa_supports_avx2) return status::unimplemented; |
2608 | init_default_ws(1); |
2609 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
2610 | } |
2611 | |
2612 | auto scratchpad = scratchpad_registry().registrar(); |
2613 | bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this); |
2614 | |
2615 | return status::success; |
2616 | } |
2617 | |
2618 | template <cpu_isa_t isa> |
2619 | jit_uni_tbb_batch_normalization_bwd_t< |
2620 | isa>::jit_uni_tbb_batch_normalization_bwd_t(const pd_t *apd) |
2621 | : primitive_t(apd) {} |
2622 | |
2623 | template <cpu_isa_t isa> |
2624 | status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::init(engine_t *engine) { |
2625 | CHECK(safe_ptr_assign(bnorm_driver_, |
2626 | new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_))); |
2627 | return bnorm_driver_->create_kernel(); |
2628 | } |
2629 | |
2630 | template <cpu_isa_t isa> |
2631 | status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::execute( |
2632 | const exec_ctx_t &ctx) const { |
2633 | |
2634 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
2635 | auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN); |
2636 | auto var = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE); |
2637 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
2638 | auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE); |
2639 | auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE); |
2640 | |
2641 | auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC); |
2642 | auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE); |
2643 | auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT); |
2644 | |
2645 | auto scratchpad = ctx.get_scratchpad_grantor(); |
2646 | |
2647 | bnorm_driver_->exec_bwd(src, diff_src, diff_dst, scale, diff_scale, |
2648 | diff_shift, mean, var, ws, scratchpad); |
2649 | |
2650 | return status::success; |
2651 | } |
2652 | |
2653 | template <cpu_isa_t isa> |
2654 | jit_uni_tbb_batch_normalization_bwd_t< |
2655 | isa>::~jit_uni_tbb_batch_normalization_bwd_t() |
2656 | = default; |
2657 | |
2658 | template struct jit_uni_tbb_batch_normalization_bwd_t<sse41>; |
2659 | template struct jit_uni_tbb_batch_normalization_bwd_t<avx2>; |
2660 | template struct jit_uni_tbb_batch_normalization_bwd_t<avx512_core>; |
2661 | } // namespace x64 |
2662 | } // namespace cpu |
2663 | } // namespace impl |
2664 | } // namespace dnnl |
2665 | |