1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_ |
20 | #define TVM_SCRIPT_IR_BUILDER_BASE_H_ |
21 | |
22 | #include <tvm/ir/expr.h> |
23 | #include <tvm/ir/function.h> |
24 | #include <tvm/node/node.h> |
25 | |
26 | #include <vector> |
27 | |
28 | namespace tvm { |
29 | namespace script { |
30 | namespace ir_builder { |
31 | |
32 | ////////////////////////////// IRBuilderFrame ////////////////////////////// |
33 | |
34 | /*! |
35 | * \brief A stack frame of the IRBuilder used to keep track of the current scope. |
36 | * Furthermore, the information stored in each stack frame can be useful for context-dependent |
37 | * IR construction. |
38 | * |
39 | * \example |
40 | * |
41 | * The `T::MatchBuffer` below adds an element in `PrimFuncNode::buffer_map`: |
42 | * |
43 | * \code {.cpp} |
44 | * |
45 | * using T = tvm::script::ir_builder::tir; |
46 | * With <PrimFuncFrame> _(...); |
47 | * Buffer buffer = T::MatchBuffer(...); |
48 | * |
49 | * \endcode |
50 | * |
51 | * The `T::MatchBuffer` below instead generates `MatchBufferRegion` in a TIR block: |
52 | * |
53 | * \code {.cpp} |
54 | * |
55 | * using T = tvm::script::ir_builder::tir; |
56 | * With <PrimFuncFrame> _(...); |
57 | * { |
58 | * With<BlockFrame> _2(...); |
59 | * Buffer buffer = T::MatchBuffer(...); |
60 | * } |
61 | * |
62 | * \endcode |
63 | */ |
64 | class IRBuilderFrameNode : public runtime::Object { |
65 | public: |
66 | /*! \brief A list of callbacks used when exiting the frame. */ |
67 | std::vector<runtime::TypedPackedFunc<void()>> callbacks; |
68 | |
69 | void VisitAttrs(tvm::AttrVisitor* v) { |
70 | // `callbacks` is not visited. |
71 | } |
72 | |
73 | static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame" ; |
74 | TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object); |
75 | |
76 | public: |
77 | /*! \brief Default destructor. */ |
78 | virtual ~IRBuilderFrameNode() = default; |
79 | /*! |
80 | * \brief The method called when entering RAII scope. |
81 | * \sa tvm::support::With |
82 | */ |
83 | virtual void EnterWithScope(); |
84 | /*! |
85 | * \brief The method called when exiting RAII scope. |
86 | * \sa tvm::support::With |
87 | */ |
88 | virtual void ExitWithScope(); |
89 | /*! |
90 | * \brief Add a callback method invoked when exiting the RAII scope. |
91 | * \param callback The callback to be added. |
92 | */ |
93 | void AddCallback(runtime::TypedPackedFunc<void()> callback); |
94 | }; |
95 | |
96 | /*! |
97 | * \brief Managed reference to an IRBuilderFrameNode. |
98 | * \sa IRBuilderFrameNode |
99 | */ |
100 | class IRBuilderFrame : public runtime::ObjectRef { |
101 | public: |
102 | /*! \brief Default destructor. */ |
103 | virtual ~IRBuilderFrame() = default; |
104 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode); |
105 | |
106 | protected: |
107 | /*! \brief Disallow direct construction of this object. */ |
108 | IRBuilderFrame() = default; |
109 | |
110 | public: |
111 | /*! |
112 | * \brief Redirected to `IRBuilderFrameNode::EnterWithScope`. |
113 | * \sa IRBuilderFrameNode::EnterWithScope |
114 | */ |
115 | inline void EnterWithScope() { |
116 | ICHECK(data_ != nullptr); |
117 | static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope(); |
118 | } |
119 | /*! |
120 | * \brief Redirected to `IRBuilderFrameNode::ExitWithScope`. |
121 | * \sa IRBuilderFrameNode::ExitWithScope |
122 | */ |
123 | inline void ExitWithScope() { |
124 | ICHECK(data_ != nullptr); |
125 | static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope(); |
126 | data_.reset(); |
127 | } |
128 | }; |
129 | |
130 | ////////////////////////////// IRBuilder ////////////////////////////// |
131 | |
132 | /*! |
133 | * \brief A dialect-agnostic IRBuilder that constructs any IR of TVM. |
134 | * An idiomatic use of this class is to put this inside the RAII with-scope, |
135 | * call dialect-specific methods accordingly. Upon exiting the scope. |
136 | * |
137 | * \code |
138 | * |
139 | * PrimFunc ConstructPrimFunc() { |
140 | * using tvm::script::ir_builder::IRBuilder; |
141 | * using T = tvm::script::ir_builder::tir; |
142 | * IRBuilder builder; |
143 | * // Step 1. Place IRBuilder inside the with-scope. |
144 | * { |
145 | * With<IRBuilder> _(builder); |
146 | * // Step 2. Call dialect-specific methods. |
147 | * With<T::PrimFuncFrame> _2(...); |
148 | * T::MatchBuffer(...); |
149 | * } |
150 | * // Step 3. Return the constructed PrimFunc. |
151 | * return builder->Get<PrimFunc>(); |
152 | * } |
153 | * |
154 | * \endcode |
155 | */ |
156 | class IRBuilderNode : public runtime::Object { |
157 | public: |
158 | /*! \brief A stack of context frames in the IRBuilder */ |
159 | runtime::Array<IRBuilderFrame> frames; |
160 | /*! \brief The outcome of IR construction */ |
161 | Optional<ObjectRef> result; |
162 | |
163 | void VisitAttrs(tvm::AttrVisitor* v) { |
164 | v->Visit("frames" , &frames); |
165 | v->Visit("result" , &result); |
166 | } |
167 | |
168 | static constexpr const char* _type_key = "script.ir_builder.IRBuilder" ; |
169 | TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object); |
170 | |
171 | public: |
172 | /*! |
173 | * \brief Find a frame of the given type in the stack `this->frames` from top to bottom. |
174 | * \tparam T The type of the frame to find. |
175 | * \return The frame if found, otherwise NullOpt. |
176 | */ |
177 | template <typename TFrame> |
178 | inline Optional<TFrame> FindFrame() const; |
179 | /*! |
180 | * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`. |
181 | * \tparam TFrame The assumed type of the last frame on stack. |
182 | * \return The frame if the stack is non-empty and the top of the stack is of type `TFrame`. |
183 | * Otherwise NullOpt. |
184 | */ |
185 | template <typename TFrame> |
186 | inline Optional<TFrame> GetLastFrame() const; |
187 | /*! |
188 | * \brief Get the IR being constructed. |
189 | * \tparam TObjectRef The type of the IR being constructed. |
190 | * \return The resulting IR. Throw an exception if the IR is not constructed yet. |
191 | */ |
192 | template <typename TObjectRef> |
193 | inline TObjectRef Get() const; |
194 | }; |
195 | |
196 | /*! |
197 | * \brief Managed reference to an IRBuilderNode. |
198 | * \sa IRBuilderNode |
199 | */ |
200 | class IRBuilder : public runtime::ObjectRef { |
201 | public: |
202 | /*! \brief Creates an IRBuilder. */ |
203 | IRBuilder(); |
204 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode); |
205 | |
206 | public: |
207 | /*! |
208 | * \brief Puts the current IRBuilder into a thread-local scope, which can be retrieved using |
209 | * `IRBuilder::Current()`. |
210 | * |
211 | * \code {.cpp} |
212 | * IRBuilder builder; |
213 | * { |
214 | * With<IRBuilder> _(builder); |
215 | * // IRBuilder::Current() == builder |
216 | * } |
217 | * // IRBuilder::Current() == nullptr |
218 | * \endcode |
219 | * |
220 | * \sa IRBuilder::Current |
221 | * \sa IRBuilder::ExitWithScope |
222 | * \sa tvm::support::With |
223 | */ |
224 | void EnterWithScope(); |
225 | /*! |
226 | * \brief Exit the RAII scope. |
227 | * \sa IRBuilder::EnterWithScope |
228 | * \sa IRBuilder::Current |
229 | * \sa tvm::support::With |
230 | */ |
231 | void ExitWithScope(); |
232 | /*! |
233 | * \brief Get the current IRBuilder in the current thread-local scope. |
234 | * \return The current IRBuilder. |
235 | * \sa IRBuilder::EnterWithScope |
236 | * \sa IRBuilder::ExitWithScope |
237 | * \sa tvm::support::With |
238 | */ |
239 | static IRBuilder Current(); |
240 | /*! |
241 | * \brief Give a string name to the `obj` |
242 | * \tparam TObjectRef The type of the object to name. |
243 | * \param name The name to give to the object. |
244 | * \param obj The object to name. |
245 | */ |
246 | template <class TObjectRef> |
247 | inline static TObjectRef Name(String name, TObjectRef obj); |
248 | }; |
249 | |
250 | ////////////////////////////// Details ////////////////////////////// |
251 | |
252 | namespace details { |
253 | |
254 | class Namer { |
255 | public: |
256 | using FType = NodeFunctor<void(const ObjectRef&, String)>; |
257 | static FType& vtable(); |
258 | static void Name(ObjectRef node, String name); |
259 | }; |
260 | |
261 | } // namespace details |
262 | |
263 | template <class TObjectRef> |
264 | inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) { |
265 | details::Namer::Name(obj, name); |
266 | return Downcast<TObjectRef>(obj); |
267 | } |
268 | |
269 | template <typename TFrame> |
270 | inline Optional<TFrame> IRBuilderNode::FindFrame() const { |
271 | using TFrameNode = typename TFrame::ContainerType; |
272 | for (auto it = frames.rbegin(); it != frames.rend(); ++it) { |
273 | if (const TFrameNode* p = (*it).template as<TFrameNode>()) { |
274 | return GetRef<TFrame>(p); |
275 | } |
276 | } |
277 | return NullOpt; |
278 | } |
279 | |
280 | template <typename TFrame> |
281 | inline Optional<TFrame> IRBuilderNode::GetLastFrame() const { |
282 | using TFrameNode = typename TFrame::ContainerType; |
283 | if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) { |
284 | return Downcast<TFrame>(frames.back()); |
285 | } |
286 | return NullOpt; |
287 | } |
288 | |
289 | template <typename TObjectRef> |
290 | inline TObjectRef IRBuilderNode::Get() const { |
291 | using TObject = typename TObjectRef::ContainerType; |
292 | CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet" ; |
293 | const auto* n = result.as<TObject>(); |
294 | CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key; |
295 | return GetRef<TObjectRef>(n); |
296 | } |
297 | |
298 | } // namespace ir_builder |
299 | } // namespace script |
300 | } // namespace tvm |
301 | |
302 | #endif // TVM_SCRIPT_IR_BUILDER_BASE_H_ |
303 | |