1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Hybrid computation rule.
22 * \file hybrid_op.cc
23 */
24#include "hybrid_op.h"
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/te/operation.h>
29#include <tvm/tir/analysis.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/op.h>
32#include <tvm/tir/stmt_functor.h>
33
34#include <string>
35#include <unordered_set>
36#include <utility>
37
38#include "op_utils.h"
39
40namespace tvm {
41namespace te {
42using namespace tir;
43// HybridOpNode
44TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
45 .set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) {
46 auto* op = static_cast<const HybridOpNode*>(node.get());
47 p->stream << "hybrid(" << op->name << ", " << op << ")";
48 });
49
50TVM_REGISTER_NODE_TYPE(HybridOpNode);
51
52int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); }
53
54Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; }
55
56DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; }
57
58Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
59
60HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
61 Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
62 if (!attrs.defined()) {
63 attrs = Map<String, ObjectRef>();
64 }
65 auto n = make_object<HybridOpNode>();
66 n->name = std::move(name);
67 n->tag = std::move(tag);
68 n->attrs = std::move(attrs);
69 n->inputs = std::move(inputs);
70 n->outputs = std::move(outputs);
71 n->axis = te::GatherLoopVars(body);
72 n->body = std::move(body);
73 data_ = std::move(n);
74}
75
76TVM_REGISTER_GLOBAL("te.HybridOp")
77 .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
78 Array<Tensor> inputs, Array<Tensor> outputs,
79 Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); });
80
81Array<Tensor> HybridOpNode::InputTensors() const {
82 // Because input tensors could be potentially inlined into hybrid scripts,
83 // we need to check if all input tensors are used in the body.
84 std::unordered_set<Tensor> orig_inputs;
85 for (auto t : inputs) {
86 orig_inputs.insert(t);
87 }
88 std::unordered_set<Tensor> visited;
89 Array<Tensor> curr_inputs;
90 tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
91 if (auto* pload = n.as<tir::ProducerLoadNode>()) {
92 Tensor t = Downcast<Tensor>(pload->producer);
93 if (orig_inputs.count(t) && !visited.count(t)) {
94 curr_inputs.push_back(t);
95 visited.insert(t);
96 }
97 }
98 });
99 return curr_inputs;
100}
101
102Operation HybridOpNode::ReplaceInputs(const Operation& self,
103 const std::unordered_map<Tensor, Tensor>& rmap) const {
104 ICHECK_EQ(self.operator->(), this);
105 auto n = make_object<HybridOpNode>(*this);
106 n->body = te::ReplaceTensor(this->body, rmap);
107 for (size_t i = 0; i < n->inputs.size(); ++i) {
108 Tensor t = n->inputs[i];
109 if (rmap.count(t)) {
110 n->inputs.Set(i, rmap.at(t));
111 }
112 }
113
114 if (body.same_as(n->body) && inputs.same_as(n->inputs)) {
115 return self;
116 } else {
117 return Operation(n);
118 }
119}
120
121void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
122 const std::unordered_map<const VarNode*, IntSet>& dom_map,
123 std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
124 auto curr_inputs = InputTensors();
125 for (Tensor t : curr_inputs) {
126 auto it = out_dom_map->find(t);
127 if (it == out_dom_map->end()) continue;
128 TensorDom& dom = it->second;
129 for (size_t i = 0; i < t->shape.size(); ++i) {
130 dom.data[i].emplace_back(
131 IntSet::FromRange(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
132 }
133 }
134}
135
136void HybridOpNode::GatherBound(const Operation& self,
137 const std::unordered_map<Tensor, TensorDom>& tensor_dom,
138 std::unordered_map<IterVar, Range>* out_dom_map) const {
139 for (auto iter_var : axis) {
140 ICHECK(!out_dom_map->count(iter_var));
141 out_dom_map->operator[](iter_var) = iter_var->dom;
142 }
143}
144
145Stmt HybridOpNode::BuildRealize(const Stage& stage,
146 const std::unordered_map<IterVar, Range>& realize_map,
147 const Stmt& body, String storage_scope) const {
148 // TODO(@were): Add attribute inject here and remove it from hybrid parser.
149 ICHECK_EQ(stage->op.get(), this);
150 Stmt realize_body = body;
151 for (int k = 0; k < num_outputs(); ++k) {
152 Tensor t = stage->op.output(k);
153 Region bounds;
154 for (size_t i = 0; i < t->shape.size(); ++i) {
155 bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
156 }
157 realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope);
158 }
159 return realize_body;
160}
161
162Stmt HybridOpNode::BuildProvide(const Stage& stage,
163 const std::unordered_map<IterVar, Range>& dom_map,
164 bool debug_keep_trivial_loop) const {
165 ICHECK_EQ(stage->op.operator->(), this);
166 Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
167 std::unordered_map<Tensor, Tensor> rmap;
168 for (int i = 0; i < this->num_outputs(); ++i) {
169 rmap[outputs[i]] = stage->op.output(i);
170 }
171 auto n = make_object<HybridOpNode>(*this);
172 /* This is a story little bit complicated.
173 * The following two lines of codes replace output tensors' usage.
174 * This is the simplest way I (@were) can come up with to glue
175 * hybrid operation node to TVM op system.
176 * In hybrid script all the tensors, especially the output tensors,
177 * have their own names defined by the users. However, In TVM
178 * conventional ops:
179 * 1. Output tensors refer the corresponding op node so that the output
180 * tensors have the same names as the operation produces them.
181 * 2. Once OpNode is wrapped up by an Operation node, it is finalized.
182 * Later access will be from a const OpNode*.
183 * This is a chicken-egg paradox. It is impossible to put the output
184 * tensors into the function body without forming the op node. The
185 * function body is immutable after the node is formed.
186 *
187 * Finally, I decided to resolve this issue "lazily". During the
188 * pipeline of compilation, this stage is a very preliminary stage.
189 * Technically, it is before Phase 0. The actual tensors will be replaced
190 * here.
191 * Thus, the operation body is slightly different from the Phase 0 body.
192 * This is a major difference that HybridOpNode is NOT the same as
193 * ExternOpNode.
194 * */
195 ret = te::ReplaceTensor(ret, rmap);
196 ret = te::ReplaceProvideTensor(ret, rmap);
197
198 ret = te::ApplySchedule(stage, dom_map, ret);
199 return ret;
200}
201
202Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
203 Stmt stmt) {
204 class LoopSpliter : public StmtExprMutator {
205 PrimExpr factor;
206 const VarNode* parent;
207 IterVar inner, outer;
208
209 public:
210 bool splitted;
211 LoopSpliter(const SplitNode* split, const std::unordered_map<IterVar, Range>& dom_map)
212 : factor(split->factor), splitted(false) {
213 parent = split->parent->var.get();
214
215 auto& inner_ = split->inner;
216 ICHECK(dom_map.count(inner_));
217 auto& inner_dom = dom_map.find(inner_)->second;
218 ICHECK(is_const_int(inner_dom->min, 0));
219
220 auto& outer_ = split->outer;
221 ICHECK(dom_map.count(outer_));
222 auto& outer_dom = dom_map.find(outer_)->second;
223 ICHECK(is_const_int(outer_dom->min, 0));
224
225 inner = IterVar(inner_dom, inner_->var, inner_->iter_type);
226 outer = IterVar(outer_dom, outer_->var, outer_->iter_type);
227 }
228
229 Stmt VisitStmt_(const ForNode* op) final {
230 if (op->loop_var.get() == parent) {
231 std::unordered_map<const VarNode*, PrimExpr> rmap;
232 rmap[op->loop_var.get()] = inner + outer * factor;
233 Stmt ret = tir::Substitute(op->body, rmap);
234 PrimExpr cond = likely(outer * factor < (op->extent - inner));
235 ret = IfThenElse(cond, ret);
236 ret = For(inner->var, PrimExpr(0), inner->dom->extent,
237 IterVarTypeToForKind(inner->iter_type), ret);
238 ret = For(outer->var, PrimExpr(0), outer->dom->extent,
239 IterVarTypeToForKind(outer->iter_type), ret);
240 splitted = true;
241 return ret;
242 }
243 return StmtExprMutator::VisitStmt_(op);
244 }
245 };
246
247 class LoopFuser : public StmtExprMutator {
248 const IterVar& parent;
249 const VarNode* inner;
250 const VarNode* outer;
251 bool under_outer;
252 PrimExpr extent;
253
254 public:
255 bool fused;
256 explicit LoopFuser(const FuseNode* fuse_)
257 : parent(fuse_->fused),
258 inner(fuse_->inner->var.get()),
259 outer(fuse_->outer->var.get()),
260 under_outer(false),
261 extent(0),
262 fused(false) {}
263
264 // TODO(@were): Handle imperfect loops
265 Stmt VisitStmt_(const ForNode* op) final {
266 if (op->loop_var.get() == inner) {
267 ICHECK(under_outer);
268 std::unordered_map<const VarNode*, PrimExpr> rmap;
269 rmap[op->loop_var.get()] = indexmod(parent, op->extent);
270 extent = op->extent;
271 fused = true;
272 return tir::Substitute(op->body, rmap);
273 } else if (op->loop_var.get() == outer) {
274 under_outer = true;
275 Stmt body = this->VisitStmt(op->body);
276 std::unordered_map<const VarNode*, PrimExpr> rmap;
277 rmap[op->loop_var.get()] = indexdiv(parent, extent);
278 body = tir::Substitute(body, rmap);
279 under_outer = false;
280 return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body,
281 op->thread_binding, op->annotations);
282 } else if (under_outer) {
283 Stmt body = this->VisitStmt(op->body);
284 std::unordered_map<const VarNode*, PrimExpr> rmap;
285 rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
286 body = tir::Substitute(body, rmap);
287 extent = extent * op->extent;
288 return body;
289 }
290 return StmtExprMutator::VisitStmt_(op);
291 }
292 };
293
294 for (auto& rel : stage->relations) {
295 if (const SplitNode* split = rel.as<SplitNode>()) {
296 LoopSpliter Spliter(split, dom_map);
297 stmt = Spliter(stmt);
298 ICHECK(Spliter.splitted);
299 } else if (const FuseNode* fuse = rel.as<FuseNode>()) {
300 LoopFuser Fuser(fuse);
301 stmt = Fuser(stmt);
302 ICHECK(Fuser.fused);
303 }
304 }
305
306 return stmt;
307}
308
309Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar, IterVar>& rebased,
310 Stmt stmt) {
311 class LoopAnnotator : public StmtMutator {
312 const VarNode* var;
313 const IterVarAttr& attr;
314
315 public:
316 LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {}
317
318 Stmt VisitStmt_(const ForNode* op) final {
319 tir::ExprDeepEqual expr_equal;
320
321 if (op->loop_var.get() == var) {
322 if (attr->bind_thread.defined()) {
323 const auto& iter_var = attr->bind_thread;
324 if (iter_var->dom.defined()) {
325 ICHECK(is_const_int(iter_var->dom->min, 0));
326 ICHECK(expr_equal(iter_var->dom->extent, op->extent))
327 << "Thread extent and loop extent mismatch!\n";
328 }
329 std::unordered_map<const VarNode*, PrimExpr> rmap;
330 rmap[op->loop_var.get()] = iter_var;
331 Stmt body = tir::Substitute(op->body, rmap);
332 return AttrStmt(iter_var, "thread_extent", op->extent, body);
333 } else {
334 return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type),
335 op->body, op->thread_binding, op->annotations);
336 }
337 }
338 return StmtMutator::VisitStmt_(op);
339 }
340 };
341
342 for (auto& iter_var : stage->leaf_iter_vars) {
343 bool need_change = false;
344 int found = 0;
345
346 const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
347 const VarNode* var = actual->var.get();
348 ForKind expected = IterVarTypeToForKind(iter_var->iter_type);
349 IterVarAttr attr;
350 if (stage->iter_var_attrs.count(iter_var)) {
351 attr = stage->iter_var_attrs[iter_var];
352 expected = IterVarTypeToForKind(attr->iter_type);
353 }
354
355 PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
356 if (const ForNode* op = node.as<ForNode>()) {
357 if (op->loop_var.get() == var) {
358 ++found;
359 need_change = expected != op->kind || (attr.defined() && attr->bind_thread.defined());
360 }
361 }
362 });
363
364 ICHECK_EQ(found, 1) << " iter var should be found exactly once!";
365 if (need_change) {
366 stmt = LoopAnnotator(var, attr)(std::move(stmt));
367 }
368 }
369 return stmt;
370}
371
372Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
373 const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt) {
374 std::vector<const VarNode*> current_order;
375 PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
376 if (const ForNode* op = node.as<ForNode>()) current_order.push_back(op->loop_var.get());
377 });
378 std::reverse(current_order.begin(), current_order.end());
379 auto& required_ord = stage->leaf_iter_vars;
380 ICHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
381 std::unordered_map<const VarNode*, IterVar> reorder;
382 bool need_reorder = false;
383 for (size_t i = 0; i < current_order.size(); ++i) {
384 auto& current = current_order[i];
385 const IterVar& iter_var = required_ord[i];
386 const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
387 ICHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
388 reorder[current] = required;
389 if (current != required->var.get()) {
390 need_reorder = true;
391 }
392 }
393
394 class LoopReorder : public StmtMutator {
395 const Stage& stage;
396 const std::unordered_map<IterVar, Range>& dom_map;
397 const std::unordered_map<const VarNode*, IterVar>& reorder;
398
399 public:
400 LoopReorder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
401 const std::unordered_map<const VarNode*, IterVar>& reorder)
402 : stage(stage), dom_map(dom_map), reorder(reorder) {}
403
404 Stmt VisitStmt_(const ForNode* op) final {
405 // Reorder from in to out
406 Stmt body_ = this->VisitStmt(op->body);
407 ICHECK(reorder.count(op->loop_var.get()));
408 auto target = reorder.find(op->loop_var.get())->second;
409 if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
410 return GetRef<Stmt>(op);
411 const Stmt& body = op->body.same_as(body_) ? op->body : body_;
412 ForKind kind = IterVarTypeToForKind(target->iter_type);
413 if (stage->iter_var_attrs.count(target)) {
414 kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type);
415 }
416 const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
417 return For(target->var, range->min, range->extent, kind, body, op->thread_binding,
418 op->annotations);
419 }
420 };
421
422 if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt);
423
424 return stmt;
425}
426
427Stmt ApplySchedule(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
428 Stmt stmt) {
429 // TODO(@were): Eliminate loop rebase in script parser and move the burden here
430 // Gather rebased variables
431 std::unordered_map<IterVar, IterVar> rebased;
432 for (auto rel : stage->relations) {
433 if (const auto* rebase = rel.as<RebaseNode>()) {
434 rebased[rebase->rebased] = rebase->parent;
435 ICHECK(rebase->parent->dom.defined());
436 ICHECK(dom_map.count(rebase->rebased));
437 }
438 }
439 stmt = ApplyLoopShapes(stage, dom_map, stmt);
440 stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
441 stmt = ApplyLoopAnnotations(stage, rebased, stmt);
442 return stmt;
443}
444
445std::vector<IterVar> GatherLoopVars(Stmt stmt) {
446 // TODO(@were): Write a comprehensive pass to analyze iter var types
447 std::vector<IterVar> res_;
448 PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
449 if (const ForNode* op = node.as<ForNode>()) {
450 Var loop_var(op->loop_var);
451 Range dom = Range::FromMinExtent(op->min, cast(loop_var.dtype(), op->extent));
452 res_.push_back(IterVar(dom, loop_var, ForKindToIterVarType(op->kind)));
453 }
454 });
455 std::reverse(res_.begin(), res_.end());
456 return res_;
457}
458
459// replacer to replace tensors' usage in Provide
460class ProviderReplacer : public tir::StmtMutator {
461 public:
462 explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {}
463
464 Stmt VisitStmt_(const tir::ProducerStoreNode* op) final {
465 Tensor t = Downcast<Tensor>(op->producer);
466 auto it = vmap_.find(t);
467 if (it != vmap_.end()) {
468 Stmt ret = tir::ProducerStore(it->second, op->value, op->indices);
469 found = true;
470 return this->VisitStmt(ret);
471 }
472 return StmtMutator::VisitStmt_(op);
473 }
474
475 // whether it is found.
476 bool found{false};
477
478 private:
479 const std::unordered_map<Tensor, Tensor>& vmap_;
480};
481
482Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) {
483 ProviderReplacer repl(replace);
484 Stmt ret = repl(stmt);
485 return repl.found ? ret : stmt;
486}
487} // namespace te
488} // namespace tvm
489