1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/lang/data_layout.cc |
22 | * \brief Data Layout expression. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include <cctype> |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | using tir::IterVar; |
34 | using tir::IterVarNode; |
35 | using tir::Var; |
36 | |
37 | TVM_REGISTER_NODE_TYPE(LayoutNode); |
38 | TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode); |
39 | |
40 | const LayoutAxis LayoutAxis::UPPER_CASE[] = { |
41 | LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), |
42 | LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), |
43 | LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), |
44 | LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), |
45 | LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), |
46 | LayoutAxis('Z')}; |
47 | |
48 | const LayoutAxis LayoutAxis::LOWER_CASE[] = { |
49 | LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), |
50 | LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), |
51 | LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), |
52 | LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), |
53 | LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), |
54 | LayoutAxis('z')}; |
55 | |
56 | const LayoutAxis& LayoutAxis::Get(const char name) { |
57 | ICHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z')) |
58 | << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z." ; |
59 | return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A'] |
60 | : LayoutAxis::LOWER_CASE[name - 'a']; |
61 | } |
62 | |
63 | const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { |
64 | const std::string axis = itvar->var.get()->name_hint; |
65 | ICHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis; |
66 | return LayoutAxis::Get(axis[0]); |
67 | } |
68 | |
69 | const LayoutAxis& LayoutAxis::Get(const std::string& name) { |
70 | ICHECK_EQ(name.length(), 1) << "Invalid axis " << name; |
71 | return LayoutAxis::Get(name[0]); |
72 | } |
73 | |
74 | Layout::Layout(const Array<IterVar>& axes) { |
75 | auto node = make_object<LayoutNode>(); |
76 | node->axes = axes; |
77 | std::ostringstream repr; |
78 | for (const IterVar& axis : axes) { |
79 | if (const auto* factor = axis->dom->extent.as<IntImmNode>()) { |
80 | ICHECK_GT(factor->value, 0); |
81 | repr << factor->value; |
82 | } |
83 | ICHECK_EQ(axis->var.get()->name_hint.size(), 1) |
84 | << "Invalid layout axis " << axis->var.get()->name_hint; |
85 | char c = axis->var.get()->name_hint.operator std::string()[0]; |
86 | ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; |
87 | repr << axis->var.get()->name_hint; |
88 | } |
89 | node->name = repr.str(); |
90 | data_ = std::move(node); |
91 | } |
92 | |
93 | Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) |
94 | CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type" ; |
95 | if (name == "__undef__" ) return; |
96 | |
97 | auto node = make_object<LayoutNode>(); |
98 | node->name = name; |
99 | |
100 | if (name.empty()) return; // scalar |
101 | |
102 | // parse layout string |
103 | int32_t factor = 0; |
104 | for (char c : name) { |
105 | if (c >= 'A' && c <= 'Z') { |
106 | ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor |
107 | << " before dimension " << c; |
108 | std::string shape_name("_shape" ); |
109 | shape_name.insert(0, 1, c); |
110 | IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype), |
111 | tir::kDataPar); |
112 | node->axes.push_back(axis); |
113 | } else if (c >= 'a' && c <= 'z') { |
114 | ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor |
115 | << " for dimension " << c; |
116 | IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype), |
117 | tir::kDataPar); |
118 | node->axes.push_back(axis); |
119 | factor = 0; |
120 | } else if (c >= '0' && c <= '9') { |
121 | ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number." ; |
122 | factor = factor * 10 + c - '0'; |
123 | } else { |
124 | LOG(FATAL) << "Invalid layout " << name; |
125 | } |
126 | } |
127 | |
128 | // validate layout |
129 | std::vector<bool> exist_axis(256, false); |
130 | for (const IterVar& v : node->axes) { |
131 | auto axis_str = v->var.get()->name_hint.operator std::string(); |
132 | ICHECK_EQ(axis_str.size(), 1); |
133 | char axis = axis_str[0]; |
134 | ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); |
135 | exist_axis[axis] = true; |
136 | } |
137 | for (const IterVar& v : node->axes) { |
138 | char axis = v->var.get()->name_hint.operator std::string()[0]; |
139 | if (axis >= 'a' && axis <= 'z') { |
140 | ICHECK(exist_axis[axis - 'a' + 'A']) |
141 | << "Invalid layout " << name << ": missing axis " << std::toupper(axis); |
142 | } |
143 | } |
144 | data_ = std::move(node); |
145 | } |
146 | |
147 | Layout Layout::SubLayout(size_t pos, size_t len) const { |
148 | if (!defined() || pos > ndim()) return Layout::Undef(); |
149 | if (len == 0) return Layout(Array<IterVar>()); |
150 | if (pos + len > ndim()) len = ndim() - pos; |
151 | Array<IterVar> new_layout; |
152 | const auto axes = operator->()->axes; |
153 | for (size_t i = pos; i < pos + len; ++i) { |
154 | new_layout.push_back(axes[i]); |
155 | } |
156 | return Layout(new_layout); |
157 | } |
158 | |
159 | Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const { |
160 | if (!defined()) return Layout::Undef(); |
161 | const std::string& name = operator->()->name; |
162 | const auto axes = operator->()->axes; |
163 | ICHECK(target_pos <= this->ndim()) |
164 | << "Invalid split position " << target_pos << " for layout " << name; |
165 | ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; |
166 | ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; |
167 | ICHECK(!this->Contains(axis.ToSubordinate())) |
168 | << "Axis " << axis << " has already been split in " << name; |
169 | ICHECK(factor > 0) << "Invalid split size " << factor; |
170 | Array<IterVar> new_layout; |
171 | for (size_t i = 0; i <= this->ndim(); ++i) { |
172 | if (i == target_pos) { |
173 | new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)), |
174 | Var(axis.ToSubordinate().name()), tir::kDataPar)); |
175 | } |
176 | if (i == this->ndim()) break; |
177 | new_layout.push_back(axes[i]); |
178 | } |
179 | return Layout(new_layout); |
180 | } |
181 | |
182 | int32_t Layout::FactorOf(const LayoutAxis& axis) const { |
183 | if (!defined()) return -1; |
184 | const LayoutAxis& sub = axis.ToSubordinate(); |
185 | |
186 | int32_t factor = 1; |
187 | bool has_sub = false; |
188 | for (const IterVar& itvar : operator->()->axes) { |
189 | if (sub == LayoutAxis::Get(itvar)) { |
190 | has_sub = true; |
191 | int32_t val = itvar->dom->extent.as<IntImmNode>()->value; |
192 | ICHECK(val); |
193 | factor *= val; |
194 | } |
195 | } |
196 | factor = has_sub ? factor : -1; |
197 | |
198 | return factor; |
199 | } |
200 | |
201 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
202 | .set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) { |
203 | auto* l = static_cast<const LayoutNode*>(node.get()); |
204 | p->stream << "Layout(" << l->name << ")" ; |
205 | }); |
206 | |
207 | inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rule, |
208 | const Layout& src_layout, const Layout& dst_layout) { |
209 | if (!src_layout.defined() || src_layout.name().empty()) { |
210 | LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid." ; |
211 | return false; |
212 | } |
213 | if (!dst_layout.defined() || dst_layout.name().empty()) { |
214 | LOG(WARNING) << "dst layout '" << dst_layout.name() << "' is invalid." ; |
215 | return false; |
216 | } |
217 | |
218 | for (size_t i = 0; i < dst_layout.ndim(); ++i) { |
219 | const auto& store_axis = dst_layout[i]; |
220 | const IterVar& store_axis_impl = dst_layout->axes[i]; |
221 | PrimExpr index_store(0); |
222 | |
223 | for (size_t j = 0; j < src_layout.ndim(); ++j) { |
224 | const auto& orig_axis = src_layout[j]; |
225 | const IterVar& orig_axis_impl = src_layout->axes[j]; |
226 | if (store_axis.ToPrimal() == orig_axis.ToPrimal()) { |
227 | if (orig_axis.IsPrimal()) { |
228 | PrimExpr orig_var = orig_axis_impl->var; |
229 | const int32_t factor = src_layout.FactorOf(orig_axis); |
230 | if (factor > 0) { |
231 | orig_var = orig_var * factor; |
232 | } |
233 | index_store = index_store + orig_var; |
234 | } else { |
235 | PrimExpr factor(1); |
236 | for (size_t k = j + 1; k < src_layout.ndim(); ++k) { |
237 | if (LayoutAxis::Get(orig_axis_impl) == LayoutAxis::Get(src_layout->axes[k])) { |
238 | factor = factor * src_layout->axes[k]->dom->extent; |
239 | } |
240 | } |
241 | index_store = index_store + orig_axis_impl->var * factor; |
242 | } |
243 | } |
244 | } |
245 | if (tir::is_zero(index_store)) { |
246 | LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << dst_layout.name() |
247 | << "' is not convertible." ; |
248 | return false; |
249 | } |
250 | |
251 | PrimExpr shape_store = index_store; |
252 | if (store_axis.IsPrimal()) { |
253 | const int32_t factor = dst_layout.FactorOf(store_axis); |
254 | if (factor > 0) { |
255 | shape_store = shapediv(index_store, PrimExpr(factor)); |
256 | index_store = indexdiv(index_store, PrimExpr(factor)); |
257 | } |
258 | } else { |
259 | PrimExpr stride(1); |
260 | PrimExpr factor(1); |
261 | for (size_t j = i; j < dst_layout.ndim(); ++j) { |
262 | if (LayoutAxis::Get(store_axis_impl) == LayoutAxis::Get(dst_layout->axes[j])) { |
263 | stride = stride * dst_layout->axes[j]->dom->extent; |
264 | if (j > i) { |
265 | factor = factor * dst_layout->axes[j]->dom->extent; |
266 | } |
267 | } |
268 | } |
269 | shape_store = indexdiv(indexmod(index_store, stride), factor); |
270 | index_store = indexdiv(indexmod(index_store, stride), factor); |
271 | } |
272 | |
273 | index_rule->push_back(index_store); |
274 | shape_rule->push_back(shape_store); |
275 | } |
276 | |
277 | std::stringstream ss; |
278 | ss << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ " ; |
279 | for (const auto& r : *index_rule) { |
280 | ss << r << ", " ; |
281 | } |
282 | ss << "]" << std::endl; |
283 | |
284 | ss << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ " ; |
285 | for (const auto& r : *shape_rule) { |
286 | ss << r << ", " ; |
287 | } |
288 | ss << "]" << std::endl; |
289 | VLOG(1) << std::endl << ss.str(); |
290 | |
291 | return true; |
292 | } |
293 | |
294 | inline Array<PrimExpr> TransformIndex(const Array<PrimExpr>& src_index, |
295 | const Array<IterVar>& src_axis, |
296 | const Array<PrimExpr>& transform_rule) { |
297 | arith::Analyzer ana; |
298 | Array<PrimExpr> result; |
299 | std::unordered_map<const tir::VarNode*, PrimExpr> bind_map; |
300 | for (size_t i = 0; i < src_index.size(); ++i) { |
301 | bind_map[src_axis[i]->var.get()] = src_index[i]; |
302 | } |
303 | for (PrimExpr rule : transform_rule) { |
304 | result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); |
305 | } |
306 | return result; |
307 | } |
308 | |
309 | Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index) const { |
310 | ICHECK(defined()) << "Cannot operate on an undefined bijective layout." ; |
311 | const BijectiveLayoutNode* self = operator->(); |
312 | ICHECK_EQ(src_index.size(), self->src_layout->axes.size()) |
313 | << "Input mismatch with layout " << self->src_layout; |
314 | return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule); |
315 | } |
316 | |
317 | Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const { |
318 | ICHECK(defined()) << "Cannot operate on an undefined bijective layout." ; |
319 | const BijectiveLayoutNode* self = operator->(); |
320 | ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) |
321 | << "Output mismatch with layout " << self->dst_layout; |
322 | return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule); |
323 | } |
324 | |
325 | inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape, |
326 | const Array<IterVar>& src_axis, |
327 | const Array<IterVar>& target_axis, |
328 | const Array<PrimExpr>& transform_rule) { |
329 | arith::Analyzer ana; |
330 | ICHECK_EQ(src_shape.size(), src_axis.size()) |
331 | << "Input shape size " << src_shape.size() << " mismatch with the exepected shape size " |
332 | << src_axis.size(); |
333 | // bind variables for original axes |
334 | // for major-axis, bind the corresponding size |
335 | // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, |
336 | // e.g., (C * 16 + c) / 32 |
337 | std::unordered_map<const tir::VarNode*, PrimExpr> bind_map; |
338 | for (size_t i = 0; i < src_shape.size(); ++i) { |
339 | PrimExpr orig_shape = src_shape[i]; |
340 | IterVar orig_axis = src_axis[i]; |
341 | if (!LayoutAxis::Get(orig_axis).IsPrimal()) { |
342 | if (orig_shape.defined()) { |
343 | const auto* orig_shape_const = orig_shape.as<IntImmNode>(); |
344 | const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>(); |
345 | if (orig_shape_const) { |
346 | ICHECK_EQ(orig_shape_const->value, orig_axis_extent->value) |
347 | << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent |
348 | << ", get " << orig_shape; |
349 | } |
350 | } |
351 | bind_map[orig_axis->var.get()] = IntImm(orig_axis->var->dtype, 0); |
352 | } else { |
353 | bind_map[orig_axis->var.get()] = orig_axis->var->dtype == orig_shape->dtype |
354 | ? orig_shape |
355 | : cast(orig_axis->var->dtype, orig_shape); |
356 | } |
357 | } |
358 | // infer the target shape, |
359 | // for major-axis, use the forward/backward_rule directly, |
360 | // for minor-axis, simply use the extent. |
361 | Array<PrimExpr> result; |
362 | ICHECK_EQ(transform_rule.size(), target_axis.size()); |
363 | for (size_t i = 0; i < transform_rule.size(); ++i) { |
364 | PrimExpr rule = transform_rule[i]; |
365 | IterVar axis = target_axis[i]; |
366 | if (!LayoutAxis::Get(axis).IsPrimal()) { |
367 | result.push_back(axis->dom->extent); |
368 | } else { |
369 | result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); |
370 | } |
371 | } |
372 | |
373 | std::stringstream ss; |
374 | ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name() |
375 | << ": [ " ; |
376 | for (const auto& r : transform_rule) { |
377 | ss << r << ", " ; |
378 | } |
379 | ss << "]" << std::endl; |
380 | |
381 | ss << "shape transform: [ " ; |
382 | for (const auto& s : src_shape) { |
383 | ss << s << ", " ; |
384 | } |
385 | ss << "] --> [ " ; |
386 | for (const auto& r : result) { |
387 | ss << r << ", " ; |
388 | } |
389 | ss << "]" << std::endl; |
390 | VLOG(1) << std::endl << ss.str(); |
391 | |
392 | return result; |
393 | } |
394 | |
395 | Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const { |
396 | ICHECK(defined()) << "Cannot operate on an undefined bijective layout." ; |
397 | const BijectiveLayoutNode* self = operator->(); |
398 | return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, |
399 | self->shape_forward_rule); |
400 | } |
401 | |
402 | Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const { |
403 | ICHECK(defined()) << "Cannot operate on an undefined bijective layout." ; |
404 | const BijectiveLayoutNode* self = operator->(); |
405 | return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, |
406 | self->shape_backward_rule); |
407 | } |
408 | |
409 | BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { |
410 | auto n = make_object<BijectiveLayoutNode>(); |
411 | |
412 | n->src_layout = std::move(src_layout); |
413 | n->dst_layout = std::move(dst_layout); |
414 | // To be consistent with previous behavior, a nullptr layout is created |
415 | // when argument is invalid. |
416 | if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) { |
417 | ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout, |
418 | n->src_layout)); |
419 | data_ = std::move(n); |
420 | } |
421 | } |
422 | |
423 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
424 | .set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) { |
425 | auto* b = static_cast<const BijectiveLayoutNode*>(node.get()); |
426 | p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() |
427 | << ")" ; |
428 | }); |
429 | |
430 | TVM_REGISTER_GLOBAL("tir.Layout" ).set_body_typed([](std::string name, DataType dtype) { |
431 | return Layout(name, dtype); |
432 | }); |
433 | |
434 | TVM_REGISTER_GLOBAL("tir.LayoutIndexOf" ).set_body_typed([](Layout layout, std::string axis) -> int { |
435 | return layout.IndexOf(LayoutAxis::Get(axis)); |
436 | }); |
437 | |
438 | TVM_REGISTER_GLOBAL("tir.LayoutFactorOf" ) |
439 | .set_body_typed([](Layout layout, std::string axis) -> int { |
440 | return layout.FactorOf(LayoutAxis::Get(axis)); |
441 | }); |
442 | |
443 | TVM_REGISTER_GLOBAL("tir.LayoutNdim" ).set_body_typed([](Layout layout) -> int { |
444 | return layout.ndim(); |
445 | }); |
446 | |
447 | TVM_REGISTER_GLOBAL("tir.LayoutGetItem" ).set_body_typed([](Layout layout, int idx) -> std::string { |
448 | const LayoutAxis& axis = layout[idx]; |
449 | return axis.name(); |
450 | }); |
451 | |
452 | TVM_REGISTER_GLOBAL("tir.BijectiveLayout" ) |
453 | .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { |
454 | return BijectiveLayout(src_layout, dst_layout); |
455 | }); |
456 | |
457 | TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex" ) |
458 | .set_body_method(&BijectiveLayout::ForwardIndex); |
459 | |
460 | TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex" ) |
461 | .set_body_method(&BijectiveLayout::BackwardIndex); |
462 | |
463 | TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape" ) |
464 | .set_body_method(&BijectiveLayout::ForwardShape); |
465 | |
466 | TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape" ) |
467 | .set_body_method(&BijectiveLayout::BackwardShape); |
468 | } // namespace tir |
469 | } // namespace tvm |
470 | |