1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/lang/data_layout.cc
22 * \brief Data Layout expression.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/data_layout.h>
27#include <tvm/tir/stmt_functor.h>
28
29#include <cctype>
30
31namespace tvm {
32namespace tir {
33using tir::IterVar;
34using tir::IterVarNode;
35using tir::Var;
36
37TVM_REGISTER_NODE_TYPE(LayoutNode);
38TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode);
39
40const LayoutAxis LayoutAxis::UPPER_CASE[] = {
41 LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
42 LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
43 LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
44 LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
45 LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
46 LayoutAxis('Z')};
47
48const LayoutAxis LayoutAxis::LOWER_CASE[] = {
49 LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
50 LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
51 LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
52 LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
53 LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
54 LayoutAxis('z')};
55
56const LayoutAxis& LayoutAxis::Get(const char name) {
57 ICHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z'))
58 << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
59 return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A']
60 : LayoutAxis::LOWER_CASE[name - 'a'];
61}
62
63const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
64 const std::string axis = itvar->var.get()->name_hint;
65 ICHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis;
66 return LayoutAxis::Get(axis[0]);
67}
68
69const LayoutAxis& LayoutAxis::Get(const std::string& name) {
70 ICHECK_EQ(name.length(), 1) << "Invalid axis " << name;
71 return LayoutAxis::Get(name[0]);
72}
73
74Layout::Layout(const Array<IterVar>& axes) {
75 auto node = make_object<LayoutNode>();
76 node->axes = axes;
77 std::ostringstream repr;
78 for (const IterVar& axis : axes) {
79 if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
80 ICHECK_GT(factor->value, 0);
81 repr << factor->value;
82 }
83 ICHECK_EQ(axis->var.get()->name_hint.size(), 1)
84 << "Invalid layout axis " << axis->var.get()->name_hint;
85 char c = axis->var.get()->name_hint.operator std::string()[0];
86 ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
87 repr << axis->var.get()->name_hint;
88 }
89 node->name = repr.str();
90 data_ = std::move(node);
91}
92
93Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*)
94 CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type";
95 if (name == "__undef__") return;
96
97 auto node = make_object<LayoutNode>();
98 node->name = name;
99
100 if (name.empty()) return; // scalar
101
102 // parse layout string
103 int32_t factor = 0;
104 for (char c : name) {
105 if (c >= 'A' && c <= 'Z') {
106 ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
107 << " before dimension " << c;
108 std::string shape_name("_shape");
109 shape_name.insert(0, 1, c);
110 IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype),
111 tir::kDataPar);
112 node->axes.push_back(axis);
113 } else if (c >= 'a' && c <= 'z') {
114 ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
115 << " for dimension " << c;
116 IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype),
117 tir::kDataPar);
118 node->axes.push_back(axis);
119 factor = 0;
120 } else if (c >= '0' && c <= '9') {
121 ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
122 factor = factor * 10 + c - '0';
123 } else {
124 LOG(FATAL) << "Invalid layout " << name;
125 }
126 }
127
128 // validate layout
129 std::vector<bool> exist_axis(256, false);
130 for (const IterVar& v : node->axes) {
131 auto axis_str = v->var.get()->name_hint.operator std::string();
132 ICHECK_EQ(axis_str.size(), 1);
133 char axis = axis_str[0];
134 ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
135 exist_axis[axis] = true;
136 }
137 for (const IterVar& v : node->axes) {
138 char axis = v->var.get()->name_hint.operator std::string()[0];
139 if (axis >= 'a' && axis <= 'z') {
140 ICHECK(exist_axis[axis - 'a' + 'A'])
141 << "Invalid layout " << name << ": missing axis " << std::toupper(axis);
142 }
143 }
144 data_ = std::move(node);
145}
146
147Layout Layout::SubLayout(size_t pos, size_t len) const {
148 if (!defined() || pos > ndim()) return Layout::Undef();
149 if (len == 0) return Layout(Array<IterVar>());
150 if (pos + len > ndim()) len = ndim() - pos;
151 Array<IterVar> new_layout;
152 const auto axes = operator->()->axes;
153 for (size_t i = pos; i < pos + len; ++i) {
154 new_layout.push_back(axes[i]);
155 }
156 return Layout(new_layout);
157}
158
159Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const {
160 if (!defined()) return Layout::Undef();
161 const std::string& name = operator->()->name;
162 const auto axes = operator->()->axes;
163 ICHECK(target_pos <= this->ndim())
164 << "Invalid split position " << target_pos << " for layout " << name;
165 ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis;
166 ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name;
167 ICHECK(!this->Contains(axis.ToSubordinate()))
168 << "Axis " << axis << " has already been split in " << name;
169 ICHECK(factor > 0) << "Invalid split size " << factor;
170 Array<IterVar> new_layout;
171 for (size_t i = 0; i <= this->ndim(); ++i) {
172 if (i == target_pos) {
173 new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)),
174 Var(axis.ToSubordinate().name()), tir::kDataPar));
175 }
176 if (i == this->ndim()) break;
177 new_layout.push_back(axes[i]);
178 }
179 return Layout(new_layout);
180}
181
182int32_t Layout::FactorOf(const LayoutAxis& axis) const {
183 if (!defined()) return -1;
184 const LayoutAxis& sub = axis.ToSubordinate();
185
186 int32_t factor = 1;
187 bool has_sub = false;
188 for (const IterVar& itvar : operator->()->axes) {
189 if (sub == LayoutAxis::Get(itvar)) {
190 has_sub = true;
191 int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
192 ICHECK(val);
193 factor *= val;
194 }
195 }
196 factor = has_sub ? factor : -1;
197
198 return factor;
199}
200
201TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
202 .set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
203 auto* l = static_cast<const LayoutNode*>(node.get());
204 p->stream << "Layout(" << l->name << ")";
205 });
206
207inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rule,
208 const Layout& src_layout, const Layout& dst_layout) {
209 if (!src_layout.defined() || src_layout.name().empty()) {
210 LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid.";
211 return false;
212 }
213 if (!dst_layout.defined() || dst_layout.name().empty()) {
214 LOG(WARNING) << "dst layout '" << dst_layout.name() << "' is invalid.";
215 return false;
216 }
217
218 for (size_t i = 0; i < dst_layout.ndim(); ++i) {
219 const auto& store_axis = dst_layout[i];
220 const IterVar& store_axis_impl = dst_layout->axes[i];
221 PrimExpr index_store(0);
222
223 for (size_t j = 0; j < src_layout.ndim(); ++j) {
224 const auto& orig_axis = src_layout[j];
225 const IterVar& orig_axis_impl = src_layout->axes[j];
226 if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
227 if (orig_axis.IsPrimal()) {
228 PrimExpr orig_var = orig_axis_impl->var;
229 const int32_t factor = src_layout.FactorOf(orig_axis);
230 if (factor > 0) {
231 orig_var = orig_var * factor;
232 }
233 index_store = index_store + orig_var;
234 } else {
235 PrimExpr factor(1);
236 for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
237 if (LayoutAxis::Get(orig_axis_impl) == LayoutAxis::Get(src_layout->axes[k])) {
238 factor = factor * src_layout->axes[k]->dom->extent;
239 }
240 }
241 index_store = index_store + orig_axis_impl->var * factor;
242 }
243 }
244 }
245 if (tir::is_zero(index_store)) {
246 LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << dst_layout.name()
247 << "' is not convertible.";
248 return false;
249 }
250
251 PrimExpr shape_store = index_store;
252 if (store_axis.IsPrimal()) {
253 const int32_t factor = dst_layout.FactorOf(store_axis);
254 if (factor > 0) {
255 shape_store = shapediv(index_store, PrimExpr(factor));
256 index_store = indexdiv(index_store, PrimExpr(factor));
257 }
258 } else {
259 PrimExpr stride(1);
260 PrimExpr factor(1);
261 for (size_t j = i; j < dst_layout.ndim(); ++j) {
262 if (LayoutAxis::Get(store_axis_impl) == LayoutAxis::Get(dst_layout->axes[j])) {
263 stride = stride * dst_layout->axes[j]->dom->extent;
264 if (j > i) {
265 factor = factor * dst_layout->axes[j]->dom->extent;
266 }
267 }
268 }
269 shape_store = indexdiv(indexmod(index_store, stride), factor);
270 index_store = indexdiv(indexmod(index_store, stride), factor);
271 }
272
273 index_rule->push_back(index_store);
274 shape_rule->push_back(shape_store);
275 }
276
277 std::stringstream ss;
278 ss << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ ";
279 for (const auto& r : *index_rule) {
280 ss << r << ", ";
281 }
282 ss << "]" << std::endl;
283
284 ss << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ ";
285 for (const auto& r : *shape_rule) {
286 ss << r << ", ";
287 }
288 ss << "]" << std::endl;
289 VLOG(1) << std::endl << ss.str();
290
291 return true;
292}
293
294inline Array<PrimExpr> TransformIndex(const Array<PrimExpr>& src_index,
295 const Array<IterVar>& src_axis,
296 const Array<PrimExpr>& transform_rule) {
297 arith::Analyzer ana;
298 Array<PrimExpr> result;
299 std::unordered_map<const tir::VarNode*, PrimExpr> bind_map;
300 for (size_t i = 0; i < src_index.size(); ++i) {
301 bind_map[src_axis[i]->var.get()] = src_index[i];
302 }
303 for (PrimExpr rule : transform_rule) {
304 result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
305 }
306 return result;
307}
308
309Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index) const {
310 ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
311 const BijectiveLayoutNode* self = operator->();
312 ICHECK_EQ(src_index.size(), self->src_layout->axes.size())
313 << "Input mismatch with layout " << self->src_layout;
314 return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule);
315}
316
317Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
318 ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
319 const BijectiveLayoutNode* self = operator->();
320 ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
321 << "Output mismatch with layout " << self->dst_layout;
322 return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule);
323}
324
325inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
326 const Array<IterVar>& src_axis,
327 const Array<IterVar>& target_axis,
328 const Array<PrimExpr>& transform_rule) {
329 arith::Analyzer ana;
330 ICHECK_EQ(src_shape.size(), src_axis.size())
331 << "Input shape size " << src_shape.size() << " mismatch with the exepected shape size "
332 << src_axis.size();
333 // bind variables for original axes
334 // for major-axis, bind the corresponding size
335 // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
336 // e.g., (C * 16 + c) / 32
337 std::unordered_map<const tir::VarNode*, PrimExpr> bind_map;
338 for (size_t i = 0; i < src_shape.size(); ++i) {
339 PrimExpr orig_shape = src_shape[i];
340 IterVar orig_axis = src_axis[i];
341 if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
342 if (orig_shape.defined()) {
343 const auto* orig_shape_const = orig_shape.as<IntImmNode>();
344 const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
345 if (orig_shape_const) {
346 ICHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
347 << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent
348 << ", get " << orig_shape;
349 }
350 }
351 bind_map[orig_axis->var.get()] = IntImm(orig_axis->var->dtype, 0);
352 } else {
353 bind_map[orig_axis->var.get()] = orig_axis->var->dtype == orig_shape->dtype
354 ? orig_shape
355 : cast(orig_axis->var->dtype, orig_shape);
356 }
357 }
358 // infer the target shape,
359 // for major-axis, use the forward/backward_rule directly,
360 // for minor-axis, simply use the extent.
361 Array<PrimExpr> result;
362 ICHECK_EQ(transform_rule.size(), target_axis.size());
363 for (size_t i = 0; i < transform_rule.size(); ++i) {
364 PrimExpr rule = transform_rule[i];
365 IterVar axis = target_axis[i];
366 if (!LayoutAxis::Get(axis).IsPrimal()) {
367 result.push_back(axis->dom->extent);
368 } else {
369 result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
370 }
371 }
372
373 std::stringstream ss;
374 ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name()
375 << ": [ ";
376 for (const auto& r : transform_rule) {
377 ss << r << ", ";
378 }
379 ss << "]" << std::endl;
380
381 ss << "shape transform: [ ";
382 for (const auto& s : src_shape) {
383 ss << s << ", ";
384 }
385 ss << "] --> [ ";
386 for (const auto& r : result) {
387 ss << r << ", ";
388 }
389 ss << "]" << std::endl;
390 VLOG(1) << std::endl << ss.str();
391
392 return result;
393}
394
395Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
396 ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
397 const BijectiveLayoutNode* self = operator->();
398 return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes,
399 self->shape_forward_rule);
400}
401
402Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
403 ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
404 const BijectiveLayoutNode* self = operator->();
405 return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes,
406 self->shape_backward_rule);
407}
408
409BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {
410 auto n = make_object<BijectiveLayoutNode>();
411
412 n->src_layout = std::move(src_layout);
413 n->dst_layout = std::move(dst_layout);
414 // To be consistent with previous behavior, a nullptr layout is created
415 // when argument is invalid.
416 if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) {
417 ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout,
418 n->src_layout));
419 data_ = std::move(n);
420 }
421}
422
423TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
424 .set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
425 auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
426 p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name()
427 << ")";
428 });
429
430TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) {
431 return Layout(name, dtype);
432});
433
434TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
435 return layout.IndexOf(LayoutAxis::Get(axis));
436});
437
438TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
439 .set_body_typed([](Layout layout, std::string axis) -> int {
440 return layout.FactorOf(LayoutAxis::Get(axis));
441 });
442
443TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int {
444 return layout.ndim();
445});
446
447TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string {
448 const LayoutAxis& axis = layout[idx];
449 return axis.name();
450});
451
452TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
453 .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
454 return BijectiveLayout(src_layout, dst_layout);
455 });
456
457TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
458 .set_body_method(&BijectiveLayout::ForwardIndex);
459
460TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex")
461 .set_body_method(&BijectiveLayout::BackwardIndex);
462
463TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape")
464 .set_body_method(&BijectiveLayout::ForwardShape);
465
466TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape")
467 .set_body_method(&BijectiveLayout::BackwardShape);
468} // namespace tir
469} // namespace tvm
470