1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_TENSOR_HPP
18#define GPU_JIT_IR_TENSOR_HPP
19
20#include <algorithm>
21#include <array>
22#include <iostream>
23#include <sstream>
24#include <string>
25#include <thread>
26#include <tuple>
27#include <utility>
28#include <vector>
29#include <unordered_map>
30
31#include "common/memory_desc_wrapper.hpp"
32#include "gpu/jit/ir/ir.hpp"
33#include "gpu/jit/pass/simplify.hpp"
34#include "gpu/jit/utils/utils.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace gpu {
39namespace jit {
40
41class tensor_t {
42public:
43 tensor_t() = default;
44
45 tensor_t(const std::vector<dim_t> &dims)
46 : tensor_t(dims, std::vector<expr_t>()) {}
47
48 tensor_t(const std::vector<dim_t> &dims, const std::vector<expr_t> &start)
49 : dims_(dims), start_(start) {
50 if (start_.empty()) start_.resize(dims.size(), 0);
51 }
52
53 tensor_t(const std::vector<dim_t> &dims, const std::vector<dim_t> &start)
54 : tensor_t(dims) {
55 start_.resize(start.size());
56 for (size_t i = 0; i < start.size(); i++)
57 start_[i] = start[i];
58 }
59
60 dim_t operator()(int idx) const { return dims_[idx]; }
61
62 const expr_t &start(int idx) const { return start_[idx]; }
63
64 int ndims() const { return int(dims_.size()); }
65
66 dim_t elems() const {
67 dim_t ret = 1;
68 for (int i = 0; i < ndims(); i++)
69 ret *= dims_[i];
70 return ret;
71 }
72
73 const std::vector<dim_t> &dims() const { return dims_; }
74
75 const std::vector<expr_t> &start() const { return start_; }
76
77 bool is_empty() const { return dims_.empty(); }
78
79 bool is_equal(const tensor_t &other) const {
80 if (ndims() != other.ndims()) return false;
81 for (int i = 0; i < ndims(); i++) {
82 if (dims_[i] != other.dims_[i]) return false;
83 if (!start_[i].is_equal(other.start_[i])) return false;
84 }
85 return true;
86 }
87
88 bool is_divisible(const tensor_t &other) const {
89 if (ndims() != other.ndims()) return false;
90 for (int i = 0; i < ndims(); i++) {
91 if (dims_[i] % other.dims_[i] != 0) return false;
92 }
93 return true;
94 }
95
96 std::string str() const {
97 using ir_utils::operator<<;
98
99 if (is_empty()) return "(nil)";
100 std::ostringstream oss;
101 oss << ir_utils::make_seq_print_helper(dims_, "x");
102 if (!has_zero_start()) oss << " start: [" << start_ << "]";
103 return oss.str();
104 }
105
106 IR_DEFINE_DUMP()
107
108 bool has_zero_start() const {
109 for (int i = 0; i < ndims(); i++)
110 if (!is_zero(start_[i])) return false;
111 return true;
112 }
113
114 dim_t to_1d_offset(const std::vector<dim_t> &args) const {
115 ir_assert(has_zero_start());
116
117 dim_t off = 0;
118 for (int i = 0; i < ndims(); i++) {
119 off *= dims_[i];
120 off += args[i];
121 }
122 return off;
123 }
124
125 tensor_t create_sub_tensor(const tensor_t &tile) const {
126 ir_assert(ndims() == tile.ndims()) << "Incompatible sizes.";
127 std::vector<expr_t> new_start = start_;
128 for (int i = 0; i < ndims(); i++)
129 new_start[i] += tile.start(i);
130 return tensor_t(tile.dims(), new_start);
131 }
132
133 tensor_t substitute(const expr_t &from, const expr_t &to) const {
134 tensor_t ret = *this;
135 for (int i = 0; i < ndims(); i++) {
136 ret.start_[i] = jit::substitute(ret.start_[i], from, to);
137 ret.start_[i] = simplify(ret.start_[i]);
138 }
139 return ret;
140 }
141
142private:
143 std::vector<dim_t> dims_;
144 std::vector<expr_t> start_;
145};
146
147inline std::ostream &operator<<(std::ostream &out, const tensor_t &tensor) {
148 out << tensor.str();
149 return out;
150}
151
152class grid_info_t {
153public:
154 grid_info_t() = default;
155 grid_info_t(int ndims) : dims_(ndims), offs_(ndims), idxs_(ndims) {}
156 grid_info_t(const std::vector<int> &dims, const std::vector<expr_t> &idxs)
157 : grid_info_t(dims, {}, idxs) {}
158 grid_info_t(const std::vector<int> &dims, const std::string &prefix)
159 : grid_info_t(dims, make_idxs(prefix, (int)dims.size())) {}
160 grid_info_t(const std::vector<int> &dims, const std::vector<int> &offs,
161 const std::vector<expr_t> &idxs)
162 : dims_(dims), offs_(offs), idxs_(idxs) {
163 if (offs_.empty()) offs_.resize(dims.size());
164 ir_assert(dims_.size() == offs_.size());
165 ir_assert(dims_.size() == idxs_.size());
166 }
167
168 bool operator==(const grid_info_t &other) const {
169 if (ndims() != other.ndims()) return false;
170 for (int i = 0; i < ndims(); i++) {
171 if (dim(i) != other.dim(i)) return false;
172 if (off(i) != other.off(i)) return false;
173 if (!idx(i).is_equal(other.idx(i))) return false;
174 }
175 return true;
176 }
177
178 bool is_empty() const { return dims_.empty(); }
179
180 int &dim(int dim_idx) { return dims_[dim_idx]; }
181 int &off(int dim_idx) { return offs_[dim_idx]; }
182 expr_t &idx(int dim_idx) { return idxs_[dim_idx]; }
183 int dim_idx(const expr_t &idx_var) const {
184 for (int i = 0; i < ndims(); i++) {
185 if (idx(i).is_same(idx_var)) return i;
186 }
187 ir_error_not_expected() << "Index not found: " << idx_var;
188 return -1;
189 }
190
191 const int &dim(int dim_idx) const { return dims_[dim_idx]; }
192 const int &dim(const expr_t &idx_var) const {
193 return dims_[dim_idx(idx_var)];
194 }
195 const int &off(int dim_idx) const { return offs_[dim_idx]; }
196 const expr_t &idx(int dim_idx) const { return idxs_[dim_idx]; }
197
198 int &operator[](int dim_idx) { return dim(dim_idx); }
199 const int &operator[](int dim_idx) const { return dim(dim_idx); }
200
201 int ndims() const { return int(dims_.size()); }
202 int elems() const {
203 return utils::array_product(dims_.data(), dims_.size());
204 }
205
206 grid_info_t sub_grid(std::initializer_list<int> old_dim_idxs) const {
207 grid_info_t ret(int(old_dim_idxs.size()));
208 int new_dim_idx = 0;
209 for (auto old_dim_idx : old_dim_idxs) {
210 ret.dim(new_dim_idx) = dim(old_dim_idx);
211 ret.off(new_dim_idx) = off(old_dim_idx);
212 ret.idx(new_dim_idx) = idx(old_dim_idx);
213 new_dim_idx++;
214 }
215 return ret;
216 }
217
218 grid_info_t resize(const std::vector<int> &new_dims) const {
219 grid_info_t ret = *this;
220 ret.dims_ = new_dims;
221 return ret;
222 }
223
224 grid_info_t slice(int dim_idx, int new_off, int new_dim,
225 const expr_t &new_idx, expr_t &new_idx_value) const {
226 ir_assert(dim_idx >= 0 && dim_idx < ndims());
227 ir_assert(new_dim > 0 && new_off >= 0);
228 ir_assert(new_off + new_dim <= dims_[dim_idx]);
229
230 grid_info_t ret = *this;
231 ret.offs_[dim_idx] += new_off;
232 ret.dims_[dim_idx] = new_dim;
233 if (new_off > 0) {
234 new_idx_value = ret.idxs_[dim_idx] - new_off;
235 ret.idxs_[dim_idx] = new_idx;
236 } else {
237 new_idx_value = expr_t();
238 }
239 ret.parent_dims_ = (parent_dims_.empty() ? dims_ : parent_dims_);
240 return ret;
241 }
242
243 grid_info_t halven(const expr_t &new_idx, int &dim_idx,
244 expr_t &new_idx_value, bool first = true) const {
245 for (int i = ndims() - 1; i >= 0; i--) {
246 if (dim(i) == 1 || dim(i) % 2 != 0) continue;
247 dim_idx = i;
248 if (first) return slice(i, 0, dim(i) / 2, new_idx, new_idx_value);
249 return slice(i, dim(i) / 2, dim(i) / 2, new_idx, new_idx_value);
250 }
251 return grid_info_t();
252 }
253
254 expr_t slice_condition() const {
255 if (parent_dims_.empty()) return expr_t();
256 expr_t ret(true);
257 for (int i = 0; i < ndims(); i++) {
258 auto &idx = idxs_[i];
259 if (offs_[i] > 0) ret &= (idx >= 0);
260 if (offs_[i] + dims_[i] < parent_dims_[i]) ret &= (idx < dims_[i]);
261 }
262 if (ret.is_equal(expr_t(true))) return expr_t();
263 return ret;
264 }
265
266 std::string str() const {
267 std::ostringstream oss;
268 oss << ir_utils::make_seq_print_helper(dims_, " x ");
269 return oss.str();
270 }
271
272 IR_DEFINE_DUMP()
273
274private:
275 static std::vector<expr_t> make_idxs(const std::string &prefix, int n) {
276 std::vector<expr_t> ret;
277 for (int i = 0; i < n; i++)
278 ret.push_back(
279 var_t::make(type_t::s32(), prefix + std::to_string(i)));
280 return ret;
281 }
282
283 std::vector<int> dims_;
284 std::vector<int> offs_;
285 std::vector<expr_t> idxs_;
286
287 std::vector<int> parent_dims_;
288};
289
290inline std::ostream &operator<<(
291 std::ostream &out, const grid_info_t &grid_info) {
292 out << grid_info.str();
293 return out;
294}
295
296class grid_splitter_t {
297public:
298 grid_splitter_t(const grid_info_t &grid)
299 : grid_(grid), cur_idx_(grid.ndims() - 1), cur_stride_(1) {
300 skip_size_1_dims();
301 ir_assert(cur_idx_ >= 0);
302 }
303
304 int cur_block() const {
305 if (is_empty()) return 1;
306
307 return grid_.dim(cur_idx_) / cur_stride_;
308 }
309
310 bool is_empty() const { return cur_idx_ == -1; }
311
312 bool can_pop_block(int size) const {
313 if (is_empty()) return false;
314 return cur_block() % size == 0;
315 }
316
317 expr_t pop_block(int size);
318
319private:
320 void skip_size_1_dims() {
321 while (cur_idx_ >= 0 && grid_.dim(cur_idx_) == 1)
322 cur_idx_--;
323 }
324
325 grid_info_t grid_;
326
327 int cur_idx_;
328 int cur_stride_;
329};
330
331enum class stride_kind_t {
332 undef,
333 fixed,
334 unknown,
335};
336
337class stride_t {
338public:
339 stride_t() = default;
340
341 stride_t(dim_t stride) : stride_t(stride_kind_t::fixed, stride) {}
342
343 bool operator==(const stride_t &other) const {
344 return (kind_ == other.kind_) && (stride_ == other.stride_);
345 }
346
347 bool operator!=(const stride_t &other) const { return !operator==(other); }
348
349 size_t get_hash() const { return ir_utils::get_hash(kind_, stride_); }
350
351 operator dim_t() const {
352 ir_assert(kind_ == stride_kind_t::fixed);
353 return stride_;
354 }
355
356 bool is_fixed() const { return kind_ == stride_kind_t::fixed; }
357
358 bool is_unknown() const { return kind_ == stride_kind_t::unknown; }
359
360 stride_t &operator*=(const stride_t &other) {
361 if (is_fixed() && other.is_fixed()) {
362 stride_ *= other.stride_;
363 } else {
364 set_unknown();
365 }
366 return *this;
367 }
368
369 stride_t &operator/=(const stride_t &other) {
370 if (is_fixed() && other.is_fixed()) {
371 stride_ /= other.stride_;
372 } else {
373 set_unknown();
374 }
375 return *this;
376 }
377
378 std::string str() const {
379 std::ostringstream oss;
380 if (is_fixed()) {
381 oss << stride_;
382 } else {
383 oss << "(unknown)";
384 }
385 return oss.str();
386 }
387
388 IR_DEFINE_DUMP()
389
390 static stride_t unknown() { return stride_t(stride_kind_t::unknown); }
391
392private:
393 stride_t(stride_kind_t kind, dim_t stride = 0)
394 : kind_(kind), stride_(stride) {}
395
396 void set_unknown() {
397 kind_ = stride_kind_t::unknown;
398 stride_ = 0;
399 }
400
401 stride_kind_t kind_ = stride_kind_t::undef;
402 dim_t stride_ = 0;
403};
404
405inline std::ostream &operator<<(std::ostream &out, const stride_t &stride) {
406 out << stride.str();
407 return out;
408}
409
410inline stride_t operator*(const stride_t &a, const stride_t &b) {
411 stride_t tmp = a;
412 return tmp *= b;
413}
414
415inline stride_t operator*(const stride_t &a, dim_t b) {
416 return a * stride_t(b);
417}
418
419inline stride_t operator*(dim_t a, const stride_t &b) {
420 return stride_t(a) * b;
421}
422
423struct block_t {
424 block_t() = default;
425
426 block_t(int dim_idx, dim_t block, const stride_t &stride)
427 : dim_idx(dim_idx), block(block), stride(stride) {}
428
429 bool is_equal(const block_t &other) const {
430 return (dim_idx == other.dim_idx) && (block == other.block)
431 && (stride == other.stride);
432 }
433
434 size_t get_hash() const {
435 return ir_utils::get_hash(dim_idx, block, stride);
436 }
437
438 std::string str() const {
439 std::ostringstream oss;
440 oss << "block_t(dim_idx = " << dim_idx;
441 oss << ", block = " << block;
442 oss << ", stride = " << stride;
443 oss << ")";
444 return oss.str();
445 }
446
447 IR_DEFINE_DUMP()
448
449 bool is_empty() const { return dim_idx == -1; }
450
451 int dim_idx = -1; // Dimension index.
452 dim_t block; // Block size.
453 stride_t stride; // Stride between elements of the block.
454};
455
456inline std::ostream &operator<<(std::ostream &out, const block_t &b) {
457 out << b.str();
458 return out;
459}
460
461class layout_t {
462public:
463 static const int max_ndims = 16;
464
465 layout_t() : type_(type_t::undef()), ndims_(0), offset_(0) {
466 sanity_check();
467 }
468
469 layout_t(const type_t &type, const expr_t &offset,
470 const std::vector<std::pair<int, dim_t>> &parts,
471 const std::vector<dim_t> &dims = {}, bool do_normalize = true);
472
473 layout_t(const type_t &type, const expr_t &offset,
474 const std::string &format, const std::vector<dim_t> &dims = {},
475 bool do_normalize = true)
476 : layout_t(type, offset, parse_format(format, int(dims.size())), dims,
477 do_normalize) {}
478
479 layout_t(const memory_desc_wrapper &mdw, const std::string &format,
480 bool do_normalize = true)
481 : layout_t(mdw.data_type(), mdw.offset0(), format,
482 std::vector<dim_t>(mdw.dims(), mdw.dims() + mdw.ndims()),
483 do_normalize) {}
484
485 layout_t(const memory_desc_wrapper &mdw, const char *format,
486 bool do_normalize = true)
487 : layout_t(mdw, std::string(format), do_normalize) {}
488
489 layout_t(const memory_desc_wrapper &mdw, bool do_normalize = true);
490
491 layout_t(const type_t &type, const expr_t &offset,
492 const std::vector<dim_t> &dims, bool do_normalize = true)
493 : type_(type), ndims_(int(dims.size())), offset_(offset) {
494 dim_t stride = 1;
495 for (int i = ndims_ - 1; i >= 0; i--) {
496 blocks_.emplace_back(i, dims[i], stride);
497 stride *= dims[i];
498 }
499 if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
500 sanity_check();
501 }
502
503 layout_t(const type_t &type, int ndims, const expr_t &offset,
504 const std::vector<block_t> &blocks, bool do_normalize = true)
505 : type_(type), ndims_(ndims), offset_(offset), blocks_(blocks) {
506 if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
507 sanity_check();
508 }
509
510 layout_t(const type_t &type, const expr_t &offset, const layout_t &other,
511 bool do_normalize)
512 : layout_t(type, other.ndims(), offset, other.blocks(), do_normalize) {}
513
514 bool is_empty() const { return ndims_ == 0; }
515
516 int ndims() const { return ndims_; }
517
518 dim_t elems() const {
519 dim_t ret = 1;
520 for (auto &b : blocks_)
521 ret *= b.block;
522 return ret;
523 }
524
525 // Storage size in bytes.
526 dim_t size() const {
527 if (is_empty()) return 0;
528 dim_t max_stride = 1;
529 for (auto &b : blocks_) {
530 max_stride = std::max(max_stride, dim_t(b.block * b.stride));
531 }
532 return max_stride * type().size();
533 }
534
535 template <typename T = expr_t>
536 T offset(
537 const std::vector<T> &args = {}, bool ignore_offset = false) const {
538 if (args.empty()) return expr_cast<T>(offset_);
539
540 ir_assert(int(args.size()) == ndims()) << "Dimensions do not match.";
541
542 T off = 0;
543 auto _args = args;
544 for (auto &eb : enumerated_blocks()) {
545 auto &b = eb.second;
546 auto &idx = _args[b.dim_idx];
547 if (ir_utils::is_equal(idx, T(0))) continue;
548
549 // Do not use modulus for outermost blocks.
550 auto i = is_outermost(eb) ? idx : (idx % b.block);
551 off = i * dim_t(b.stride) + off;
552 idx /= b.block;
553 }
554 if (ignore_offset) return off;
555
556 T off0 = expr_cast<T>(offset_);
557 return off0 + off;
558 }
559
560 const type_t &type() const { return type_; }
561
562 std::vector<dim_t> dims() const {
563 std::vector<dim_t> dims(ndims(), 1);
564 for (auto &b : blocks_) {
565 dims[b.dim_idx] *= b.block;
566 }
567 return dims;
568 }
569
570 dim_t dim(int dim_idx) const {
571 dim_t ret = 1;
572 for (auto &b : blocks_) {
573 if (b.dim_idx == dim_idx) ret *= b.block;
574 }
575 return ret;
576 }
577
578 const std::vector<block_t> &blocks() const { return blocks_; }
579
580 dim_t inner_block(int dim_idx, bool skip_outer = true) const {
581 dim_t block0 = -1;
582 int nblocks = 0;
583 for (auto &b : blocks_) {
584 if (b.dim_idx == dim_idx) {
585 nblocks++;
586 if (!skip_outer) return b.block;
587 if (block0 == -1) block0 = b.block;
588 }
589 }
590 return nblocks > 1 ? block0 : 1;
591 }
592
593 void set_offset(const expr_t &offset) { offset_ = offset; }
594
595 bool is_strictly_equal(const layout_t &other, bool compare_offset = true,
596 bool compare_strides = true) const {
597 if (!type_.is_equal(other.type_)) return false;
598 if (compare_offset && !offset_.is_equal(other.offset_)) return false;
599 if (blocks_.size() != other.blocks_.size()) return false;
600 for (size_t i = 0; i < blocks_.size(); i++) {
601 auto &b0 = blocks_[i];
602 auto &b1 = other.blocks_[i];
603 if (b0.dim_idx != b1.dim_idx) return false;
604 if (b0.block != b1.block) return false;
605 if (compare_strides && b0.stride != b1.stride) return false;
606 }
607 return true;
608 }
609
610 bool operator==(const layout_t &other) const { return is_equal(other); }
611
612 bool operator!=(const layout_t &other) const { return !operator==(other); }
613 bool operator<=(const layout_t &other) const {
614 if (!type_.is_equal(other.type_)) return false;
615 const auto other_blocks = other.normalize().blocks();
616 const auto self_blocks = normalize().blocks();
617 if (self_blocks.size() > other_blocks.size()) return false;
618 if (self_blocks.size() == 0) return true;
619
620 int i = 0;
621 for (; i < (int)self_blocks.size() - 1; i++) {
622 if (!self_blocks[i].is_equal(other_blocks[i])) return false;
623 }
624 return (self_blocks[i].dim_idx == other_blocks[i].dim_idx
625 && self_blocks[i].stride == other_blocks[i].stride
626 && other_blocks[i].block % self_blocks[i].block == 0);
627 }
628 bool operator>=(const layout_t &other) const { return other <= *this; }
629
630 bool is_equal(const layout_t &other, bool compare_offset = true) const {
631 return normalize().is_strictly_equal(other.normalize(), compare_offset);
632 }
633
634 size_t get_hash() const {
635 return ir_utils::get_hash(type_, ndims_, offset_, blocks_);
636 }
637
638 template <typename T>
639 T operator()(const std::vector<T> &args) const {
640 return offset(args);
641 }
642
643 template <typename T = expr_t>
644 T offset_in_bytes(
645 const std::vector<T> &args = {}, bool ignore_offset = false) const {
646 return offset(args, ignore_offset) * type().size();
647 }
648
649 std::string desc_str(bool dnnl_style = false) const {
650 if (is_empty()) return "(nil)";
651 if (!dnnl_style && blocks_.empty())
652 return "(scalar:" + type().str() + ")";
653 std::string ret;
654 stride_t dense_stride(1);
655 std::vector<bool> seen(ndims());
656 for (auto &eb : enumerated_blocks()) {
657 auto &b = eb.second;
658 std::string b_str;
659 if (dnnl_style && is_outermost(eb)) {
660 b_str.append(1, (seen[b.dim_idx] ? 'A' : 'a') + b.dim_idx);
661 } else {
662 b_str = std::to_string(b.block);
663 b_str.append(1, 'a' + b.dim_idx);
664 }
665 if (!dnnl_style) {
666 if (b.stride.is_unknown()) {
667 b_str.append(1, '?');
668 } else if (b.stride != dense_stride) {
669 b_str.append(1, '*');
670 }
671 }
672 ret = b_str + ret;
673 dense_stride = b.stride * b.block;
674 seen[b.dim_idx] = true;
675 }
676 ret += ":" + type().str();
677 return ret;
678 }
679
680 std::string str() const {
681 if (is_empty()) return "(nil)";
682 std::ostringstream oss;
683 oss << desc_str();
684 if (!has_zero_offset()) oss << " offset: " << offset_;
685 return oss.str();
686 }
687
688 IR_DEFINE_DUMP()
689
690 memory_desc_t to_dnnl(const dim_t *dims_hint) const;
691
692 // Returns a vector of <block index, block> pairs.
693 // The innermost block (first) has index 0.
694 std::vector<std::pair<int, block_t>> enumerated_blocks() const {
695 std::vector<std::pair<int, block_t>> ret;
696 for (int i = 0; i < int(blocks_.size()); i++) {
697 ret.emplace_back(i, blocks_[i]);
698 }
699 return ret;
700 }
701
702 std::vector<dim_t> strides(int dim_idx) const {
703 std::vector<dim_t> ret;
704 for (auto &b : blocks_)
705 if (b.dim_idx == dim_idx) ret.push_back(b.stride);
706 return ret;
707 }
708
709 // eb is <block index, block> pair, see enumerated_blocks().
710 bool is_outermost(const std::pair<int, block_t> &eb) const {
711 return is_outermost(eb, blocks_);
712 }
713
714 bool is_plain() const {
715 std::vector<bool> seen(ndims());
716 for (auto &b : blocks_) {
717 if (seen[b.dim_idx]) return false;
718 seen[b.dim_idx] = true;
719 }
720 return true;
721 }
722
723 bool has_zero_offset() const { return offset_.is_equal(expr_t(0)); }
724
725 bool has_unknown_strides() const {
726 for (auto &b : blocks_)
727 if (b.stride.is_unknown()) return true;
728 return false;
729 }
730
731 // Returns a canonical representation of the layout:
732 // - Size one blocks are removed
733 // - Consecutive dense blocks are merged
734 layout_t normalize() const {
735 auto blocks = normalize_blocks(ndims(), blocks_);
736 return layout_t(type(), ndims(), offset(), blocks);
737 }
738
739 layout_t transpose() const {
740 if (ndims() != 2) ir_error_not_expected();
741
742 // Flip: 0 -> 1, 1 -> 0.
743 auto blocks = blocks_;
744 for (auto &b : blocks)
745 b.dim_idx ^= 1;
746
747 return layout_t(type(), ndims(), offset(), blocks);
748 }
749
750 // Returns a new (sub-)layout that fully contains the passed sub-tensor.
751 // Strides are kept unchanged.
752 // Assumption: the original layout can be tiled by the passed sub-tensor.
753 // For example: XaYb4a2b can be tiled into 2x2 sub-tensors but it's not
754 // possible to tile it into 3x2 sub-tensors.
755 layout_t map(const tensor_t &tensor) const;
756
757 layout_t reinterpret(
758 const type_t &new_type, bool do_normalize = true) const;
759
760 layout_t retype(const type_t &new_type) const {
761 auto ret = *this;
762 ret.type_ = new_type;
763 return ret;
764 }
765
766 bool is_dense() const {
767 stride_t stride = 1;
768 for (auto &b : blocks_) {
769 if (b.stride != stride) return false;
770 stride *= b.block;
771 }
772 return true;
773 }
774
775 layout_t innermost_block_layout() const {
776 int block_count[layout_t::max_ndims] = {0};
777 for (auto &b : blocks_)
778 block_count[b.dim_idx]++;
779
780 std::vector<block_t> inner_blocks;
781
782 stride_t stride = 1;
783 for (auto &b : blocks_) {
784 if (b.stride != stride) break; // Not dense anymore.
785 if (block_count[b.dim_idx] == 1) break; // Outer block.
786 stride *= b.block;
787 ir_assert(block_count[b.dim_idx] > 0);
788 block_count[b.dim_idx]--;
789 inner_blocks.push_back(b);
790 }
791 return layout_t(type(), ndims(), 0, inner_blocks);
792 }
793
794 // Returns a packed layout where all blocks are contiguous, without gaps.
795 layout_t make_dense() const {
796 dim_t stride = 1;
797 auto new_blocks = blocks_;
798 for (auto &b : new_blocks) {
799 b.stride = stride;
800 stride *= b.block;
801 }
802 return layout_t(type(), ndims(), 0, new_blocks);
803 }
804
805 layout_t make_strided(int _stride, int block_idx = 0) const {
806 dim_t cur_stride = 1;
807 auto new_blocks = blocks_;
808 for (int i = 0; i < (int)new_blocks.size(); i++) {
809 auto &b = new_blocks[i];
810 if (i >= block_idx) {
811 b.stride = (i == block_idx ? _stride : cur_stride);
812 }
813 cur_stride = b.stride * b.block;
814 }
815 return layout_t(type(), ndims(), 0, new_blocks);
816 }
817
818 // Returns an equivalent layout where the specified block is split into two.
819 // block0 - inner block size.
820 // block1 - outer block size.
821 layout_t split_block(const std::pair<int, block_t> &eb, dim_t block0,
822 dim_t block1) const;
823
824 // Splits blocks so that they can be used to form `multi_blocks` without
825 // crossing the block boundaries. `multi_blocks` are ordered from innermost
826 // to outermost. Returns an empty layout if such a split is not possible.
827 // Example (all blocks are ordered from innermost to outermost):
828 // Input blocks: [4, 4, 2]
829 // Multi-blocks: [8, 2]
830 // Output blocks: [4, 2, 2, 2]
831 layout_t split_into_multi_blocks(
832 const std::vector<dim_t> &multi_blocks) const;
833
834 layout_t add_outer_block(
835 int dim_idx, dim_t block, dim_t stride = -1) const {
836 if (stride == -1) stride = elems();
837 ir_assert(stride >= elems());
838 ir_assert(dim_idx < ndims());
839 auto new_blocks = blocks();
840 new_blocks.emplace_back(dim_idx, block, stride);
841 return layout_t(type(), ndims(), offset(), new_blocks);
842 }
843
844 // Returns a tensor corresponding to the biggest innermost sub-layout so that
845 // 1) It consists of consecutive blocks only.
846 // 2) It contains less or equal than max_tile_elems elements.
847 // 3) It is dense if is_dense_tile is true.
848 tensor_t split_into_max_tile(
849 dim_t max_tile_elems, bool is_dense_tile) const;
850
851 tensor_t split(const grid_info_t &grid_info) const {
852 tensor_t min_tile;
853 std::vector<int> cur_dims(grid_info.ndims(), 1);
854
855 for (int iter = 0; iter < grid_info.elems(); iter++) {
856 for (int i = 0; i < grid_info.ndims(); i++) {
857 if (++cur_dims[i] <= grid_info.dim(i)) break;
858 cur_dims[i] = 1;
859 }
860 auto sub_grid = grid_info.resize(cur_dims);
861 auto tile = split_exact(sub_grid);
862 if (tile.is_empty()) continue;
863 if (min_tile.is_empty() || tile.elems() < min_tile.elems()) {
864 min_tile = tile;
865 }
866 }
867 return min_tile;
868 }
869
870 tensor_t split_exact(const grid_info_t &grid) const {
871 std::vector<dim_t> tile_dims(ndims(), 1);
872 if (elems() % grid.elems() != 0) return tensor_t();
873
874 dim_t cur_elems_per_tile = 1;
875 dim_t elems_per_tile = elems() / grid.elems();
876 for (auto &b : blocks()) {
877 dim_t block
878 = std::min(b.block, elems_per_tile / cur_elems_per_tile);
879 tile_dims[b.dim_idx] *= block;
880 cur_elems_per_tile *= block;
881 }
882 if (cur_elems_per_tile != elems_per_tile) return tensor_t();
883
884 return split(tensor_t(tile_dims), grid);
885 }
886
887 tensor_t split(const tensor_t &tile, const grid_info_t &grid,
888 std::vector<block_t> *outer_blocks = nullptr) const {
889 ir_assert(ndims() == tile.ndims())
890 << "Number of dimensions doesn't match.";
891 ir_assert(tile.has_zero_start());
892
893 if (outer_blocks) outer_blocks->resize(0);
894
895 if (grid.elems() == 1) return tile;
896
897 dim_t total_elems = elems();
898 dim_t tile_elems = tile.elems();
899
900 grid_splitter_t grid_splitter(grid);
901 ir_assert(tile_elems * grid.elems() == total_elems)
902 << "Tile/grid dimensions do not match.";
903 MAYBE_UNUSED(total_elems);
904 MAYBE_UNUSED(tile_elems);
905
906 std::vector<dim_t> dims(tile.ndims(), 1);
907 std::vector<expr_t> start(tile.ndims(), 0);
908 std::vector<dim_t> rem_dims = tile.dims();
909 for (auto &eb : enumerated_blocks()) {
910 auto &b = eb.second;
911 if (b.block == 1) continue;
912
913 dim_t &e = rem_dims[b.dim_idx];
914 if (e > 1) {
915 if (e % b.block == 0) {
916 e /= b.block;
917 } else if (b.block % e == 0) {
918 auto tmp_layout = split_block(eb, e, b.block / e);
919 return tmp_layout.split(tile, grid, outer_blocks);
920 } else {
921 return tensor_t();
922 }
923 } else {
924 dim_t next_chunk
925 = math::gcd(b.block, grid_splitter.cur_block());
926 if (b.block == next_chunk) {
927 auto idx = grid_splitter.pop_block(next_chunk);
928 start[b.dim_idx] += idx * dims[b.dim_idx];
929 if (outer_blocks) outer_blocks->push_back(b);
930 } else if (b.block % next_chunk == 0 && next_chunk != 1) {
931 auto tmp_layout
932 = split_block(eb, next_chunk, b.block / next_chunk);
933 return tmp_layout.split(tile, grid, outer_blocks);
934 } else {
935 return tensor_t();
936 }
937 }
938 dims[b.dim_idx] *= b.block;
939 }
940 return tensor_t(tile.dims(), start);
941 }
942
943 // Iterates through tiles of the layout, calling `f` with relative offsets
944 // for each tile. The iteration order is defined by the layout blocks -
945 // absolute 1D offsets are increasing between callback calls.
946 template <typename F>
947 void for_each_tile(const tensor_t &tile, const F &f) const {
948 ir_assert(tile.ndims() == ndims());
949 ir_assert(tile.has_zero_start());
950 for (int i = 0; i < ndims(); i++) {
951 ir_assert(dim(i) % tile.dims()[i] == 0);
952 }
953
954 int nblocks = int(blocks().size());
955 std::vector<dim_t> sub_blocks(nblocks);
956 for (int i = 0; i < nblocks; i++)
957 sub_blocks[i] = blocks()[i].block;
958
959 for (int i = 0; i < ndims(); i++) {
960 dim_t dim = tile.dims()[i];
961 for (auto &eb : enumerated_blocks()) {
962 auto &b = eb.second;
963 if (b.dim_idx != i) continue;
964 int block_idx = eb.first;
965 if (b.block >= dim) {
966 ir_assert(b.block % dim == 0);
967 sub_blocks[block_idx] = b.block / dim;
968 break;
969 }
970 sub_blocks[block_idx] = 1;
971 ir_assert(dim % b.block == 0);
972 dim /= b.block;
973 }
974 }
975
976 int ntiles = int(elems() / tile.elems());
977
978 std::vector<dim_t> sub_block_idxs(nblocks);
979 for (int i = 0; i < ntiles; i++) {
980 // Convert sub-block indices to dimension indices.
981 std::vector<dim_t> dims(ndims(), 1);
982 std::vector<dim_t> start(ndims());
983 for (int j = 0; j < nblocks; j++) {
984 auto &b = blocks()[j];
985 dim_t k = sub_block_idxs[j]
986 * (blocks()[j].block / sub_blocks[j]);
987 start[b.dim_idx] += dims[b.dim_idx] * k;
988 dims[b.dim_idx] *= b.block;
989 }
990
991 // Pass dimension offsets to the callback.
992 f(start);
993
994 // Move to the next vector of indices.
995 for (int j = 0; j < nblocks; j++) {
996 auto &idx = sub_block_idxs[j];
997 if (idx + 1 < sub_blocks[j]) {
998 idx++;
999 break;
1000 }
1001 idx = 0;
1002 }
1003 }
1004 }
1005
1006 // eb is <block index, block> pair, see enumerated_blocks().
1007 static bool is_outermost(const std::pair<int, block_t> &eb,
1008 const std::vector<block_t> &blocks) {
1009 int dim_idx = eb.second.dim_idx;
1010 for (int i = 0; i < int(blocks.size()); i++) {
1011 if (blocks[i].dim_idx == dim_idx && i > eb.first) return false;
1012 }
1013 return true;
1014 }
1015
1016 // Assume that layouts are normalized.
1017 static void align_layouts(layout_t &a, layout_t &b);
1018
1019 static std::vector<block_t> normalize_blocks(int ndims,
1020 const std::vector<block_t> &blocks,
1021 bool remove_size_1_blocks = true) {
1022 auto new_blocks = blocks;
1023
1024 // Remove blocks of size 1.
1025 if (remove_size_1_blocks) {
1026 for (auto it = new_blocks.begin(); it != new_blocks.end();) {
1027 if (it->block == 1) {
1028 it = new_blocks.erase(it);
1029 } else {
1030 ++it;
1031 }
1032 }
1033 }
1034
1035 // Merge same dimension blocks.
1036 block_t prev_b;
1037 prev_b.dim_idx = -1;
1038 for (auto it = new_blocks.begin(); it != new_blocks.end();) {
1039 if (it->dim_idx == prev_b.dim_idx
1040 && it->stride == (prev_b.stride * prev_b.block)) {
1041 auto &b = *(it - 1);
1042 b.block *= it->block;
1043 prev_b = b;
1044 it = new_blocks.erase(it);
1045 } else {
1046 prev_b = *it;
1047 ++it;
1048 }
1049 }
1050
1051 return new_blocks;
1052 }
1053
1054 // Reinterprets layouts to wider data type (up to 4 bytes).
1055 // Example: 16a16b (s8 type) -> 16a4b (s32 type)
1056 static bool try_reinterpret_to_wider_type(layout_t &src, layout_t &dst,
1057 const tensor_t &tile = {}, bool do_update = true,
1058 int *new_size_out = nullptr) {
1059 if (src.blocks().empty() || dst.blocks().empty()) return false;
1060 if (src.type() != dst.type()) return false;
1061
1062 auto &s0 = src.blocks()[0];
1063 auto &d0 = dst.blocks()[0];
1064 if (s0.dim_idx != d0.dim_idx) return false;
1065 if (int(s0.stride) != 1) return false;
1066 if (int(d0.stride) != 1) return false;
1067
1068 int old_size = src.type().size();
1069 int s0_old_size = int(s0.block) * old_size;
1070 int d0_old_size = int(d0.block) * old_size;
1071
1072 int new_size = math::gcd(s0_old_size, d0_old_size);
1073 new_size = math::gcd(new_size, 4); // Try types up to 4 bytes.
1074 if (new_size <= old_size) return false;
1075
1076 auto tile_ok = [&](const layout_t &l) {
1077 if (tile.is_empty()) return true;
1078 int factor = new_size / old_size;
1079 if (tile(l.blocks()[0].dim_idx) % factor != 0) return false;
1080 return true;
1081 };
1082
1083 auto strides_ok = [&](const layout_t &l) {
1084 for (int i = 1; i < int(l.blocks().size()); i++) {
1085 auto &b = l.blocks()[i];
1086 if (int(b.stride) * old_size % new_size != 0) return false;
1087 }
1088 return true;
1089 };
1090
1091 while (new_size > old_size) {
1092 bool ok = true;
1093 ok &= (tile_ok(src) && tile_ok(dst));
1094 ok &= (strides_ok(src) && strides_ok(dst));
1095 if (ok) {
1096 if (do_update) {
1097 src = src.reinterpret(type_t::s(new_size * 8));
1098 dst = dst.reinterpret(type_t::s(new_size * 8));
1099 }
1100 if (new_size_out) *new_size_out = new_size;
1101 return true;
1102 }
1103 new_size /= 2;
1104 }
1105 return false;
1106 }
1107
1108private:
1109 // Returns vector of <dimension index, block size> pairs.
1110 static std::vector<std::pair<int, dim_t>> parse_format(
1111 const std::string &format, int ndims_hint);
1112
1113 // Returns vector of <dimension letter, block size> pairs.
1114 static std::vector<std::pair<char, dim_t>> parse_letter_blocks(
1115 const std::string &format);
1116
1117 void sanity_check() const;
1118
1119 // Data type of the layout.
1120 type_t type_;
1121
1122 // Number of dimensions.
1123 int ndims_;
1124
1125 // Offset to the start of the layout (in elements of type).
1126 expr_t offset_;
1127
1128 // Blocks ordered from innermost to outermost.
1129 std::vector<block_t> blocks_;
1130};
1131
1132// Helper class to incrementally increase a sub-layout of the given layout.
1133// One step - adding the minimal factor of the next remaining block. Used
1134// to find the minimal tile between two layouts that is innermost for both
1135// layouts.
1136class layout_iterator_t {
1137public:
1138 layout_iterator_t(const layout_t &l) : l_(l), block_idx_(-1), block_(1) {}
1139
1140 bool has_next() const {
1141 dim_t b = block_;
1142 int b_idx = block_idx_;
1143 while (b == 1) {
1144 b_idx++;
1145 if (b_idx >= int(l_.blocks().size())) return false;
1146 b = int(l_.blocks()[b_idx].block);
1147 }
1148 return true;
1149 }
1150
1151 layout_iterator_t &operator++() {
1152 ir_assert(has_next());
1153 while (block_ == 1) {
1154 block_idx_++;
1155 block_ = int(l_.blocks()[block_idx_].block);
1156 }
1157 // Find smallest factor.
1158 for (int factor = 2; factor <= int(block_); factor++) {
1159 if (block_ % factor == 0) {
1160 block_ /= factor;
1161 return *this;
1162 }
1163 }
1164
1165 ir_error_not_expected();
1166 return *this;
1167 }
1168
1169 tensor_t tile() const {
1170 std::vector<dim_t> dims(l_.ndims(), 1);
1171 for (int i = 0; i <= block_idx_; i++) {
1172 auto &b = l_.blocks()[i];
1173 int b_block = b.block;
1174 if (i == block_idx_) b_block /= block_;
1175 dims[b.dim_idx] *= b_block;
1176 }
1177 return tensor_t(dims);
1178 }
1179
1180 int nblocks() const { return block_idx_ + 1; }
1181
1182 layout_t outer_layout() const {
1183 auto &blocks = l_.blocks();
1184 std::vector<block_t> outer_blocks;
1185 if (block_ > 1) {
1186 auto &b = blocks[block_idx_];
1187 outer_blocks.push_back(b);
1188 outer_blocks[0].block = block_;
1189 outer_blocks[0].stride = b.stride * (b.block / block_);
1190 }
1191 outer_blocks.insert(outer_blocks.end(),
1192 blocks.begin() + (block_idx_ + 1), blocks.end());
1193 return layout_t(l_.type(), l_.ndims(), l_.offset(), outer_blocks);
1194 }
1195
1196private:
1197 const layout_t &l_;
1198
1199 int block_idx_;
1200 dim_t block_;
1201};
1202
1203inline std::ostream &operator<<(std::ostream &out, const layout_t &layout) {
1204 out << layout.str();
1205 return out;
1206}
1207
1208class mask_tensor_t {
1209public:
1210 mask_tensor_t() = default;
1211
1212 mask_tensor_t(const layout_t &layout)
1213 : layout_(layout), masks_(layout.elems(), -1) {
1214 ir_assert(layout.is_dense());
1215 }
1216
1217 mask_tensor_t(const layout_t &layout, const std::vector<int> &masks,
1218 const object_eq_map_t<expr_t, int> &mask2ids,
1219 const std::vector<expr_t> &id2masks)
1220 : layout_(layout)
1221 , masks_(masks)
1222 , mask2ids_(mask2ids)
1223 , id2masks_(id2masks) {
1224 ir_assert(int(masks.size()) == elems()) << "Incompatible size.";
1225 }
1226
1227 const type_t &type() const { return layout_.type(); }
1228
1229 const layout_t &layout() const { return layout_; }
1230
1231 dim_t elems() const { return layout_.elems(); }
1232
1233 void set_mask(dim_t off, const expr_t &mask) {
1234 ir_assert(0 <= off && off < elems()) << "Incorrect offset.";
1235 if (mask.is_empty()) return;
1236
1237 auto ret = mask2ids_.insert({mask, int(mask2ids_.size())});
1238 int id = ret.first->second;
1239 masks_[off] = id;
1240
1241 if (ret.second) id2masks_.push_back(mask);
1242 }
1243
1244 const expr_t &mask(dim_t off) const {
1245 ir_assert(0 <= off && off < elems());
1246 return id2masks_[masks_[off]];
1247 }
1248
1249 void simplify(const constraint_set_t &cset) {
1250 for (auto &mask : id2masks_) {
1251 auto new_mask = jit::simplify(mask, cset);
1252 // Some complex expressions need more than one simplify() call.
1253 int max_tries = 5;
1254 for (int i = 0; i < max_tries; i++) {
1255 mask = new_mask;
1256 new_mask = jit::simplify(new_mask, cset);
1257 if (new_mask.is_equal(mask)) break;
1258 }
1259 }
1260 mask2ids_.clear();
1261 for (int i = 0; i < int(id2masks_.size()); i++) {
1262 auto ret = mask2ids_.insert({id2masks_[i], i});
1263 if (!ret.second) {
1264 for (auto &m : masks_)
1265 if (m == i) m = ret.first->second;
1266 }
1267 }
1268 }
1269
1270 mask_tensor_t map(const tensor_t &tile) const {
1271 auto tile_start = expr_cast<dim_t>(tile.start());
1272 auto sub_layout = layout_.map(tensor_t(tile.dims()));
1273 mask_tensor_t sub_mask(sub_layout);
1274 ir_utils::for_each(
1275 tile.dims(), [&](const std::vector<dim_t> &sub_start) {
1276 dim_t sub_off = sub_layout(sub_start);
1277 dim_t off = layout_(tile_start) + layout_(sub_start);
1278 sub_mask.set_mask(sub_off, mask(off));
1279 });
1280 return sub_mask;
1281 }
1282
1283 mask_tensor_t reinterpret(const type_t &new_type) const {
1284 ir_assert(!is_empty()) << "Can't reinterpret.";
1285 dim_t bytes = elems() * type().size();
1286 if (bytes % new_type.size() != 0 && bytes > new_type.size())
1287 return mask_tensor_t();
1288 int new_mask_size = std::max((int)(bytes / new_type.size()), 1);
1289 std::vector<int> new_masks(new_mask_size);
1290 for (dim_t i = 0; i < bytes; i += new_type.size()) {
1291 int mask_id = std::numeric_limits<int>::max();
1292 for (int j = 0; j < new_type.size() && j < bytes; j++) {
1293 int cur_mask_id = masks_[(i + j) / type().size()];
1294 if (mask_id >= int(masks_.size())) {
1295 mask_id = cur_mask_id;
1296 } else if (mask_id != cur_mask_id) {
1297 // Mask is not consistent, can't reinterpret.
1298 return mask_tensor_t();
1299 }
1300 }
1301 ir_assert(0 <= mask_id && mask_id < int(masks_.size()));
1302 new_masks[i / new_type.size()] = mask_id;
1303 }
1304 dim_t new_elems = utils::div_up(bytes, new_type.size());
1305 layout_t _1d_layout(new_type, 0, std::vector<dim_t> {new_elems});
1306 return mask_tensor_t(_1d_layout, new_masks, mask2ids_, id2masks_);
1307 }
1308
1309 expr_t to_expr(int nmasks) const {
1310 if (elems() % nmasks != 0) return expr_t();
1311
1312 std::vector<expr_t> vec(nmasks);
1313 for (int i = 0; i < elems(); i++) {
1314 auto &channel_mask = vec[i % nmasks];
1315 auto &cur_mask = id2masks_[masks_[i]];
1316 if (channel_mask.is_empty()) {
1317 channel_mask = cur_mask;
1318 continue;
1319 }
1320 if (!channel_mask.is_equal(cur_mask)) return expr_t();
1321 }
1322 auto e = shuffle_t::make(vec);
1323 e = jit::simplify(e);
1324 e = jit::simplify_propagate_shuffle(e);
1325 return e;
1326 }
1327
1328 bool is_empty() const { return layout_.is_empty(); }
1329
1330 std::string str() const {
1331 std::ostringstream oss;
1332 for (int i = 0; i < int(elems()); i++) {
1333 if (i != 0) oss << std::endl;
1334 oss << "mask #" << i << ": ";
1335 if (masks_[i] == -1) {
1336 oss << "(nil)";
1337 } else {
1338 oss << id2masks_[masks_[i]];
1339 }
1340 }
1341 return oss.str();
1342 }
1343
1344 IR_DEFINE_DUMP()
1345
1346private:
1347 layout_t layout_;
1348 std::vector<int> masks_;
1349
1350 object_eq_map_t<expr_t, int> mask2ids_;
1351 std::vector<expr_t> id2masks_;
1352};
1353
1354inline std::ostream &operator<<(
1355 std::ostream &out, const mask_tensor_t &mask_tensor) {
1356 out << mask_tensor.str();
1357 return out;
1358}
1359
1360class tdim_info_t {
1361public:
1362 tdim_info_t() = default;
1363
1364 tdim_info_t(const expr_t &expr, const expr_t &mask)
1365 : expr_(expr), mask_(mask) {}
1366
1367 int nvargs() const { return nvargs_; }
1368
1369 const expr_t &expr() const { return expr_; }
1370
1371 const expr_t &mask() const { return mask_; }
1372
1373 void set_mask(const expr_t &value) { mask_ = value; }
1374
1375 expr_t mask(const expr_t &tvalue, const std::vector<expr_t> &vvars,
1376 const std::vector<expr_t> &vvalues) const {
1377 auto ret = substitute(mask_, placeholder_var(), tvalue);
1378 for (int i = 0; i < int(vvars.size()); i++) {
1379 if (contains_object(ret, vvars[i])) {
1380 ret = substitute(ret, vvars[i], vvalues[i]);
1381 }
1382 }
1383 return ret;
1384 }
1385
1386 int vidx(int arg_idx) const {
1387 ir_assert(arg_idx < nvargs());
1388 return vidxs_[arg_idx];
1389 }
1390
1391 stride_t vstride(int arg_idx) const {
1392 ir_assert(arg_idx < nvargs());
1393 return vstrides_[arg_idx];
1394 }
1395
1396 bool is_empty() const { return expr_.is_empty(); }
1397
1398 bool is_identity() const { return is_var(expr_); }
1399
1400 bool is_fixed_stride(int arg_idx) const {
1401 ir_assert(arg_idx < nvargs());
1402 return vstrides_[arg_idx].is_fixed();
1403 }
1404
1405 void add_vvar(int vidx, const expr_t &varg) {
1406 ir_assert(nvargs_ + 1 <= max_nvargs);
1407 vidxs_[nvargs_] = vidx;
1408 vstrides_[nvargs_] = compute_stride(expr_, nvargs_, varg);
1409 nvargs_++;
1410 }
1411
1412 static const expr_t &placeholder_var() {
1413 static thread_local expr_t ph_var = var_t::make(type_t::s32(), "_ph");
1414 return ph_var;
1415 }
1416
1417private:
1418 static const int max_nvargs = 2;
1419
1420 static stride_t compute_stride(const expr_t &e, int idx, const expr_t &var);
1421
1422 expr_t expr_;
1423
1424 int nvargs_ = 0;
1425 std::array<stride_t, max_nvargs> vstrides_;
1426 std::array<int, max_nvargs> vidxs_;
1427 expr_t mask_;
1428};
1429
1430class view_t {
1431public:
1432 view_t() = default;
1433
1434 view_t(const std::vector<expr_t> &vvars, int ntdims)
1435 : vvars_(vvars)
1436 , vdims_(vvars.size())
1437 , vstart_(vvars.size())
1438 , tdims_(ntdims) {}
1439
1440 // Constructs view from a layout.
1441 explicit view_t(const layout_t &layout,
1442 const std::vector<expr_t> &_vvars = {},
1443 uint32_t bound_check_mask = 0)
1444 : view_t(layout, _vvars, layout.dims(), bound_check_mask) {}
1445
1446 view_t(const layout_t &layout, const std::vector<expr_t> &_vvars,
1447 const std::vector<dim_t> &_vdims, uint32_t bound_check_mask)
1448 : vvars_(_vvars)
1449 , vdims_(_vdims)
1450 , vstart_(layout.ndims(), 0)
1451 , tdims_(layout.ndims())
1452 , tlayout_(layout) {
1453 if (vvars_.empty()) vvars_ = create_vvars(layout.ndims());
1454 for (int i = 0; i < nvdims(); i++) {
1455 expr_t i_mask;
1456 if ((bound_check_mask & (1 << i)) != 0)
1457 i_mask = (placeholder_var() < layout.dim(i));
1458 set_tdim(i, vvars_[i], i_mask);
1459 }
1460 }
1461
1462 const std::vector<expr_t> &vvars() const { return vvars_; }
1463
1464 const std::vector<dim_t> &vdims() const { return vdims_; }
1465
1466 std::vector<expr_t> vstart() const { return vstart_; }
1467
1468 expr_t vstart(int vidx) const { return vstart_[vidx]; }
1469
1470 const layout_t &tlayout() const { return tlayout_; }
1471
1472 int nvdims() const { return int(vdims_.size()); }
1473
1474 int ntdims() const { return int(tdims_.size()); }
1475
1476 dim_t velems() const {
1477 dim_t ret = 1;
1478 for (int i = 0; i < nvdims(); i++)
1479 ret *= vdims_[i];
1480 return ret;
1481 }
1482
1483 const expr_t &vvar(int idx) const {
1484 ir_assert(idx < nvdims());
1485 return vvars_[idx];
1486 }
1487
1488 const tdim_info_t &tdim(int idx) const {
1489 ir_assert(idx < ntdims());
1490 return tdims_[idx];
1491 }
1492
1493 void set_tdim(int tidx, const expr_t &_texpr, expr_t mask = {}) {
1494 ir_assert(tdims_[tidx].is_empty());
1495
1496 auto texpr = simplify(_texpr);
1497
1498 tdim_info_t tdim(texpr, mask);
1499 for (int i = 0; i < nvdims(); i++) {
1500 if (contains_object(texpr, vvars_[i])) tdim.add_vvar(i, vvars_[i]);
1501 }
1502 if (!is_const(texpr)) {
1503 ir_assert(tdim.nvargs() > 0)
1504 << "Tensor dimension must have at least one view dimension "
1505 "that maps to it.";
1506 }
1507 tdims_[tidx] = tdim;
1508 }
1509
1510 void set_vdim(
1511 const expr_t &varg, dim_t vdim, const expr_t &vstart = expr_t(0)) {
1512 int vidx = vvar_index(varg);
1513 ir_assert(vstart_[vidx].is_empty());
1514 vstart_[vidx] = vstart;
1515 vdims_[vidx] = vdim;
1516 }
1517
1518 void set_tlayout(const layout_t &tlayout) { tlayout_ = tlayout; }
1519
1520 void set_tmasks(const std::unordered_map<std::string, int> &padded_dims,
1521 const std::unordered_map<std::string, int> &dim_blocks) {
1522 using namespace ir_utils;
1523 auto &x = placeholder_var();
1524 for (int i = 0; i < ntdims(); i++) {
1525 auto &tdim = tdims_[i];
1526 if (!tdim.is_identity() || !tdim.mask().is_empty()) continue;
1527 int vidx = tdim.vidx(0);
1528 int dim = tlayout_.dim(i);
1529 auto &dim_name = vvars_[vidx].as<var_t>().name;
1530 int padded_dim = get_or_default(padded_dims, dim_name, 1);
1531 if (dim >= padded_dim) continue;
1532 int inner_blk = ir_utils::max_pow2_divisor(dim);
1533 int dim_blk = get_or_default(dim_blocks, dim_name, 1);
1534 if (math::is_pow2(dim_blk)) {
1535 inner_blk = std::min(inner_blk, dim_blk);
1536 }
1537 auto tmask = (inner_blk == 1) ? (x < dim)
1538 : (x / inner_blk < dim / inner_blk);
1539 tdim.set_mask(tmask);
1540 }
1541 }
1542
1543 std::string str() const {
1544 using ir_utils::operator<<;
1545
1546 if (is_empty()) return "(nil)";
1547 std::ostringstream oss;
1548 oss << ir_utils::make_seq_print_helper(vdims_, "x");
1549 if (!has_zero_vstart()) oss << " vstart: [" << vstart_ << "]";
1550 oss << " tlayout: " << tlayout_;
1551 return oss.str();
1552 }
1553
1554 IR_DEFINE_DUMP()
1555
1556 bool is_empty() const { return vdims_.empty(); }
1557
1558 bool has_zero_vstart() const {
1559 for (int i = 0; i < nvdims(); i++)
1560 if (!is_zero(vstart_[i])) return false;
1561 return true;
1562 }
1563
1564 bool has_tmask(int tidx) const {
1565 ir_assert(tidx >= 0 && tidx < ntdims());
1566 return !tdims_[tidx].mask().is_empty();
1567 }
1568
1569 const type_t &type() const { return tlayout_.type(); }
1570
1571 expr_t offset(const std::vector<expr_t> &vargs = {},
1572 bool ignore_offset = false) const {
1573 auto targs = cvt_vargs_to_targs(vargs);
1574 return tlayout_.offset(targs, ignore_offset);
1575 }
1576
1577 expr_t offset_in_bytes(const std::vector<expr_t> &vargs = {},
1578 bool ignore_offset = false) const {
1579 return offset(vargs, ignore_offset) * type().size();
1580 }
1581
1582 int get_alignment(const constraint_set_t &cset) const {
1583 // Alignment must be a power of 2.
1584 const int base_alignment = 128;
1585 int64_t f = get_max_const_factor(this->offset_in_bytes(), cset);
1586 int alignment = f ? ir_utils::max_pow2_divisor(f) : base_alignment;
1587 return std::min(base_alignment, alignment);
1588 }
1589
1590 int vvar_index(const expr_t &vvar) const {
1591 for (size_t i = 0; i < vvars_.size(); i++)
1592 if (vvar.is_same(vvars_[i])) return int(i);
1593 ir_error_not_expected() << "Can't find view dimension.";
1594 return -1;
1595 }
1596
1597 template <typename T>
1598 T operator()(const std::vector<T> &vargs) const {
1599 auto targs = cvt_vargs_to_targs(vargs);
1600 return tlayout_(targs);
1601 }
1602
1603 view_t create_sub_view(const tensor_t &sub_tensor) const;
1604
1605 view_t retype(const type_t &new_type) const {
1606 auto ret = *this;
1607 ret.tlayout_ = tlayout_.retype(new_type);
1608 return ret;
1609 }
1610
1611 view_t make_dense() const {
1612 auto ret = *this;
1613 ret.tlayout_ = tlayout_.make_dense();
1614 return ret;
1615 }
1616
1617 bool is_masked_vdim(int vidx) const {
1618 ir_assert(vidx >= 0 && vidx < nvdims());
1619 ir_assert(has_zero_vstart())
1620 << "Can't be reliably determined if the view is a sub-view.";
1621 for (int i = 0; i < ntdims(); i++) {
1622 auto &tdim = tdims_[i];
1623 if (tdim.expr().is_equal(vvars_[vidx])) {
1624 if (vdims_[vidx] != tlayout_.dim(i)) return true;
1625 }
1626 if (has_tmask(i)) {
1627 for (int j = 0; j < tdim.nvargs(); j++) {
1628 if (tdim.vidx(j) == vidx) return true;
1629 }
1630 }
1631 }
1632 return false;
1633 }
1634
1635 // Returns the mask corresponding to `vargs` view indices. The mask is
1636 // based on:
1637 // 1) combined tensor masks for the given indices
1638 // 2) Bounds-based masks for those view dimensions that are used directly
1639 // in the tensor
1640 // - Example: 32a layout when 'a' dimension is A < 32. In general it's
1641 // fine to load/store elements with indices in the range [A, 31]
1642 // assuming the zero padding invariant. However in some cases we need
1643 // to generate the exact bound condition based on the logical indices.
1644 expr_t vmask(const std::vector<expr_t> &vargs) const {
1645 ir_assert(int(vargs.size()) == nvdims()) << "Incompatible dimensions.";
1646 ir_assert(has_zero_vstart())
1647 << "Can't be reliably determined if the view is a sub-view.";
1648 auto targs = cvt_vargs_to_targs(vargs);
1649 auto mask = bool_imm_t::make(true);
1650 for (int i = 0; i < ntdims(); i++) {
1651 for (int j = 0; j < nvdims(); j++) {
1652 if (!tdims_[i].expr().is_equal(vvars_[j])) continue;
1653 if (vdims_[j] != tlayout_.dim(i)) {
1654 mask &= (vargs[j] < vdims_[j]);
1655 }
1656 }
1657 if (has_tmask(i)) {
1658 auto &tdim = tdims_[i];
1659 mask &= tdim.mask(targs[i], vvars_, vargs);
1660 }
1661 }
1662 return mask;
1663 }
1664
1665 bool can_convert_to_vlayout() const {
1666 if (nvdims() != ntdims()) return false;
1667 for (int i = 0; i < nvdims(); i++) {
1668 if (!tdims_[i].expr().is_same(vvars_[i])) return false;
1669 if (!tdims_[i].is_fixed_stride(0)) return false;
1670 }
1671 return true;
1672 }
1673
1674 // FIXME: Offset of the returned layout is always 0.
1675 layout_t create_pseudo_vlayout() const {
1676 return create_pseudo_vlayout(normalized_tlayout());
1677 }
1678
1679 layout_t normalized_tlayout() const {
1680 auto blocks = move_size_1_blocks_outer();
1681 blocks = layout_t::normalize_blocks(tlayout_.ndims(), blocks, false);
1682 auto layout
1683 = layout_t(type(), tlayout_.ndims(), offset(), blocks, false);
1684 return layout;
1685 }
1686
1687 layout_t create_dense_vlayout() const {
1688 return create_pseudo_vlayout().make_dense();
1689 }
1690
1691 layout_t create_vlayout(bool force_zero_offset = false) const {
1692 ir_assert(can_convert_to_vlayout()) << "Can't convert view to layout.";
1693 if (force_zero_offset) return tlayout_.map(tensor_t(vdims_));
1694 return tlayout_.map(tensor_t(vdims_, vstart_));
1695 }
1696
1697 dim_t vlayout_size() const { return create_vlayout().size(); }
1698
1699 bool has_same_vlayout(
1700 const view_t &other, bool compare_offset = true) const {
1701 return create_vlayout().is_equal(
1702 other.create_vlayout(), compare_offset);
1703 }
1704
1705 view_t split(const grid_info_t &grid, tensor_t &vtile) const {
1706 auto vlayout = create_pseudo_vlayout();
1707 vtile = vlayout.split(grid);
1708 return create_sub_view(vtile);
1709 }
1710
1711 view_t split(const grid_info_t &grid) const {
1712 tensor_t vtile;
1713 return split(grid, vtile);
1714 }
1715
1716 // Returns a tensor corresponding to the biggest innermost sub-layout so that
1717 // 1) It consists of consecutive blocks only.
1718 // 2) It contains less or equal than max_tile_elems elements.
1719 // 3) It is dense if is_dense_tile is true.
1720 tensor_t split_into_max_tile(
1721 dim_t max_tile_elems, bool is_dense_tile) const {
1722 auto vlayout = create_pseudo_vlayout();
1723 return vlayout.split_into_max_tile(max_tile_elems, is_dense_tile);
1724 }
1725
1726 template <typename F>
1727 void for_each_tile(const tensor_t &tile, const F &f) const {
1728 auto vlayout = create_dense_vlayout();
1729 vlayout.for_each_tile(tile, f);
1730 }
1731
1732 view_t substitute(const expr_t &from, const expr_t &to) const;
1733
1734 mask_tensor_t create_mask_tensor(
1735 const constraint_set_t &cset, uint32_t tmask = 0xFFFFFFFF) const {
1736 auto _vlayout = create_dense_vlayout();
1737 mask_tensor_t mask_tensor(_vlayout);
1738 std::vector<dim_t> vargs(nvdims());
1739 create_mask_tensor(mask_tensor, _vlayout, 0, vargs, tmask);
1740 mask_tensor.simplify(cset);
1741 return mask_tensor;
1742 }
1743
1744 void try_create_buffer_view(view_t &buf_view, view_t &inv_view) const {
1745 buf_view = view_t(create_vvars(ntdims()), ntdims());
1746 inv_view = view_t(vvars(), ntdims());
1747 for (int i = 0; i < nvdims(); i++) {
1748 inv_view.set_vdim(vvars()[i], vdims()[i]);
1749 }
1750 for (int i = 0; i < ntdims(); i++) {
1751 auto &tdim = tdims_[i];
1752 auto &buf_vvar = buf_view.vvars()[i];
1753 if (tdim.is_identity()) {
1754 int vidx = tdim.vidx(0);
1755 buf_view.set_vdim(buf_vvar, vdims()[vidx], vstart(vidx));
1756 buf_view.set_tdim(i, buf_vvar, tdim.mask());
1757 inv_view.set_tdim(i, tdim.expr());
1758 continue;
1759 }
1760 int buf_vdim = 0;
1761 bool ok = true;
1762 for (int j = 0; j < tdim.nvargs(); j++) {
1763 int vidx = tdim.vidx(j);
1764 auto &vvar = vvars()[vidx];
1765 int vdim = vdims()[vidx];
1766 if (vdim == 1) continue;
1767 auto A = tdim.expr();
1768 auto B = jit::substitute(A, vvar, vvar + 1);
1769 auto C = simplify(B - A);
1770 if (!is_const(C)) {
1771 ok = false;
1772 break;
1773 }
1774 buf_vdim += to_cpp<int>(C) * (vdim - 1);
1775 }
1776 buf_vdim++;
1777
1778 if (!ok) {
1779 buf_view = view_t();
1780 inv_view = view_t();
1781 return;
1782 }
1783
1784 auto buf_vstart = tdim.expr();
1785 auto inv_vstart = tdim.expr();
1786 for (int j = 0; j < tdim.nvargs(); j++) {
1787 int vidx = tdim.vidx(j);
1788 buf_vstart = jit::substitute(
1789 buf_vstart, vvars()[vidx], vstart(vidx));
1790 inv_vstart
1791 = jit::substitute(inv_vstart, vvars()[vidx], expr_t(0));
1792 }
1793 buf_vstart = simplify(buf_vstart);
1794 inv_vstart = simplify(inv_vstart);
1795
1796 if (!is_const(inv_vstart)) {
1797 buf_view = view_t();
1798 inv_view = view_t();
1799 return;
1800 }
1801
1802 buf_view.set_vdim(buf_vvar, buf_vdim, buf_vstart);
1803
1804 // Check that mask doesn't contain vvars - they can't be accessed
1805 // in the buffered view.
1806 auto &tmask = tdim.mask();
1807 for (auto &vvar : vvars()) {
1808 if (contains_object(tmask, vvar)) {
1809 buf_view = view_t();
1810 inv_view = view_t();
1811 return;
1812 }
1813 }
1814
1815 buf_view.set_tdim(i, buf_vvar, tmask);
1816 inv_view.set_tdim(i, tdim.expr() - inv_vstart);
1817 }
1818 buf_view.set_tlayout(tlayout_);
1819 }
1820
1821 static const expr_t &placeholder_var() {
1822 return tdim_info_t::placeholder_var();
1823 }
1824
1825 static std::vector<expr_t> create_vvars(int nvdims);
1826
1827 template <typename SrcT = expr_t, typename DstT = SrcT>
1828 std::vector<DstT> cvt_vargs_to_targs(const std::vector<SrcT> &_vargs = {},
1829 bool ignore_vstart = false) const {
1830 std::vector<expr_t> vargs = expr_cast<expr_t>(_vargs);
1831 if (vargs.empty()) vargs.resize(nvdims(), 0);
1832
1833 if (!ignore_vstart) {
1834 for (int i = 0; i < nvdims(); i++) {
1835 if (!is_zero(vstart_[i])) vargs[i] += vstart_[i];
1836 }
1837 }
1838
1839 std::vector<expr_t> targs(ntdims());
1840 for (int i = 0; i < ntdims(); i++) {
1841 targs[i] = tdims_[i].expr();
1842 for (int j = 0; j < nvdims(); j++) {
1843 targs[i] = jit::substitute(targs[i], vvars_[j], vargs[j]);
1844 }
1845 }
1846 for (int i = 0; i < ntdims(); i++) {
1847 targs[i] = const_fold(targs[i]);
1848 }
1849 return expr_cast<DstT>(targs);
1850 }
1851
1852private:
1853 layout_t create_pseudo_vlayout(const layout_t &tlayout) const;
1854
1855 void create_mask_tensor(mask_tensor_t &mask_tensor,
1856 const layout_t &_vlayout, int vidx, std::vector<dim_t> &vargs,
1857 uint32_t tmask) const {
1858 if (vidx == _vlayout.ndims()) {
1859 bool is_init = false;
1860 std::vector<expr_t> vvalues;
1861 std::vector<expr_t> targs;
1862 expr_t mask = bool_imm_t::make(true);
1863 for (int i = 0; i < ntdims(); i++) {
1864 auto &tdim = tdims_[i];
1865 if ((tmask & (1 << i)) == 0) continue;
1866 if (tdim.mask().is_empty()) continue;
1867 if (!is_init) {
1868 // Lazily initialize values
1869 vvalues = vstart_;
1870 for (int i = 0; i < nvdims(); i++)
1871 vvalues[i] += vargs[i];
1872 targs = cvt_vargs_to_targs<dim_t, expr_t>(vargs);
1873 is_init = true;
1874 }
1875 mask &= tdim.mask(targs[i], vvars_, vvalues);
1876 }
1877 mask_tensor.set_mask(_vlayout(vargs), mask);
1878 return;
1879 }
1880
1881 for (int i = 0; i < vdims()[vidx]; i++) {
1882 vargs[vidx] = i;
1883 create_mask_tensor(mask_tensor, _vlayout, vidx + 1, vargs, tmask);
1884 }
1885 }
1886
1887 std::vector<block_t> move_size_1_blocks_outer() const {
1888 std::vector<block_t> new_blocks;
1889 std::vector<block_t> size_1_blocks;
1890 for (auto &b : tlayout_.blocks()) {
1891 if (b.block == 1 && vdims_[b.dim_idx] == 1) {
1892 size_1_blocks.emplace_back(b);
1893 } else {
1894 new_blocks.emplace_back(b);
1895 }
1896 }
1897 stride_t stride = new_blocks.empty()
1898 ? stride_t(1)
1899 : new_blocks.back().block * new_blocks.back().stride;
1900 for (auto &b : size_1_blocks) {
1901 b.stride = stride;
1902 new_blocks.emplace_back(b);
1903 }
1904 return new_blocks;
1905 }
1906
1907 std::vector<expr_t> vvars_;
1908 std::vector<dim_t> vdims_;
1909 std::vector<expr_t> vstart_;
1910
1911 std::vector<tdim_info_t> tdims_;
1912 layout_t tlayout_;
1913};
1914
1915inline std::ostream &operator<<(std::ostream &out, const view_t &view) {
1916 out << view.str();
1917 return out;
1918}
1919
1920class dim_assignment_t {
1921public:
1922 dim_assignment_t() = default;
1923
1924 dim_assignment_t(int old_ndims, int new_ndims)
1925 : old_ndims_(old_ndims)
1926 , new_ndims_(new_ndims)
1927 , assignments_(old_ndims, -1) {}
1928
1929 void assign(int old_idx, int new_idx) {
1930 ir_assert(0 <= old_idx && old_idx < old_ndims_);
1931 ir_assert(0 <= new_idx && new_idx < new_ndims_);
1932 assignments_[old_idx] = new_idx;
1933 }
1934
1935 void assign(const std::vector<int> &old_idxes, int new_idx) {
1936 for (auto old_idx : old_idxes) {
1937 assign(old_idx, new_idx);
1938 }
1939 }
1940
1941 int operator[](int old_idx) const {
1942 ir_assert(old_idx >= 0 && old_idx < old_ndims());
1943 return assignments_[old_idx];
1944 }
1945
1946 int old_ndims() const { return old_ndims_; }
1947
1948 int new_ndims() const { return new_ndims_; }
1949
1950 bool is_empty() const { return old_ndims_ == 0 && new_ndims_ == 0; }
1951
1952 layout_t map(const layout_t &layout) const;
1953
1954private:
1955 int old_ndims_ = 0;
1956 int new_ndims_ = 0;
1957
1958 // assignments_[old_idx] = new_idx.
1959 std::vector<int> assignments_;
1960};
1961
1962} // namespace jit
1963} // namespace gpu
1964} // namespace impl
1965} // namespace dnnl
1966
1967#endif
1968