1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file vectorize_loop.cc |
22 | */ |
23 | // Loop vectorizer as in Halide pipeline. |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/op_attr_types.h> |
31 | #include <tvm/tir/stmt_functor.h> |
32 | #include <tvm/tir/transform.h> |
33 | |
34 | #include <unordered_map> |
35 | #include <unordered_set> |
36 | #include <vector> |
37 | |
38 | namespace tvm { |
39 | namespace tir { |
40 | |
41 | inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { |
42 | if (e.dtype().lanes() == lanes) return e; |
43 | if (const BroadcastNode* op = e.as<BroadcastNode>()) { |
44 | if (lanes % op->lanes == 0) { |
45 | return Broadcast(op->value, lanes); |
46 | } |
47 | } |
48 | ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " |
49 | << lanes; |
50 | return Broadcast(e, lanes); |
51 | } |
52 | |
53 | // Rewrite vectorized allocation access |
54 | // This is necessary for making each vector component containing its own workspace. |
55 | // Originates from Halide's loop vectorizer |
56 | // |
57 | // s[i] = s[i * lanes + var] |
58 | // |
59 | // The same principle applies when using one thread to simulate multiple context. |
60 | // |
61 | class VecAllocAccess : public StmtExprMutator { |
62 | public: |
63 | VecAllocAccess(const VarNode* buf, Var var, int var_lanes) |
64 | : buf_(buf), var_(var), var_lanes_(var_lanes) {} |
65 | |
66 | PrimExpr VisitExpr_(const LoadNode* op) final { |
67 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
68 | } |
69 | |
70 | Stmt VisitStmt_(const StoreNode* op) final { |
71 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
72 | } |
73 | |
74 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
75 | auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
76 | return UpdateBufferAccess(load); |
77 | } |
78 | |
79 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
80 | auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
81 | return UpdateBufferAccess(store); |
82 | } |
83 | |
84 | private: |
85 | template <typename Node> |
86 | Node UpdateBufferAccess(Node node) { |
87 | // Only update the buffer that's being replaced. |
88 | if (node->buffer->data.get() != buf_) { |
89 | return node; |
90 | } |
91 | |
92 | // Find/make a Buffer object with the correct updated shape. |
93 | Buffer buf; |
94 | auto it = buffer_map_.find(node->buffer.get()); |
95 | if (it != buffer_map_.end()) { |
96 | buf = it->second; |
97 | } else { |
98 | // Extend the least significant dimension by a factor of |
99 | // var_lanes_. Typically, this will be a 1-d index into a flat |
100 | // memory space. |
101 | Array<PrimExpr> shape = node->buffer->shape; |
102 | shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); |
103 | |
104 | // TODO(Lunderberg): Move this pass to be prior to |
105 | // StorageFlatten/FlattenBuffer, implement by appending a |
106 | // dimension to the buffer. Since it is currently after the |
107 | // flattening, the strides are not technically necessary, but |
108 | // are updated for consistency. |
109 | |
110 | // Update strides if defined. |
111 | Array<PrimExpr> strides; |
112 | for (size_t i = 0; i < strides.size(); i++) { |
113 | PrimExpr stride = strides[i]; |
114 | if (i != strides.size() - 1) { |
115 | stride *= var_lanes_; |
116 | } |
117 | strides.push_back(analyzer_.Simplify(stride)); |
118 | } |
119 | |
120 | // Copy everything into the new buffer. |
121 | buf = node->buffer; |
122 | auto buf_writer = buf.CopyOnWrite(); |
123 | buf_writer->shape = shape; |
124 | buf_writer->strides = strides; |
125 | buffer_map_[buf.get()] = buf; |
126 | } |
127 | |
128 | // Extend the last index by the number of lanes in the vectorized |
129 | // variable. |
130 | Array<PrimExpr> indices = node->indices; |
131 | indices.Set(indices.size() - 1, |
132 | analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); |
133 | |
134 | auto writer = node.CopyOnWrite(); |
135 | writer->buffer = buf; |
136 | writer->indices = indices; |
137 | return node; |
138 | } |
139 | |
140 | // buffer var |
141 | const VarNode* buf_; |
142 | // Updated buffer objects. |
143 | std::unordered_map<const BufferNode*, Buffer> buffer_map_; |
144 | // variable to be replaced |
145 | Var var_; |
146 | // the lanes. |
147 | int var_lanes_; |
148 | // Analyzer for simplifications |
149 | arith::Analyzer analyzer_; |
150 | }; |
151 | |
152 | // We use ExprFunctor directly instead of StmtExprMutator |
153 | // This is because the transformation can change the dtype of the Expr |
154 | // The existing ExprMutator transformation rules may not be well defined. |
155 | class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> { |
156 | public: |
157 | using ExprFunctor::VisitExpr; |
158 | using StmtMutator::operator(); |
159 | |
160 | Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { |
161 | ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); |
162 | } |
163 | |
164 | Stmt VisitStmt(const Stmt& stmt) final { |
165 | ICHECK(!need_scalarize_); |
166 | Stmt ret = StmtMutator::VisitStmt(stmt); |
167 | if (need_scalarize_) { |
168 | need_scalarize_ = false; |
169 | return Scalarize(stmt); |
170 | } else { |
171 | return ret; |
172 | } |
173 | } |
174 | |
175 | PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); } |
176 | |
177 | PrimExpr VisitExpr_(const AddNode* op) final { |
178 | return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); |
179 | } |
180 | |
181 | PrimExpr VisitExpr_(const SubNode* op) final { |
182 | return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); |
183 | } |
184 | |
185 | PrimExpr VisitExpr_(const MulNode* op) final { |
186 | PrimExpr a = this->VisitExpr(op->a); |
187 | PrimExpr b = this->VisitExpr(op->b); |
188 | if (a.same_as(op->a) && b.same_as(op->b)) { |
189 | return GetRef<PrimExpr>(op); |
190 | } else { |
191 | int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); |
192 | if (lanes != 1) { |
193 | const RampNode* b_ramp = b.as<RampNode>(); |
194 | const RampNode* a_ramp = a.as<RampNode>(); |
195 | if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { |
196 | return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); |
197 | } |
198 | if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { |
199 | return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); |
200 | } |
201 | } |
202 | return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); |
203 | } |
204 | return BinaryVec<Mul>(op); |
205 | } |
206 | PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); } |
207 | PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); } |
208 | PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); } |
209 | PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); } |
210 | PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); } |
211 | PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); } |
212 | PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); } |
213 | PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); } |
214 | PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); } |
215 | PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); } |
216 | PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); } |
217 | PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); } |
218 | PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); } |
219 | PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); } |
220 | |
221 | PrimExpr VisitExpr_(const NotNode* op) final { |
222 | PrimExpr a = this->VisitExpr(op->a); |
223 | if (a.same_as(op->a)) { |
224 | return GetRef<PrimExpr>(op); |
225 | } else { |
226 | return !(a); |
227 | } |
228 | } |
229 | |
230 | PrimExpr VisitExpr_(const RampNode* op) final { |
231 | PrimExpr base = this->VisitExpr(op->base); |
232 | PrimExpr stride = this->VisitExpr(op->stride); |
233 | if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { |
234 | const RampNode* base_ramp = base.as<RampNode>(); |
235 | if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { |
236 | return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes); |
237 | } |
238 | } |
239 | int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); |
240 | base = BroadcastTo(base, lanes); |
241 | stride = BroadcastTo(stride, lanes); |
242 | Array<PrimExpr> elems; |
243 | for (int i = 0; i < lanes; ++i) { |
244 | elems.push_back( |
245 | Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); |
246 | } |
247 | return Shuffle::Concat(elems); |
248 | } |
249 | |
250 | PrimExpr VisitExpr_(const BroadcastNode* op) final { |
251 | PrimExpr value = this->VisitExpr(op->value); |
252 | if (value.dtype().lanes() != 1) { |
253 | need_scalarize_ = true; |
254 | return GetRef<PrimExpr>(op); |
255 | } |
256 | if (value.same_as(op->value)) { |
257 | return GetRef<PrimExpr>(op); |
258 | } else { |
259 | return Broadcast(op->value, op->lanes); |
260 | } |
261 | } |
262 | |
263 | PrimExpr VisitExpr_(const SelectNode* op) final { |
264 | PrimExpr cond = this->VisitExpr(op->condition); |
265 | PrimExpr t = this->VisitExpr(op->true_value); |
266 | PrimExpr f = this->VisitExpr(op->false_value); |
267 | if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { |
268 | return GetRef<PrimExpr>(op); |
269 | } else { |
270 | int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); |
271 | return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); |
272 | } |
273 | } |
274 | PrimExpr VisitExpr_(const CastNode* op) final { |
275 | PrimExpr value = this->VisitExpr(op->value); |
276 | if (value.same_as(op->value)) { |
277 | return GetRef<PrimExpr>(op); |
278 | } else { |
279 | return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); |
280 | } |
281 | } |
282 | |
283 | PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); } |
284 | |
285 | PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); } |
286 | |
287 | PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); } |
288 | |
289 | // Variable |
290 | PrimExpr VisitExpr_(const VarNode* op) final { |
291 | Var var = GetRef<Var>(op); |
292 | |
293 | if (var.same_as(var_)) { |
294 | return ramp_; |
295 | } |
296 | auto it = let_binding_.find(var); |
297 | if (it != let_binding_.end()) { |
298 | return it->second; |
299 | } else { |
300 | return std::move(var); |
301 | } |
302 | } |
303 | // IfThenElse expr |
304 | PrimExpr MutateIfThenElseExpr_(const CallNode* op) { |
305 | PrimExpr cond = this->VisitExpr(op->args[0]); |
306 | if (cond.dtype().is_vector()) { |
307 | need_scalarize_ = true; |
308 | return GetRef<PrimExpr>(op); |
309 | } |
310 | PrimExpr t = this->VisitExpr(op->args[1]); |
311 | PrimExpr f = this->VisitExpr(op->args[2]); |
312 | if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { |
313 | return GetRef<PrimExpr>(op); |
314 | } else { |
315 | int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); |
316 | t = BroadcastTo(t, lanes); |
317 | f = BroadcastTo(f, lanes); |
318 | return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); |
319 | } |
320 | } |
321 | // Call |
322 | PrimExpr VisitExpr_(const CallNode* op) final { |
323 | if (op->op.same_as(builtin::if_then_else())) { |
324 | return MutateIfThenElseExpr_(op); |
325 | } else if (op->op.same_as(builtin::texture2d_load())) { |
326 | int lane = 0; |
327 | Array<PrimExpr> fcd = MutateArray({op->args.back()}, &lane); |
328 | auto new_args = op->args; |
329 | new_args.pop_back(); |
330 | new_args.push_back(fcd[0]); |
331 | return Call(op->dtype.with_lanes(4), op->op, new_args); |
332 | } else if (op->op.same_as(builtin::texture2d_store())) { |
333 | int lane = 0; |
334 | // Vectorize the value to store |
335 | Array<PrimExpr> value{op->args.back()}; |
336 | Array<PrimExpr> mutated_value = MutateArray(value, &lane); |
337 | Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; |
338 | return Call(op->dtype.with_lanes(lane), op->op, new_args); |
339 | } |
340 | auto* op_ptr = op->op.as<OpNode>(); |
341 | bool vectorizable = op_ptr && op_vectorizable_.get(GetRef<Op>(op_ptr), false); |
342 | |
343 | if (!vectorizable) { |
344 | // Cannot vectorize this op |
345 | Array<PrimExpr> new_args; |
346 | for (auto arg : op->args) { |
347 | auto new_arg = this->VisitExpr(arg); |
348 | if (new_arg.dtype().is_vector()) { |
349 | need_scalarize_ = true; |
350 | return GetRef<PrimExpr>(op); |
351 | } |
352 | new_args.push_back(new_arg); |
353 | } |
354 | if (op->args.same_as(new_args)) { |
355 | return GetRef<PrimExpr>(op); |
356 | } else { |
357 | return Call(op->dtype, op->op, new_args); |
358 | } |
359 | } else { |
360 | int lane = 0; |
361 | Array<PrimExpr> new_args = MutateArray(op->args, &lane); |
362 | // normal code path. |
363 | if (op->args.same_as(new_args)) { |
364 | return GetRef<PrimExpr>(op); |
365 | } else { |
366 | return Call(op->dtype.with_lanes(lane), op->op, new_args); |
367 | } |
368 | } |
369 | } |
370 | // Load |
371 | PrimExpr VisitExpr_(const LoadNode* op) final { |
372 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
373 | } |
374 | // BufferLoad |
375 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
376 | auto load = GetRef<BufferLoad>(op); |
377 | |
378 | auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; |
379 | Array<PrimExpr> indices = op->indices.Map(fmutate); |
380 | |
381 | if (!indices.same_as(op->indices)) { |
382 | auto writer = load.CopyOnWrite(); |
383 | writer->indices = indices; |
384 | writer->LegalizeDType(); |
385 | } |
386 | |
387 | return std::move(load); |
388 | } |
389 | // Let |
390 | PrimExpr VisitExpr_(const LetNode* op) final { |
391 | PrimExpr value = this->VisitExpr(op->value); |
392 | // Weaker SSA condition |
393 | // A single var can be binded in multiple lets |
394 | // but they have to bind to the same value. |
395 | // This is used to allow cases when we reuse a single let |
396 | // expression to cosntruct a nested expr. |
397 | // (let x = 1 in x + 1) * (let x = 1 in x + 1) |
398 | auto it = let_binding_.find(op->var); |
399 | if (it != let_binding_.end()) { |
400 | ICHECK(deep_equal_(it->second, value)) |
401 | << "Let cannot bind the same var to two different values" ; |
402 | } |
403 | if (value.dtype().lanes() != op->value.dtype().lanes()) { |
404 | Var new_var(op->var->name_hint, value.dtype()); |
405 | let_binding_[op->var] = new_var; |
406 | return Let(new_var, value, this->VisitExpr(op->body)); |
407 | } else { |
408 | let_binding_[op->var] = op->var; |
409 | PrimExpr body = this->VisitExpr(op->body); |
410 | if (value.same_as(op->value) && body.same_as(op->body)) { |
411 | return GetRef<PrimExpr>(op); |
412 | } else { |
413 | return Let(op->var, value, body); |
414 | } |
415 | } |
416 | } |
417 | // Store |
418 | Stmt VisitStmt_(const StoreNode* op) final { |
419 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
420 | } |
421 | // BufferStore |
422 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
423 | auto store = GetRef<BufferStore>(op); |
424 | |
425 | auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; |
426 | Array<PrimExpr> indices = op->indices.Map(fmutate); |
427 | |
428 | PrimExpr value = this->VisitExpr(op->value); |
429 | |
430 | if (!indices.same_as(op->indices) || !value.same_as(op->value)) { |
431 | // How many lanes of indexing are present in the index and |
432 | // buffer element type, excluding the last index. T |
433 | int other_index_lanes = op->buffer->dtype.lanes(); |
434 | for (size_t i = 0; i < indices.size() - 1; i++) { |
435 | other_index_lanes *= indices[i].dtype().lanes(); |
436 | } |
437 | |
438 | // The total number of lanes of indexing, including the last index. |
439 | int index_lanes = other_index_lanes * indices[indices.size() - 1].dtype().lanes(); |
440 | |
441 | // The total number of lanes in this store operation. Either |
442 | // the index or the value will be broadcast out to this number |
443 | // of lanes, depending on which has more lanes. |
444 | int total_lanes = std::max(index_lanes, value.dtype().lanes()); |
445 | |
446 | ICHECK_EQ(total_lanes % other_index_lanes, 0) |
447 | << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes |
448 | << " lanes of storage location by changing the last index." ; |
449 | int last_index_lanes = total_lanes / other_index_lanes; |
450 | |
451 | // Broadcast the last index such that the total number of index |
452 | // lanes matches the desired number. |
453 | indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes)); |
454 | |
455 | auto writer = store.CopyOnWrite(); |
456 | writer->indices = indices; |
457 | writer->value = BroadcastTo(value, total_lanes); |
458 | } |
459 | |
460 | return std::move(store); |
461 | } |
462 | // For |
463 | Stmt VisitStmt_(const ForNode* op) final { |
464 | if (op->kind == ForKind::kVectorized) { |
465 | LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..." ; |
466 | } |
467 | ICHECK(is_zero(op->min)); |
468 | ICHECK(!op->extent.dtype().is_vector()); |
469 | PrimExpr extent = this->VisitExpr(op->extent); |
470 | if (extent.dtype().is_vector()) { |
471 | return Scalarize(GetRef<Stmt>(op)); |
472 | } |
473 | Stmt body = this->VisitStmt(op->body); |
474 | if (extent.same_as(op->extent) && body.same_as(op->body)) { |
475 | return GetRef<Stmt>(op); |
476 | } else { |
477 | return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, |
478 | op->annotations); |
479 | } |
480 | } |
481 | // IfThenElse |
482 | Stmt VisitStmt_(const IfThenElseNode* op) final { |
483 | ICHECK(!op->condition.dtype().is_vector()); |
484 | PrimExpr condition = this->VisitExpr(op->condition); |
485 | if (condition.dtype().is_vector()) { |
486 | return Scalarize(GetRef<Stmt>(op)); |
487 | } |
488 | Stmt then_case = this->VisitStmt(op->then_case); |
489 | Optional<Stmt> else_case = NullOpt; |
490 | if (op->else_case) { |
491 | else_case = this->VisitStmt(op->else_case.value()); |
492 | } |
493 | if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && |
494 | else_case.same_as(op->else_case)) { |
495 | return GetRef<Stmt>(op); |
496 | } else { |
497 | return IfThenElse(condition, then_case, else_case); |
498 | } |
499 | } |
500 | // While |
501 | Stmt VisitStmt_(const WhileNode* op) final { |
502 | LOG(FATAL) << "A while loop inside a vectorized loop not supported." ; |
503 | } |
504 | // LetStmt |
505 | Stmt VisitStmt_(const LetStmtNode* op) final { |
506 | PrimExpr value = this->VisitExpr(op->value); |
507 | ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice" ; |
508 | let_binding_[op->var] = value; |
509 | |
510 | if (value.dtype().lanes() != op->value.dtype().lanes()) { |
511 | Var new_var(op->var->name_hint, value.dtype()); |
512 | let_binding_[op->var] = new_var; |
513 | return LetStmt(new_var, value, this->VisitStmt(op->body)); |
514 | } else { |
515 | let_binding_[op->var] = op->var; |
516 | Stmt body = this->VisitStmt(op->body); |
517 | if (value.same_as(op->value) && body.same_as(op->body)) { |
518 | return GetRef<Stmt>(op); |
519 | } else { |
520 | return LetStmt(op->var, value, body); |
521 | } |
522 | } |
523 | } |
524 | // Allocate |
525 | Stmt VisitStmt_(const AllocateNode* op) final { |
526 | // Mutate the condition |
527 | PrimExpr condition = this->VisitExpr(op->condition); |
528 | if (condition.dtype().is_vector()) { |
529 | LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; |
530 | return Scalarize(GetRef<Stmt>(op)); |
531 | } |
532 | |
533 | // Mutate the extents |
534 | Array<PrimExpr> extents; |
535 | for (const auto& extent : op->extents) { |
536 | PrimExpr new_ext = this->VisitExpr(extent); |
537 | if (new_ext.dtype().is_vector()) { |
538 | LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; |
539 | return Scalarize(GetRef<Stmt>(op)); |
540 | } |
541 | extents.push_back(new_ext); |
542 | } |
543 | |
544 | // TODO(Lunderberg): Move this pass to be prior to |
545 | // StorageFlatten/FlattenBuffer. That will allow this pass to be |
546 | // implemented as adding a new buffer dimension, which is later |
547 | // flattened. |
548 | |
549 | // Extend the least significant dimension by a factor of |
550 | // var_lanes_. Typically, this will be a 1-d index into a flat |
551 | // memory space. |
552 | extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); |
553 | |
554 | // Rewrite access to the buffer in the body. |
555 | Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); |
556 | body = this->VisitStmt(body); |
557 | return Allocate(op->buffer_var, op->dtype, extents, condition, body); |
558 | } |
559 | |
560 | // scalarize the statment |
561 | Stmt Scalarize(Stmt stmt) { |
562 | Var idx(var_->name_hint + ".s" , var_->dtype); |
563 | Map<Var, PrimExpr> values{{var_, idx}}; |
564 | stmt = Substitute(stmt, values); |
565 | return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial, |
566 | stmt); |
567 | } |
568 | // ProducerStore |
569 | Stmt VisitStmt_(const ProducerStoreNode* op) final { |
570 | LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc" ; |
571 | } |
572 | |
573 | private: |
574 | // analyzer |
575 | arith::Analyzer analyzer_; |
576 | // deep equal |
577 | ExprDeepEqual deep_equal_; |
578 | // variable to be replaced |
579 | Var var_; |
580 | // the lanes. |
581 | int var_lanes_; |
582 | // ramp representing the var. |
583 | PrimExpr ramp_; |
584 | // flag to mark requirment of scalarization. |
585 | bool need_scalarize_{false}; |
586 | // Let binding |
587 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_; |
588 | // vectorizable property |
589 | OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable" ); |
590 | |
591 | // mutate array, with given lane requirement |
592 | // when finished, p_lane updates the lane requirement. |
593 | Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) { |
594 | if (arr.size() == 0) return arr; |
595 | int& lanes = *p_lanes; |
596 | bool changed = false; |
597 | std::vector<PrimExpr> new_arr(arr.size()); |
598 | for (size_t i = 0; i < arr.size(); i++) { |
599 | PrimExpr old_elem = arr[i]; |
600 | PrimExpr new_elem = this->VisitExpr(old_elem); |
601 | if (!new_elem.same_as(old_elem)) changed = true; |
602 | new_arr[i] = new_elem; |
603 | lanes = std::max(lanes, new_elem.dtype().lanes()); |
604 | } |
605 | |
606 | for (size_t i = 0; i < arr.size(); ++i) { |
607 | if (new_arr[i].dtype().lanes() != lanes) { |
608 | new_arr[i] = BroadcastTo(new_arr[i], lanes); |
609 | changed = true; |
610 | } |
611 | } |
612 | if (!changed) return arr; |
613 | return Array<PrimExpr>(new_arr); |
614 | } |
615 | template <typename TOp, typename T> |
616 | PrimExpr BinaryVec(const T* op) { |
617 | static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint" ); |
618 | PrimExpr a = this->VisitExpr(op->a); |
619 | PrimExpr b = this->VisitExpr(op->b); |
620 | if (a.same_as(op->a) && b.same_as(op->b)) { |
621 | return GetRef<PrimExpr>(op); |
622 | } else { |
623 | int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); |
624 | return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); |
625 | } |
626 | } |
627 | template <typename T, typename FCompute> |
628 | PrimExpr AddSubVec(const T* op, FCompute fcompute) { |
629 | PrimExpr a = this->VisitExpr(op->a); |
630 | PrimExpr b = this->VisitExpr(op->b); |
631 | if (a.same_as(op->a) && b.same_as(op->b)) { |
632 | return GetRef<PrimExpr>(op); |
633 | } else { |
634 | int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); |
635 | if (lanes != 1) { |
636 | const RampNode* b_ramp = b.as<RampNode>(); |
637 | const RampNode* a_ramp = a.as<RampNode>(); |
638 | if (a.dtype().lanes() == 1 && b_ramp) { |
639 | return Ramp(fcompute(a, b_ramp->base), |
640 | fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); |
641 | } |
642 | if (b.dtype().lanes() == 1 && a_ramp) { |
643 | return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); |
644 | } |
645 | } |
646 | return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); |
647 | } |
648 | } |
649 | }; |
650 | |
651 | class LoopVectorizer : public StmtMutator { |
652 | public: |
653 | Stmt VisitStmt_(const ForNode* op) final { |
654 | if (op->kind == ForKind::kVectorized) { |
655 | ICHECK(is_zero(op->min)); |
656 | auto* extent_as_int = op->extent.as<IntImmNode>(); |
657 | if (!extent_as_int || extent_as_int->value < 1) { |
658 | LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; |
659 | } |
660 | return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body); |
661 | } else { |
662 | return StmtMutator::VisitStmt_(op); |
663 | } |
664 | } |
665 | }; |
666 | |
667 | Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } |
668 | |
669 | class VectorizeSkipper : public StmtMutator { |
670 | public: |
671 | Stmt VisitStmt_(const ForNode* op) final { |
672 | Stmt stmt = StmtMutator::VisitStmt_(op); |
673 | op = stmt.as<ForNode>(); |
674 | if (op->kind == ForKind::kVectorized) { |
675 | return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body); |
676 | } else { |
677 | return stmt; |
678 | } |
679 | } |
680 | }; |
681 | |
682 | Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } |
683 | |
684 | namespace transform { |
685 | |
686 | // TODO(tvm-team): Make it as a target property. |
687 | Pass VectorizeLoop(bool enable_vectorize) { |
688 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
689 | auto* n = f.CopyOnWrite(); |
690 | if (enable_vectorize) { |
691 | n->body = LoopVectorizer()(std::move(n->body)); |
692 | } else { |
693 | n->body = VectorizeSkipper()(std::move(n->body)); |
694 | } |
695 | return f; |
696 | }; |
697 | return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop" , {}); |
698 | } |
699 | |
700 | TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop" ).set_body_typed(VectorizeLoop); |
701 | |
702 | } // namespace transform |
703 | |
704 | } // namespace tir |
705 | } // namespace tvm |
706 | |