1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_UNI_REDUCTION_KERNEL_HPP |
18 | #define CPU_X64_UNI_REDUCTION_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/cpu_resampling_pd.hpp" |
25 | |
26 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
27 | #include "cpu/x64/jit_generator.hpp" |
28 | #include "cpu/x64/jit_primitive_conf.hpp" |
29 | #include "cpu/x64/utils/jit_io_helper.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | struct jit_uni_reduction_kernel_base_t : public jit_generator { |
37 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reduction) |
38 | |
39 | jit_uni_reduction_kernel_base_t(const jit_reduction_conf_t &conf) |
40 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, conf.isa) |
41 | , conf_(conf) |
42 | , sum_scales_(conf_.sum_scales) {} |
43 | virtual ~jit_uni_reduction_kernel_base_t() = default; |
44 | |
45 | virtual std::size_t get_simd_w() = 0; |
46 | |
47 | protected: |
48 | const jit_reduction_conf_t &conf_; |
49 | std::queue<float> sum_scales_; |
50 | }; |
51 | |
52 | template <cpu_isa_t isa, typename Vmm = typename cpu_isa_traits<isa>::Vmm> |
53 | struct jit_uni_reduction_kernel_t : public jit_uni_reduction_kernel_base_t { |
54 | jit_uni_reduction_kernel_t( |
55 | const jit_reduction_conf_t &conf, const memory_desc_t *dst_md); |
56 | |
57 | virtual ~jit_uni_reduction_kernel_t() = default; |
58 | |
59 | std::size_t get_simd_w() override { return simd_w_; } |
60 | |
61 | private: |
62 | using compute_fn_t = std::function<void( |
63 | const Xbyak::Xmm &acc, const Xbyak::Xmm &to_acc)>; |
64 | |
65 | void init_acc(); |
66 | void init_compute_op(); |
67 | void init_compute_scalar_op(); |
68 | void init_post_ops_injector(const memory_desc_t *dst_md); |
69 | |
70 | void reduce_ymm_to_xmm(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp); |
71 | void reduce_xmm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp, |
72 | const std::size_t number_of_values_to_reduce |
73 | = number_of_f32_in_xmm_); |
74 | void reduce_zmm_to_ymm(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp); |
75 | void reduce_ymm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, |
76 | const Xbyak::Xmm &tmp2, |
77 | const std::size_t number_of_values_to_reduce |
78 | = number_of_f32_in_ymm_); |
79 | void reduce_vmm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, |
80 | const Xbyak::Xmm &tmp2, const Xbyak::Xmm &tmp3, |
81 | const std::size_t number_of_values_to_reduce |
82 | = number_of_f32_in_zmm_); |
83 | |
84 | void reduce(); |
85 | |
86 | void load_params(); |
87 | void apply_sum(const int data_idx); |
88 | void apply_postops(const int data_idx); |
89 | void finalize(); |
90 | void generate() override; |
91 | |
92 | const Vmm vmm_tail_load_mask_ = Vmm(0); |
93 | const Vmm vmm_tail_store_mask_ = Vmm(1); |
94 | const Vmm vmm_zero_saturation_ = Vmm(2); |
95 | const Vmm vmm_saturation_ubound_ = Vmm(3); |
96 | const Vmm vmm_acc_ = Vmm(4); |
97 | const Vmm vmm_tmp1_ = Vmm(5); |
98 | const Vmm vmm_tmp2_ = Vmm(6); |
99 | const Vmm vmm_tmp3_ = Vmm(7); |
100 | const Vmm vmm_tmp4_ = Vmm(8); |
101 | const Vmm vmm_sum_scale_ = Vmm(9); |
102 | const Vmm rhs_dt_helper_vmm_ = Vmm(10); |
103 | const Xbyak::Zmm vmm_bf16_emu_1_ = Xbyak::Zmm(28); |
104 | const Xbyak::Zmm vmm_bf16_emu_2_ = Xbyak::Zmm(29); |
105 | const Xbyak::Zmm vmm_bf16_emu_3_ = Xbyak::Zmm(30); |
106 | const Xbyak::Zmm vmm_bf16_emu_4_ = Xbyak::Zmm(31); |
107 | |
108 | const Xbyak::Opmask k_tail_load_mask_ = k3; |
109 | const Xbyak::Opmask k_tail_store_mask_ = k4; |
110 | |
111 | const Xbyak::Reg64 reg_work_ = rax; |
112 | const Xbyak::Reg64 reg_src_ = rbx; |
113 | const Xbyak::Reg64 reg_dst_ = rdx; |
114 | const Xbyak::Reg64 reg_param_ = abi_param1; |
115 | const Xbyak::Reg64 reg_tmp_ = abi_not_param1; |
116 | const Xbyak::Reg64 reg_tmp1_ = r13; |
117 | |
118 | static constexpr bool is_zmm_ = std::is_same<Vmm, Xbyak::Zmm>::value; |
119 | static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value; |
120 | static constexpr bool is_xmm_ = std::is_same<Vmm, Xbyak::Ymm>::value; |
121 | static constexpr std::size_t vlen_ = is_zmm_ ? 64 : is_ymm_ ? 32 : 16; |
122 | static constexpr std::size_t simd_w_ = vlen_ / sizeof(float); |
123 | static constexpr std::size_t number_of_f32_in_xmm_ = 4; |
124 | static constexpr std::size_t number_of_f32_in_ymm_ = 8; |
125 | static constexpr std::size_t number_of_f32_in_zmm_ = 16; |
126 | const std::size_t load_tail_size_; |
127 | static constexpr std::size_t store_tail_size_ = 1; |
128 | |
129 | io::jit_io_helper_t<Vmm> io_load_; |
130 | io::jit_io_helper_t<Vmm> io_store_; |
131 | |
132 | compute_fn_t compute_op_; |
133 | compute_fn_t compute_scalar_op_; |
134 | |
135 | const Xbyak::Opmask elt_inj_opmask_ = k1; |
136 | const Xbyak::Reg64 reg_po_injector_helper_1_ = r14; |
137 | const Xbyak::Reg64 reg_po_injector_helper_2_ = r15; |
138 | const Xbyak::Reg64 reg_po_injector_helper_3_ = r12; |
139 | |
140 | // post-ops injector does not use avx512_core_bf16 instructions |
141 | static constexpr cpu_isa_t inject_isa_ |
142 | = isa == avx512_core_bf16 ? avx512_core : isa; |
143 | std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa_, Vmm>> |
144 | postops_injector_; |
145 | }; |
146 | |
147 | } // namespace x64 |
148 | } // namespace cpu |
149 | } // namespace impl |
150 | } // namespace dnnl |
151 | |
152 | #endif |
153 | |