1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_REORDER_HPP |
18 | #define CPU_X64_JIT_UNI_REORDER_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/reorder/cpu_reorder_pd.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | namespace tr { |
33 | |
34 | constexpr int max_ndims = DNNL_MAX_NDIMS; |
35 | |
36 | struct node_t { |
37 | static constexpr int64_t empty_field = -1; |
38 | |
39 | size_t n = 0; |
40 | size_t tail_size = 0; |
41 | int dim_id = empty_field; |
42 | int parent_node_id = empty_field; |
43 | bool is_zero_pad_needed = false; |
44 | ptrdiff_t is = 0; // input stride |
45 | ptrdiff_t os = 0; // output stride |
46 | ptrdiff_t ss = 0; // scale stride |
47 | ptrdiff_t cs = 0; // compensation stride |
48 | |
49 | bool is_dim_id_empty() const { return dim_id == empty_field; } |
50 | bool is_parent_empty() const { return parent_node_id == empty_field; } |
51 | }; |
52 | |
53 | enum class scale_type_t { NONE, COMMON, MANY }; |
54 | |
55 | struct prb_t { |
56 | /* The compensation mask value indicates how big an additional buffer should be. |
57 | * Possible values for reorder: |
58 | * 1) standard compensation = 1 = 0b01 |
59 | * 2) asymmetric compensation = 2 = 0b10 |
60 | * 3) compensation if tensor contains group = 3 = 0b11 */ |
61 | static constexpr int invalid_comp_mask = 0; |
62 | static constexpr int standard_comp_mask = 0b1; |
63 | static constexpr int asymmetric_comp_mask = 0b10; |
64 | static constexpr int comp_mask_with_groups |
65 | = standard_comp_mask + asymmetric_comp_mask; |
66 | |
67 | bool is_tail_in_one_of_child_nodes(int parent_node_id) const { |
68 | for (int i = parent_node_id; i >= 0; i--) { |
69 | if (nodes[i].parent_node_id == parent_node_id) { |
70 | if (nodes[i].tail_size != 0) |
71 | return true; |
72 | else |
73 | parent_node_id = i; |
74 | } |
75 | } |
76 | |
77 | return false; |
78 | } |
79 | |
80 | int tail(int d) const { |
81 | assert(d < ndims); |
82 | return static_cast<int>(nodes[d].tail_size); |
83 | } |
84 | |
85 | int n(int d) const { |
86 | assert(d < ndims); |
87 | return static_cast<int>(nodes[d].n); |
88 | } |
89 | int is(int d) const { |
90 | assert(d < ndims); |
91 | return static_cast<int>(nodes[d].is); |
92 | } |
93 | int os(int d) const { |
94 | assert(d < ndims); |
95 | return static_cast<int>(nodes[d].os); |
96 | } |
97 | int ss(int d) const { |
98 | assert(d < ndims); |
99 | return static_cast<int>(nodes[d].ss); |
100 | } |
101 | |
102 | int cs(int d) const { |
103 | assert(d < ndims); |
104 | return static_cast<int>(nodes[d].cs); |
105 | } |
106 | |
107 | data_type_t itype; |
108 | data_type_t otype; |
109 | int ndims; |
110 | node_t nodes[max_ndims]; |
111 | ptrdiff_t ioff; |
112 | ptrdiff_t ooff; |
113 | scale_type_t src_scale_type; |
114 | scale_type_t dst_scale_type; |
115 | float beta; |
116 | int full_ndims; |
117 | bool is_tail_present = false; |
118 | float scale_adjust = 1.f; |
119 | int compensation_mask = invalid_comp_mask; |
120 | bool req_s8s8_comp = false; |
121 | bool req_asymmetric_comp = false; |
122 | bool req_src_zp = false; |
123 | bool req_dst_zp = false; |
124 | }; |
125 | |
126 | status_t prb_init(prb_t &prb, const memory_desc_t &imd, |
127 | const memory_desc_t &omd, const primitive_attr_t *attr); |
128 | |
129 | /** sorts the problem nodes so that output strides come in ascending order */ |
130 | void prb_normalize(prb_t &p); |
131 | |
132 | /** fill parent node info for blocked nodes */ |
133 | void prb_node_dependency(prb_t &p); |
134 | |
135 | /** folds nodes together if possible */ |
136 | void prb_simplify(prb_t &p); |
137 | |
138 | /** splits the node dim into two of sizes n1 and n / n1 |
139 | * @warning n must be multiple of n1 */ |
140 | void prb_node_split(prb_t &p, int dim, size_t n1); |
141 | |
142 | /** swaps d0 and d1 nodes */ |
143 | void prb_node_swap(prb_t &p, int d0, int d1); |
144 | |
145 | /** moves node d0 to the d1 position. |
146 | * nodes (d0, d1] are shifted to the left if d0 < d1 or |
147 | * to the right if d0 > d1 */ |
148 | void prb_node_move(prb_t &p, int d0, int d1); |
149 | |
150 | /** dumps the problem to stdout */ |
151 | void prb_dump(const prb_t &p); |
152 | |
153 | struct call_param_t { |
154 | const void *in = nullptr; |
155 | void *out = nullptr; |
156 | const float *src_scales = nullptr; |
157 | const float *dst_scales = nullptr; |
158 | int32_t src_zp = 0; |
159 | int32_t dst_zp = 0; |
160 | int32_t *compensation_scratch = nullptr; |
161 | }; |
162 | |
163 | // The additional structure is needed because |
164 | // using a data structure with tail processing |
165 | // data for non-tail cases reduces kernel |
166 | // performance. This is because there is too |
167 | // much data that has to be transferred to the kernel. |
168 | struct tail_call_param_t { |
169 | call_param_t base_params; |
170 | int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1}; |
171 | int64_t zeroing_data = static_cast<int64_t>(false); |
172 | int64_t skip_kernel_execution = static_cast<int64_t>(false); |
173 | }; |
174 | |
175 | struct kernel_t { |
176 | struct desc_t { |
177 | int id; |
178 | prb_t prb; |
179 | }; |
180 | |
181 | kernel_t(const desc_t &desc) |
182 | : desc_(desc) |
183 | , compensation_needed_( |
184 | desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {} |
185 | virtual void operator()(const call_param_t *c) const = 0; |
186 | virtual void operator()(const tail_call_param_t *c) const = 0; |
187 | virtual status_t create_kernel() = 0; |
188 | virtual ~kernel_t() {} |
189 | |
190 | /** inits kernel descriptor: |
191 | * desc -- kernel descriptor (output) |
192 | * prb -- transposition problem (input) |
193 | * ndims_ker_max -- limit the maximum number of dimensions kernel |
194 | * will process (optional, 0 -- no limitation) */ |
195 | static status_t desc_init( |
196 | desc_t &desc, const prb_t &prb, int ndims_ker_max = 0); |
197 | |
198 | /** creates kernel for the problem described in desc */ |
199 | static kernel_t *create(const desc_t &desc); |
200 | |
201 | protected: |
202 | const desc_t desc_; |
203 | const prb_t &prb_ = desc_.prb; |
204 | bool compensation_needed_ = false; |
205 | }; |
206 | |
207 | /* TODO: add trans_t class */ |
208 | |
209 | struct jit_single_blk_kernel_t; |
210 | |
211 | } // namespace tr |
212 | |
213 | struct jit_uni_reorder_t : public primitive_t { |
214 | using primitive_t::primitive_t; |
215 | struct pd_t : public cpu_reorder_pd_t { |
216 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
217 | |
218 | DECLARE_COMMON_PD_T("jit:uni" , jit_uni_reorder_t); |
219 | |
220 | tr::prb_t prb_; |
221 | tr::kernel_t::desc_t ker_desc_; |
222 | int nthr_; |
223 | bool with_groups_ = false; |
224 | dim_t D_mask_ = 0; |
225 | |
226 | status_t init( |
227 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine); |
228 | |
229 | private: |
230 | status_t init_scratchpad(); |
231 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
232 | const primitive_attr_t *attr, engine_t *src_engine, |
233 | const memory_desc_t *src_md, engine_t *dst_engine, |
234 | const memory_desc_t *dst_md); |
235 | |
236 | friend dnnl::impl::impl_list_item_t; |
237 | }; |
238 | |
239 | status_t init(engine_t *engine) override; |
240 | status_t execute(const exec_ctx_t &ctx) const override; |
241 | |
242 | enum { ndims_driver_max = 4 }; |
243 | |
244 | private: |
245 | void omp_driver_0d(int off, const char *in, char *out, |
246 | const float *src_scales, const float *dst_scales, int src_zp, |
247 | int dst_zp, int32_t *compensation_scratch) const; |
248 | void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, |
249 | const float *src_scales, const float *dst_scales, int src_zp, |
250 | int dst_zp, int32_t *compensation_scratch) const; |
251 | void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, |
252 | const float *src_scales, const float *dst_scales, int src_zp, |
253 | int dst_zp, int32_t *compensation_scratch) const; |
254 | void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, |
255 | const float *src_scales, const float *dst_scales, int src_zp, |
256 | int dst_zp, int32_t *compensation_scratch) const; |
257 | void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, |
258 | const float *src_scales, const float *dst_scales, int src_zp, |
259 | int dst_zp, int32_t *compensation_scratch) const; |
260 | |
261 | void omp_driver(const char *in, char *out, const float *src_scales, |
262 | const float *dst_scales, int src_zp, int dst_zp, |
263 | const memory_tracking::grantor_t &scratchpad) const; |
264 | |
265 | void fill_curr_data_chunks(const tr::prb_t &prb, const int off, |
266 | const ptrdiff_t *omp_data_chunks, const int omp_ndims, |
267 | tr::tail_call_param_t &c) const; |
268 | |
269 | void reduce_compensation(char *out, |
270 | const int32_t *compensation_reduce_scratch, const int nthr, |
271 | const dim_t wspace_per_thr_size) const; |
272 | |
273 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
274 | std::unique_ptr<tr::kernel_t> kernel_; |
275 | }; |
276 | |
277 | struct jit_blk_reorder_t : public primitive_t { |
278 | using primitive_t::primitive_t; |
279 | struct pd_t : public cpu_reorder_pd_t { |
280 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
281 | DECLARE_COMMON_PD_T("jit:blk" , jit_blk_reorder_t); |
282 | |
283 | tr::prb_t prb_; |
284 | |
285 | private: |
286 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
287 | const primitive_attr_t *attr, engine_t *src_engine, |
288 | const memory_desc_t *src_md, engine_t *dst_engine, |
289 | const memory_desc_t *dst_md); |
290 | |
291 | // Swap last two nodes, put block 4, 8, 16 nodes to first |
292 | static void prb_tile_normalize(tr::prb_t &p); |
293 | friend dnnl::impl::impl_list_item_t; |
294 | }; |
295 | |
296 | status_t init(engine_t *engine) override; |
297 | status_t execute(const exec_ctx_t &ctx) const override; |
298 | |
299 | jit_blk_reorder_t(const pd_t *apd); |
300 | ~jit_blk_reorder_t(); |
301 | |
302 | private: |
303 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
304 | std::unique_ptr<tr::jit_single_blk_kernel_t> kernel_; |
305 | }; |
306 | |
307 | } // namespace x64 |
308 | } // namespace cpu |
309 | } // namespace impl |
310 | } // namespace dnnl |
311 | |
312 | #endif |
313 | |