1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_REORDER_HPP
18#define CPU_X64_JIT_UNI_REORDER_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/reorder/cpu_reorder_pd.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32namespace tr {
33
34constexpr int max_ndims = DNNL_MAX_NDIMS;
35
36struct node_t {
37 static constexpr int64_t empty_field = -1;
38
39 size_t n = 0;
40 size_t tail_size = 0;
41 int dim_id = empty_field;
42 int parent_node_id = empty_field;
43 bool is_zero_pad_needed = false;
44 ptrdiff_t is = 0; // input stride
45 ptrdiff_t os = 0; // output stride
46 ptrdiff_t ss = 0; // scale stride
47 ptrdiff_t cs = 0; // compensation stride
48
49 bool is_dim_id_empty() const { return dim_id == empty_field; }
50 bool is_parent_empty() const { return parent_node_id == empty_field; }
51};
52
53enum class scale_type_t { NONE, COMMON, MANY };
54
55struct prb_t {
56 /* The compensation mask value indicates how big an additional buffer should be.
57 * Possible values for reorder:
58 * 1) standard compensation = 1 = 0b01
59 * 2) asymmetric compensation = 2 = 0b10
60 * 3) compensation if tensor contains group = 3 = 0b11 */
61 static constexpr int invalid_comp_mask = 0;
62 static constexpr int standard_comp_mask = 0b1;
63 static constexpr int asymmetric_comp_mask = 0b10;
64 static constexpr int comp_mask_with_groups
65 = standard_comp_mask + asymmetric_comp_mask;
66
67 bool is_tail_in_one_of_child_nodes(int parent_node_id) const {
68 for (int i = parent_node_id; i >= 0; i--) {
69 if (nodes[i].parent_node_id == parent_node_id) {
70 if (nodes[i].tail_size != 0)
71 return true;
72 else
73 parent_node_id = i;
74 }
75 }
76
77 return false;
78 }
79
80 int tail(int d) const {
81 assert(d < ndims);
82 return static_cast<int>(nodes[d].tail_size);
83 }
84
85 int n(int d) const {
86 assert(d < ndims);
87 return static_cast<int>(nodes[d].n);
88 }
89 int is(int d) const {
90 assert(d < ndims);
91 return static_cast<int>(nodes[d].is);
92 }
93 int os(int d) const {
94 assert(d < ndims);
95 return static_cast<int>(nodes[d].os);
96 }
97 int ss(int d) const {
98 assert(d < ndims);
99 return static_cast<int>(nodes[d].ss);
100 }
101
102 int cs(int d) const {
103 assert(d < ndims);
104 return static_cast<int>(nodes[d].cs);
105 }
106
107 data_type_t itype;
108 data_type_t otype;
109 int ndims;
110 node_t nodes[max_ndims];
111 ptrdiff_t ioff;
112 ptrdiff_t ooff;
113 scale_type_t src_scale_type;
114 scale_type_t dst_scale_type;
115 float beta;
116 int full_ndims;
117 bool is_tail_present = false;
118 float scale_adjust = 1.f;
119 int compensation_mask = invalid_comp_mask;
120 bool req_s8s8_comp = false;
121 bool req_asymmetric_comp = false;
122 bool req_src_zp = false;
123 bool req_dst_zp = false;
124};
125
126status_t prb_init(prb_t &prb, const memory_desc_t &imd,
127 const memory_desc_t &omd, const primitive_attr_t *attr);
128
129/** sorts the problem nodes so that output strides come in ascending order */
130void prb_normalize(prb_t &p);
131
132/** fill parent node info for blocked nodes */
133void prb_node_dependency(prb_t &p);
134
135/** folds nodes together if possible */
136void prb_simplify(prb_t &p);
137
138/** splits the node dim into two of sizes n1 and n / n1
139 * @warning n must be multiple of n1 */
140void prb_node_split(prb_t &p, int dim, size_t n1);
141
142/** swaps d0 and d1 nodes */
143void prb_node_swap(prb_t &p, int d0, int d1);
144
145/** moves node d0 to the d1 position.
146 * nodes (d0, d1] are shifted to the left if d0 < d1 or
147 * to the right if d0 > d1 */
148void prb_node_move(prb_t &p, int d0, int d1);
149
150/** dumps the problem to stdout */
151void prb_dump(const prb_t &p);
152
153struct call_param_t {
154 const void *in = nullptr;
155 void *out = nullptr;
156 const float *src_scales = nullptr;
157 const float *dst_scales = nullptr;
158 int32_t src_zp = 0;
159 int32_t dst_zp = 0;
160 int32_t *compensation_scratch = nullptr;
161};
162
163// The additional structure is needed because
164// using a data structure with tail processing
165// data for non-tail cases reduces kernel
166// performance. This is because there is too
167// much data that has to be transferred to the kernel.
168struct tail_call_param_t {
169 call_param_t base_params;
170 int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1};
171 int64_t zeroing_data = static_cast<int64_t>(false);
172 int64_t skip_kernel_execution = static_cast<int64_t>(false);
173};
174
175struct kernel_t {
176 struct desc_t {
177 int id;
178 prb_t prb;
179 };
180
181 kernel_t(const desc_t &desc)
182 : desc_(desc)
183 , compensation_needed_(
184 desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {}
185 virtual void operator()(const call_param_t *c) const = 0;
186 virtual void operator()(const tail_call_param_t *c) const = 0;
187 virtual status_t create_kernel() = 0;
188 virtual ~kernel_t() {}
189
190 /** inits kernel descriptor:
191 * desc -- kernel descriptor (output)
192 * prb -- transposition problem (input)
193 * ndims_ker_max -- limit the maximum number of dimensions kernel
194 * will process (optional, 0 -- no limitation) */
195 static status_t desc_init(
196 desc_t &desc, const prb_t &prb, int ndims_ker_max = 0);
197
198 /** creates kernel for the problem described in desc */
199 static kernel_t *create(const desc_t &desc);
200
201protected:
202 const desc_t desc_;
203 const prb_t &prb_ = desc_.prb;
204 bool compensation_needed_ = false;
205};
206
207/* TODO: add trans_t class */
208
209struct jit_single_blk_kernel_t;
210
211} // namespace tr
212
213struct jit_uni_reorder_t : public primitive_t {
214 using primitive_t::primitive_t;
215 struct pd_t : public cpu_reorder_pd_t {
216 using cpu_reorder_pd_t::cpu_reorder_pd_t;
217
218 DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
219
220 tr::prb_t prb_;
221 tr::kernel_t::desc_t ker_desc_;
222 int nthr_;
223 bool with_groups_ = false;
224 dim_t D_mask_ = 0;
225
226 status_t init(
227 engine_t *engine, engine_t *src_engine, engine_t *dst_engine);
228
229 private:
230 status_t init_scratchpad();
231 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
232 const primitive_attr_t *attr, engine_t *src_engine,
233 const memory_desc_t *src_md, engine_t *dst_engine,
234 const memory_desc_t *dst_md);
235
236 friend dnnl::impl::impl_list_item_t;
237 };
238
239 status_t init(engine_t *engine) override;
240 status_t execute(const exec_ctx_t &ctx) const override;
241
242 enum { ndims_driver_max = 4 };
243
244private:
245 void omp_driver_0d(int off, const char *in, char *out,
246 const float *src_scales, const float *dst_scales, int src_zp,
247 int dst_zp, int32_t *compensation_scratch) const;
248 void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
249 const float *src_scales, const float *dst_scales, int src_zp,
250 int dst_zp, int32_t *compensation_scratch) const;
251 void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
252 const float *src_scales, const float *dst_scales, int src_zp,
253 int dst_zp, int32_t *compensation_scratch) const;
254 void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
255 const float *src_scales, const float *dst_scales, int src_zp,
256 int dst_zp, int32_t *compensation_scratch) const;
257 void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
258 const float *src_scales, const float *dst_scales, int src_zp,
259 int dst_zp, int32_t *compensation_scratch) const;
260
261 void omp_driver(const char *in, char *out, const float *src_scales,
262 const float *dst_scales, int src_zp, int dst_zp,
263 const memory_tracking::grantor_t &scratchpad) const;
264
265 void fill_curr_data_chunks(const tr::prb_t &prb, const int off,
266 const ptrdiff_t *omp_data_chunks, const int omp_ndims,
267 tr::tail_call_param_t &c) const;
268
269 void reduce_compensation(char *out,
270 const int32_t *compensation_reduce_scratch, const int nthr,
271 const dim_t wspace_per_thr_size) const;
272
273 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
274 std::unique_ptr<tr::kernel_t> kernel_;
275};
276
277struct jit_blk_reorder_t : public primitive_t {
278 using primitive_t::primitive_t;
279 struct pd_t : public cpu_reorder_pd_t {
280 using cpu_reorder_pd_t::cpu_reorder_pd_t;
281 DECLARE_COMMON_PD_T("jit:blk", jit_blk_reorder_t);
282
283 tr::prb_t prb_;
284
285 private:
286 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
287 const primitive_attr_t *attr, engine_t *src_engine,
288 const memory_desc_t *src_md, engine_t *dst_engine,
289 const memory_desc_t *dst_md);
290
291 // Swap last two nodes, put block 4, 8, 16 nodes to first
292 static void prb_tile_normalize(tr::prb_t &p);
293 friend dnnl::impl::impl_list_item_t;
294 };
295
296 status_t init(engine_t *engine) override;
297 status_t execute(const exec_ctx_t &ctx) const override;
298
299 jit_blk_reorder_t(const pd_t *apd);
300 ~jit_blk_reorder_t();
301
302private:
303 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
304 std::unique_ptr<tr::jit_single_blk_kernel_t> kernel_;
305};
306
307} // namespace x64
308} // namespace cpu
309} // namespace impl
310} // namespace dnnl
311
312#endif
313