1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_UNI_REDUCTION_KERNEL_HPP
18#define CPU_X64_UNI_REDUCTION_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/cpu_resampling_pd.hpp"
25
26#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
27#include "cpu/x64/jit_generator.hpp"
28#include "cpu/x64/jit_primitive_conf.hpp"
29#include "cpu/x64/utils/jit_io_helper.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36struct jit_uni_reduction_kernel_base_t : public jit_generator {
37 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reduction)
38
39 jit_uni_reduction_kernel_base_t(const jit_reduction_conf_t &conf)
40 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, conf.isa)
41 , conf_(conf)
42 , sum_scales_(conf_.sum_scales) {}
43 virtual ~jit_uni_reduction_kernel_base_t() = default;
44
45 virtual std::size_t get_simd_w() = 0;
46
47protected:
48 const jit_reduction_conf_t &conf_;
49 std::queue<float> sum_scales_;
50};
51
52template <cpu_isa_t isa, typename Vmm = typename cpu_isa_traits<isa>::Vmm>
53struct jit_uni_reduction_kernel_t : public jit_uni_reduction_kernel_base_t {
54 jit_uni_reduction_kernel_t(
55 const jit_reduction_conf_t &conf, const memory_desc_t *dst_md);
56
57 virtual ~jit_uni_reduction_kernel_t() = default;
58
59 std::size_t get_simd_w() override { return simd_w_; }
60
61private:
62 using compute_fn_t = std::function<void(
63 const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc)>;
64
65 void init_acc();
66 void init_compute_op();
67 void init_compute_scalar_op();
68 void init_post_ops_injector(const memory_desc_t *dst_md);
69
70 void reduce_ymm_to_xmm(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp);
71 void reduce_xmm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp,
72 const std::size_t number_of_values_to_reduce
73 = number_of_f32_in_xmm_);
74 void reduce_zmm_to_ymm(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp);
75 void reduce_ymm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1,
76 const Xbyak::Xmm &tmp2,
77 const std::size_t number_of_values_to_reduce
78 = number_of_f32_in_ymm_);
79 void reduce_vmm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1,
80 const Xbyak::Xmm &tmp2, const Xbyak::Xmm &tmp3,
81 const std::size_t number_of_values_to_reduce
82 = number_of_f32_in_zmm_);
83
84 void reduce();
85
86 void load_params();
87 void apply_sum(const int data_idx);
88 void apply_postops(const int data_idx);
89 void finalize();
90 void generate() override;
91
92 const Vmm vmm_tail_load_mask_ = Vmm(0);
93 const Vmm vmm_tail_store_mask_ = Vmm(1);
94 const Vmm vmm_zero_saturation_ = Vmm(2);
95 const Vmm vmm_saturation_ubound_ = Vmm(3);
96 const Vmm vmm_acc_ = Vmm(4);
97 const Vmm vmm_tmp1_ = Vmm(5);
98 const Vmm vmm_tmp2_ = Vmm(6);
99 const Vmm vmm_tmp3_ = Vmm(7);
100 const Vmm vmm_tmp4_ = Vmm(8);
101 const Vmm vmm_sum_scale_ = Vmm(9);
102 const Vmm rhs_dt_helper_vmm_ = Vmm(10);
103 const Xbyak::Zmm vmm_bf16_emu_1_ = Xbyak::Zmm(28);
104 const Xbyak::Zmm vmm_bf16_emu_2_ = Xbyak::Zmm(29);
105 const Xbyak::Zmm vmm_bf16_emu_3_ = Xbyak::Zmm(30);
106 const Xbyak::Zmm vmm_bf16_emu_4_ = Xbyak::Zmm(31);
107
108 const Xbyak::Opmask k_tail_load_mask_ = k3;
109 const Xbyak::Opmask k_tail_store_mask_ = k4;
110
111 const Xbyak::Reg64 reg_work_ = rax;
112 const Xbyak::Reg64 reg_src_ = rbx;
113 const Xbyak::Reg64 reg_dst_ = rdx;
114 const Xbyak::Reg64 reg_param_ = abi_param1;
115 const Xbyak::Reg64 reg_tmp_ = abi_not_param1;
116 const Xbyak::Reg64 reg_tmp1_ = r13;
117
118 static constexpr bool is_zmm_ = std::is_same<Vmm, Xbyak::Zmm>::value;
119 static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
120 static constexpr bool is_xmm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
121 static constexpr std::size_t vlen_ = is_zmm_ ? 64 : is_ymm_ ? 32 : 16;
122 static constexpr std::size_t simd_w_ = vlen_ / sizeof(float);
123 static constexpr std::size_t number_of_f32_in_xmm_ = 4;
124 static constexpr std::size_t number_of_f32_in_ymm_ = 8;
125 static constexpr std::size_t number_of_f32_in_zmm_ = 16;
126 const std::size_t load_tail_size_;
127 static constexpr std::size_t store_tail_size_ = 1;
128
129 io::jit_io_helper_t<Vmm> io_load_;
130 io::jit_io_helper_t<Vmm> io_store_;
131
132 compute_fn_t compute_op_;
133 compute_fn_t compute_scalar_op_;
134
135 const Xbyak::Opmask elt_inj_opmask_ = k1;
136 const Xbyak::Reg64 reg_po_injector_helper_1_ = r14;
137 const Xbyak::Reg64 reg_po_injector_helper_2_ = r15;
138 const Xbyak::Reg64 reg_po_injector_helper_3_ = r12;
139
140 // post-ops injector does not use avx512_core_bf16 instructions
141 static constexpr cpu_isa_t inject_isa_
142 = isa == avx512_core_bf16 ? avx512_core : isa;
143 std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa_, Vmm>>
144 postops_injector_;
145};
146
147} // namespace x64
148} // namespace cpu
149} // namespace impl
150} // namespace dnnl
151
152#endif
153