1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SCRIPT_PRINTER_TIR_UTILS_H_
20#define TVM_SCRIPT_PRINTER_TIR_UTILS_H_
21
22#include <tvm/script/printer/ir_docsifier.h>
23#include <tvm/tir/analysis.h>
24#include <tvm/tir/buffer.h>
25#include <tvm/tir/expr.h>
26#include <tvm/tir/function.h>
27#include <tvm/tir/index_map.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/stmt.h>
30#include <tvm/tir/stmt_functor.h>
31
32#include <string>
33#include <unordered_map>
34#include <utility>
35#include <vector>
36
37#include "../utils.h"
38
39namespace tvm {
40namespace script {
41namespace printer {
42
43/*! \brief A printer frame for TIR fragment */
44class TIRFrameNode : public FrameNode {
45 public:
46 /*! \brief The TIR fragment the frame corresponds to */
47 ObjectRef tir;
48 /*! \brief Whether or not the frame allows concise scoping */
49 bool allow_concise_scoping{false};
50
51 void VisitAttrs(AttrVisitor* v) {
52 FrameNode::VisitAttrs(v);
53 v->Visit("tir", &tir);
54 v->Visit("allow_concise_scoping", &allow_concise_scoping);
55 }
56
57 static constexpr const char* _type_key = "script.printer.TIRFrame";
58 TVM_DECLARE_FINAL_OBJECT_INFO(TIRFrameNode, FrameNode);
59};
60
61/*! \brief Managed reference to TIRFrameNode */
62class TIRFrame : public Frame {
63 public:
64 /*! \brief Constructor */
65 explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) {
66 ObjectPtr<TIRFrameNode> n = make_object<TIRFrameNode>();
67 n->stmts.clear();
68 n->d = d.get();
69 n->tir = tir;
70 data_ = std::move(n);
71 }
72
73 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode);
74};
75
76/*!
77 * \brief Defines a variable in the IRDocsifier at the given frame,
78 * and returns the corresponding IdDoc
79 * \param var The variable to define
80 * \param d The IRDocsifier
81 * \param frame The frame to define the variable in
82 * \return The IdDoc corresponding to the variable
83 */
84inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) {
85 if (Optional<ExprDoc> doc = d->GetVarDoc(var)) {
86 return doc.value();
87 }
88 return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint);
89}
90
91/*!
92 * \brief Defines a buffer in the IRDocsifier at the given frame,
93 * and returns the corresponding IdDoc
94 * \param buffer The buffer to define
95 * \param frame The frame to define the buffer in
96 * \param d The IRDocsifier
97 * \return The IdDoc corresponding to the buffer
98 */
99inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const IRDocsifier& d) {
100 return d->Define(buffer, frame, buffer->name.empty() ? "buffer" : buffer->name);
101}
102
103/*!
104 * \brief Recursively process the body statements of a TIR fragment represented by a frame
105 * \param stmt The body statement to process
106 * \param p The object path
107 * \param f The frame
108 * \param d The IRDocsifier
109 */
110inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, const IRDocsifier& d) {
111 if (const auto* seq_stmt = stmt.as<tir::SeqStmtNode>()) {
112 Array<tir::Stmt> body = seq_stmt->seq;
113 for (int i = 0, n = body.size(); i < n; ++i) {
114 f->allow_concise_scoping = (i == n - 1);
115 Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayIndex(i));
116 doc->source_paths.push_back(p);
117 if (const auto* block = doc.as<StmtBlockDocNode>()) {
118 f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end());
119 } else {
120 f->stmts.push_back(Downcast<StmtDoc>(doc));
121 }
122 }
123 } else {
124 f->allow_concise_scoping = true;
125 Doc doc = d->AsDoc(stmt, p);
126 if (const auto* block = doc.as<StmtBlockDocNode>()) {
127 f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end());
128 } else {
129 f->stmts.push_back(Downcast<StmtDoc>(doc));
130 }
131 }
132}
133
134/*!
135 * \brief Find the top frame in the stack that could place a var definition
136 * \param var The var to be defined
137 * \param d The IRDocsifier
138 * \return The frame that could place the var definition
139 */
140inline Optional<Frame> FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) {
141 if (!d->common_prefix.count(var.get())) {
142 return NullOpt;
143 }
144 int n_frames = d->frames.size();
145 std::unordered_map<const Object*, const FrameNode*> tir_to_frame;
146 const FrameNode* fallback_frame = nullptr;
147 tir_to_frame.reserve(n_frames);
148 for (int i = n_frames - 1; i >= 0; --i) {
149 if (const auto* f = d->frames[i].as<TIRFrameNode>()) {
150 if (f->tir.defined()) {
151 tir_to_frame[f->tir.get()] = f;
152 } else if (fallback_frame == nullptr) {
153 fallback_frame = f;
154 }
155 }
156 }
157 const std::vector<const Object*>& path = d->common_prefix.at(var.get());
158 for (auto it = path.rbegin(); it != path.rend(); ++it) {
159 if (tir_to_frame.count(*it)) {
160 return GetRef<Frame>(tir_to_frame.at(*it));
161 }
162 }
163 if (fallback_frame != nullptr) {
164 return GetRef<Frame>(fallback_frame);
165 }
166 return NullOpt;
167}
168
169/*! \brief Redirected method for the ReprPrinter */
170inline std::string ReprPrintTIR(const ObjectRef& obj, const PrinterConfig& cfg) {
171 IRDocsifier d(cfg);
172 d->SetCommonPrefix(obj, [](const ObjectRef& obj) {
173 return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>();
174 });
175 With<TIRFrame> f(d, ObjectRef{nullptr});
176 (*f)->AddDispatchToken(d, "tir");
177 return Docsify(obj, d, *f, cfg);
178}
179
180/*!
181 * \brief Declare and define a buffer
182 * \param buffer The buffer to be defined
183 * \param method The method used to declare the buffer
184 * \param args The extra arguments used to declare the buffer
185 * \param p The object path
186 * \param f The frame
187 * \param d The IRDocsifier
188 * \return The ExprDoc corresponding to the buffer declaration
189 */
190ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
191 const ObjectPath& p, const Frame& frame, const IRDocsifier& d);
192
193/*!
194 * \brief Declare and define a buffer as annotation
195 * \param buffer The buffer to be defined
196 * \param p The object path
197 * \param f The frame
198 * \param d The IRDocsifier
199 * \return The ExprDoc corresponding to the buffer declaration
200 */
201ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
202 const IRDocsifier& d);
203
204/*! \brief A Var occurrence counter visitor */
205class OccurrenceCounter : public tir::StmtExprVisitor {
206 public:
207 /*! \brief The occurrence counter */
208 int count = 0;
209 /*! \brief The Var to count occurrence */
210 const tir::VarNode* v = nullptr;
211
212 void VisitExpr_(const tir::VarNode* op) final {
213 if (op == v) {
214 ++count;
215 }
216 tir::StmtExprVisitor::VisitExpr_(op);
217 }
218
219 void VisitStmt_(const tir::BufferStoreNode* op) final {
220 VisitBuffer(op->buffer.get());
221 tir::StmtExprVisitor::VisitStmt_(op);
222 }
223
224 void VisitExpr_(const tir::BufferLoadNode* op) final {
225 VisitBuffer(op->buffer.get());
226 tir::StmtExprVisitor::VisitExpr_(op);
227 }
228
229 void VisitStmt_(const tir::DeclBufferNode* op) final {
230 VisitBuffer(op->buffer.get());
231 tir::StmtExprVisitor::VisitStmt_(op);
232 }
233
234 void VisitBuffer(const tir::BufferNode* buffer) {
235 VisitExpr(buffer->data);
236 for (const PrimExpr& shape_i : buffer->shape) {
237 VisitExpr(shape_i);
238 }
239 for (const PrimExpr& stride_i : buffer->strides) {
240 VisitExpr(stride_i);
241 }
242 VisitExpr(buffer->elem_offset);
243 }
244
245 explicit OccurrenceCounter(const tir::VarNode* var) { v = var; }
246};
247
248} // namespace printer
249} // namespace script
250} // namespace tvm
251
252#endif // TVM_SCRIPT_PRINTER_TIR_UTILS_H_
253