1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace script { |
23 | namespace printer { |
24 | |
25 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
26 | .set_dispatch<tir::For>("" , [](tir::For loop, ObjectPath loop_p, IRDocsifier d) -> Doc { |
27 | // Step 1. Check syntactic sugar: `T.grid` |
28 | std::vector<const tir::ForNode*> grid; |
29 | std::unordered_set<const tir::VarNode*> grid_loop_vars; |
30 | auto f_var_dep = [&grid_loop_vars](const PrimExpr& e) -> bool { |
31 | return tir::UsesVar(e, [&grid_loop_vars](const tir::VarNode* v) -> bool { // |
32 | return grid_loop_vars.count(v); |
33 | }); |
34 | }; |
35 | if (d->cfg->syntax_sugar) { |
36 | for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) { |
37 | ICHECK(l->loop_var->dtype == l->min->dtype); |
38 | ICHECK(l->loop_var->dtype == l->extent->dtype); |
39 | if (l->kind != tir::ForKind::kSerial || // |
40 | !tir::is_zero(l->min) || // |
41 | !l->annotations.empty() || // |
42 | f_var_dep(l->extent)) { |
43 | break; |
44 | } |
45 | grid.push_back(l); |
46 | grid_loop_vars.insert(l->loop_var.get()); |
47 | } |
48 | } |
49 | With<TIRFrame> f(d, loop); |
50 | // Step 2. Construct `T.grid` |
51 | if (grid.size() > 1) { |
52 | int n = grid.size(); |
53 | Array<ExprDoc> lhs; |
54 | Array<ExprDoc> rhs; |
55 | lhs.reserve(n); |
56 | rhs.reserve(n); |
57 | for (int i = 0; i < n; ++i) { |
58 | const tir::ForNode* loop = grid[i]; |
59 | lhs.push_back(DefineVar(loop->loop_var, *f, d)); |
60 | rhs.push_back(d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent" ))); |
61 | loop_p = loop_p->Attr("body" ); |
62 | } |
63 | AsDocBody(grid.back()->body, loop_p, (*f).get(), d); |
64 | return ForDoc(TupleDoc(lhs), TIR(d, "grid" )->Call(rhs), (*f)->stmts); |
65 | } |
66 | // Step 3. If not `T.grid`, print loop kind accordingly |
67 | ExprDoc lhs = DefineVar(loop->loop_var, *f, d); |
68 | Optional<ExprDoc> min = NullOpt; |
69 | Optional<ExprDoc> max = NullOpt; |
70 | Optional<ExprDoc> annotations = NullOpt; |
71 | Optional<ExprDoc> thread = NullOpt; |
72 | if (tir::is_zero(loop->min)) { |
73 | max = d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent" )); |
74 | } else { |
75 | min = d->AsDoc<ExprDoc>(loop->min, loop_p->Attr("min" )); |
76 | max = d->AsDoc<ExprDoc>(loop->min + loop->extent, loop_p->Attr("extent" )); |
77 | } |
78 | if (!loop->annotations.empty()) { |
79 | annotations = d->AsDoc<ExprDoc>(loop->annotations, loop_p->Attr("annotations" )); |
80 | } |
81 | ExprDoc prefix{nullptr}; |
82 | if (loop->kind == tir::ForKind::kSerial) { |
83 | if (loop->annotations.empty()) { |
84 | prefix = IdDoc("range" ); |
85 | } else { |
86 | prefix = TIR(d, "serial" ); |
87 | } |
88 | } else if (loop->kind == tir::ForKind::kParallel) { |
89 | prefix = TIR(d, "parallel" ); |
90 | } else if (loop->kind == tir::ForKind::kUnrolled) { |
91 | prefix = TIR(d, "unroll" ); |
92 | } else if (loop->kind == tir::ForKind::kVectorized) { |
93 | prefix = TIR(d, "vectorized" ); |
94 | } else if (loop->kind == tir::ForKind::kThreadBinding) { |
95 | prefix = TIR(d, "thread_binding" ); |
96 | thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag, |
97 | loop_p->Attr("thread_binding" )); |
98 | } else { |
99 | LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind); |
100 | } |
101 | Array<ExprDoc> args; |
102 | Array<String> kwargs_keys; |
103 | Array<ExprDoc> kwargs_values; |
104 | if (min.defined()) { |
105 | args.push_back(min.value()); |
106 | } |
107 | if (max.defined()) { |
108 | args.push_back(max.value()); |
109 | } |
110 | if (thread.defined()) { |
111 | kwargs_keys.push_back("thread" ); |
112 | kwargs_values.push_back(thread.value()); |
113 | } |
114 | if (annotations.defined()) { |
115 | kwargs_keys.push_back("annotations" ); |
116 | kwargs_values.push_back(annotations.value()); |
117 | } |
118 | ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values); |
119 | AsDocBody(loop->body, loop_p->Attr("body" ), (*f).get(), d); |
120 | return ForDoc(lhs, rhs, (*f)->stmts); |
121 | }); |
122 | |
123 | TVM_SCRIPT_REPR(tir::ForNode, ReprPrintTIR); |
124 | |
125 | } // namespace printer |
126 | } // namespace script |
127 | } // namespace tvm |
128 | |