1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file lower_opaque_block.cc |
22 | */ |
23 | |
24 | #include <tvm/tir/stmt_functor.h> |
25 | #include <tvm/tir/transform.h> |
26 | |
27 | #include "ir_utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace tir { |
31 | |
32 | /*! |
33 | * \brief Remove Block to ensure that the TIR can not be scheduled again. |
34 | */ |
35 | class OpaqueBlockLower : public StmtExprMutator { |
36 | private: |
37 | Stmt VisitStmt_(const BlockRealizeNode* op) final { |
38 | // We have convert blocks into opaque blocks in previous passes. |
39 | ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " |
40 | "call pass ConvertBlocksToOpaque before." ; |
41 | // Step 1. Visit the body |
42 | Block new_block = Downcast<Block>(this->VisitStmt(op->block)); |
43 | PrimExpr predicate = this->VisitExpr(op->predicate); |
44 | // Step 2. Transform the `predicate` to if-then-else |
45 | Stmt body = new_block->body; |
46 | if (!is_one(predicate)) { |
47 | body = IfThenElse(predicate, std::move(body)); |
48 | } |
49 | // Step 3. Handle allocations in reverse order |
50 | for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { |
51 | const Buffer& buffer = new_block->alloc_buffers[i - 1]; |
52 | Array<PrimExpr> new_shape = buffer->shape; |
53 | if (buffer->strides.size()) { |
54 | ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); |
55 | for (size_t i = buffer->strides.size() - 1; i > 0; --i) { |
56 | ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))); |
57 | new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); |
58 | } |
59 | } |
60 | body = DeclBuffer(buffer, std::move(body)); |
61 | body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body)); |
62 | } |
63 | // Step 4. Handle annotations, block annotations are not preserved by default. |
64 | std::vector<std::pair<std::string, PrimExpr>> pragma_attrs; |
65 | HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true); |
66 | for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { |
67 | body = AttrStmt(Integer(0), it->first, it->second, std::move(body)); |
68 | } |
69 | return body; |
70 | } |
71 | |
72 | Stmt VisitStmt_(const ForNode* op) final { |
73 | // Step 1. Update unit loop info. |
74 | PrimExpr min = this->VisitExpr(op->min); |
75 | PrimExpr extent = this->VisitExpr(op->extent); |
76 | if (is_one(extent) && op->annotations.empty()) { |
77 | // handling unit loop |
78 | unit_loop_vars_[op->loop_var] = min; |
79 | } |
80 | // Step 2. Visit recursively |
81 | Stmt body = this->VisitStmt(op->body); |
82 | // Step 3. Handle annotations |
83 | std::vector<std::pair<std::string, PrimExpr>> pragma_attrs; |
84 | Map<String, ObjectRef> new_annotations = |
85 | HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false); |
86 | // Step 4. Create new For loop accordingly |
87 | if (op->kind == ForKind::kThreadBinding) { |
88 | // Case 1. Thread binding |
89 | ICHECK(op->thread_binding.defined()); |
90 | String thread_tag = op->thread_binding.value()->thread_tag; |
91 | body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); |
92 | } else if (is_one(extent) && op->annotations.empty()) { |
93 | // Case 2. Unit loop |
94 | return body; |
95 | } else { |
96 | // Case 3. An ordinary loop |
97 | body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body), |
98 | NullOpt, new_annotations); |
99 | } |
100 | // Step 5. Insert nested attrs |
101 | for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { |
102 | body = AttrStmt(op->loop_var, it->first, it->second, std::move(body)); |
103 | } |
104 | return body; |
105 | } |
106 | |
107 | PrimExpr VisitExpr_(const VarNode* op) final { |
108 | Var var = GetRef<Var>(op); |
109 | auto it = unit_loop_vars_.find(var); |
110 | if (it == unit_loop_vars_.end()) { |
111 | return std::move(var); |
112 | } else { |
113 | PrimExpr expr = it->second; |
114 | if (expr.dtype() != var.dtype()) { |
115 | expr = tvm::cast(var.dtype(), std::move(expr)); |
116 | } |
117 | return expr; |
118 | } |
119 | } |
120 | |
121 | static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, |
122 | Stmt body) { |
123 | IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), |
124 | /*var=*/std::move(var), |
125 | /*iter_type=*/IterVarType::kThreadIndex, |
126 | /*thread_tag=*/thread_tag); |
127 | String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || |
128 | thread_tag == "vthread.y" || thread_tag == "vthread.z" ) |
129 | ? attr::virtual_thread |
130 | : attr::thread_extent; |
131 | return AttrStmt(/*node=*/std::move(iter_var), |
132 | /*attr_key=*/std::move(attr_key), |
133 | /*value=*/std::move(extent), |
134 | /*body=*/std::move(body)); |
135 | } |
136 | |
137 | /*! \brief Convert attr value from annotation map into PrimExpr. */ |
138 | PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) { |
139 | if (!obj.defined()) { |
140 | return PrimExpr(); |
141 | } else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) { |
142 | return GetRef<PrimExpr>(expr); |
143 | } else if (const StringObj* str = obj.as<StringObj>()) { |
144 | return std::move(StringImm(str->data)); |
145 | } else { |
146 | LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey() |
147 | << " not supported" ; |
148 | return PrimExpr(); |
149 | } |
150 | } |
151 | |
152 | /*! |
153 | * \brief Helper to handle annotation dict. |
154 | * (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They |
155 | * are lowered to `AttrStmt` by legacy TE schedule convention. |
156 | * (2) the non-pragma loop annotations are preserved |
157 | * (3) the non-pragma block annotations are dropped |
158 | * \return New annotation dict with preserved keys. Also update pragma attr pairs ordered by key. |
159 | */ |
160 | Map<String, ObjectRef> HandleAnnotations( |
161 | const Map<String, ObjectRef>& annotations, |
162 | std::vector<std::pair<std::string, PrimExpr>>* pragma_attrs, bool is_block) { |
163 | Map<String, ObjectRef> preserved_annotations; |
164 | pragma_attrs->clear(); |
165 | for (const auto& kv : annotations) { |
166 | const String& key = kv.first; |
167 | if (attr::IsPragmaKey(key)) { |
168 | pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); |
169 | } else if (!is_block) { |
170 | // the loop annotation is preserved |
171 | preserved_annotations.Set(key, kv.second); |
172 | } |
173 | } |
174 | std::sort(pragma_attrs->begin(), pragma_attrs->end(), |
175 | [](const auto& p1, const auto& p2) { return p1.first < p2.first; }); |
176 | return preserved_annotations; |
177 | } |
178 | |
179 | /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ |
180 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_; |
181 | |
182 | /*! \brief Attr keys to preserve into loop annotations. */ |
183 | std::unordered_set<std::string> preserved_annotations_; |
184 | }; |
185 | |
186 | PrimFunc LowerOpaqueBlock(PrimFunc f) { |
187 | // Only apply this pass to TIR that is not from TE schedules |
188 | if (!IsFromLegacyTESchedule(f)) { |
189 | auto fptr = f.CopyOnWrite(); |
190 | fptr->body = OpaqueBlockLower()(std::move(fptr->body)); |
191 | return f; |
192 | } else { |
193 | return f; |
194 | } |
195 | } |
196 | |
197 | namespace transform { |
198 | |
199 | Pass LowerOpaqueBlock() { |
200 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
201 | return LowerOpaqueBlock(std::move(f)); |
202 | }; |
203 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock" , {}); |
204 | } |
205 | |
206 | TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock" ).set_body_typed(LowerOpaqueBlock); |
207 | } // namespace transform |
208 | |
209 | } // namespace tir |
210 | } // namespace tvm |
211 | |