1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/var.h |
22 | * \brief Variables in the TIR. |
23 | */ |
24 | #ifndef TVM_TIR_VAR_H_ |
25 | #define TVM_TIR_VAR_H_ |
26 | |
27 | #include <tvm/ir/expr.h> |
28 | #include <tvm/node/node.h> |
29 | #include <tvm/runtime/data_type.h> |
30 | |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | |
36 | /*! |
37 | * \brief A variable node in the IR. |
38 | * |
39 | * A variable is uniquely identified by its address. |
40 | * |
41 | * Each variable is only bound once in the following nodes: |
42 | * - Allocate |
43 | * - For |
44 | * - Let |
45 | * - LetStmt |
46 | */ |
47 | class VarNode : public PrimExprNode { |
48 | public: |
49 | /*! |
50 | * \brief The hint to the variable name. |
51 | * \note Each variable is uniquely identified by its address. |
52 | */ |
53 | String name_hint; |
54 | /*! |
55 | * \brief type annotation of the variable. |
56 | * |
57 | * It is an optional field that provides a refined type of the variable than dtype. |
58 | * |
59 | * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type. |
60 | */ |
61 | Type type_annotation; |
62 | |
63 | void VisitAttrs(AttrVisitor* v) { |
64 | v->Visit("dtype" , &dtype); |
65 | v->Visit("name" , &name_hint); |
66 | v->Visit("type_annotation" , &type_annotation); |
67 | v->Visit("span" , &span); |
68 | } |
69 | |
70 | bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { |
71 | if (!equal(dtype, other->dtype)) return false; |
72 | if (!equal(type_annotation, other->type_annotation)) return false; |
73 | return equal.FreeVarEqualImpl(this, other); |
74 | } |
75 | |
76 | void SHashReduce(SHashReducer hash_reduce) const { |
77 | hash_reduce(dtype); |
78 | hash_reduce(type_annotation); |
79 | hash_reduce.FreeVarHashImpl(this); |
80 | } |
81 | |
82 | static constexpr const char* _type_key = "tir.Var" ; |
83 | static constexpr const uint32_t _type_child_slots = 1; |
84 | TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); |
85 | }; |
86 | |
87 | /*! \brief a named variable in TIR */ |
88 | class Var : public PrimExpr { |
89 | public: |
90 | explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {} |
91 | /*! |
92 | * \brief Constructor |
93 | * \param name_hint variable name |
94 | * \param dtype data type |
95 | * \param span The location of this object in the source code. |
96 | */ |
97 | TVM_DLL explicit Var(String name_hint = "v" , DataType dtype = DataType::Int(32), |
98 | Span span = Span()); |
99 | /*! |
100 | * \brief Constructor which provides a more detailed type annotation. |
101 | * \param name_hint variable name. |
102 | * \param type_annotation The type annotation. |
103 | * \param span The location of this object in the source code. |
104 | */ |
105 | TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span()); |
106 | /*! |
107 | * \brief Make a new copy of var with same type, append suffix |
108 | * \param suffix The suffix to be appended. |
109 | * \return the new Var copy |
110 | */ |
111 | TVM_DLL Var copy_with_suffix(const String& suffix) const; |
112 | /*! |
113 | * \brief Make a new copy of the variable with specified dtype |
114 | * \param dtype The specified dtype |
115 | * \return The new variable |
116 | */ |
117 | TVM_DLL Var copy_with_dtype(DataType dtype) const; |
118 | |
119 | /*! |
120 | * \brief Get pointer to the internal value. |
121 | * \return the corresponding Variable. |
122 | */ |
123 | const VarNode* operator->() const { return get(); } |
124 | /*! |
125 | * \brief Get pointer to the internal value. |
126 | * \return the corresponding Variable. |
127 | */ |
128 | const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); } |
129 | /*! \brief type indicate the container type */ |
130 | using ContainerType = VarNode; |
131 | }; |
132 | |
133 | /*! |
134 | * \brief A variable node represent a tensor index size, |
135 | * whose value must be non-negative. |
136 | */ |
137 | class SizeVarNode : public VarNode { |
138 | public: |
139 | static constexpr const char* _type_key = "tir.SizeVar" ; |
140 | TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); |
141 | }; |
142 | |
143 | /*! \brief a named variable represents a tensor index size */ |
144 | class SizeVar : public Var { |
145 | public: |
146 | explicit SizeVar(ObjectPtr<Object> n) : Var(n) {} |
147 | /*! |
148 | * \brief constructor |
149 | * \param name_hint variable name |
150 | * \param t data type |
151 | * \param span The location of this object in the source code. |
152 | */ |
153 | TVM_DLL explicit SizeVar(String name_hint = "s" , DataType t = DataType::Int(32), |
154 | Span span = Span()); |
155 | /*! |
156 | * \brief Get pointer to the internal value. |
157 | * \return the corresponding Variable. |
158 | */ |
159 | const SizeVarNode* operator->() const { return get(); } |
160 | /*! |
161 | * \brief Get pointer to the internal value. |
162 | * \return the corresponding Variable. |
163 | */ |
164 | const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); } |
165 | /*! \brief type indicate the container type */ |
166 | using ContainerType = SizeVarNode; |
167 | }; |
168 | |
169 | using Region = Array<Range>; |
170 | |
171 | /*! |
172 | * \brief Type of iteration variable. |
173 | * Each IterVar have a specific type. |
174 | * |
175 | * The type of iter var can be overriden via |
176 | * stage.iter_var_attrs given they are compatible. |
177 | */ |
178 | enum IterVarType : int { |
179 | /*! |
180 | * \brief Data parallel iteration. |
181 | * This normally corresponds to axis of Tensor. |
182 | * Allow all IterVar manipulations. |
183 | * |
184 | * \note This does not mean the loop |
185 | * have to be executed in parallel fashion. |
186 | */ |
187 | kDataPar = 0, |
188 | /*! |
189 | * \brief The IterVar itself is a thread-index |
190 | * of a fixed thread launching group. |
191 | * Note that this is already assumed to be parallelized. |
192 | * |
193 | * Disallow: split/fuse/vectorize/parallel |
194 | */ |
195 | kThreadIndex = 1, |
196 | /*! |
197 | * \brief Communicative reduction. |
198 | * Cannot be directly parallelized. |
199 | * |
200 | * Disallow: parallel/vectorize |
201 | */ |
202 | kCommReduce = 2, |
203 | /*! |
204 | * \brief Serial loops with loop carry dependency, |
205 | * the iteration must execute in order. |
206 | * Cannot be re-ordered. |
207 | * |
208 | * Disallow: reorder/parallel/vectorize |
209 | */ |
210 | kOrdered = 3, |
211 | /*! |
212 | * \brief IterVar is opaque, |
213 | * |
214 | * May not corresponds to any generated loop |
215 | * Disallow all IterVar manipulations and compute_at |
216 | * |
217 | * \note This is usually used to implement composite op |
218 | * or external op, where the |
219 | */ |
220 | kOpaque = 4, |
221 | // The following are possible additional |
222 | // types that are provided during schedule |
223 | /*! |
224 | * \brief The execution is unrolled. |
225 | */ |
226 | kUnrolled = 5, |
227 | /*! |
228 | * \brief The loop is vectorized. |
229 | */ |
230 | kVectorized = 6, |
231 | /*! |
232 | * \brief The loop is parallelized. |
233 | */ |
234 | kParallelized = 7, |
235 | /*! |
236 | * \brief Marks boundary of tensorization intrinsic. |
237 | */ |
238 | kTensorized = 8 |
239 | }; |
240 | |
241 | /*! |
242 | * \brief An iteration variable representing an iteration |
243 | * over a one dimensional interval. |
244 | * |
245 | * The dtype of the extent of the `dom` of the IterVar must match the dtype of the internal Var. |
246 | */ |
247 | class IterVarNode : public Object { |
248 | public: |
249 | /*! |
250 | * \brief the domain of iteration, if known, can be None |
251 | * For the intermediate schedule node, before schedule. |
252 | */ |
253 | Range dom; |
254 | /*! \brief The looping variable */ |
255 | Var var; |
256 | /*! \brief The type of the IterVar */ |
257 | IterVarType iter_type; |
258 | /*! |
259 | * \brief additional tag on the iteration variable, |
260 | * set this if this is binded already to a known thread tag. |
261 | */ |
262 | String thread_tag; |
263 | /*! |
264 | * \brief Span that points to the original source code. |
265 | * Reserved debug information. |
266 | */ |
267 | mutable Span span; |
268 | |
269 | void VisitAttrs(AttrVisitor* v) { |
270 | v->Visit("dom" , &dom); |
271 | v->Visit("var" , &var); |
272 | v->Visit("iter_type" , &iter_type); |
273 | v->Visit("thread_tag" , &thread_tag); |
274 | v->Visit("span" , &span); |
275 | } |
276 | |
277 | bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { |
278 | return equal(dom, other->dom) && equal.DefEqual(var, other->var) && |
279 | equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag); |
280 | } |
281 | |
282 | void SHashReduce(SHashReducer hash_reduce) const { |
283 | hash_reduce(dom); |
284 | hash_reduce.DefHash(var); |
285 | hash_reduce(iter_type); |
286 | hash_reduce(thread_tag); |
287 | } |
288 | |
289 | static constexpr const char* _type_key = "tir.IterVar" ; |
290 | static constexpr const bool _type_has_method_sequal_reduce = true; |
291 | static constexpr const bool _type_has_method_shash_reduce = true; |
292 | TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); |
293 | }; |
294 | |
295 | /*! |
296 | * \brief Iteration Variable, |
297 | * represents an iteration over an integer interval. |
298 | * |
299 | * The dtype of the extent of the `dom` of the IterVar must match the dtype of the internal Var. |
300 | */ |
301 | class IterVar : public ObjectRef { |
302 | public: |
303 | TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "" , |
304 | Span span = Span()); |
305 | /*! |
306 | * \return the corresponding var in the IterVar. |
307 | */ |
308 | inline operator PrimExpr() const; |
309 | |
310 | TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode); |
311 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode); |
312 | }; |
313 | |
314 | // inline implementations |
315 | inline IterVar::operator PrimExpr() const { return (*this)->var; } |
316 | |
317 | inline const char* IterVarType2String(IterVarType t) { |
318 | switch (t) { |
319 | case kDataPar: |
320 | return "DataPar" ; |
321 | case kThreadIndex: |
322 | return "ThreadIndex" ; |
323 | case kCommReduce: |
324 | return "CommReduce" ; |
325 | case kOrdered: |
326 | return "Ordered" ; |
327 | case kOpaque: |
328 | return "Opaque" ; |
329 | case kUnrolled: |
330 | return "Unrolled" ; |
331 | case kVectorized: |
332 | return "Vectorized" ; |
333 | case kParallelized: |
334 | return "Parallelized" ; |
335 | case kTensorized: |
336 | return "Tensorized" ; |
337 | } |
338 | return "Unknown" ; |
339 | } |
340 | } // namespace tir |
341 | } // namespace tvm |
342 | #endif // TVM_TIR_VAR_H_ |
343 | |