1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! \brief Check if an IRModule has any async strided mem copies. */ |
25 | struct AsyncStridedMemCopyFinder : private StmtExprVisitor { |
26 | public: |
27 | static bool Find(const IRModule& mod) { |
28 | AsyncStridedMemCopyFinder finder; |
29 | for (const auto& kv : mod->functions) { |
30 | if (const auto* prim_func = kv.second.as<PrimFuncNode>()) { |
31 | finder(prim_func->body); |
32 | if (finder.found_) { |
33 | return true; |
34 | } |
35 | } |
36 | } |
37 | return false; |
38 | } |
39 | |
40 | private: |
41 | void VisitStmt_(const ForNode* loop) final { |
42 | if (!found_) { |
43 | input_iters.Set(loop->loop_var, Range(loop->min, loop->extent)); |
44 | StmtExprVisitor::VisitStmt_(loop); |
45 | } |
46 | } |
47 | |
48 | void VisitStmt_(const AttrStmtNode* attrStmt) final { |
49 | if (!found_) { |
50 | if (attrStmt->attr_key == tir::attr::async_commit_queue_scope) { |
51 | auto async_scope = attrStmt->body.as<AttrStmtNode>(); |
52 | if (!async_scope) { |
53 | StmtExprVisitor::VisitStmt_(attrStmt); |
54 | } |
55 | |
56 | auto for_loop = async_scope->body.as<ForNode>(); |
57 | if (!for_loop) { |
58 | StmtExprVisitor::VisitStmt_(attrStmt); |
59 | } |
60 | |
61 | input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent)); |
62 | |
63 | auto bufferstorenode = for_loop->body.as<BufferStoreNode>(); |
64 | if (!bufferstorenode) { |
65 | StmtExprVisitor::VisitStmt_(attrStmt); |
66 | } |
67 | |
68 | auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>(); |
69 | if (!bufferloadnode) { |
70 | StmtExprVisitor::VisitStmt_(attrStmt); |
71 | } |
72 | |
73 | // get store buffer; assert it exists and is contiguous given it uses a single index |
74 | auto bufferstore = bufferstorenode->buffer.as<BufferNode>(); |
75 | |
76 | // get load buffer; assert it exists and is contiguous given it uses a single index |
77 | auto bufferload = bufferloadnode->buffer.as<BufferNode>(); |
78 | |
79 | if (!bufferstore || !bufferload) { |
80 | StmtExprVisitor::VisitStmt_(attrStmt); |
81 | } |
82 | |
83 | // map loop variable to zero for the store index & simplify |
84 | Array<PrimExpr> store_index = bufferstorenode->indices; |
85 | |
86 | // Use DetectIterMap to detect whether store index is non-contiguous. |
87 | arith::Analyzer analyzer; |
88 | auto store_iter_map = DetectIterMap(store_index, input_iters, 1, |
89 | arith::IterMapLevel::Surjective, &analyzer, false); |
90 | if (!store_iter_map->errors.empty()) { |
91 | found_ = true; |
92 | } |
93 | |
94 | // map loop variable to zero for the load index & simplify |
95 | Array<PrimExpr> load_index = bufferloadnode->indices; |
96 | |
97 | // Use DetectIterMap to detect whether load index is non-contiguous. |
98 | auto load_iter_map = DetectIterMap(load_index, input_iters, 1, |
99 | arith::IterMapLevel::Surjective, &analyzer, false); |
100 | if (!load_iter_map->errors.empty()) { |
101 | found_ = true; |
102 | } |
103 | } |
104 | if (!found_) { |
105 | StmtExprVisitor::VisitStmt_(attrStmt); |
106 | } |
107 | } |
108 | } |
109 | |
110 | bool found_ = false; |
111 | Map<Var, Range> input_iters = Map<Var, Range>(); |
112 | }; |
113 | |
114 | } // namespace tir |
115 | |
116 | namespace meta_schedule { |
117 | |
118 | /*! \brief Check if the IRModule has any loop with non-constant extent. */ |
119 | class DisallowAsyncStridedMemCopyNode : public PostprocNode { |
120 | public: |
121 | // Inherited from PostprocNode |
122 | void InitializeWithTuneContext(const TuneContext& context) final {} |
123 | // Inherited from PostprocNode |
124 | bool Apply(const tir::Schedule& sch) final { |
125 | IRModule mod = sch->mod(); |
126 | for (const auto& kv : mod->functions) { |
127 | const GlobalVar& g_var = kv.first; |
128 | const BaseFunc& base_func = kv.second; |
129 | if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) { |
130 | IRModule lowered{nullptr}; |
131 | try { |
132 | auto pass_list = Array<tvm::transform::Pass>(); |
133 | pass_list.push_back(tir::transform::LowerInitBlock()); |
134 | pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); |
135 | pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); |
136 | pass_list.push_back(tir::transform::CompactBufferAllocation()); |
137 | pass_list.push_back(tir::transform::LowerMatchBuffer()); |
138 | pass_list.push_back(tir::transform::InjectSoftwarePipeline()); |
139 | pass_list.push_back(tir::transform::LowerOpaqueBlock()); |
140 | pass_list.push_back(tir::transform::FlattenBuffer()); |
141 | pass_list.push_back(tir::transform::BF16Legalize()); |
142 | pass_list.push_back(tir::transform::NarrowDataType(32)); |
143 | pass_list.push_back(tir::transform::Simplify()); |
144 | pass_list.push_back(tir::transform::InjectVirtualThread()); |
145 | pass_list.push_back(tir::transform::InjectDoubleBuffer()); |
146 | pass_list.push_back(tir::transform::VectorizeLoop(true)); |
147 | pass_list.push_back(tir::transform::StorageRewrite()); |
148 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
149 | pass_ctx->config.Set("tir.merge_async_commit_queue_scope" , |
150 | Bool(merge_async_commit_queue_scope)); |
151 | tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol" , |
152 | runtime::String(g_var->name_hint)); |
153 | IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}})); |
154 | lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); |
155 | } catch (const dmlc::Error& e) { |
156 | return false; |
157 | } |
158 | if (tir::AsyncStridedMemCopyFinder::Find(lowered)) { |
159 | return false; |
160 | } |
161 | } |
162 | } |
163 | return true; |
164 | } |
165 | // Inherited from PostprocNode |
166 | Postproc Clone() const { |
167 | ObjectPtr<DisallowAsyncStridedMemCopyNode> n = |
168 | make_object<DisallowAsyncStridedMemCopyNode>(*this); |
169 | return Postproc(n); |
170 | } |
171 | |
172 | bool merge_async_commit_queue_scope = true; |
173 | |
174 | static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy" ; |
175 | TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); |
176 | }; |
177 | |
178 | Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) { |
179 | ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>(); |
180 | n->merge_async_commit_queue_scope = merge_async_commit_queue_scope; |
181 | return Postproc(n); |
182 | } |
183 | |
184 | TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode); |
185 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy" ) |
186 | .set_body_typed(Postproc::DisallowAsyncStridedMemCopy); |
187 | |
188 | } // namespace meta_schedule |
189 | } // namespace tvm |
190 | |