1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/runtime/device_api.h> // For `kAllocAlignment`
20
21#include "./utils.h"
22
23namespace tvm {
24namespace script {
25namespace printer {
26
27Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
28 const IRDocsifier& d) {
29 Map<String, ExprDoc> kwargs;
30 auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const String& key) {
31 if (Optional<ExprDoc> doc = d->GetVarDoc(e)) {
32 kwargs.Set(key, doc.value());
33 return false;
34 }
35 if (e->IsInstance<tir::VarNode>()) {
36 d->Define(e, frame, [=]() { return d->AsDoc<IdDoc>(buffer, p)->Attr(key); });
37 return true;
38 }
39 kwargs.Set(key, d->AsDoc<ExprDoc>(e, p));
40 return false;
41 };
42 auto array_out_line_var_def = [&](const Array<PrimExpr>& array, const ObjectPath& p,
43 const String& key) {
44 int n = array.size();
45 Array<ExprDoc> results;
46 results.reserve(n);
47 for (int i = 0; i < n; ++i) {
48 PrimExpr s = array[i];
49 ObjectPath s_path = p->ArrayIndex(i);
50 // Add out-of-line definition for a new Var in shape
51 results.push_back(d->AsDoc<ExprDoc>(s, s_path));
52 }
53 kwargs.Set(key, TupleDoc(results));
54 };
55 // Step 1. Handle `buffer.shape`
56 array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape");
57 // Step 2. Handle `buffer.dtype`
58 if (buffer->dtype != d->cfg->buffer_dtype) {
59 kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
60 }
61 // Step 3. Handle `buffer.data`
62 implicit_var_def(buffer->data, p->Attr("data"), "data");
63 // Step 4. Handle `buffer.strides`
64 if (!buffer->strides.empty()) {
65 array_out_line_var_def(buffer->strides, p->Attr("strides"), "strides");
66 }
67 // Step 5. Handle `buffer.elem_offset`
68 bool needs_print_factor = false;
69 if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) {
70 if (int_imm->value != 0) {
71 kwargs.Set("elem_offset", d->AsDoc<ExprDoc>(buffer->elem_offset, p->Attr("elem_offset")));
72 }
73 } else {
74 needs_print_factor =
75 implicit_var_def(buffer->elem_offset, p->Attr("elem_offset"), "elem_offset");
76 }
77 // Step 6. Handle `buffer.scope`
78 {
79 String scope = buffer.scope();
80 if (scope != "global") {
81 kwargs.Set(
82 "scope",
83 LiteralDoc::Str(scope, p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
84 }
85 }
86 // Step 7. Handle `buffer.data_alignment`
87 if (buffer->data_alignment != runtime::kAllocAlignment) {
88 kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, p->Attr("data_alignment")));
89 }
90 // Step 8. Handle `buffer.offset_factor`
91 if (needs_print_factor || buffer->offset_factor != 1) {
92 kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor, p->Attr("offset_factor")));
93 }
94 // Step 9. Handle `buffer.buffer_type`
95 if (buffer->buffer_type != tir::BufferType::kDefault) {
96 kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type")));
97 }
98 // Step 10. Handle `buffer.axis_separator`
99 if (!buffer->axis_separators.empty()) {
100 kwargs.Set("axis_separators",
101 d->AsDoc<ExprDoc>(buffer->axis_separators, p->Attr("axis_separators")));
102 }
103 return kwargs;
104}
105
106ExprDoc BufferCall(const ExprDoc& prefix, const Map<String, ExprDoc>& attrs, Array<ExprDoc> args) {
107 Array<String> kwargs_keys;
108 Array<ExprDoc> kwargs_values;
109 for (String s : {"shape", "dtype"}) {
110 if (Optional<ExprDoc> doc = attrs.Get(s)) {
111 args.push_back(doc.value());
112 }
113 }
114 for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", "type",
115 "axis_separators"}) {
116 if (Optional<ExprDoc> doc = attrs.Get(s)) {
117 kwargs_keys.push_back(s);
118 kwargs_values.push_back(doc.value());
119 }
120 }
121 return prefix->Call(args, kwargs_keys, kwargs_values);
122}
123
124ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
125 const ObjectPath& p, const Frame& frame, const IRDocsifier& d) {
126 return BufferCall(/*prefix=*/TIR(d, method),
127 /*attrs=*/BufferAttrs(buffer, p, frame, d),
128 /*args=*/args);
129}
130
131ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
132 const IRDocsifier& d) {
133 Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
134 ExprDoc shape = attrs.Get("shape").value();
135 ExprDoc dtype =
136 attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
137 return TIR(d, "Buffer")->Call({shape, dtype}, {}, {});
138}
139
140Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
141 const IRDocsifier& d) {
142 int n = indices.size();
143 Array<Doc> indices_doc;
144 indices_doc.reserve(n);
145 for (int i = 0; i < n; ++i) {
146 if (const auto* ramp = indices[i].as<tir::RampNode>()) {
147 if (const auto* stride = ramp->stride.as<IntImmNode>()) {
148 ObjectPath ramp_p = p->Attr("indices")->ArrayIndex(i);
149 ObjectPath stride_p = ramp_p->Attr("stride");
150 ExprDoc start = d->AsDoc<ExprDoc>(ramp->base, //
151 ramp_p->Attr("base"));
152 ExprDoc stop = d->AsDoc<ExprDoc>(ramp->base + ramp->lanes * ramp->stride, //
153 ramp_p->Attr("lanes"));
154 Optional<ExprDoc> step = NullOpt;
155 if (stride->value != 1) {
156 step = d->AsDoc<ExprDoc>(ramp->stride, ramp_p->Attr("stride"));
157 }
158 indices_doc.push_back(SliceDoc(start, stop, step));
159 continue;
160 }
161 }
162 indices_doc.push_back(d->AsDoc<ExprDoc>(indices[i], p->Attr("indices")->ArrayIndex(i)));
163 }
164 return indices_doc;
165}
166
167Array<Doc> BufferSlices(const Array<Range>& region, const ObjectPath& p, const IRDocsifier& d) {
168 int n = region.size();
169 Array<Doc> indices;
170 indices.reserve(n);
171 for (int i = 0; i < n; ++i) {
172 Range range = region[i];
173 ObjectPath range_p = p->ArrayIndex(i);
174 ExprDoc min = d->AsDoc<ExprDoc>(range->min, range_p->Attr("min"));
175 if (tir::is_one(range->extent)) {
176 indices.push_back(min);
177 } else {
178 ExprDoc max = d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent"));
179 indices.push_back(SliceDoc(min, max, NullOpt));
180 }
181 }
182 return indices;
183}
184
185TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
186 .set_dispatch<tir::BufferRegion>(
187 "", [](tir::BufferRegion buffer_region, ObjectPath p, IRDocsifier d) -> Doc {
188 ExprDoc prefix = d->AsDoc<ExprDoc>(buffer_region->buffer, p->Attr("buffer"));
189 return prefix[BufferSlices(buffer_region->region, p->Attr("region"), d)];
190 });
191
192TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
193 .set_dispatch<tir::BufferStore>( //
194 "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc {
195 ExprDoc buffer = d->AsDoc<ExprDoc>(store->buffer, p->Attr("buffer"));
196 return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)],
197 /*rhs=*/d->AsDoc<ExprDoc>(store->value, p->Attr("value")), NullOpt);
198 });
199
200TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
201 .set_dispatch<tir::BufferLoad>( //
202 "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc {
203 ExprDoc buffer = d->AsDoc<ExprDoc>(load->buffer, p->Attr("buffer"));
204 return buffer[BufferIndices(load->indices, p->Attr("indices"), d)];
205 });
206
207TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
208 .set_dispatch<tir::Buffer>("", [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc {
209 if (!d->IsVarDefined(buffer)) {
210 if (Optional<Frame> opt_f = FindLowestVarDef(buffer, d)) {
211 ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d);
212 ExprDoc rhs = BufferDecl(buffer, "Buffer", {}, p, opt_f.value(), d);
213 opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
214 }
215 }
216 if (Optional<ExprDoc> doc = d->GetVarDoc(buffer)) {
217 return doc.value();
218 }
219 LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer;
220 });
221
222TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
223 .set_dispatch<tir::MatchBufferRegion>(
224 "", [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc {
225 Frame frame = d->frames.back();
226 ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d);
227 ExprDoc src_buffer = d->AsDoc<ExprDoc>(stmt->source, p->Attr("source"));
228 ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer", {src_buffer}, p->Attr("buffer"),
229 d->frames.back(), d);
230 return AssignDoc(lhs, rhs, NullOpt);
231 });
232
233TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
234 .set_dispatch<tir::ProducerLoad>( //
235 "", [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc {
236 ExprDoc prefix = IdDoc(load->producer->GetNameHint());
237 return prefix[BufferIndices(load->indices, p->Attr("indices"), d)];
238 });
239
240TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
241 .set_dispatch<tir::ProducerStore>( //
242 "", [](tir::ProducerStore store, ObjectPath p, IRDocsifier d) -> Doc {
243 ExprDoc prefix = IdDoc(store->producer->GetNameHint());
244 prefix = prefix[BufferIndices(store->indices, p->Attr("indices"), d)];
245 return AssignDoc(prefix, d->AsDoc<ExprDoc>(store->value, p->Attr("value")), NullOpt);
246 });
247
248TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
249 .set_dispatch<tir::ProducerRealize>( //
250 "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
251 ExprDoc prefix = IdDoc(stmt->producer->GetNameHint());
252 prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds"), d)];
253 prefix = TIR(d, "ProducerRealize")
254 ->Call({prefix, d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"))});
255 With<TIRFrame> f(d, stmt);
256 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
257 return ScopeDoc(NullOpt, prefix, (*f)->stmts);
258 });
259
260TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR);
261TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR);
262TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR);
263TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR);
264TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR);
265TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR);
266TVM_SCRIPT_REPR(tir::ProducerStoreNode, ReprPrintTIR);
267TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR);
268
269} // namespace printer
270} // namespace script
271} // namespace tvm
272