1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ |
20 | #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ |
21 | |
22 | #include <tvm/ir/module.h> |
23 | #include <tvm/node/node.h> |
24 | #include <tvm/script/printer/doc.h> |
25 | #include <tvm/script/printer/ir_docsifier_functor.h> |
26 | |
27 | #include <string> |
28 | #include <unordered_map> |
29 | #include <unordered_set> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | namespace tvm { |
34 | namespace script { |
35 | namespace printer { |
36 | |
37 | //////////////////////// Frame //////////////////////// |
38 | |
39 | class IRDocsifier; |
40 | class IRDocsifierNode; |
41 | |
42 | /*! |
43 | * Frame is the core data structure for semantic information |
44 | * when printing IR graph into TVMScript code. |
45 | */ |
46 | class FrameNode : public Object { |
47 | public: |
48 | /*! The docs generated in the frame */ |
49 | Array<StmtDoc> stmts; |
50 | /*! The corresponding IRDocsifier */ |
51 | IRDocsifierNode* d; |
52 | /*! The callbacks that are going to be invoked when the frame exits */ |
53 | std::vector<std::function<void()>> callbacks; |
54 | |
55 | void VisitAttrs(tvm::AttrVisitor* v) { |
56 | v->Visit("stmts" , &stmts); |
57 | // `d` is not visited |
58 | // `callbacks` is not visited |
59 | } |
60 | |
61 | static constexpr const char* _type_key = "script.printer.Frame" ; |
62 | TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); |
63 | |
64 | public: |
65 | virtual ~FrameNode() = default; |
66 | |
67 | /*! |
68 | * \brief Add a callback function to be called when this frame exits. |
69 | * \param cb The callback function. It should have signature void(). |
70 | */ |
71 | template <typename TCallback> |
72 | void AddExitCallback(TCallback&& cb) { |
73 | callbacks.emplace_back(std::forward<TCallback>(cb)); |
74 | } |
75 | /*! |
76 | * \brief Add a dispatch token to the docsifier, and a callback that pops the token when this |
77 | * frame exits. |
78 | * \param d The docsifier. |
79 | * \param token The token to be added. |
80 | */ |
81 | void AddDispatchToken(const IRDocsifier& d, const String& token); |
82 | /*! |
83 | * \brief Method that's called when Frame enters the scope. |
84 | */ |
85 | virtual void EnterWithScope(); |
86 | /*! |
87 | * \brief Method that's called when Frame exits the scope. |
88 | */ |
89 | virtual void ExitWithScope(); |
90 | }; |
91 | |
92 | /*! |
93 | * \brief Reference type of FrameNode |
94 | */ |
95 | class Frame : public ObjectRef { |
96 | protected: |
97 | Frame() = default; |
98 | |
99 | public: |
100 | virtual ~Frame() = default; |
101 | |
102 | /*! \brief Method that's called when Frame enters the scope. */ |
103 | void EnterWithScope() { get()->EnterWithScope(); } |
104 | |
105 | /*! \brief Method that's called when Frame exits the scope. */ |
106 | void ExitWithScope() { get()->ExitWithScope(); } |
107 | |
108 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); |
109 | }; |
110 | |
111 | //////////////////////// IRDocsifier //////////////////////// |
112 | |
113 | /*! |
114 | * \brief IRDocsifier is the top-level interface in the IR->Doc process. |
115 | * |
116 | * It provides methods to convert IR node object to Doc, operate on Frame |
117 | * objects and change dispatch tokens. |
118 | */ |
119 | class IRDocsifierNode : public Object { |
120 | public: |
121 | /*! \brief A function that creates the doc for a variable */ |
122 | using DocCreator = std::function<ExprDoc()>; |
123 | /*! \brief Information about a variable, including its optional name and its doc creator */ |
124 | struct VariableInfo { |
125 | /*! \brief The creator */ |
126 | DocCreator creator; |
127 | /*! \brief The name of the variable */ |
128 | Optional<String> name; |
129 | }; |
130 | /*! \brief The configuration of the printer */ |
131 | PrinterConfig cfg{nullptr}; |
132 | /*! |
133 | * \brief The stack of frames. |
134 | * \sa FrameNode |
135 | */ |
136 | Array<Frame> frames; |
137 | /*! |
138 | * \brief The stack of dispatch tokens. |
139 | * |
140 | * The dispatch token on the top decides which dispatch function to use |
141 | * when converting IR node object to Doc. |
142 | */ |
143 | Array<String> dispatch_tokens; |
144 | /*! \brief Mapping from a var to its info */ |
145 | std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info; |
146 | /*! \brief Metadata printing */ |
147 | std::unordered_map<String, Array<ObjectRef>> metadata; |
148 | /*! \brief The variable names used already */ |
149 | std::unordered_set<String> defined_names; |
150 | /*! \brief Common prefixes of variable usages */ |
151 | std::unordered_map<const Object*, std::vector<const Object*>> common_prefix; |
152 | /*! \brief The IR usages for headers printing */ |
153 | std::unordered_set<std::string> ir_usage; |
154 | |
155 | void VisitAttrs(tvm::AttrVisitor* v) { |
156 | v->Visit("frames" , &frames); |
157 | v->Visit("dispatch_tokens" , &dispatch_tokens); |
158 | // `obj2info` is not visited |
159 | // `metadata` is not visited |
160 | // `defined_names` is not visited |
161 | // `common_prefix` is not visited |
162 | // `ir_usage` is not visited |
163 | } |
164 | |
165 | static constexpr const char* _type_key = "script.printer.IRDocsifier" ; |
166 | TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object); |
167 | |
168 | public: |
169 | /*! |
170 | * \brief Define variable by name. |
171 | * \param obj The variable object. |
172 | * \param frame The frame that this variable is defined in. |
173 | * \param name_hint The hint for variable name. |
174 | * |
175 | * \return The id doc for this variable. |
176 | * |
177 | * This function will rename the variable to avoid name conflict with other variables |
178 | * in the table. |
179 | */ |
180 | IdDoc Define(const ObjectRef& obj, const Frame& frame, const String& name_hint); |
181 | |
182 | /*! |
183 | * \brief Define variable by doc factory. |
184 | * \param obj The variable object. |
185 | * \param frame The frame that this variable is defined in. |
186 | * \param doc_factory The function to return an ExprDoc object for this variable. |
187 | * |
188 | * This function is a special form of `Define`. Variable is mapped to ExprDoc rather |
189 | * than IdDoc. It's useful when a variable is implicitly defined without a name, like |
190 | * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`. |
191 | * |
192 | * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to |
193 | * return a new Doc object every time it's called, as the returned doc will have |
194 | * different `source_path`. Currently there isn't a good way to deep copy a TVMObject |
195 | * so VarTable needs to call a factory function to get a freshly-constructed Doc object |
196 | * every time GetVarDoc is called. |
197 | */ |
198 | void Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory); |
199 | |
200 | /*! |
201 | * \brief Get the doc for variable. |
202 | * \param obj The variable object. |
203 | * |
204 | * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. |
205 | */ |
206 | Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const; |
207 | /*! \brief Add a TVM object to the metadata section*/ |
208 | ExprDoc AddMetadata(const ObjectRef& obj); |
209 | /*! |
210 | * \brief Check if a variable exists in the table. |
211 | * \param obj The variable object. |
212 | * |
213 | * \return a boolean for whether variable exists. |
214 | */ |
215 | bool IsVarDefined(const ObjectRef& obj) const; |
216 | /*! \brief Remove the variable defined */ |
217 | void RemoveVar(const ObjectRef& obj); |
218 | /*! |
219 | * \brief Set the common prefix information of variable usage. |
220 | * \param root The root of the AST. |
221 | * \param is_var A function that returns true if the given object is considered a variable. |
222 | */ |
223 | void SetCommonPrefix(const ObjectRef& root, runtime::TypedPackedFunc<bool(ObjectRef)> is_var); |
224 | /*! |
225 | * \brief Transform the input object into TDoc. |
226 | * \param obj The object to be transformed. |
227 | * \param path The path to this object. |
228 | * |
229 | * \return The Doc object. |
230 | */ |
231 | template <class TDoc = Doc> |
232 | inline TDoc AsDoc(const ObjectRef& obj, const ObjectPath& path) const; |
233 | }; |
234 | |
235 | /*! |
236 | * \brief Reference type of IRDocsifierNode. |
237 | */ |
238 | class IRDocsifier : public ObjectRef { |
239 | public: |
240 | using FType = IRDocsifierFunctor<printer::Doc, ObjectPath, IRDocsifier>; |
241 | /*! \brief Create a IRDocsifier. */ |
242 | explicit IRDocsifier(const PrinterConfig& cfg); |
243 | /*! \brief The registration table for IRDocsifier. */ |
244 | TVM_DLL static FType& vtable(); |
245 | |
246 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode); |
247 | }; |
248 | |
249 | //////////////////////// Implementation //////////////////////// |
250 | |
251 | inline void FrameNode::EnterWithScope() { |
252 | if (d != nullptr) { |
253 | d->frames.push_back(GetRef<Frame>(this)); |
254 | } |
255 | } |
256 | |
257 | inline void FrameNode::ExitWithScope() { |
258 | for (const std::function<void()>& callback : callbacks) { |
259 | callback(); |
260 | } |
261 | callbacks.clear(); |
262 | if (d != nullptr) { |
263 | d->frames.pop_back(); |
264 | } |
265 | } |
266 | |
267 | template <class TDoc> |
268 | inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const ObjectPath& path, |
269 | const PrinterConfig& cfg) { |
270 | if (cfg->obj_to_annotate.count(obj)) { |
271 | if (const auto* stmt = d.as<StmtDocNode>()) { |
272 | if (stmt->comment.defined()) { |
273 | stmt->comment = stmt->comment.value() + "\n" + cfg->obj_to_annotate.at(obj); |
274 | } else { |
275 | stmt->comment = cfg->obj_to_annotate.at(obj); |
276 | } |
277 | } else { |
278 | LOG(WARNING) << "Expect StmtDoc to be annotated for object " << obj << ", but got " |
279 | << Downcast<TDoc>(d)->_type_key; |
280 | } |
281 | } |
282 | for (const ObjectRef& o : cfg->obj_to_underline) { |
283 | if (o.same_as(obj)) { |
284 | cfg->path_to_underline.push_back(path); |
285 | } |
286 | } |
287 | for (const auto& pair : cfg->path_to_annotate) { |
288 | ObjectPath p = pair.first; |
289 | String attn = pair.second; |
290 | if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { |
291 | if (const auto* stmt = d.as<StmtDocNode>()) { |
292 | if (stmt->comment.defined()) { |
293 | stmt->comment = stmt->comment.value() + "\n" + attn; |
294 | } else { |
295 | stmt->comment = attn; |
296 | } |
297 | } else { |
298 | LOG(WARNING) << "Expect StmtDoc to be annotated at object path " << p << ", but got " |
299 | << Downcast<TDoc>(d)->_type_key; |
300 | } |
301 | } |
302 | } |
303 | } |
304 | |
305 | template <class TDoc> |
306 | inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const { |
307 | if (obj.defined()) { |
308 | Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this)); |
309 | d->source_paths.push_back(path); |
310 | AddDocDecoration<TDoc>(d, obj, path, cfg); |
311 | return Downcast<TDoc>(d); |
312 | } |
313 | return Downcast<TDoc>(LiteralDoc::None(path)); |
314 | } |
315 | |
316 | inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) { |
317 | d->dispatch_tokens.push_back(token); |
318 | this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); }); |
319 | } |
320 | |
321 | } // namespace printer |
322 | } // namespace script |
323 | } // namespace tvm |
324 | |
325 | #endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ |
326 | |