1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/function.h |
22 | * \brief TIR Function. |
23 | */ |
24 | #ifndef TVM_TIR_FUNCTION_H_ |
25 | #define TVM_TIR_FUNCTION_H_ |
26 | |
27 | #include <tvm/ir/function.h> |
28 | #include <tvm/runtime/ndarray.h> |
29 | #include <tvm/tir/buffer.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/stmt.h> |
32 | |
33 | #include <string> |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | /*! |
39 | * \brief Primitive functions that contains TIR statements. |
40 | * |
41 | * The PrimFunc provides low-level code representation does not |
42 | * automatically manage |
43 | * |
44 | * \sa PrimFunc |
45 | */ |
46 | class PrimFuncNode : public BaseFuncNode { |
47 | public: |
48 | /*! \brief Function parameters */ |
49 | Array<tir::Var> params; |
50 | /*! \brief The body of the function */ |
51 | tir::Stmt body; |
52 | /*! \brief The return type of the function. */ |
53 | Type ret_type; |
54 | /*! |
55 | * \brief Maps some parameters to specific Buffer data structures. |
56 | * |
57 | * buffer_map provides a way to express data structure's field and shape |
58 | * constraints. The provided information is used in the program analysis |
59 | * and the code generation. |
60 | * |
61 | * - It defines the vars in the Buffer (m, n) in the cases below when |
62 | * they appears in the buffer_map for the first time. |
63 | * - When a var appears multiple times, they translate into runtime |
64 | * assertion to check the field constraint. |
65 | * |
66 | * \code |
67 | * |
68 | * # The corresponding fields of f are as follows |
69 | * # |
70 | * # - f.params = [a, b] |
71 | * # - f.buffer_map = {a: A, b: B} |
72 | * # - A = decl_buffer(shape=[m, n]) |
73 | * # - B = decl_buffer(shape=[m, n]) |
74 | * |
75 | * def f(a, b): |
76 | * m, n = var(), var() |
77 | * A = bind_buffer(a, shape=[m, n]) |
78 | * B = bind_buffer(b, shape=[m, n]) |
79 | * # body |
80 | * |
81 | * \endcode |
82 | * |
83 | * buffer_map is a sugar to express: |
84 | * - Parameter unpacking: e.g. I can load a.shape[0] to get value of m |
85 | * - Constraint checking: a.shape[0] must equal b.shape[0] because they |
86 | * both corresponds to m. |
87 | |
88 | * While we could have express parameter unpacking and constraint using |
89 | * normal statements, making buffer_map as first class citizen of PrimFunc |
90 | * will make program analysis much easier. |
91 | * |
92 | * Prior to buffer flattening, which is performed either in |
93 | * StorageFlatten for TE-based schedules or in FlattenBuffer for |
94 | * TIR-based schedules, these buffer objects are used directly in |
95 | * the body of the function. After buffer flattening, these buffer |
96 | * objects remain unflattened for use in argument validation, but |
97 | * all usage in the body of the function is done through a |
98 | * flattened alias of the buffer. |
99 | */ |
100 | Map<tir::Var, Buffer> buffer_map; |
101 | |
102 | void VisitAttrs(tvm::AttrVisitor* v) { |
103 | v->Visit("params" , ¶ms); |
104 | v->Visit("body" , &body); |
105 | v->Visit("ret_type" , &ret_type); |
106 | v->Visit("buffer_map" , &buffer_map); |
107 | v->Visit("attrs" , &attrs); |
108 | v->Visit("span" , &span); |
109 | v->Visit("_checked_type_" , &checked_type_); |
110 | } |
111 | |
112 | bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { |
113 | // visit params and buffer_map first as they contains defs. |
114 | return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && |
115 | equal(ret_type, other->ret_type) && equal(body, other->body) && |
116 | equal(attrs, other->attrs); |
117 | } |
118 | |
119 | void SHashReduce(SHashReducer hash_reduce) const { |
120 | hash_reduce.DefHash(params); |
121 | hash_reduce(buffer_map); |
122 | hash_reduce(ret_type); |
123 | hash_reduce(body); |
124 | hash_reduce(attrs); |
125 | } |
126 | /*! |
127 | * \brief Return the derived function annotation of this function. |
128 | * |
129 | * \return The function type annotation. |
130 | * \note The function type annotation of PrimExpr is |
131 | * directly derived from the Vars without the need of type inference. |
132 | */ |
133 | TVM_DLL FuncType func_type_annotation() const; |
134 | |
135 | TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); |
136 | |
137 | static constexpr const char* _type_key = "tir.PrimFunc" ; |
138 | TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); |
139 | }; |
140 | |
141 | /*! |
142 | * \brief Managed reference to PrimFuncNode. |
143 | * \sa PrimFuncNode |
144 | */ |
145 | class PrimFunc : public BaseFunc { |
146 | public: |
147 | /*! |
148 | * \brief Constructor |
149 | * |
150 | * \param params The parameters of the function. |
151 | * |
152 | * \param body The body of the function. |
153 | * |
154 | * \param ret_type The return type of the function. |
155 | * |
156 | * \param buffer_map The buffer map for parameter buffer unpacking. |
157 | * This contains buffer objects as they appear in the body of the |
158 | * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally |
159 | * generated as a tensor of shape ``[32, 32]``) |
160 | * |
161 | * \param attrs Additional function attributes. |
162 | * |
163 | * \param span The location of this object in the source code. |
164 | */ |
165 | TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(), |
166 | Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(), |
167 | DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span()); |
168 | |
169 | TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); |
170 | TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); |
171 | }; |
172 | |
173 | /*! |
174 | * \brief Tensor intrinsics for tensorization |
175 | */ |
176 | class TensorIntrinNode : public Object { |
177 | public: |
178 | /*! \brief The function to describe the computation. */ |
179 | PrimFunc desc; |
180 | /*! \brief The function of the implementation for the execution. */ |
181 | PrimFunc impl; |
182 | |
183 | void VisitAttrs(AttrVisitor* v) { |
184 | v->Visit("desc" , &desc); |
185 | v->Visit("impl" , &impl); |
186 | } |
187 | |
188 | static constexpr const char* _type_key = "tir.TensorIntrin" ; |
189 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); |
190 | }; |
191 | |
192 | /*! |
193 | * \brief Managed reference to TensorIntrinNode. |
194 | */ |
195 | class TensorIntrin : public ObjectRef { |
196 | public: |
197 | /*! |
198 | * \brief Constructor |
199 | * \param desc The function to describe the computation. |
200 | * \param impl The function of the implementation for the execution. |
201 | */ |
202 | TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl); |
203 | |
204 | /*! |
205 | * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked |
206 | * up with its name. |
207 | * \param name The name of the TensorIntrin to register |
208 | * \param intrin The TensorIntrin to register. |
209 | * \param override Whether override existing intrinsic. |
210 | * \throws This method throws an exception if the TensorIntrin with the specified name already |
211 | * exists. |
212 | */ |
213 | TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false); |
214 | |
215 | /*! |
216 | * \brief Look up TensorIntrin by name. Raises an exception if not found. |
217 | * \param name The name of the TensorIntrin. |
218 | * \param allow_missing Whether to allow missing tensor intrin. If false, an exception is raised |
219 | * if the tensor intrin is not found. |
220 | * \return The TensorIntrin with the specified name. |
221 | * \throws This method throws an exception if the TensorIntrin does not exist and allow_missing is |
222 | * false. |
223 | */ |
224 | TVM_DLL static Optional<TensorIntrin> Get(String name, bool allow_missing = false); |
225 | |
226 | TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) |
227 | }; |
228 | |
229 | /*! |
230 | * \brief Specialize parameters of PrimFunc. |
231 | * \param func The PrimFunc to be specialized. |
232 | * \param param_map The mapping from function params to the instance. |
233 | * \return The new function with parameter specialized. |
234 | * \note We can define a Meta TIR function with symbolic shape: |
235 | * |
236 | * \code{.py} |
237 | * @T.prim_func |
238 | * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: |
239 | * A = T.match_buffer(a, (m, n), "float32") |
240 | * B = T.match_buffer(b, (m, n), "float32") |
241 | * for i, j in T.grid(m, n): |
242 | * with T.block(): |
243 | * vi, vj = T.axis.remap("SS", [i, j]) |
244 | * B[vi, vj] = A[vi, vj] |
245 | * \endcode |
246 | * |
247 | * Then we can make it specialized with given shapes or buffers. |
248 | * |
249 | * \code{.py} |
250 | * a, _, m, n = mem_copy.params |
251 | * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) |
252 | * # or |
253 | * func = mem_copy.specialize({n: 16, m: 16}) |
254 | * \endcode |
255 | * |
256 | * \code{.py} |
257 | * @T.prim_func |
258 | * def mem_copy_16_16(a: T.handle, b: T.handle) -> None: |
259 | * A = T.match_buffer(a, (16, 16), "float32") |
260 | * B = T.match_buffer(b, (16, 16), "float32") |
261 | * for i, j in T.grid(16, 16): |
262 | * with T.block(): |
263 | * vi, vj = T.axis.remap("SS", [i, j]) |
264 | * B[vi, vj] = A[vi, vj] |
265 | * \endcode |
266 | */ |
267 | PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map); |
268 | |
269 | /*! |
270 | * \brief PrimFunc specific attribute names. |
271 | * |
272 | * \sa tvm::attr |
273 | */ |
274 | namespace attr { |
275 | /*! |
276 | * \brief List of thread IterVar that a DeviceLaunch function corresponds to. |
277 | * |
278 | * Type: Array<tir::IterVar> |
279 | * |
280 | * We call a device kernel launch function f using the following convention: |
281 | * |
282 | * Call(f, |
283 | * [arg1, arg2, ..., arg_n, |
284 | * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) |
285 | * |
286 | * Here n = len(arg), m = len(work_size) = len(device_thread_axis). |
287 | * |
288 | * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. |
289 | * |
290 | * The list of device_thread_axis indicates how can be bind the |
291 | * work_size arguments to the corresponding threads. |
292 | * |
293 | * \sa tvm::CallingConv::kDeviceKernelLaunch |
294 | */ |
295 | constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis" ; |
296 | |
297 | /*! |
298 | * \brief Whether or not use dynamic shared memory. |
299 | * |
300 | * Type: Integer |
301 | */ |
302 | constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory" ; |
303 | |
304 | /*! |
305 | * \brief Whether to set noalias rule on the function arguments. |
306 | * |
307 | * Type: Integer |
308 | */ |
309 | constexpr const char* kNoAlias = "tir.noalias" ; |
310 | |
311 | /*! |
312 | * \brief Mark the function as the entry function of |
313 | * the final generated runtime module. |
314 | * |
315 | * Type: Integer |
316 | * |
317 | * \note There can only be one entry function per module. |
318 | */ |
319 | constexpr const char* kIsEntryFunc = "tir.is_entry_func" ; |
320 | |
321 | /*! |
322 | * \brief Mark the function as the global function called from the host. |
323 | * |
324 | * Type: Integer |
325 | */ |
326 | constexpr const char* kIsGlobalFunc = "tir.is_global_func" ; |
327 | |
328 | } // namespace attr |
329 | } // namespace tir |
330 | } // namespace tvm |
331 | #endif // TVM_TIR_FUNCTION_H_ |
332 | |