1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_UNI_RESAMPLING_KERNEL_HPP
18#define CPU_X64_UNI_RESAMPLING_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/cpu_resampling_pd.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_primitive_conf.hpp"
30#include "cpu/x64/utils/jit_io_helper.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37struct jit_uni_resampling_kernel_base_t : public jit_generator {
38 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_resampling)
39
40 jit_uni_resampling_kernel_base_t(const jit_resampling_conf_t &conf)
41 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, conf.isa)
42 , conf_(conf)
43 , sum_scales_(conf_.sum_scales) {}
44
45 virtual ~jit_uni_resampling_kernel_base_t() = default;
46
47 virtual std::size_t get_simd_w() = 0;
48
49protected:
50 const jit_resampling_conf_t &conf_;
51 std::queue<float> sum_scales_;
52};
53
54template <cpu_isa_t isa, typename Vmm>
55struct jit_uni_resampling_kernel_t : public jit_uni_resampling_kernel_base_t {
56
57 jit_uni_resampling_kernel_t(
58 const jit_resampling_conf_t &conf, const memory_desc_t *dst_md);
59
60 virtual ~jit_uni_resampling_kernel_t() = default;
61
62 std::size_t get_simd_w() override { return simd_w_; }
63
64private:
65 using Xmm = Xbyak::Xmm;
66 using Ymm = Xbyak::Ymm;
67 using Zmm = Xbyak::Zmm;
68 using Opmask = Xbyak::Opmask;
69 using Reg64 = Xbyak::Reg64;
70 using c_oriented_generation_fn_t = std::function<void(const bool)>;
71
72 constexpr int vmm_idx(int idx) const {
73 return (cpu_isa_traits<isa>::n_vregs - 1) - idx;
74 }
75
76 bool can_movntps_be_used() const;
77 std::size_t calculate_tail_size() const;
78 int get_channels_to_compute_without_tail(
79 bool is_tail_in_blocked_format) const;
80
81 std::map<data_type_t, io::io_saturation_conf_t>
82 create_saturation_vmm_map() const;
83
84 void get_params_for_linear_in_c_oriented_format();
85
86 void preserve_zero_padding_in_post_ops(int data_idx);
87 void apply_sum(const int data_idx, const bool is_tail);
88 void apply_postops(const int data_idx, const bool is_tail,
89 const Reg64 *reg_c = nullptr);
90
91 void preserve_zero_padding(
92 int c_to_compute_without_tail, const bool is_tail);
93
94 void interpolate_c_oriented_format(
95 const c_oriented_generation_fn_t &generation_fn);
96 void nearest_ncsp_format();
97 void nearest_c_oriented_format(const bool is_tail_in_blocked_format);
98 void linear_ncsp_format();
99 void linear_c_oriented_format(const bool is_tail_in_blocked_format);
100
101 void generate() override;
102
103 // Used only for avx and if c tail is present.
104 const Vmm vmm_tail_mask_ = Vmm(0);
105 // Used only for avx2 and if ncsp format is present.
106 // Vgatherdps always gets data using a conditional mask.
107 // This register contains all bits set to 1, allowing
108 // to get the maximum number of values available to the register
109 const Vmm vmm_full_mask_ = Vmm(1);
110 const Vmm vmm_src_ = Vmm(2);
111 const Vmm vmm_weights_ = Vmm(3);
112 const Vmm vmm_indices_ = Vmm(4);
113 const Vmm vmm_tmp_gather_ = Vmm(5);
114 const Vmm vmm_sum_scale_ = Vmm(7);
115 const Vmm vmm_tmp_ = Vmm(8);
116 const Vmm vmm_post_op_helper_ = Vmm(9);
117 const Vmm vmm_zero_saturation_ = isa == avx512_core ? Vmm(18) : Vmm(10);
118 const Vmm vmm_saturation_ubound_ = isa == avx512_core ? Vmm(19) : Vmm(11);
119
120 const Zmm vmm_bf16_emu_1_ = Zmm(20);
121 const Zmm vmm_bf16_emu_2_ = Zmm(21);
122 const Zmm vmm_bf16_emu_3_ = Zmm(22);
123 const Zmm vmm_bf16_emu_4_ = Zmm(23);
124
125 const Opmask k_tail_mask_ = k3;
126 const Opmask k_full_mask_ = k4;
127
128 const Reg64 reg_tmp_ = rax;
129 const Reg64 reg_dst_ = rbx;
130 const Reg64 reg_work_ = rdx;
131 const Reg64 reg_indices_ = rsi;
132 const Reg64 reg_c_offset = rbp;
133 const Reg64 reg_param = abi_param1;
134 const Reg64 reg_weights = abi_not_param1;
135 const Reg64 reg_src_ = r8;
136 const Reg64 reg_aux_src_0_ = r9;
137 const Reg64 reg_aux_src_1_ = r10;
138 const Reg64 reg_aux_src_2_ = r11;
139 const Reg64 reg_tmp1_ = r15;
140
141 // Registers which are used only for linear algorithm
142 // and for channel oriented formats.
143 // Meaning of shortcuts:
144 // f - front, b - back
145 // t - top, b - bottom
146 // l - left, r - right
147 // Example:
148 // src_ftl_ - source tensor data for the front top left corner
149 // reg_src_ftl_ - register which contains address of source
150 // tensor data for the front top left corner
151 const Vmm weight_left_ = Vmm(1);
152 const Vmm weight_right_ = Vmm(2);
153 const Vmm weight_top_ = Vmm(3);
154 const Vmm weight_bottom_ = Vmm(4);
155 const Vmm weight_front_ = Vmm(5);
156 const Vmm weight_back_ = Vmm(6);
157 const Vmm src_ftl_ = Vmm(vmm_idx(0));
158 const Vmm src_ftr_ = Vmm(vmm_idx(1));
159 const Vmm src_fbl_ = Vmm(vmm_idx(2));
160 const Vmm src_fbr_ = Vmm(vmm_idx(3));
161 const Vmm src_btl_ = Vmm(vmm_idx(4));
162 const Vmm src_btr_ = Vmm(vmm_idx(5));
163 const Vmm src_bbl_ = Vmm(vmm_idx(6));
164 const Vmm src_bbr_ = Vmm(vmm_idx(7));
165
166 const Reg64 reg_src_ftl_ = reg_src_;
167 const Reg64 reg_src_ftr_ = reg_aux_src_0_;
168 const Reg64 reg_src_fbl_ = reg_aux_src_1_;
169 const Reg64 reg_src_fbr_ = reg_aux_src_2_;
170 const Reg64 reg_src_btl_ = r12;
171 const Reg64 reg_src_btr_ = r13;
172 const Reg64 reg_src_bbl_ = r14;
173 const Reg64 reg_src_bbr_ = r15;
174
175 static constexpr bool is_zmm_ = std::is_same<Vmm, Xbyak::Zmm>::value;
176 static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
177 static constexpr bool is_xmm_ = std::is_same<Vmm, Xbyak::Xmm>::value;
178 static constexpr std::size_t vlen_ = is_zmm_ ? 64 : is_ymm_ ? 32 : 16;
179 static constexpr std::size_t simd_w_ = vlen_ / sizeof(float);
180 const std::size_t tail_size_;
181
182 bool any_binary_postop_is_per_oc_bcast_type_ = false;
183 bool any_binary_postop_is_per_oc_sp_bcast_type_ = false;
184
185 io::jit_io_multi_dt_helper_t<Vmm> io_;
186 std::unique_ptr<injector::jit_uni_postops_injector_t<isa, Vmm>>
187 postops_injector_;
188};
189} // namespace x64
190} // namespace cpu
191} // namespace impl
192} // namespace dnnl
193
194#endif
195