1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace script { |
23 | namespace printer { |
24 | |
25 | Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // |
26 | Optional<tir::BlockRealize> opt_realize, Optional<ObjectPath> opt_realize_p) { |
27 | With<TIRFrame> frame(d, block); |
28 | ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); |
29 | const tir::BlockRealizeNode* realize = |
30 | opt_realize.defined() ? opt_realize.value().get() : nullptr; |
31 | const ObjectPathNode* realize_p = opt_realize_p.defined() ? opt_realize_p.get() : nullptr; |
32 | // Step 1. Handle block var and block bindings |
33 | // Step 1.1. Obtain all loop var defined along path |
34 | std::unordered_map<const tir::VarNode*, tir::For> loop_vars; |
35 | for (Frame f : d->frames) { |
36 | if (const auto* tir_f = f.as<TIRFrameNode>()) { |
37 | if (const auto* for_loop = tir_f->tir.as<tir::ForNode>()) { |
38 | for (const tir::ForNode* l = for_loop; l != nullptr; l = l->body.as<tir::ForNode>()) { |
39 | loop_vars.insert(std::make_pair(l->loop_var.get(), GetRef<tir::For>(l))); |
40 | } |
41 | } |
42 | } |
43 | } |
44 | |
45 | std::vector<int> remap_vars_indices; |
46 | auto add_remapped_iter_var = [&](int i) -> bool { |
47 | if (realize && d->cfg->syntax_sugar) { |
48 | tir::ExprDeepEqual expr_equal; |
49 | tir::IterVar iter_var = block->iter_vars[i]; |
50 | PrimExpr value = realize->iter_values[i]; |
51 | if (iter_var->iter_type == tir::IterVarType::kDataPar || |
52 | iter_var->iter_type == tir::IterVarType::kCommReduce) { |
53 | if (const auto* var = value.as<tir::VarNode>()) { |
54 | if (loop_vars.count(var)) { |
55 | tir::For for_loop = loop_vars.at(var); |
56 | if (expr_equal(for_loop->min, iter_var->dom->min) && |
57 | expr_equal(for_loop->extent, iter_var->dom->extent)) { |
58 | remap_vars_indices.push_back(i); |
59 | return true; |
60 | } |
61 | } |
62 | } |
63 | } |
64 | } |
65 | return false; |
66 | }; |
67 | |
68 | auto print_single_iter_var = [&](int i) { |
69 | tir::IterVar iter_var = block->iter_vars[i]; |
70 | ObjectPath iter_var_p = block_p->Attr("iter_var" )->ArrayIndex(i); |
71 | ExprDoc rhs = TIR(d, "axis" ); |
72 | if (iter_var->iter_type == tir::IterVarType::kDataPar) { |
73 | rhs = rhs->Attr("spatial" ); |
74 | } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) { |
75 | rhs = rhs->Attr("reduce" ); |
76 | } else if (iter_var->iter_type == tir::IterVarType::kOrdered) { |
77 | rhs = rhs->Attr("scan" ); |
78 | } else if (iter_var->iter_type == tir::IterVarType::kOpaque) { |
79 | rhs = rhs->Attr("opaque" ); |
80 | } else { |
81 | LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: " |
82 | << tir::IterVarType2String(iter_var->iter_type); |
83 | } |
84 | ExprDoc dom{nullptr}; |
85 | if (tir::is_zero(iter_var->dom->min)) { |
86 | ExprDoc extent = d->AsDoc<ExprDoc>(iter_var->dom->extent, // |
87 | iter_var_p->Attr("dom" )->Attr("extent" )); |
88 | dom = extent; |
89 | } else { |
90 | ExprDoc min = d->AsDoc<ExprDoc>(iter_var->dom->min, iter_var_p->Attr("dom" )->Attr("min" )); |
91 | ExprDoc max = d->AsDoc<ExprDoc>(iter_var->dom->min + iter_var->dom->extent, |
92 | iter_var_p->Attr("dom" )->Attr("extent" )); |
93 | dom = TupleDoc({min, max}); |
94 | } |
95 | if (realize) { |
96 | ExprDoc binding = d->AsDoc<ExprDoc>(realize->iter_values[i], // |
97 | realize_p->Attr("iter_values" )->ArrayIndex(i)); |
98 | rhs = rhs->Call({dom, binding}); |
99 | } else { |
100 | rhs = rhs->Call({dom}); |
101 | } |
102 | (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt)); |
103 | }; |
104 | |
105 | auto print_remapped_iter_var = [&]() { |
106 | if (remap_vars_indices.size()) { |
107 | int m = remap_vars_indices.size(); |
108 | if (!m) { |
109 | return; |
110 | } |
111 | if (m == 1) { |
112 | print_single_iter_var(remap_vars_indices[0]); |
113 | remap_vars_indices.clear(); |
114 | return; |
115 | } |
116 | Array<ExprDoc> lhs; |
117 | Array<ExprDoc> loop_var_doc; |
118 | lhs.reserve(m); |
119 | loop_var_doc.reserve(m); |
120 | std::string binding_type = "" ; |
121 | Array<ObjectPath> binding_paths; |
122 | for (int i : remap_vars_indices) { |
123 | tir::IterVar iter_var = block->iter_vars[i]; |
124 | ObjectPath iter_var_p = block_p->Attr("iter_vars" )->ArrayIndex(i); |
125 | lhs.push_back(DefineVar(iter_var->var, *frame, d)); |
126 | loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i], |
127 | realize_p->Attr("iter_values" )->ArrayIndex(i))); |
128 | binding_paths.push_back(iter_var_p->Attr("iter_type" )); |
129 | binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R" ; |
130 | } |
131 | ExprDoc rhs = TIR(d, "axis" )->Attr("remap" ); |
132 | ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt); |
133 | binding_str->source_paths = std::move(binding_paths); |
134 | rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)}); |
135 | (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt)); |
136 | remap_vars_indices.clear(); |
137 | } |
138 | }; |
139 | |
140 | // Step 1.2. Construct all block var bindings |
141 | int n_vars = block->iter_vars.size(); |
142 | for (int i = 0; i < n_vars; ++i) { |
143 | if (!add_remapped_iter_var(i)) { |
144 | print_remapped_iter_var(); |
145 | print_single_iter_var(i); |
146 | } |
147 | } |
148 | print_remapped_iter_var(); |
149 | |
150 | // Step 2. Handle block predicate |
151 | if (realize) { |
152 | ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); |
153 | if (!tir::is_one(realize->predicate)) { |
154 | (*frame)->stmts.push_back(ExprStmtDoc( |
155 | TIR(d, "where" ) |
156 | ->Call({d->AsDoc<ExprDoc>(realize->predicate, realize_p->Attr("predicate" ))}))); |
157 | } |
158 | } |
159 | // Step 3. Handle block read/write regions |
160 | { |
161 | Array<ExprDoc> reads; |
162 | for (int i = 0, n = block->reads.size(); i < n; ++i) { |
163 | reads.push_back(d->AsDoc<ExprDoc>(block->reads[i], block_p->Attr("reads" )->ArrayIndex(i))); |
164 | } |
165 | (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads" )->Call(reads))); |
166 | Array<ExprDoc> writes; |
167 | for (int i = 0, n = block->writes.size(); i < n; ++i) { |
168 | writes.push_back(d->AsDoc<ExprDoc>(block->writes[i], block_p->Attr("writes" )->ArrayIndex(i))); |
169 | } |
170 | (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "writes" )->Call(writes))); |
171 | } |
172 | // Step 4. Handle block attributes |
173 | if (!block->annotations.empty()) { |
174 | (*frame)->stmts.push_back(ExprStmtDoc( |
175 | TIR(d, "block_attr" ) |
176 | ->Call({d->AsDoc<ExprDoc>(block->annotations, block_p->Attr("annotations" ))}))); |
177 | } |
178 | // Step 5. Handle `alloc_buffer` |
179 | for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) { |
180 | tir::Buffer buffer = block->alloc_buffers[i]; |
181 | ObjectPath buffer_p = block_p->Attr("alloc_buffers" )->ArrayIndex(i); |
182 | IdDoc lhs = DefineBuffer(buffer, *frame, d); |
183 | ExprDoc rhs = BufferDecl(buffer, "alloc_buffer" , {}, buffer_p, *frame, d); |
184 | (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); |
185 | } |
186 | // Step 6. Handle `match_buffer` |
187 | for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { |
188 | tir::MatchBufferRegion buffer_region = block->match_buffers[i]; |
189 | ObjectPath buffer_region_p = block_p->Attr("match_buffers" )->ArrayIndex(i); |
190 | StmtDoc doc = d->AsDoc<StmtDoc>(buffer_region, buffer_region_p); |
191 | (*frame)->stmts.push_back(doc); |
192 | } |
193 | // Step 7. Handle init block |
194 | if (block->init.defined()) { |
195 | tir::Stmt init = block->init.value(); |
196 | With<TIRFrame> init_frame(d, init); |
197 | AsDocBody(init, block_p->Attr("init" ), init_frame->get(), d); |
198 | (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR(d, "init" )->Call({}), (*init_frame)->stmts)); |
199 | } |
200 | // Step 8. Handle block body |
201 | AsDocBody(block->body, block_p->Attr("body" ), frame->get(), d); |
202 | Array<String> kwargs_keys; |
203 | Array<ExprDoc> kwargs_values; |
204 | if (!realize) { |
205 | kwargs_keys.push_back("no_realize" ); |
206 | kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); |
207 | } |
208 | return ScopeDoc(NullOpt, |
209 | TIR(d, "block" ) // |
210 | ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint" ))}, |
211 | kwargs_keys, kwargs_values), |
212 | (*frame)->stmts); |
213 | } |
214 | |
215 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
216 | .set_dispatch<tir::BlockRealize>( |
217 | "" , [](tir::BlockRealize realize, ObjectPath p, IRDocsifier d) -> Doc { |
218 | Doc doc = PrintBlock(d, realize->block, p->Attr("block" ), realize, p); |
219 | // since we do not have d->AsDoc for realize->block, |
220 | // we should add possible doc decoration manually. |
221 | AddDocDecoration<ScopeDoc>(doc, realize->block, p->Attr("block" ), d->cfg); |
222 | return doc; |
223 | }); |
224 | |
225 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
226 | .set_dispatch<tir::Block>("" , [](tir::Block block, ObjectPath p, IRDocsifier d) -> Doc { |
227 | return PrintBlock(d, block, p, NullOpt, NullOpt); |
228 | }); |
229 | |
230 | TVM_SCRIPT_REPR(tir::BlockNode, ReprPrintTIR); |
231 | TVM_SCRIPT_REPR(tir::BlockRealizeNode, ReprPrintTIR); |
232 | |
233 | } // namespace printer |
234 | } // namespace script |
235 | } // namespace tvm |
236 | |