1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/function.h
22 * \brief TIR Function.
23 */
24#ifndef TVM_TIR_FUNCTION_H_
25#define TVM_TIR_FUNCTION_H_
26
27#include <tvm/ir/function.h>
28#include <tvm/runtime/ndarray.h>
29#include <tvm/tir/buffer.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/stmt.h>
32
33#include <string>
34
35namespace tvm {
36namespace tir {
37
38/*!
39 * \brief Primitive functions that contains TIR statements.
40 *
41 * The PrimFunc provides low-level code representation does not
42 * automatically manage
43 *
44 * \sa PrimFunc
45 */
46class PrimFuncNode : public BaseFuncNode {
47 public:
48 /*! \brief Function parameters */
49 Array<tir::Var> params;
50 /*! \brief The body of the function */
51 tir::Stmt body;
52 /*! \brief The return type of the function. */
53 Type ret_type;
54 /*!
55 * \brief Maps some parameters to specific Buffer data structures.
56 *
57 * buffer_map provides a way to express data structure's field and shape
58 * constraints. The provided information is used in the program analysis
59 * and the code generation.
60 *
61 * - It defines the vars in the Buffer (m, n) in the cases below when
62 * they appears in the buffer_map for the first time.
63 * - When a var appears multiple times, they translate into runtime
64 * assertion to check the field constraint.
65 *
66 * \code
67 *
68 * # The corresponding fields of f are as follows
69 * #
70 * # - f.params = [a, b]
71 * # - f.buffer_map = {a: A, b: B}
72 * # - A = decl_buffer(shape=[m, n])
73 * # - B = decl_buffer(shape=[m, n])
74 *
75 * def f(a, b):
76 * m, n = var(), var()
77 * A = bind_buffer(a, shape=[m, n])
78 * B = bind_buffer(b, shape=[m, n])
79 * # body
80 *
81 * \endcode
82 *
83 * buffer_map is a sugar to express:
84 * - Parameter unpacking: e.g. I can load a.shape[0] to get value of m
85 * - Constraint checking: a.shape[0] must equal b.shape[0] because they
86 * both corresponds to m.
87
88 * While we could have express parameter unpacking and constraint using
89 * normal statements, making buffer_map as first class citizen of PrimFunc
90 * will make program analysis much easier.
91 *
92 * Prior to buffer flattening, which is performed either in
93 * StorageFlatten for TE-based schedules or in FlattenBuffer for
94 * TIR-based schedules, these buffer objects are used directly in
95 * the body of the function. After buffer flattening, these buffer
96 * objects remain unflattened for use in argument validation, but
97 * all usage in the body of the function is done through a
98 * flattened alias of the buffer.
99 */
100 Map<tir::Var, Buffer> buffer_map;
101
102 void VisitAttrs(tvm::AttrVisitor* v) {
103 v->Visit("params", &params);
104 v->Visit("body", &body);
105 v->Visit("ret_type", &ret_type);
106 v->Visit("buffer_map", &buffer_map);
107 v->Visit("attrs", &attrs);
108 v->Visit("span", &span);
109 v->Visit("_checked_type_", &checked_type_);
110 }
111
112 bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
113 // visit params and buffer_map first as they contains defs.
114 return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
115 equal(ret_type, other->ret_type) && equal(body, other->body) &&
116 equal(attrs, other->attrs);
117 }
118
119 void SHashReduce(SHashReducer hash_reduce) const {
120 hash_reduce.DefHash(params);
121 hash_reduce(buffer_map);
122 hash_reduce(ret_type);
123 hash_reduce(body);
124 hash_reduce(attrs);
125 }
126 /*!
127 * \brief Return the derived function annotation of this function.
128 *
129 * \return The function type annotation.
130 * \note The function type annotation of PrimExpr is
131 * directly derived from the Vars without the need of type inference.
132 */
133 TVM_DLL FuncType func_type_annotation() const;
134
135 TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
136
137 static constexpr const char* _type_key = "tir.PrimFunc";
138 TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
139};
140
141/*!
142 * \brief Managed reference to PrimFuncNode.
143 * \sa PrimFuncNode
144 */
145class PrimFunc : public BaseFunc {
146 public:
147 /*!
148 * \brief Constructor
149 *
150 * \param params The parameters of the function.
151 *
152 * \param body The body of the function.
153 *
154 * \param ret_type The return type of the function.
155 *
156 * \param buffer_map The buffer map for parameter buffer unpacking.
157 * This contains buffer objects as they appear in the body of the
158 * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally
159 * generated as a tensor of shape ``[32, 32]``)
160 *
161 * \param attrs Additional function attributes.
162 *
163 * \param span The location of this object in the source code.
164 */
165 TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
166 Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
167 DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
168
169 TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
170 TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
171};
172
173/*!
174 * \brief Tensor intrinsics for tensorization
175 */
176class TensorIntrinNode : public Object {
177 public:
178 /*! \brief The function to describe the computation. */
179 PrimFunc desc;
180 /*! \brief The function of the implementation for the execution. */
181 PrimFunc impl;
182
183 void VisitAttrs(AttrVisitor* v) {
184 v->Visit("desc", &desc);
185 v->Visit("impl", &impl);
186 }
187
188 static constexpr const char* _type_key = "tir.TensorIntrin";
189 TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
190};
191
192/*!
193 * \brief Managed reference to TensorIntrinNode.
194 */
195class TensorIntrin : public ObjectRef {
196 public:
197 /*!
198 * \brief Constructor
199 * \param desc The function to describe the computation.
200 * \param impl The function of the implementation for the execution.
201 */
202 TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);
203
204 /*!
205 * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked
206 * up with its name.
207 * \param name The name of the TensorIntrin to register
208 * \param intrin The TensorIntrin to register.
209 * \param override Whether override existing intrinsic.
210 * \throws This method throws an exception if the TensorIntrin with the specified name already
211 * exists.
212 */
213 TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false);
214
215 /*!
216 * \brief Look up TensorIntrin by name. Raises an exception if not found.
217 * \param name The name of the TensorIntrin.
218 * \param allow_missing Whether to allow missing tensor intrin. If false, an exception is raised
219 * if the tensor intrin is not found.
220 * \return The TensorIntrin with the specified name.
221 * \throws This method throws an exception if the TensorIntrin does not exist and allow_missing is
222 * false.
223 */
224 TVM_DLL static Optional<TensorIntrin> Get(String name, bool allow_missing = false);
225
226 TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
227};
228
229/*!
230 * \brief Specialize parameters of PrimFunc.
231 * \param func The PrimFunc to be specialized.
232 * \param param_map The mapping from function params to the instance.
233 * \return The new function with parameter specialized.
234 * \note We can define a Meta TIR function with symbolic shape:
235 *
236 * \code{.py}
237 * @T.prim_func
238 * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
239 * A = T.match_buffer(a, (m, n), "float32")
240 * B = T.match_buffer(b, (m, n), "float32")
241 * for i, j in T.grid(m, n):
242 * with T.block():
243 * vi, vj = T.axis.remap("SS", [i, j])
244 * B[vi, vj] = A[vi, vj]
245 * \endcode
246 *
247 * Then we can make it specialized with given shapes or buffers.
248 *
249 * \code{.py}
250 * a, _, m, n = mem_copy.params
251 * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
252 * # or
253 * func = mem_copy.specialize({n: 16, m: 16})
254 * \endcode
255 *
256 * \code{.py}
257 * @T.prim_func
258 * def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
259 * A = T.match_buffer(a, (16, 16), "float32")
260 * B = T.match_buffer(b, (16, 16), "float32")
261 * for i, j in T.grid(16, 16):
262 * with T.block():
263 * vi, vj = T.axis.remap("SS", [i, j])
264 * B[vi, vj] = A[vi, vj]
265 * \endcode
266 */
267PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
268
269/*!
270 * \brief PrimFunc specific attribute names.
271 *
272 * \sa tvm::attr
273 */
274namespace attr {
275/*!
276 * \brief List of thread IterVar that a DeviceLaunch function corresponds to.
277 *
278 * Type: Array<tir::IterVar>
279 *
280 * We call a device kernel launch function f using the following convention:
281 *
282 * Call(f,
283 * [arg1, arg2, ..., arg_n,
284 * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
285 *
286 * Here n = len(arg), m = len(work_size) = len(device_thread_axis).
287 *
288 * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted.
289 *
290 * The list of device_thread_axis indicates how can be bind the
291 * work_size arguments to the corresponding threads.
292 *
293 * \sa tvm::CallingConv::kDeviceKernelLaunch
294 */
295constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
296
297/*!
298 * \brief Whether or not use dynamic shared memory.
299 *
300 * Type: Integer
301 */
302constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";
303
304/*!
305 * \brief Whether to set noalias rule on the function arguments.
306 *
307 * Type: Integer
308 */
309constexpr const char* kNoAlias = "tir.noalias";
310
311/*!
312 * \brief Mark the function as the entry function of
313 * the final generated runtime module.
314 *
315 * Type: Integer
316 *
317 * \note There can only be one entry function per module.
318 */
319constexpr const char* kIsEntryFunc = "tir.is_entry_func";
320
321/*!
322 * \brief Mark the function as the global function called from the host.
323 *
324 * Type: Integer
325 */
326constexpr const char* kIsGlobalFunc = "tir.is_global_func";
327
328} // namespace attr
329} // namespace tir
330} // namespace tvm
331#endif // TVM_TIR_FUNCTION_H_
332