1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file ir_utils.h |
22 | * \brief Helper functions to construct and compose IR nodes. |
23 | */ |
24 | #ifndef TVM_TIR_TRANSFORMS_IR_UTILS_H_ |
25 | #define TVM_TIR_TRANSFORMS_IR_UTILS_H_ |
26 | |
27 | #include <tvm/arith/int_set.h> |
28 | #include <tvm/runtime/device_api.h> |
29 | #include <tvm/support/with.h> |
30 | #include <tvm/tir/builtin.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/function.h> |
33 | #include <tvm/tir/op.h> |
34 | |
35 | #include <limits> |
36 | #include <string> |
37 | #include <unordered_map> |
38 | #include <utility> |
39 | #include <vector> |
40 | |
41 | namespace tvm { |
42 | namespace tir { |
43 | /*! |
44 | * \brief combine the nest stmt, whose body is not defined. |
45 | * \param nest A list of For and LetStmt, whose body is not defined. |
46 | * \param body body |
47 | * \return The combined Stmt |
48 | */ |
49 | Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body); |
50 | |
51 | /*! |
52 | * \brief combine the nest stmt, whose body is not defined. |
53 | * \param nest A list of For and LetStmt, whose body is not defined. |
54 | * \param body body |
55 | * \return The combined Stmt |
56 | */ |
57 | Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body); |
58 | |
59 | /*! |
60 | * \brief update array with an unary function |
61 | * \param arr array |
62 | * \param fupdate an unary function |
63 | * \tparam T type of array element |
64 | * \tparam F type of the unary function |
65 | * \return if update happens, return the new array, else return the |
66 | * original array |
67 | */ |
68 | template <typename T, typename F> |
69 | inline Array<T> UpdateArray(Array<T> arr, F fupdate) { |
70 | std::vector<T> new_arr(arr.size()); |
71 | bool changed = false; |
72 | for (size_t i = 0; i < arr.size(); ++i) { |
73 | T old_elem = arr[i]; |
74 | T new_elem = fupdate(old_elem); |
75 | if (!new_elem.same_as(old_elem)) changed = true; |
76 | new_arr[i] = new_elem; |
77 | } |
78 | if (!changed) { |
79 | return arr; |
80 | } else { |
81 | return Array<T>(new_arr); |
82 | } |
83 | } |
84 | |
85 | /*! |
86 | * \brief Get construct from struct |
87 | * \param dtype The data type. |
88 | * \param handle the struct handle. |
89 | * \param index the offset index. |
90 | * \param kind The data kind. |
91 | * \return the get expression. |
92 | */ |
93 | inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, |
94 | builtin::TVMStructFieldKind kind) { |
95 | Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index), |
96 | make_const(DataType::Int(32), static_cast<int>(kind))}; |
97 | return Call(dtype, builtin::tvm_struct_get(), args); |
98 | } |
99 | |
100 | /*! |
101 | * \brief Address of handle + offset |
102 | * \param handle the array handle. |
103 | * \param dtype The data type. |
104 | * \param offset the offset index. |
105 | */ |
106 | inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { |
107 | PrimExpr offset_expr = make_const(DataType::Int(32), offset * dtype.lanes()); |
108 | Buffer dummy_buf(handle, dtype, {offset_expr + 1}, {}, 0, handle->name_hint, 0, 0, kDefault); |
109 | BufferLoad buf_load(dummy_buf, {offset_expr}); |
110 | |
111 | return Call(DataType::Handle(), builtin::address_of(), {buf_load}); |
112 | } |
113 | |
114 | /*! |
115 | * \brief Address of handle + offset |
116 | * \param handle the array handle. |
117 | * \param dtype The data type. |
118 | * \param offset the offset index. |
119 | */ |
120 | inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { |
121 | if (dtype.lanes() != 1) { |
122 | offset = offset * make_const(offset.dtype(), dtype.lanes()); |
123 | offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); |
124 | } |
125 | |
126 | Buffer dummy_buf(handle, dtype.element_of(), {offset + 1}, {}, 0, handle->name_hint, 0, 0, |
127 | kDefault); |
128 | BufferLoad buf_load(dummy_buf, {offset}); |
129 | |
130 | return Call(DataType::Handle(), builtin::address_of(), {buf_load}); |
131 | } |
132 | |
133 | /*! |
134 | * \brief Set value into struct. |
135 | * \param handle the struct handle. |
136 | * \param index the offset index. |
137 | * \param kind The data kind. |
138 | * \param value The value to be set. |
139 | * \return the set stmt. |
140 | */ |
141 | inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { |
142 | Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index), |
143 | make_const(DataType::Int(32), static_cast<int>(kind)), value}; |
144 | return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args)); |
145 | } |
146 | |
147 | /*! |
148 | * \brief Get the type that is passed around TVM PackedFunc API. |
149 | * \param t The original type. |
150 | * \return The corresponding API type. |
151 | */ |
152 | inline DataType APIType(DataType t) { |
153 | if (t.is_handle()) return t; |
154 | ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API." ; |
155 | if (t.is_uint() || t.is_int()) return DataType::Int(64); |
156 | ICHECK(t.is_float()); |
157 | return DataType::Float(64); |
158 | } |
159 | |
160 | /*! |
161 | * \brief Rule to get allocation alignment requirement for a given const array. |
162 | * \param type The type of allocation. |
163 | * \param const_size The constant size of the array. |
164 | * \return the alignment |
165 | */ |
166 | inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { |
167 | int align = runtime::kTempAllocaAlignment; |
168 | if (const_size > 0) { |
169 | int64_t const_s = static_cast<int64_t>(const_size) * type.bits() * type.lanes() / 8; |
170 | while (align > const_s) { |
171 | align = align / 2; |
172 | } |
173 | } |
174 | return align; |
175 | } |
176 | |
177 | /*! |
178 | * \brief Create an int32 constant |
179 | * \param index the value of the constant |
180 | * \return the PrimExpr that represents the constant |
181 | */ |
182 | inline PrimExpr ConstInt32(size_t index) { |
183 | ICHECK_LE(index, std::numeric_limits<int>::max()); |
184 | return make_const(DataType::Int(32), static_cast<int>(index)); |
185 | } |
186 | |
187 | /*! |
188 | * \brief Allocate TVMValues on the stack |
189 | * \param type type of allocation |
190 | * \param num number of TVMValues to allocate |
191 | * \return PrimExpr representing the TVMValue |
192 | */ |
193 | inline PrimExpr StackAlloca(std::string type, size_t num) { |
194 | Array<PrimExpr> args = {StringImm(type), ConstInt32(num)}; |
195 | return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); |
196 | } |
197 | |
198 | /*! |
199 | * \brief Convert a IR node to be SSA form. |
200 | * \param stmt The source statement to be converted. |
201 | * \return The converted form. |
202 | */ |
203 | Stmt ConvertSSA(Stmt stmt); |
204 | |
205 | /*! |
206 | * \brief Return the storage scope associated with a buffer variable. |
207 | * \param buffer_var The input buffer variable. |
208 | * \return A string representing the storage scope of this buffer variable. |
209 | */ |
210 | String GetPtrStorageScope(Var buffer_var); |
211 | |
212 | /*! |
213 | * \brief Convert match buffer target buffer access indices to original one. |
214 | * \param indices The indices of the target buffer |
215 | * \return The indices of source buffer. |
216 | */ |
217 | Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer, |
218 | const Array<PrimExpr>& indices); |
219 | |
220 | /*! |
221 | * \brief Convert match buffer target buffer region to original one. |
222 | * \param region The sub-region of the target buffer |
223 | * \return The region of source buffer. |
224 | */ |
225 | Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); |
226 | |
227 | /*! |
228 | * \brief Check if a given PrimFunc originated from a TE schedule. |
229 | * |
230 | * Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc. |
231 | * |
232 | * \param f PrimFunc to check |
233 | * \return Whether or not the PrimFunc was created from a te schedule |
234 | */ |
235 | Bool IsFromLegacyTESchedule(PrimFunc f); |
236 | |
237 | /*! |
238 | *\brief Context helper to update domain map within conditional scope. |
239 | * |
240 | * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is |
241 | * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step |
242 | *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition, |
243 | *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20] |
244 | */ |
245 | class ConditionalBoundsContext { |
246 | private: |
247 | friend class With<ConditionalBoundsContext>; |
248 | /*! |
249 | * \brief Construct a condition bounds context. |
250 | * \param condition The condition holds on true branch. |
251 | * \param relax_map The domain map for relaxed vars to update. |
252 | * \param hint_map The domain map for free vars to update. |
253 | * \param is_true_branch Whether step into the branch where condition bounds holds. |
254 | */ |
255 | ConditionalBoundsContext(const PrimExpr& condition, |
256 | std::unordered_map<const VarNode*, arith::IntSet>* relax_map, |
257 | std::unordered_map<const VarNode*, arith::IntSet>* hint_map, |
258 | bool is_true_branch); |
259 | void EnterWithScope(); |
260 | void ExitWithScope(); |
261 | |
262 | /*! \brief Helper to solve related variable's bound within conditional scope.*/ |
263 | Map<Var, Range> GetVarBoundsFromCondition(); |
264 | |
265 | /*! \brief the condition holds on true branch. */ |
266 | const PrimExpr& condition_; |
267 | /*! \brief domain map for relaxed vars to update */ |
268 | std::unordered_map<const VarNode*, arith::IntSet>* relax_map_; |
269 | /*! \brief domain map for free vars to update */ |
270 | std::unordered_map<const VarNode*, arith::IntSet>* hint_map_; |
271 | /*! \brief whether is on true branch */ |
272 | bool is_true_branch_; |
273 | /*! \brief used to record and restore original var bounds */ |
274 | std::unordered_map<const VarNode*, arith::IntSet> origin_map_; |
275 | }; |
276 | |
277 | // Information of tensor core fragment. |
278 | struct FragmentInfo { |
279 | // fragment shape |
280 | int m, n, k; |
281 | // fragment layout (row-major or column-major) |
282 | std::string layout; |
283 | // scope of the fragment (wmma.matrix_a, wmma.matrix_b, or wmma.accumulator) |
284 | std::string scope; |
285 | FragmentInfo() = default; |
286 | FragmentInfo(int _m, int _n, int _k, const std::string& _layout, const std::string& _scope) |
287 | : m(_m), n(_n), k(_k), layout(_layout), scope(_scope) {} |
288 | |
289 | int GetSize() const { |
290 | if (scope == "wmma.matrix_a" ) { |
291 | return m * k; |
292 | } else if (scope == "wmma.matrix_b" ) { |
293 | return n * k; |
294 | } else if (scope == "wmma.accumulator" ) { |
295 | return m * n; |
296 | } else { |
297 | ICHECK(0); |
298 | throw; |
299 | } |
300 | } |
301 | }; |
302 | |
303 | /*! |
304 | * \brief Extract information of tensor core fragment from the IR. |
305 | * \param stmt The stmt to visit. |
306 | * \return Map from buffer variables to the fragment info. |
307 | */ |
308 | std::unordered_map<const VarNode*, FragmentInfo> GetTensorCoreFragmentInfo(const Stmt& stmt); |
309 | |
310 | // Return the queue id and the in-flight count associated with the given |
311 | // attr::async_wait_queue_scope annotation. |
312 | std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op); |
313 | |
314 | /*! |
315 | * \brief Bind a subset of parameter tensors to constants, replacing them by AllocateConst nodes. |
316 | * \param f The function to bind constants to. |
317 | * \param constants Raw constant data. If the size of this array is N, the last N parameter tensors |
318 | * will be removed from the signature and instead AllocateConst nodes will be introduced in the |
319 | * function body. |
320 | * \return The updated function. |
321 | */ |
322 | PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants); |
323 | |
324 | } // namespace tir |
325 | } // namespace tvm |
326 | #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ |
327 | |