1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/var.h
22 * \brief Variables in the TIR.
23 */
24#ifndef TVM_TIR_VAR_H_
25#define TVM_TIR_VAR_H_
26
27#include <tvm/ir/expr.h>
28#include <tvm/node/node.h>
29#include <tvm/runtime/data_type.h>
30
31#include <string>
32
33namespace tvm {
34namespace tir {
35
36/*!
37 * \brief A variable node in the IR.
38 *
39 * A variable is uniquely identified by its address.
40 *
41 * Each variable is only bound once in the following nodes:
42 * - Allocate
43 * - For
44 * - Let
45 * - LetStmt
46 */
47class VarNode : public PrimExprNode {
48 public:
49 /*!
50 * \brief The hint to the variable name.
51 * \note Each variable is uniquely identified by its address.
52 */
53 String name_hint;
54 /*!
55 * \brief type annotation of the variable.
56 *
57 * It is an optional field that provides a refined type of the variable than dtype.
58 *
59 * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
60 */
61 Type type_annotation;
62
63 void VisitAttrs(AttrVisitor* v) {
64 v->Visit("dtype", &dtype);
65 v->Visit("name", &name_hint);
66 v->Visit("type_annotation", &type_annotation);
67 v->Visit("span", &span);
68 }
69
70 bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
71 if (!equal(dtype, other->dtype)) return false;
72 if (!equal(type_annotation, other->type_annotation)) return false;
73 return equal.FreeVarEqualImpl(this, other);
74 }
75
76 void SHashReduce(SHashReducer hash_reduce) const {
77 hash_reduce(dtype);
78 hash_reduce(type_annotation);
79 hash_reduce.FreeVarHashImpl(this);
80 }
81
82 static constexpr const char* _type_key = "tir.Var";
83 static constexpr const uint32_t _type_child_slots = 1;
84 TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
85};
86
87/*! \brief a named variable in TIR */
88class Var : public PrimExpr {
89 public:
90 explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
91 /*!
92 * \brief Constructor
93 * \param name_hint variable name
94 * \param dtype data type
95 * \param span The location of this object in the source code.
96 */
97 TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32),
98 Span span = Span());
99 /*!
100 * \brief Constructor which provides a more detailed type annotation.
101 * \param name_hint variable name.
102 * \param type_annotation The type annotation.
103 * \param span The location of this object in the source code.
104 */
105 TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
106 /*!
107 * \brief Make a new copy of var with same type, append suffix
108 * \param suffix The suffix to be appended.
109 * \return the new Var copy
110 */
111 TVM_DLL Var copy_with_suffix(const String& suffix) const;
112 /*!
113 * \brief Make a new copy of the variable with specified dtype
114 * \param dtype The specified dtype
115 * \return The new variable
116 */
117 TVM_DLL Var copy_with_dtype(DataType dtype) const;
118
119 /*!
120 * \brief Get pointer to the internal value.
121 * \return the corresponding Variable.
122 */
123 const VarNode* operator->() const { return get(); }
124 /*!
125 * \brief Get pointer to the internal value.
126 * \return the corresponding Variable.
127 */
128 const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
129 /*! \brief type indicate the container type */
130 using ContainerType = VarNode;
131};
132
133/*!
134 * \brief A variable node represent a tensor index size,
135 * whose value must be non-negative.
136 */
137class SizeVarNode : public VarNode {
138 public:
139 static constexpr const char* _type_key = "tir.SizeVar";
140 TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
141};
142
143/*! \brief a named variable represents a tensor index size */
144class SizeVar : public Var {
145 public:
146 explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
147 /*!
148 * \brief constructor
149 * \param name_hint variable name
150 * \param t data type
151 * \param span The location of this object in the source code.
152 */
153 TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
154 Span span = Span());
155 /*!
156 * \brief Get pointer to the internal value.
157 * \return the corresponding Variable.
158 */
159 const SizeVarNode* operator->() const { return get(); }
160 /*!
161 * \brief Get pointer to the internal value.
162 * \return the corresponding Variable.
163 */
164 const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
165 /*! \brief type indicate the container type */
166 using ContainerType = SizeVarNode;
167};
168
169using Region = Array<Range>;
170
171/*!
172 * \brief Type of iteration variable.
173 * Each IterVar have a specific type.
174 *
175 * The type of iter var can be overriden via
176 * stage.iter_var_attrs given they are compatible.
177 */
178enum IterVarType : int {
179 /*!
180 * \brief Data parallel iteration.
181 * This normally corresponds to axis of Tensor.
182 * Allow all IterVar manipulations.
183 *
184 * \note This does not mean the loop
185 * have to be executed in parallel fashion.
186 */
187 kDataPar = 0,
188 /*!
189 * \brief The IterVar itself is a thread-index
190 * of a fixed thread launching group.
191 * Note that this is already assumed to be parallelized.
192 *
193 * Disallow: split/fuse/vectorize/parallel
194 */
195 kThreadIndex = 1,
196 /*!
197 * \brief Communicative reduction.
198 * Cannot be directly parallelized.
199 *
200 * Disallow: parallel/vectorize
201 */
202 kCommReduce = 2,
203 /*!
204 * \brief Serial loops with loop carry dependency,
205 * the iteration must execute in order.
206 * Cannot be re-ordered.
207 *
208 * Disallow: reorder/parallel/vectorize
209 */
210 kOrdered = 3,
211 /*!
212 * \brief IterVar is opaque,
213 *
214 * May not corresponds to any generated loop
215 * Disallow all IterVar manipulations and compute_at
216 *
217 * \note This is usually used to implement composite op
218 * or external op, where the
219 */
220 kOpaque = 4,
221 // The following are possible additional
222 // types that are provided during schedule
223 /*!
224 * \brief The execution is unrolled.
225 */
226 kUnrolled = 5,
227 /*!
228 * \brief The loop is vectorized.
229 */
230 kVectorized = 6,
231 /*!
232 * \brief The loop is parallelized.
233 */
234 kParallelized = 7,
235 /*!
236 * \brief Marks boundary of tensorization intrinsic.
237 */
238 kTensorized = 8
239};
240
241/*!
242 * \brief An iteration variable representing an iteration
243 * over a one dimensional interval.
244 *
245 * The dtype of the extent of the `dom` of the IterVar must match the dtype of the internal Var.
246 */
247class IterVarNode : public Object {
248 public:
249 /*!
250 * \brief the domain of iteration, if known, can be None
251 * For the intermediate schedule node, before schedule.
252 */
253 Range dom;
254 /*! \brief The looping variable */
255 Var var;
256 /*! \brief The type of the IterVar */
257 IterVarType iter_type;
258 /*!
259 * \brief additional tag on the iteration variable,
260 * set this if this is binded already to a known thread tag.
261 */
262 String thread_tag;
263 /*!
264 * \brief Span that points to the original source code.
265 * Reserved debug information.
266 */
267 mutable Span span;
268
269 void VisitAttrs(AttrVisitor* v) {
270 v->Visit("dom", &dom);
271 v->Visit("var", &var);
272 v->Visit("iter_type", &iter_type);
273 v->Visit("thread_tag", &thread_tag);
274 v->Visit("span", &span);
275 }
276
277 bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
278 return equal(dom, other->dom) && equal.DefEqual(var, other->var) &&
279 equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag);
280 }
281
282 void SHashReduce(SHashReducer hash_reduce) const {
283 hash_reduce(dom);
284 hash_reduce.DefHash(var);
285 hash_reduce(iter_type);
286 hash_reduce(thread_tag);
287 }
288
289 static constexpr const char* _type_key = "tir.IterVar";
290 static constexpr const bool _type_has_method_sequal_reduce = true;
291 static constexpr const bool _type_has_method_shash_reduce = true;
292 TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
293};
294
295/*!
296 * \brief Iteration Variable,
297 * represents an iteration over an integer interval.
298 *
299 * The dtype of the extent of the `dom` of the IterVar must match the dtype of the internal Var.
300 */
301class IterVar : public ObjectRef {
302 public:
303 TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "",
304 Span span = Span());
305 /*!
306 * \return the corresponding var in the IterVar.
307 */
308 inline operator PrimExpr() const;
309
310 TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode);
311 TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode);
312};
313
314// inline implementations
315inline IterVar::operator PrimExpr() const { return (*this)->var; }
316
317inline const char* IterVarType2String(IterVarType t) {
318 switch (t) {
319 case kDataPar:
320 return "DataPar";
321 case kThreadIndex:
322 return "ThreadIndex";
323 case kCommReduce:
324 return "CommReduce";
325 case kOrdered:
326 return "Ordered";
327 case kOpaque:
328 return "Opaque";
329 case kUnrolled:
330 return "Unrolled";
331 case kVectorized:
332 return "Vectorized";
333 case kParallelized:
334 return "Parallelized";
335 case kTensorized:
336 return "Tensorized";
337 }
338 return "Unknown";
339}
340} // namespace tir
341} // namespace tvm
342#endif // TVM_TIR_VAR_H_
343