1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | #include <set> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_desc_wrapper.hpp" |
23 | #include "common/nstl.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | #include "oneapi/dnnl/dnnl_debug.h" |
27 | |
28 | #include "cpu/x64/jit_uni_reorder.hpp" |
29 | |
30 | // #define TR_DEBUG |
31 | #if defined(TR_DEBUG) |
32 | #define DEBUg(...) \ |
33 | do { \ |
34 | __VA_ARGS__ \ |
35 | } while (0) |
36 | #else |
37 | #define DEBUg(...) |
38 | #endif |
39 | #define DEBUG(...) DEBUg(__VA_ARGS__) |
40 | |
41 | using namespace dnnl::impl::types; |
42 | using namespace dnnl::impl::status; |
43 | |
44 | namespace dnnl { |
45 | namespace impl { |
46 | namespace cpu { |
47 | namespace x64 { |
48 | |
49 | namespace tr { |
50 | |
51 | /** ad-hoc structure to describe blocked memory layout */ |
52 | struct layout_desc_t { |
53 | layout_desc_t() |
54 | : dt(dnnl_data_type_undef) |
55 | , ndims(0) |
56 | , id {-1} |
57 | , dims {0} |
58 | , tails {0} |
59 | , is_blk {false} |
60 | , strides {0} {} |
61 | data_type_t dt; |
62 | int ndims; |
63 | dims_t id; |
64 | dims_t dims; |
65 | dims_t tails; |
66 | bool is_blk[DNNL_MAX_NDIMS]; |
67 | strides_t strides; |
68 | }; |
69 | |
70 | status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, |
71 | layout_desc_t &ld, const dims_t &blocks, const dims_t &external_padding, |
72 | const dims_t &tails) { |
73 | static constexpr bool it_is_blk = true; |
74 | |
75 | const auto md = memory_desc_wrapper(md_); |
76 | |
77 | if (!md.is_blocking_desc()) return invalid_arguments; |
78 | |
79 | const auto &bd = md.blocking_desc(); |
80 | |
81 | ld.ndims = 0; |
82 | ld.dt = md.data_type(); |
83 | |
84 | auto add_dim = [&ld](int id, dim_t dim, dim_t tail, bool is_blk, |
85 | ptrdiff_t stride) { |
86 | assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); |
87 | ld.id[ld.ndims] = id; |
88 | ld.dims[ld.ndims] = dim; |
89 | ld.strides[ld.ndims] = stride; |
90 | ld.tails[ld.ndims] = tail; |
91 | ld.is_blk[ld.ndims] = is_blk; |
92 | ++ld.ndims; |
93 | }; |
94 | |
95 | for (int d = 0; d < md.ndims(); ++d) { |
96 | const int ld_ndims_start = ld.ndims; |
97 | if (blocks[d] != 1) { |
98 | stride_t stride = 1; |
99 | int tail = tails[d]; |
100 | for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { |
101 | if (bd.inner_idxs[iblk] == d) { |
102 | const dim_t inner_tail = tail % bd.inner_blks[iblk]; |
103 | add_dim(d, bd.inner_blks[iblk], inner_tail, it_is_blk, |
104 | stride); |
105 | tail = utils::div_up(tail, bd.inner_blks[iblk]); |
106 | } |
107 | stride *= bd.inner_blks[iblk]; |
108 | } |
109 | } |
110 | |
111 | const dim_t dim_with_external_padding |
112 | = (md.padded_dims()[d] + external_padding[d]) / blocks[d]; |
113 | const dim_t padded_dim = md.padded_dims()[d] / blocks[d]; |
114 | const dim_t tail = dim_with_external_padding != padded_dim |
115 | ? dim_with_external_padding |
116 | - (dim_with_external_padding - padded_dim) |
117 | : 0; |
118 | |
119 | add_dim(d, dim_with_external_padding, tail, !it_is_blk, bd.strides[d]); |
120 | |
121 | // TODO: NOW: revisit, do we need a reverse? |
122 | // TODO: NOW: consider using strides instead of block sizes in md |
123 | // reverse the order of dims |
124 | for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { |
125 | const int idx0 = ld_ndims_start + ld_d; |
126 | const int idx1 = ld.ndims - 1 - ld_d; |
127 | nstl::swap(ld.dims[idx0], ld.dims[idx1]); |
128 | nstl::swap(ld.strides[idx0], ld.strides[idx1]); |
129 | nstl::swap(ld.tails[idx0], ld.tails[idx1]); |
130 | nstl::swap(ld.is_blk[idx0], ld.is_blk[idx1]); |
131 | } |
132 | } |
133 | |
134 | return success; |
135 | } |
136 | |
137 | static bool is_with_groups(const memory_desc_t &dst_md) { |
138 | using namespace memory_extra_flags; |
139 | auto dst_d = memory_desc_wrapper(dst_md); |
140 | const int grp_bit = 1 << 1; |
141 | auto check_flag_and_mask = [&](int flag, int mask) { |
142 | return (dst_d.extra().flags & flag) && (mask & grp_bit); |
143 | }; |
144 | |
145 | return check_flag_and_mask( |
146 | compensation_conv_s8s8, dst_d.extra().compensation_mask) |
147 | || check_flag_and_mask(compensation_conv_asymmetric_src, |
148 | dst_d.extra().asymm_compensation_mask); |
149 | } |
150 | |
151 | static inline int get_next_parent_node(node_t *nodes, int ndims, int cur_node) { |
152 | const int cur_id = nodes[cur_node].dim_id; |
153 | for (int d = cur_node + 1; d < ndims; ++d) { |
154 | if (nodes[d].dim_id == cur_id) return d; |
155 | } |
156 | return -1; |
157 | } |
158 | |
159 | static void prb_set_compensation_strides(prb_t &p) { |
160 | |
161 | auto require_n_stride = [&](int cur_node) -> bool { |
162 | const int parent = get_next_parent_node(p.nodes, p.ndims, cur_node); |
163 | if (parent < 0) return false; |
164 | |
165 | const size_t p_n = p.nodes[parent].n; |
166 | |
167 | // if 'parent_node.n' is larger than 1, then cur_node stride |
168 | // is 'cur_node.n' |
169 | return p_n > size_t(1); |
170 | }; |
171 | |
172 | const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp; |
173 | if (!compensation_needed) return; |
174 | int mask = p.compensation_mask; |
175 | ptrdiff_t cs = 1; |
176 | for (int d = 0; d < p.ndims; ++d) { |
177 | if (mask & (1 << p.nodes[d].dim_id)) { |
178 | |
179 | // correct cases when 'cs' exceeds output stride |
180 | if (cs > p.nodes[d].os) cs = p.nodes[d].os; |
181 | |
182 | p.nodes[d].cs = cs; |
183 | const bool n_stride = require_n_stride(d); |
184 | if (p.nodes[d].tail_size > 0 && (!p.nodes[d].is_zero_pad_needed) |
185 | && (!n_stride)) |
186 | cs *= p.nodes[d].tail_size; |
187 | else |
188 | cs *= p.nodes[d].n; |
189 | } |
190 | } |
191 | } |
192 | |
193 | status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, |
194 | const primitive_attr_t *attr) { |
195 | auto im_d = memory_desc_wrapper(imd); |
196 | auto om_d = memory_desc_wrapper(omd); |
197 | |
198 | auto check_post_ops = [](const primitive_attr_t *attr) { |
199 | const auto &po = attr->post_ops_; |
200 | return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false)); |
201 | }; |
202 | |
203 | bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc() |
204 | && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim() |
205 | && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim() |
206 | && attr->has_default_values( |
207 | primitive_attr_t::skip_mask_t::scales_runtime |
208 | | primitive_attr_t::skip_mask_t::zero_points_runtime |
209 | | primitive_attr_t::skip_mask_t::post_ops) |
210 | && check_post_ops(attr); |
211 | if (!ok) return unimplemented; |
212 | |
213 | bool is_tail_present = false; |
214 | dims_t iblocks, oblocks, i_tails, o_tails, i_paddings, o_paddings; |
215 | im_d.compute_blocks(iblocks); |
216 | om_d.compute_blocks(oblocks); |
217 | |
218 | for (int d = 0; d < om_d.ndims(); ++d) { |
219 | const auto dim = om_d.dims()[d]; |
220 | const auto pdim = om_d.padded_dims()[d]; |
221 | const auto cblock = oblocks[d]; |
222 | // do not allow excess pdim other than required for rounding-up of dim. |
223 | if (utils::rnd_up(dim, cblock) != pdim) return unimplemented; |
224 | } |
225 | |
226 | utils::array_set(i_tails, 0, im_d.ndims()); |
227 | utils::array_set(o_tails, 0, om_d.ndims()); |
228 | utils::array_set(i_paddings, 0, im_d.ndims()); |
229 | utils::array_set(o_paddings, 0, om_d.ndims()); |
230 | |
231 | for (int d = 0; d < im_d.ndims(); ++d) { |
232 | const dim_t i_dim = im_d.dims()[d]; |
233 | const dim_t o_dim = om_d.dims()[d]; |
234 | const dim_t i_tail = i_dim % iblocks[d]; |
235 | const dim_t o_tail = o_dim % oblocks[d]; |
236 | |
237 | if (o_tail > 0) { |
238 | is_tail_present = true; |
239 | o_tails[d] = o_tail; |
240 | o_paddings[d] = oblocks[d] - o_tail; |
241 | } |
242 | |
243 | if (i_tail > 0) { |
244 | is_tail_present = true; |
245 | i_tails[d] = i_tail; |
246 | i_paddings[d] = iblocks[d] - i_tail; |
247 | } |
248 | } |
249 | |
250 | // To compute input layout description we need to pass output paddings |
251 | // which will be used to compute input dims rounded up to multiple of |
252 | // output dims. Analogous applies to output layout description. |
253 | // This is demanded by the algorithm of nodes creation. |
254 | // Example: |
255 | // input: |
256 | // format: abc |
257 | // size: 77, 15, 3 |
258 | // o_padding: 3, 17, 0 |
259 | // returns ild: 80, 32, 3 |
260 | // output: |
261 | // format: ABc16b16a2b |
262 | // size: 77, 15, 3 |
263 | // i_padding: 0, 0, 0 |
264 | // returns old: 5, 16, 1, 16, 2, 3 |
265 | layout_desc_t ild, old; |
266 | CHECK(cvt_mem_desc_to_layout_desc(imd, ild, iblocks, o_paddings, i_tails)); |
267 | CHECK(cvt_mem_desc_to_layout_desc(omd, old, oblocks, i_paddings, o_tails)); |
268 | |
269 | p.itype = ild.dt; |
270 | p.otype = old.dt; |
271 | p.is_tail_present = is_tail_present; |
272 | p.req_src_zp = !attr->zero_points_.has_default_values(DNNL_ARG_SRC); |
273 | p.req_dst_zp = !attr->zero_points_.has_default_values(DNNL_ARG_DST); |
274 | |
275 | p.src_scale_type = scale_type_t::NONE; |
276 | int src_mask = 0; |
277 | bool is_src_set = false; |
278 | CHECK(attr->scales_.get(DNNL_ARG_SRC, &src_mask, &is_src_set)); |
279 | if (is_src_set) { |
280 | p.src_scale_type |
281 | = src_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; |
282 | } |
283 | |
284 | p.dst_scale_type = scale_type_t::NONE; |
285 | int dst_mask = 0; |
286 | bool is_dst_set = false; |
287 | CHECK(attr->scales_.get(DNNL_ARG_DST, &dst_mask, &is_dst_set)); |
288 | if (is_dst_set) { |
289 | p.dst_scale_type |
290 | = dst_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; |
291 | } |
292 | |
293 | if (is_src_set && is_dst_set && src_mask != dst_mask) |
294 | return status::unimplemented; |
295 | |
296 | p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) |
297 | ? om_d.extra().scale_adjust |
298 | : 1.f; |
299 | p.req_s8s8_comp |
300 | = om_d.extra().flags & memory_extra_flags::compensation_conv_s8s8; |
301 | p.req_asymmetric_comp = om_d.extra().flags |
302 | & memory_extra_flags::compensation_conv_asymmetric_src; |
303 | |
304 | const bool with_groups = is_with_groups(omd); |
305 | |
306 | auto mask_ok = [&](bool check, int mask) { |
307 | return IMPLICATION(check, mask == (with_groups ? 0x3 : 0x1)); |
308 | }; |
309 | |
310 | if (!mask_ok(p.req_s8s8_comp, om_d.extra().compensation_mask) |
311 | || !mask_ok(p.req_asymmetric_comp, |
312 | om_d.extra().asymm_compensation_mask)) |
313 | return status::unimplemented; |
314 | |
315 | ptrdiff_t ss[max_ndims] = {0}; // scales strides |
316 | if (p.src_scale_type == scale_type_t::MANY |
317 | || p.dst_scale_type == scale_type_t::MANY) { |
318 | const int mask = nstl::max(src_mask, dst_mask); |
319 | ptrdiff_t dense_stride = 1; |
320 | ptrdiff_t last_stride = 1; |
321 | for (int d = old.ndims - 1; d >= 0; --d) { |
322 | assert((d == 0 || old.id[d - 1] <= old.id[d]) |
323 | && "logical dimensions should be in ascending order" ); |
324 | if (mask & (1 << old.id[d])) { |
325 | if ((d + 1) < old.ndims && old.id[d + 1] != old.id[d] |
326 | && (mask & (1 << old.id[d + 1]))) { |
327 | dense_stride = dense_stride * imd.dims[old.id[d + 1]]; |
328 | last_stride = dense_stride; |
329 | } |
330 | ss[d] = last_stride; |
331 | last_stride *= old.dims[d]; |
332 | } |
333 | } |
334 | } |
335 | |
336 | const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp; |
337 | if (compensation_needed) { |
338 | p.compensation_mask = p.req_s8s8_comp |
339 | ? om_d.extra().compensation_mask |
340 | : (p.req_asymmetric_comp ? om_d.extra().asymm_compensation_mask |
341 | : tr::prb_t::invalid_comp_mask); |
342 | |
343 | if (p.compensation_mask == tr::prb_t::asymmetric_comp_mask) |
344 | return unimplemented; |
345 | |
346 | assert(p.compensation_mask == tr::prb_t::standard_comp_mask |
347 | || p.compensation_mask == tr::prb_t::comp_mask_with_groups); |
348 | } |
349 | |
350 | int ndims = 0; |
351 | |
352 | int i_pos = 0; /* state for input -- current dimension */ |
353 | int o_pos = 0; /* state for output -- current dimension */ |
354 | |
355 | while (i_pos < ild.ndims && o_pos < old.ndims) { |
356 | assert(ild.id[i_pos] == old.id[o_pos]); |
357 | |
358 | assert(ndims < max_ndims); |
359 | if (ndims == max_ndims) return runtime_error; |
360 | |
361 | if (ild.dims[i_pos] == old.dims[o_pos]) { |
362 | p.nodes[ndims].n = ild.dims[i_pos]; |
363 | p.nodes[ndims].dim_id = old.id[o_pos]; |
364 | p.nodes[ndims].tail_size = old.tails[o_pos]; |
365 | p.nodes[ndims].is_zero_pad_needed |
366 | = old.is_blk[o_pos] && old.tails[o_pos] > 0; |
367 | p.nodes[ndims].is = ild.strides[i_pos]; |
368 | p.nodes[ndims].os = old.strides[o_pos]; |
369 | p.nodes[ndims].ss = ss[o_pos]; |
370 | ++ndims; |
371 | ++i_pos; |
372 | ++o_pos; |
373 | } else if (ild.dims[i_pos] < old.dims[o_pos]) { |
374 | // old must be divisible by ild or we will not be |
375 | // able to create valid nodes. The problem appears |
376 | // when stag=Acdb48a and dtag=Acdb32a for example. |
377 | if (ild.dims[i_pos] == 0 || old.dims[o_pos] % ild.dims[i_pos] != 0) |
378 | return status::unimplemented; |
379 | |
380 | dim_t factor = old.dims[o_pos] / ild.dims[i_pos]; |
381 | |
382 | const size_t tail_of_upper_dim |
383 | = utils::div_up(old.tails[o_pos], factor) == ild.dims[i_pos] |
384 | ? 0 |
385 | : utils::div_up(old.tails[o_pos], factor); |
386 | const size_t tail_of_lower_dim = old.tails[o_pos] % factor; |
387 | |
388 | p.nodes[ndims].n = ild.dims[i_pos]; |
389 | p.nodes[ndims].dim_id = old.id[o_pos]; |
390 | p.nodes[ndims].tail_size = tail_of_upper_dim; |
391 | p.nodes[ndims].is_zero_pad_needed |
392 | = old.is_blk[o_pos] && tail_of_upper_dim > 0; |
393 | p.nodes[ndims].is = ild.strides[i_pos]; |
394 | p.nodes[ndims].os = old.strides[o_pos] * factor; |
395 | p.nodes[ndims].ss = ss[o_pos] * factor; |
396 | ++ndims; |
397 | ++i_pos; |
398 | old.dims[o_pos] = factor; |
399 | old.tails[o_pos] = tail_of_lower_dim; |
400 | } else if (ild.dims[i_pos] > old.dims[o_pos]) { |
401 | // ild must be divisible by old or we will not be |
402 | // able to create valid nodes. The problem appears |
403 | // when stag=Acdb32a and dtag=Acdb48a for example. |
404 | if (old.dims[o_pos] == 0 || ild.dims[i_pos] % old.dims[o_pos] != 0) |
405 | return status::unimplemented; |
406 | |
407 | dim_t factor = ild.dims[i_pos] / old.dims[o_pos]; |
408 | p.nodes[ndims].n = old.dims[o_pos]; |
409 | p.nodes[ndims].dim_id = old.id[o_pos]; |
410 | p.nodes[ndims].tail_size = old.tails[o_pos]; |
411 | p.nodes[ndims].is_zero_pad_needed |
412 | = old.is_blk[o_pos] && old.tails[o_pos] > 0; |
413 | p.nodes[ndims].is = ild.strides[i_pos] * factor; |
414 | p.nodes[ndims].os = old.strides[o_pos]; |
415 | p.nodes[ndims].ss = ss[o_pos]; |
416 | ++ndims; |
417 | ++o_pos; |
418 | ild.dims[i_pos] = factor; |
419 | } |
420 | } |
421 | |
422 | p.ndims = ndims; |
423 | p.full_ndims = ndims; |
424 | |
425 | p.ioff = memory_desc_wrapper(imd).offset0(); |
426 | p.ooff = memory_desc_wrapper(omd).offset0(); |
427 | |
428 | const int sum_idx = attr->post_ops_.find(primitive_kind::sum); |
429 | p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; |
430 | |
431 | DEBUG({ |
432 | printf("init : " ); |
433 | prb_dump(prb); |
434 | }); |
435 | // Sort the prb array in increasing sizes of the output stride |
436 | prb_normalize(p); |
437 | DEBUG({ |
438 | printf("norm : " ); |
439 | prb_dump(prb); |
440 | }); |
441 | |
442 | // compensation strides require prb_normalized |
443 | prb_set_compensation_strides(p); |
444 | |
445 | /* Combine the variables, which appear together on both |
446 | * sides of the reorder */ |
447 | prb_simplify(p); |
448 | DEBUG({ |
449 | printf("smpl : " ); |
450 | prb_dump(prb); |
451 | }); |
452 | |
453 | return success; |
454 | } |
455 | |
456 | void prb_normalize(prb_t &p) { |
457 | for (int d = 0; d < p.ndims; ++d) { |
458 | int min_pos = d; |
459 | for (int j = d + 1; j < p.ndims; ++j) { |
460 | bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os |
461 | || (true && p.nodes[j].os == p.nodes[min_pos].os |
462 | && p.nodes[j].n < p.nodes[min_pos].n); |
463 | if (new_min) min_pos = j; |
464 | } |
465 | if (min_pos != d) { nstl::swap(p.nodes[d], p.nodes[min_pos]); } |
466 | } |
467 | } |
468 | |
469 | void prb_node_dependency(prb_t &prb) { |
470 | for (int i = 0; i < prb.ndims; i++) { |
471 | tr::node_t &node = prb.nodes[i]; |
472 | node.parent_node_id = node_t::empty_field; |
473 | for (int j = i + 1; j < prb.ndims; j++) { |
474 | const tr::node_t &potential_parent_node = prb.nodes[j]; |
475 | if (!potential_parent_node.is_dim_id_empty() |
476 | && potential_parent_node.dim_id == node.dim_id) { |
477 | node.parent_node_id = j; |
478 | break; |
479 | } |
480 | } |
481 | } |
482 | } |
483 | |
484 | void prb_simplify(prb_t &p) { |
485 | #if defined(__GNUC__) && __GNUC__ >= 4 |
486 | /* GCC produces bogus array subscript is above array bounds warning for |
487 | * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ |
488 | #pragma GCC diagnostic push |
489 | #pragma GCC diagnostic ignored "-Warray-bounds" |
490 | #endif |
491 | |
492 | const auto skip_dim_combining = [&p](const int node_id) -> bool { |
493 | return (p.is_tail_in_one_of_child_nodes(node_id) |
494 | && p.nodes[node_id].n > 1) |
495 | || p.nodes[node_id].tail_size > 0; |
496 | }; |
497 | |
498 | if (p.is_tail_present) prb_node_dependency(p); |
499 | |
500 | for (int d = 0; d < p.ndims - 1; ++d) { |
501 | auto &this_node = p.nodes[d + 0]; |
502 | auto &next_node = p.nodes[d + 1]; |
503 | const bool skip_dims_combining |
504 | = skip_dim_combining(d) || skip_dim_combining(d + 1); |
505 | const bool fold = false |
506 | || (next_node.n == static_cast<size_t>(1) |
507 | && !skip_dims_combining) // trivial case, just drop next node |
508 | || (true // or real folding if possible |
509 | && !skip_dims_combining |
510 | && next_node.is |
511 | == static_cast<ptrdiff_t>( |
512 | this_node.n * this_node.is) |
513 | && next_node.os |
514 | == static_cast<ptrdiff_t>( |
515 | this_node.n * this_node.os) |
516 | && next_node.ss |
517 | == static_cast<ptrdiff_t>( |
518 | this_node.n * this_node.ss) |
519 | && next_node.cs |
520 | == static_cast<ptrdiff_t>( |
521 | this_node.n * this_node.cs)); |
522 | if (fold) { |
523 | this_node.n *= next_node.n; |
524 | this_node.dim_id = node_t::empty_field; |
525 | this_node.is_zero_pad_needed = false; |
526 | for (int j = d + 2; j < p.ndims; ++j) |
527 | p.nodes[j - 1] = p.nodes[j]; |
528 | --p.ndims; |
529 | --p.full_ndims; |
530 | --d; // make another try |
531 | if (p.is_tail_present) prb_node_dependency(p); |
532 | } |
533 | } |
534 | #if defined(__GNUC__) && __GNUC__ >= 4 |
535 | #pragma GCC diagnostic pop |
536 | #endif |
537 | } |
538 | |
539 | void prb_node_split(prb_t &p, int dim, size_t new_node_size) { |
540 | assert(dim < p.ndims); |
541 | assert(p.ndims < max_ndims); |
542 | assert(p.nodes[dim].n % new_node_size == 0); |
543 | |
544 | p.ndims += 1; |
545 | p.full_ndims += 1; |
546 | |
547 | for (int d = p.ndims; d > dim + 1; --d) |
548 | p.nodes[d] = p.nodes[d - 1]; |
549 | |
550 | const size_t upper_node_size = p.nodes[dim].n / new_node_size; |
551 | const size_t lower_node_size = new_node_size; |
552 | p.nodes[dim + 1].n = upper_node_size; |
553 | p.nodes[dim].n = lower_node_size; |
554 | |
555 | const bool is_tail = p.nodes[dim].tail_size > 0; |
556 | const size_t upper_node_tail |
557 | = utils::div_up(p.nodes[dim].tail_size, lower_node_size) |
558 | == upper_node_size |
559 | ? 0 |
560 | : utils::div_up(p.nodes[dim].tail_size, lower_node_size); |
561 | const size_t lower_node_tail = p.nodes[dim].tail_size % lower_node_size; |
562 | p.nodes[dim].tail_size = is_tail ? lower_node_tail : 0; |
563 | p.nodes[dim + 1].tail_size = is_tail ? upper_node_tail : 0; |
564 | |
565 | p.nodes[dim + 1].is_zero_pad_needed |
566 | = p.nodes[dim].is_zero_pad_needed && p.nodes[dim + 1].tail_size > 0; |
567 | p.nodes[dim].is_zero_pad_needed |
568 | = p.nodes[dim].is_zero_pad_needed && p.nodes[dim].tail_size > 0; |
569 | |
570 | p.nodes[dim + 1].dim_id = p.nodes[dim].dim_id; |
571 | p.nodes[dim + 1].is = p.nodes[dim].is * lower_node_size; |
572 | p.nodes[dim + 1].os = p.nodes[dim].os * lower_node_size; |
573 | p.nodes[dim + 1].ss = p.nodes[dim].ss * lower_node_size; |
574 | p.nodes[dim + 1].cs = p.nodes[dim].cs * lower_node_size; |
575 | } |
576 | |
577 | void prb_node_swap(prb_t &p, int d0, int d1) { |
578 | assert(d0 < p.ndims); |
579 | assert(d1 < p.ndims); |
580 | assert(p.ndims < max_ndims); |
581 | |
582 | if (d0 == d1) return; |
583 | |
584 | nstl::swap(p.nodes[d0], p.nodes[d1]); |
585 | } |
586 | |
587 | void prb_node_move(prb_t &p, int d0, int d1) { |
588 | assert(d0 < p.ndims); |
589 | assert(d1 < p.ndims); |
590 | assert(p.ndims < max_ndims); |
591 | |
592 | if (d0 == d1) return; |
593 | |
594 | node_t node = p.nodes[d0]; |
595 | |
596 | if (d0 < d1) |
597 | for (int d = d0; d < d1; ++d) |
598 | p.nodes[d] = p.nodes[d + 1]; |
599 | else |
600 | for (int d = d0; d > d1; --d) |
601 | p.nodes[d] = p.nodes[d - 1]; |
602 | |
603 | p.nodes[d1] = node; |
604 | } |
605 | |
606 | void prb_dump(const prb_t &p) { |
607 | printf("@@@ type:%s:%s ndims:%d " , dnnl_dt2str(p.itype), |
608 | dnnl_dt2str(p.otype), p.ndims); |
609 | for (int d = 0; d < p.ndims; ++d) |
610 | printf("[%zu:%zu:%d:%d:%s:%td:%td:%td:%td]" , p.nodes[d].n, |
611 | p.nodes[d].tail_size, p.nodes[d].dim_id, |
612 | p.nodes[d].parent_node_id, |
613 | p.nodes[d].is_zero_pad_needed ? "true" : "false" , p.nodes[d].is, |
614 | p.nodes[d].os, p.nodes[d].ss, p.nodes[d].cs); |
615 | printf(" off:%zu:%zu\n" , p.ioff, p.ooff); |
616 | } |
617 | |
618 | } // namespace tr |
619 | |
620 | } // namespace x64 |
621 | } // namespace cpu |
622 | } // namespace impl |
623 | } // namespace dnnl |
624 | |