1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_UNI_RESAMPLING_HPP |
18 | #define CPU_X64_UNI_RESAMPLING_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | |
23 | #include "cpu/cpu_resampling_pd.hpp" |
24 | |
25 | #include "cpu/x64/cpu_isa_traits.hpp" |
26 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
27 | #include "cpu/x64/jit_primitive_conf.hpp" |
28 | #include "cpu/x64/jit_uni_resampling_kernel.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | struct jit_uni_resampling_fwd_t : public primitive_t { |
36 | struct pd_t : public cpu_resampling_fwd_pd_t { |
37 | using cpu_resampling_fwd_pd_t::cpu_resampling_fwd_pd_t; |
38 | |
39 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:" , conf_.isa, "" ), |
40 | jit_uni_resampling_fwd_t); |
41 | |
42 | status_t init(engine_t *engine); |
43 | |
44 | const jit_resampling_conf_t &get_conf() const { return conf_; } |
45 | |
46 | private: |
47 | void fill_format_tag_info(); |
48 | |
49 | jit_resampling_conf_t conf_; |
50 | }; |
51 | |
52 | jit_uni_resampling_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
53 | virtual ~jit_uni_resampling_fwd_t() = default; |
54 | |
55 | status_t init(engine_t *engine) override; |
56 | status_t execute(const exec_ctx_t &ctx) const override; |
57 | |
58 | private: |
59 | status_t fill_data_for_interpolation(); |
60 | /* |
61 | * Fills indices_ with the data that contains the corresponding |
62 | * input point for each output point. |
63 | * The data is arranged as follows: |
64 | * od_0 = id_0 * stride_w |
65 | * od_1 = id_1 * stride_w |
66 | * od_2 = id_2 * stride_w |
67 | * ... |
68 | * ih_0 = ih_0 * stride_h |
69 | * ih_1 = ih_1 * stride_h |
70 | * ... |
71 | * iw_0 = iw_1 * stride_w |
72 | * ... |
73 | */ |
74 | status_t fill_data_for_nearest(); |
75 | /* |
76 | * Fills indices_ with the data that contains the corresponding |
77 | * input point for each output point. |
78 | * The data is arranged as follows: |
79 | * od_0 = id_0 |
80 | * od_1 = id_1 |
81 | * od_2 = id_2 |
82 | * ... |
83 | * oh_0 = ih_0 |
84 | * oh_1 = ih_1 |
85 | * ... |
86 | * ow_0 = iw_0 |
87 | * ... |
88 | */ |
89 | status_t fill_data_for_linear(); |
90 | /* |
91 | * Fills indices_ with the data that contains the corresponding |
92 | * corners from input tensor for each output point and fills |
93 | * weights_ with with the data that contains weights for |
94 | * corners from input tensor for each output point. |
95 | * The data is arranged as follows: |
96 | * NSPC and BLOCKED: |
97 | * |
98 | * indices_: |
99 | * ow_0 = iw_0_left |
100 | * ow_0 = iw_0_right |
101 | * ow_1 = iw_1_left |
102 | * ow_1 = iw_1_right |
103 | * ... |
104 | * oh_0 = ih_0_top |
105 | * oh_1 = ih_1_top |
106 | * ... |
107 | * oh_0 = ih_0_bottom |
108 | * oh_1 = ih_1_bottom |
109 | * ... |
110 | * od_0 = id_0_front |
111 | * od_1 = id_1_front |
112 | * ... |
113 | * od_0 = id_0_back |
114 | * od_1 = id_1_back |
115 | * ... |
116 | * |
117 | * weights_: |
118 | * ow_0 = weight_0_left |
119 | * ow_0 = weight_0_right |
120 | * ow_1 = weight_1_left |
121 | * ow_1 = weight_1_right |
122 | * ... |
123 | * oh_0 = weight_0_top |
124 | * oh_1 = weight_1_top |
125 | * ... |
126 | * oh_0 = weight_0_bottom |
127 | * oh_1 = weight_1_bottom |
128 | * ... |
129 | * od_0 = weight_0_front |
130 | * od_1 = weight_1_front |
131 | * ... |
132 | * od_0 = weight_0_back |
133 | * od_1 = weight_1_back |
134 | * ... |
135 | * |
136 | * NCSP: |
137 | * |
138 | * indices_: |
139 | * sp_0 = id_0_front + ih_0_top + iw_0_left |
140 | * sp_0 = id_0_front + ih_0_top + iw_0_right |
141 | * sp_0 = id_0_front + ih_0_bottom + iw_0_left |
142 | * sp_0 = id_0_front + ih_0_bottom + iw_0_right |
143 | * sp_0 = id_0_back + ih_0_top + iw_0_left |
144 | * sp_0 = id_0_back + ih_0_top + iw_0_right |
145 | * sp_0 = id_0_back + ih_0_bottom + iw_0_left |
146 | * sp_0 = id_0_back + ih_0_bottom + iw_0_right |
147 | * sp_1 = id_1_front + ih_1_top + iw_1_left |
148 | * sp_1 = id_1_front + ih_1_top + iw_1_right |
149 | * sp_1 = id_1_front + ih_1_bottom + iw_1_left |
150 | * ... |
151 | * |
152 | * weights_: |
153 | * sp_0 = weight_0_front * weight_0_top * weight_0_left |
154 | * sp_0 = weight_0_front * weight_0_top * weight_0_right |
155 | * sp_0 = weight_0_front * weight_0_bottom * weight_0_left |
156 | * sp_0 = weight_0_front * weight_0_bottom * weight_0_right |
157 | * sp_0 = weight_0_back * weight_0_top * weight_0_left |
158 | * sp_0 = weight_0_back * weight_0_top * weight_0_right |
159 | * sp_0 = weight_0_back * weight_0_bottom * weight_0_left |
160 | * sp_0 = weight_0_back * weight_0_bottom * weight_0_right |
161 | * sp_1 = weight_1_front * weight_1_top * weight_1_left |
162 | * sp_1 = weight_1_front * weight_1_top * weight_1_right |
163 | * sp_1 = weight_1_front * weight_1_bottom * weight_1_left |
164 | * ... |
165 | */ |
166 | |
167 | status_t interpolate_nearest(const uint8_t *src, uint8_t *dst, |
168 | const std::vector<const void *> &post_ops_args) const; |
169 | status_t interpolate_linear(const uint8_t *src, uint8_t *dst, |
170 | const std::vector<const void *> &post_ops_args) const; |
171 | |
172 | status_t get_proper_kernel_for_avx512( |
173 | const memory_desc_t *dst_md, const jit_resampling_conf_t &conf); |
174 | status_t get_proper_kernel_for_avx( |
175 | const memory_desc_t *dst_md, const jit_resampling_conf_t &conf); |
176 | status_t get_proper_kernel_for_sse( |
177 | const memory_desc_t *dst_md, const jit_resampling_conf_t &conf); |
178 | |
179 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
180 | |
181 | std::unique_ptr<jit_uni_resampling_kernel_base_t> kernel_; |
182 | |
183 | std::vector<unsigned> indices_; |
184 | std::vector<float> weights_; |
185 | }; |
186 | |
187 | } // namespace x64 |
188 | } // namespace cpu |
189 | } // namespace impl |
190 | } // namespace dnnl |
191 | |
192 | #endif |
193 | |