1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./utils.h"
20
21namespace tvm {
22namespace script {
23namespace printer {
24
25Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
26 Optional<tir::BlockRealize> opt_realize, Optional<ObjectPath> opt_realize_p) {
27 With<TIRFrame> frame(d, block);
28 ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined());
29 const tir::BlockRealizeNode* realize =
30 opt_realize.defined() ? opt_realize.value().get() : nullptr;
31 const ObjectPathNode* realize_p = opt_realize_p.defined() ? opt_realize_p.get() : nullptr;
32 // Step 1. Handle block var and block bindings
33 // Step 1.1. Obtain all loop var defined along path
34 std::unordered_map<const tir::VarNode*, tir::For> loop_vars;
35 for (Frame f : d->frames) {
36 if (const auto* tir_f = f.as<TIRFrameNode>()) {
37 if (const auto* for_loop = tir_f->tir.as<tir::ForNode>()) {
38 for (const tir::ForNode* l = for_loop; l != nullptr; l = l->body.as<tir::ForNode>()) {
39 loop_vars.insert(std::make_pair(l->loop_var.get(), GetRef<tir::For>(l)));
40 }
41 }
42 }
43 }
44
45 std::vector<int> remap_vars_indices;
46 auto add_remapped_iter_var = [&](int i) -> bool {
47 if (realize && d->cfg->syntax_sugar) {
48 tir::ExprDeepEqual expr_equal;
49 tir::IterVar iter_var = block->iter_vars[i];
50 PrimExpr value = realize->iter_values[i];
51 if (iter_var->iter_type == tir::IterVarType::kDataPar ||
52 iter_var->iter_type == tir::IterVarType::kCommReduce) {
53 if (const auto* var = value.as<tir::VarNode>()) {
54 if (loop_vars.count(var)) {
55 tir::For for_loop = loop_vars.at(var);
56 if (expr_equal(for_loop->min, iter_var->dom->min) &&
57 expr_equal(for_loop->extent, iter_var->dom->extent)) {
58 remap_vars_indices.push_back(i);
59 return true;
60 }
61 }
62 }
63 }
64 }
65 return false;
66 };
67
68 auto print_single_iter_var = [&](int i) {
69 tir::IterVar iter_var = block->iter_vars[i];
70 ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
71 ExprDoc rhs = TIR(d, "axis");
72 if (iter_var->iter_type == tir::IterVarType::kDataPar) {
73 rhs = rhs->Attr("spatial");
74 } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) {
75 rhs = rhs->Attr("reduce");
76 } else if (iter_var->iter_type == tir::IterVarType::kOrdered) {
77 rhs = rhs->Attr("scan");
78 } else if (iter_var->iter_type == tir::IterVarType::kOpaque) {
79 rhs = rhs->Attr("opaque");
80 } else {
81 LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: "
82 << tir::IterVarType2String(iter_var->iter_type);
83 }
84 ExprDoc dom{nullptr};
85 if (tir::is_zero(iter_var->dom->min)) {
86 ExprDoc extent = d->AsDoc<ExprDoc>(iter_var->dom->extent, //
87 iter_var_p->Attr("dom")->Attr("extent"));
88 dom = extent;
89 } else {
90 ExprDoc min = d->AsDoc<ExprDoc>(iter_var->dom->min, iter_var_p->Attr("dom")->Attr("min"));
91 ExprDoc max = d->AsDoc<ExprDoc>(iter_var->dom->min + iter_var->dom->extent,
92 iter_var_p->Attr("dom")->Attr("extent"));
93 dom = TupleDoc({min, max});
94 }
95 if (realize) {
96 ExprDoc binding = d->AsDoc<ExprDoc>(realize->iter_values[i], //
97 realize_p->Attr("iter_values")->ArrayIndex(i));
98 rhs = rhs->Call({dom, binding});
99 } else {
100 rhs = rhs->Call({dom});
101 }
102 (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt));
103 };
104
105 auto print_remapped_iter_var = [&]() {
106 if (remap_vars_indices.size()) {
107 int m = remap_vars_indices.size();
108 if (!m) {
109 return;
110 }
111 if (m == 1) {
112 print_single_iter_var(remap_vars_indices[0]);
113 remap_vars_indices.clear();
114 return;
115 }
116 Array<ExprDoc> lhs;
117 Array<ExprDoc> loop_var_doc;
118 lhs.reserve(m);
119 loop_var_doc.reserve(m);
120 std::string binding_type = "";
121 Array<ObjectPath> binding_paths;
122 for (int i : remap_vars_indices) {
123 tir::IterVar iter_var = block->iter_vars[i];
124 ObjectPath iter_var_p = block_p->Attr("iter_vars")->ArrayIndex(i);
125 lhs.push_back(DefineVar(iter_var->var, *frame, d));
126 loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i],
127 realize_p->Attr("iter_values")->ArrayIndex(i)));
128 binding_paths.push_back(iter_var_p->Attr("iter_type"));
129 binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R";
130 }
131 ExprDoc rhs = TIR(d, "axis")->Attr("remap");
132 ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt);
133 binding_str->source_paths = std::move(binding_paths);
134 rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)});
135 (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt));
136 remap_vars_indices.clear();
137 }
138 };
139
140 // Step 1.2. Construct all block var bindings
141 int n_vars = block->iter_vars.size();
142 for (int i = 0; i < n_vars; ++i) {
143 if (!add_remapped_iter_var(i)) {
144 print_remapped_iter_var();
145 print_single_iter_var(i);
146 }
147 }
148 print_remapped_iter_var();
149
150 // Step 2. Handle block predicate
151 if (realize) {
152 ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool());
153 if (!tir::is_one(realize->predicate)) {
154 (*frame)->stmts.push_back(ExprStmtDoc(
155 TIR(d, "where")
156 ->Call({d->AsDoc<ExprDoc>(realize->predicate, realize_p->Attr("predicate"))})));
157 }
158 }
159 // Step 3. Handle block read/write regions
160 {
161 Array<ExprDoc> reads;
162 for (int i = 0, n = block->reads.size(); i < n; ++i) {
163 reads.push_back(d->AsDoc<ExprDoc>(block->reads[i], block_p->Attr("reads")->ArrayIndex(i)));
164 }
165 (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads)));
166 Array<ExprDoc> writes;
167 for (int i = 0, n = block->writes.size(); i < n; ++i) {
168 writes.push_back(d->AsDoc<ExprDoc>(block->writes[i], block_p->Attr("writes")->ArrayIndex(i)));
169 }
170 (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "writes")->Call(writes)));
171 }
172 // Step 4. Handle block attributes
173 if (!block->annotations.empty()) {
174 (*frame)->stmts.push_back(ExprStmtDoc(
175 TIR(d, "block_attr")
176 ->Call({d->AsDoc<ExprDoc>(block->annotations, block_p->Attr("annotations"))})));
177 }
178 // Step 5. Handle `alloc_buffer`
179 for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) {
180 tir::Buffer buffer = block->alloc_buffers[i];
181 ObjectPath buffer_p = block_p->Attr("alloc_buffers")->ArrayIndex(i);
182 IdDoc lhs = DefineBuffer(buffer, *frame, d);
183 ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d);
184 (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
185 }
186 // Step 6. Handle `match_buffer`
187 for (int i = 0, n = block->match_buffers.size(); i < n; ++i) {
188 tir::MatchBufferRegion buffer_region = block->match_buffers[i];
189 ObjectPath buffer_region_p = block_p->Attr("match_buffers")->ArrayIndex(i);
190 StmtDoc doc = d->AsDoc<StmtDoc>(buffer_region, buffer_region_p);
191 (*frame)->stmts.push_back(doc);
192 }
193 // Step 7. Handle init block
194 if (block->init.defined()) {
195 tir::Stmt init = block->init.value();
196 With<TIRFrame> init_frame(d, init);
197 AsDocBody(init, block_p->Attr("init"), init_frame->get(), d);
198 (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR(d, "init")->Call({}), (*init_frame)->stmts));
199 }
200 // Step 8. Handle block body
201 AsDocBody(block->body, block_p->Attr("body"), frame->get(), d);
202 Array<String> kwargs_keys;
203 Array<ExprDoc> kwargs_values;
204 if (!realize) {
205 kwargs_keys.push_back("no_realize");
206 kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt));
207 }
208 return ScopeDoc(NullOpt,
209 TIR(d, "block") //
210 ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))},
211 kwargs_keys, kwargs_values),
212 (*frame)->stmts);
213}
214
215TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
216 .set_dispatch<tir::BlockRealize>(
217 "", [](tir::BlockRealize realize, ObjectPath p, IRDocsifier d) -> Doc {
218 Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p);
219 // since we do not have d->AsDoc for realize->block,
220 // we should add possible doc decoration manually.
221 AddDocDecoration<ScopeDoc>(doc, realize->block, p->Attr("block"), d->cfg);
222 return doc;
223 });
224
225TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
226 .set_dispatch<tir::Block>("", [](tir::Block block, ObjectPath p, IRDocsifier d) -> Doc {
227 return PrintBlock(d, block, p, NullOpt, NullOpt);
228 });
229
230TVM_SCRIPT_REPR(tir::BlockNode, ReprPrintTIR);
231TVM_SCRIPT_REPR(tir::BlockRealizeNode, ReprPrintTIR);
232
233} // namespace printer
234} // namespace script
235} // namespace tvm
236