1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file vectorize_loop.cc
22 */
23// Loop vectorizer as in Halide pipeline.
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/op_attr_types.h>
31#include <tvm/tir/stmt_functor.h>
32#include <tvm/tir/transform.h>
33
34#include <unordered_map>
35#include <unordered_set>
36#include <vector>
37
38namespace tvm {
39namespace tir {
40
41inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
42 if (e.dtype().lanes() == lanes) return e;
43 if (const BroadcastNode* op = e.as<BroadcastNode>()) {
44 if (lanes % op->lanes == 0) {
45 return Broadcast(op->value, lanes);
46 }
47 }
48 ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to "
49 << lanes;
50 return Broadcast(e, lanes);
51}
52
53// Rewrite vectorized allocation access
54// This is necessary for making each vector component containing its own workspace.
55// Originates from Halide's loop vectorizer
56//
57// s[i] = s[i * lanes + var]
58//
59// The same principle applies when using one thread to simulate multiple context.
60//
61class VecAllocAccess : public StmtExprMutator {
62 public:
63 VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
64 : buf_(buf), var_(var), var_lanes_(var_lanes) {}
65
66 PrimExpr VisitExpr_(const LoadNode* op) final {
67 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
68 }
69
70 Stmt VisitStmt_(const StoreNode* op) final {
71 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
72 }
73
74 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
75 auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
76 return UpdateBufferAccess(load);
77 }
78
79 Stmt VisitStmt_(const BufferStoreNode* op) final {
80 auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
81 return UpdateBufferAccess(store);
82 }
83
84 private:
85 template <typename Node>
86 Node UpdateBufferAccess(Node node) {
87 // Only update the buffer that's being replaced.
88 if (node->buffer->data.get() != buf_) {
89 return node;
90 }
91
92 // Find/make a Buffer object with the correct updated shape.
93 Buffer buf;
94 auto it = buffer_map_.find(node->buffer.get());
95 if (it != buffer_map_.end()) {
96 buf = it->second;
97 } else {
98 // Extend the least significant dimension by a factor of
99 // var_lanes_. Typically, this will be a 1-d index into a flat
100 // memory space.
101 Array<PrimExpr> shape = node->buffer->shape;
102 shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
103
104 // TODO(Lunderberg): Move this pass to be prior to
105 // StorageFlatten/FlattenBuffer, implement by appending a
106 // dimension to the buffer. Since it is currently after the
107 // flattening, the strides are not technically necessary, but
108 // are updated for consistency.
109
110 // Update strides if defined.
111 Array<PrimExpr> strides;
112 for (size_t i = 0; i < strides.size(); i++) {
113 PrimExpr stride = strides[i];
114 if (i != strides.size() - 1) {
115 stride *= var_lanes_;
116 }
117 strides.push_back(analyzer_.Simplify(stride));
118 }
119
120 // Copy everything into the new buffer.
121 buf = node->buffer;
122 auto buf_writer = buf.CopyOnWrite();
123 buf_writer->shape = shape;
124 buf_writer->strides = strides;
125 buffer_map_[buf.get()] = buf;
126 }
127
128 // Extend the last index by the number of lanes in the vectorized
129 // variable.
130 Array<PrimExpr> indices = node->indices;
131 indices.Set(indices.size() - 1,
132 analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
133
134 auto writer = node.CopyOnWrite();
135 writer->buffer = buf;
136 writer->indices = indices;
137 return node;
138 }
139
140 // buffer var
141 const VarNode* buf_;
142 // Updated buffer objects.
143 std::unordered_map<const BufferNode*, Buffer> buffer_map_;
144 // variable to be replaced
145 Var var_;
146 // the lanes.
147 int var_lanes_;
148 // Analyzer for simplifications
149 arith::Analyzer analyzer_;
150};
151
152// We use ExprFunctor directly instead of StmtExprMutator
153// This is because the transformation can change the dtype of the Expr
154// The existing ExprMutator transformation rules may not be well defined.
155class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> {
156 public:
157 using ExprFunctor::VisitExpr;
158 using StmtMutator::operator();
159
160 Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
161 ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
162 }
163
164 Stmt VisitStmt(const Stmt& stmt) final {
165 ICHECK(!need_scalarize_);
166 Stmt ret = StmtMutator::VisitStmt(stmt);
167 if (need_scalarize_) {
168 need_scalarize_ = false;
169 return Scalarize(stmt);
170 } else {
171 return ret;
172 }
173 }
174
175 PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); }
176
177 PrimExpr VisitExpr_(const AddNode* op) final {
178 return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
179 }
180
181 PrimExpr VisitExpr_(const SubNode* op) final {
182 return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
183 }
184
185 PrimExpr VisitExpr_(const MulNode* op) final {
186 PrimExpr a = this->VisitExpr(op->a);
187 PrimExpr b = this->VisitExpr(op->b);
188 if (a.same_as(op->a) && b.same_as(op->b)) {
189 return GetRef<PrimExpr>(op);
190 } else {
191 int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
192 if (lanes != 1) {
193 const RampNode* b_ramp = b.as<RampNode>();
194 const RampNode* a_ramp = a.as<RampNode>();
195 if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
196 return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
197 }
198 if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
199 return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
200 }
201 }
202 return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
203 }
204 return BinaryVec<Mul>(op);
205 }
206 PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); }
207 PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); }
208 PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); }
209 PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); }
210 PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); }
211 PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); }
212 PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); }
213 PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); }
214 PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); }
215 PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); }
216 PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); }
217 PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
218 PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
219 PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
220
221 PrimExpr VisitExpr_(const NotNode* op) final {
222 PrimExpr a = this->VisitExpr(op->a);
223 if (a.same_as(op->a)) {
224 return GetRef<PrimExpr>(op);
225 } else {
226 return !(a);
227 }
228 }
229
230 PrimExpr VisitExpr_(const RampNode* op) final {
231 PrimExpr base = this->VisitExpr(op->base);
232 PrimExpr stride = this->VisitExpr(op->stride);
233 if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
234 const RampNode* base_ramp = base.as<RampNode>();
235 if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
236 return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes);
237 }
238 }
239 int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
240 base = BroadcastTo(base, lanes);
241 stride = BroadcastTo(stride, lanes);
242 Array<PrimExpr> elems;
243 for (int i = 0; i < lanes; ++i) {
244 elems.push_back(
245 Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes));
246 }
247 return Shuffle::Concat(elems);
248 }
249
250 PrimExpr VisitExpr_(const BroadcastNode* op) final {
251 PrimExpr value = this->VisitExpr(op->value);
252 if (value.dtype().lanes() != 1) {
253 need_scalarize_ = true;
254 return GetRef<PrimExpr>(op);
255 }
256 if (value.same_as(op->value)) {
257 return GetRef<PrimExpr>(op);
258 } else {
259 return Broadcast(op->value, op->lanes);
260 }
261 }
262
263 PrimExpr VisitExpr_(const SelectNode* op) final {
264 PrimExpr cond = this->VisitExpr(op->condition);
265 PrimExpr t = this->VisitExpr(op->true_value);
266 PrimExpr f = this->VisitExpr(op->false_value);
267 if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) {
268 return GetRef<PrimExpr>(op);
269 } else {
270 int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes());
271 return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
272 }
273 }
274 PrimExpr VisitExpr_(const CastNode* op) final {
275 PrimExpr value = this->VisitExpr(op->value);
276 if (value.same_as(op->value)) {
277 return GetRef<PrimExpr>(op);
278 } else {
279 return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
280 }
281 }
282
283 PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); }
284
285 PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); }
286
287 PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); }
288
289 // Variable
290 PrimExpr VisitExpr_(const VarNode* op) final {
291 Var var = GetRef<Var>(op);
292
293 if (var.same_as(var_)) {
294 return ramp_;
295 }
296 auto it = let_binding_.find(var);
297 if (it != let_binding_.end()) {
298 return it->second;
299 } else {
300 return std::move(var);
301 }
302 }
303 // IfThenElse expr
304 PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
305 PrimExpr cond = this->VisitExpr(op->args[0]);
306 if (cond.dtype().is_vector()) {
307 need_scalarize_ = true;
308 return GetRef<PrimExpr>(op);
309 }
310 PrimExpr t = this->VisitExpr(op->args[1]);
311 PrimExpr f = this->VisitExpr(op->args[2]);
312 if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) {
313 return GetRef<PrimExpr>(op);
314 } else {
315 int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
316 t = BroadcastTo(t, lanes);
317 f = BroadcastTo(f, lanes);
318 return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
319 }
320 }
321 // Call
322 PrimExpr VisitExpr_(const CallNode* op) final {
323 if (op->op.same_as(builtin::if_then_else())) {
324 return MutateIfThenElseExpr_(op);
325 } else if (op->op.same_as(builtin::texture2d_load())) {
326 int lane = 0;
327 Array<PrimExpr> fcd = MutateArray({op->args.back()}, &lane);
328 auto new_args = op->args;
329 new_args.pop_back();
330 new_args.push_back(fcd[0]);
331 return Call(op->dtype.with_lanes(4), op->op, new_args);
332 } else if (op->op.same_as(builtin::texture2d_store())) {
333 int lane = 0;
334 // Vectorize the value to store
335 Array<PrimExpr> value{op->args.back()};
336 Array<PrimExpr> mutated_value = MutateArray(value, &lane);
337 Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]};
338 return Call(op->dtype.with_lanes(lane), op->op, new_args);
339 }
340 auto* op_ptr = op->op.as<OpNode>();
341 bool vectorizable = op_ptr && op_vectorizable_.get(GetRef<Op>(op_ptr), false);
342
343 if (!vectorizable) {
344 // Cannot vectorize this op
345 Array<PrimExpr> new_args;
346 for (auto arg : op->args) {
347 auto new_arg = this->VisitExpr(arg);
348 if (new_arg.dtype().is_vector()) {
349 need_scalarize_ = true;
350 return GetRef<PrimExpr>(op);
351 }
352 new_args.push_back(new_arg);
353 }
354 if (op->args.same_as(new_args)) {
355 return GetRef<PrimExpr>(op);
356 } else {
357 return Call(op->dtype, op->op, new_args);
358 }
359 } else {
360 int lane = 0;
361 Array<PrimExpr> new_args = MutateArray(op->args, &lane);
362 // normal code path.
363 if (op->args.same_as(new_args)) {
364 return GetRef<PrimExpr>(op);
365 } else {
366 return Call(op->dtype.with_lanes(lane), op->op, new_args);
367 }
368 }
369 }
370 // Load
371 PrimExpr VisitExpr_(const LoadNode* op) final {
372 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
373 }
374 // BufferLoad
375 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
376 auto load = GetRef<BufferLoad>(op);
377
378 auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
379 Array<PrimExpr> indices = op->indices.Map(fmutate);
380
381 if (!indices.same_as(op->indices)) {
382 auto writer = load.CopyOnWrite();
383 writer->indices = indices;
384 writer->LegalizeDType();
385 }
386
387 return std::move(load);
388 }
389 // Let
390 PrimExpr VisitExpr_(const LetNode* op) final {
391 PrimExpr value = this->VisitExpr(op->value);
392 // Weaker SSA condition
393 // A single var can be binded in multiple lets
394 // but they have to bind to the same value.
395 // This is used to allow cases when we reuse a single let
396 // expression to cosntruct a nested expr.
397 // (let x = 1 in x + 1) * (let x = 1 in x + 1)
398 auto it = let_binding_.find(op->var);
399 if (it != let_binding_.end()) {
400 ICHECK(deep_equal_(it->second, value))
401 << "Let cannot bind the same var to two different values";
402 }
403 if (value.dtype().lanes() != op->value.dtype().lanes()) {
404 Var new_var(op->var->name_hint, value.dtype());
405 let_binding_[op->var] = new_var;
406 return Let(new_var, value, this->VisitExpr(op->body));
407 } else {
408 let_binding_[op->var] = op->var;
409 PrimExpr body = this->VisitExpr(op->body);
410 if (value.same_as(op->value) && body.same_as(op->body)) {
411 return GetRef<PrimExpr>(op);
412 } else {
413 return Let(op->var, value, body);
414 }
415 }
416 }
417 // Store
418 Stmt VisitStmt_(const StoreNode* op) final {
419 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
420 }
421 // BufferStore
422 Stmt VisitStmt_(const BufferStoreNode* op) final {
423 auto store = GetRef<BufferStore>(op);
424
425 auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
426 Array<PrimExpr> indices = op->indices.Map(fmutate);
427
428 PrimExpr value = this->VisitExpr(op->value);
429
430 if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
431 // How many lanes of indexing are present in the index and
432 // buffer element type, excluding the last index. T
433 int other_index_lanes = op->buffer->dtype.lanes();
434 for (size_t i = 0; i < indices.size() - 1; i++) {
435 other_index_lanes *= indices[i].dtype().lanes();
436 }
437
438 // The total number of lanes of indexing, including the last index.
439 int index_lanes = other_index_lanes * indices[indices.size() - 1].dtype().lanes();
440
441 // The total number of lanes in this store operation. Either
442 // the index or the value will be broadcast out to this number
443 // of lanes, depending on which has more lanes.
444 int total_lanes = std::max(index_lanes, value.dtype().lanes());
445
446 ICHECK_EQ(total_lanes % other_index_lanes, 0)
447 << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes
448 << " lanes of storage location by changing the last index.";
449 int last_index_lanes = total_lanes / other_index_lanes;
450
451 // Broadcast the last index such that the total number of index
452 // lanes matches the desired number.
453 indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes));
454
455 auto writer = store.CopyOnWrite();
456 writer->indices = indices;
457 writer->value = BroadcastTo(value, total_lanes);
458 }
459
460 return std::move(store);
461 }
462 // For
463 Stmt VisitStmt_(const ForNode* op) final {
464 if (op->kind == ForKind::kVectorized) {
465 LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
466 }
467 ICHECK(is_zero(op->min));
468 ICHECK(!op->extent.dtype().is_vector());
469 PrimExpr extent = this->VisitExpr(op->extent);
470 if (extent.dtype().is_vector()) {
471 return Scalarize(GetRef<Stmt>(op));
472 }
473 Stmt body = this->VisitStmt(op->body);
474 if (extent.same_as(op->extent) && body.same_as(op->body)) {
475 return GetRef<Stmt>(op);
476 } else {
477 return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding,
478 op->annotations);
479 }
480 }
481 // IfThenElse
482 Stmt VisitStmt_(const IfThenElseNode* op) final {
483 ICHECK(!op->condition.dtype().is_vector());
484 PrimExpr condition = this->VisitExpr(op->condition);
485 if (condition.dtype().is_vector()) {
486 return Scalarize(GetRef<Stmt>(op));
487 }
488 Stmt then_case = this->VisitStmt(op->then_case);
489 Optional<Stmt> else_case = NullOpt;
490 if (op->else_case) {
491 else_case = this->VisitStmt(op->else_case.value());
492 }
493 if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
494 else_case.same_as(op->else_case)) {
495 return GetRef<Stmt>(op);
496 } else {
497 return IfThenElse(condition, then_case, else_case);
498 }
499 }
500 // While
501 Stmt VisitStmt_(const WhileNode* op) final {
502 LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
503 }
504 // LetStmt
505 Stmt VisitStmt_(const LetStmtNode* op) final {
506 PrimExpr value = this->VisitExpr(op->value);
507 ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice";
508 let_binding_[op->var] = value;
509
510 if (value.dtype().lanes() != op->value.dtype().lanes()) {
511 Var new_var(op->var->name_hint, value.dtype());
512 let_binding_[op->var] = new_var;
513 return LetStmt(new_var, value, this->VisitStmt(op->body));
514 } else {
515 let_binding_[op->var] = op->var;
516 Stmt body = this->VisitStmt(op->body);
517 if (value.same_as(op->value) && body.same_as(op->body)) {
518 return GetRef<Stmt>(op);
519 } else {
520 return LetStmt(op->var, value, body);
521 }
522 }
523 }
524 // Allocate
525 Stmt VisitStmt_(const AllocateNode* op) final {
526 // Mutate the condition
527 PrimExpr condition = this->VisitExpr(op->condition);
528 if (condition.dtype().is_vector()) {
529 LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint;
530 return Scalarize(GetRef<Stmt>(op));
531 }
532
533 // Mutate the extents
534 Array<PrimExpr> extents;
535 for (const auto& extent : op->extents) {
536 PrimExpr new_ext = this->VisitExpr(extent);
537 if (new_ext.dtype().is_vector()) {
538 LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint;
539 return Scalarize(GetRef<Stmt>(op));
540 }
541 extents.push_back(new_ext);
542 }
543
544 // TODO(Lunderberg): Move this pass to be prior to
545 // StorageFlatten/FlattenBuffer. That will allow this pass to be
546 // implemented as adding a new buffer dimension, which is later
547 // flattened.
548
549 // Extend the least significant dimension by a factor of
550 // var_lanes_. Typically, this will be a 1-d index into a flat
551 // memory space.
552 extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
553
554 // Rewrite access to the buffer in the body.
555 Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
556 body = this->VisitStmt(body);
557 return Allocate(op->buffer_var, op->dtype, extents, condition, body);
558 }
559
560 // scalarize the statment
561 Stmt Scalarize(Stmt stmt) {
562 Var idx(var_->name_hint + ".s", var_->dtype);
563 Map<Var, PrimExpr> values{{var_, idx}};
564 stmt = Substitute(stmt, values);
565 return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial,
566 stmt);
567 }
568 // ProducerStore
569 Stmt VisitStmt_(const ProducerStoreNode* op) final {
570 LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
571 }
572
573 private:
574 // analyzer
575 arith::Analyzer analyzer_;
576 // deep equal
577 ExprDeepEqual deep_equal_;
578 // variable to be replaced
579 Var var_;
580 // the lanes.
581 int var_lanes_;
582 // ramp representing the var.
583 PrimExpr ramp_;
584 // flag to mark requirment of scalarization.
585 bool need_scalarize_{false};
586 // Let binding
587 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
588 // vectorizable property
589 OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable");
590
591 // mutate array, with given lane requirement
592 // when finished, p_lane updates the lane requirement.
593 Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
594 if (arr.size() == 0) return arr;
595 int& lanes = *p_lanes;
596 bool changed = false;
597 std::vector<PrimExpr> new_arr(arr.size());
598 for (size_t i = 0; i < arr.size(); i++) {
599 PrimExpr old_elem = arr[i];
600 PrimExpr new_elem = this->VisitExpr(old_elem);
601 if (!new_elem.same_as(old_elem)) changed = true;
602 new_arr[i] = new_elem;
603 lanes = std::max(lanes, new_elem.dtype().lanes());
604 }
605
606 for (size_t i = 0; i < arr.size(); ++i) {
607 if (new_arr[i].dtype().lanes() != lanes) {
608 new_arr[i] = BroadcastTo(new_arr[i], lanes);
609 changed = true;
610 }
611 }
612 if (!changed) return arr;
613 return Array<PrimExpr>(new_arr);
614 }
615 template <typename TOp, typename T>
616 PrimExpr BinaryVec(const T* op) {
617 static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint");
618 PrimExpr a = this->VisitExpr(op->a);
619 PrimExpr b = this->VisitExpr(op->b);
620 if (a.same_as(op->a) && b.same_as(op->b)) {
621 return GetRef<PrimExpr>(op);
622 } else {
623 int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
624 return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
625 }
626 }
627 template <typename T, typename FCompute>
628 PrimExpr AddSubVec(const T* op, FCompute fcompute) {
629 PrimExpr a = this->VisitExpr(op->a);
630 PrimExpr b = this->VisitExpr(op->b);
631 if (a.same_as(op->a) && b.same_as(op->b)) {
632 return GetRef<PrimExpr>(op);
633 } else {
634 int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
635 if (lanes != 1) {
636 const RampNode* b_ramp = b.as<RampNode>();
637 const RampNode* a_ramp = a.as<RampNode>();
638 if (a.dtype().lanes() == 1 && b_ramp) {
639 return Ramp(fcompute(a, b_ramp->base),
640 fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
641 }
642 if (b.dtype().lanes() == 1 && a_ramp) {
643 return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
644 }
645 }
646 return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
647 }
648 }
649};
650
651class LoopVectorizer : public StmtMutator {
652 public:
653 Stmt VisitStmt_(const ForNode* op) final {
654 if (op->kind == ForKind::kVectorized) {
655 ICHECK(is_zero(op->min));
656 auto* extent_as_int = op->extent.as<IntImmNode>();
657 if (!extent_as_int || extent_as_int->value < 1) {
658 LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
659 }
660 return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
661 } else {
662 return StmtMutator::VisitStmt_(op);
663 }
664 }
665};
666
667Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
668
669class VectorizeSkipper : public StmtMutator {
670 public:
671 Stmt VisitStmt_(const ForNode* op) final {
672 Stmt stmt = StmtMutator::VisitStmt_(op);
673 op = stmt.as<ForNode>();
674 if (op->kind == ForKind::kVectorized) {
675 return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body);
676 } else {
677 return stmt;
678 }
679 }
680};
681
682Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
683
684namespace transform {
685
686// TODO(tvm-team): Make it as a target property.
687Pass VectorizeLoop(bool enable_vectorize) {
688 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
689 auto* n = f.CopyOnWrite();
690 if (enable_vectorize) {
691 n->body = LoopVectorizer()(std::move(n->body));
692 } else {
693 n->body = VectorizeSkipper()(std::move(n->body));
694 }
695 return f;
696 };
697 return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
698}
699
700TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
701
702} // namespace transform
703
704} // namespace tir
705} // namespace tvm
706