1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_UNI_RESAMPLING_KERNEL_HPP |
18 | #define CPU_X64_UNI_RESAMPLING_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/cpu_resampling_pd.hpp" |
25 | |
26 | #include "cpu/x64/cpu_isa_traits.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | #include "cpu/x64/jit_primitive_conf.hpp" |
30 | #include "cpu/x64/utils/jit_io_helper.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | struct jit_uni_resampling_kernel_base_t : public jit_generator { |
38 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_resampling) |
39 | |
40 | jit_uni_resampling_kernel_base_t(const jit_resampling_conf_t &conf) |
41 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, conf.isa) |
42 | , conf_(conf) |
43 | , sum_scales_(conf_.sum_scales) {} |
44 | |
45 | virtual ~jit_uni_resampling_kernel_base_t() = default; |
46 | |
47 | virtual std::size_t get_simd_w() = 0; |
48 | |
49 | protected: |
50 | const jit_resampling_conf_t &conf_; |
51 | std::queue<float> sum_scales_; |
52 | }; |
53 | |
54 | template <cpu_isa_t isa, typename Vmm> |
55 | struct jit_uni_resampling_kernel_t : public jit_uni_resampling_kernel_base_t { |
56 | |
57 | jit_uni_resampling_kernel_t( |
58 | const jit_resampling_conf_t &conf, const memory_desc_t *dst_md); |
59 | |
60 | virtual ~jit_uni_resampling_kernel_t() = default; |
61 | |
62 | std::size_t get_simd_w() override { return simd_w_; } |
63 | |
64 | private: |
65 | using Xmm = Xbyak::Xmm; |
66 | using Ymm = Xbyak::Ymm; |
67 | using Zmm = Xbyak::Zmm; |
68 | using Opmask = Xbyak::Opmask; |
69 | using Reg64 = Xbyak::Reg64; |
70 | using c_oriented_generation_fn_t = std::function<void(const bool)>; |
71 | |
72 | constexpr int vmm_idx(int idx) const { |
73 | return (cpu_isa_traits<isa>::n_vregs - 1) - idx; |
74 | } |
75 | |
76 | bool can_movntps_be_used() const; |
77 | std::size_t calculate_tail_size() const; |
78 | int get_channels_to_compute_without_tail( |
79 | bool is_tail_in_blocked_format) const; |
80 | |
81 | std::map<data_type_t, io::io_saturation_conf_t> |
82 | create_saturation_vmm_map() const; |
83 | |
84 | void get_params_for_linear_in_c_oriented_format(); |
85 | |
86 | void preserve_zero_padding_in_post_ops(int data_idx); |
87 | void apply_sum(const int data_idx, const bool is_tail); |
88 | void apply_postops(const int data_idx, const bool is_tail, |
89 | const Reg64 *reg_c = nullptr); |
90 | |
91 | void preserve_zero_padding( |
92 | int c_to_compute_without_tail, const bool is_tail); |
93 | |
94 | void interpolate_c_oriented_format( |
95 | const c_oriented_generation_fn_t &generation_fn); |
96 | void nearest_ncsp_format(); |
97 | void nearest_c_oriented_format(const bool is_tail_in_blocked_format); |
98 | void linear_ncsp_format(); |
99 | void linear_c_oriented_format(const bool is_tail_in_blocked_format); |
100 | |
101 | void generate() override; |
102 | |
103 | // Used only for avx and if c tail is present. |
104 | const Vmm vmm_tail_mask_ = Vmm(0); |
105 | // Used only for avx2 and if ncsp format is present. |
106 | // Vgatherdps always gets data using a conditional mask. |
107 | // This register contains all bits set to 1, allowing |
108 | // to get the maximum number of values available to the register |
109 | const Vmm vmm_full_mask_ = Vmm(1); |
110 | const Vmm vmm_src_ = Vmm(2); |
111 | const Vmm vmm_weights_ = Vmm(3); |
112 | const Vmm vmm_indices_ = Vmm(4); |
113 | const Vmm vmm_tmp_gather_ = Vmm(5); |
114 | const Vmm vmm_sum_scale_ = Vmm(7); |
115 | const Vmm vmm_tmp_ = Vmm(8); |
116 | const Vmm vmm_post_op_helper_ = Vmm(9); |
117 | const Vmm vmm_zero_saturation_ = isa == avx512_core ? Vmm(18) : Vmm(10); |
118 | const Vmm vmm_saturation_ubound_ = isa == avx512_core ? Vmm(19) : Vmm(11); |
119 | |
120 | const Zmm vmm_bf16_emu_1_ = Zmm(20); |
121 | const Zmm vmm_bf16_emu_2_ = Zmm(21); |
122 | const Zmm vmm_bf16_emu_3_ = Zmm(22); |
123 | const Zmm vmm_bf16_emu_4_ = Zmm(23); |
124 | |
125 | const Opmask k_tail_mask_ = k3; |
126 | const Opmask k_full_mask_ = k4; |
127 | |
128 | const Reg64 reg_tmp_ = rax; |
129 | const Reg64 reg_dst_ = rbx; |
130 | const Reg64 reg_work_ = rdx; |
131 | const Reg64 reg_indices_ = rsi; |
132 | const Reg64 reg_c_offset = rbp; |
133 | const Reg64 reg_param = abi_param1; |
134 | const Reg64 reg_weights = abi_not_param1; |
135 | const Reg64 reg_src_ = r8; |
136 | const Reg64 reg_aux_src_0_ = r9; |
137 | const Reg64 reg_aux_src_1_ = r10; |
138 | const Reg64 reg_aux_src_2_ = r11; |
139 | const Reg64 reg_tmp1_ = r15; |
140 | |
141 | // Registers which are used only for linear algorithm |
142 | // and for channel oriented formats. |
143 | // Meaning of shortcuts: |
144 | // f - front, b - back |
145 | // t - top, b - bottom |
146 | // l - left, r - right |
147 | // Example: |
148 | // src_ftl_ - source tensor data for the front top left corner |
149 | // reg_src_ftl_ - register which contains address of source |
150 | // tensor data for the front top left corner |
151 | const Vmm weight_left_ = Vmm(1); |
152 | const Vmm weight_right_ = Vmm(2); |
153 | const Vmm weight_top_ = Vmm(3); |
154 | const Vmm weight_bottom_ = Vmm(4); |
155 | const Vmm weight_front_ = Vmm(5); |
156 | const Vmm weight_back_ = Vmm(6); |
157 | const Vmm src_ftl_ = Vmm(vmm_idx(0)); |
158 | const Vmm src_ftr_ = Vmm(vmm_idx(1)); |
159 | const Vmm src_fbl_ = Vmm(vmm_idx(2)); |
160 | const Vmm src_fbr_ = Vmm(vmm_idx(3)); |
161 | const Vmm src_btl_ = Vmm(vmm_idx(4)); |
162 | const Vmm src_btr_ = Vmm(vmm_idx(5)); |
163 | const Vmm src_bbl_ = Vmm(vmm_idx(6)); |
164 | const Vmm src_bbr_ = Vmm(vmm_idx(7)); |
165 | |
166 | const Reg64 reg_src_ftl_ = reg_src_; |
167 | const Reg64 reg_src_ftr_ = reg_aux_src_0_; |
168 | const Reg64 reg_src_fbl_ = reg_aux_src_1_; |
169 | const Reg64 reg_src_fbr_ = reg_aux_src_2_; |
170 | const Reg64 reg_src_btl_ = r12; |
171 | const Reg64 reg_src_btr_ = r13; |
172 | const Reg64 reg_src_bbl_ = r14; |
173 | const Reg64 reg_src_bbr_ = r15; |
174 | |
175 | static constexpr bool is_zmm_ = std::is_same<Vmm, Xbyak::Zmm>::value; |
176 | static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value; |
177 | static constexpr bool is_xmm_ = std::is_same<Vmm, Xbyak::Xmm>::value; |
178 | static constexpr std::size_t vlen_ = is_zmm_ ? 64 : is_ymm_ ? 32 : 16; |
179 | static constexpr std::size_t simd_w_ = vlen_ / sizeof(float); |
180 | const std::size_t tail_size_; |
181 | |
182 | bool any_binary_postop_is_per_oc_bcast_type_ = false; |
183 | bool any_binary_postop_is_per_oc_sp_bcast_type_ = false; |
184 | |
185 | io::jit_io_multi_dt_helper_t<Vmm> io_; |
186 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa, Vmm>> |
187 | postops_injector_; |
188 | }; |
189 | } // namespace x64 |
190 | } // namespace cpu |
191 | } // namespace impl |
192 | } // namespace dnnl |
193 | |
194 | #endif |
195 | |