1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Lower block init stmt into branch stmt |
22 | * \file lower_reduction.cc |
23 | */ |
24 | #include <tvm/tir/op.h> |
25 | #include <tvm/tir/stmt_functor.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include "ir_utils.h" |
29 | |
30 | namespace tvm { |
31 | namespace tir { |
32 | |
33 | class InitBlockLower : public StmtMutator { |
34 | private: |
35 | Stmt VisitStmt_(const BlockNode* block) final { |
36 | if (!block->init.defined()) { |
37 | return StmtMutator::VisitStmt_(block); |
38 | } |
39 | Stmt init = DoLowering(block->init.value(), block->iter_vars); |
40 | Stmt body = VisitStmt(block->body); |
41 | auto n = CopyOnWrite(block); |
42 | n->init = NullOpt; |
43 | n->body = SeqStmt::Flatten(init, body); |
44 | return Block(n); |
45 | } |
46 | |
47 | static Stmt DoLowering(const Stmt& init, const Array<IterVar>& iter_vars) { |
48 | std::vector<PrimExpr> conditions; |
49 | for (const IterVar& var : iter_vars) { |
50 | if (var->iter_type == IterVarType::kCommReduce) { |
51 | conditions.push_back(equal(var->var, var->dom->min)); |
52 | } |
53 | } |
54 | // Handle the case where there is no condition |
55 | if (conditions.empty()) { |
56 | return init; |
57 | } |
58 | // Concat the conditions with logical and (&&) |
59 | PrimExpr cond = conditions[0]; |
60 | for (size_t i = 1; i < conditions.size(); ++i) { |
61 | cond = logical_and(cond, conditions[i]); |
62 | } |
63 | return IfThenElse(cond, init); |
64 | } |
65 | }; |
66 | |
67 | PrimFunc LowerInitBlock(PrimFunc func) { |
68 | // Only apply this pass to TIR that is not from TE schedules |
69 | if (!IsFromLegacyTESchedule(func)) { |
70 | auto fptr = func.CopyOnWrite(); |
71 | fptr->body = InitBlockLower()(std::move(fptr->body)); |
72 | return func; |
73 | } else { |
74 | return func; |
75 | } |
76 | } |
77 | |
78 | namespace transform { |
79 | |
80 | Pass LowerInitBlock() { |
81 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
82 | return LowerInitBlock(std::move(f)); |
83 | }; |
84 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock" , {}); |
85 | } |
86 | |
87 | TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock" ).set_body_typed(LowerInitBlock); |
88 | |
89 | } // namespace transform |
90 | |
91 | } // namespace tir |
92 | } // namespace tvm |
93 | |