1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cctype> |
18 | #include <sstream> |
19 | #include <thread> |
20 | |
21 | #include "gpu/jit/ir/tensor.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | layout_t::layout_t(const type_t &type, const expr_t &offset, |
29 | const std::vector<std::pair<int, dim_t>> &parts, |
30 | const std::vector<dim_t> &dims, bool do_normalize) |
31 | : type_(type), offset_(offset) { |
32 | ndims_ = 0; |
33 | for (auto &p : parts) { |
34 | int dim_idx = p.first; |
35 | dim_t block = p.second; |
36 | ndims_ = std::max(ndims_, dim_idx + 1); |
37 | if (block == 0 && dims.empty()) |
38 | ir_error_not_expected() |
39 | << "Dimensions are missing. Can't deduce them from " |
40 | "the format." ; |
41 | } |
42 | if (!dims.empty() && ndims_ != int(dims.size())) { |
43 | ir_error_not_expected() << "Format and dimensions do not match." ; |
44 | } |
45 | |
46 | dim_t stride = 1; |
47 | // Iterate from right to left (innermost to outermost). |
48 | for (auto it = parts.rbegin(); it != parts.rend(); ++it) { |
49 | int dim_idx = it->first; |
50 | dim_t block = it->second; |
51 | if (block == 0) { |
52 | dim_t full_block = 1; |
53 | for (auto &b : blocks_) |
54 | if (b.dim_idx == dim_idx) full_block *= b.block; |
55 | |
56 | block = utils::div_up(dims[dim_idx], full_block); |
57 | } |
58 | |
59 | blocks_.emplace_back(dim_idx, block, stride); |
60 | stride = block * stride; |
61 | } |
62 | |
63 | if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_); |
64 | sanity_check(); |
65 | } |
66 | |
67 | layout_t::layout_t(const memory_desc_wrapper &mdw, bool do_normalize) |
68 | : type_(mdw.data_type()), offset_(mdw.offset0()) { |
69 | ir_assert(mdw.is_blocking_desc()) << "Expected blocking memory descriptor." ; |
70 | |
71 | ndims_ = mdw.ndims(); |
72 | auto &blocking = mdw.blocking_desc(); |
73 | auto *padded_dims = mdw.padded_dims(); |
74 | |
75 | dim_t stride = 1; |
76 | std::vector<dim_t> full_blocks(ndims_, 1); |
77 | for (int i = blocking.inner_nblks - 1; i >= 0; i--) { |
78 | int dim_idx = blocking.inner_idxs[i]; |
79 | dim_t block = blocking.inner_blks[i]; |
80 | blocks_.emplace_back(dim_idx, block, stride); |
81 | stride *= block; |
82 | full_blocks[dim_idx] *= block; |
83 | } |
84 | |
85 | for (int i = 0; i < ndims_; i++) { |
86 | dim_t block = padded_dims[i] / full_blocks[i]; |
87 | blocks_.emplace_back(i, block, blocking.strides[i]); |
88 | } |
89 | |
90 | // Sort outer blocks by their stride. |
91 | std::sort(blocks_.begin() + blocking.inner_nblks, blocks_.end(), |
92 | [](const block_t &a, const block_t &b) { |
93 | return a.stride < b.stride |
94 | || (a.stride == b.stride && a.block < b.block); |
95 | }); |
96 | |
97 | if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_); |
98 | sanity_check(); |
99 | } |
100 | |
101 | memory_desc_t layout_t::to_dnnl(const dim_t *dims_hint) const { |
102 | memory_desc_t md = {}; |
103 | md.ndims = ndims(); |
104 | std::copy(dims_hint, dims_hint + ndims(), md.dims); |
105 | md.data_type = jit::to_dnnl(type_); |
106 | md.offset0 = to_cpp<dim_t>(offset_); |
107 | md.format_kind = format_kind::blocked; |
108 | |
109 | auto &blk = md.format_desc.blocking; |
110 | bool seen[DNNL_MAX_NDIMS] = {}; |
111 | |
112 | bool in_inner_block = false; |
113 | dim_t prev_stride = 0; |
114 | |
115 | for (auto it = blocks_.rbegin(); it != blocks_.rend(); ++it) { |
116 | auto &b = *it; |
117 | if (!seen[b.dim_idx]) { |
118 | // Outer block. |
119 | ir_assert(!in_inner_block); |
120 | MAYBE_UNUSED(in_inner_block); |
121 | blk.strides[b.dim_idx] = b.stride; |
122 | md.padded_dims[b.dim_idx] = b.block; |
123 | } else { |
124 | // Inner block. |
125 | md.padded_dims[b.dim_idx] *= b.block; |
126 | blk.inner_idxs[blk.inner_nblks] = b.dim_idx; |
127 | blk.inner_blks[blk.inner_nblks] = b.block; |
128 | blk.inner_nblks++; |
129 | if (prev_stride > 0) { |
130 | // Inner block must be dense. |
131 | ir_assert(prev_stride == b.block * dim_t(b.stride)); |
132 | } |
133 | prev_stride = b.stride; |
134 | in_inner_block = true; |
135 | } |
136 | seen[b.dim_idx] = true; |
137 | } |
138 | |
139 | return md; |
140 | } |
141 | |
142 | layout_t layout_t::map(const tensor_t &tensor) const { |
143 | if (ndims() != tensor.ndims()) |
144 | ir_error_not_expected() << "Dimensions do not match." ; |
145 | |
146 | std::vector<dim_t> remaining_dims = tensor.dims(); |
147 | std::vector<block_t> mapped_blocks; |
148 | |
149 | for (auto &eb : enumerated_blocks()) { |
150 | block_t &b = eb.second; |
151 | bool b_is_outermost = is_outermost(eb); |
152 | |
153 | dim_t block = b.block; |
154 | dim_t &rem_dim = remaining_dims[b.dim_idx]; |
155 | if (rem_dim == 1) { |
156 | if (b_is_outermost) { |
157 | // This is to have similarity between the current and |
158 | // mapped layouts. |
159 | mapped_blocks.emplace_back(b.dim_idx, 1, b.stride); |
160 | } |
161 | continue; |
162 | } |
163 | if (b_is_outermost) { |
164 | block = rem_dim; |
165 | } else if (rem_dim % block != 0) { |
166 | // Try to split the current block and start mapping from |
167 | // scratch. |
168 | if (block % rem_dim == 0) |
169 | return split_block(eb, rem_dim, block / rem_dim).map(tensor); |
170 | |
171 | ir_error_not_expected() << "Can't map tensor layout." ; |
172 | } |
173 | rem_dim /= block; |
174 | mapped_blocks.emplace_back(b.dim_idx, block, b.stride); |
175 | } |
176 | |
177 | for (auto &d : remaining_dims) { |
178 | ir_assert(d == 1) << "Can't map tensor layout." ; |
179 | MAYBE_UNUSED(d); |
180 | } |
181 | |
182 | return layout_t(type(), ndims(), operator()(tensor.start()), mapped_blocks); |
183 | } |
184 | |
185 | layout_t layout_t::reinterpret( |
186 | const type_t &new_type, bool do_normalize) const { |
187 | int old_size = type().size(); |
188 | int new_size = new_type.size(); |
189 | if (new_size == old_size) return *this; |
190 | |
191 | expr_t new_offset = 0; |
192 | if (!has_zero_offset()) { |
193 | ir_assert(is_const(offset_)) << "Expected constant offset." ; |
194 | int64_t off = to_cpp<int64_t>(offset_) * old_size; |
195 | ir_assert(off % new_size == 0); |
196 | new_offset = off / new_size; |
197 | } |
198 | |
199 | if (old_size % new_size != 0 && new_size % old_size != 0) { |
200 | ir_error_not_expected(); |
201 | return layout_t(); |
202 | } |
203 | |
204 | auto new_blocks = blocks_; |
205 | if (new_blocks.empty()) { |
206 | ir_error_not_expected() << "Can't reinterpret." ; |
207 | return layout_t(); |
208 | } |
209 | |
210 | auto &b0 = new_blocks.front(); |
211 | if (dim_t(b0.stride) != 1) { |
212 | ir_error_not_expected(); |
213 | return layout_t(); |
214 | } |
215 | |
216 | if (new_size < old_size) { |
217 | int factor = (old_size / new_size); |
218 | b0.block *= factor; |
219 | // Recompute strides. |
220 | for (auto &b : new_blocks) { |
221 | if (&b == &b0) continue; |
222 | b.stride *= factor; |
223 | } |
224 | } else { |
225 | int factor = (new_size / old_size); |
226 | if (b0.block % factor != 0) { |
227 | ir_error_not_expected(); |
228 | return layout_t(); |
229 | } |
230 | b0.block /= factor; |
231 | // Recompute strides. |
232 | for (auto &b : new_blocks) { |
233 | if (&b == &b0) continue; |
234 | if (b.stride % factor != 0) { |
235 | ir_error_not_expected(); |
236 | return layout_t(); |
237 | } |
238 | b.stride /= factor; |
239 | } |
240 | } |
241 | |
242 | return layout_t(new_type, ndims(), new_offset, new_blocks, do_normalize); |
243 | } |
244 | |
245 | layout_t layout_t::split_block( |
246 | const std::pair<int, block_t> &eb, dim_t block0, dim_t block1) const { |
247 | int block_idx = eb.first; |
248 | auto &b = eb.second; |
249 | ir_assert(b.block == block0 * block1) << "Incompatible block sizes." ; |
250 | MAYBE_UNUSED(b); |
251 | |
252 | auto new_blocks = blocks_; |
253 | |
254 | block_t &b0 = new_blocks[block_idx]; |
255 | block_t b1 = b0; |
256 | |
257 | b0.block = block0; |
258 | b1.block = block1; |
259 | b1.stride = b0.stride * block0; |
260 | |
261 | new_blocks.insert(new_blocks.begin() + block_idx + 1, b1); |
262 | |
263 | return layout_t( |
264 | type(), ndims(), offset(), new_blocks, /*do_normalize=*/false); |
265 | } |
266 | |
267 | layout_t layout_t::split_into_multi_blocks( |
268 | const std::vector<dim_t> &multi_blocks) const { |
269 | if (is_empty()) return *this; |
270 | |
271 | layout_t tmp(*this); |
272 | std::vector<dim_t> rem_elems = multi_blocks; |
273 | std::vector<dim_t> cur_elems(rem_elems.size(), 1); |
274 | for (auto &eb : tmp.enumerated_blocks()) { |
275 | auto &b = eb.second; |
276 | for (int i = 0; i < int(rem_elems.size()); i++) { |
277 | auto &e = rem_elems[i]; |
278 | if (e == 1) continue; |
279 | if (b.block > e) { |
280 | // Try to split this block. |
281 | int next_block = utils::max_div(b.block, e); |
282 | if (next_block == 1) return layout_t(); |
283 | return tmp.split_block(eb, next_block, b.block / next_block) |
284 | .split_into_multi_blocks(multi_blocks); |
285 | } |
286 | if (e % b.block != 0) return layout_t(); |
287 | e /= b.block; |
288 | cur_elems[i] *= b.block; |
289 | break; |
290 | } |
291 | } |
292 | for (int i = 0; i < int(cur_elems.size()); i++) { |
293 | if (cur_elems[i] != multi_blocks[i]) { return layout_t(); } |
294 | } |
295 | return tmp; |
296 | } |
297 | |
298 | tensor_t layout_t::split_into_max_tile( |
299 | dim_t max_tile_elems, bool is_dense_tile) const { |
300 | stride_t dense_stride = 1; |
301 | std::vector<dim_t> tile_dims(ndims(), 1); |
302 | dim_t cur_elems = 1; |
303 | for (auto &eb : enumerated_blocks()) { |
304 | auto &b = eb.second; |
305 | if (b.block == 1) continue; |
306 | if (b.block * cur_elems <= max_tile_elems) { |
307 | if (is_dense_tile) { |
308 | if (b.stride.is_unknown()) break; |
309 | if (dense_stride != b.stride) break; |
310 | dense_stride = b.block * b.stride; |
311 | } |
312 | cur_elems *= b.block; |
313 | tile_dims[b.dim_idx] *= b.block; |
314 | continue; |
315 | } |
316 | dim_t max_block = utils::max_div(b.block, max_tile_elems / cur_elems); |
317 | if (max_block == 1) break; |
318 | auto tmp_layout = split_block(eb, max_block, b.block / max_block); |
319 | return tmp_layout.split_into_max_tile(max_tile_elems, is_dense_tile); |
320 | } |
321 | return tensor_t(tile_dims); |
322 | } |
323 | |
324 | void layout_t::align_layouts(layout_t &a, layout_t &b) { |
325 | for (int i = 0; i < a.ndims(); i++) { |
326 | auto a_blocks = a.blocks(); |
327 | auto b_blocks = b.blocks(); |
328 | |
329 | int a_max = int(a_blocks.size()); |
330 | int b_max = int(b_blocks.size()); |
331 | int a_idx = 0; |
332 | int b_idx = 0; |
333 | |
334 | for (;;) { |
335 | while (a_idx < a_max && a_blocks[a_idx].dim_idx != i) |
336 | a_idx++; |
337 | while (b_idx < b_max && b_blocks[b_idx].dim_idx != i) |
338 | b_idx++; |
339 | |
340 | if (a_idx >= a_max || b_idx >= b_max) break; |
341 | |
342 | auto &ab = a_blocks[a_idx]; |
343 | auto &bb = b_blocks[b_idx]; |
344 | dim_t common_block = math::gcd(ab.block, bb.block); |
345 | if (ab.block == common_block && bb.block == common_block) { |
346 | a_idx++; |
347 | b_idx++; |
348 | continue; |
349 | } |
350 | |
351 | if (ab.block != common_block) { |
352 | a = a.split_block( |
353 | {a_idx, ab}, common_block, ab.block / common_block); |
354 | } |
355 | if (bb.block != common_block) { |
356 | b = b.split_block( |
357 | {b_idx, bb}, common_block, bb.block / common_block); |
358 | } |
359 | break; |
360 | } |
361 | } |
362 | } |
363 | |
364 | std::vector<std::pair<char, dim_t>> layout_t::parse_letter_blocks( |
365 | const std::string &format) { |
366 | std::vector<std::pair<char, dim_t>> ret; |
367 | |
368 | std::stringstream ss(format); |
369 | while (!ss.eof()) { |
370 | int next = ss.peek(); |
371 | if (ss.eof()) break; |
372 | dim_t block = 0; |
373 | while (std::isdigit(next)) { |
374 | block = 10 * block + (next - '0'); |
375 | ss.ignore(1); |
376 | next = ss.peek(); |
377 | } |
378 | char letter = char(ss.peek()); |
379 | ir_assert(!ss.eof()) << "EOF is unexpected." ; |
380 | ss.ignore(1); |
381 | ret.emplace_back(letter, block); |
382 | } |
383 | return ret; |
384 | } |
385 | |
386 | std::vector<std::pair<int, dim_t>> layout_t::parse_format( |
387 | const std::string &format, int ndims_hint) { |
388 | bool seen_letters[DNNL_MAX_NDIMS] = {}; |
389 | int letter_ndims = 0; |
390 | for (char c = 'a'; c < 'a' + DNNL_MAX_NDIMS; c++) { |
391 | if (format.find(c) != std::string::npos) { |
392 | seen_letters[c - 'a'] = true; |
393 | MAYBE_UNUSED(seen_letters); |
394 | letter_ndims++; |
395 | } |
396 | } |
397 | |
398 | for (int i = 0; i < DNNL_MAX_NDIMS; i++) { |
399 | ir_assert(seen_letters[i] == (i < letter_ndims)); |
400 | } |
401 | |
402 | auto letter_blocks = parse_letter_blocks(format); |
403 | |
404 | std::vector<std::pair<int, dim_t>> parts; |
405 | for (auto &p : letter_blocks) { |
406 | char letter = p.first; |
407 | dim_t block = p.second; |
408 | if (letter != 'x') { |
409 | int dim_idx = std::tolower(letter) - 'a'; |
410 | parts.emplace_back(dim_idx, block); |
411 | } else { |
412 | ir_assert(ndims_hint >= letter_ndims); |
413 | for (int i = letter_ndims; i < ndims_hint; i++) { |
414 | parts.emplace_back(i, 0); |
415 | } |
416 | } |
417 | } |
418 | |
419 | return parts; |
420 | } |
421 | |
422 | void layout_t::sanity_check() const { |
423 | #if !defined(NDEBUG) || defined(GEN_CONV_DEBUG) |
424 | return; |
425 | #endif |
426 | if (is_empty()) return; |
427 | |
428 | for (auto &b : blocks_) { |
429 | ir_assert(b.block > 0) << "Incorrect block size." ; |
430 | MAYBE_UNUSED(b); |
431 | } |
432 | ir_assert(ndims_ <= max_ndims); |
433 | } |
434 | |
435 | expr_t grid_splitter_t::pop_block(int size) { |
436 | ir_assert(size > 1); |
437 | ir_assert(can_pop_block(size)); |
438 | |
439 | int new_stride = cur_stride_ * size; |
440 | |
441 | auto idx_expr = grid_.idx(cur_idx_); |
442 | if (cur_stride_ != 1) idx_expr /= cur_stride_; |
443 | if (new_stride != grid_.dim(cur_idx_)) idx_expr %= size; |
444 | |
445 | cur_stride_ = new_stride; |
446 | if (cur_stride_ == grid_.dim(cur_idx_)) { |
447 | // Move to the next dimension. |
448 | cur_idx_--; |
449 | skip_size_1_dims(); |
450 | cur_stride_ = 1; |
451 | } |
452 | return idx_expr; |
453 | } |
454 | |
455 | stride_t tdim_info_t::compute_stride( |
456 | const expr_t &e, int idx, const expr_t &var) { |
457 | // e == var -> fixed stride. |
458 | if (e.is_same(var)) return stride_t(1); |
459 | |
460 | auto vars = find_objects<var_t>(e); |
461 | |
462 | auto e0 = e; |
463 | auto e1 = substitute(e, var, var + 1); |
464 | auto e_stride = simplify(e1 - e0); |
465 | |
466 | if (is_const(e_stride)) return stride_t(to_cpp<dim_t>(e_stride)); |
467 | |
468 | // Stride is not a constant. |
469 | return stride_t::unknown(); |
470 | } |
471 | |
472 | view_t view_t::create_sub_view(const tensor_t &sub_tensor) const { |
473 | ir_assert(sub_tensor.ndims() == nvdims()) << "Dimensions don't match." ; |
474 | |
475 | auto ret = *this; |
476 | ret.vdims_ = sub_tensor.dims(); |
477 | for (int i = 0; i < nvdims(); i++) { |
478 | auto &i_start = sub_tensor.start()[i]; |
479 | if (is_zero(i_start)) continue; |
480 | auto &s = ret.vstart_[i]; |
481 | s += i_start; |
482 | s = simplify(s); |
483 | } |
484 | return ret; |
485 | } |
486 | |
487 | view_t view_t::substitute(const expr_t &from, const expr_t &to) const { |
488 | view_t ret = *this; |
489 | for (int i = 0; i < nvdims(); i++) { |
490 | ret.vstart_[i] = jit::substitute(ret.vstart_[i], from, to); |
491 | ret.vstart_[i] = simplify(ret.vstart_[i]); |
492 | } |
493 | return ret; |
494 | } |
495 | |
496 | std::vector<expr_t> view_t::create_vvars(int nvdims) { |
497 | static const int max_nvdims = 128; |
498 | static thread_local std::vector<expr_t> _vvars([] { |
499 | std::vector<expr_t> ret; |
500 | ret.reserve(max_nvdims); |
501 | for (int i = 0; i < max_nvdims; i++) |
502 | ret.push_back(var_t::make(type_t::s32(), "_" + std::to_string(i))); |
503 | return ret; |
504 | }()); |
505 | |
506 | ir_assert(nvdims <= max_nvdims) << "Too many dimensions: " << nvdims; |
507 | return std::vector<expr_t>(_vvars.begin(), _vvars.begin() + nvdims); |
508 | } |
509 | |
510 | layout_t view_t::create_pseudo_vlayout(const layout_t &tlayout) const { |
511 | ir_assert(!tlayout.is_empty()); |
512 | |
513 | std::vector<dim_t> rem_vdims = vdims_; |
514 | std::vector<block_t> blocks; |
515 | |
516 | for (auto &teb : tlayout.enumerated_blocks()) { |
517 | block_t &tb = teb.second; |
518 | bool tb_is_outermost = tlayout.is_outermost(teb); |
519 | dim_t tblock = tb.block; |
520 | |
521 | auto &tinfo = tdims_[tb.dim_idx]; |
522 | if (tb_is_outermost) { |
523 | // Use innermost dimension with maximum remaining size for first |
524 | // block |
525 | int max_idx = tinfo.nvargs() - 1; |
526 | int max_vidx = tinfo.vidx(max_idx); |
527 | int max_vdim = rem_vdims[max_vidx]; |
528 | for (int i = tinfo.nvargs() - 2; i >= 0; i--) { |
529 | int vidx = tinfo.vidx(i); |
530 | if (rem_vdims[vidx] > max_vdim) { |
531 | max_idx = i; |
532 | max_vidx = vidx; |
533 | max_vdim = rem_vdims[vidx]; |
534 | } |
535 | } |
536 | |
537 | if (max_vdim > 1) { |
538 | stride_t stride = tinfo.vstride(max_idx); |
539 | blocks.emplace_back( |
540 | max_vidx, max_vdim, stride * stride_t(tb.stride)); |
541 | rem_vdims[max_vidx] = 1; |
542 | } |
543 | |
544 | for (int i = tinfo.nvargs() - 1; i >= 0; i--) { |
545 | int vidx = tinfo.vidx(i); |
546 | if (rem_vdims[vidx] == 1) continue; |
547 | |
548 | // When expression contains 2+ variables, use unknown stride for |
549 | // the remaining view variables. |
550 | blocks.emplace_back(vidx, rem_vdims[vidx], stride_t::unknown()); |
551 | rem_vdims[vidx] = 1; |
552 | } |
553 | continue; |
554 | } |
555 | |
556 | ir_assert(tinfo.is_identity()) << "Can't create pseudo-layout." ; |
557 | |
558 | int vidx = tinfo.vidx(0); |
559 | dim_t &rem_vdim = rem_vdims[vidx]; |
560 | if (rem_vdim == 1) continue; |
561 | |
562 | if (tb_is_outermost) { |
563 | tblock = rem_vdim; |
564 | rem_vdim = 1; |
565 | } else if (rem_vdim % tblock == 0) { |
566 | rem_vdim /= tblock; |
567 | } else if (rem_vdim % tblock != 0) { |
568 | // Try to split the current block and start from scratch. |
569 | if (tblock % rem_vdim == 0) { |
570 | auto tmp_layout |
571 | = tlayout.split_block(teb, rem_vdim, tblock / rem_vdim); |
572 | return create_pseudo_vlayout(tmp_layout); |
573 | } |
574 | |
575 | ir_error_not_expected() << "Can't create pseudo-layout." ; |
576 | } |
577 | blocks.emplace_back(tb.dim_idx, tblock, tb.stride); |
578 | } |
579 | |
580 | for (auto &d : rem_vdims) { |
581 | ir_assert(d == 1) << "Can't create pseudo-layout." ; |
582 | MAYBE_UNUSED(d); |
583 | } |
584 | |
585 | return layout_t(tlayout.type(), nvdims(), 0, blocks); |
586 | } |
587 | |
588 | layout_t dim_assignment_t::map(const layout_t &layout) const { |
589 | std::vector<block_t> new_blocks; |
590 | for (auto &b : layout.blocks()) { |
591 | int new_idx = assignments_[b.dim_idx]; |
592 | if (new_idx == -1) continue; // Drop this block. |
593 | auto new_b = b; |
594 | new_b.dim_idx = new_idx; |
595 | new_blocks.push_back(new_b); |
596 | } |
597 | new_blocks = layout_t::normalize_blocks(new_ndims(), new_blocks, |
598 | /*remove_size_1_blocks=*/false); |
599 | auto ret = layout_t(layout.type(), new_ndims(), layout.offset(), new_blocks, |
600 | /*do_normalize=*/false); |
601 | ir_assert(layout.elems() == ret.elems()) |
602 | << "Assignment doesn't preserve number of elements." ; |
603 | return ret; |
604 | } |
605 | |
606 | } // namespace jit |
607 | } // namespace gpu |
608 | } // namespace impl |
609 | } // namespace dnnl |
610 | |