1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file expr_functor.cc |
21 | */ |
22 | #include <tvm/tir/expr_functor.h> |
23 | |
24 | #include "functor_common.h" |
25 | |
26 | namespace tvm { |
27 | namespace tir { |
28 | |
29 | void ExprVisitor::VisitExpr_(const VarNode* op) {} |
30 | |
31 | void ExprVisitor::VisitExpr_(const SizeVarNode* op) { |
32 | this->VisitExpr_(static_cast<const VarNode*>(op)); |
33 | } |
34 | |
35 | void ExprVisitor::VisitExpr_(const AnyNode* op) {} |
36 | |
37 | void ExprVisitor::VisitExpr_(const LoadNode* op) { |
38 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
39 | } |
40 | |
41 | void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { |
42 | VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
43 | } |
44 | |
45 | void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) { |
46 | VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
47 | } |
48 | |
49 | void ExprVisitor::VisitExpr_(const LetNode* op) { |
50 | this->VisitExpr(op->value); |
51 | this->VisitExpr(op->body); |
52 | } |
53 | |
54 | void ExprVisitor::VisitExpr_(const CallNode* op) { |
55 | VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
56 | } |
57 | |
58 | #define DEFINE_BINOP_VISIT_(OP) \ |
59 | void ExprVisitor::VisitExpr_(const OP* op) { \ |
60 | this->VisitExpr(op->a); \ |
61 | this->VisitExpr(op->b); \ |
62 | } |
63 | |
64 | DEFINE_BINOP_VISIT_(AddNode); |
65 | DEFINE_BINOP_VISIT_(SubNode); |
66 | DEFINE_BINOP_VISIT_(MulNode); |
67 | DEFINE_BINOP_VISIT_(DivNode); |
68 | DEFINE_BINOP_VISIT_(ModNode); |
69 | DEFINE_BINOP_VISIT_(FloorDivNode); |
70 | DEFINE_BINOP_VISIT_(FloorModNode); |
71 | DEFINE_BINOP_VISIT_(MinNode); |
72 | DEFINE_BINOP_VISIT_(MaxNode); |
73 | DEFINE_BINOP_VISIT_(EQNode); |
74 | DEFINE_BINOP_VISIT_(NENode); |
75 | DEFINE_BINOP_VISIT_(LTNode); |
76 | DEFINE_BINOP_VISIT_(LENode); |
77 | DEFINE_BINOP_VISIT_(GTNode); |
78 | DEFINE_BINOP_VISIT_(GENode); |
79 | DEFINE_BINOP_VISIT_(AndNode); |
80 | DEFINE_BINOP_VISIT_(OrNode); |
81 | |
82 | void ExprVisitor::VisitExpr_(const IntImmNode* op) {} |
83 | void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} |
84 | void ExprVisitor::VisitExpr_(const StringImmNode* op) {} |
85 | |
86 | void ExprVisitor::VisitExpr_(const ReduceNode* op) { |
87 | VisitArray(op->axis, [this](const IterVar& r) { |
88 | this->VisitExpr(r->dom->min); |
89 | this->VisitExpr(r->dom->extent); |
90 | }); |
91 | VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
92 | if (!op->init.empty()) { |
93 | VisitArray(op->init, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
94 | } |
95 | this->VisitExpr(op->condition); |
96 | } |
97 | |
98 | void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } |
99 | |
100 | void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } |
101 | |
102 | void ExprVisitor::VisitExpr_(const SelectNode* op) { |
103 | this->VisitExpr(op->condition); |
104 | this->VisitExpr(op->true_value); |
105 | this->VisitExpr(op->false_value); |
106 | } |
107 | |
108 | void ExprVisitor::VisitExpr_(const RampNode* op) { |
109 | this->VisitExpr(op->base); |
110 | this->VisitExpr(op->stride); |
111 | } |
112 | |
113 | void ExprVisitor::VisitExpr_(const ShuffleNode* op) { |
114 | VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
115 | VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
116 | } |
117 | |
118 | void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } |
119 | |
120 | PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef<PrimExpr>(op); } |
121 | |
122 | PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { |
123 | return this->VisitExpr_(static_cast<const VarNode*>(op)); |
124 | } |
125 | |
126 | PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef<PrimExpr>(op); } |
127 | |
128 | PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { |
129 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
130 | } |
131 | |
132 | PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { |
133 | auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; |
134 | Array<PrimExpr> indices = op->indices.Map(fmutate); |
135 | if (indices.same_as(op->indices)) { |
136 | return GetRef<PrimExpr>(op); |
137 | } else { |
138 | return BufferLoad(op->buffer, indices); |
139 | } |
140 | } |
141 | |
142 | PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { |
143 | auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; |
144 | Array<PrimExpr> indices = op->indices.Map(fmutate); |
145 | if (indices.same_as(op->indices)) { |
146 | return GetRef<PrimExpr>(op); |
147 | } else { |
148 | return ProducerLoad(op->producer, indices); |
149 | } |
150 | } |
151 | |
152 | PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { |
153 | PrimExpr value = this->VisitExpr(op->value); |
154 | PrimExpr body = this->VisitExpr(op->body); |
155 | if (value.same_as(op->value) && body.same_as(op->body)) { |
156 | return GetRef<PrimExpr>(op); |
157 | } else { |
158 | return Let(op->var, value, body); |
159 | } |
160 | } |
161 | |
162 | PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { |
163 | auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; |
164 | Array<PrimExpr> args = op->args.Map(fmutate); |
165 | |
166 | if (args.same_as(op->args)) { |
167 | return GetRef<PrimExpr>(op); |
168 | } else { |
169 | return Call(op->dtype, op->op, args); |
170 | } |
171 | } |
172 | |
173 | #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ |
174 | PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef<PrimExpr>(op); } |
175 | |
176 | DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) |
177 | DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) |
178 | DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) |
179 | |
180 | #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ |
181 | PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \ |
182 | PrimExpr a = this->VisitExpr(op->a); \ |
183 | PrimExpr b = this->VisitExpr(op->b); \ |
184 | if (a.same_as(op->a) && b.same_as(op->b)) { \ |
185 | return GetRef<PrimExpr>(op); \ |
186 | } else { \ |
187 | return OP(a, b); \ |
188 | } \ |
189 | } |
190 | |
191 | DEFINE_BIOP_EXPR_MUTATE_(Add); |
192 | DEFINE_BIOP_EXPR_MUTATE_(Sub); |
193 | DEFINE_BIOP_EXPR_MUTATE_(Mul); |
194 | DEFINE_BIOP_EXPR_MUTATE_(Div); |
195 | DEFINE_BIOP_EXPR_MUTATE_(Mod); |
196 | DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); |
197 | DEFINE_BIOP_EXPR_MUTATE_(FloorMod); |
198 | DEFINE_BIOP_EXPR_MUTATE_(Min); |
199 | DEFINE_BIOP_EXPR_MUTATE_(Max); |
200 | DEFINE_BIOP_EXPR_MUTATE_(EQ); |
201 | DEFINE_BIOP_EXPR_MUTATE_(NE); |
202 | DEFINE_BIOP_EXPR_MUTATE_(LT); |
203 | DEFINE_BIOP_EXPR_MUTATE_(LE); |
204 | DEFINE_BIOP_EXPR_MUTATE_(GT); |
205 | DEFINE_BIOP_EXPR_MUTATE_(GE); |
206 | DEFINE_BIOP_EXPR_MUTATE_(And); |
207 | DEFINE_BIOP_EXPR_MUTATE_(Or); |
208 | |
209 | PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { |
210 | auto fitervar = [this](const IterVar& v) { |
211 | Range r = v->dom; |
212 | PrimExpr min = this->VisitExpr(r->min); |
213 | PrimExpr extent = this->VisitExpr(r->extent); |
214 | if (min.same_as(r->min) && extent.same_as(r->extent)) { |
215 | return v; |
216 | } else { |
217 | return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag); |
218 | } |
219 | }; |
220 | Array<IterVar> axis = op->axis.Map(fitervar); |
221 | |
222 | auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; |
223 | Array<PrimExpr> source = op->source.Map(fexpr); |
224 | Array<PrimExpr> init = op->init.Map(fexpr); |
225 | |
226 | PrimExpr condition = this->VisitExpr(op->condition); |
227 | |
228 | if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) && |
229 | init.same_as(op->init)) { |
230 | return GetRef<PrimExpr>(op); |
231 | } else { |
232 | return Reduce(op->combiner, source, axis, condition, op->value_index, init); |
233 | } |
234 | } |
235 | |
236 | PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { |
237 | PrimExpr value = this->VisitExpr(op->value); |
238 | if (value.same_as(op->value)) { |
239 | return GetRef<PrimExpr>(op); |
240 | } else { |
241 | return Cast(op->dtype, value); |
242 | } |
243 | } |
244 | |
245 | PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { |
246 | PrimExpr a = this->VisitExpr(op->a); |
247 | if (a.same_as(op->a)) { |
248 | return GetRef<PrimExpr>(op); |
249 | } else { |
250 | return Not(a); |
251 | } |
252 | } |
253 | |
254 | PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { |
255 | PrimExpr condition = this->VisitExpr(op->condition); |
256 | PrimExpr true_value = this->VisitExpr(op->true_value); |
257 | PrimExpr false_value = this->VisitExpr(op->false_value); |
258 | if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && |
259 | false_value.same_as(op->false_value)) { |
260 | return GetRef<PrimExpr>(op); |
261 | } else { |
262 | return Select(condition, true_value, false_value); |
263 | } |
264 | } |
265 | |
266 | PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { |
267 | PrimExpr base = this->VisitExpr(op->base); |
268 | PrimExpr stride = this->VisitExpr(op->stride); |
269 | if (base.same_as(op->base) && stride.same_as(op->stride)) { |
270 | return GetRef<PrimExpr>(op); |
271 | } else { |
272 | return Ramp(base, stride, op->lanes); |
273 | } |
274 | } |
275 | |
276 | PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { |
277 | PrimExpr value = this->VisitExpr(op->value); |
278 | if (value.same_as(op->value)) { |
279 | return GetRef<PrimExpr>(op); |
280 | } else { |
281 | return Broadcast(value, op->lanes); |
282 | } |
283 | } |
284 | |
285 | PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { |
286 | auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; |
287 | auto vectors = op->vectors.Map(fexpr); |
288 | if (vectors.same_as(op->vectors)) { |
289 | return GetRef<PrimExpr>(op); |
290 | } else { |
291 | return Shuffle(vectors, op->indices); |
292 | } |
293 | } |
294 | |
295 | } // namespace tir |
296 | } // namespace tvm |
297 | |