1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file ir_utils.h
22 * \brief Helper functions to construct and compose IR nodes.
23 */
24#ifndef TVM_TIR_TRANSFORMS_IR_UTILS_H_
25#define TVM_TIR_TRANSFORMS_IR_UTILS_H_
26
27#include <tvm/arith/int_set.h>
28#include <tvm/runtime/device_api.h>
29#include <tvm/support/with.h>
30#include <tvm/tir/builtin.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/function.h>
33#include <tvm/tir/op.h>
34
35#include <limits>
36#include <string>
37#include <unordered_map>
38#include <utility>
39#include <vector>
40
41namespace tvm {
42namespace tir {
43/*!
44 * \brief combine the nest stmt, whose body is not defined.
45 * \param nest A list of For and LetStmt, whose body is not defined.
46 * \param body body
47 * \return The combined Stmt
48 */
49Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body);
50
51/*!
52 * \brief combine the nest stmt, whose body is not defined.
53 * \param nest A list of For and LetStmt, whose body is not defined.
54 * \param body body
55 * \return The combined Stmt
56 */
57Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body);
58
59/*!
60 * \brief update array with an unary function
61 * \param arr array
62 * \param fupdate an unary function
63 * \tparam T type of array element
64 * \tparam F type of the unary function
65 * \return if update happens, return the new array, else return the
66 * original array
67 */
68template <typename T, typename F>
69inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
70 std::vector<T> new_arr(arr.size());
71 bool changed = false;
72 for (size_t i = 0; i < arr.size(); ++i) {
73 T old_elem = arr[i];
74 T new_elem = fupdate(old_elem);
75 if (!new_elem.same_as(old_elem)) changed = true;
76 new_arr[i] = new_elem;
77 }
78 if (!changed) {
79 return arr;
80 } else {
81 return Array<T>(new_arr);
82 }
83}
84
85/*!
86 * \brief Get construct from struct
87 * \param dtype The data type.
88 * \param handle the struct handle.
89 * \param index the offset index.
90 * \param kind The data kind.
91 * \return the get expression.
92 */
93inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index,
94 builtin::TVMStructFieldKind kind) {
95 Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
96 make_const(DataType::Int(32), static_cast<int>(kind))};
97 return Call(dtype, builtin::tvm_struct_get(), args);
98}
99
100/*!
101 * \brief Address of handle + offset
102 * \param handle the array handle.
103 * \param dtype The data type.
104 * \param offset the offset index.
105 */
106inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
107 PrimExpr offset_expr = make_const(DataType::Int(32), offset * dtype.lanes());
108 Buffer dummy_buf(handle, dtype, {offset_expr + 1}, {}, 0, handle->name_hint, 0, 0, kDefault);
109 BufferLoad buf_load(dummy_buf, {offset_expr});
110
111 return Call(DataType::Handle(), builtin::address_of(), {buf_load});
112}
113
114/*!
115 * \brief Address of handle + offset
116 * \param handle the array handle.
117 * \param dtype The data type.
118 * \param offset the offset index.
119 */
120inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) {
121 if (dtype.lanes() != 1) {
122 offset = offset * make_const(offset.dtype(), dtype.lanes());
123 offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes());
124 }
125
126 Buffer dummy_buf(handle, dtype.element_of(), {offset + 1}, {}, 0, handle->name_hint, 0, 0,
127 kDefault);
128 BufferLoad buf_load(dummy_buf, {offset});
129
130 return Call(DataType::Handle(), builtin::address_of(), {buf_load});
131}
132
133/*!
134 * \brief Set value into struct.
135 * \param handle the struct handle.
136 * \param index the offset index.
137 * \param kind The data kind.
138 * \param value The value to be set.
139 * \return the set stmt.
140 */
141inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) {
142 Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
143 make_const(DataType::Int(32), static_cast<int>(kind)), value};
144 return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args));
145}
146
147/*!
148 * \brief Get the type that is passed around TVM PackedFunc API.
149 * \param t The original type.
150 * \return The corresponding API type.
151 */
152inline DataType APIType(DataType t) {
153 if (t.is_handle()) return t;
154 ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
155 if (t.is_uint() || t.is_int()) return DataType::Int(64);
156 ICHECK(t.is_float());
157 return DataType::Float(64);
158}
159
160/*!
161 * \brief Rule to get allocation alignment requirement for a given const array.
162 * \param type The type of allocation.
163 * \param const_size The constant size of the array.
164 * \return the alignment
165 */
166inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
167 int align = runtime::kTempAllocaAlignment;
168 if (const_size > 0) {
169 int64_t const_s = static_cast<int64_t>(const_size) * type.bits() * type.lanes() / 8;
170 while (align > const_s) {
171 align = align / 2;
172 }
173 }
174 return align;
175}
176
177/*!
178 * \brief Create an int32 constant
179 * \param index the value of the constant
180 * \return the PrimExpr that represents the constant
181 */
182inline PrimExpr ConstInt32(size_t index) {
183 ICHECK_LE(index, std::numeric_limits<int>::max());
184 return make_const(DataType::Int(32), static_cast<int>(index));
185}
186
187/*!
188 * \brief Allocate TVMValues on the stack
189 * \param type type of allocation
190 * \param num number of TVMValues to allocate
191 * \return PrimExpr representing the TVMValue
192 */
193inline PrimExpr StackAlloca(std::string type, size_t num) {
194 Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
195 return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
196}
197
198/*!
199 * \brief Convert a IR node to be SSA form.
200 * \param stmt The source statement to be converted.
201 * \return The converted form.
202 */
203Stmt ConvertSSA(Stmt stmt);
204
205/*!
206 * \brief Return the storage scope associated with a buffer variable.
207 * \param buffer_var The input buffer variable.
208 * \return A string representing the storage scope of this buffer variable.
209 */
210String GetPtrStorageScope(Var buffer_var);
211
212/*!
213 * \brief Convert match buffer target buffer access indices to original one.
214 * \param indices The indices of the target buffer
215 * \return The indices of source buffer.
216 */
217Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
218 const Array<PrimExpr>& indices);
219
220/*!
221 * \brief Convert match buffer target buffer region to original one.
222 * \param region The sub-region of the target buffer
223 * \return The region of source buffer.
224 */
225Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region);
226
227/*!
228 * \brief Check if a given PrimFunc originated from a TE schedule.
229 *
230 * Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc.
231 *
232 * \param f PrimFunc to check
233 * \return Whether or not the PrimFunc was created from a te schedule
234 */
235Bool IsFromLegacyTESchedule(PrimFunc f);
236
237/*!
238 *\brief Context helper to update domain map within conditional scope.
239 *
240 * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is
241 * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step
242 *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition,
243 *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
244 */
245class ConditionalBoundsContext {
246 private:
247 friend class With<ConditionalBoundsContext>;
248 /*!
249 * \brief Construct a condition bounds context.
250 * \param condition The condition holds on true branch.
251 * \param relax_map The domain map for relaxed vars to update.
252 * \param hint_map The domain map for free vars to update.
253 * \param is_true_branch Whether step into the branch where condition bounds holds.
254 */
255 ConditionalBoundsContext(const PrimExpr& condition,
256 std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
257 std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
258 bool is_true_branch);
259 void EnterWithScope();
260 void ExitWithScope();
261
262 /*! \brief Helper to solve related variable's bound within conditional scope.*/
263 Map<Var, Range> GetVarBoundsFromCondition();
264
265 /*! \brief the condition holds on true branch. */
266 const PrimExpr& condition_;
267 /*! \brief domain map for relaxed vars to update */
268 std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
269 /*! \brief domain map for free vars to update */
270 std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
271 /*! \brief whether is on true branch */
272 bool is_true_branch_;
273 /*! \brief used to record and restore original var bounds */
274 std::unordered_map<const VarNode*, arith::IntSet> origin_map_;
275};
276
277// Information of tensor core fragment.
278struct FragmentInfo {
279 // fragment shape
280 int m, n, k;
281 // fragment layout (row-major or column-major)
282 std::string layout;
283 // scope of the fragment (wmma.matrix_a, wmma.matrix_b, or wmma.accumulator)
284 std::string scope;
285 FragmentInfo() = default;
286 FragmentInfo(int _m, int _n, int _k, const std::string& _layout, const std::string& _scope)
287 : m(_m), n(_n), k(_k), layout(_layout), scope(_scope) {}
288
289 int GetSize() const {
290 if (scope == "wmma.matrix_a") {
291 return m * k;
292 } else if (scope == "wmma.matrix_b") {
293 return n * k;
294 } else if (scope == "wmma.accumulator") {
295 return m * n;
296 } else {
297 ICHECK(0);
298 throw;
299 }
300 }
301};
302
303/*!
304 * \brief Extract information of tensor core fragment from the IR.
305 * \param stmt The stmt to visit.
306 * \return Map from buffer variables to the fragment info.
307 */
308std::unordered_map<const VarNode*, FragmentInfo> GetTensorCoreFragmentInfo(const Stmt& stmt);
309
310// Return the queue id and the in-flight count associated with the given
311// attr::async_wait_queue_scope annotation.
312std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op);
313
314/*!
315 * \brief Bind a subset of parameter tensors to constants, replacing them by AllocateConst nodes.
316 * \param f The function to bind constants to.
317 * \param constants Raw constant data. If the size of this array is N, the last N parameter tensors
318 * will be removed from the signature and instead AllocateConst nodes will be introduced in the
319 * function body.
320 * \return The updated function.
321 */
322PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants);
323
324} // namespace tir
325} // namespace tvm
326#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_
327