1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Hybrid computation rule. |
22 | * \file hybrid_op.cc |
23 | */ |
24 | #include "hybrid_op.h" |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/te/operation.h> |
29 | #include <tvm/tir/analysis.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/op.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | |
34 | #include <string> |
35 | #include <unordered_set> |
36 | #include <utility> |
37 | |
38 | #include "op_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace te { |
42 | using namespace tir; |
43 | // HybridOpNode |
44 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
45 | .set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) { |
46 | auto* op = static_cast<const HybridOpNode*>(node.get()); |
47 | p->stream << "hybrid(" << op->name << ", " << op << ")" ; |
48 | }); |
49 | |
50 | TVM_REGISTER_NODE_TYPE(HybridOpNode); |
51 | |
52 | int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); } |
53 | |
54 | Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; } |
55 | |
56 | DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } |
57 | |
58 | Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } |
59 | |
60 | HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, |
61 | Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) { |
62 | if (!attrs.defined()) { |
63 | attrs = Map<String, ObjectRef>(); |
64 | } |
65 | auto n = make_object<HybridOpNode>(); |
66 | n->name = std::move(name); |
67 | n->tag = std::move(tag); |
68 | n->attrs = std::move(attrs); |
69 | n->inputs = std::move(inputs); |
70 | n->outputs = std::move(outputs); |
71 | n->axis = te::GatherLoopVars(body); |
72 | n->body = std::move(body); |
73 | data_ = std::move(n); |
74 | } |
75 | |
76 | TVM_REGISTER_GLOBAL("te.HybridOp" ) |
77 | .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs, |
78 | Array<Tensor> inputs, Array<Tensor> outputs, |
79 | Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); }); |
80 | |
81 | Array<Tensor> HybridOpNode::InputTensors() const { |
82 | // Because input tensors could be potentially inlined into hybrid scripts, |
83 | // we need to check if all input tensors are used in the body. |
84 | std::unordered_set<Tensor> orig_inputs; |
85 | for (auto t : inputs) { |
86 | orig_inputs.insert(t); |
87 | } |
88 | std::unordered_set<Tensor> visited; |
89 | Array<Tensor> curr_inputs; |
90 | tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { |
91 | if (auto* pload = n.as<tir::ProducerLoadNode>()) { |
92 | Tensor t = Downcast<Tensor>(pload->producer); |
93 | if (orig_inputs.count(t) && !visited.count(t)) { |
94 | curr_inputs.push_back(t); |
95 | visited.insert(t); |
96 | } |
97 | } |
98 | }); |
99 | return curr_inputs; |
100 | } |
101 | |
102 | Operation HybridOpNode::ReplaceInputs(const Operation& self, |
103 | const std::unordered_map<Tensor, Tensor>& rmap) const { |
104 | ICHECK_EQ(self.operator->(), this); |
105 | auto n = make_object<HybridOpNode>(*this); |
106 | n->body = te::ReplaceTensor(this->body, rmap); |
107 | for (size_t i = 0; i < n->inputs.size(); ++i) { |
108 | Tensor t = n->inputs[i]; |
109 | if (rmap.count(t)) { |
110 | n->inputs.Set(i, rmap.at(t)); |
111 | } |
112 | } |
113 | |
114 | if (body.same_as(n->body) && inputs.same_as(n->inputs)) { |
115 | return self; |
116 | } else { |
117 | return Operation(n); |
118 | } |
119 | } |
120 | |
121 | void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
122 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
123 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const { |
124 | auto curr_inputs = InputTensors(); |
125 | for (Tensor t : curr_inputs) { |
126 | auto it = out_dom_map->find(t); |
127 | if (it == out_dom_map->end()) continue; |
128 | TensorDom& dom = it->second; |
129 | for (size_t i = 0; i < t->shape.size(); ++i) { |
130 | dom.data[i].emplace_back( |
131 | IntSet::FromRange(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); |
132 | } |
133 | } |
134 | } |
135 | |
136 | void HybridOpNode::GatherBound(const Operation& self, |
137 | const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
138 | std::unordered_map<IterVar, Range>* out_dom_map) const { |
139 | for (auto iter_var : axis) { |
140 | ICHECK(!out_dom_map->count(iter_var)); |
141 | out_dom_map->operator[](iter_var) = iter_var->dom; |
142 | } |
143 | } |
144 | |
145 | Stmt HybridOpNode::BuildRealize(const Stage& stage, |
146 | const std::unordered_map<IterVar, Range>& realize_map, |
147 | const Stmt& body, String storage_scope) const { |
148 | // TODO(@were): Add attribute inject here and remove it from hybrid parser. |
149 | ICHECK_EQ(stage->op.get(), this); |
150 | Stmt realize_body = body; |
151 | for (int k = 0; k < num_outputs(); ++k) { |
152 | Tensor t = stage->op.output(k); |
153 | Region bounds; |
154 | for (size_t i = 0; i < t->shape.size(); ++i) { |
155 | bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); |
156 | } |
157 | realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); |
158 | } |
159 | return realize_body; |
160 | } |
161 | |
162 | Stmt HybridOpNode::BuildProvide(const Stage& stage, |
163 | const std::unordered_map<IterVar, Range>& dom_map, |
164 | bool debug_keep_trivial_loop) const { |
165 | ICHECK_EQ(stage->op.operator->(), this); |
166 | Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); |
167 | std::unordered_map<Tensor, Tensor> rmap; |
168 | for (int i = 0; i < this->num_outputs(); ++i) { |
169 | rmap[outputs[i]] = stage->op.output(i); |
170 | } |
171 | auto n = make_object<HybridOpNode>(*this); |
172 | /* This is a story little bit complicated. |
173 | * The following two lines of codes replace output tensors' usage. |
174 | * This is the simplest way I (@were) can come up with to glue |
175 | * hybrid operation node to TVM op system. |
176 | * In hybrid script all the tensors, especially the output tensors, |
177 | * have their own names defined by the users. However, In TVM |
178 | * conventional ops: |
179 | * 1. Output tensors refer the corresponding op node so that the output |
180 | * tensors have the same names as the operation produces them. |
181 | * 2. Once OpNode is wrapped up by an Operation node, it is finalized. |
182 | * Later access will be from a const OpNode*. |
183 | * This is a chicken-egg paradox. It is impossible to put the output |
184 | * tensors into the function body without forming the op node. The |
185 | * function body is immutable after the node is formed. |
186 | * |
187 | * Finally, I decided to resolve this issue "lazily". During the |
188 | * pipeline of compilation, this stage is a very preliminary stage. |
189 | * Technically, it is before Phase 0. The actual tensors will be replaced |
190 | * here. |
191 | * Thus, the operation body is slightly different from the Phase 0 body. |
192 | * This is a major difference that HybridOpNode is NOT the same as |
193 | * ExternOpNode. |
194 | * */ |
195 | ret = te::ReplaceTensor(ret, rmap); |
196 | ret = te::ReplaceProvideTensor(ret, rmap); |
197 | |
198 | ret = te::ApplySchedule(stage, dom_map, ret); |
199 | return ret; |
200 | } |
201 | |
202 | Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
203 | Stmt stmt) { |
204 | class LoopSpliter : public StmtExprMutator { |
205 | PrimExpr factor; |
206 | const VarNode* parent; |
207 | IterVar inner, outer; |
208 | |
209 | public: |
210 | bool splitted; |
211 | LoopSpliter(const SplitNode* split, const std::unordered_map<IterVar, Range>& dom_map) |
212 | : factor(split->factor), splitted(false) { |
213 | parent = split->parent->var.get(); |
214 | |
215 | auto& inner_ = split->inner; |
216 | ICHECK(dom_map.count(inner_)); |
217 | auto& inner_dom = dom_map.find(inner_)->second; |
218 | ICHECK(is_const_int(inner_dom->min, 0)); |
219 | |
220 | auto& outer_ = split->outer; |
221 | ICHECK(dom_map.count(outer_)); |
222 | auto& outer_dom = dom_map.find(outer_)->second; |
223 | ICHECK(is_const_int(outer_dom->min, 0)); |
224 | |
225 | inner = IterVar(inner_dom, inner_->var, inner_->iter_type); |
226 | outer = IterVar(outer_dom, outer_->var, outer_->iter_type); |
227 | } |
228 | |
229 | Stmt VisitStmt_(const ForNode* op) final { |
230 | if (op->loop_var.get() == parent) { |
231 | std::unordered_map<const VarNode*, PrimExpr> rmap; |
232 | rmap[op->loop_var.get()] = inner + outer * factor; |
233 | Stmt ret = tir::Substitute(op->body, rmap); |
234 | PrimExpr cond = likely(outer * factor < (op->extent - inner)); |
235 | ret = IfThenElse(cond, ret); |
236 | ret = For(inner->var, PrimExpr(0), inner->dom->extent, |
237 | IterVarTypeToForKind(inner->iter_type), ret); |
238 | ret = For(outer->var, PrimExpr(0), outer->dom->extent, |
239 | IterVarTypeToForKind(outer->iter_type), ret); |
240 | splitted = true; |
241 | return ret; |
242 | } |
243 | return StmtExprMutator::VisitStmt_(op); |
244 | } |
245 | }; |
246 | |
247 | class LoopFuser : public StmtExprMutator { |
248 | const IterVar& parent; |
249 | const VarNode* inner; |
250 | const VarNode* outer; |
251 | bool under_outer; |
252 | PrimExpr extent; |
253 | |
254 | public: |
255 | bool fused; |
256 | explicit LoopFuser(const FuseNode* fuse_) |
257 | : parent(fuse_->fused), |
258 | inner(fuse_->inner->var.get()), |
259 | outer(fuse_->outer->var.get()), |
260 | under_outer(false), |
261 | extent(0), |
262 | fused(false) {} |
263 | |
264 | // TODO(@were): Handle imperfect loops |
265 | Stmt VisitStmt_(const ForNode* op) final { |
266 | if (op->loop_var.get() == inner) { |
267 | ICHECK(under_outer); |
268 | std::unordered_map<const VarNode*, PrimExpr> rmap; |
269 | rmap[op->loop_var.get()] = indexmod(parent, op->extent); |
270 | extent = op->extent; |
271 | fused = true; |
272 | return tir::Substitute(op->body, rmap); |
273 | } else if (op->loop_var.get() == outer) { |
274 | under_outer = true; |
275 | Stmt body = this->VisitStmt(op->body); |
276 | std::unordered_map<const VarNode*, PrimExpr> rmap; |
277 | rmap[op->loop_var.get()] = indexdiv(parent, extent); |
278 | body = tir::Substitute(body, rmap); |
279 | under_outer = false; |
280 | return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, |
281 | op->thread_binding, op->annotations); |
282 | } else if (under_outer) { |
283 | Stmt body = this->VisitStmt(op->body); |
284 | std::unordered_map<const VarNode*, PrimExpr> rmap; |
285 | rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); |
286 | body = tir::Substitute(body, rmap); |
287 | extent = extent * op->extent; |
288 | return body; |
289 | } |
290 | return StmtExprMutator::VisitStmt_(op); |
291 | } |
292 | }; |
293 | |
294 | for (auto& rel : stage->relations) { |
295 | if (const SplitNode* split = rel.as<SplitNode>()) { |
296 | LoopSpliter Spliter(split, dom_map); |
297 | stmt = Spliter(stmt); |
298 | ICHECK(Spliter.splitted); |
299 | } else if (const FuseNode* fuse = rel.as<FuseNode>()) { |
300 | LoopFuser Fuser(fuse); |
301 | stmt = Fuser(stmt); |
302 | ICHECK(Fuser.fused); |
303 | } |
304 | } |
305 | |
306 | return stmt; |
307 | } |
308 | |
309 | Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar, IterVar>& rebased, |
310 | Stmt stmt) { |
311 | class LoopAnnotator : public StmtMutator { |
312 | const VarNode* var; |
313 | const IterVarAttr& attr; |
314 | |
315 | public: |
316 | LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {} |
317 | |
318 | Stmt VisitStmt_(const ForNode* op) final { |
319 | tir::ExprDeepEqual expr_equal; |
320 | |
321 | if (op->loop_var.get() == var) { |
322 | if (attr->bind_thread.defined()) { |
323 | const auto& iter_var = attr->bind_thread; |
324 | if (iter_var->dom.defined()) { |
325 | ICHECK(is_const_int(iter_var->dom->min, 0)); |
326 | ICHECK(expr_equal(iter_var->dom->extent, op->extent)) |
327 | << "Thread extent and loop extent mismatch!\n" ; |
328 | } |
329 | std::unordered_map<const VarNode*, PrimExpr> rmap; |
330 | rmap[op->loop_var.get()] = iter_var; |
331 | Stmt body = tir::Substitute(op->body, rmap); |
332 | return AttrStmt(iter_var, "thread_extent" , op->extent, body); |
333 | } else { |
334 | return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type), |
335 | op->body, op->thread_binding, op->annotations); |
336 | } |
337 | } |
338 | return StmtMutator::VisitStmt_(op); |
339 | } |
340 | }; |
341 | |
342 | for (auto& iter_var : stage->leaf_iter_vars) { |
343 | bool need_change = false; |
344 | int found = 0; |
345 | |
346 | const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; |
347 | const VarNode* var = actual->var.get(); |
348 | ForKind expected = IterVarTypeToForKind(iter_var->iter_type); |
349 | IterVarAttr attr; |
350 | if (stage->iter_var_attrs.count(iter_var)) { |
351 | attr = stage->iter_var_attrs[iter_var]; |
352 | expected = IterVarTypeToForKind(attr->iter_type); |
353 | } |
354 | |
355 | PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { |
356 | if (const ForNode* op = node.as<ForNode>()) { |
357 | if (op->loop_var.get() == var) { |
358 | ++found; |
359 | need_change = expected != op->kind || (attr.defined() && attr->bind_thread.defined()); |
360 | } |
361 | } |
362 | }); |
363 | |
364 | ICHECK_EQ(found, 1) << " iter var should be found exactly once!" ; |
365 | if (need_change) { |
366 | stmt = LoopAnnotator(var, attr)(std::move(stmt)); |
367 | } |
368 | } |
369 | return stmt; |
370 | } |
371 | |
372 | Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
373 | const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt) { |
374 | std::vector<const VarNode*> current_order; |
375 | PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { |
376 | if (const ForNode* op = node.as<ForNode>()) current_order.push_back(op->loop_var.get()); |
377 | }); |
378 | std::reverse(current_order.begin(), current_order.end()); |
379 | auto& required_ord = stage->leaf_iter_vars; |
380 | ICHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!" ; |
381 | std::unordered_map<const VarNode*, IterVar> reorder; |
382 | bool need_reorder = false; |
383 | for (size_t i = 0; i < current_order.size(); ++i) { |
384 | auto& current = current_order[i]; |
385 | const IterVar& iter_var = required_ord[i]; |
386 | const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; |
387 | ICHECK(required->dom.defined() || dom_map.count(required)) << required << "\n" ; |
388 | reorder[current] = required; |
389 | if (current != required->var.get()) { |
390 | need_reorder = true; |
391 | } |
392 | } |
393 | |
394 | class LoopReorder : public StmtMutator { |
395 | const Stage& stage; |
396 | const std::unordered_map<IterVar, Range>& dom_map; |
397 | const std::unordered_map<const VarNode*, IterVar>& reorder; |
398 | |
399 | public: |
400 | LoopReorder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
401 | const std::unordered_map<const VarNode*, IterVar>& reorder) |
402 | : stage(stage), dom_map(dom_map), reorder(reorder) {} |
403 | |
404 | Stmt VisitStmt_(const ForNode* op) final { |
405 | // Reorder from in to out |
406 | Stmt body_ = this->VisitStmt(op->body); |
407 | ICHECK(reorder.count(op->loop_var.get())); |
408 | auto target = reorder.find(op->loop_var.get())->second; |
409 | if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) |
410 | return GetRef<Stmt>(op); |
411 | const Stmt& body = op->body.same_as(body_) ? op->body : body_; |
412 | ForKind kind = IterVarTypeToForKind(target->iter_type); |
413 | if (stage->iter_var_attrs.count(target)) { |
414 | kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type); |
415 | } |
416 | const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; |
417 | return For(target->var, range->min, range->extent, kind, body, op->thread_binding, |
418 | op->annotations); |
419 | } |
420 | }; |
421 | |
422 | if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt); |
423 | |
424 | return stmt; |
425 | } |
426 | |
427 | Stmt ApplySchedule(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
428 | Stmt stmt) { |
429 | // TODO(@were): Eliminate loop rebase in script parser and move the burden here |
430 | // Gather rebased variables |
431 | std::unordered_map<IterVar, IterVar> rebased; |
432 | for (auto rel : stage->relations) { |
433 | if (const auto* rebase = rel.as<RebaseNode>()) { |
434 | rebased[rebase->rebased] = rebase->parent; |
435 | ICHECK(rebase->parent->dom.defined()); |
436 | ICHECK(dom_map.count(rebase->rebased)); |
437 | } |
438 | } |
439 | stmt = ApplyLoopShapes(stage, dom_map, stmt); |
440 | stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt); |
441 | stmt = ApplyLoopAnnotations(stage, rebased, stmt); |
442 | return stmt; |
443 | } |
444 | |
445 | std::vector<IterVar> GatherLoopVars(Stmt stmt) { |
446 | // TODO(@were): Write a comprehensive pass to analyze iter var types |
447 | std::vector<IterVar> res_; |
448 | PostOrderVisit(stmt, [&res_](const ObjectRef& node) { |
449 | if (const ForNode* op = node.as<ForNode>()) { |
450 | Var loop_var(op->loop_var); |
451 | Range dom = Range::FromMinExtent(op->min, cast(loop_var.dtype(), op->extent)); |
452 | res_.push_back(IterVar(dom, loop_var, ForKindToIterVarType(op->kind))); |
453 | } |
454 | }); |
455 | std::reverse(res_.begin(), res_.end()); |
456 | return res_; |
457 | } |
458 | |
459 | // replacer to replace tensors' usage in Provide |
460 | class ProviderReplacer : public tir::StmtMutator { |
461 | public: |
462 | explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {} |
463 | |
464 | Stmt VisitStmt_(const tir::ProducerStoreNode* op) final { |
465 | Tensor t = Downcast<Tensor>(op->producer); |
466 | auto it = vmap_.find(t); |
467 | if (it != vmap_.end()) { |
468 | Stmt ret = tir::ProducerStore(it->second, op->value, op->indices); |
469 | found = true; |
470 | return this->VisitStmt(ret); |
471 | } |
472 | return StmtMutator::VisitStmt_(op); |
473 | } |
474 | |
475 | // whether it is found. |
476 | bool found{false}; |
477 | |
478 | private: |
479 | const std::unordered_map<Tensor, Tensor>& vmap_; |
480 | }; |
481 | |
482 | Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) { |
483 | ProviderReplacer repl(replace); |
484 | Stmt ret = repl(stmt); |
485 | return repl.found ? ret : stmt; |
486 | } |
487 | } // namespace te |
488 | } // namespace tvm |
489 | |