1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/nstl.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/x64/jit_generator.hpp" |
27 | |
28 | #include "cpu/x64/jit_uni_batch_normalization_s8.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace Xbyak; |
36 | |
37 | using data_t = int8_t; |
38 | |
39 | struct jit_uni_bnorm_s8_call_params_t { |
40 | // keep int sizes at 8 bytes -- jit code expects this |
41 | size_t channel_offt_count, spat_offt_count; |
42 | float eps; |
43 | const float *scale, *shift, *mean, *var; |
44 | const data_t *src, *dst; |
45 | }; |
46 | |
47 | template <cpu_isa_t isa> |
48 | struct jit_bnorm_base_t : public jit_generator { |
49 | |
50 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_s8_t) |
51 | |
52 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
53 | const AddressFrame &vmmword |
54 | = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword); |
55 | const int vlen = cpu_isa_traits<isa>::vlen; |
56 | |
57 | const batch_normalization_pd_t *pd_; |
58 | |
59 | Reg64 reg_param = abi_param1; |
60 | |
61 | Reg64 reg_scale = rbx; |
62 | Reg64 reg_shift = rdx; |
63 | Reg64 reg_mean = rbp; |
64 | |
65 | Reg64 reg_channel_offt_count = r8; |
66 | Reg64 reg_spat_offt = r9; |
67 | Reg64 reg_spat_offt_count = r10; |
68 | Reg64 reg_tmp = r11; |
69 | Reg64 reg_src = r12; |
70 | Reg64 reg_dst = r13; |
71 | Reg64 reg_var = r14; |
72 | Reg64 reg_channel_offt_1byte = r15; |
73 | Reg64 reg_channel_offt_4byte = rax; |
74 | Reg64 reg_relu_alpha = abi_not_param1; |
75 | |
76 | Opmask kstore_mask = Opmask(1); |
77 | |
78 | Vmm vzero = Vmm(isa == avx512_core ? 29 : 13); |
79 | Xmm xone = Xmm(14); |
80 | Vmm vone = Vmm(isa == avx512_core ? 30 : 14); |
81 | Vmm veps = Vmm(isa == avx512_core ? 31 : 15); |
82 | Vmm vmm_aux = Vmm(isa == avx512_core ? 28 : 10); // shared with 'veps' |
83 | Vmm vmm_mask = Vmm(0); // used for AVX2 and SSE41 |
84 | |
85 | size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float); |
86 | size_t c_in_xmm_ = (isa == sse41) ? 8 : 16; |
87 | size_t chan_data_offt_; |
88 | size_t num_c_blocks_; |
89 | size_t c_tail_; |
90 | bool with_relu_; |
91 | bool has_alpha_value_; |
92 | |
93 | void compute_predefined_variables() { |
94 | chan_data_offt_ = pd_->C() * sizeof(float); |
95 | num_c_blocks_ = pd_->C() / c_in_xmm_; |
96 | c_tail_ = pd_->C() % c_in_xmm_; |
97 | with_relu_ = (pd_->with_relu_post_op(false) || pd_->fuse_norm_relu()) |
98 | && pd_->is_fwd(); |
99 | has_alpha_value_ = with_relu_ && pd_->with_relu_post_op(false) |
100 | && pd_->alpha() != 0; |
101 | } |
102 | |
103 | void load_common_params() { |
104 | mov(reg_tmp, float2int(1.0f)); |
105 | uni_vmovq(xone, reg_tmp); |
106 | uni_vbroadcastss(vone, xone); |
107 | |
108 | #define PARAM_OFF(x) offsetof(jit_uni_bnorm_s8_call_params_t, x) |
109 | uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); |
110 | uni_vpxor(vzero, vzero, vzero); |
111 | |
112 | mov(reg_channel_offt_count, |
113 | ptr[reg_param + PARAM_OFF(channel_offt_count)]); |
114 | mov(reg_spat_offt_count, ptr[reg_param + PARAM_OFF(spat_offt_count)]); |
115 | mov(reg_src, ptr[reg_param + PARAM_OFF(src)]); |
116 | mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); |
117 | mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); |
118 | mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]); |
119 | mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]); |
120 | mov(reg_var, ptr[reg_param + PARAM_OFF(var)]); |
121 | #undef PARAM_OFF |
122 | |
123 | if (has_alpha_value_) { mov(reg_relu_alpha, float2int(pd_->alpha())); } |
124 | } |
125 | |
126 | Address mean_ptr(size_t offt = 0) { |
127 | return vmmword[reg_mean + reg_channel_offt_4byte + offt]; |
128 | } |
129 | |
130 | Address var_ptr(size_t offt = 0) { |
131 | return vmmword[reg_var + reg_channel_offt_4byte + offt]; |
132 | } |
133 | |
134 | Address scale_ptr(size_t offt = 0) { |
135 | return vmmword[reg_scale + reg_channel_offt_4byte + offt |
136 | + 0 * chan_data_offt_]; |
137 | } |
138 | |
139 | Address shift_ptr(size_t offt = 0) { |
140 | return vmmword[reg_shift + reg_channel_offt_4byte + offt |
141 | + 0 * chan_data_offt_]; |
142 | } |
143 | |
144 | Address src_ptr(size_t offt = 0) { |
145 | return vmmword[reg_src + reg_spat_offt + offt]; |
146 | } |
147 | |
148 | Address dst_ptr(size_t offt = 0) { |
149 | return vmmword[reg_dst + reg_spat_offt + offt]; |
150 | } |
151 | |
152 | virtual void prepare_tail_mask() {} |
153 | virtual void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, |
154 | size_t offt, bool need_tail) {} |
155 | virtual void load_scale(const Vmm &vscale, size_t offt, bool need_tail) {} |
156 | virtual void load_shift(const Vmm &vshift, size_t offt, bool need_tail) {} |
157 | virtual void compute_dst(bool need_tail) {} |
158 | |
159 | // Precomputes vscale and vshift for following |
160 | // `vdst = vscale * vsrc + vshift` |
161 | void compute_vscaleshift(const Vmm &vscale, const Vmm &vshift, |
162 | const Vmm &vmean, const Vmm &vsqrtvar, size_t offt, |
163 | bool need_tail) { |
164 | load_mean_and_var(vmean, vsqrtvar, offt, need_tail); |
165 | uni_vaddps(vsqrtvar, vsqrtvar, veps); |
166 | uni_vsqrtps(vsqrtvar, vsqrtvar); |
167 | |
168 | if (pd_->use_scale() && pd_->use_shift()) { |
169 | load_scale(vscale, offt, need_tail); |
170 | uni_vdivps(vscale, vscale, vsqrtvar); |
171 | load_shift(vshift, offt, need_tail); |
172 | uni_vfnmadd231ps(vshift, vmean, vscale); |
173 | } else if (pd_->use_scale()) { |
174 | load_scale(vscale, offt, need_tail); |
175 | uni_vdivps(vscale, vscale, vsqrtvar); |
176 | uni_vmulps(vmean, vmean, vscale); |
177 | uni_vsubps(vshift, vzero, vmean, vshift); |
178 | } else if (pd_->use_shift()) { |
179 | uni_vdivps(vscale, vone, vsqrtvar, vscale); |
180 | load_shift(vshift, offt, need_tail); |
181 | uni_vfnmadd231ps(vshift, vmean, vscale); |
182 | } else { |
183 | uni_vdivps(vscale, vone, vsqrtvar, vscale); |
184 | uni_vmulps(vmean, vmean, vscale); |
185 | uni_vsubps(vshift, vzero, vmean, vshift); |
186 | } |
187 | } |
188 | |
189 | void forward() { |
190 | xor_(reg_channel_offt_1byte, reg_channel_offt_1byte); |
191 | xor_(reg_channel_offt_4byte, reg_channel_offt_4byte); |
192 | mov(reg_tmp, sizeof(data_t) * c_in_xmm_); |
193 | |
194 | if (num_c_blocks_) compute_dst(false); |
195 | if (c_tail_) compute_dst(true); |
196 | } |
197 | |
198 | // either this stub or duplication at each jit_binary_t ctor due to methods |
199 | // that are participated are not defined at the moment of base ctor |
200 | // initialization. |
201 | void generate() override { |
202 | preamble(); |
203 | compute_predefined_variables(); |
204 | load_common_params(); |
205 | prepare_tail_mask(); |
206 | forward(); |
207 | postamble(); |
208 | } |
209 | |
210 | jit_bnorm_base_t(const batch_normalization_pd_t *pd) |
211 | : jit_generator(jit_name()), pd_(pd) {} |
212 | }; |
213 | |
214 | template <cpu_isa_t isa> |
215 | struct jit_bnorm_s8_t; |
216 | |
217 | template <> |
218 | struct jit_bnorm_s8_t<avx512_core> : public jit_bnorm_base_t<avx512_core> { |
219 | Opmask tail_opmask = Opmask(1); // f32 mask for channel math |
220 | |
221 | void prepare_tail_mask() override { |
222 | if (!c_tail_) return; |
223 | |
224 | const int mask_f32 = (1 << c_tail_) - 1; |
225 | |
226 | Reg32 regw_tmp = reg_tmp.cvt32(); |
227 | mov(regw_tmp, mask_f32); |
228 | kmovw(tail_opmask, regw_tmp); |
229 | } |
230 | |
231 | void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt, |
232 | bool need_tail) override { |
233 | if (need_tail) { |
234 | uni_vmovups_tail(vmean, tail_opmask, mean_ptr(offt)); |
235 | uni_vmovups_tail(vsqrtvar, tail_opmask, var_ptr(offt)); |
236 | } else { |
237 | uni_vmovups(vmean, mean_ptr(offt)); |
238 | uni_vmovups(vsqrtvar, var_ptr(offt)); |
239 | } |
240 | } |
241 | |
242 | void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override { |
243 | if (need_tail) { |
244 | uni_vmovups_tail(vscale, tail_opmask, scale_ptr(offt)); |
245 | } else { |
246 | uni_vmovups(vscale, scale_ptr(offt)); |
247 | } |
248 | } |
249 | |
250 | void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override { |
251 | if (need_tail) { |
252 | uni_vmovups_tail(vshift, tail_opmask, shift_ptr(offt)); |
253 | } else { |
254 | uni_vmovups(vshift, shift_ptr(offt)); |
255 | } |
256 | } |
257 | |
258 | void process_relu_alpha(Vmm vmm_dst) { |
259 | const Xmm xmm_aux = Xmm(vmm_aux.getIdx()); |
260 | vmovq(xmm_aux, reg_relu_alpha); |
261 | vbroadcastss(vmm_aux, xmm_aux); |
262 | vcmpps(kstore_mask, vzero, vmm_dst, _cmp_lt_os); |
263 | vmulps(vmm_aux, vmm_dst, vmm_aux); |
264 | vblendmps(vmm_dst | kstore_mask, vmm_aux, vmm_dst); |
265 | } |
266 | |
267 | void compute_dst(bool need_tail) override { |
268 | Label c_loop; |
269 | L(c_loop); |
270 | { |
271 | Xmm x = Xmm(0); |
272 | Vmm v = Vmm(0); |
273 | Vmm vscale = Vmm(1); |
274 | Vmm vshift = Vmm(2); |
275 | Vmm vmean = Vmm(3); |
276 | Vmm vsqrtvar = Vmm(4); |
277 | |
278 | // compute single vscale and vshift vectors... |
279 | compute_vscaleshift(vscale, vshift, vmean, vsqrtvar, 0, need_tail); |
280 | |
281 | // ... then process all spatial loop with it and move to the |
282 | // next channel chunk |
283 | mov(reg_spat_offt, reg_channel_offt_1byte); |
284 | Label mb_sp_loop; |
285 | L(mb_sp_loop); |
286 | { |
287 | if (need_tail) { |
288 | for (size_t tl = 0; tl < c_tail_; tl++) |
289 | vpinsrb(x, x, src_ptr(tl), tl); |
290 | vpmovsxbd(v, x); |
291 | } else |
292 | vpmovsxbd(v, src_ptr()); |
293 | |
294 | vcvtdq2ps(v, v); |
295 | |
296 | uni_vfmadd213ps(v, vscale, vshift); |
297 | if (with_relu_) { |
298 | if (has_alpha_value_) |
299 | process_relu_alpha(v); |
300 | else |
301 | uni_vmaxps(v, v, vzero); |
302 | } |
303 | |
304 | vcvtps2dq(v, v); |
305 | if (need_tail) { |
306 | vpmovsdb(x, v); |
307 | for (size_t tl = 0; tl < c_tail_; tl++) |
308 | vpextrb(dst_ptr(tl), x, tl); |
309 | } else |
310 | vpmovsdb(dst_ptr(), v); |
311 | |
312 | add(reg_spat_offt, reg_channel_offt_count); |
313 | cmp(reg_spat_offt, reg_spat_offt_count); |
314 | jl(mb_sp_loop); |
315 | } |
316 | |
317 | // reg_tmp checks c_in_xmm_ channels ahead for further tail process |
318 | add(reg_tmp, sizeof(data_t) * c_in_xmm_); |
319 | add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_); |
320 | add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_); |
321 | cmp(reg_tmp, reg_channel_offt_count); |
322 | jle(c_loop); |
323 | } |
324 | } |
325 | |
326 | jit_bnorm_s8_t(const batch_normalization_pd_t *pd) |
327 | : jit_bnorm_base_t<avx512_core>(pd) {} |
328 | }; |
329 | |
330 | template <> |
331 | struct jit_bnorm_s8_t<avx2> : public jit_bnorm_base_t<avx2> { |
332 | Vmm tail_vmask = Vmm(11); |
333 | Vmm body_vmask = Vmm(12); |
334 | |
335 | void prepare_tail_mask() override { |
336 | // tail is always < 16, process it with two parts |
337 | static const uint32_t mask_half_ymm[8] |
338 | = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0}; |
339 | mov(reg_tmp, reinterpret_cast<size_t>(&mask_half_ymm[0])); |
340 | vmovups(body_vmask, ptr[reg_tmp]); |
341 | |
342 | if (!c_tail_) return; |
343 | |
344 | static const uint32_t mask_f32[14] |
345 | = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, |
346 | 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0}; |
347 | |
348 | mov(reg_tmp, |
349 | reinterpret_cast<size_t>(&mask_f32[7 - c_tail_ % simd_w_])); |
350 | vmovups(tail_vmask, ptr[reg_tmp]); |
351 | } |
352 | |
353 | void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt, |
354 | bool need_tail) override { |
355 | if (need_tail) { |
356 | uni_vmovups_tail(vmean, tail_vmask, mean_ptr(offt)); |
357 | uni_vmovups_tail(vsqrtvar, tail_vmask, var_ptr(offt)); |
358 | } else { |
359 | uni_vmovups(vmean, mean_ptr(offt)); |
360 | uni_vmovups(vsqrtvar, var_ptr(offt)); |
361 | } |
362 | } |
363 | |
364 | void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override { |
365 | if (need_tail) { |
366 | uni_vmovups_tail(vscale, tail_vmask, scale_ptr(offt)); |
367 | } else { |
368 | uni_vmovups(vscale, scale_ptr(offt)); |
369 | } |
370 | } |
371 | |
372 | void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override { |
373 | if (need_tail) { |
374 | uni_vmovups_tail(vshift, tail_vmask, shift_ptr(offt)); |
375 | } else { |
376 | uni_vmovups(vshift, shift_ptr(offt)); |
377 | } |
378 | } |
379 | |
380 | void process_relu_alpha(Vmm vmm_dst) { |
381 | const Xmm xmm_aux = Xmm(vmm_aux.getIdx()); |
382 | uni_vpxor(vmm_mask, vmm_mask, vmm_mask); |
383 | vmovq(xmm_aux, reg_relu_alpha); |
384 | uni_vbroadcastss(vmm_aux, xmm_aux); |
385 | uni_vcmpps(vmm_mask, vmm_dst, vzero, _cmp_lt_os); |
386 | uni_vmulps(vmm_aux, vmm_aux, vmm_dst); |
387 | uni_vblendvps( |
388 | vmm_dst, vmm_dst, vmm_aux, vmm_mask); // swaped aux and dst |
389 | } |
390 | |
391 | void compute_dst(bool need_tail) override { |
392 | Label c_loop; |
393 | L(c_loop); |
394 | { |
395 | |
396 | Xmm x0 = Xmm(0); |
397 | Vmm v0 = Vmm(0); |
398 | Xmm x1 = Xmm(1); |
399 | Vmm v1 = Vmm(1); |
400 | Vmm vscale0 = Vmm(2); |
401 | Vmm vshift0 = Vmm(3); |
402 | Vmm vmean0 = Vmm(4); |
403 | Vmm vsqrtvar0 = Vmm(5); |
404 | Vmm vscale1 = Vmm(6); |
405 | Vmm vshift1 = Vmm(7); |
406 | Vmm vmean1 = Vmm(8); |
407 | Vmm vsqrtvar1 = Vmm(9); |
408 | |
409 | // compute couple vscale and vshift vectors each of 8 channels... |
410 | compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0, |
411 | (c_tail_ < simd_w_ && need_tail) ? true : false); |
412 | if (!need_tail || c_tail_ > simd_w_) { |
413 | compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1, |
414 | simd_w_ * sizeof(float), need_tail); |
415 | } |
416 | |
417 | // ... then process all spatial loop with it and move to the |
418 | // next channel chunk |
419 | mov(reg_spat_offt, reg_channel_offt_1byte); |
420 | Label mb_sp_loop; |
421 | L(mb_sp_loop); |
422 | { |
423 | |
424 | if (need_tail) { |
425 | for (size_t tl = 0; tl < c_tail_; tl++) { |
426 | if (tl < simd_w_) { |
427 | vpinsrb(x0, x0, src_ptr(tl), tl); |
428 | } else { |
429 | vpinsrb(x1, x1, src_ptr(tl), tl - simd_w_); |
430 | } |
431 | } |
432 | vpmovsxbd(v0, x0); |
433 | vpmovsxbd(v1, x1); |
434 | } else { |
435 | vpmovsxbd(v0, src_ptr()); |
436 | vpmovsxbd(v1, src_ptr(simd_w_)); |
437 | } |
438 | |
439 | vcvtdq2ps(v0, v0); |
440 | vcvtdq2ps(v1, v1); |
441 | |
442 | uni_vfmadd213ps(v0, vscale0, vshift0); |
443 | uni_vfmadd213ps(v1, vscale1, vshift1); |
444 | if (with_relu_) { |
445 | if (has_alpha_value_) { |
446 | Vmm vmm_dst_0 = Vmm(5); |
447 | Vmm vmm_dst_1 = Vmm(9); |
448 | uni_vmovups(vmm_dst_0, v0); |
449 | uni_vmovups(vmm_dst_1, v1); |
450 | |
451 | process_relu_alpha(vmm_dst_0); |
452 | process_relu_alpha(vmm_dst_1); |
453 | |
454 | uni_vmovups(v0, vmm_dst_0); |
455 | uni_vmovups(v1, vmm_dst_1); |
456 | } else { |
457 | uni_vmaxps(v0, v0, vzero); |
458 | uni_vmaxps(v1, v1, vzero); |
459 | } |
460 | } |
461 | |
462 | vcvtps2dq(v0, v0); // BA |
463 | vcvtps2dq(v1, v1); // DC |
464 | vpackssdw(v0, v0, v1); // BA + DC -> DBCA |
465 | vpermq(v0, v0, 0xD8); // DBCA -> DCBA |
466 | vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC |
467 | vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA |
468 | |
469 | if (need_tail) { |
470 | for (size_t tl = 0; tl < c_tail_; tl++) { |
471 | vpextrb(dst_ptr(tl), x0, tl); |
472 | } |
473 | } else { |
474 | // due to vpacksswb produces 32 integers in ymm, and top |
475 | // half of them are garbage, do 128-b masked store |
476 | vmaskmovps(dst_ptr(), body_vmask, v0); |
477 | } |
478 | |
479 | add(reg_spat_offt, reg_channel_offt_count); |
480 | cmp(reg_spat_offt, reg_spat_offt_count); |
481 | jl(mb_sp_loop); |
482 | } |
483 | |
484 | // reg_tmp checks c_in_xmm_ channels ahead for further tail process |
485 | add(reg_tmp, sizeof(data_t) * c_in_xmm_); |
486 | add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_); |
487 | add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_); |
488 | cmp(reg_tmp, reg_channel_offt_count); |
489 | jle(c_loop); |
490 | } |
491 | } |
492 | |
493 | jit_bnorm_s8_t(const batch_normalization_pd_t *pd) |
494 | : jit_bnorm_base_t<avx2>(pd) {} |
495 | }; |
496 | |
497 | template <> |
498 | struct jit_bnorm_s8_t<sse41> : public jit_bnorm_base_t<sse41> { |
499 | void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt, |
500 | bool need_tail) override { |
501 | if (need_tail) { |
502 | for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) { |
503 | pinsrd(vmean, mean_ptr(offt + tl * sizeof(float)), tl); |
504 | pinsrd(vsqrtvar, var_ptr(offt + tl * sizeof(float)), tl); |
505 | } |
506 | } else { |
507 | movups(vmean, mean_ptr(offt)); |
508 | movups(vsqrtvar, var_ptr(offt)); |
509 | } |
510 | } |
511 | |
512 | void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override { |
513 | if (need_tail) { |
514 | for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) { |
515 | pinsrd(vscale, scale_ptr(offt + tl * sizeof(float)), tl); |
516 | } |
517 | } else { |
518 | movups(vscale, scale_ptr(offt)); |
519 | } |
520 | } |
521 | |
522 | void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override { |
523 | if (need_tail) { |
524 | for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) { |
525 | pinsrd(vshift, shift_ptr(offt + tl * sizeof(float)), tl); |
526 | } |
527 | } else { |
528 | movups(vshift, shift_ptr(offt)); |
529 | } |
530 | } |
531 | |
532 | void process_relu_alpha(Vmm vmm_dst) { |
533 | const Xmm xmm_aux = Xmm(vmm_aux.getIdx()); |
534 | uni_vpxor(vmm_mask, vmm_mask, vmm_mask); |
535 | vmovq(xmm_aux, reg_relu_alpha); |
536 | uni_vbroadcastss(vmm_aux, xmm_aux); |
537 | uni_vcmpps(vmm_mask, vmm_dst, vzero, _cmp_lt_os); |
538 | uni_vmulps(vmm_aux, vmm_aux, vmm_dst); |
539 | uni_vblendvps( |
540 | vmm_dst, vmm_dst, vmm_aux, vmm_mask); // swaped aux and dst |
541 | } |
542 | |
543 | void compute_dst(bool need_tail) override { |
544 | const size_t copy_range = need_tail ? c_tail_ : c_in_xmm_; |
545 | Label c_loop; |
546 | L(c_loop); |
547 | { |
548 | |
549 | Vmm v0 = Vmm(0); |
550 | Vmm v1 = Vmm(1); |
551 | Vmm vscale0 = Vmm(2); |
552 | Vmm vshift0 = Vmm(3); |
553 | Vmm vmean0 = Vmm(4); |
554 | Vmm vsqrtvar0 = Vmm(5); |
555 | Vmm vscale1 = Vmm(6); |
556 | Vmm vshift1 = Vmm(7); |
557 | Vmm vmean1 = Vmm(8); |
558 | Vmm vsqrtvar1 = Vmm(9); |
559 | |
560 | // compute couple vscale and vshift vectors each of 8 channels... |
561 | compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0, |
562 | (c_tail_ < simd_w_ && need_tail) ? true : false); |
563 | if (!need_tail || c_tail_ > simd_w_) { |
564 | compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1, |
565 | simd_w_ * sizeof(float), need_tail); |
566 | } |
567 | |
568 | // ... then process all spatial loop with it and move to the |
569 | // next channel chunk |
570 | mov(reg_spat_offt, reg_channel_offt_1byte); |
571 | Label mb_sp_loop; |
572 | L(mb_sp_loop); |
573 | { |
574 | if (need_tail) { |
575 | for (size_t tl = 0; tl < copy_range; tl++) { |
576 | if (tl < simd_w_) { |
577 | pinsrb(v0, src_ptr(tl), tl); |
578 | } else { |
579 | pinsrb(v1, src_ptr(tl), (tl - simd_w_)); |
580 | } |
581 | } |
582 | pmovsxbd(v0, v0); |
583 | pmovsxbd(v1, v1); |
584 | } else { |
585 | pmovsxbd(v0, src_ptr()); |
586 | pmovsxbd(v1, src_ptr(simd_w_)); |
587 | } |
588 | |
589 | cvtdq2ps(v0, v0); |
590 | cvtdq2ps(v1, v1); |
591 | |
592 | uni_vfmadd213ps(v0, vscale0, vshift0); |
593 | uni_vfmadd213ps(v1, vscale1, vshift1); |
594 | if (with_relu_) { |
595 | if (has_alpha_value_) { |
596 | Vmm vmm_dst_0 = Vmm(5); |
597 | Vmm vmm_dst_1 = Vmm(9); |
598 | movups(vmm_dst_0, v0); |
599 | movups(vmm_dst_1, v1); |
600 | |
601 | process_relu_alpha(vmm_dst_0); |
602 | process_relu_alpha(vmm_dst_1); |
603 | |
604 | movups(v0, vmm_dst_0); |
605 | movups(v1, vmm_dst_1); |
606 | } else { |
607 | maxps(v0, vzero); |
608 | maxps(v1, vzero); |
609 | } |
610 | } |
611 | |
612 | cvtps2dq(v0, v0); |
613 | cvtps2dq(v1, v1); |
614 | packssdw(v0, v1); |
615 | movups(v1, v0); |
616 | packsswb(v0, v1); |
617 | |
618 | // Potential perf gain is possible if combining two halves |
619 | // into a single vector register and use movups instead |
620 | // of byte stores. |
621 | for (size_t tl = 0; tl < copy_range; tl++) { |
622 | pextrb(dst_ptr(tl), v0, tl); |
623 | } |
624 | |
625 | add(reg_spat_offt, reg_channel_offt_count); |
626 | cmp(reg_spat_offt, reg_spat_offt_count); |
627 | jl(mb_sp_loop); |
628 | } |
629 | |
630 | // reg_tmp checks c_in_xmm_ channels ahead for further tail process |
631 | add(reg_tmp, sizeof(data_t) * c_in_xmm_); |
632 | add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_); |
633 | add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_); |
634 | cmp(reg_tmp, reg_channel_offt_count); |
635 | jle(c_loop); |
636 | } |
637 | } |
638 | |
639 | jit_bnorm_s8_t(const batch_normalization_pd_t *pd) |
640 | : jit_bnorm_base_t<sse41>(pd) {} |
641 | }; |
642 | |
643 | namespace bnorm_s8_impl { |
644 | |
645 | template <cpu_isa_t isa> |
646 | struct driver_t : public c_compatible { |
647 | driver_t(const batch_normalization_pd_t *pd) : pd_(pd), ker_(pd_) {} |
648 | ~driver_t() = default; |
649 | |
650 | // TODO: for problems where thread pieces don't fit L2 cache, add spatial |
651 | // re-balance using less pieces. |
652 | void exec(int ithr, int nthr, const data_t *src, data_t *dst, |
653 | const float *scale, const float *shift, const float *mean, |
654 | const float *var) { |
655 | dim_t N = pd_->MB(); |
656 | dim_t C = pd_->C(); |
657 | dim_t D = pd_->D(); |
658 | dim_t H = pd_->H(); |
659 | dim_t W = pd_->W(); |
660 | dim_t SP = D * H * W; |
661 | |
662 | jit_uni_bnorm_s8_call_params_t p; |
663 | |
664 | p.eps = pd_->desc()->batch_norm_epsilon; |
665 | |
666 | p.scale = scale; |
667 | p.shift = shift; |
668 | p.mean = mean; |
669 | p.var = var; |
670 | |
671 | dim_t work_amount {N * SP}, start {0}, end {0}; |
672 | balance211(work_amount, nthr, ithr, start, end); |
673 | |
674 | p.channel_offt_count = C; |
675 | p.spat_offt_count = (end - start) * p.channel_offt_count; |
676 | p.src = src + start * p.channel_offt_count; |
677 | p.dst = dst + start * p.channel_offt_count; |
678 | |
679 | if (p.spat_offt_count != 0) ker_(&p); |
680 | } |
681 | |
682 | status_t create_kernel() { return ker_.create_kernel(); } |
683 | |
684 | private: |
685 | const batch_normalization_pd_t *pd_; |
686 | |
687 | jit_bnorm_s8_t<isa> ker_; |
688 | }; |
689 | |
690 | } // namespace bnorm_s8_impl |
691 | |
692 | using namespace data_type; |
693 | using namespace format_tag; |
694 | using namespace utils; |
695 | |
696 | /* fwd */ |
697 | |
698 | template <cpu_isa_t isa> |
699 | status_t jit_uni_batch_normalization_s8_fwd_t<isa>::pd_t::init( |
700 | engine_t *engine) { |
701 | auto desired_fmt_tag = (ndims() == 4) ? nhwc : ndhwc; |
702 | |
703 | bool ok = true && mayiuse(isa) && is_fwd() && !has_zero_dim_memory() |
704 | && one_of(ndims(), 4, 5) && stats_is_src() |
705 | && src_md()->data_type == s8 && check_scale_shift_data_type() |
706 | && memory_desc_matches_tag(*src_md(), desired_fmt_tag) |
707 | && (attr()->has_default_values() || this->with_relu_post_op(false)); |
708 | if (!ok) return status::unimplemented; |
709 | |
710 | // BN+Add+Relu fusion is not currently implemented |
711 | if (fuse_norm_add_relu()) return status::unimplemented; |
712 | |
713 | return status::success; |
714 | } |
715 | |
716 | template <cpu_isa_t isa> |
717 | jit_uni_batch_normalization_s8_fwd_t<isa>::jit_uni_batch_normalization_s8_fwd_t( |
718 | const pd_t *apd) |
719 | : primitive_t(apd) {} |
720 | |
721 | template <cpu_isa_t isa> |
722 | status_t jit_uni_batch_normalization_s8_fwd_t<isa>::init(engine_t *engine) { |
723 | CHECK(safe_ptr_assign( |
724 | bnorm_driver_, new bnorm_s8_impl::driver_t<isa>(pd()))); |
725 | return bnorm_driver_->create_kernel(); |
726 | } |
727 | |
728 | template <cpu_isa_t isa> |
729 | status_t jit_uni_batch_normalization_s8_fwd_t<isa>::execute( |
730 | const exec_ctx_t &ctx) const { |
731 | |
732 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
733 | auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE); |
734 | auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT); |
735 | auto mean = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN)); |
736 | auto var |
737 | = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE)); |
738 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
739 | |
740 | // do sequential if the problem is less than one 4K memory page |
741 | const bool force_sequential |
742 | = pd()->MB() * pd()->C() * pd()->D() * pd()->H() * pd()->W() |
743 | <= 4096; |
744 | |
745 | parallel(force_sequential ? 1 : 0, [&](const int ithr, const int nthr) { |
746 | bnorm_driver_->exec(ithr, nthr, src, dst, scale, shift, mean, var); |
747 | }); |
748 | |
749 | return status::success; |
750 | } |
751 | |
752 | template <cpu_isa_t isa> |
753 | jit_uni_batch_normalization_s8_fwd_t< |
754 | isa>::~jit_uni_batch_normalization_s8_fwd_t() { |
755 | delete bnorm_driver_; |
756 | } |
757 | |
758 | /* struct instantiation */ |
759 | template struct jit_uni_batch_normalization_s8_fwd_t<avx512_core>; |
760 | template struct jit_uni_batch_normalization_s8_fwd_t<avx2>; |
761 | template struct jit_uni_batch_normalization_s8_fwd_t<sse41>; |
762 | |
763 | } // namespace x64 |
764 | } // namespace cpu |
765 | } // namespace impl |
766 | } // namespace dnnl |
767 | |