1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/runtime/device_api.h> // For `kAllocAlignment` |
20 | |
21 | #include "./utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace script { |
25 | namespace printer { |
26 | |
27 | Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, |
28 | const IRDocsifier& d) { |
29 | Map<String, ExprDoc> kwargs; |
30 | auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const String& key) { |
31 | if (Optional<ExprDoc> doc = d->GetVarDoc(e)) { |
32 | kwargs.Set(key, doc.value()); |
33 | return false; |
34 | } |
35 | if (e->IsInstance<tir::VarNode>()) { |
36 | d->Define(e, frame, [=]() { return d->AsDoc<IdDoc>(buffer, p)->Attr(key); }); |
37 | return true; |
38 | } |
39 | kwargs.Set(key, d->AsDoc<ExprDoc>(e, p)); |
40 | return false; |
41 | }; |
42 | auto array_out_line_var_def = [&](const Array<PrimExpr>& array, const ObjectPath& p, |
43 | const String& key) { |
44 | int n = array.size(); |
45 | Array<ExprDoc> results; |
46 | results.reserve(n); |
47 | for (int i = 0; i < n; ++i) { |
48 | PrimExpr s = array[i]; |
49 | ObjectPath s_path = p->ArrayIndex(i); |
50 | // Add out-of-line definition for a new Var in shape |
51 | results.push_back(d->AsDoc<ExprDoc>(s, s_path)); |
52 | } |
53 | kwargs.Set(key, TupleDoc(results)); |
54 | }; |
55 | // Step 1. Handle `buffer.shape` |
56 | array_out_line_var_def(buffer->shape, p->Attr("shape" ), "shape" ); |
57 | // Step 2. Handle `buffer.dtype` |
58 | if (buffer->dtype != d->cfg->buffer_dtype) { |
59 | kwargs.Set("dtype" , LiteralDoc::DataType(buffer->dtype, p->Attr("dtype" ))); |
60 | } |
61 | // Step 3. Handle `buffer.data` |
62 | implicit_var_def(buffer->data, p->Attr("data" ), "data" ); |
63 | // Step 4. Handle `buffer.strides` |
64 | if (!buffer->strides.empty()) { |
65 | array_out_line_var_def(buffer->strides, p->Attr("strides" ), "strides" ); |
66 | } |
67 | // Step 5. Handle `buffer.elem_offset` |
68 | bool needs_print_factor = false; |
69 | if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) { |
70 | if (int_imm->value != 0) { |
71 | kwargs.Set("elem_offset" , d->AsDoc<ExprDoc>(buffer->elem_offset, p->Attr("elem_offset" ))); |
72 | } |
73 | } else { |
74 | needs_print_factor = |
75 | implicit_var_def(buffer->elem_offset, p->Attr("elem_offset" ), "elem_offset" ); |
76 | } |
77 | // Step 6. Handle `buffer.scope` |
78 | { |
79 | String scope = buffer.scope(); |
80 | if (scope != "global" ) { |
81 | kwargs.Set( |
82 | "scope" , |
83 | LiteralDoc::Str(scope, p->Attr("data" )->Attr("type_annotation" )->Attr("storage_scope" ))); |
84 | } |
85 | } |
86 | // Step 7. Handle `buffer.data_alignment` |
87 | if (buffer->data_alignment != runtime::kAllocAlignment) { |
88 | kwargs.Set("align" , LiteralDoc::Int(buffer->data_alignment, p->Attr("data_alignment" ))); |
89 | } |
90 | // Step 8. Handle `buffer.offset_factor` |
91 | if (needs_print_factor || buffer->offset_factor != 1) { |
92 | kwargs.Set("offset_factor" , LiteralDoc::Int(buffer->offset_factor, p->Attr("offset_factor" ))); |
93 | } |
94 | // Step 9. Handle `buffer.buffer_type` |
95 | if (buffer->buffer_type != tir::BufferType::kDefault) { |
96 | kwargs.Set("type" , LiteralDoc::Str("auto" , p->Attr("buffer_type" ))); |
97 | } |
98 | // Step 10. Handle `buffer.axis_separator` |
99 | if (!buffer->axis_separators.empty()) { |
100 | kwargs.Set("axis_separators" , |
101 | d->AsDoc<ExprDoc>(buffer->axis_separators, p->Attr("axis_separators" ))); |
102 | } |
103 | return kwargs; |
104 | } |
105 | |
106 | ExprDoc BufferCall(const ExprDoc& prefix, const Map<String, ExprDoc>& attrs, Array<ExprDoc> args) { |
107 | Array<String> kwargs_keys; |
108 | Array<ExprDoc> kwargs_values; |
109 | for (String s : {"shape" , "dtype" }) { |
110 | if (Optional<ExprDoc> doc = attrs.Get(s)) { |
111 | args.push_back(doc.value()); |
112 | } |
113 | } |
114 | for (String s : {"data" , "strides" , "elem_offset" , "scope" , "align" , "offset_factor" , "type" , |
115 | "axis_separators" }) { |
116 | if (Optional<ExprDoc> doc = attrs.Get(s)) { |
117 | kwargs_keys.push_back(s); |
118 | kwargs_values.push_back(doc.value()); |
119 | } |
120 | } |
121 | return prefix->Call(args, kwargs_keys, kwargs_values); |
122 | } |
123 | |
124 | ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args, |
125 | const ObjectPath& p, const Frame& frame, const IRDocsifier& d) { |
126 | return BufferCall(/*prefix=*/TIR(d, method), |
127 | /*attrs=*/BufferAttrs(buffer, p, frame, d), |
128 | /*args=*/args); |
129 | } |
130 | |
131 | ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, |
132 | const IRDocsifier& d) { |
133 | Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d); |
134 | ExprDoc shape = attrs.Get("shape" ).value(); |
135 | ExprDoc dtype = |
136 | attrs.Get("dtype" ).value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype" ))); |
137 | return TIR(d, "Buffer" )->Call({shape, dtype}, {}, {}); |
138 | } |
139 | |
140 | Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p, |
141 | const IRDocsifier& d) { |
142 | int n = indices.size(); |
143 | Array<Doc> indices_doc; |
144 | indices_doc.reserve(n); |
145 | for (int i = 0; i < n; ++i) { |
146 | if (const auto* ramp = indices[i].as<tir::RampNode>()) { |
147 | if (const auto* stride = ramp->stride.as<IntImmNode>()) { |
148 | ObjectPath ramp_p = p->Attr("indices" )->ArrayIndex(i); |
149 | ObjectPath stride_p = ramp_p->Attr("stride" ); |
150 | ExprDoc start = d->AsDoc<ExprDoc>(ramp->base, // |
151 | ramp_p->Attr("base" )); |
152 | ExprDoc stop = d->AsDoc<ExprDoc>(ramp->base + ramp->lanes * ramp->stride, // |
153 | ramp_p->Attr("lanes" )); |
154 | Optional<ExprDoc> step = NullOpt; |
155 | if (stride->value != 1) { |
156 | step = d->AsDoc<ExprDoc>(ramp->stride, ramp_p->Attr("stride" )); |
157 | } |
158 | indices_doc.push_back(SliceDoc(start, stop, step)); |
159 | continue; |
160 | } |
161 | } |
162 | indices_doc.push_back(d->AsDoc<ExprDoc>(indices[i], p->Attr("indices" )->ArrayIndex(i))); |
163 | } |
164 | return indices_doc; |
165 | } |
166 | |
167 | Array<Doc> BufferSlices(const Array<Range>& region, const ObjectPath& p, const IRDocsifier& d) { |
168 | int n = region.size(); |
169 | Array<Doc> indices; |
170 | indices.reserve(n); |
171 | for (int i = 0; i < n; ++i) { |
172 | Range range = region[i]; |
173 | ObjectPath range_p = p->ArrayIndex(i); |
174 | ExprDoc min = d->AsDoc<ExprDoc>(range->min, range_p->Attr("min" )); |
175 | if (tir::is_one(range->extent)) { |
176 | indices.push_back(min); |
177 | } else { |
178 | ExprDoc max = d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent" )); |
179 | indices.push_back(SliceDoc(min, max, NullOpt)); |
180 | } |
181 | } |
182 | return indices; |
183 | } |
184 | |
185 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
186 | .set_dispatch<tir::BufferRegion>( |
187 | "" , [](tir::BufferRegion buffer_region, ObjectPath p, IRDocsifier d) -> Doc { |
188 | ExprDoc prefix = d->AsDoc<ExprDoc>(buffer_region->buffer, p->Attr("buffer" )); |
189 | return prefix[BufferSlices(buffer_region->region, p->Attr("region" ), d)]; |
190 | }); |
191 | |
192 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
193 | .set_dispatch<tir::BufferStore>( // |
194 | "" , [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { |
195 | ExprDoc buffer = d->AsDoc<ExprDoc>(store->buffer, p->Attr("buffer" )); |
196 | return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices" ), d)], |
197 | /*rhs=*/d->AsDoc<ExprDoc>(store->value, p->Attr("value" )), NullOpt); |
198 | }); |
199 | |
200 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
201 | .set_dispatch<tir::BufferLoad>( // |
202 | "" , [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { |
203 | ExprDoc buffer = d->AsDoc<ExprDoc>(load->buffer, p->Attr("buffer" )); |
204 | return buffer[BufferIndices(load->indices, p->Attr("indices" ), d)]; |
205 | }); |
206 | |
207 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // |
208 | .set_dispatch<tir::Buffer>("" , [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc { |
209 | if (!d->IsVarDefined(buffer)) { |
210 | if (Optional<Frame> opt_f = FindLowestVarDef(buffer, d)) { |
211 | ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); |
212 | ExprDoc rhs = BufferDecl(buffer, "Buffer" , {}, p, opt_f.value(), d); |
213 | opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); |
214 | } |
215 | } |
216 | if (Optional<ExprDoc> doc = d->GetVarDoc(buffer)) { |
217 | return doc.value(); |
218 | } |
219 | LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer; |
220 | }); |
221 | |
222 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
223 | .set_dispatch<tir::MatchBufferRegion>( |
224 | "" , [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc { |
225 | Frame frame = d->frames.back(); |
226 | ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d); |
227 | ExprDoc src_buffer = d->AsDoc<ExprDoc>(stmt->source, p->Attr("source" )); |
228 | ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer" , {src_buffer}, p->Attr("buffer" ), |
229 | d->frames.back(), d); |
230 | return AssignDoc(lhs, rhs, NullOpt); |
231 | }); |
232 | |
233 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
234 | .set_dispatch<tir::ProducerLoad>( // |
235 | "" , [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc { |
236 | ExprDoc prefix = IdDoc(load->producer->GetNameHint()); |
237 | return prefix[BufferIndices(load->indices, p->Attr("indices" ), d)]; |
238 | }); |
239 | |
240 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
241 | .set_dispatch<tir::ProducerStore>( // |
242 | "" , [](tir::ProducerStore store, ObjectPath p, IRDocsifier d) -> Doc { |
243 | ExprDoc prefix = IdDoc(store->producer->GetNameHint()); |
244 | prefix = prefix[BufferIndices(store->indices, p->Attr("indices" ), d)]; |
245 | return AssignDoc(prefix, d->AsDoc<ExprDoc>(store->value, p->Attr("value" )), NullOpt); |
246 | }); |
247 | |
248 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
249 | .set_dispatch<tir::ProducerRealize>( // |
250 | "" , [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { |
251 | ExprDoc prefix = IdDoc(stmt->producer->GetNameHint()); |
252 | prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds" ), d)]; |
253 | prefix = TIR(d, "ProducerRealize" ) |
254 | ->Call({prefix, d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition" ))}); |
255 | With<TIRFrame> f(d, stmt); |
256 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
257 | return ScopeDoc(NullOpt, prefix, (*f)->stmts); |
258 | }); |
259 | |
260 | TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); |
261 | TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR); |
262 | TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); |
263 | TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); |
264 | TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); |
265 | TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); |
266 | TVM_SCRIPT_REPR(tir::ProducerStoreNode, ReprPrintTIR); |
267 | TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR); |
268 | |
269 | } // namespace printer |
270 | } // namespace script |
271 | } // namespace tvm |
272 | |