1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file lower_opaque_block.cc
22 */
23
24#include <tvm/tir/stmt_functor.h>
25#include <tvm/tir/transform.h>
26
27#include "ir_utils.h"
28
29namespace tvm {
30namespace tir {
31
32/*!
33 * \brief Remove Block to ensure that the TIR can not be scheduled again.
34 */
35class OpaqueBlockLower : public StmtExprMutator {
36 private:
37 Stmt VisitStmt_(const BlockRealizeNode* op) final {
38 // We have convert blocks into opaque blocks in previous passes.
39 ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
40 "call pass ConvertBlocksToOpaque before.";
41 // Step 1. Visit the body
42 Block new_block = Downcast<Block>(this->VisitStmt(op->block));
43 PrimExpr predicate = this->VisitExpr(op->predicate);
44 // Step 2. Transform the `predicate` to if-then-else
45 Stmt body = new_block->body;
46 if (!is_one(predicate)) {
47 body = IfThenElse(predicate, std::move(body));
48 }
49 // Step 3. Handle allocations in reverse order
50 for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
51 const Buffer& buffer = new_block->alloc_buffers[i - 1];
52 Array<PrimExpr> new_shape = buffer->shape;
53 if (buffer->strides.size()) {
54 ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
55 for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
56 ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
57 new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
58 }
59 }
60 body = DeclBuffer(buffer, std::move(body));
61 body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body));
62 }
63 // Step 4. Handle annotations, block annotations are not preserved by default.
64 std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
65 HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
66 for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
67 body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
68 }
69 return body;
70 }
71
72 Stmt VisitStmt_(const ForNode* op) final {
73 // Step 1. Update unit loop info.
74 PrimExpr min = this->VisitExpr(op->min);
75 PrimExpr extent = this->VisitExpr(op->extent);
76 if (is_one(extent) && op->annotations.empty()) {
77 // handling unit loop
78 unit_loop_vars_[op->loop_var] = min;
79 }
80 // Step 2. Visit recursively
81 Stmt body = this->VisitStmt(op->body);
82 // Step 3. Handle annotations
83 std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
84 Map<String, ObjectRef> new_annotations =
85 HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false);
86 // Step 4. Create new For loop accordingly
87 if (op->kind == ForKind::kThreadBinding) {
88 // Case 1. Thread binding
89 ICHECK(op->thread_binding.defined());
90 String thread_tag = op->thread_binding.value()->thread_tag;
91 body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
92 } else if (is_one(extent) && op->annotations.empty()) {
93 // Case 2. Unit loop
94 return body;
95 } else {
96 // Case 3. An ordinary loop
97 body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body),
98 NullOpt, new_annotations);
99 }
100 // Step 5. Insert nested attrs
101 for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
102 body = AttrStmt(op->loop_var, it->first, it->second, std::move(body));
103 }
104 return body;
105 }
106
107 PrimExpr VisitExpr_(const VarNode* op) final {
108 Var var = GetRef<Var>(op);
109 auto it = unit_loop_vars_.find(var);
110 if (it == unit_loop_vars_.end()) {
111 return std::move(var);
112 } else {
113 PrimExpr expr = it->second;
114 if (expr.dtype() != var.dtype()) {
115 expr = tvm::cast(var.dtype(), std::move(expr));
116 }
117 return expr;
118 }
119 }
120
121 static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag,
122 Stmt body) {
123 IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
124 /*var=*/std::move(var),
125 /*iter_type=*/IterVarType::kThreadIndex,
126 /*thread_tag=*/thread_tag);
127 String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
128 thread_tag == "vthread.y" || thread_tag == "vthread.z")
129 ? attr::virtual_thread
130 : attr::thread_extent;
131 return AttrStmt(/*node=*/std::move(iter_var),
132 /*attr_key=*/std::move(attr_key),
133 /*value=*/std::move(extent),
134 /*body=*/std::move(body));
135 }
136
137 /*! \brief Convert attr value from annotation map into PrimExpr. */
138 PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
139 if (!obj.defined()) {
140 return PrimExpr();
141 } else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
142 return GetRef<PrimExpr>(expr);
143 } else if (const StringObj* str = obj.as<StringObj>()) {
144 return std::move(StringImm(str->data));
145 } else {
146 LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey()
147 << " not supported";
148 return PrimExpr();
149 }
150 }
151
152 /*!
153 * \brief Helper to handle annotation dict.
154 * (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They
155 * are lowered to `AttrStmt` by legacy TE schedule convention.
156 * (2) the non-pragma loop annotations are preserved
157 * (3) the non-pragma block annotations are dropped
158 * \return New annotation dict with preserved keys. Also update pragma attr pairs ordered by key.
159 */
160 Map<String, ObjectRef> HandleAnnotations(
161 const Map<String, ObjectRef>& annotations,
162 std::vector<std::pair<std::string, PrimExpr>>* pragma_attrs, bool is_block) {
163 Map<String, ObjectRef> preserved_annotations;
164 pragma_attrs->clear();
165 for (const auto& kv : annotations) {
166 const String& key = kv.first;
167 if (attr::IsPragmaKey(key)) {
168 pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
169 } else if (!is_block) {
170 // the loop annotation is preserved
171 preserved_annotations.Set(key, kv.second);
172 }
173 }
174 std::sort(pragma_attrs->begin(), pragma_attrs->end(),
175 [](const auto& p1, const auto& p2) { return p1.first < p2.first; });
176 return preserved_annotations;
177 }
178
179 /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
180 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
181
182 /*! \brief Attr keys to preserve into loop annotations. */
183 std::unordered_set<std::string> preserved_annotations_;
184};
185
186PrimFunc LowerOpaqueBlock(PrimFunc f) {
187 // Only apply this pass to TIR that is not from TE schedules
188 if (!IsFromLegacyTESchedule(f)) {
189 auto fptr = f.CopyOnWrite();
190 fptr->body = OpaqueBlockLower()(std::move(fptr->body));
191 return f;
192 } else {
193 return f;
194 }
195}
196
197namespace transform {
198
199Pass LowerOpaqueBlock() {
200 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
201 return LowerOpaqueBlock(std::move(f));
202 };
203 return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {});
204}
205
206TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock);
207} // namespace transform
208
209} // namespace tir
210} // namespace tvm
211