1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
20#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
21
22#include <tvm/ir/module.h>
23#include <tvm/node/node.h>
24#include <tvm/script/printer/doc.h>
25#include <tvm/script/printer/ir_docsifier_functor.h>
26
27#include <string>
28#include <unordered_map>
29#include <unordered_set>
30#include <utility>
31#include <vector>
32
33namespace tvm {
34namespace script {
35namespace printer {
36
37//////////////////////// Frame ////////////////////////
38
39class IRDocsifier;
40class IRDocsifierNode;
41
42/*!
43 * Frame is the core data structure for semantic information
44 * when printing IR graph into TVMScript code.
45 */
46class FrameNode : public Object {
47 public:
48 /*! The docs generated in the frame */
49 Array<StmtDoc> stmts;
50 /*! The corresponding IRDocsifier */
51 IRDocsifierNode* d;
52 /*! The callbacks that are going to be invoked when the frame exits */
53 std::vector<std::function<void()>> callbacks;
54
55 void VisitAttrs(tvm::AttrVisitor* v) {
56 v->Visit("stmts", &stmts);
57 // `d` is not visited
58 // `callbacks` is not visited
59 }
60
61 static constexpr const char* _type_key = "script.printer.Frame";
62 TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object);
63
64 public:
65 virtual ~FrameNode() = default;
66
67 /*!
68 * \brief Add a callback function to be called when this frame exits.
69 * \param cb The callback function. It should have signature void().
70 */
71 template <typename TCallback>
72 void AddExitCallback(TCallback&& cb) {
73 callbacks.emplace_back(std::forward<TCallback>(cb));
74 }
75 /*!
76 * \brief Add a dispatch token to the docsifier, and a callback that pops the token when this
77 * frame exits.
78 * \param d The docsifier.
79 * \param token The token to be added.
80 */
81 void AddDispatchToken(const IRDocsifier& d, const String& token);
82 /*!
83 * \brief Method that's called when Frame enters the scope.
84 */
85 virtual void EnterWithScope();
86 /*!
87 * \brief Method that's called when Frame exits the scope.
88 */
89 virtual void ExitWithScope();
90};
91
92/*!
93 * \brief Reference type of FrameNode
94 */
95class Frame : public ObjectRef {
96 protected:
97 Frame() = default;
98
99 public:
100 virtual ~Frame() = default;
101
102 /*! \brief Method that's called when Frame enters the scope. */
103 void EnterWithScope() { get()->EnterWithScope(); }
104
105 /*! \brief Method that's called when Frame exits the scope. */
106 void ExitWithScope() { get()->ExitWithScope(); }
107
108 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode);
109};
110
111//////////////////////// IRDocsifier ////////////////////////
112
113/*!
114 * \brief IRDocsifier is the top-level interface in the IR->Doc process.
115 *
116 * It provides methods to convert IR node object to Doc, operate on Frame
117 * objects and change dispatch tokens.
118 */
119class IRDocsifierNode : public Object {
120 public:
121 /*! \brief A function that creates the doc for a variable */
122 using DocCreator = std::function<ExprDoc()>;
123 /*! \brief Information about a variable, including its optional name and its doc creator */
124 struct VariableInfo {
125 /*! \brief The creator */
126 DocCreator creator;
127 /*! \brief The name of the variable */
128 Optional<String> name;
129 };
130 /*! \brief The configuration of the printer */
131 PrinterConfig cfg{nullptr};
132 /*!
133 * \brief The stack of frames.
134 * \sa FrameNode
135 */
136 Array<Frame> frames;
137 /*!
138 * \brief The stack of dispatch tokens.
139 *
140 * The dispatch token on the top decides which dispatch function to use
141 * when converting IR node object to Doc.
142 */
143 Array<String> dispatch_tokens;
144 /*! \brief Mapping from a var to its info */
145 std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
146 /*! \brief Metadata printing */
147 std::unordered_map<String, Array<ObjectRef>> metadata;
148 /*! \brief The variable names used already */
149 std::unordered_set<String> defined_names;
150 /*! \brief Common prefixes of variable usages */
151 std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
152 /*! \brief The IR usages for headers printing */
153 std::unordered_set<std::string> ir_usage;
154
155 void VisitAttrs(tvm::AttrVisitor* v) {
156 v->Visit("frames", &frames);
157 v->Visit("dispatch_tokens", &dispatch_tokens);
158 // `obj2info` is not visited
159 // `metadata` is not visited
160 // `defined_names` is not visited
161 // `common_prefix` is not visited
162 // `ir_usage` is not visited
163 }
164
165 static constexpr const char* _type_key = "script.printer.IRDocsifier";
166 TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object);
167
168 public:
169 /*!
170 * \brief Define variable by name.
171 * \param obj The variable object.
172 * \param frame The frame that this variable is defined in.
173 * \param name_hint The hint for variable name.
174 *
175 * \return The id doc for this variable.
176 *
177 * This function will rename the variable to avoid name conflict with other variables
178 * in the table.
179 */
180 IdDoc Define(const ObjectRef& obj, const Frame& frame, const String& name_hint);
181
182 /*!
183 * \brief Define variable by doc factory.
184 * \param obj The variable object.
185 * \param frame The frame that this variable is defined in.
186 * \param doc_factory The function to return an ExprDoc object for this variable.
187 *
188 * This function is a special form of `Define`. Variable is mapped to ExprDoc rather
189 * than IdDoc. It's useful when a variable is implicitly defined without a name, like
190 * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`.
191 *
192 * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to
193 * return a new Doc object every time it's called, as the returned doc will have
194 * different `source_path`. Currently there isn't a good way to deep copy a TVMObject
195 * so VarTable needs to call a factory function to get a freshly-constructed Doc object
196 * every time GetVarDoc is called.
197 */
198 void Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory);
199
200 /*!
201 * \brief Get the doc for variable.
202 * \param obj The variable object.
203 *
204 * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
205 */
206 Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
207 /*! \brief Add a TVM object to the metadata section*/
208 ExprDoc AddMetadata(const ObjectRef& obj);
209 /*!
210 * \brief Check if a variable exists in the table.
211 * \param obj The variable object.
212 *
213 * \return a boolean for whether variable exists.
214 */
215 bool IsVarDefined(const ObjectRef& obj) const;
216 /*! \brief Remove the variable defined */
217 void RemoveVar(const ObjectRef& obj);
218 /*!
219 * \brief Set the common prefix information of variable usage.
220 * \param root The root of the AST.
221 * \param is_var A function that returns true if the given object is considered a variable.
222 */
223 void SetCommonPrefix(const ObjectRef& root, runtime::TypedPackedFunc<bool(ObjectRef)> is_var);
224 /*!
225 * \brief Transform the input object into TDoc.
226 * \param obj The object to be transformed.
227 * \param path The path to this object.
228 *
229 * \return The Doc object.
230 */
231 template <class TDoc = Doc>
232 inline TDoc AsDoc(const ObjectRef& obj, const ObjectPath& path) const;
233};
234
235/*!
236 * \brief Reference type of IRDocsifierNode.
237 */
238class IRDocsifier : public ObjectRef {
239 public:
240 using FType = IRDocsifierFunctor<printer::Doc, ObjectPath, IRDocsifier>;
241 /*! \brief Create a IRDocsifier. */
242 explicit IRDocsifier(const PrinterConfig& cfg);
243 /*! \brief The registration table for IRDocsifier. */
244 TVM_DLL static FType& vtable();
245
246 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
247};
248
249//////////////////////// Implementation ////////////////////////
250
251inline void FrameNode::EnterWithScope() {
252 if (d != nullptr) {
253 d->frames.push_back(GetRef<Frame>(this));
254 }
255}
256
257inline void FrameNode::ExitWithScope() {
258 for (const std::function<void()>& callback : callbacks) {
259 callback();
260 }
261 callbacks.clear();
262 if (d != nullptr) {
263 d->frames.pop_back();
264 }
265}
266
267template <class TDoc>
268inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const ObjectPath& path,
269 const PrinterConfig& cfg) {
270 if (cfg->obj_to_annotate.count(obj)) {
271 if (const auto* stmt = d.as<StmtDocNode>()) {
272 if (stmt->comment.defined()) {
273 stmt->comment = stmt->comment.value() + "\n" + cfg->obj_to_annotate.at(obj);
274 } else {
275 stmt->comment = cfg->obj_to_annotate.at(obj);
276 }
277 } else {
278 LOG(WARNING) << "Expect StmtDoc to be annotated for object " << obj << ", but got "
279 << Downcast<TDoc>(d)->_type_key;
280 }
281 }
282 for (const ObjectRef& o : cfg->obj_to_underline) {
283 if (o.same_as(obj)) {
284 cfg->path_to_underline.push_back(path);
285 }
286 }
287 for (const auto& pair : cfg->path_to_annotate) {
288 ObjectPath p = pair.first;
289 String attn = pair.second;
290 if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) {
291 if (const auto* stmt = d.as<StmtDocNode>()) {
292 if (stmt->comment.defined()) {
293 stmt->comment = stmt->comment.value() + "\n" + attn;
294 } else {
295 stmt->comment = attn;
296 }
297 } else {
298 LOG(WARNING) << "Expect StmtDoc to be annotated at object path " << p << ", but got "
299 << Downcast<TDoc>(d)->_type_key;
300 }
301 }
302 }
303}
304
305template <class TDoc>
306inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const {
307 if (obj.defined()) {
308 Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this));
309 d->source_paths.push_back(path);
310 AddDocDecoration<TDoc>(d, obj, path, cfg);
311 return Downcast<TDoc>(d);
312 }
313 return Downcast<TDoc>(LiteralDoc::None(path));
314}
315
316inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) {
317 d->dispatch_tokens.push_back(token);
318 this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); });
319}
320
321} // namespace printer
322} // namespace script
323} // namespace tvm
324
325#endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
326