1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_UNI_RESAMPLING_HPP
18#define CPU_X64_UNI_RESAMPLING_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22
23#include "cpu/cpu_resampling_pd.hpp"
24
25#include "cpu/x64/cpu_isa_traits.hpp"
26#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
27#include "cpu/x64/jit_primitive_conf.hpp"
28#include "cpu/x64/jit_uni_resampling_kernel.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35struct jit_uni_resampling_fwd_t : public primitive_t {
36 struct pd_t : public cpu_resampling_fwd_pd_t {
37 using cpu_resampling_fwd_pd_t::cpu_resampling_fwd_pd_t;
38
39 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", conf_.isa, ""),
40 jit_uni_resampling_fwd_t);
41
42 status_t init(engine_t *engine);
43
44 const jit_resampling_conf_t &get_conf() const { return conf_; }
45
46 private:
47 void fill_format_tag_info();
48
49 jit_resampling_conf_t conf_;
50 };
51
52 jit_uni_resampling_fwd_t(const pd_t *apd) : primitive_t(apd) {}
53 virtual ~jit_uni_resampling_fwd_t() = default;
54
55 status_t init(engine_t *engine) override;
56 status_t execute(const exec_ctx_t &ctx) const override;
57
58private:
59 status_t fill_data_for_interpolation();
60 /*
61 * Fills indices_ with the data that contains the corresponding
62 * input point for each output point.
63 * The data is arranged as follows:
64 * od_0 = id_0 * stride_w
65 * od_1 = id_1 * stride_w
66 * od_2 = id_2 * stride_w
67 * ...
68 * ih_0 = ih_0 * stride_h
69 * ih_1 = ih_1 * stride_h
70 * ...
71 * iw_0 = iw_1 * stride_w
72 * ...
73 */
74 status_t fill_data_for_nearest();
75 /*
76 * Fills indices_ with the data that contains the corresponding
77 * input point for each output point.
78 * The data is arranged as follows:
79 * od_0 = id_0
80 * od_1 = id_1
81 * od_2 = id_2
82 * ...
83 * oh_0 = ih_0
84 * oh_1 = ih_1
85 * ...
86 * ow_0 = iw_0
87 * ...
88 */
89 status_t fill_data_for_linear();
90 /*
91 * Fills indices_ with the data that contains the corresponding
92 * corners from input tensor for each output point and fills
93 * weights_ with with the data that contains weights for
94 * corners from input tensor for each output point.
95 * The data is arranged as follows:
96 * NSPC and BLOCKED:
97 *
98 * indices_:
99 * ow_0 = iw_0_left
100 * ow_0 = iw_0_right
101 * ow_1 = iw_1_left
102 * ow_1 = iw_1_right
103 * ...
104 * oh_0 = ih_0_top
105 * oh_1 = ih_1_top
106 * ...
107 * oh_0 = ih_0_bottom
108 * oh_1 = ih_1_bottom
109 * ...
110 * od_0 = id_0_front
111 * od_1 = id_1_front
112 * ...
113 * od_0 = id_0_back
114 * od_1 = id_1_back
115 * ...
116 *
117 * weights_:
118 * ow_0 = weight_0_left
119 * ow_0 = weight_0_right
120 * ow_1 = weight_1_left
121 * ow_1 = weight_1_right
122 * ...
123 * oh_0 = weight_0_top
124 * oh_1 = weight_1_top
125 * ...
126 * oh_0 = weight_0_bottom
127 * oh_1 = weight_1_bottom
128 * ...
129 * od_0 = weight_0_front
130 * od_1 = weight_1_front
131 * ...
132 * od_0 = weight_0_back
133 * od_1 = weight_1_back
134 * ...
135 *
136 * NCSP:
137 *
138 * indices_:
139 * sp_0 = id_0_front + ih_0_top + iw_0_left
140 * sp_0 = id_0_front + ih_0_top + iw_0_right
141 * sp_0 = id_0_front + ih_0_bottom + iw_0_left
142 * sp_0 = id_0_front + ih_0_bottom + iw_0_right
143 * sp_0 = id_0_back + ih_0_top + iw_0_left
144 * sp_0 = id_0_back + ih_0_top + iw_0_right
145 * sp_0 = id_0_back + ih_0_bottom + iw_0_left
146 * sp_0 = id_0_back + ih_0_bottom + iw_0_right
147 * sp_1 = id_1_front + ih_1_top + iw_1_left
148 * sp_1 = id_1_front + ih_1_top + iw_1_right
149 * sp_1 = id_1_front + ih_1_bottom + iw_1_left
150 * ...
151 *
152 * weights_:
153 * sp_0 = weight_0_front * weight_0_top * weight_0_left
154 * sp_0 = weight_0_front * weight_0_top * weight_0_right
155 * sp_0 = weight_0_front * weight_0_bottom * weight_0_left
156 * sp_0 = weight_0_front * weight_0_bottom * weight_0_right
157 * sp_0 = weight_0_back * weight_0_top * weight_0_left
158 * sp_0 = weight_0_back * weight_0_top * weight_0_right
159 * sp_0 = weight_0_back * weight_0_bottom * weight_0_left
160 * sp_0 = weight_0_back * weight_0_bottom * weight_0_right
161 * sp_1 = weight_1_front * weight_1_top * weight_1_left
162 * sp_1 = weight_1_front * weight_1_top * weight_1_right
163 * sp_1 = weight_1_front * weight_1_bottom * weight_1_left
164 * ...
165 */
166
167 status_t interpolate_nearest(const uint8_t *src, uint8_t *dst,
168 const std::vector<const void *> &post_ops_args) const;
169 status_t interpolate_linear(const uint8_t *src, uint8_t *dst,
170 const std::vector<const void *> &post_ops_args) const;
171
172 status_t get_proper_kernel_for_avx512(
173 const memory_desc_t *dst_md, const jit_resampling_conf_t &conf);
174 status_t get_proper_kernel_for_avx(
175 const memory_desc_t *dst_md, const jit_resampling_conf_t &conf);
176 status_t get_proper_kernel_for_sse(
177 const memory_desc_t *dst_md, const jit_resampling_conf_t &conf);
178
179 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
180
181 std::unique_ptr<jit_uni_resampling_kernel_base_t> kernel_;
182
183 std::vector<unsigned> indices_;
184 std::vector<float> weights_;
185};
186
187} // namespace x64
188} // namespace cpu
189} // namespace impl
190} // namespace dnnl
191
192#endif
193