1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/type_helpers.hpp"
18
19#include "jit_uni_reduction_kernel.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace x64 {
25
26using namespace Xbyak;
27#define GET_OFF(field) offsetof(jit_reduction_call_s, field)
28
29static const bcast_set_t &get_supported_postops_bcast_strategies() {
30 static const bcast_set_t supported_strategies
31 = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
32 broadcasting_strategy_t::per_oc_spatial,
33 broadcasting_strategy_t::no_broadcast};
34 return supported_strategies;
35}
36
37template <cpu_isa_t isa, typename Vmm>
38jit_uni_reduction_kernel_t<isa, Vmm>::jit_uni_reduction_kernel_t(
39 const jit_reduction_conf_t &conf, const memory_desc_t *dst_md)
40 : jit_uni_reduction_kernel_base_t(conf)
41 , load_tail_size_(conf.reduce_size % simd_w_)
42 , io_load_(this, isa, conf_.src_type, {false},
43 io::io_tail_conf_t {simd_w_, load_tail_size_, k_tail_load_mask_,
44 vmm_tail_load_mask_.getIdx(), reg_tmp_},
45 io::io_emu_bf16_conf_t {vmm_bf16_emu_1_, vmm_bf16_emu_2_,
46 vmm_bf16_emu_3_, reg_tmp_, vmm_bf16_emu_4_},
47 io::io_saturation_conf_t {vmm_zero_saturation_.getIdx(),
48 vmm_saturation_ubound_.getIdx(), reg_tmp_})
49 , io_store_(this, isa, conf_.dst_type, {false},
50 io::io_tail_conf_t {simd_w_, store_tail_size_, k_tail_store_mask_,
51 vmm_tail_store_mask_.getIdx(), reg_tmp_},
52 io::io_emu_bf16_conf_t {vmm_bf16_emu_1_, vmm_bf16_emu_2_,
53 vmm_bf16_emu_3_, reg_tmp_, vmm_bf16_emu_4_},
54 io::io_saturation_conf_t {vmm_zero_saturation_.getIdx(),
55 vmm_saturation_ubound_.getIdx(), reg_tmp_}) {
56 init_compute_op();
57 init_compute_scalar_op();
58 if (conf_.with_postops) init_post_ops_injector(dst_md);
59}
60
61template <cpu_isa_t isa, typename Vmm>
62void jit_uni_reduction_kernel_t<isa, Vmm>::init_acc() {
63 using namespace alg_kind;
64 using namespace nstl;
65
66 const Xmm xmm_tmp_(vmm_tmp1_.getIdx());
67 float starting_val = 0;
68
69 switch (conf_.alg) {
70 case reduction_max:
71 starting_val = numeric_limits<float>::lowest();
72 break;
73 case reduction_min: starting_val = numeric_limits<float>::max(); break;
74 case reduction_mean:
75 case reduction_sum: starting_val = 0.f; break;
76 case reduction_mul: starting_val = 1.f; break;
77 default: assert(!"unknown alg");
78 }
79
80 mov(reg_tmp_.cvt32(), float2int(starting_val));
81 uni_vmovd(xmm_tmp_, reg_tmp_.cvt32());
82 uni_vbroadcastss(vmm_acc_, xmm_tmp_);
83}
84
85template <cpu_isa_t isa, typename Vmm>
86void jit_uni_reduction_kernel_t<isa, Vmm>::init_compute_op() {
87 using namespace alg_kind;
88 switch (conf_.alg) {
89 case reduction_max:
90 compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
91 uni_vmaxps(acc, acc, to_acc);
92 };
93 break;
94 case reduction_min:
95 compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
96 uni_vminps(acc, acc, to_acc);
97 };
98 break;
99 case reduction_mean:
100 case reduction_sum:
101 compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
102 uni_vaddps(acc, acc, to_acc);
103 };
104 break;
105 case reduction_mul:
106 compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
107 uni_vmulps(acc, acc, to_acc);
108 };
109 break;
110 default: assert(!"unsupported alg.");
111 }
112}
113
114template <cpu_isa_t isa, typename Vmm>
115void jit_uni_reduction_kernel_t<isa, Vmm>::init_compute_scalar_op() {
116 using namespace alg_kind;
117
118 switch (conf_.alg) {
119 case reduction_max:
120 compute_scalar_op_
121 = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
122 maxss(acc, to_acc);
123 };
124 break;
125 case reduction_min:
126 compute_scalar_op_
127 = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
128 minss(acc, to_acc);
129 };
130 break;
131 case reduction_mean:
132 case reduction_sum:
133 compute_scalar_op_
134 = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
135 addss(acc, to_acc);
136 };
137 break;
138 case reduction_mul:
139 compute_scalar_op_
140 = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) {
141 mulss(acc, to_acc);
142 };
143 break;
144 default: assert(!"unsupported alg.");
145 }
146}
147
148template <cpu_isa_t isa, typename Vmm>
149void jit_uni_reduction_kernel_t<isa, Vmm>::init_post_ops_injector(
150 const memory_desc_t *dst_md) {
151 const memory_desc_wrapper dst_d(*dst_md);
152
153 const eltwise_injector::static_params_t esp(true /*save_state*/,
154 reg_po_injector_helper_1_, elt_inj_opmask_, true /*is_fwd*/,
155 false /*use_dst*/);
156 const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {
157 static_cast<size_t>(rhs_dt_helper_vmm_.getIdx()),
158 reg_po_injector_helper_1_, reg_po_injector_helper_2_,
159 reg_po_injector_helper_3_, true /*preserve gpr*/,
160 true /*preserve vmm*/, GET_OFF(post_ops_binary_rhs_arg_vec),
161 GET_OFF(dst_orig), dst_d, store_tail_size_, k_tail_store_mask_,
162 false /*use_exact_tail_scalar_bcast*/};
163 const binary_injector::static_params_t bsp(
164 reg_param_, get_supported_postops_bcast_strategies(), rhs_arg_bsp);
165
166 postops_injector_ = utils::make_unique<
167 injector::jit_uni_postops_injector_t<inject_isa_, Vmm>>(
168 this, conf_.post_ops, bsp, esp);
169}
170
171template <cpu_isa_t isa, typename Vmm>
172void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_zmm_to_ymm(
173 const Xmm &acc, const Xmm &tmp) {
174 const Zmm zmm_acc(acc.getIdx());
175 const Ymm ymm_acc(acc.getIdx());
176 const Ymm ymm_to_acc(tmp.getIdx());
177 vextractf64x4(ymm_to_acc, zmm_acc, 1);
178 compute_op_(ymm_acc, ymm_to_acc);
179}
180
181template <cpu_isa_t isa, typename Vmm>
182void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_ymm_to_xmm(
183 const Xmm &acc, const Xmm &tmp) {
184 const Ymm ymm_acc(acc.getIdx());
185 const Xmm xmm_acc(acc.getIdx());
186 const Xmm xmm_to_acc(tmp.getIdx());
187 vextractf128(xmm_to_acc, ymm_acc, 1);
188 compute_op_(xmm_acc, xmm_to_acc);
189}
190
191template <cpu_isa_t isa, typename Vmm>
192void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_xmm_to_scalar(const Xmm &acc,
193 const Xmm &tmp, const std::size_t number_of_values_to_reduce) {
194 assert(number_of_values_to_reduce <= number_of_f32_in_xmm_);
195
196 const Xmm xmm_acc(acc.getIdx());
197 const Xmm ymm_to_acc(tmp.getIdx());
198
199 static constexpr int number_of_f32_to_move = number_of_f32_in_xmm_ - 1;
200 static constexpr uint8_t insertps_configuration[number_of_f32_to_move]
201 = {0b01001110, 0b10001110, 0b11001110};
202
203 for (std::size_t i = 0; i < number_of_values_to_reduce - 1; i++) {
204 insertps(ymm_to_acc, xmm_acc, insertps_configuration[i]);
205 compute_scalar_op_(xmm_acc, ymm_to_acc);
206 }
207}
208
209template <cpu_isa_t isa, typename Vmm>
210void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_ymm_to_scalar(
211 const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
212 const std::size_t number_of_values_to_reduce) {
213 assert(number_of_values_to_reduce <= number_of_f32_in_ymm_);
214
215 const Ymm ymm_acc(acc.getIdx());
216 const Xmm xmm_acc(acc.getIdx());
217 const Xmm xmm_tmp(tmp1.getIdx());
218 const Xmm xmm_acc_upper_half(tmp2.getIdx());
219
220 if (number_of_values_to_reduce == number_of_f32_in_ymm_) {
221 reduce_ymm_to_xmm(ymm_acc, xmm_tmp);
222 reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
223 } else if (number_of_values_to_reduce > number_of_f32_in_xmm_) {
224 vextractf128(xmm_acc_upper_half, ymm_acc, 1);
225 reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
226 reduce_xmm_to_scalar(xmm_acc_upper_half, xmm_tmp,
227 number_of_values_to_reduce - number_of_f32_in_xmm_);
228 compute_scalar_op_(xmm_acc, xmm_acc_upper_half);
229 } else if (number_of_values_to_reduce <= number_of_f32_in_xmm_) {
230 reduce_xmm_to_scalar(xmm_acc, xmm_tmp, number_of_values_to_reduce);
231 }
232}
233
234template <cpu_isa_t isa, typename Vmm>
235void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_vmm_to_scalar(
236 const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
237 const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce) {
238 assert(number_of_values_to_reduce <= number_of_f32_in_zmm_);
239
240 const Zmm zmm_acc(acc.getIdx());
241 const Ymm ymm_acc(acc.getIdx());
242 const Xmm xmm_acc(acc.getIdx());
243 const Ymm ymm_acc_upper_half(tmp1.getIdx());
244 const Xmm xmm_acc_upper_half(tmp1.getIdx());
245 const Ymm ymm_tmp(tmp2.getIdx());
246 const Xmm xmm_tmp1(tmp2.getIdx());
247 const Xmm xmm_tmp2(tmp3.getIdx());
248
249 if (number_of_values_to_reduce == number_of_f32_in_zmm_) {
250 reduce_zmm_to_ymm(zmm_acc, ymm_tmp);
251 reduce_ymm_to_xmm(ymm_acc, xmm_tmp1);
252 reduce_xmm_to_scalar(xmm_acc, xmm_tmp1);
253 } else if (number_of_values_to_reduce > number_of_f32_in_ymm_) {
254 vextractf64x4(ymm_acc_upper_half, zmm_acc, 1);
255 reduce_ymm_to_scalar(ymm_acc, xmm_tmp1, xmm_tmp2);
256 reduce_ymm_to_scalar(ymm_acc_upper_half, xmm_tmp1, xmm_tmp2,
257 number_of_values_to_reduce - number_of_f32_in_ymm_);
258 compute_scalar_op_(xmm_acc, xmm_acc_upper_half);
259 } else if (number_of_values_to_reduce <= number_of_f32_in_ymm_) {
260 reduce_ymm_to_scalar(
261 ymm_acc, xmm_tmp1, xmm_tmp2, number_of_values_to_reduce);
262 }
263}
264
265template <cpu_isa_t isa, typename Vmm>
266void jit_uni_reduction_kernel_t<isa, Vmm>::reduce() {
267 Label label_work_begin, label_work_end;
268
269 L(label_work_begin);
270 {
271 cmp(reg_work_, 0);
272 je(label_work_end);
273 io_load_.load(ptr[reg_src_], vmm_tmp1_, false);
274 compute_op_(vmm_acc_, vmm_tmp1_);
275
276 add(reg_src_, simd_w_ * conf_.src_dt_size);
277
278 dec(reg_work_);
279 jmp(label_work_begin);
280 }
281 L(label_work_end);
282
283 if (load_tail_size_) {
284 io_load_.load(ptr[reg_src_], vmm_tmp1_, true);
285 reduce_vmm_to_scalar(
286 vmm_tmp1_, vmm_tmp2_, vmm_tmp3_, vmm_tmp4_, load_tail_size_);
287 compute_scalar_op_(Xmm(vmm_acc_.getIdx()), Xmm(vmm_tmp1_.getIdx()));
288 }
289}
290
291template <cpu_isa_t isa, typename Vmm>
292void jit_uni_reduction_kernel_t<isa, Vmm>::load_params() {
293 mov(reg_src_, ptr[reg_param_ + GET_OFF(src)]);
294 mov(reg_dst_, ptr[reg_param_ + GET_OFF(dst)]);
295 mov(reg_work_, conf_.reduce_size / simd_w_);
296}
297
298template <cpu_isa_t isa, typename Vmm>
299void jit_uni_reduction_kernel_t<isa, Vmm>::apply_sum(const int data_idx) {
300 if (conf_.with_sum) {
301 assert(!conf_.sum_scales.empty()
302 && "No scales for sum post operation.");
303 const auto sum_injector = [this, data_idx]() {
304 const Vmm vmm_prev_dst(vmm_tmp1_.getIdx());
305 const Vmm vmm_dst(data_idx);
306
307 io_store_.load(ptr[reg_dst_], vmm_prev_dst, true);
308 const float sum_scale = sum_scales_.front();
309 if (sum_scale == 1.f)
310 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
311 else {
312 const Xmm xmm_sum_scale = Xmm(vmm_sum_scale_.getIdx());
313 mov(reg_tmp1_.cvt32(), float2int(sum_scale));
314 uni_vmovd(xmm_sum_scale, reg_tmp1_.cvt32());
315 uni_vbroadcastss(vmm_sum_scale_, xmm_sum_scale);
316 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, vmm_sum_scale_);
317 }
318 sum_scales_.push(sum_scale);
319 sum_scales_.pop();
320 };
321 postops_injector_->set_lambda_injector(
322 primitive_kind::sum, sum_injector);
323 }
324}
325
326template <cpu_isa_t isa, typename Vmm>
327void jit_uni_reduction_kernel_t<isa, Vmm>::apply_postops(const int data_idx) {
328 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
329
330 if (conf_.with_sum) apply_sum(data_idx);
331
332 if (conf_.with_binary) {
333 rhs_arg_params.vmm_idx_to_out_reg.emplace(data_idx, reg_dst_);
334 rhs_arg_params.vmm_tail_idx_.emplace(data_idx);
335 }
336
337 postops_injector_->compute_vector(data_idx, rhs_arg_params);
338}
339
340template <cpu_isa_t isa, typename Vmm>
341void jit_uni_reduction_kernel_t<isa, Vmm>::finalize() {
342 if (static_cast<std::size_t>(conf_.reduce_size) > load_tail_size_) {
343 reduce_vmm_to_scalar(
344 vmm_acc_, vmm_tmp1_, vmm_tmp2_, vmm_tmp3_, simd_w_);
345 }
346
347 if (conf_.alg == alg_kind::reduction_mean) {
348 const Xmm xmm_acc(vmm_acc_.getIdx());
349 const Xmm xmm_reduce_size(vmm_tmp1_.getIdx());
350 mov(reg_tmp_.cvt32(), float2int(static_cast<float>(conf_.reduce_size)));
351 uni_vmovd(xmm_reduce_size, reg_tmp_.cvt32());
352 uni_vdivss(xmm_acc, xmm_acc, xmm_reduce_size);
353 }
354
355 if (conf_.with_postops) apply_postops(vmm_acc_.getIdx());
356
357 io_store_.store(vmm_acc_, ptr[reg_dst_], true);
358}
359
360template <cpu_isa_t isa, typename Vmm>
361void jit_uni_reduction_kernel_t<isa, Vmm>::generate() {
362 preamble();
363
364 io_store_.init_bf16();
365 if (conf_.is_saturation_needed) io_store_.init_saturate_f32();
366
367 if (load_tail_size_ > 0) io_load_.prepare_tail_mask();
368 io_store_.prepare_tail_mask();
369
370 load_params();
371 init_acc();
372 reduce();
373 finalize();
374
375 postamble();
376
377 if (conf_.with_eltwise && postops_injector_)
378 postops_injector_->prepare_table();
379}
380
381template struct jit_uni_reduction_kernel_t<avx512_core_fp16>;
382template struct jit_uni_reduction_kernel_t<avx512_core_bf16>;
383template struct jit_uni_reduction_kernel_t<avx512_core>;
384template struct jit_uni_reduction_kernel_t<avx2>;
385template struct jit_uni_reduction_kernel_t<avx2, Xbyak::Xmm>;
386template struct jit_uni_reduction_kernel_t<avx>;
387template struct jit_uni_reduction_kernel_t<avx, Xbyak::Xmm>;
388template struct jit_uni_reduction_kernel_t<sse41>;
389
390} // namespace x64
391} // namespace cpu
392} // namespace impl
393} // namespace dnnl
394