1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/type_helpers.hpp" |
18 | |
19 | #include "jit_uni_reduction_kernel.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace x64 { |
25 | |
26 | using namespace Xbyak; |
27 | #define GET_OFF(field) offsetof(jit_reduction_call_s, field) |
28 | |
29 | static const bcast_set_t &get_supported_postops_bcast_strategies() { |
30 | static const bcast_set_t supported_strategies |
31 | = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
32 | broadcasting_strategy_t::per_oc_spatial, |
33 | broadcasting_strategy_t::no_broadcast}; |
34 | return supported_strategies; |
35 | } |
36 | |
37 | template <cpu_isa_t isa, typename Vmm> |
38 | jit_uni_reduction_kernel_t<isa, Vmm>::jit_uni_reduction_kernel_t( |
39 | const jit_reduction_conf_t &conf, const memory_desc_t *dst_md) |
40 | : jit_uni_reduction_kernel_base_t(conf) |
41 | , load_tail_size_(conf.reduce_size % simd_w_) |
42 | , io_load_(this, isa, conf_.src_type, {false}, |
43 | io::io_tail_conf_t {simd_w_, load_tail_size_, k_tail_load_mask_, |
44 | vmm_tail_load_mask_.getIdx(), reg_tmp_}, |
45 | io::io_emu_bf16_conf_t {vmm_bf16_emu_1_, vmm_bf16_emu_2_, |
46 | vmm_bf16_emu_3_, reg_tmp_, vmm_bf16_emu_4_}, |
47 | io::io_saturation_conf_t {vmm_zero_saturation_.getIdx(), |
48 | vmm_saturation_ubound_.getIdx(), reg_tmp_}) |
49 | , io_store_(this, isa, conf_.dst_type, {false}, |
50 | io::io_tail_conf_t {simd_w_, store_tail_size_, k_tail_store_mask_, |
51 | vmm_tail_store_mask_.getIdx(), reg_tmp_}, |
52 | io::io_emu_bf16_conf_t {vmm_bf16_emu_1_, vmm_bf16_emu_2_, |
53 | vmm_bf16_emu_3_, reg_tmp_, vmm_bf16_emu_4_}, |
54 | io::io_saturation_conf_t {vmm_zero_saturation_.getIdx(), |
55 | vmm_saturation_ubound_.getIdx(), reg_tmp_}) { |
56 | init_compute_op(); |
57 | init_compute_scalar_op(); |
58 | if (conf_.with_postops) init_post_ops_injector(dst_md); |
59 | } |
60 | |
61 | template <cpu_isa_t isa, typename Vmm> |
62 | void jit_uni_reduction_kernel_t<isa, Vmm>::init_acc() { |
63 | using namespace alg_kind; |
64 | using namespace nstl; |
65 | |
66 | const Xmm xmm_tmp_(vmm_tmp1_.getIdx()); |
67 | float starting_val = 0; |
68 | |
69 | switch (conf_.alg) { |
70 | case reduction_max: |
71 | starting_val = numeric_limits<float>::lowest(); |
72 | break; |
73 | case reduction_min: starting_val = numeric_limits<float>::max(); break; |
74 | case reduction_mean: |
75 | case reduction_sum: starting_val = 0.f; break; |
76 | case reduction_mul: starting_val = 1.f; break; |
77 | default: assert(!"unknown alg" ); |
78 | } |
79 | |
80 | mov(reg_tmp_.cvt32(), float2int(starting_val)); |
81 | uni_vmovd(xmm_tmp_, reg_tmp_.cvt32()); |
82 | uni_vbroadcastss(vmm_acc_, xmm_tmp_); |
83 | } |
84 | |
85 | template <cpu_isa_t isa, typename Vmm> |
86 | void jit_uni_reduction_kernel_t<isa, Vmm>::init_compute_op() { |
87 | using namespace alg_kind; |
88 | switch (conf_.alg) { |
89 | case reduction_max: |
90 | compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
91 | uni_vmaxps(acc, acc, to_acc); |
92 | }; |
93 | break; |
94 | case reduction_min: |
95 | compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
96 | uni_vminps(acc, acc, to_acc); |
97 | }; |
98 | break; |
99 | case reduction_mean: |
100 | case reduction_sum: |
101 | compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
102 | uni_vaddps(acc, acc, to_acc); |
103 | }; |
104 | break; |
105 | case reduction_mul: |
106 | compute_op_ = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
107 | uni_vmulps(acc, acc, to_acc); |
108 | }; |
109 | break; |
110 | default: assert(!"unsupported alg." ); |
111 | } |
112 | } |
113 | |
114 | template <cpu_isa_t isa, typename Vmm> |
115 | void jit_uni_reduction_kernel_t<isa, Vmm>::init_compute_scalar_op() { |
116 | using namespace alg_kind; |
117 | |
118 | switch (conf_.alg) { |
119 | case reduction_max: |
120 | compute_scalar_op_ |
121 | = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
122 | maxss(acc, to_acc); |
123 | }; |
124 | break; |
125 | case reduction_min: |
126 | compute_scalar_op_ |
127 | = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
128 | minss(acc, to_acc); |
129 | }; |
130 | break; |
131 | case reduction_mean: |
132 | case reduction_sum: |
133 | compute_scalar_op_ |
134 | = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
135 | addss(acc, to_acc); |
136 | }; |
137 | break; |
138 | case reduction_mul: |
139 | compute_scalar_op_ |
140 | = [&](const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc) { |
141 | mulss(acc, to_acc); |
142 | }; |
143 | break; |
144 | default: assert(!"unsupported alg." ); |
145 | } |
146 | } |
147 | |
148 | template <cpu_isa_t isa, typename Vmm> |
149 | void jit_uni_reduction_kernel_t<isa, Vmm>::init_post_ops_injector( |
150 | const memory_desc_t *dst_md) { |
151 | const memory_desc_wrapper dst_d(*dst_md); |
152 | |
153 | const eltwise_injector::static_params_t esp(true /*save_state*/, |
154 | reg_po_injector_helper_1_, elt_inj_opmask_, true /*is_fwd*/, |
155 | false /*use_dst*/); |
156 | const binary_injector::rhs_arg_static_params_t rhs_arg_bsp { |
157 | static_cast<size_t>(rhs_dt_helper_vmm_.getIdx()), |
158 | reg_po_injector_helper_1_, reg_po_injector_helper_2_, |
159 | reg_po_injector_helper_3_, true /*preserve gpr*/, |
160 | true /*preserve vmm*/, GET_OFF(post_ops_binary_rhs_arg_vec), |
161 | GET_OFF(dst_orig), dst_d, store_tail_size_, k_tail_store_mask_, |
162 | false /*use_exact_tail_scalar_bcast*/}; |
163 | const binary_injector::static_params_t bsp( |
164 | reg_param_, get_supported_postops_bcast_strategies(), rhs_arg_bsp); |
165 | |
166 | postops_injector_ = utils::make_unique< |
167 | injector::jit_uni_postops_injector_t<inject_isa_, Vmm>>( |
168 | this, conf_.post_ops, bsp, esp); |
169 | } |
170 | |
171 | template <cpu_isa_t isa, typename Vmm> |
172 | void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_zmm_to_ymm( |
173 | const Xmm &acc, const Xmm &tmp) { |
174 | const Zmm zmm_acc(acc.getIdx()); |
175 | const Ymm ymm_acc(acc.getIdx()); |
176 | const Ymm ymm_to_acc(tmp.getIdx()); |
177 | vextractf64x4(ymm_to_acc, zmm_acc, 1); |
178 | compute_op_(ymm_acc, ymm_to_acc); |
179 | } |
180 | |
181 | template <cpu_isa_t isa, typename Vmm> |
182 | void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_ymm_to_xmm( |
183 | const Xmm &acc, const Xmm &tmp) { |
184 | const Ymm ymm_acc(acc.getIdx()); |
185 | const Xmm xmm_acc(acc.getIdx()); |
186 | const Xmm xmm_to_acc(tmp.getIdx()); |
187 | vextractf128(xmm_to_acc, ymm_acc, 1); |
188 | compute_op_(xmm_acc, xmm_to_acc); |
189 | } |
190 | |
191 | template <cpu_isa_t isa, typename Vmm> |
192 | void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_xmm_to_scalar(const Xmm &acc, |
193 | const Xmm &tmp, const std::size_t number_of_values_to_reduce) { |
194 | assert(number_of_values_to_reduce <= number_of_f32_in_xmm_); |
195 | |
196 | const Xmm xmm_acc(acc.getIdx()); |
197 | const Xmm ymm_to_acc(tmp.getIdx()); |
198 | |
199 | static constexpr int number_of_f32_to_move = number_of_f32_in_xmm_ - 1; |
200 | static constexpr uint8_t insertps_configuration[number_of_f32_to_move] |
201 | = {0b01001110, 0b10001110, 0b11001110}; |
202 | |
203 | for (std::size_t i = 0; i < number_of_values_to_reduce - 1; i++) { |
204 | insertps(ymm_to_acc, xmm_acc, insertps_configuration[i]); |
205 | compute_scalar_op_(xmm_acc, ymm_to_acc); |
206 | } |
207 | } |
208 | |
209 | template <cpu_isa_t isa, typename Vmm> |
210 | void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_ymm_to_scalar( |
211 | const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2, |
212 | const std::size_t number_of_values_to_reduce) { |
213 | assert(number_of_values_to_reduce <= number_of_f32_in_ymm_); |
214 | |
215 | const Ymm ymm_acc(acc.getIdx()); |
216 | const Xmm xmm_acc(acc.getIdx()); |
217 | const Xmm xmm_tmp(tmp1.getIdx()); |
218 | const Xmm xmm_acc_upper_half(tmp2.getIdx()); |
219 | |
220 | if (number_of_values_to_reduce == number_of_f32_in_ymm_) { |
221 | reduce_ymm_to_xmm(ymm_acc, xmm_tmp); |
222 | reduce_xmm_to_scalar(xmm_acc, xmm_tmp); |
223 | } else if (number_of_values_to_reduce > number_of_f32_in_xmm_) { |
224 | vextractf128(xmm_acc_upper_half, ymm_acc, 1); |
225 | reduce_xmm_to_scalar(xmm_acc, xmm_tmp); |
226 | reduce_xmm_to_scalar(xmm_acc_upper_half, xmm_tmp, |
227 | number_of_values_to_reduce - number_of_f32_in_xmm_); |
228 | compute_scalar_op_(xmm_acc, xmm_acc_upper_half); |
229 | } else if (number_of_values_to_reduce <= number_of_f32_in_xmm_) { |
230 | reduce_xmm_to_scalar(xmm_acc, xmm_tmp, number_of_values_to_reduce); |
231 | } |
232 | } |
233 | |
234 | template <cpu_isa_t isa, typename Vmm> |
235 | void jit_uni_reduction_kernel_t<isa, Vmm>::reduce_vmm_to_scalar( |
236 | const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2, |
237 | const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce) { |
238 | assert(number_of_values_to_reduce <= number_of_f32_in_zmm_); |
239 | |
240 | const Zmm zmm_acc(acc.getIdx()); |
241 | const Ymm ymm_acc(acc.getIdx()); |
242 | const Xmm xmm_acc(acc.getIdx()); |
243 | const Ymm ymm_acc_upper_half(tmp1.getIdx()); |
244 | const Xmm xmm_acc_upper_half(tmp1.getIdx()); |
245 | const Ymm ymm_tmp(tmp2.getIdx()); |
246 | const Xmm xmm_tmp1(tmp2.getIdx()); |
247 | const Xmm xmm_tmp2(tmp3.getIdx()); |
248 | |
249 | if (number_of_values_to_reduce == number_of_f32_in_zmm_) { |
250 | reduce_zmm_to_ymm(zmm_acc, ymm_tmp); |
251 | reduce_ymm_to_xmm(ymm_acc, xmm_tmp1); |
252 | reduce_xmm_to_scalar(xmm_acc, xmm_tmp1); |
253 | } else if (number_of_values_to_reduce > number_of_f32_in_ymm_) { |
254 | vextractf64x4(ymm_acc_upper_half, zmm_acc, 1); |
255 | reduce_ymm_to_scalar(ymm_acc, xmm_tmp1, xmm_tmp2); |
256 | reduce_ymm_to_scalar(ymm_acc_upper_half, xmm_tmp1, xmm_tmp2, |
257 | number_of_values_to_reduce - number_of_f32_in_ymm_); |
258 | compute_scalar_op_(xmm_acc, xmm_acc_upper_half); |
259 | } else if (number_of_values_to_reduce <= number_of_f32_in_ymm_) { |
260 | reduce_ymm_to_scalar( |
261 | ymm_acc, xmm_tmp1, xmm_tmp2, number_of_values_to_reduce); |
262 | } |
263 | } |
264 | |
265 | template <cpu_isa_t isa, typename Vmm> |
266 | void jit_uni_reduction_kernel_t<isa, Vmm>::reduce() { |
267 | Label label_work_begin, label_work_end; |
268 | |
269 | L(label_work_begin); |
270 | { |
271 | cmp(reg_work_, 0); |
272 | je(label_work_end); |
273 | io_load_.load(ptr[reg_src_], vmm_tmp1_, false); |
274 | compute_op_(vmm_acc_, vmm_tmp1_); |
275 | |
276 | add(reg_src_, simd_w_ * conf_.src_dt_size); |
277 | |
278 | dec(reg_work_); |
279 | jmp(label_work_begin); |
280 | } |
281 | L(label_work_end); |
282 | |
283 | if (load_tail_size_) { |
284 | io_load_.load(ptr[reg_src_], vmm_tmp1_, true); |
285 | reduce_vmm_to_scalar( |
286 | vmm_tmp1_, vmm_tmp2_, vmm_tmp3_, vmm_tmp4_, load_tail_size_); |
287 | compute_scalar_op_(Xmm(vmm_acc_.getIdx()), Xmm(vmm_tmp1_.getIdx())); |
288 | } |
289 | } |
290 | |
291 | template <cpu_isa_t isa, typename Vmm> |
292 | void jit_uni_reduction_kernel_t<isa, Vmm>::load_params() { |
293 | mov(reg_src_, ptr[reg_param_ + GET_OFF(src)]); |
294 | mov(reg_dst_, ptr[reg_param_ + GET_OFF(dst)]); |
295 | mov(reg_work_, conf_.reduce_size / simd_w_); |
296 | } |
297 | |
298 | template <cpu_isa_t isa, typename Vmm> |
299 | void jit_uni_reduction_kernel_t<isa, Vmm>::apply_sum(const int data_idx) { |
300 | if (conf_.with_sum) { |
301 | assert(!conf_.sum_scales.empty() |
302 | && "No scales for sum post operation." ); |
303 | const auto sum_injector = [this, data_idx]() { |
304 | const Vmm vmm_prev_dst(vmm_tmp1_.getIdx()); |
305 | const Vmm vmm_dst(data_idx); |
306 | |
307 | io_store_.load(ptr[reg_dst_], vmm_prev_dst, true); |
308 | const float sum_scale = sum_scales_.front(); |
309 | if (sum_scale == 1.f) |
310 | uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst); |
311 | else { |
312 | const Xmm xmm_sum_scale = Xmm(vmm_sum_scale_.getIdx()); |
313 | mov(reg_tmp1_.cvt32(), float2int(sum_scale)); |
314 | uni_vmovd(xmm_sum_scale, reg_tmp1_.cvt32()); |
315 | uni_vbroadcastss(vmm_sum_scale_, xmm_sum_scale); |
316 | uni_vfmadd231ps(vmm_dst, vmm_prev_dst, vmm_sum_scale_); |
317 | } |
318 | sum_scales_.push(sum_scale); |
319 | sum_scales_.pop(); |
320 | }; |
321 | postops_injector_->set_lambda_injector( |
322 | primitive_kind::sum, sum_injector); |
323 | } |
324 | } |
325 | |
326 | template <cpu_isa_t isa, typename Vmm> |
327 | void jit_uni_reduction_kernel_t<isa, Vmm>::apply_postops(const int data_idx) { |
328 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
329 | |
330 | if (conf_.with_sum) apply_sum(data_idx); |
331 | |
332 | if (conf_.with_binary) { |
333 | rhs_arg_params.vmm_idx_to_out_reg.emplace(data_idx, reg_dst_); |
334 | rhs_arg_params.vmm_tail_idx_.emplace(data_idx); |
335 | } |
336 | |
337 | postops_injector_->compute_vector(data_idx, rhs_arg_params); |
338 | } |
339 | |
340 | template <cpu_isa_t isa, typename Vmm> |
341 | void jit_uni_reduction_kernel_t<isa, Vmm>::finalize() { |
342 | if (static_cast<std::size_t>(conf_.reduce_size) > load_tail_size_) { |
343 | reduce_vmm_to_scalar( |
344 | vmm_acc_, vmm_tmp1_, vmm_tmp2_, vmm_tmp3_, simd_w_); |
345 | } |
346 | |
347 | if (conf_.alg == alg_kind::reduction_mean) { |
348 | const Xmm xmm_acc(vmm_acc_.getIdx()); |
349 | const Xmm xmm_reduce_size(vmm_tmp1_.getIdx()); |
350 | mov(reg_tmp_.cvt32(), float2int(static_cast<float>(conf_.reduce_size))); |
351 | uni_vmovd(xmm_reduce_size, reg_tmp_.cvt32()); |
352 | uni_vdivss(xmm_acc, xmm_acc, xmm_reduce_size); |
353 | } |
354 | |
355 | if (conf_.with_postops) apply_postops(vmm_acc_.getIdx()); |
356 | |
357 | io_store_.store(vmm_acc_, ptr[reg_dst_], true); |
358 | } |
359 | |
360 | template <cpu_isa_t isa, typename Vmm> |
361 | void jit_uni_reduction_kernel_t<isa, Vmm>::generate() { |
362 | preamble(); |
363 | |
364 | io_store_.init_bf16(); |
365 | if (conf_.is_saturation_needed) io_store_.init_saturate_f32(); |
366 | |
367 | if (load_tail_size_ > 0) io_load_.prepare_tail_mask(); |
368 | io_store_.prepare_tail_mask(); |
369 | |
370 | load_params(); |
371 | init_acc(); |
372 | reduce(); |
373 | finalize(); |
374 | |
375 | postamble(); |
376 | |
377 | if (conf_.with_eltwise && postops_injector_) |
378 | postops_injector_->prepare_table(); |
379 | } |
380 | |
381 | template struct jit_uni_reduction_kernel_t<avx512_core_fp16>; |
382 | template struct jit_uni_reduction_kernel_t<avx512_core_bf16>; |
383 | template struct jit_uni_reduction_kernel_t<avx512_core>; |
384 | template struct jit_uni_reduction_kernel_t<avx2>; |
385 | template struct jit_uni_reduction_kernel_t<avx2, Xbyak::Xmm>; |
386 | template struct jit_uni_reduction_kernel_t<avx>; |
387 | template struct jit_uni_reduction_kernel_t<avx, Xbyak::Xmm>; |
388 | template struct jit_uni_reduction_kernel_t<sse41>; |
389 | |
390 | } // namespace x64 |
391 | } // namespace cpu |
392 | } // namespace impl |
393 | } // namespace dnnl |
394 | |