1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/math_utils.hpp"
22#include "common/nstl.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/x64/jit_generator.hpp"
27
28#include "cpu/x64/jit_uni_batch_normalization_s8.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace Xbyak;
36
37using data_t = int8_t;
38
39struct jit_uni_bnorm_s8_call_params_t {
40 // keep int sizes at 8 bytes -- jit code expects this
41 size_t channel_offt_count, spat_offt_count;
42 float eps;
43 const float *scale, *shift, *mean, *var;
44 const data_t *src, *dst;
45};
46
47template <cpu_isa_t isa>
48struct jit_bnorm_base_t : public jit_generator {
49
50 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_s8_t)
51
52 using Vmm = typename cpu_isa_traits<isa>::Vmm;
53 const AddressFrame &vmmword
54 = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword);
55 const int vlen = cpu_isa_traits<isa>::vlen;
56
57 const batch_normalization_pd_t *pd_;
58
59 Reg64 reg_param = abi_param1;
60
61 Reg64 reg_scale = rbx;
62 Reg64 reg_shift = rdx;
63 Reg64 reg_mean = rbp;
64
65 Reg64 reg_channel_offt_count = r8;
66 Reg64 reg_spat_offt = r9;
67 Reg64 reg_spat_offt_count = r10;
68 Reg64 reg_tmp = r11;
69 Reg64 reg_src = r12;
70 Reg64 reg_dst = r13;
71 Reg64 reg_var = r14;
72 Reg64 reg_channel_offt_1byte = r15;
73 Reg64 reg_channel_offt_4byte = rax;
74 Reg64 reg_relu_alpha = abi_not_param1;
75
76 Opmask kstore_mask = Opmask(1);
77
78 Vmm vzero = Vmm(isa == avx512_core ? 29 : 13);
79 Xmm xone = Xmm(14);
80 Vmm vone = Vmm(isa == avx512_core ? 30 : 14);
81 Vmm veps = Vmm(isa == avx512_core ? 31 : 15);
82 Vmm vmm_aux = Vmm(isa == avx512_core ? 28 : 10); // shared with 'veps'
83 Vmm vmm_mask = Vmm(0); // used for AVX2 and SSE41
84
85 size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
86 size_t c_in_xmm_ = (isa == sse41) ? 8 : 16;
87 size_t chan_data_offt_;
88 size_t num_c_blocks_;
89 size_t c_tail_;
90 bool with_relu_;
91 bool has_alpha_value_;
92
93 void compute_predefined_variables() {
94 chan_data_offt_ = pd_->C() * sizeof(float);
95 num_c_blocks_ = pd_->C() / c_in_xmm_;
96 c_tail_ = pd_->C() % c_in_xmm_;
97 with_relu_ = (pd_->with_relu_post_op(false) || pd_->fuse_norm_relu())
98 && pd_->is_fwd();
99 has_alpha_value_ = with_relu_ && pd_->with_relu_post_op(false)
100 && pd_->alpha() != 0;
101 }
102
103 void load_common_params() {
104 mov(reg_tmp, float2int(1.0f));
105 uni_vmovq(xone, reg_tmp);
106 uni_vbroadcastss(vone, xone);
107
108#define PARAM_OFF(x) offsetof(jit_uni_bnorm_s8_call_params_t, x)
109 uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
110 uni_vpxor(vzero, vzero, vzero);
111
112 mov(reg_channel_offt_count,
113 ptr[reg_param + PARAM_OFF(channel_offt_count)]);
114 mov(reg_spat_offt_count, ptr[reg_param + PARAM_OFF(spat_offt_count)]);
115 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
116 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
117 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
118 mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
119 mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]);
120 mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
121#undef PARAM_OFF
122
123 if (has_alpha_value_) { mov(reg_relu_alpha, float2int(pd_->alpha())); }
124 }
125
126 Address mean_ptr(size_t offt = 0) {
127 return vmmword[reg_mean + reg_channel_offt_4byte + offt];
128 }
129
130 Address var_ptr(size_t offt = 0) {
131 return vmmword[reg_var + reg_channel_offt_4byte + offt];
132 }
133
134 Address scale_ptr(size_t offt = 0) {
135 return vmmword[reg_scale + reg_channel_offt_4byte + offt
136 + 0 * chan_data_offt_];
137 }
138
139 Address shift_ptr(size_t offt = 0) {
140 return vmmword[reg_shift + reg_channel_offt_4byte + offt
141 + 0 * chan_data_offt_];
142 }
143
144 Address src_ptr(size_t offt = 0) {
145 return vmmword[reg_src + reg_spat_offt + offt];
146 }
147
148 Address dst_ptr(size_t offt = 0) {
149 return vmmword[reg_dst + reg_spat_offt + offt];
150 }
151
152 virtual void prepare_tail_mask() {}
153 virtual void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar,
154 size_t offt, bool need_tail) {}
155 virtual void load_scale(const Vmm &vscale, size_t offt, bool need_tail) {}
156 virtual void load_shift(const Vmm &vshift, size_t offt, bool need_tail) {}
157 virtual void compute_dst(bool need_tail) {}
158
159 // Precomputes vscale and vshift for following
160 // `vdst = vscale * vsrc + vshift`
161 void compute_vscaleshift(const Vmm &vscale, const Vmm &vshift,
162 const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
163 bool need_tail) {
164 load_mean_and_var(vmean, vsqrtvar, offt, need_tail);
165 uni_vaddps(vsqrtvar, vsqrtvar, veps);
166 uni_vsqrtps(vsqrtvar, vsqrtvar);
167
168 if (pd_->use_scale() && pd_->use_shift()) {
169 load_scale(vscale, offt, need_tail);
170 uni_vdivps(vscale, vscale, vsqrtvar);
171 load_shift(vshift, offt, need_tail);
172 uni_vfnmadd231ps(vshift, vmean, vscale);
173 } else if (pd_->use_scale()) {
174 load_scale(vscale, offt, need_tail);
175 uni_vdivps(vscale, vscale, vsqrtvar);
176 uni_vmulps(vmean, vmean, vscale);
177 uni_vsubps(vshift, vzero, vmean, vshift);
178 } else if (pd_->use_shift()) {
179 uni_vdivps(vscale, vone, vsqrtvar, vscale);
180 load_shift(vshift, offt, need_tail);
181 uni_vfnmadd231ps(vshift, vmean, vscale);
182 } else {
183 uni_vdivps(vscale, vone, vsqrtvar, vscale);
184 uni_vmulps(vmean, vmean, vscale);
185 uni_vsubps(vshift, vzero, vmean, vshift);
186 }
187 }
188
189 void forward() {
190 xor_(reg_channel_offt_1byte, reg_channel_offt_1byte);
191 xor_(reg_channel_offt_4byte, reg_channel_offt_4byte);
192 mov(reg_tmp, sizeof(data_t) * c_in_xmm_);
193
194 if (num_c_blocks_) compute_dst(false);
195 if (c_tail_) compute_dst(true);
196 }
197
198 // either this stub or duplication at each jit_binary_t ctor due to methods
199 // that are participated are not defined at the moment of base ctor
200 // initialization.
201 void generate() override {
202 preamble();
203 compute_predefined_variables();
204 load_common_params();
205 prepare_tail_mask();
206 forward();
207 postamble();
208 }
209
210 jit_bnorm_base_t(const batch_normalization_pd_t *pd)
211 : jit_generator(jit_name()), pd_(pd) {}
212};
213
214template <cpu_isa_t isa>
215struct jit_bnorm_s8_t;
216
217template <>
218struct jit_bnorm_s8_t<avx512_core> : public jit_bnorm_base_t<avx512_core> {
219 Opmask tail_opmask = Opmask(1); // f32 mask for channel math
220
221 void prepare_tail_mask() override {
222 if (!c_tail_) return;
223
224 const int mask_f32 = (1 << c_tail_) - 1;
225
226 Reg32 regw_tmp = reg_tmp.cvt32();
227 mov(regw_tmp, mask_f32);
228 kmovw(tail_opmask, regw_tmp);
229 }
230
231 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
232 bool need_tail) override {
233 if (need_tail) {
234 uni_vmovups_tail(vmean, tail_opmask, mean_ptr(offt));
235 uni_vmovups_tail(vsqrtvar, tail_opmask, var_ptr(offt));
236 } else {
237 uni_vmovups(vmean, mean_ptr(offt));
238 uni_vmovups(vsqrtvar, var_ptr(offt));
239 }
240 }
241
242 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
243 if (need_tail) {
244 uni_vmovups_tail(vscale, tail_opmask, scale_ptr(offt));
245 } else {
246 uni_vmovups(vscale, scale_ptr(offt));
247 }
248 }
249
250 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
251 if (need_tail) {
252 uni_vmovups_tail(vshift, tail_opmask, shift_ptr(offt));
253 } else {
254 uni_vmovups(vshift, shift_ptr(offt));
255 }
256 }
257
258 void process_relu_alpha(Vmm vmm_dst) {
259 const Xmm xmm_aux = Xmm(vmm_aux.getIdx());
260 vmovq(xmm_aux, reg_relu_alpha);
261 vbroadcastss(vmm_aux, xmm_aux);
262 vcmpps(kstore_mask, vzero, vmm_dst, _cmp_lt_os);
263 vmulps(vmm_aux, vmm_dst, vmm_aux);
264 vblendmps(vmm_dst | kstore_mask, vmm_aux, vmm_dst);
265 }
266
267 void compute_dst(bool need_tail) override {
268 Label c_loop;
269 L(c_loop);
270 {
271 Xmm x = Xmm(0);
272 Vmm v = Vmm(0);
273 Vmm vscale = Vmm(1);
274 Vmm vshift = Vmm(2);
275 Vmm vmean = Vmm(3);
276 Vmm vsqrtvar = Vmm(4);
277
278 // compute single vscale and vshift vectors...
279 compute_vscaleshift(vscale, vshift, vmean, vsqrtvar, 0, need_tail);
280
281 // ... then process all spatial loop with it and move to the
282 // next channel chunk
283 mov(reg_spat_offt, reg_channel_offt_1byte);
284 Label mb_sp_loop;
285 L(mb_sp_loop);
286 {
287 if (need_tail) {
288 for (size_t tl = 0; tl < c_tail_; tl++)
289 vpinsrb(x, x, src_ptr(tl), tl);
290 vpmovsxbd(v, x);
291 } else
292 vpmovsxbd(v, src_ptr());
293
294 vcvtdq2ps(v, v);
295
296 uni_vfmadd213ps(v, vscale, vshift);
297 if (with_relu_) {
298 if (has_alpha_value_)
299 process_relu_alpha(v);
300 else
301 uni_vmaxps(v, v, vzero);
302 }
303
304 vcvtps2dq(v, v);
305 if (need_tail) {
306 vpmovsdb(x, v);
307 for (size_t tl = 0; tl < c_tail_; tl++)
308 vpextrb(dst_ptr(tl), x, tl);
309 } else
310 vpmovsdb(dst_ptr(), v);
311
312 add(reg_spat_offt, reg_channel_offt_count);
313 cmp(reg_spat_offt, reg_spat_offt_count);
314 jl(mb_sp_loop);
315 }
316
317 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
318 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
319 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
320 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
321 cmp(reg_tmp, reg_channel_offt_count);
322 jle(c_loop);
323 }
324 }
325
326 jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
327 : jit_bnorm_base_t<avx512_core>(pd) {}
328};
329
330template <>
331struct jit_bnorm_s8_t<avx2> : public jit_bnorm_base_t<avx2> {
332 Vmm tail_vmask = Vmm(11);
333 Vmm body_vmask = Vmm(12);
334
335 void prepare_tail_mask() override {
336 // tail is always < 16, process it with two parts
337 static const uint32_t mask_half_ymm[8]
338 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0};
339 mov(reg_tmp, reinterpret_cast<size_t>(&mask_half_ymm[0]));
340 vmovups(body_vmask, ptr[reg_tmp]);
341
342 if (!c_tail_) return;
343
344 static const uint32_t mask_f32[14]
345 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
346 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
347
348 mov(reg_tmp,
349 reinterpret_cast<size_t>(&mask_f32[7 - c_tail_ % simd_w_]));
350 vmovups(tail_vmask, ptr[reg_tmp]);
351 }
352
353 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
354 bool need_tail) override {
355 if (need_tail) {
356 uni_vmovups_tail(vmean, tail_vmask, mean_ptr(offt));
357 uni_vmovups_tail(vsqrtvar, tail_vmask, var_ptr(offt));
358 } else {
359 uni_vmovups(vmean, mean_ptr(offt));
360 uni_vmovups(vsqrtvar, var_ptr(offt));
361 }
362 }
363
364 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
365 if (need_tail) {
366 uni_vmovups_tail(vscale, tail_vmask, scale_ptr(offt));
367 } else {
368 uni_vmovups(vscale, scale_ptr(offt));
369 }
370 }
371
372 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
373 if (need_tail) {
374 uni_vmovups_tail(vshift, tail_vmask, shift_ptr(offt));
375 } else {
376 uni_vmovups(vshift, shift_ptr(offt));
377 }
378 }
379
380 void process_relu_alpha(Vmm vmm_dst) {
381 const Xmm xmm_aux = Xmm(vmm_aux.getIdx());
382 uni_vpxor(vmm_mask, vmm_mask, vmm_mask);
383 vmovq(xmm_aux, reg_relu_alpha);
384 uni_vbroadcastss(vmm_aux, xmm_aux);
385 uni_vcmpps(vmm_mask, vmm_dst, vzero, _cmp_lt_os);
386 uni_vmulps(vmm_aux, vmm_aux, vmm_dst);
387 uni_vblendvps(
388 vmm_dst, vmm_dst, vmm_aux, vmm_mask); // swaped aux and dst
389 }
390
391 void compute_dst(bool need_tail) override {
392 Label c_loop;
393 L(c_loop);
394 {
395
396 Xmm x0 = Xmm(0);
397 Vmm v0 = Vmm(0);
398 Xmm x1 = Xmm(1);
399 Vmm v1 = Vmm(1);
400 Vmm vscale0 = Vmm(2);
401 Vmm vshift0 = Vmm(3);
402 Vmm vmean0 = Vmm(4);
403 Vmm vsqrtvar0 = Vmm(5);
404 Vmm vscale1 = Vmm(6);
405 Vmm vshift1 = Vmm(7);
406 Vmm vmean1 = Vmm(8);
407 Vmm vsqrtvar1 = Vmm(9);
408
409 // compute couple vscale and vshift vectors each of 8 channels...
410 compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
411 (c_tail_ < simd_w_ && need_tail) ? true : false);
412 if (!need_tail || c_tail_ > simd_w_) {
413 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
414 simd_w_ * sizeof(float), need_tail);
415 }
416
417 // ... then process all spatial loop with it and move to the
418 // next channel chunk
419 mov(reg_spat_offt, reg_channel_offt_1byte);
420 Label mb_sp_loop;
421 L(mb_sp_loop);
422 {
423
424 if (need_tail) {
425 for (size_t tl = 0; tl < c_tail_; tl++) {
426 if (tl < simd_w_) {
427 vpinsrb(x0, x0, src_ptr(tl), tl);
428 } else {
429 vpinsrb(x1, x1, src_ptr(tl), tl - simd_w_);
430 }
431 }
432 vpmovsxbd(v0, x0);
433 vpmovsxbd(v1, x1);
434 } else {
435 vpmovsxbd(v0, src_ptr());
436 vpmovsxbd(v1, src_ptr(simd_w_));
437 }
438
439 vcvtdq2ps(v0, v0);
440 vcvtdq2ps(v1, v1);
441
442 uni_vfmadd213ps(v0, vscale0, vshift0);
443 uni_vfmadd213ps(v1, vscale1, vshift1);
444 if (with_relu_) {
445 if (has_alpha_value_) {
446 Vmm vmm_dst_0 = Vmm(5);
447 Vmm vmm_dst_1 = Vmm(9);
448 uni_vmovups(vmm_dst_0, v0);
449 uni_vmovups(vmm_dst_1, v1);
450
451 process_relu_alpha(vmm_dst_0);
452 process_relu_alpha(vmm_dst_1);
453
454 uni_vmovups(v0, vmm_dst_0);
455 uni_vmovups(v1, vmm_dst_1);
456 } else {
457 uni_vmaxps(v0, v0, vzero);
458 uni_vmaxps(v1, v1, vzero);
459 }
460 }
461
462 vcvtps2dq(v0, v0); // BA
463 vcvtps2dq(v1, v1); // DC
464 vpackssdw(v0, v0, v1); // BA + DC -> DBCA
465 vpermq(v0, v0, 0xD8); // DBCA -> DCBA
466 vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
467 vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
468
469 if (need_tail) {
470 for (size_t tl = 0; tl < c_tail_; tl++) {
471 vpextrb(dst_ptr(tl), x0, tl);
472 }
473 } else {
474 // due to vpacksswb produces 32 integers in ymm, and top
475 // half of them are garbage, do 128-b masked store
476 vmaskmovps(dst_ptr(), body_vmask, v0);
477 }
478
479 add(reg_spat_offt, reg_channel_offt_count);
480 cmp(reg_spat_offt, reg_spat_offt_count);
481 jl(mb_sp_loop);
482 }
483
484 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
485 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
486 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
487 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
488 cmp(reg_tmp, reg_channel_offt_count);
489 jle(c_loop);
490 }
491 }
492
493 jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
494 : jit_bnorm_base_t<avx2>(pd) {}
495};
496
497template <>
498struct jit_bnorm_s8_t<sse41> : public jit_bnorm_base_t<sse41> {
499 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
500 bool need_tail) override {
501 if (need_tail) {
502 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
503 pinsrd(vmean, mean_ptr(offt + tl * sizeof(float)), tl);
504 pinsrd(vsqrtvar, var_ptr(offt + tl * sizeof(float)), tl);
505 }
506 } else {
507 movups(vmean, mean_ptr(offt));
508 movups(vsqrtvar, var_ptr(offt));
509 }
510 }
511
512 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
513 if (need_tail) {
514 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
515 pinsrd(vscale, scale_ptr(offt + tl * sizeof(float)), tl);
516 }
517 } else {
518 movups(vscale, scale_ptr(offt));
519 }
520 }
521
522 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
523 if (need_tail) {
524 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
525 pinsrd(vshift, shift_ptr(offt + tl * sizeof(float)), tl);
526 }
527 } else {
528 movups(vshift, shift_ptr(offt));
529 }
530 }
531
532 void process_relu_alpha(Vmm vmm_dst) {
533 const Xmm xmm_aux = Xmm(vmm_aux.getIdx());
534 uni_vpxor(vmm_mask, vmm_mask, vmm_mask);
535 vmovq(xmm_aux, reg_relu_alpha);
536 uni_vbroadcastss(vmm_aux, xmm_aux);
537 uni_vcmpps(vmm_mask, vmm_dst, vzero, _cmp_lt_os);
538 uni_vmulps(vmm_aux, vmm_aux, vmm_dst);
539 uni_vblendvps(
540 vmm_dst, vmm_dst, vmm_aux, vmm_mask); // swaped aux and dst
541 }
542
543 void compute_dst(bool need_tail) override {
544 const size_t copy_range = need_tail ? c_tail_ : c_in_xmm_;
545 Label c_loop;
546 L(c_loop);
547 {
548
549 Vmm v0 = Vmm(0);
550 Vmm v1 = Vmm(1);
551 Vmm vscale0 = Vmm(2);
552 Vmm vshift0 = Vmm(3);
553 Vmm vmean0 = Vmm(4);
554 Vmm vsqrtvar0 = Vmm(5);
555 Vmm vscale1 = Vmm(6);
556 Vmm vshift1 = Vmm(7);
557 Vmm vmean1 = Vmm(8);
558 Vmm vsqrtvar1 = Vmm(9);
559
560 // compute couple vscale and vshift vectors each of 8 channels...
561 compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
562 (c_tail_ < simd_w_ && need_tail) ? true : false);
563 if (!need_tail || c_tail_ > simd_w_) {
564 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
565 simd_w_ * sizeof(float), need_tail);
566 }
567
568 // ... then process all spatial loop with it and move to the
569 // next channel chunk
570 mov(reg_spat_offt, reg_channel_offt_1byte);
571 Label mb_sp_loop;
572 L(mb_sp_loop);
573 {
574 if (need_tail) {
575 for (size_t tl = 0; tl < copy_range; tl++) {
576 if (tl < simd_w_) {
577 pinsrb(v0, src_ptr(tl), tl);
578 } else {
579 pinsrb(v1, src_ptr(tl), (tl - simd_w_));
580 }
581 }
582 pmovsxbd(v0, v0);
583 pmovsxbd(v1, v1);
584 } else {
585 pmovsxbd(v0, src_ptr());
586 pmovsxbd(v1, src_ptr(simd_w_));
587 }
588
589 cvtdq2ps(v0, v0);
590 cvtdq2ps(v1, v1);
591
592 uni_vfmadd213ps(v0, vscale0, vshift0);
593 uni_vfmadd213ps(v1, vscale1, vshift1);
594 if (with_relu_) {
595 if (has_alpha_value_) {
596 Vmm vmm_dst_0 = Vmm(5);
597 Vmm vmm_dst_1 = Vmm(9);
598 movups(vmm_dst_0, v0);
599 movups(vmm_dst_1, v1);
600
601 process_relu_alpha(vmm_dst_0);
602 process_relu_alpha(vmm_dst_1);
603
604 movups(v0, vmm_dst_0);
605 movups(v1, vmm_dst_1);
606 } else {
607 maxps(v0, vzero);
608 maxps(v1, vzero);
609 }
610 }
611
612 cvtps2dq(v0, v0);
613 cvtps2dq(v1, v1);
614 packssdw(v0, v1);
615 movups(v1, v0);
616 packsswb(v0, v1);
617
618 // Potential perf gain is possible if combining two halves
619 // into a single vector register and use movups instead
620 // of byte stores.
621 for (size_t tl = 0; tl < copy_range; tl++) {
622 pextrb(dst_ptr(tl), v0, tl);
623 }
624
625 add(reg_spat_offt, reg_channel_offt_count);
626 cmp(reg_spat_offt, reg_spat_offt_count);
627 jl(mb_sp_loop);
628 }
629
630 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
631 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
632 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
633 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
634 cmp(reg_tmp, reg_channel_offt_count);
635 jle(c_loop);
636 }
637 }
638
639 jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
640 : jit_bnorm_base_t<sse41>(pd) {}
641};
642
643namespace bnorm_s8_impl {
644
645template <cpu_isa_t isa>
646struct driver_t : public c_compatible {
647 driver_t(const batch_normalization_pd_t *pd) : pd_(pd), ker_(pd_) {}
648 ~driver_t() = default;
649
650 // TODO: for problems where thread pieces don't fit L2 cache, add spatial
651 // re-balance using less pieces.
652 void exec(int ithr, int nthr, const data_t *src, data_t *dst,
653 const float *scale, const float *shift, const float *mean,
654 const float *var) {
655 dim_t N = pd_->MB();
656 dim_t C = pd_->C();
657 dim_t D = pd_->D();
658 dim_t H = pd_->H();
659 dim_t W = pd_->W();
660 dim_t SP = D * H * W;
661
662 jit_uni_bnorm_s8_call_params_t p;
663
664 p.eps = pd_->desc()->batch_norm_epsilon;
665
666 p.scale = scale;
667 p.shift = shift;
668 p.mean = mean;
669 p.var = var;
670
671 dim_t work_amount {N * SP}, start {0}, end {0};
672 balance211(work_amount, nthr, ithr, start, end);
673
674 p.channel_offt_count = C;
675 p.spat_offt_count = (end - start) * p.channel_offt_count;
676 p.src = src + start * p.channel_offt_count;
677 p.dst = dst + start * p.channel_offt_count;
678
679 if (p.spat_offt_count != 0) ker_(&p);
680 }
681
682 status_t create_kernel() { return ker_.create_kernel(); }
683
684private:
685 const batch_normalization_pd_t *pd_;
686
687 jit_bnorm_s8_t<isa> ker_;
688};
689
690} // namespace bnorm_s8_impl
691
692using namespace data_type;
693using namespace format_tag;
694using namespace utils;
695
696/* fwd */
697
698template <cpu_isa_t isa>
699status_t jit_uni_batch_normalization_s8_fwd_t<isa>::pd_t::init(
700 engine_t *engine) {
701 auto desired_fmt_tag = (ndims() == 4) ? nhwc : ndhwc;
702
703 bool ok = true && mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
704 && one_of(ndims(), 4, 5) && stats_is_src()
705 && src_md()->data_type == s8 && check_scale_shift_data_type()
706 && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
707 && (attr()->has_default_values() || this->with_relu_post_op(false));
708 if (!ok) return status::unimplemented;
709
710 // BN+Add+Relu fusion is not currently implemented
711 if (fuse_norm_add_relu()) return status::unimplemented;
712
713 return status::success;
714}
715
716template <cpu_isa_t isa>
717jit_uni_batch_normalization_s8_fwd_t<isa>::jit_uni_batch_normalization_s8_fwd_t(
718 const pd_t *apd)
719 : primitive_t(apd) {}
720
721template <cpu_isa_t isa>
722status_t jit_uni_batch_normalization_s8_fwd_t<isa>::init(engine_t *engine) {
723 CHECK(safe_ptr_assign(
724 bnorm_driver_, new bnorm_s8_impl::driver_t<isa>(pd())));
725 return bnorm_driver_->create_kernel();
726}
727
728template <cpu_isa_t isa>
729status_t jit_uni_batch_normalization_s8_fwd_t<isa>::execute(
730 const exec_ctx_t &ctx) const {
731
732 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
733 auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE);
734 auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT);
735 auto mean = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN));
736 auto var
737 = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE));
738 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
739
740 // do sequential if the problem is less than one 4K memory page
741 const bool force_sequential
742 = pd()->MB() * pd()->C() * pd()->D() * pd()->H() * pd()->W()
743 <= 4096;
744
745 parallel(force_sequential ? 1 : 0, [&](const int ithr, const int nthr) {
746 bnorm_driver_->exec(ithr, nthr, src, dst, scale, shift, mean, var);
747 });
748
749 return status::success;
750}
751
752template <cpu_isa_t isa>
753jit_uni_batch_normalization_s8_fwd_t<
754 isa>::~jit_uni_batch_normalization_s8_fwd_t() {
755 delete bnorm_driver_;
756}
757
758/* struct instantiation */
759template struct jit_uni_batch_normalization_s8_fwd_t<avx512_core>;
760template struct jit_uni_batch_normalization_s8_fwd_t<avx2>;
761template struct jit_uni_batch_normalization_s8_fwd_t<sse41>;
762
763} // namespace x64
764} // namespace cpu
765} // namespace impl
766} // namespace dnnl
767