1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_
20#define TVM_SCRIPT_IR_BUILDER_BASE_H_
21
22#include <tvm/ir/expr.h>
23#include <tvm/ir/function.h>
24#include <tvm/node/node.h>
25
26#include <vector>
27
28namespace tvm {
29namespace script {
30namespace ir_builder {
31
32////////////////////////////// IRBuilderFrame //////////////////////////////
33
34/*!
35 * \brief A stack frame of the IRBuilder used to keep track of the current scope.
36 * Furthermore, the information stored in each stack frame can be useful for context-dependent
37 * IR construction.
38 *
39 * \example
40 *
41 * The `T::MatchBuffer` below adds an element in `PrimFuncNode::buffer_map`:
42 *
43 * \code {.cpp}
44 *
45 * using T = tvm::script::ir_builder::tir;
46 * With <PrimFuncFrame> _(...);
47 * Buffer buffer = T::MatchBuffer(...);
48 *
49 * \endcode
50 *
51 * The `T::MatchBuffer` below instead generates `MatchBufferRegion` in a TIR block:
52 *
53 * \code {.cpp}
54 *
55 * using T = tvm::script::ir_builder::tir;
56 * With <PrimFuncFrame> _(...);
57 * {
58 * With<BlockFrame> _2(...);
59 * Buffer buffer = T::MatchBuffer(...);
60 * }
61 *
62 * \endcode
63 */
64class IRBuilderFrameNode : public runtime::Object {
65 public:
66 /*! \brief A list of callbacks used when exiting the frame. */
67 std::vector<runtime::TypedPackedFunc<void()>> callbacks;
68
69 void VisitAttrs(tvm::AttrVisitor* v) {
70 // `callbacks` is not visited.
71 }
72
73 static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame";
74 TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
75
76 public:
77 /*! \brief Default destructor. */
78 virtual ~IRBuilderFrameNode() = default;
79 /*!
80 * \brief The method called when entering RAII scope.
81 * \sa tvm::support::With
82 */
83 virtual void EnterWithScope();
84 /*!
85 * \brief The method called when exiting RAII scope.
86 * \sa tvm::support::With
87 */
88 virtual void ExitWithScope();
89 /*!
90 * \brief Add a callback method invoked when exiting the RAII scope.
91 * \param callback The callback to be added.
92 */
93 void AddCallback(runtime::TypedPackedFunc<void()> callback);
94};
95
96/*!
97 * \brief Managed reference to an IRBuilderFrameNode.
98 * \sa IRBuilderFrameNode
99 */
100class IRBuilderFrame : public runtime::ObjectRef {
101 public:
102 /*! \brief Default destructor. */
103 virtual ~IRBuilderFrame() = default;
104 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
105
106 protected:
107 /*! \brief Disallow direct construction of this object. */
108 IRBuilderFrame() = default;
109
110 public:
111 /*!
112 * \brief Redirected to `IRBuilderFrameNode::EnterWithScope`.
113 * \sa IRBuilderFrameNode::EnterWithScope
114 */
115 inline void EnterWithScope() {
116 ICHECK(data_ != nullptr);
117 static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
118 }
119 /*!
120 * \brief Redirected to `IRBuilderFrameNode::ExitWithScope`.
121 * \sa IRBuilderFrameNode::ExitWithScope
122 */
123 inline void ExitWithScope() {
124 ICHECK(data_ != nullptr);
125 static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
126 data_.reset();
127 }
128};
129
130////////////////////////////// IRBuilder //////////////////////////////
131
132/*!
133 * \brief A dialect-agnostic IRBuilder that constructs any IR of TVM.
134 * An idiomatic use of this class is to put this inside the RAII with-scope,
135 * call dialect-specific methods accordingly. Upon exiting the scope.
136 *
137 * \code
138 *
139 * PrimFunc ConstructPrimFunc() {
140 * using tvm::script::ir_builder::IRBuilder;
141 * using T = tvm::script::ir_builder::tir;
142 * IRBuilder builder;
143 * // Step 1. Place IRBuilder inside the with-scope.
144 * {
145 * With<IRBuilder> _(builder);
146 * // Step 2. Call dialect-specific methods.
147 * With<T::PrimFuncFrame> _2(...);
148 * T::MatchBuffer(...);
149 * }
150 * // Step 3. Return the constructed PrimFunc.
151 * return builder->Get<PrimFunc>();
152 * }
153 *
154 * \endcode
155 */
156class IRBuilderNode : public runtime::Object {
157 public:
158 /*! \brief A stack of context frames in the IRBuilder */
159 runtime::Array<IRBuilderFrame> frames;
160 /*! \brief The outcome of IR construction */
161 Optional<ObjectRef> result;
162
163 void VisitAttrs(tvm::AttrVisitor* v) {
164 v->Visit("frames", &frames);
165 v->Visit("result", &result);
166 }
167
168 static constexpr const char* _type_key = "script.ir_builder.IRBuilder";
169 TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
170
171 public:
172 /*!
173 * \brief Find a frame of the given type in the stack `this->frames` from top to bottom.
174 * \tparam T The type of the frame to find.
175 * \return The frame if found, otherwise NullOpt.
176 */
177 template <typename TFrame>
178 inline Optional<TFrame> FindFrame() const;
179 /*!
180 * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`.
181 * \tparam TFrame The assumed type of the last frame on stack.
182 * \return The frame if the stack is non-empty and the top of the stack is of type `TFrame`.
183 * Otherwise NullOpt.
184 */
185 template <typename TFrame>
186 inline Optional<TFrame> GetLastFrame() const;
187 /*!
188 * \brief Get the IR being constructed.
189 * \tparam TObjectRef The type of the IR being constructed.
190 * \return The resulting IR. Throw an exception if the IR is not constructed yet.
191 */
192 template <typename TObjectRef>
193 inline TObjectRef Get() const;
194};
195
196/*!
197 * \brief Managed reference to an IRBuilderNode.
198 * \sa IRBuilderNode
199 */
200class IRBuilder : public runtime::ObjectRef {
201 public:
202 /*! \brief Creates an IRBuilder. */
203 IRBuilder();
204 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
205
206 public:
207 /*!
208 * \brief Puts the current IRBuilder into a thread-local scope, which can be retrieved using
209 * `IRBuilder::Current()`.
210 *
211 * \code {.cpp}
212 * IRBuilder builder;
213 * {
214 * With<IRBuilder> _(builder);
215 * // IRBuilder::Current() == builder
216 * }
217 * // IRBuilder::Current() == nullptr
218 * \endcode
219 *
220 * \sa IRBuilder::Current
221 * \sa IRBuilder::ExitWithScope
222 * \sa tvm::support::With
223 */
224 void EnterWithScope();
225 /*!
226 * \brief Exit the RAII scope.
227 * \sa IRBuilder::EnterWithScope
228 * \sa IRBuilder::Current
229 * \sa tvm::support::With
230 */
231 void ExitWithScope();
232 /*!
233 * \brief Get the current IRBuilder in the current thread-local scope.
234 * \return The current IRBuilder.
235 * \sa IRBuilder::EnterWithScope
236 * \sa IRBuilder::ExitWithScope
237 * \sa tvm::support::With
238 */
239 static IRBuilder Current();
240 /*!
241 * \brief Give a string name to the `obj`
242 * \tparam TObjectRef The type of the object to name.
243 * \param name The name to give to the object.
244 * \param obj The object to name.
245 */
246 template <class TObjectRef>
247 inline static TObjectRef Name(String name, TObjectRef obj);
248};
249
250////////////////////////////// Details //////////////////////////////
251
252namespace details {
253
254class Namer {
255 public:
256 using FType = NodeFunctor<void(const ObjectRef&, String)>;
257 static FType& vtable();
258 static void Name(ObjectRef node, String name);
259};
260
261} // namespace details
262
263template <class TObjectRef>
264inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
265 details::Namer::Name(obj, name);
266 return Downcast<TObjectRef>(obj);
267}
268
269template <typename TFrame>
270inline Optional<TFrame> IRBuilderNode::FindFrame() const {
271 using TFrameNode = typename TFrame::ContainerType;
272 for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
273 if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
274 return GetRef<TFrame>(p);
275 }
276 }
277 return NullOpt;
278}
279
280template <typename TFrame>
281inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
282 using TFrameNode = typename TFrame::ContainerType;
283 if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
284 return Downcast<TFrame>(frames.back());
285 }
286 return NullOpt;
287}
288
289template <typename TObjectRef>
290inline TObjectRef IRBuilderNode::Get() const {
291 using TObject = typename TObjectRef::ContainerType;
292 CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
293 const auto* n = result.as<TObject>();
294 CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key;
295 return GetRef<TObjectRef>(n);
296}
297
298} // namespace ir_builder
299} // namespace script
300} // namespace tvm
301
302#endif // TVM_SCRIPT_IR_BUILDER_BASE_H_
303