1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_TENSOR_HPP |
18 | #define GPU_JIT_IR_TENSOR_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <array> |
22 | #include <iostream> |
23 | #include <sstream> |
24 | #include <string> |
25 | #include <thread> |
26 | #include <tuple> |
27 | #include <utility> |
28 | #include <vector> |
29 | #include <unordered_map> |
30 | |
31 | #include "common/memory_desc_wrapper.hpp" |
32 | #include "gpu/jit/ir/ir.hpp" |
33 | #include "gpu/jit/pass/simplify.hpp" |
34 | #include "gpu/jit/utils/utils.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace gpu { |
39 | namespace jit { |
40 | |
41 | class tensor_t { |
42 | public: |
43 | tensor_t() = default; |
44 | |
45 | tensor_t(const std::vector<dim_t> &dims) |
46 | : tensor_t(dims, std::vector<expr_t>()) {} |
47 | |
48 | tensor_t(const std::vector<dim_t> &dims, const std::vector<expr_t> &start) |
49 | : dims_(dims), start_(start) { |
50 | if (start_.empty()) start_.resize(dims.size(), 0); |
51 | } |
52 | |
53 | tensor_t(const std::vector<dim_t> &dims, const std::vector<dim_t> &start) |
54 | : tensor_t(dims) { |
55 | start_.resize(start.size()); |
56 | for (size_t i = 0; i < start.size(); i++) |
57 | start_[i] = start[i]; |
58 | } |
59 | |
60 | dim_t operator()(int idx) const { return dims_[idx]; } |
61 | |
62 | const expr_t &start(int idx) const { return start_[idx]; } |
63 | |
64 | int ndims() const { return int(dims_.size()); } |
65 | |
66 | dim_t elems() const { |
67 | dim_t ret = 1; |
68 | for (int i = 0; i < ndims(); i++) |
69 | ret *= dims_[i]; |
70 | return ret; |
71 | } |
72 | |
73 | const std::vector<dim_t> &dims() const { return dims_; } |
74 | |
75 | const std::vector<expr_t> &start() const { return start_; } |
76 | |
77 | bool is_empty() const { return dims_.empty(); } |
78 | |
79 | bool is_equal(const tensor_t &other) const { |
80 | if (ndims() != other.ndims()) return false; |
81 | for (int i = 0; i < ndims(); i++) { |
82 | if (dims_[i] != other.dims_[i]) return false; |
83 | if (!start_[i].is_equal(other.start_[i])) return false; |
84 | } |
85 | return true; |
86 | } |
87 | |
88 | bool is_divisible(const tensor_t &other) const { |
89 | if (ndims() != other.ndims()) return false; |
90 | for (int i = 0; i < ndims(); i++) { |
91 | if (dims_[i] % other.dims_[i] != 0) return false; |
92 | } |
93 | return true; |
94 | } |
95 | |
96 | std::string str() const { |
97 | using ir_utils::operator<<; |
98 | |
99 | if (is_empty()) return "(nil)" ; |
100 | std::ostringstream oss; |
101 | oss << ir_utils::make_seq_print_helper(dims_, "x" ); |
102 | if (!has_zero_start()) oss << " start: [" << start_ << "]" ; |
103 | return oss.str(); |
104 | } |
105 | |
106 | IR_DEFINE_DUMP() |
107 | |
108 | bool has_zero_start() const { |
109 | for (int i = 0; i < ndims(); i++) |
110 | if (!is_zero(start_[i])) return false; |
111 | return true; |
112 | } |
113 | |
114 | dim_t to_1d_offset(const std::vector<dim_t> &args) const { |
115 | ir_assert(has_zero_start()); |
116 | |
117 | dim_t off = 0; |
118 | for (int i = 0; i < ndims(); i++) { |
119 | off *= dims_[i]; |
120 | off += args[i]; |
121 | } |
122 | return off; |
123 | } |
124 | |
125 | tensor_t create_sub_tensor(const tensor_t &tile) const { |
126 | ir_assert(ndims() == tile.ndims()) << "Incompatible sizes." ; |
127 | std::vector<expr_t> new_start = start_; |
128 | for (int i = 0; i < ndims(); i++) |
129 | new_start[i] += tile.start(i); |
130 | return tensor_t(tile.dims(), new_start); |
131 | } |
132 | |
133 | tensor_t substitute(const expr_t &from, const expr_t &to) const { |
134 | tensor_t ret = *this; |
135 | for (int i = 0; i < ndims(); i++) { |
136 | ret.start_[i] = jit::substitute(ret.start_[i], from, to); |
137 | ret.start_[i] = simplify(ret.start_[i]); |
138 | } |
139 | return ret; |
140 | } |
141 | |
142 | private: |
143 | std::vector<dim_t> dims_; |
144 | std::vector<expr_t> start_; |
145 | }; |
146 | |
147 | inline std::ostream &operator<<(std::ostream &out, const tensor_t &tensor) { |
148 | out << tensor.str(); |
149 | return out; |
150 | } |
151 | |
152 | class grid_info_t { |
153 | public: |
154 | grid_info_t() = default; |
155 | grid_info_t(int ndims) : dims_(ndims), offs_(ndims), idxs_(ndims) {} |
156 | grid_info_t(const std::vector<int> &dims, const std::vector<expr_t> &idxs) |
157 | : grid_info_t(dims, {}, idxs) {} |
158 | grid_info_t(const std::vector<int> &dims, const std::string &prefix) |
159 | : grid_info_t(dims, make_idxs(prefix, (int)dims.size())) {} |
160 | grid_info_t(const std::vector<int> &dims, const std::vector<int> &offs, |
161 | const std::vector<expr_t> &idxs) |
162 | : dims_(dims), offs_(offs), idxs_(idxs) { |
163 | if (offs_.empty()) offs_.resize(dims.size()); |
164 | ir_assert(dims_.size() == offs_.size()); |
165 | ir_assert(dims_.size() == idxs_.size()); |
166 | } |
167 | |
168 | bool operator==(const grid_info_t &other) const { |
169 | if (ndims() != other.ndims()) return false; |
170 | for (int i = 0; i < ndims(); i++) { |
171 | if (dim(i) != other.dim(i)) return false; |
172 | if (off(i) != other.off(i)) return false; |
173 | if (!idx(i).is_equal(other.idx(i))) return false; |
174 | } |
175 | return true; |
176 | } |
177 | |
178 | bool is_empty() const { return dims_.empty(); } |
179 | |
180 | int &dim(int dim_idx) { return dims_[dim_idx]; } |
181 | int &off(int dim_idx) { return offs_[dim_idx]; } |
182 | expr_t &idx(int dim_idx) { return idxs_[dim_idx]; } |
183 | int dim_idx(const expr_t &idx_var) const { |
184 | for (int i = 0; i < ndims(); i++) { |
185 | if (idx(i).is_same(idx_var)) return i; |
186 | } |
187 | ir_error_not_expected() << "Index not found: " << idx_var; |
188 | return -1; |
189 | } |
190 | |
191 | const int &dim(int dim_idx) const { return dims_[dim_idx]; } |
192 | const int &dim(const expr_t &idx_var) const { |
193 | return dims_[dim_idx(idx_var)]; |
194 | } |
195 | const int &off(int dim_idx) const { return offs_[dim_idx]; } |
196 | const expr_t &idx(int dim_idx) const { return idxs_[dim_idx]; } |
197 | |
198 | int &operator[](int dim_idx) { return dim(dim_idx); } |
199 | const int &operator[](int dim_idx) const { return dim(dim_idx); } |
200 | |
201 | int ndims() const { return int(dims_.size()); } |
202 | int elems() const { |
203 | return utils::array_product(dims_.data(), dims_.size()); |
204 | } |
205 | |
206 | grid_info_t sub_grid(std::initializer_list<int> old_dim_idxs) const { |
207 | grid_info_t ret(int(old_dim_idxs.size())); |
208 | int new_dim_idx = 0; |
209 | for (auto old_dim_idx : old_dim_idxs) { |
210 | ret.dim(new_dim_idx) = dim(old_dim_idx); |
211 | ret.off(new_dim_idx) = off(old_dim_idx); |
212 | ret.idx(new_dim_idx) = idx(old_dim_idx); |
213 | new_dim_idx++; |
214 | } |
215 | return ret; |
216 | } |
217 | |
218 | grid_info_t resize(const std::vector<int> &new_dims) const { |
219 | grid_info_t ret = *this; |
220 | ret.dims_ = new_dims; |
221 | return ret; |
222 | } |
223 | |
224 | grid_info_t slice(int dim_idx, int new_off, int new_dim, |
225 | const expr_t &new_idx, expr_t &new_idx_value) const { |
226 | ir_assert(dim_idx >= 0 && dim_idx < ndims()); |
227 | ir_assert(new_dim > 0 && new_off >= 0); |
228 | ir_assert(new_off + new_dim <= dims_[dim_idx]); |
229 | |
230 | grid_info_t ret = *this; |
231 | ret.offs_[dim_idx] += new_off; |
232 | ret.dims_[dim_idx] = new_dim; |
233 | if (new_off > 0) { |
234 | new_idx_value = ret.idxs_[dim_idx] - new_off; |
235 | ret.idxs_[dim_idx] = new_idx; |
236 | } else { |
237 | new_idx_value = expr_t(); |
238 | } |
239 | ret.parent_dims_ = (parent_dims_.empty() ? dims_ : parent_dims_); |
240 | return ret; |
241 | } |
242 | |
243 | grid_info_t halven(const expr_t &new_idx, int &dim_idx, |
244 | expr_t &new_idx_value, bool first = true) const { |
245 | for (int i = ndims() - 1; i >= 0; i--) { |
246 | if (dim(i) == 1 || dim(i) % 2 != 0) continue; |
247 | dim_idx = i; |
248 | if (first) return slice(i, 0, dim(i) / 2, new_idx, new_idx_value); |
249 | return slice(i, dim(i) / 2, dim(i) / 2, new_idx, new_idx_value); |
250 | } |
251 | return grid_info_t(); |
252 | } |
253 | |
254 | expr_t slice_condition() const { |
255 | if (parent_dims_.empty()) return expr_t(); |
256 | expr_t ret(true); |
257 | for (int i = 0; i < ndims(); i++) { |
258 | auto &idx = idxs_[i]; |
259 | if (offs_[i] > 0) ret &= (idx >= 0); |
260 | if (offs_[i] + dims_[i] < parent_dims_[i]) ret &= (idx < dims_[i]); |
261 | } |
262 | if (ret.is_equal(expr_t(true))) return expr_t(); |
263 | return ret; |
264 | } |
265 | |
266 | std::string str() const { |
267 | std::ostringstream oss; |
268 | oss << ir_utils::make_seq_print_helper(dims_, " x " ); |
269 | return oss.str(); |
270 | } |
271 | |
272 | IR_DEFINE_DUMP() |
273 | |
274 | private: |
275 | static std::vector<expr_t> make_idxs(const std::string &prefix, int n) { |
276 | std::vector<expr_t> ret; |
277 | for (int i = 0; i < n; i++) |
278 | ret.push_back( |
279 | var_t::make(type_t::s32(), prefix + std::to_string(i))); |
280 | return ret; |
281 | } |
282 | |
283 | std::vector<int> dims_; |
284 | std::vector<int> offs_; |
285 | std::vector<expr_t> idxs_; |
286 | |
287 | std::vector<int> parent_dims_; |
288 | }; |
289 | |
290 | inline std::ostream &operator<<( |
291 | std::ostream &out, const grid_info_t &grid_info) { |
292 | out << grid_info.str(); |
293 | return out; |
294 | } |
295 | |
296 | class grid_splitter_t { |
297 | public: |
298 | grid_splitter_t(const grid_info_t &grid) |
299 | : grid_(grid), cur_idx_(grid.ndims() - 1), cur_stride_(1) { |
300 | skip_size_1_dims(); |
301 | ir_assert(cur_idx_ >= 0); |
302 | } |
303 | |
304 | int cur_block() const { |
305 | if (is_empty()) return 1; |
306 | |
307 | return grid_.dim(cur_idx_) / cur_stride_; |
308 | } |
309 | |
310 | bool is_empty() const { return cur_idx_ == -1; } |
311 | |
312 | bool can_pop_block(int size) const { |
313 | if (is_empty()) return false; |
314 | return cur_block() % size == 0; |
315 | } |
316 | |
317 | expr_t pop_block(int size); |
318 | |
319 | private: |
320 | void skip_size_1_dims() { |
321 | while (cur_idx_ >= 0 && grid_.dim(cur_idx_) == 1) |
322 | cur_idx_--; |
323 | } |
324 | |
325 | grid_info_t grid_; |
326 | |
327 | int cur_idx_; |
328 | int cur_stride_; |
329 | }; |
330 | |
331 | enum class stride_kind_t { |
332 | undef, |
333 | fixed, |
334 | unknown, |
335 | }; |
336 | |
337 | class stride_t { |
338 | public: |
339 | stride_t() = default; |
340 | |
341 | stride_t(dim_t stride) : stride_t(stride_kind_t::fixed, stride) {} |
342 | |
343 | bool operator==(const stride_t &other) const { |
344 | return (kind_ == other.kind_) && (stride_ == other.stride_); |
345 | } |
346 | |
347 | bool operator!=(const stride_t &other) const { return !operator==(other); } |
348 | |
349 | size_t get_hash() const { return ir_utils::get_hash(kind_, stride_); } |
350 | |
351 | operator dim_t() const { |
352 | ir_assert(kind_ == stride_kind_t::fixed); |
353 | return stride_; |
354 | } |
355 | |
356 | bool is_fixed() const { return kind_ == stride_kind_t::fixed; } |
357 | |
358 | bool is_unknown() const { return kind_ == stride_kind_t::unknown; } |
359 | |
360 | stride_t &operator*=(const stride_t &other) { |
361 | if (is_fixed() && other.is_fixed()) { |
362 | stride_ *= other.stride_; |
363 | } else { |
364 | set_unknown(); |
365 | } |
366 | return *this; |
367 | } |
368 | |
369 | stride_t &operator/=(const stride_t &other) { |
370 | if (is_fixed() && other.is_fixed()) { |
371 | stride_ /= other.stride_; |
372 | } else { |
373 | set_unknown(); |
374 | } |
375 | return *this; |
376 | } |
377 | |
378 | std::string str() const { |
379 | std::ostringstream oss; |
380 | if (is_fixed()) { |
381 | oss << stride_; |
382 | } else { |
383 | oss << "(unknown)" ; |
384 | } |
385 | return oss.str(); |
386 | } |
387 | |
388 | IR_DEFINE_DUMP() |
389 | |
390 | static stride_t unknown() { return stride_t(stride_kind_t::unknown); } |
391 | |
392 | private: |
393 | stride_t(stride_kind_t kind, dim_t stride = 0) |
394 | : kind_(kind), stride_(stride) {} |
395 | |
396 | void set_unknown() { |
397 | kind_ = stride_kind_t::unknown; |
398 | stride_ = 0; |
399 | } |
400 | |
401 | stride_kind_t kind_ = stride_kind_t::undef; |
402 | dim_t stride_ = 0; |
403 | }; |
404 | |
405 | inline std::ostream &operator<<(std::ostream &out, const stride_t &stride) { |
406 | out << stride.str(); |
407 | return out; |
408 | } |
409 | |
410 | inline stride_t operator*(const stride_t &a, const stride_t &b) { |
411 | stride_t tmp = a; |
412 | return tmp *= b; |
413 | } |
414 | |
415 | inline stride_t operator*(const stride_t &a, dim_t b) { |
416 | return a * stride_t(b); |
417 | } |
418 | |
419 | inline stride_t operator*(dim_t a, const stride_t &b) { |
420 | return stride_t(a) * b; |
421 | } |
422 | |
423 | struct block_t { |
424 | block_t() = default; |
425 | |
426 | block_t(int dim_idx, dim_t block, const stride_t &stride) |
427 | : dim_idx(dim_idx), block(block), stride(stride) {} |
428 | |
429 | bool is_equal(const block_t &other) const { |
430 | return (dim_idx == other.dim_idx) && (block == other.block) |
431 | && (stride == other.stride); |
432 | } |
433 | |
434 | size_t get_hash() const { |
435 | return ir_utils::get_hash(dim_idx, block, stride); |
436 | } |
437 | |
438 | std::string str() const { |
439 | std::ostringstream oss; |
440 | oss << "block_t(dim_idx = " << dim_idx; |
441 | oss << ", block = " << block; |
442 | oss << ", stride = " << stride; |
443 | oss << ")" ; |
444 | return oss.str(); |
445 | } |
446 | |
447 | IR_DEFINE_DUMP() |
448 | |
449 | bool is_empty() const { return dim_idx == -1; } |
450 | |
451 | int dim_idx = -1; // Dimension index. |
452 | dim_t block; // Block size. |
453 | stride_t stride; // Stride between elements of the block. |
454 | }; |
455 | |
456 | inline std::ostream &operator<<(std::ostream &out, const block_t &b) { |
457 | out << b.str(); |
458 | return out; |
459 | } |
460 | |
461 | class layout_t { |
462 | public: |
463 | static const int max_ndims = 16; |
464 | |
465 | layout_t() : type_(type_t::undef()), ndims_(0), offset_(0) { |
466 | sanity_check(); |
467 | } |
468 | |
469 | layout_t(const type_t &type, const expr_t &offset, |
470 | const std::vector<std::pair<int, dim_t>> &parts, |
471 | const std::vector<dim_t> &dims = {}, bool do_normalize = true); |
472 | |
473 | layout_t(const type_t &type, const expr_t &offset, |
474 | const std::string &format, const std::vector<dim_t> &dims = {}, |
475 | bool do_normalize = true) |
476 | : layout_t(type, offset, parse_format(format, int(dims.size())), dims, |
477 | do_normalize) {} |
478 | |
479 | layout_t(const memory_desc_wrapper &mdw, const std::string &format, |
480 | bool do_normalize = true) |
481 | : layout_t(mdw.data_type(), mdw.offset0(), format, |
482 | std::vector<dim_t>(mdw.dims(), mdw.dims() + mdw.ndims()), |
483 | do_normalize) {} |
484 | |
485 | layout_t(const memory_desc_wrapper &mdw, const char *format, |
486 | bool do_normalize = true) |
487 | : layout_t(mdw, std::string(format), do_normalize) {} |
488 | |
489 | layout_t(const memory_desc_wrapper &mdw, bool do_normalize = true); |
490 | |
491 | layout_t(const type_t &type, const expr_t &offset, |
492 | const std::vector<dim_t> &dims, bool do_normalize = true) |
493 | : type_(type), ndims_(int(dims.size())), offset_(offset) { |
494 | dim_t stride = 1; |
495 | for (int i = ndims_ - 1; i >= 0; i--) { |
496 | blocks_.emplace_back(i, dims[i], stride); |
497 | stride *= dims[i]; |
498 | } |
499 | if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_); |
500 | sanity_check(); |
501 | } |
502 | |
503 | layout_t(const type_t &type, int ndims, const expr_t &offset, |
504 | const std::vector<block_t> &blocks, bool do_normalize = true) |
505 | : type_(type), ndims_(ndims), offset_(offset), blocks_(blocks) { |
506 | if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_); |
507 | sanity_check(); |
508 | } |
509 | |
510 | layout_t(const type_t &type, const expr_t &offset, const layout_t &other, |
511 | bool do_normalize) |
512 | : layout_t(type, other.ndims(), offset, other.blocks(), do_normalize) {} |
513 | |
514 | bool is_empty() const { return ndims_ == 0; } |
515 | |
516 | int ndims() const { return ndims_; } |
517 | |
518 | dim_t elems() const { |
519 | dim_t ret = 1; |
520 | for (auto &b : blocks_) |
521 | ret *= b.block; |
522 | return ret; |
523 | } |
524 | |
525 | // Storage size in bytes. |
526 | dim_t size() const { |
527 | if (is_empty()) return 0; |
528 | dim_t max_stride = 1; |
529 | for (auto &b : blocks_) { |
530 | max_stride = std::max(max_stride, dim_t(b.block * b.stride)); |
531 | } |
532 | return max_stride * type().size(); |
533 | } |
534 | |
535 | template <typename T = expr_t> |
536 | T offset( |
537 | const std::vector<T> &args = {}, bool ignore_offset = false) const { |
538 | if (args.empty()) return expr_cast<T>(offset_); |
539 | |
540 | ir_assert(int(args.size()) == ndims()) << "Dimensions do not match." ; |
541 | |
542 | T off = 0; |
543 | auto _args = args; |
544 | for (auto &eb : enumerated_blocks()) { |
545 | auto &b = eb.second; |
546 | auto &idx = _args[b.dim_idx]; |
547 | if (ir_utils::is_equal(idx, T(0))) continue; |
548 | |
549 | // Do not use modulus for outermost blocks. |
550 | auto i = is_outermost(eb) ? idx : (idx % b.block); |
551 | off = i * dim_t(b.stride) + off; |
552 | idx /= b.block; |
553 | } |
554 | if (ignore_offset) return off; |
555 | |
556 | T off0 = expr_cast<T>(offset_); |
557 | return off0 + off; |
558 | } |
559 | |
560 | const type_t &type() const { return type_; } |
561 | |
562 | std::vector<dim_t> dims() const { |
563 | std::vector<dim_t> dims(ndims(), 1); |
564 | for (auto &b : blocks_) { |
565 | dims[b.dim_idx] *= b.block; |
566 | } |
567 | return dims; |
568 | } |
569 | |
570 | dim_t dim(int dim_idx) const { |
571 | dim_t ret = 1; |
572 | for (auto &b : blocks_) { |
573 | if (b.dim_idx == dim_idx) ret *= b.block; |
574 | } |
575 | return ret; |
576 | } |
577 | |
578 | const std::vector<block_t> &blocks() const { return blocks_; } |
579 | |
580 | dim_t inner_block(int dim_idx, bool skip_outer = true) const { |
581 | dim_t block0 = -1; |
582 | int nblocks = 0; |
583 | for (auto &b : blocks_) { |
584 | if (b.dim_idx == dim_idx) { |
585 | nblocks++; |
586 | if (!skip_outer) return b.block; |
587 | if (block0 == -1) block0 = b.block; |
588 | } |
589 | } |
590 | return nblocks > 1 ? block0 : 1; |
591 | } |
592 | |
593 | void set_offset(const expr_t &offset) { offset_ = offset; } |
594 | |
595 | bool is_strictly_equal(const layout_t &other, bool compare_offset = true, |
596 | bool compare_strides = true) const { |
597 | if (!type_.is_equal(other.type_)) return false; |
598 | if (compare_offset && !offset_.is_equal(other.offset_)) return false; |
599 | if (blocks_.size() != other.blocks_.size()) return false; |
600 | for (size_t i = 0; i < blocks_.size(); i++) { |
601 | auto &b0 = blocks_[i]; |
602 | auto &b1 = other.blocks_[i]; |
603 | if (b0.dim_idx != b1.dim_idx) return false; |
604 | if (b0.block != b1.block) return false; |
605 | if (compare_strides && b0.stride != b1.stride) return false; |
606 | } |
607 | return true; |
608 | } |
609 | |
610 | bool operator==(const layout_t &other) const { return is_equal(other); } |
611 | |
612 | bool operator!=(const layout_t &other) const { return !operator==(other); } |
613 | bool operator<=(const layout_t &other) const { |
614 | if (!type_.is_equal(other.type_)) return false; |
615 | const auto other_blocks = other.normalize().blocks(); |
616 | const auto self_blocks = normalize().blocks(); |
617 | if (self_blocks.size() > other_blocks.size()) return false; |
618 | if (self_blocks.size() == 0) return true; |
619 | |
620 | int i = 0; |
621 | for (; i < (int)self_blocks.size() - 1; i++) { |
622 | if (!self_blocks[i].is_equal(other_blocks[i])) return false; |
623 | } |
624 | return (self_blocks[i].dim_idx == other_blocks[i].dim_idx |
625 | && self_blocks[i].stride == other_blocks[i].stride |
626 | && other_blocks[i].block % self_blocks[i].block == 0); |
627 | } |
628 | bool operator>=(const layout_t &other) const { return other <= *this; } |
629 | |
630 | bool is_equal(const layout_t &other, bool compare_offset = true) const { |
631 | return normalize().is_strictly_equal(other.normalize(), compare_offset); |
632 | } |
633 | |
634 | size_t get_hash() const { |
635 | return ir_utils::get_hash(type_, ndims_, offset_, blocks_); |
636 | } |
637 | |
638 | template <typename T> |
639 | T operator()(const std::vector<T> &args) const { |
640 | return offset(args); |
641 | } |
642 | |
643 | template <typename T = expr_t> |
644 | T offset_in_bytes( |
645 | const std::vector<T> &args = {}, bool ignore_offset = false) const { |
646 | return offset(args, ignore_offset) * type().size(); |
647 | } |
648 | |
649 | std::string desc_str(bool dnnl_style = false) const { |
650 | if (is_empty()) return "(nil)" ; |
651 | if (!dnnl_style && blocks_.empty()) |
652 | return "(scalar:" + type().str() + ")" ; |
653 | std::string ret; |
654 | stride_t dense_stride(1); |
655 | std::vector<bool> seen(ndims()); |
656 | for (auto &eb : enumerated_blocks()) { |
657 | auto &b = eb.second; |
658 | std::string b_str; |
659 | if (dnnl_style && is_outermost(eb)) { |
660 | b_str.append(1, (seen[b.dim_idx] ? 'A' : 'a') + b.dim_idx); |
661 | } else { |
662 | b_str = std::to_string(b.block); |
663 | b_str.append(1, 'a' + b.dim_idx); |
664 | } |
665 | if (!dnnl_style) { |
666 | if (b.stride.is_unknown()) { |
667 | b_str.append(1, '?'); |
668 | } else if (b.stride != dense_stride) { |
669 | b_str.append(1, '*'); |
670 | } |
671 | } |
672 | ret = b_str + ret; |
673 | dense_stride = b.stride * b.block; |
674 | seen[b.dim_idx] = true; |
675 | } |
676 | ret += ":" + type().str(); |
677 | return ret; |
678 | } |
679 | |
680 | std::string str() const { |
681 | if (is_empty()) return "(nil)" ; |
682 | std::ostringstream oss; |
683 | oss << desc_str(); |
684 | if (!has_zero_offset()) oss << " offset: " << offset_; |
685 | return oss.str(); |
686 | } |
687 | |
688 | IR_DEFINE_DUMP() |
689 | |
690 | memory_desc_t to_dnnl(const dim_t *dims_hint) const; |
691 | |
692 | // Returns a vector of <block index, block> pairs. |
693 | // The innermost block (first) has index 0. |
694 | std::vector<std::pair<int, block_t>> enumerated_blocks() const { |
695 | std::vector<std::pair<int, block_t>> ret; |
696 | for (int i = 0; i < int(blocks_.size()); i++) { |
697 | ret.emplace_back(i, blocks_[i]); |
698 | } |
699 | return ret; |
700 | } |
701 | |
702 | std::vector<dim_t> strides(int dim_idx) const { |
703 | std::vector<dim_t> ret; |
704 | for (auto &b : blocks_) |
705 | if (b.dim_idx == dim_idx) ret.push_back(b.stride); |
706 | return ret; |
707 | } |
708 | |
709 | // eb is <block index, block> pair, see enumerated_blocks(). |
710 | bool is_outermost(const std::pair<int, block_t> &eb) const { |
711 | return is_outermost(eb, blocks_); |
712 | } |
713 | |
714 | bool is_plain() const { |
715 | std::vector<bool> seen(ndims()); |
716 | for (auto &b : blocks_) { |
717 | if (seen[b.dim_idx]) return false; |
718 | seen[b.dim_idx] = true; |
719 | } |
720 | return true; |
721 | } |
722 | |
723 | bool has_zero_offset() const { return offset_.is_equal(expr_t(0)); } |
724 | |
725 | bool has_unknown_strides() const { |
726 | for (auto &b : blocks_) |
727 | if (b.stride.is_unknown()) return true; |
728 | return false; |
729 | } |
730 | |
731 | // Returns a canonical representation of the layout: |
732 | // - Size one blocks are removed |
733 | // - Consecutive dense blocks are merged |
734 | layout_t normalize() const { |
735 | auto blocks = normalize_blocks(ndims(), blocks_); |
736 | return layout_t(type(), ndims(), offset(), blocks); |
737 | } |
738 | |
739 | layout_t transpose() const { |
740 | if (ndims() != 2) ir_error_not_expected(); |
741 | |
742 | // Flip: 0 -> 1, 1 -> 0. |
743 | auto blocks = blocks_; |
744 | for (auto &b : blocks) |
745 | b.dim_idx ^= 1; |
746 | |
747 | return layout_t(type(), ndims(), offset(), blocks); |
748 | } |
749 | |
750 | // Returns a new (sub-)layout that fully contains the passed sub-tensor. |
751 | // Strides are kept unchanged. |
752 | // Assumption: the original layout can be tiled by the passed sub-tensor. |
753 | // For example: XaYb4a2b can be tiled into 2x2 sub-tensors but it's not |
754 | // possible to tile it into 3x2 sub-tensors. |
755 | layout_t map(const tensor_t &tensor) const; |
756 | |
757 | layout_t reinterpret( |
758 | const type_t &new_type, bool do_normalize = true) const; |
759 | |
760 | layout_t retype(const type_t &new_type) const { |
761 | auto ret = *this; |
762 | ret.type_ = new_type; |
763 | return ret; |
764 | } |
765 | |
766 | bool is_dense() const { |
767 | stride_t stride = 1; |
768 | for (auto &b : blocks_) { |
769 | if (b.stride != stride) return false; |
770 | stride *= b.block; |
771 | } |
772 | return true; |
773 | } |
774 | |
775 | layout_t innermost_block_layout() const { |
776 | int block_count[layout_t::max_ndims] = {0}; |
777 | for (auto &b : blocks_) |
778 | block_count[b.dim_idx]++; |
779 | |
780 | std::vector<block_t> inner_blocks; |
781 | |
782 | stride_t stride = 1; |
783 | for (auto &b : blocks_) { |
784 | if (b.stride != stride) break; // Not dense anymore. |
785 | if (block_count[b.dim_idx] == 1) break; // Outer block. |
786 | stride *= b.block; |
787 | ir_assert(block_count[b.dim_idx] > 0); |
788 | block_count[b.dim_idx]--; |
789 | inner_blocks.push_back(b); |
790 | } |
791 | return layout_t(type(), ndims(), 0, inner_blocks); |
792 | } |
793 | |
794 | // Returns a packed layout where all blocks are contiguous, without gaps. |
795 | layout_t make_dense() const { |
796 | dim_t stride = 1; |
797 | auto new_blocks = blocks_; |
798 | for (auto &b : new_blocks) { |
799 | b.stride = stride; |
800 | stride *= b.block; |
801 | } |
802 | return layout_t(type(), ndims(), 0, new_blocks); |
803 | } |
804 | |
805 | layout_t make_strided(int _stride, int block_idx = 0) const { |
806 | dim_t cur_stride = 1; |
807 | auto new_blocks = blocks_; |
808 | for (int i = 0; i < (int)new_blocks.size(); i++) { |
809 | auto &b = new_blocks[i]; |
810 | if (i >= block_idx) { |
811 | b.stride = (i == block_idx ? _stride : cur_stride); |
812 | } |
813 | cur_stride = b.stride * b.block; |
814 | } |
815 | return layout_t(type(), ndims(), 0, new_blocks); |
816 | } |
817 | |
818 | // Returns an equivalent layout where the specified block is split into two. |
819 | // block0 - inner block size. |
820 | // block1 - outer block size. |
821 | layout_t split_block(const std::pair<int, block_t> &eb, dim_t block0, |
822 | dim_t block1) const; |
823 | |
824 | // Splits blocks so that they can be used to form `multi_blocks` without |
825 | // crossing the block boundaries. `multi_blocks` are ordered from innermost |
826 | // to outermost. Returns an empty layout if such a split is not possible. |
827 | // Example (all blocks are ordered from innermost to outermost): |
828 | // Input blocks: [4, 4, 2] |
829 | // Multi-blocks: [8, 2] |
830 | // Output blocks: [4, 2, 2, 2] |
831 | layout_t split_into_multi_blocks( |
832 | const std::vector<dim_t> &multi_blocks) const; |
833 | |
834 | layout_t add_outer_block( |
835 | int dim_idx, dim_t block, dim_t stride = -1) const { |
836 | if (stride == -1) stride = elems(); |
837 | ir_assert(stride >= elems()); |
838 | ir_assert(dim_idx < ndims()); |
839 | auto new_blocks = blocks(); |
840 | new_blocks.emplace_back(dim_idx, block, stride); |
841 | return layout_t(type(), ndims(), offset(), new_blocks); |
842 | } |
843 | |
844 | // Returns a tensor corresponding to the biggest innermost sub-layout so that |
845 | // 1) It consists of consecutive blocks only. |
846 | // 2) It contains less or equal than max_tile_elems elements. |
847 | // 3) It is dense if is_dense_tile is true. |
848 | tensor_t split_into_max_tile( |
849 | dim_t max_tile_elems, bool is_dense_tile) const; |
850 | |
851 | tensor_t split(const grid_info_t &grid_info) const { |
852 | tensor_t min_tile; |
853 | std::vector<int> cur_dims(grid_info.ndims(), 1); |
854 | |
855 | for (int iter = 0; iter < grid_info.elems(); iter++) { |
856 | for (int i = 0; i < grid_info.ndims(); i++) { |
857 | if (++cur_dims[i] <= grid_info.dim(i)) break; |
858 | cur_dims[i] = 1; |
859 | } |
860 | auto sub_grid = grid_info.resize(cur_dims); |
861 | auto tile = split_exact(sub_grid); |
862 | if (tile.is_empty()) continue; |
863 | if (min_tile.is_empty() || tile.elems() < min_tile.elems()) { |
864 | min_tile = tile; |
865 | } |
866 | } |
867 | return min_tile; |
868 | } |
869 | |
870 | tensor_t split_exact(const grid_info_t &grid) const { |
871 | std::vector<dim_t> tile_dims(ndims(), 1); |
872 | if (elems() % grid.elems() != 0) return tensor_t(); |
873 | |
874 | dim_t cur_elems_per_tile = 1; |
875 | dim_t elems_per_tile = elems() / grid.elems(); |
876 | for (auto &b : blocks()) { |
877 | dim_t block |
878 | = std::min(b.block, elems_per_tile / cur_elems_per_tile); |
879 | tile_dims[b.dim_idx] *= block; |
880 | cur_elems_per_tile *= block; |
881 | } |
882 | if (cur_elems_per_tile != elems_per_tile) return tensor_t(); |
883 | |
884 | return split(tensor_t(tile_dims), grid); |
885 | } |
886 | |
887 | tensor_t split(const tensor_t &tile, const grid_info_t &grid, |
888 | std::vector<block_t> *outer_blocks = nullptr) const { |
889 | ir_assert(ndims() == tile.ndims()) |
890 | << "Number of dimensions doesn't match." ; |
891 | ir_assert(tile.has_zero_start()); |
892 | |
893 | if (outer_blocks) outer_blocks->resize(0); |
894 | |
895 | if (grid.elems() == 1) return tile; |
896 | |
897 | dim_t total_elems = elems(); |
898 | dim_t tile_elems = tile.elems(); |
899 | |
900 | grid_splitter_t grid_splitter(grid); |
901 | ir_assert(tile_elems * grid.elems() == total_elems) |
902 | << "Tile/grid dimensions do not match." ; |
903 | MAYBE_UNUSED(total_elems); |
904 | MAYBE_UNUSED(tile_elems); |
905 | |
906 | std::vector<dim_t> dims(tile.ndims(), 1); |
907 | std::vector<expr_t> start(tile.ndims(), 0); |
908 | std::vector<dim_t> rem_dims = tile.dims(); |
909 | for (auto &eb : enumerated_blocks()) { |
910 | auto &b = eb.second; |
911 | if (b.block == 1) continue; |
912 | |
913 | dim_t &e = rem_dims[b.dim_idx]; |
914 | if (e > 1) { |
915 | if (e % b.block == 0) { |
916 | e /= b.block; |
917 | } else if (b.block % e == 0) { |
918 | auto tmp_layout = split_block(eb, e, b.block / e); |
919 | return tmp_layout.split(tile, grid, outer_blocks); |
920 | } else { |
921 | return tensor_t(); |
922 | } |
923 | } else { |
924 | dim_t next_chunk |
925 | = math::gcd(b.block, grid_splitter.cur_block()); |
926 | if (b.block == next_chunk) { |
927 | auto idx = grid_splitter.pop_block(next_chunk); |
928 | start[b.dim_idx] += idx * dims[b.dim_idx]; |
929 | if (outer_blocks) outer_blocks->push_back(b); |
930 | } else if (b.block % next_chunk == 0 && next_chunk != 1) { |
931 | auto tmp_layout |
932 | = split_block(eb, next_chunk, b.block / next_chunk); |
933 | return tmp_layout.split(tile, grid, outer_blocks); |
934 | } else { |
935 | return tensor_t(); |
936 | } |
937 | } |
938 | dims[b.dim_idx] *= b.block; |
939 | } |
940 | return tensor_t(tile.dims(), start); |
941 | } |
942 | |
943 | // Iterates through tiles of the layout, calling `f` with relative offsets |
944 | // for each tile. The iteration order is defined by the layout blocks - |
945 | // absolute 1D offsets are increasing between callback calls. |
946 | template <typename F> |
947 | void for_each_tile(const tensor_t &tile, const F &f) const { |
948 | ir_assert(tile.ndims() == ndims()); |
949 | ir_assert(tile.has_zero_start()); |
950 | for (int i = 0; i < ndims(); i++) { |
951 | ir_assert(dim(i) % tile.dims()[i] == 0); |
952 | } |
953 | |
954 | int nblocks = int(blocks().size()); |
955 | std::vector<dim_t> sub_blocks(nblocks); |
956 | for (int i = 0; i < nblocks; i++) |
957 | sub_blocks[i] = blocks()[i].block; |
958 | |
959 | for (int i = 0; i < ndims(); i++) { |
960 | dim_t dim = tile.dims()[i]; |
961 | for (auto &eb : enumerated_blocks()) { |
962 | auto &b = eb.second; |
963 | if (b.dim_idx != i) continue; |
964 | int block_idx = eb.first; |
965 | if (b.block >= dim) { |
966 | ir_assert(b.block % dim == 0); |
967 | sub_blocks[block_idx] = b.block / dim; |
968 | break; |
969 | } |
970 | sub_blocks[block_idx] = 1; |
971 | ir_assert(dim % b.block == 0); |
972 | dim /= b.block; |
973 | } |
974 | } |
975 | |
976 | int ntiles = int(elems() / tile.elems()); |
977 | |
978 | std::vector<dim_t> sub_block_idxs(nblocks); |
979 | for (int i = 0; i < ntiles; i++) { |
980 | // Convert sub-block indices to dimension indices. |
981 | std::vector<dim_t> dims(ndims(), 1); |
982 | std::vector<dim_t> start(ndims()); |
983 | for (int j = 0; j < nblocks; j++) { |
984 | auto &b = blocks()[j]; |
985 | dim_t k = sub_block_idxs[j] |
986 | * (blocks()[j].block / sub_blocks[j]); |
987 | start[b.dim_idx] += dims[b.dim_idx] * k; |
988 | dims[b.dim_idx] *= b.block; |
989 | } |
990 | |
991 | // Pass dimension offsets to the callback. |
992 | f(start); |
993 | |
994 | // Move to the next vector of indices. |
995 | for (int j = 0; j < nblocks; j++) { |
996 | auto &idx = sub_block_idxs[j]; |
997 | if (idx + 1 < sub_blocks[j]) { |
998 | idx++; |
999 | break; |
1000 | } |
1001 | idx = 0; |
1002 | } |
1003 | } |
1004 | } |
1005 | |
1006 | // eb is <block index, block> pair, see enumerated_blocks(). |
1007 | static bool is_outermost(const std::pair<int, block_t> &eb, |
1008 | const std::vector<block_t> &blocks) { |
1009 | int dim_idx = eb.second.dim_idx; |
1010 | for (int i = 0; i < int(blocks.size()); i++) { |
1011 | if (blocks[i].dim_idx == dim_idx && i > eb.first) return false; |
1012 | } |
1013 | return true; |
1014 | } |
1015 | |
1016 | // Assume that layouts are normalized. |
1017 | static void align_layouts(layout_t &a, layout_t &b); |
1018 | |
1019 | static std::vector<block_t> normalize_blocks(int ndims, |
1020 | const std::vector<block_t> &blocks, |
1021 | bool remove_size_1_blocks = true) { |
1022 | auto new_blocks = blocks; |
1023 | |
1024 | // Remove blocks of size 1. |
1025 | if (remove_size_1_blocks) { |
1026 | for (auto it = new_blocks.begin(); it != new_blocks.end();) { |
1027 | if (it->block == 1) { |
1028 | it = new_blocks.erase(it); |
1029 | } else { |
1030 | ++it; |
1031 | } |
1032 | } |
1033 | } |
1034 | |
1035 | // Merge same dimension blocks. |
1036 | block_t prev_b; |
1037 | prev_b.dim_idx = -1; |
1038 | for (auto it = new_blocks.begin(); it != new_blocks.end();) { |
1039 | if (it->dim_idx == prev_b.dim_idx |
1040 | && it->stride == (prev_b.stride * prev_b.block)) { |
1041 | auto &b = *(it - 1); |
1042 | b.block *= it->block; |
1043 | prev_b = b; |
1044 | it = new_blocks.erase(it); |
1045 | } else { |
1046 | prev_b = *it; |
1047 | ++it; |
1048 | } |
1049 | } |
1050 | |
1051 | return new_blocks; |
1052 | } |
1053 | |
1054 | // Reinterprets layouts to wider data type (up to 4 bytes). |
1055 | // Example: 16a16b (s8 type) -> 16a4b (s32 type) |
1056 | static bool try_reinterpret_to_wider_type(layout_t &src, layout_t &dst, |
1057 | const tensor_t &tile = {}, bool do_update = true, |
1058 | int *new_size_out = nullptr) { |
1059 | if (src.blocks().empty() || dst.blocks().empty()) return false; |
1060 | if (src.type() != dst.type()) return false; |
1061 | |
1062 | auto &s0 = src.blocks()[0]; |
1063 | auto &d0 = dst.blocks()[0]; |
1064 | if (s0.dim_idx != d0.dim_idx) return false; |
1065 | if (int(s0.stride) != 1) return false; |
1066 | if (int(d0.stride) != 1) return false; |
1067 | |
1068 | int old_size = src.type().size(); |
1069 | int s0_old_size = int(s0.block) * old_size; |
1070 | int d0_old_size = int(d0.block) * old_size; |
1071 | |
1072 | int new_size = math::gcd(s0_old_size, d0_old_size); |
1073 | new_size = math::gcd(new_size, 4); // Try types up to 4 bytes. |
1074 | if (new_size <= old_size) return false; |
1075 | |
1076 | auto tile_ok = [&](const layout_t &l) { |
1077 | if (tile.is_empty()) return true; |
1078 | int factor = new_size / old_size; |
1079 | if (tile(l.blocks()[0].dim_idx) % factor != 0) return false; |
1080 | return true; |
1081 | }; |
1082 | |
1083 | auto strides_ok = [&](const layout_t &l) { |
1084 | for (int i = 1; i < int(l.blocks().size()); i++) { |
1085 | auto &b = l.blocks()[i]; |
1086 | if (int(b.stride) * old_size % new_size != 0) return false; |
1087 | } |
1088 | return true; |
1089 | }; |
1090 | |
1091 | while (new_size > old_size) { |
1092 | bool ok = true; |
1093 | ok &= (tile_ok(src) && tile_ok(dst)); |
1094 | ok &= (strides_ok(src) && strides_ok(dst)); |
1095 | if (ok) { |
1096 | if (do_update) { |
1097 | src = src.reinterpret(type_t::s(new_size * 8)); |
1098 | dst = dst.reinterpret(type_t::s(new_size * 8)); |
1099 | } |
1100 | if (new_size_out) *new_size_out = new_size; |
1101 | return true; |
1102 | } |
1103 | new_size /= 2; |
1104 | } |
1105 | return false; |
1106 | } |
1107 | |
1108 | private: |
1109 | // Returns vector of <dimension index, block size> pairs. |
1110 | static std::vector<std::pair<int, dim_t>> parse_format( |
1111 | const std::string &format, int ndims_hint); |
1112 | |
1113 | // Returns vector of <dimension letter, block size> pairs. |
1114 | static std::vector<std::pair<char, dim_t>> parse_letter_blocks( |
1115 | const std::string &format); |
1116 | |
1117 | void sanity_check() const; |
1118 | |
1119 | // Data type of the layout. |
1120 | type_t type_; |
1121 | |
1122 | // Number of dimensions. |
1123 | int ndims_; |
1124 | |
1125 | // Offset to the start of the layout (in elements of type). |
1126 | expr_t offset_; |
1127 | |
1128 | // Blocks ordered from innermost to outermost. |
1129 | std::vector<block_t> blocks_; |
1130 | }; |
1131 | |
1132 | // Helper class to incrementally increase a sub-layout of the given layout. |
1133 | // One step - adding the minimal factor of the next remaining block. Used |
1134 | // to find the minimal tile between two layouts that is innermost for both |
1135 | // layouts. |
1136 | class layout_iterator_t { |
1137 | public: |
1138 | layout_iterator_t(const layout_t &l) : l_(l), block_idx_(-1), block_(1) {} |
1139 | |
1140 | bool has_next() const { |
1141 | dim_t b = block_; |
1142 | int b_idx = block_idx_; |
1143 | while (b == 1) { |
1144 | b_idx++; |
1145 | if (b_idx >= int(l_.blocks().size())) return false; |
1146 | b = int(l_.blocks()[b_idx].block); |
1147 | } |
1148 | return true; |
1149 | } |
1150 | |
1151 | layout_iterator_t &operator++() { |
1152 | ir_assert(has_next()); |
1153 | while (block_ == 1) { |
1154 | block_idx_++; |
1155 | block_ = int(l_.blocks()[block_idx_].block); |
1156 | } |
1157 | // Find smallest factor. |
1158 | for (int factor = 2; factor <= int(block_); factor++) { |
1159 | if (block_ % factor == 0) { |
1160 | block_ /= factor; |
1161 | return *this; |
1162 | } |
1163 | } |
1164 | |
1165 | ir_error_not_expected(); |
1166 | return *this; |
1167 | } |
1168 | |
1169 | tensor_t tile() const { |
1170 | std::vector<dim_t> dims(l_.ndims(), 1); |
1171 | for (int i = 0; i <= block_idx_; i++) { |
1172 | auto &b = l_.blocks()[i]; |
1173 | int b_block = b.block; |
1174 | if (i == block_idx_) b_block /= block_; |
1175 | dims[b.dim_idx] *= b_block; |
1176 | } |
1177 | return tensor_t(dims); |
1178 | } |
1179 | |
1180 | int nblocks() const { return block_idx_ + 1; } |
1181 | |
1182 | layout_t outer_layout() const { |
1183 | auto &blocks = l_.blocks(); |
1184 | std::vector<block_t> outer_blocks; |
1185 | if (block_ > 1) { |
1186 | auto &b = blocks[block_idx_]; |
1187 | outer_blocks.push_back(b); |
1188 | outer_blocks[0].block = block_; |
1189 | outer_blocks[0].stride = b.stride * (b.block / block_); |
1190 | } |
1191 | outer_blocks.insert(outer_blocks.end(), |
1192 | blocks.begin() + (block_idx_ + 1), blocks.end()); |
1193 | return layout_t(l_.type(), l_.ndims(), l_.offset(), outer_blocks); |
1194 | } |
1195 | |
1196 | private: |
1197 | const layout_t &l_; |
1198 | |
1199 | int block_idx_; |
1200 | dim_t block_; |
1201 | }; |
1202 | |
1203 | inline std::ostream &operator<<(std::ostream &out, const layout_t &layout) { |
1204 | out << layout.str(); |
1205 | return out; |
1206 | } |
1207 | |
1208 | class mask_tensor_t { |
1209 | public: |
1210 | mask_tensor_t() = default; |
1211 | |
1212 | mask_tensor_t(const layout_t &layout) |
1213 | : layout_(layout), masks_(layout.elems(), -1) { |
1214 | ir_assert(layout.is_dense()); |
1215 | } |
1216 | |
1217 | mask_tensor_t(const layout_t &layout, const std::vector<int> &masks, |
1218 | const object_eq_map_t<expr_t, int> &mask2ids, |
1219 | const std::vector<expr_t> &id2masks) |
1220 | : layout_(layout) |
1221 | , masks_(masks) |
1222 | , mask2ids_(mask2ids) |
1223 | , id2masks_(id2masks) { |
1224 | ir_assert(int(masks.size()) == elems()) << "Incompatible size." ; |
1225 | } |
1226 | |
1227 | const type_t &type() const { return layout_.type(); } |
1228 | |
1229 | const layout_t &layout() const { return layout_; } |
1230 | |
1231 | dim_t elems() const { return layout_.elems(); } |
1232 | |
1233 | void set_mask(dim_t off, const expr_t &mask) { |
1234 | ir_assert(0 <= off && off < elems()) << "Incorrect offset." ; |
1235 | if (mask.is_empty()) return; |
1236 | |
1237 | auto ret = mask2ids_.insert({mask, int(mask2ids_.size())}); |
1238 | int id = ret.first->second; |
1239 | masks_[off] = id; |
1240 | |
1241 | if (ret.second) id2masks_.push_back(mask); |
1242 | } |
1243 | |
1244 | const expr_t &mask(dim_t off) const { |
1245 | ir_assert(0 <= off && off < elems()); |
1246 | return id2masks_[masks_[off]]; |
1247 | } |
1248 | |
1249 | void simplify(const constraint_set_t &cset) { |
1250 | for (auto &mask : id2masks_) { |
1251 | auto new_mask = jit::simplify(mask, cset); |
1252 | // Some complex expressions need more than one simplify() call. |
1253 | int max_tries = 5; |
1254 | for (int i = 0; i < max_tries; i++) { |
1255 | mask = new_mask; |
1256 | new_mask = jit::simplify(new_mask, cset); |
1257 | if (new_mask.is_equal(mask)) break; |
1258 | } |
1259 | } |
1260 | mask2ids_.clear(); |
1261 | for (int i = 0; i < int(id2masks_.size()); i++) { |
1262 | auto ret = mask2ids_.insert({id2masks_[i], i}); |
1263 | if (!ret.second) { |
1264 | for (auto &m : masks_) |
1265 | if (m == i) m = ret.first->second; |
1266 | } |
1267 | } |
1268 | } |
1269 | |
1270 | mask_tensor_t map(const tensor_t &tile) const { |
1271 | auto tile_start = expr_cast<dim_t>(tile.start()); |
1272 | auto sub_layout = layout_.map(tensor_t(tile.dims())); |
1273 | mask_tensor_t sub_mask(sub_layout); |
1274 | ir_utils::for_each( |
1275 | tile.dims(), [&](const std::vector<dim_t> &sub_start) { |
1276 | dim_t sub_off = sub_layout(sub_start); |
1277 | dim_t off = layout_(tile_start) + layout_(sub_start); |
1278 | sub_mask.set_mask(sub_off, mask(off)); |
1279 | }); |
1280 | return sub_mask; |
1281 | } |
1282 | |
1283 | mask_tensor_t reinterpret(const type_t &new_type) const { |
1284 | ir_assert(!is_empty()) << "Can't reinterpret." ; |
1285 | dim_t bytes = elems() * type().size(); |
1286 | if (bytes % new_type.size() != 0 && bytes > new_type.size()) |
1287 | return mask_tensor_t(); |
1288 | int new_mask_size = std::max((int)(bytes / new_type.size()), 1); |
1289 | std::vector<int> new_masks(new_mask_size); |
1290 | for (dim_t i = 0; i < bytes; i += new_type.size()) { |
1291 | int mask_id = std::numeric_limits<int>::max(); |
1292 | for (int j = 0; j < new_type.size() && j < bytes; j++) { |
1293 | int cur_mask_id = masks_[(i + j) / type().size()]; |
1294 | if (mask_id >= int(masks_.size())) { |
1295 | mask_id = cur_mask_id; |
1296 | } else if (mask_id != cur_mask_id) { |
1297 | // Mask is not consistent, can't reinterpret. |
1298 | return mask_tensor_t(); |
1299 | } |
1300 | } |
1301 | ir_assert(0 <= mask_id && mask_id < int(masks_.size())); |
1302 | new_masks[i / new_type.size()] = mask_id; |
1303 | } |
1304 | dim_t new_elems = utils::div_up(bytes, new_type.size()); |
1305 | layout_t _1d_layout(new_type, 0, std::vector<dim_t> {new_elems}); |
1306 | return mask_tensor_t(_1d_layout, new_masks, mask2ids_, id2masks_); |
1307 | } |
1308 | |
1309 | expr_t to_expr(int nmasks) const { |
1310 | if (elems() % nmasks != 0) return expr_t(); |
1311 | |
1312 | std::vector<expr_t> vec(nmasks); |
1313 | for (int i = 0; i < elems(); i++) { |
1314 | auto &channel_mask = vec[i % nmasks]; |
1315 | auto &cur_mask = id2masks_[masks_[i]]; |
1316 | if (channel_mask.is_empty()) { |
1317 | channel_mask = cur_mask; |
1318 | continue; |
1319 | } |
1320 | if (!channel_mask.is_equal(cur_mask)) return expr_t(); |
1321 | } |
1322 | auto e = shuffle_t::make(vec); |
1323 | e = jit::simplify(e); |
1324 | e = jit::simplify_propagate_shuffle(e); |
1325 | return e; |
1326 | } |
1327 | |
1328 | bool is_empty() const { return layout_.is_empty(); } |
1329 | |
1330 | std::string str() const { |
1331 | std::ostringstream oss; |
1332 | for (int i = 0; i < int(elems()); i++) { |
1333 | if (i != 0) oss << std::endl; |
1334 | oss << "mask #" << i << ": " ; |
1335 | if (masks_[i] == -1) { |
1336 | oss << "(nil)" ; |
1337 | } else { |
1338 | oss << id2masks_[masks_[i]]; |
1339 | } |
1340 | } |
1341 | return oss.str(); |
1342 | } |
1343 | |
1344 | IR_DEFINE_DUMP() |
1345 | |
1346 | private: |
1347 | layout_t layout_; |
1348 | std::vector<int> masks_; |
1349 | |
1350 | object_eq_map_t<expr_t, int> mask2ids_; |
1351 | std::vector<expr_t> id2masks_; |
1352 | }; |
1353 | |
1354 | inline std::ostream &operator<<( |
1355 | std::ostream &out, const mask_tensor_t &mask_tensor) { |
1356 | out << mask_tensor.str(); |
1357 | return out; |
1358 | } |
1359 | |
1360 | class tdim_info_t { |
1361 | public: |
1362 | tdim_info_t() = default; |
1363 | |
1364 | tdim_info_t(const expr_t &expr, const expr_t &mask) |
1365 | : expr_(expr), mask_(mask) {} |
1366 | |
1367 | int nvargs() const { return nvargs_; } |
1368 | |
1369 | const expr_t &expr() const { return expr_; } |
1370 | |
1371 | const expr_t &mask() const { return mask_; } |
1372 | |
1373 | void set_mask(const expr_t &value) { mask_ = value; } |
1374 | |
1375 | expr_t mask(const expr_t &tvalue, const std::vector<expr_t> &vvars, |
1376 | const std::vector<expr_t> &vvalues) const { |
1377 | auto ret = substitute(mask_, placeholder_var(), tvalue); |
1378 | for (int i = 0; i < int(vvars.size()); i++) { |
1379 | if (contains_object(ret, vvars[i])) { |
1380 | ret = substitute(ret, vvars[i], vvalues[i]); |
1381 | } |
1382 | } |
1383 | return ret; |
1384 | } |
1385 | |
1386 | int vidx(int arg_idx) const { |
1387 | ir_assert(arg_idx < nvargs()); |
1388 | return vidxs_[arg_idx]; |
1389 | } |
1390 | |
1391 | stride_t vstride(int arg_idx) const { |
1392 | ir_assert(arg_idx < nvargs()); |
1393 | return vstrides_[arg_idx]; |
1394 | } |
1395 | |
1396 | bool is_empty() const { return expr_.is_empty(); } |
1397 | |
1398 | bool is_identity() const { return is_var(expr_); } |
1399 | |
1400 | bool is_fixed_stride(int arg_idx) const { |
1401 | ir_assert(arg_idx < nvargs()); |
1402 | return vstrides_[arg_idx].is_fixed(); |
1403 | } |
1404 | |
1405 | void add_vvar(int vidx, const expr_t &varg) { |
1406 | ir_assert(nvargs_ + 1 <= max_nvargs); |
1407 | vidxs_[nvargs_] = vidx; |
1408 | vstrides_[nvargs_] = compute_stride(expr_, nvargs_, varg); |
1409 | nvargs_++; |
1410 | } |
1411 | |
1412 | static const expr_t &placeholder_var() { |
1413 | static thread_local expr_t ph_var = var_t::make(type_t::s32(), "_ph" ); |
1414 | return ph_var; |
1415 | } |
1416 | |
1417 | private: |
1418 | static const int max_nvargs = 2; |
1419 | |
1420 | static stride_t compute_stride(const expr_t &e, int idx, const expr_t &var); |
1421 | |
1422 | expr_t expr_; |
1423 | |
1424 | int nvargs_ = 0; |
1425 | std::array<stride_t, max_nvargs> vstrides_; |
1426 | std::array<int, max_nvargs> vidxs_; |
1427 | expr_t mask_; |
1428 | }; |
1429 | |
1430 | class view_t { |
1431 | public: |
1432 | view_t() = default; |
1433 | |
1434 | view_t(const std::vector<expr_t> &vvars, int ntdims) |
1435 | : vvars_(vvars) |
1436 | , vdims_(vvars.size()) |
1437 | , vstart_(vvars.size()) |
1438 | , tdims_(ntdims) {} |
1439 | |
1440 | // Constructs view from a layout. |
1441 | explicit view_t(const layout_t &layout, |
1442 | const std::vector<expr_t> &_vvars = {}, |
1443 | uint32_t bound_check_mask = 0) |
1444 | : view_t(layout, _vvars, layout.dims(), bound_check_mask) {} |
1445 | |
1446 | view_t(const layout_t &layout, const std::vector<expr_t> &_vvars, |
1447 | const std::vector<dim_t> &_vdims, uint32_t bound_check_mask) |
1448 | : vvars_(_vvars) |
1449 | , vdims_(_vdims) |
1450 | , vstart_(layout.ndims(), 0) |
1451 | , tdims_(layout.ndims()) |
1452 | , tlayout_(layout) { |
1453 | if (vvars_.empty()) vvars_ = create_vvars(layout.ndims()); |
1454 | for (int i = 0; i < nvdims(); i++) { |
1455 | expr_t i_mask; |
1456 | if ((bound_check_mask & (1 << i)) != 0) |
1457 | i_mask = (placeholder_var() < layout.dim(i)); |
1458 | set_tdim(i, vvars_[i], i_mask); |
1459 | } |
1460 | } |
1461 | |
1462 | const std::vector<expr_t> &vvars() const { return vvars_; } |
1463 | |
1464 | const std::vector<dim_t> &vdims() const { return vdims_; } |
1465 | |
1466 | std::vector<expr_t> vstart() const { return vstart_; } |
1467 | |
1468 | expr_t vstart(int vidx) const { return vstart_[vidx]; } |
1469 | |
1470 | const layout_t &tlayout() const { return tlayout_; } |
1471 | |
1472 | int nvdims() const { return int(vdims_.size()); } |
1473 | |
1474 | int ntdims() const { return int(tdims_.size()); } |
1475 | |
1476 | dim_t velems() const { |
1477 | dim_t ret = 1; |
1478 | for (int i = 0; i < nvdims(); i++) |
1479 | ret *= vdims_[i]; |
1480 | return ret; |
1481 | } |
1482 | |
1483 | const expr_t &vvar(int idx) const { |
1484 | ir_assert(idx < nvdims()); |
1485 | return vvars_[idx]; |
1486 | } |
1487 | |
1488 | const tdim_info_t &tdim(int idx) const { |
1489 | ir_assert(idx < ntdims()); |
1490 | return tdims_[idx]; |
1491 | } |
1492 | |
1493 | void set_tdim(int tidx, const expr_t &_texpr, expr_t mask = {}) { |
1494 | ir_assert(tdims_[tidx].is_empty()); |
1495 | |
1496 | auto texpr = simplify(_texpr); |
1497 | |
1498 | tdim_info_t tdim(texpr, mask); |
1499 | for (int i = 0; i < nvdims(); i++) { |
1500 | if (contains_object(texpr, vvars_[i])) tdim.add_vvar(i, vvars_[i]); |
1501 | } |
1502 | if (!is_const(texpr)) { |
1503 | ir_assert(tdim.nvargs() > 0) |
1504 | << "Tensor dimension must have at least one view dimension " |
1505 | "that maps to it." ; |
1506 | } |
1507 | tdims_[tidx] = tdim; |
1508 | } |
1509 | |
1510 | void set_vdim( |
1511 | const expr_t &varg, dim_t vdim, const expr_t &vstart = expr_t(0)) { |
1512 | int vidx = vvar_index(varg); |
1513 | ir_assert(vstart_[vidx].is_empty()); |
1514 | vstart_[vidx] = vstart; |
1515 | vdims_[vidx] = vdim; |
1516 | } |
1517 | |
1518 | void set_tlayout(const layout_t &tlayout) { tlayout_ = tlayout; } |
1519 | |
1520 | void set_tmasks(const std::unordered_map<std::string, int> &padded_dims, |
1521 | const std::unordered_map<std::string, int> &dim_blocks) { |
1522 | using namespace ir_utils; |
1523 | auto &x = placeholder_var(); |
1524 | for (int i = 0; i < ntdims(); i++) { |
1525 | auto &tdim = tdims_[i]; |
1526 | if (!tdim.is_identity() || !tdim.mask().is_empty()) continue; |
1527 | int vidx = tdim.vidx(0); |
1528 | int dim = tlayout_.dim(i); |
1529 | auto &dim_name = vvars_[vidx].as<var_t>().name; |
1530 | int padded_dim = get_or_default(padded_dims, dim_name, 1); |
1531 | if (dim >= padded_dim) continue; |
1532 | int inner_blk = ir_utils::max_pow2_divisor(dim); |
1533 | int dim_blk = get_or_default(dim_blocks, dim_name, 1); |
1534 | if (math::is_pow2(dim_blk)) { |
1535 | inner_blk = std::min(inner_blk, dim_blk); |
1536 | } |
1537 | auto tmask = (inner_blk == 1) ? (x < dim) |
1538 | : (x / inner_blk < dim / inner_blk); |
1539 | tdim.set_mask(tmask); |
1540 | } |
1541 | } |
1542 | |
1543 | std::string str() const { |
1544 | using ir_utils::operator<<; |
1545 | |
1546 | if (is_empty()) return "(nil)" ; |
1547 | std::ostringstream oss; |
1548 | oss << ir_utils::make_seq_print_helper(vdims_, "x" ); |
1549 | if (!has_zero_vstart()) oss << " vstart: [" << vstart_ << "]" ; |
1550 | oss << " tlayout: " << tlayout_; |
1551 | return oss.str(); |
1552 | } |
1553 | |
1554 | IR_DEFINE_DUMP() |
1555 | |
1556 | bool is_empty() const { return vdims_.empty(); } |
1557 | |
1558 | bool has_zero_vstart() const { |
1559 | for (int i = 0; i < nvdims(); i++) |
1560 | if (!is_zero(vstart_[i])) return false; |
1561 | return true; |
1562 | } |
1563 | |
1564 | bool has_tmask(int tidx) const { |
1565 | ir_assert(tidx >= 0 && tidx < ntdims()); |
1566 | return !tdims_[tidx].mask().is_empty(); |
1567 | } |
1568 | |
1569 | const type_t &type() const { return tlayout_.type(); } |
1570 | |
1571 | expr_t offset(const std::vector<expr_t> &vargs = {}, |
1572 | bool ignore_offset = false) const { |
1573 | auto targs = cvt_vargs_to_targs(vargs); |
1574 | return tlayout_.offset(targs, ignore_offset); |
1575 | } |
1576 | |
1577 | expr_t offset_in_bytes(const std::vector<expr_t> &vargs = {}, |
1578 | bool ignore_offset = false) const { |
1579 | return offset(vargs, ignore_offset) * type().size(); |
1580 | } |
1581 | |
1582 | int get_alignment(const constraint_set_t &cset) const { |
1583 | // Alignment must be a power of 2. |
1584 | const int base_alignment = 128; |
1585 | int64_t f = get_max_const_factor(this->offset_in_bytes(), cset); |
1586 | int alignment = f ? ir_utils::max_pow2_divisor(f) : base_alignment; |
1587 | return std::min(base_alignment, alignment); |
1588 | } |
1589 | |
1590 | int vvar_index(const expr_t &vvar) const { |
1591 | for (size_t i = 0; i < vvars_.size(); i++) |
1592 | if (vvar.is_same(vvars_[i])) return int(i); |
1593 | ir_error_not_expected() << "Can't find view dimension." ; |
1594 | return -1; |
1595 | } |
1596 | |
1597 | template <typename T> |
1598 | T operator()(const std::vector<T> &vargs) const { |
1599 | auto targs = cvt_vargs_to_targs(vargs); |
1600 | return tlayout_(targs); |
1601 | } |
1602 | |
1603 | view_t create_sub_view(const tensor_t &sub_tensor) const; |
1604 | |
1605 | view_t retype(const type_t &new_type) const { |
1606 | auto ret = *this; |
1607 | ret.tlayout_ = tlayout_.retype(new_type); |
1608 | return ret; |
1609 | } |
1610 | |
1611 | view_t make_dense() const { |
1612 | auto ret = *this; |
1613 | ret.tlayout_ = tlayout_.make_dense(); |
1614 | return ret; |
1615 | } |
1616 | |
1617 | bool is_masked_vdim(int vidx) const { |
1618 | ir_assert(vidx >= 0 && vidx < nvdims()); |
1619 | ir_assert(has_zero_vstart()) |
1620 | << "Can't be reliably determined if the view is a sub-view." ; |
1621 | for (int i = 0; i < ntdims(); i++) { |
1622 | auto &tdim = tdims_[i]; |
1623 | if (tdim.expr().is_equal(vvars_[vidx])) { |
1624 | if (vdims_[vidx] != tlayout_.dim(i)) return true; |
1625 | } |
1626 | if (has_tmask(i)) { |
1627 | for (int j = 0; j < tdim.nvargs(); j++) { |
1628 | if (tdim.vidx(j) == vidx) return true; |
1629 | } |
1630 | } |
1631 | } |
1632 | return false; |
1633 | } |
1634 | |
1635 | // Returns the mask corresponding to `vargs` view indices. The mask is |
1636 | // based on: |
1637 | // 1) combined tensor masks for the given indices |
1638 | // 2) Bounds-based masks for those view dimensions that are used directly |
1639 | // in the tensor |
1640 | // - Example: 32a layout when 'a' dimension is A < 32. In general it's |
1641 | // fine to load/store elements with indices in the range [A, 31] |
1642 | // assuming the zero padding invariant. However in some cases we need |
1643 | // to generate the exact bound condition based on the logical indices. |
1644 | expr_t vmask(const std::vector<expr_t> &vargs) const { |
1645 | ir_assert(int(vargs.size()) == nvdims()) << "Incompatible dimensions." ; |
1646 | ir_assert(has_zero_vstart()) |
1647 | << "Can't be reliably determined if the view is a sub-view." ; |
1648 | auto targs = cvt_vargs_to_targs(vargs); |
1649 | auto mask = bool_imm_t::make(true); |
1650 | for (int i = 0; i < ntdims(); i++) { |
1651 | for (int j = 0; j < nvdims(); j++) { |
1652 | if (!tdims_[i].expr().is_equal(vvars_[j])) continue; |
1653 | if (vdims_[j] != tlayout_.dim(i)) { |
1654 | mask &= (vargs[j] < vdims_[j]); |
1655 | } |
1656 | } |
1657 | if (has_tmask(i)) { |
1658 | auto &tdim = tdims_[i]; |
1659 | mask &= tdim.mask(targs[i], vvars_, vargs); |
1660 | } |
1661 | } |
1662 | return mask; |
1663 | } |
1664 | |
1665 | bool can_convert_to_vlayout() const { |
1666 | if (nvdims() != ntdims()) return false; |
1667 | for (int i = 0; i < nvdims(); i++) { |
1668 | if (!tdims_[i].expr().is_same(vvars_[i])) return false; |
1669 | if (!tdims_[i].is_fixed_stride(0)) return false; |
1670 | } |
1671 | return true; |
1672 | } |
1673 | |
1674 | // FIXME: Offset of the returned layout is always 0. |
1675 | layout_t create_pseudo_vlayout() const { |
1676 | return create_pseudo_vlayout(normalized_tlayout()); |
1677 | } |
1678 | |
1679 | layout_t normalized_tlayout() const { |
1680 | auto blocks = move_size_1_blocks_outer(); |
1681 | blocks = layout_t::normalize_blocks(tlayout_.ndims(), blocks, false); |
1682 | auto layout |
1683 | = layout_t(type(), tlayout_.ndims(), offset(), blocks, false); |
1684 | return layout; |
1685 | } |
1686 | |
1687 | layout_t create_dense_vlayout() const { |
1688 | return create_pseudo_vlayout().make_dense(); |
1689 | } |
1690 | |
1691 | layout_t create_vlayout(bool force_zero_offset = false) const { |
1692 | ir_assert(can_convert_to_vlayout()) << "Can't convert view to layout." ; |
1693 | if (force_zero_offset) return tlayout_.map(tensor_t(vdims_)); |
1694 | return tlayout_.map(tensor_t(vdims_, vstart_)); |
1695 | } |
1696 | |
1697 | dim_t vlayout_size() const { return create_vlayout().size(); } |
1698 | |
1699 | bool has_same_vlayout( |
1700 | const view_t &other, bool compare_offset = true) const { |
1701 | return create_vlayout().is_equal( |
1702 | other.create_vlayout(), compare_offset); |
1703 | } |
1704 | |
1705 | view_t split(const grid_info_t &grid, tensor_t &vtile) const { |
1706 | auto vlayout = create_pseudo_vlayout(); |
1707 | vtile = vlayout.split(grid); |
1708 | return create_sub_view(vtile); |
1709 | } |
1710 | |
1711 | view_t split(const grid_info_t &grid) const { |
1712 | tensor_t vtile; |
1713 | return split(grid, vtile); |
1714 | } |
1715 | |
1716 | // Returns a tensor corresponding to the biggest innermost sub-layout so that |
1717 | // 1) It consists of consecutive blocks only. |
1718 | // 2) It contains less or equal than max_tile_elems elements. |
1719 | // 3) It is dense if is_dense_tile is true. |
1720 | tensor_t split_into_max_tile( |
1721 | dim_t max_tile_elems, bool is_dense_tile) const { |
1722 | auto vlayout = create_pseudo_vlayout(); |
1723 | return vlayout.split_into_max_tile(max_tile_elems, is_dense_tile); |
1724 | } |
1725 | |
1726 | template <typename F> |
1727 | void for_each_tile(const tensor_t &tile, const F &f) const { |
1728 | auto vlayout = create_dense_vlayout(); |
1729 | vlayout.for_each_tile(tile, f); |
1730 | } |
1731 | |
1732 | view_t substitute(const expr_t &from, const expr_t &to) const; |
1733 | |
1734 | mask_tensor_t create_mask_tensor( |
1735 | const constraint_set_t &cset, uint32_t tmask = 0xFFFFFFFF) const { |
1736 | auto _vlayout = create_dense_vlayout(); |
1737 | mask_tensor_t mask_tensor(_vlayout); |
1738 | std::vector<dim_t> vargs(nvdims()); |
1739 | create_mask_tensor(mask_tensor, _vlayout, 0, vargs, tmask); |
1740 | mask_tensor.simplify(cset); |
1741 | return mask_tensor; |
1742 | } |
1743 | |
1744 | void try_create_buffer_view(view_t &buf_view, view_t &inv_view) const { |
1745 | buf_view = view_t(create_vvars(ntdims()), ntdims()); |
1746 | inv_view = view_t(vvars(), ntdims()); |
1747 | for (int i = 0; i < nvdims(); i++) { |
1748 | inv_view.set_vdim(vvars()[i], vdims()[i]); |
1749 | } |
1750 | for (int i = 0; i < ntdims(); i++) { |
1751 | auto &tdim = tdims_[i]; |
1752 | auto &buf_vvar = buf_view.vvars()[i]; |
1753 | if (tdim.is_identity()) { |
1754 | int vidx = tdim.vidx(0); |
1755 | buf_view.set_vdim(buf_vvar, vdims()[vidx], vstart(vidx)); |
1756 | buf_view.set_tdim(i, buf_vvar, tdim.mask()); |
1757 | inv_view.set_tdim(i, tdim.expr()); |
1758 | continue; |
1759 | } |
1760 | int buf_vdim = 0; |
1761 | bool ok = true; |
1762 | for (int j = 0; j < tdim.nvargs(); j++) { |
1763 | int vidx = tdim.vidx(j); |
1764 | auto &vvar = vvars()[vidx]; |
1765 | int vdim = vdims()[vidx]; |
1766 | if (vdim == 1) continue; |
1767 | auto A = tdim.expr(); |
1768 | auto B = jit::substitute(A, vvar, vvar + 1); |
1769 | auto C = simplify(B - A); |
1770 | if (!is_const(C)) { |
1771 | ok = false; |
1772 | break; |
1773 | } |
1774 | buf_vdim += to_cpp<int>(C) * (vdim - 1); |
1775 | } |
1776 | buf_vdim++; |
1777 | |
1778 | if (!ok) { |
1779 | buf_view = view_t(); |
1780 | inv_view = view_t(); |
1781 | return; |
1782 | } |
1783 | |
1784 | auto buf_vstart = tdim.expr(); |
1785 | auto inv_vstart = tdim.expr(); |
1786 | for (int j = 0; j < tdim.nvargs(); j++) { |
1787 | int vidx = tdim.vidx(j); |
1788 | buf_vstart = jit::substitute( |
1789 | buf_vstart, vvars()[vidx], vstart(vidx)); |
1790 | inv_vstart |
1791 | = jit::substitute(inv_vstart, vvars()[vidx], expr_t(0)); |
1792 | } |
1793 | buf_vstart = simplify(buf_vstart); |
1794 | inv_vstart = simplify(inv_vstart); |
1795 | |
1796 | if (!is_const(inv_vstart)) { |
1797 | buf_view = view_t(); |
1798 | inv_view = view_t(); |
1799 | return; |
1800 | } |
1801 | |
1802 | buf_view.set_vdim(buf_vvar, buf_vdim, buf_vstart); |
1803 | |
1804 | // Check that mask doesn't contain vvars - they can't be accessed |
1805 | // in the buffered view. |
1806 | auto &tmask = tdim.mask(); |
1807 | for (auto &vvar : vvars()) { |
1808 | if (contains_object(tmask, vvar)) { |
1809 | buf_view = view_t(); |
1810 | inv_view = view_t(); |
1811 | return; |
1812 | } |
1813 | } |
1814 | |
1815 | buf_view.set_tdim(i, buf_vvar, tmask); |
1816 | inv_view.set_tdim(i, tdim.expr() - inv_vstart); |
1817 | } |
1818 | buf_view.set_tlayout(tlayout_); |
1819 | } |
1820 | |
1821 | static const expr_t &placeholder_var() { |
1822 | return tdim_info_t::placeholder_var(); |
1823 | } |
1824 | |
1825 | static std::vector<expr_t> create_vvars(int nvdims); |
1826 | |
1827 | template <typename SrcT = expr_t, typename DstT = SrcT> |
1828 | std::vector<DstT> cvt_vargs_to_targs(const std::vector<SrcT> &_vargs = {}, |
1829 | bool ignore_vstart = false) const { |
1830 | std::vector<expr_t> vargs = expr_cast<expr_t>(_vargs); |
1831 | if (vargs.empty()) vargs.resize(nvdims(), 0); |
1832 | |
1833 | if (!ignore_vstart) { |
1834 | for (int i = 0; i < nvdims(); i++) { |
1835 | if (!is_zero(vstart_[i])) vargs[i] += vstart_[i]; |
1836 | } |
1837 | } |
1838 | |
1839 | std::vector<expr_t> targs(ntdims()); |
1840 | for (int i = 0; i < ntdims(); i++) { |
1841 | targs[i] = tdims_[i].expr(); |
1842 | for (int j = 0; j < nvdims(); j++) { |
1843 | targs[i] = jit::substitute(targs[i], vvars_[j], vargs[j]); |
1844 | } |
1845 | } |
1846 | for (int i = 0; i < ntdims(); i++) { |
1847 | targs[i] = const_fold(targs[i]); |
1848 | } |
1849 | return expr_cast<DstT>(targs); |
1850 | } |
1851 | |
1852 | private: |
1853 | layout_t create_pseudo_vlayout(const layout_t &tlayout) const; |
1854 | |
1855 | void create_mask_tensor(mask_tensor_t &mask_tensor, |
1856 | const layout_t &_vlayout, int vidx, std::vector<dim_t> &vargs, |
1857 | uint32_t tmask) const { |
1858 | if (vidx == _vlayout.ndims()) { |
1859 | bool is_init = false; |
1860 | std::vector<expr_t> vvalues; |
1861 | std::vector<expr_t> targs; |
1862 | expr_t mask = bool_imm_t::make(true); |
1863 | for (int i = 0; i < ntdims(); i++) { |
1864 | auto &tdim = tdims_[i]; |
1865 | if ((tmask & (1 << i)) == 0) continue; |
1866 | if (tdim.mask().is_empty()) continue; |
1867 | if (!is_init) { |
1868 | // Lazily initialize values |
1869 | vvalues = vstart_; |
1870 | for (int i = 0; i < nvdims(); i++) |
1871 | vvalues[i] += vargs[i]; |
1872 | targs = cvt_vargs_to_targs<dim_t, expr_t>(vargs); |
1873 | is_init = true; |
1874 | } |
1875 | mask &= tdim.mask(targs[i], vvars_, vvalues); |
1876 | } |
1877 | mask_tensor.set_mask(_vlayout(vargs), mask); |
1878 | return; |
1879 | } |
1880 | |
1881 | for (int i = 0; i < vdims()[vidx]; i++) { |
1882 | vargs[vidx] = i; |
1883 | create_mask_tensor(mask_tensor, _vlayout, vidx + 1, vargs, tmask); |
1884 | } |
1885 | } |
1886 | |
1887 | std::vector<block_t> move_size_1_blocks_outer() const { |
1888 | std::vector<block_t> new_blocks; |
1889 | std::vector<block_t> size_1_blocks; |
1890 | for (auto &b : tlayout_.blocks()) { |
1891 | if (b.block == 1 && vdims_[b.dim_idx] == 1) { |
1892 | size_1_blocks.emplace_back(b); |
1893 | } else { |
1894 | new_blocks.emplace_back(b); |
1895 | } |
1896 | } |
1897 | stride_t stride = new_blocks.empty() |
1898 | ? stride_t(1) |
1899 | : new_blocks.back().block * new_blocks.back().stride; |
1900 | for (auto &b : size_1_blocks) { |
1901 | b.stride = stride; |
1902 | new_blocks.emplace_back(b); |
1903 | } |
1904 | return new_blocks; |
1905 | } |
1906 | |
1907 | std::vector<expr_t> vvars_; |
1908 | std::vector<dim_t> vdims_; |
1909 | std::vector<expr_t> vstart_; |
1910 | |
1911 | std::vector<tdim_info_t> tdims_; |
1912 | layout_t tlayout_; |
1913 | }; |
1914 | |
1915 | inline std::ostream &operator<<(std::ostream &out, const view_t &view) { |
1916 | out << view.str(); |
1917 | return out; |
1918 | } |
1919 | |
1920 | class dim_assignment_t { |
1921 | public: |
1922 | dim_assignment_t() = default; |
1923 | |
1924 | dim_assignment_t(int old_ndims, int new_ndims) |
1925 | : old_ndims_(old_ndims) |
1926 | , new_ndims_(new_ndims) |
1927 | , assignments_(old_ndims, -1) {} |
1928 | |
1929 | void assign(int old_idx, int new_idx) { |
1930 | ir_assert(0 <= old_idx && old_idx < old_ndims_); |
1931 | ir_assert(0 <= new_idx && new_idx < new_ndims_); |
1932 | assignments_[old_idx] = new_idx; |
1933 | } |
1934 | |
1935 | void assign(const std::vector<int> &old_idxes, int new_idx) { |
1936 | for (auto old_idx : old_idxes) { |
1937 | assign(old_idx, new_idx); |
1938 | } |
1939 | } |
1940 | |
1941 | int operator[](int old_idx) const { |
1942 | ir_assert(old_idx >= 0 && old_idx < old_ndims()); |
1943 | return assignments_[old_idx]; |
1944 | } |
1945 | |
1946 | int old_ndims() const { return old_ndims_; } |
1947 | |
1948 | int new_ndims() const { return new_ndims_; } |
1949 | |
1950 | bool is_empty() const { return old_ndims_ == 0 && new_ndims_ == 0; } |
1951 | |
1952 | layout_t map(const layout_t &layout) const; |
1953 | |
1954 | private: |
1955 | int old_ndims_ = 0; |
1956 | int new_ndims_ = 0; |
1957 | |
1958 | // assignments_[old_idx] = new_idx. |
1959 | std::vector<int> assignments_; |
1960 | }; |
1961 | |
1962 | } // namespace jit |
1963 | } // namespace gpu |
1964 | } // namespace impl |
1965 | } // namespace dnnl |
1966 | |
1967 | #endif |
1968 | |