1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./utils.h"
20
21namespace tvm {
22namespace script {
23namespace printer {
24
25TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
26 .set_dispatch<tir::For>("", [](tir::For loop, ObjectPath loop_p, IRDocsifier d) -> Doc {
27 // Step 1. Check syntactic sugar: `T.grid`
28 std::vector<const tir::ForNode*> grid;
29 std::unordered_set<const tir::VarNode*> grid_loop_vars;
30 auto f_var_dep = [&grid_loop_vars](const PrimExpr& e) -> bool {
31 return tir::UsesVar(e, [&grid_loop_vars](const tir::VarNode* v) -> bool { //
32 return grid_loop_vars.count(v);
33 });
34 };
35 if (d->cfg->syntax_sugar) {
36 for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
37 ICHECK(l->loop_var->dtype == l->min->dtype);
38 ICHECK(l->loop_var->dtype == l->extent->dtype);
39 if (l->kind != tir::ForKind::kSerial || //
40 !tir::is_zero(l->min) || //
41 !l->annotations.empty() || //
42 f_var_dep(l->extent)) {
43 break;
44 }
45 grid.push_back(l);
46 grid_loop_vars.insert(l->loop_var.get());
47 }
48 }
49 With<TIRFrame> f(d, loop);
50 // Step 2. Construct `T.grid`
51 if (grid.size() > 1) {
52 int n = grid.size();
53 Array<ExprDoc> lhs;
54 Array<ExprDoc> rhs;
55 lhs.reserve(n);
56 rhs.reserve(n);
57 for (int i = 0; i < n; ++i) {
58 const tir::ForNode* loop = grid[i];
59 lhs.push_back(DefineVar(loop->loop_var, *f, d));
60 rhs.push_back(d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent")));
61 loop_p = loop_p->Attr("body");
62 }
63 AsDocBody(grid.back()->body, loop_p, (*f).get(), d);
64 return ForDoc(TupleDoc(lhs), TIR(d, "grid")->Call(rhs), (*f)->stmts);
65 }
66 // Step 3. If not `T.grid`, print loop kind accordingly
67 ExprDoc lhs = DefineVar(loop->loop_var, *f, d);
68 Optional<ExprDoc> min = NullOpt;
69 Optional<ExprDoc> max = NullOpt;
70 Optional<ExprDoc> annotations = NullOpt;
71 Optional<ExprDoc> thread = NullOpt;
72 if (tir::is_zero(loop->min)) {
73 max = d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent"));
74 } else {
75 min = d->AsDoc<ExprDoc>(loop->min, loop_p->Attr("min"));
76 max = d->AsDoc<ExprDoc>(loop->min + loop->extent, loop_p->Attr("extent"));
77 }
78 if (!loop->annotations.empty()) {
79 annotations = d->AsDoc<ExprDoc>(loop->annotations, loop_p->Attr("annotations"));
80 }
81 ExprDoc prefix{nullptr};
82 if (loop->kind == tir::ForKind::kSerial) {
83 if (loop->annotations.empty()) {
84 prefix = IdDoc("range");
85 } else {
86 prefix = TIR(d, "serial");
87 }
88 } else if (loop->kind == tir::ForKind::kParallel) {
89 prefix = TIR(d, "parallel");
90 } else if (loop->kind == tir::ForKind::kUnrolled) {
91 prefix = TIR(d, "unroll");
92 } else if (loop->kind == tir::ForKind::kVectorized) {
93 prefix = TIR(d, "vectorized");
94 } else if (loop->kind == tir::ForKind::kThreadBinding) {
95 prefix = TIR(d, "thread_binding");
96 thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag,
97 loop_p->Attr("thread_binding"));
98 } else {
99 LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind);
100 }
101 Array<ExprDoc> args;
102 Array<String> kwargs_keys;
103 Array<ExprDoc> kwargs_values;
104 if (min.defined()) {
105 args.push_back(min.value());
106 }
107 if (max.defined()) {
108 args.push_back(max.value());
109 }
110 if (thread.defined()) {
111 kwargs_keys.push_back("thread");
112 kwargs_values.push_back(thread.value());
113 }
114 if (annotations.defined()) {
115 kwargs_keys.push_back("annotations");
116 kwargs_values.push_back(annotations.value());
117 }
118 ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
119 AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
120 return ForDoc(lhs, rhs, (*f)->stmts);
121 });
122
123TVM_SCRIPT_REPR(tir::ForNode, ReprPrintTIR);
124
125} // namespace printer
126} // namespace script
127} // namespace tvm
128