1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file expr_functor.cc
21 */
22#include <tvm/tir/expr_functor.h>
23
24#include "functor_common.h"
25
26namespace tvm {
27namespace tir {
28
29void ExprVisitor::VisitExpr_(const VarNode* op) {}
30
31void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
32 this->VisitExpr_(static_cast<const VarNode*>(op));
33}
34
35void ExprVisitor::VisitExpr_(const AnyNode* op) {}
36
37void ExprVisitor::VisitExpr_(const LoadNode* op) {
38 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
39}
40
41void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
42 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
43}
44
45void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) {
46 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
47}
48
49void ExprVisitor::VisitExpr_(const LetNode* op) {
50 this->VisitExpr(op->value);
51 this->VisitExpr(op->body);
52}
53
54void ExprVisitor::VisitExpr_(const CallNode* op) {
55 VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
56}
57
58#define DEFINE_BINOP_VISIT_(OP) \
59 void ExprVisitor::VisitExpr_(const OP* op) { \
60 this->VisitExpr(op->a); \
61 this->VisitExpr(op->b); \
62 }
63
64DEFINE_BINOP_VISIT_(AddNode);
65DEFINE_BINOP_VISIT_(SubNode);
66DEFINE_BINOP_VISIT_(MulNode);
67DEFINE_BINOP_VISIT_(DivNode);
68DEFINE_BINOP_VISIT_(ModNode);
69DEFINE_BINOP_VISIT_(FloorDivNode);
70DEFINE_BINOP_VISIT_(FloorModNode);
71DEFINE_BINOP_VISIT_(MinNode);
72DEFINE_BINOP_VISIT_(MaxNode);
73DEFINE_BINOP_VISIT_(EQNode);
74DEFINE_BINOP_VISIT_(NENode);
75DEFINE_BINOP_VISIT_(LTNode);
76DEFINE_BINOP_VISIT_(LENode);
77DEFINE_BINOP_VISIT_(GTNode);
78DEFINE_BINOP_VISIT_(GENode);
79DEFINE_BINOP_VISIT_(AndNode);
80DEFINE_BINOP_VISIT_(OrNode);
81
82void ExprVisitor::VisitExpr_(const IntImmNode* op) {}
83void ExprVisitor::VisitExpr_(const FloatImmNode* op) {}
84void ExprVisitor::VisitExpr_(const StringImmNode* op) {}
85
86void ExprVisitor::VisitExpr_(const ReduceNode* op) {
87 VisitArray(op->axis, [this](const IterVar& r) {
88 this->VisitExpr(r->dom->min);
89 this->VisitExpr(r->dom->extent);
90 });
91 VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
92 if (!op->init.empty()) {
93 VisitArray(op->init, [this](const PrimExpr& e) { this->VisitExpr(e); });
94 }
95 this->VisitExpr(op->condition);
96}
97
98void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); }
99
100void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); }
101
102void ExprVisitor::VisitExpr_(const SelectNode* op) {
103 this->VisitExpr(op->condition);
104 this->VisitExpr(op->true_value);
105 this->VisitExpr(op->false_value);
106}
107
108void ExprVisitor::VisitExpr_(const RampNode* op) {
109 this->VisitExpr(op->base);
110 this->VisitExpr(op->stride);
111}
112
113void ExprVisitor::VisitExpr_(const ShuffleNode* op) {
114 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
115 VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); });
116}
117
118void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); }
119
120PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef<PrimExpr>(op); }
121
122PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
123 return this->VisitExpr_(static_cast<const VarNode*>(op));
124}
125
126PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef<PrimExpr>(op); }
127
128PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
129 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
130}
131
132PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
133 auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
134 Array<PrimExpr> indices = op->indices.Map(fmutate);
135 if (indices.same_as(op->indices)) {
136 return GetRef<PrimExpr>(op);
137 } else {
138 return BufferLoad(op->buffer, indices);
139 }
140}
141
142PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
143 auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
144 Array<PrimExpr> indices = op->indices.Map(fmutate);
145 if (indices.same_as(op->indices)) {
146 return GetRef<PrimExpr>(op);
147 } else {
148 return ProducerLoad(op->producer, indices);
149 }
150}
151
152PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
153 PrimExpr value = this->VisitExpr(op->value);
154 PrimExpr body = this->VisitExpr(op->body);
155 if (value.same_as(op->value) && body.same_as(op->body)) {
156 return GetRef<PrimExpr>(op);
157 } else {
158 return Let(op->var, value, body);
159 }
160}
161
162PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
163 auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
164 Array<PrimExpr> args = op->args.Map(fmutate);
165
166 if (args.same_as(op->args)) {
167 return GetRef<PrimExpr>(op);
168 } else {
169 return Call(op->dtype, op->op, args);
170 }
171}
172
173#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
174 PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef<PrimExpr>(op); }
175
176DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
177DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
178DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
179
180#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
181 PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \
182 PrimExpr a = this->VisitExpr(op->a); \
183 PrimExpr b = this->VisitExpr(op->b); \
184 if (a.same_as(op->a) && b.same_as(op->b)) { \
185 return GetRef<PrimExpr>(op); \
186 } else { \
187 return OP(a, b); \
188 } \
189 }
190
191DEFINE_BIOP_EXPR_MUTATE_(Add);
192DEFINE_BIOP_EXPR_MUTATE_(Sub);
193DEFINE_BIOP_EXPR_MUTATE_(Mul);
194DEFINE_BIOP_EXPR_MUTATE_(Div);
195DEFINE_BIOP_EXPR_MUTATE_(Mod);
196DEFINE_BIOP_EXPR_MUTATE_(FloorDiv);
197DEFINE_BIOP_EXPR_MUTATE_(FloorMod);
198DEFINE_BIOP_EXPR_MUTATE_(Min);
199DEFINE_BIOP_EXPR_MUTATE_(Max);
200DEFINE_BIOP_EXPR_MUTATE_(EQ);
201DEFINE_BIOP_EXPR_MUTATE_(NE);
202DEFINE_BIOP_EXPR_MUTATE_(LT);
203DEFINE_BIOP_EXPR_MUTATE_(LE);
204DEFINE_BIOP_EXPR_MUTATE_(GT);
205DEFINE_BIOP_EXPR_MUTATE_(GE);
206DEFINE_BIOP_EXPR_MUTATE_(And);
207DEFINE_BIOP_EXPR_MUTATE_(Or);
208
209PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
210 auto fitervar = [this](const IterVar& v) {
211 Range r = v->dom;
212 PrimExpr min = this->VisitExpr(r->min);
213 PrimExpr extent = this->VisitExpr(r->extent);
214 if (min.same_as(r->min) && extent.same_as(r->extent)) {
215 return v;
216 } else {
217 return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag);
218 }
219 };
220 Array<IterVar> axis = op->axis.Map(fitervar);
221
222 auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
223 Array<PrimExpr> source = op->source.Map(fexpr);
224 Array<PrimExpr> init = op->init.Map(fexpr);
225
226 PrimExpr condition = this->VisitExpr(op->condition);
227
228 if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) &&
229 init.same_as(op->init)) {
230 return GetRef<PrimExpr>(op);
231 } else {
232 return Reduce(op->combiner, source, axis, condition, op->value_index, init);
233 }
234}
235
236PrimExpr ExprMutator::VisitExpr_(const CastNode* op) {
237 PrimExpr value = this->VisitExpr(op->value);
238 if (value.same_as(op->value)) {
239 return GetRef<PrimExpr>(op);
240 } else {
241 return Cast(op->dtype, value);
242 }
243}
244
245PrimExpr ExprMutator::VisitExpr_(const NotNode* op) {
246 PrimExpr a = this->VisitExpr(op->a);
247 if (a.same_as(op->a)) {
248 return GetRef<PrimExpr>(op);
249 } else {
250 return Not(a);
251 }
252}
253
254PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) {
255 PrimExpr condition = this->VisitExpr(op->condition);
256 PrimExpr true_value = this->VisitExpr(op->true_value);
257 PrimExpr false_value = this->VisitExpr(op->false_value);
258 if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
259 false_value.same_as(op->false_value)) {
260 return GetRef<PrimExpr>(op);
261 } else {
262 return Select(condition, true_value, false_value);
263 }
264}
265
266PrimExpr ExprMutator::VisitExpr_(const RampNode* op) {
267 PrimExpr base = this->VisitExpr(op->base);
268 PrimExpr stride = this->VisitExpr(op->stride);
269 if (base.same_as(op->base) && stride.same_as(op->stride)) {
270 return GetRef<PrimExpr>(op);
271 } else {
272 return Ramp(base, stride, op->lanes);
273 }
274}
275
276PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {
277 PrimExpr value = this->VisitExpr(op->value);
278 if (value.same_as(op->value)) {
279 return GetRef<PrimExpr>(op);
280 } else {
281 return Broadcast(value, op->lanes);
282 }
283}
284
285PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
286 auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
287 auto vectors = op->vectors.Map(fexpr);
288 if (vectors.same_as(op->vectors)) {
289 return GetRef<PrimExpr>(op);
290 } else {
291 return Shuffle(vectors, op->indices);
292 }
293}
294
295} // namespace tir
296} // namespace tvm
297