1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_SCRIPT_PRINTER_TIR_UTILS_H_ |
20 | #define TVM_SCRIPT_PRINTER_TIR_UTILS_H_ |
21 | |
22 | #include <tvm/script/printer/ir_docsifier.h> |
23 | #include <tvm/tir/analysis.h> |
24 | #include <tvm/tir/buffer.h> |
25 | #include <tvm/tir/expr.h> |
26 | #include <tvm/tir/function.h> |
27 | #include <tvm/tir/index_map.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/stmt.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | |
32 | #include <string> |
33 | #include <unordered_map> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | #include "../utils.h" |
38 | |
39 | namespace tvm { |
40 | namespace script { |
41 | namespace printer { |
42 | |
43 | /*! \brief A printer frame for TIR fragment */ |
44 | class TIRFrameNode : public FrameNode { |
45 | public: |
46 | /*! \brief The TIR fragment the frame corresponds to */ |
47 | ObjectRef tir; |
48 | /*! \brief Whether or not the frame allows concise scoping */ |
49 | bool allow_concise_scoping{false}; |
50 | |
51 | void VisitAttrs(AttrVisitor* v) { |
52 | FrameNode::VisitAttrs(v); |
53 | v->Visit("tir" , &tir); |
54 | v->Visit("allow_concise_scoping" , &allow_concise_scoping); |
55 | } |
56 | |
57 | static constexpr const char* _type_key = "script.printer.TIRFrame" ; |
58 | TVM_DECLARE_FINAL_OBJECT_INFO(TIRFrameNode, FrameNode); |
59 | }; |
60 | |
61 | /*! \brief Managed reference to TIRFrameNode */ |
62 | class TIRFrame : public Frame { |
63 | public: |
64 | /*! \brief Constructor */ |
65 | explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) { |
66 | ObjectPtr<TIRFrameNode> n = make_object<TIRFrameNode>(); |
67 | n->stmts.clear(); |
68 | n->d = d.get(); |
69 | n->tir = tir; |
70 | data_ = std::move(n); |
71 | } |
72 | |
73 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); |
74 | }; |
75 | |
76 | /*! |
77 | * \brief Defines a variable in the IRDocsifier at the given frame, |
78 | * and returns the corresponding IdDoc |
79 | * \param var The variable to define |
80 | * \param d The IRDocsifier |
81 | * \param frame The frame to define the variable in |
82 | * \return The IdDoc corresponding to the variable |
83 | */ |
84 | inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { |
85 | if (Optional<ExprDoc> doc = d->GetVarDoc(var)) { |
86 | return doc.value(); |
87 | } |
88 | return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint); |
89 | } |
90 | |
91 | /*! |
92 | * \brief Defines a buffer in the IRDocsifier at the given frame, |
93 | * and returns the corresponding IdDoc |
94 | * \param buffer The buffer to define |
95 | * \param frame The frame to define the buffer in |
96 | * \param d The IRDocsifier |
97 | * \return The IdDoc corresponding to the buffer |
98 | */ |
99 | inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const IRDocsifier& d) { |
100 | return d->Define(buffer, frame, buffer->name.empty() ? "buffer" : buffer->name); |
101 | } |
102 | |
103 | /*! |
104 | * \brief Recursively process the body statements of a TIR fragment represented by a frame |
105 | * \param stmt The body statement to process |
106 | * \param p The object path |
107 | * \param f The frame |
108 | * \param d The IRDocsifier |
109 | */ |
110 | inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, const IRDocsifier& d) { |
111 | if (const auto* seq_stmt = stmt.as<tir::SeqStmtNode>()) { |
112 | Array<tir::Stmt> body = seq_stmt->seq; |
113 | for (int i = 0, n = body.size(); i < n; ++i) { |
114 | f->allow_concise_scoping = (i == n - 1); |
115 | Doc doc = d->AsDoc(body[i], p->Attr("seq" )->ArrayIndex(i)); |
116 | doc->source_paths.push_back(p); |
117 | if (const auto* block = doc.as<StmtBlockDocNode>()) { |
118 | f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); |
119 | } else { |
120 | f->stmts.push_back(Downcast<StmtDoc>(doc)); |
121 | } |
122 | } |
123 | } else { |
124 | f->allow_concise_scoping = true; |
125 | Doc doc = d->AsDoc(stmt, p); |
126 | if (const auto* block = doc.as<StmtBlockDocNode>()) { |
127 | f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); |
128 | } else { |
129 | f->stmts.push_back(Downcast<StmtDoc>(doc)); |
130 | } |
131 | } |
132 | } |
133 | |
134 | /*! |
135 | * \brief Find the top frame in the stack that could place a var definition |
136 | * \param var The var to be defined |
137 | * \param d The IRDocsifier |
138 | * \return The frame that could place the var definition |
139 | */ |
140 | inline Optional<Frame> FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { |
141 | if (!d->common_prefix.count(var.get())) { |
142 | return NullOpt; |
143 | } |
144 | int n_frames = d->frames.size(); |
145 | std::unordered_map<const Object*, const FrameNode*> tir_to_frame; |
146 | const FrameNode* fallback_frame = nullptr; |
147 | tir_to_frame.reserve(n_frames); |
148 | for (int i = n_frames - 1; i >= 0; --i) { |
149 | if (const auto* f = d->frames[i].as<TIRFrameNode>()) { |
150 | if (f->tir.defined()) { |
151 | tir_to_frame[f->tir.get()] = f; |
152 | } else if (fallback_frame == nullptr) { |
153 | fallback_frame = f; |
154 | } |
155 | } |
156 | } |
157 | const std::vector<const Object*>& path = d->common_prefix.at(var.get()); |
158 | for (auto it = path.rbegin(); it != path.rend(); ++it) { |
159 | if (tir_to_frame.count(*it)) { |
160 | return GetRef<Frame>(tir_to_frame.at(*it)); |
161 | } |
162 | } |
163 | if (fallback_frame != nullptr) { |
164 | return GetRef<Frame>(fallback_frame); |
165 | } |
166 | return NullOpt; |
167 | } |
168 | |
169 | /*! \brief Redirected method for the ReprPrinter */ |
170 | inline std::string ReprPrintTIR(const ObjectRef& obj, const PrinterConfig& cfg) { |
171 | IRDocsifier d(cfg); |
172 | d->SetCommonPrefix(obj, [](const ObjectRef& obj) { |
173 | return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>(); |
174 | }); |
175 | With<TIRFrame> f(d, ObjectRef{nullptr}); |
176 | (*f)->AddDispatchToken(d, "tir" ); |
177 | return Docsify(obj, d, *f, cfg); |
178 | } |
179 | |
180 | /*! |
181 | * \brief Declare and define a buffer |
182 | * \param buffer The buffer to be defined |
183 | * \param method The method used to declare the buffer |
184 | * \param args The extra arguments used to declare the buffer |
185 | * \param p The object path |
186 | * \param f The frame |
187 | * \param d The IRDocsifier |
188 | * \return The ExprDoc corresponding to the buffer declaration |
189 | */ |
190 | ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args, |
191 | const ObjectPath& p, const Frame& frame, const IRDocsifier& d); |
192 | |
193 | /*! |
194 | * \brief Declare and define a buffer as annotation |
195 | * \param buffer The buffer to be defined |
196 | * \param p The object path |
197 | * \param f The frame |
198 | * \param d The IRDocsifier |
199 | * \return The ExprDoc corresponding to the buffer declaration |
200 | */ |
201 | ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, |
202 | const IRDocsifier& d); |
203 | |
204 | /*! \brief A Var occurrence counter visitor */ |
205 | class OccurrenceCounter : public tir::StmtExprVisitor { |
206 | public: |
207 | /*! \brief The occurrence counter */ |
208 | int count = 0; |
209 | /*! \brief The Var to count occurrence */ |
210 | const tir::VarNode* v = nullptr; |
211 | |
212 | void VisitExpr_(const tir::VarNode* op) final { |
213 | if (op == v) { |
214 | ++count; |
215 | } |
216 | tir::StmtExprVisitor::VisitExpr_(op); |
217 | } |
218 | |
219 | void VisitStmt_(const tir::BufferStoreNode* op) final { |
220 | VisitBuffer(op->buffer.get()); |
221 | tir::StmtExprVisitor::VisitStmt_(op); |
222 | } |
223 | |
224 | void VisitExpr_(const tir::BufferLoadNode* op) final { |
225 | VisitBuffer(op->buffer.get()); |
226 | tir::StmtExprVisitor::VisitExpr_(op); |
227 | } |
228 | |
229 | void VisitStmt_(const tir::DeclBufferNode* op) final { |
230 | VisitBuffer(op->buffer.get()); |
231 | tir::StmtExprVisitor::VisitStmt_(op); |
232 | } |
233 | |
234 | void VisitBuffer(const tir::BufferNode* buffer) { |
235 | VisitExpr(buffer->data); |
236 | for (const PrimExpr& shape_i : buffer->shape) { |
237 | VisitExpr(shape_i); |
238 | } |
239 | for (const PrimExpr& stride_i : buffer->strides) { |
240 | VisitExpr(stride_i); |
241 | } |
242 | VisitExpr(buffer->elem_offset); |
243 | } |
244 | |
245 | explicit OccurrenceCounter(const tir::VarNode* var) { v = var; } |
246 | }; |
247 | |
248 | } // namespace printer |
249 | } // namespace script |
250 | } // namespace tvm |
251 | |
252 | #endif // TVM_SCRIPT_PRINTER_TIR_UTILS_H_ |
253 | |