1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*! \brief Check if an IRModule has any async strided mem copies. */
25struct AsyncStridedMemCopyFinder : private StmtExprVisitor {
26 public:
27 static bool Find(const IRModule& mod) {
28 AsyncStridedMemCopyFinder finder;
29 for (const auto& kv : mod->functions) {
30 if (const auto* prim_func = kv.second.as<PrimFuncNode>()) {
31 finder(prim_func->body);
32 if (finder.found_) {
33 return true;
34 }
35 }
36 }
37 return false;
38 }
39
40 private:
41 void VisitStmt_(const ForNode* loop) final {
42 if (!found_) {
43 input_iters.Set(loop->loop_var, Range(loop->min, loop->extent));
44 StmtExprVisitor::VisitStmt_(loop);
45 }
46 }
47
48 void VisitStmt_(const AttrStmtNode* attrStmt) final {
49 if (!found_) {
50 if (attrStmt->attr_key == tir::attr::async_commit_queue_scope) {
51 auto async_scope = attrStmt->body.as<AttrStmtNode>();
52 if (!async_scope) {
53 StmtExprVisitor::VisitStmt_(attrStmt);
54 }
55
56 auto for_loop = async_scope->body.as<ForNode>();
57 if (!for_loop) {
58 StmtExprVisitor::VisitStmt_(attrStmt);
59 }
60
61 input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent));
62
63 auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
64 if (!bufferstorenode) {
65 StmtExprVisitor::VisitStmt_(attrStmt);
66 }
67
68 auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
69 if (!bufferloadnode) {
70 StmtExprVisitor::VisitStmt_(attrStmt);
71 }
72
73 // get store buffer; assert it exists and is contiguous given it uses a single index
74 auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
75
76 // get load buffer; assert it exists and is contiguous given it uses a single index
77 auto bufferload = bufferloadnode->buffer.as<BufferNode>();
78
79 if (!bufferstore || !bufferload) {
80 StmtExprVisitor::VisitStmt_(attrStmt);
81 }
82
83 // map loop variable to zero for the store index & simplify
84 Array<PrimExpr> store_index = bufferstorenode->indices;
85
86 // Use DetectIterMap to detect whether store index is non-contiguous.
87 arith::Analyzer analyzer;
88 auto store_iter_map = DetectIterMap(store_index, input_iters, 1,
89 arith::IterMapLevel::Surjective, &analyzer, false);
90 if (!store_iter_map->errors.empty()) {
91 found_ = true;
92 }
93
94 // map loop variable to zero for the load index & simplify
95 Array<PrimExpr> load_index = bufferloadnode->indices;
96
97 // Use DetectIterMap to detect whether load index is non-contiguous.
98 auto load_iter_map = DetectIterMap(load_index, input_iters, 1,
99 arith::IterMapLevel::Surjective, &analyzer, false);
100 if (!load_iter_map->errors.empty()) {
101 found_ = true;
102 }
103 }
104 if (!found_) {
105 StmtExprVisitor::VisitStmt_(attrStmt);
106 }
107 }
108 }
109
110 bool found_ = false;
111 Map<Var, Range> input_iters = Map<Var, Range>();
112};
113
114} // namespace tir
115
116namespace meta_schedule {
117
118/*! \brief Check if the IRModule has any loop with non-constant extent. */
119class DisallowAsyncStridedMemCopyNode : public PostprocNode {
120 public:
121 // Inherited from PostprocNode
122 void InitializeWithTuneContext(const TuneContext& context) final {}
123 // Inherited from PostprocNode
124 bool Apply(const tir::Schedule& sch) final {
125 IRModule mod = sch->mod();
126 for (const auto& kv : mod->functions) {
127 const GlobalVar& g_var = kv.first;
128 const BaseFunc& base_func = kv.second;
129 if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
130 IRModule lowered{nullptr};
131 try {
132 auto pass_list = Array<tvm::transform::Pass>();
133 pass_list.push_back(tir::transform::LowerInitBlock());
134 pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
135 pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
136 pass_list.push_back(tir::transform::CompactBufferAllocation());
137 pass_list.push_back(tir::transform::LowerMatchBuffer());
138 pass_list.push_back(tir::transform::InjectSoftwarePipeline());
139 pass_list.push_back(tir::transform::LowerOpaqueBlock());
140 pass_list.push_back(tir::transform::FlattenBuffer());
141 pass_list.push_back(tir::transform::BF16Legalize());
142 pass_list.push_back(tir::transform::NarrowDataType(32));
143 pass_list.push_back(tir::transform::Simplify());
144 pass_list.push_back(tir::transform::InjectVirtualThread());
145 pass_list.push_back(tir::transform::InjectDoubleBuffer());
146 pass_list.push_back(tir::transform::VectorizeLoop(true));
147 pass_list.push_back(tir::transform::StorageRewrite());
148 transform::PassContext pass_ctx = transform::PassContext::Current();
149 pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
150 Bool(merge_async_commit_queue_scope));
151 tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
152 runtime::String(g_var->name_hint));
153 IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
154 lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
155 } catch (const dmlc::Error& e) {
156 return false;
157 }
158 if (tir::AsyncStridedMemCopyFinder::Find(lowered)) {
159 return false;
160 }
161 }
162 }
163 return true;
164 }
165 // Inherited from PostprocNode
166 Postproc Clone() const {
167 ObjectPtr<DisallowAsyncStridedMemCopyNode> n =
168 make_object<DisallowAsyncStridedMemCopyNode>(*this);
169 return Postproc(n);
170 }
171
172 bool merge_async_commit_queue_scope = true;
173
174 static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
175 TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
176};
177
178Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
179 ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
180 n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
181 return Postproc(n);
182}
183
184TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode);
185TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy")
186 .set_body_typed(Postproc::DisallowAsyncStridedMemCopy);
187
188} // namespace meta_schedule
189} // namespace tvm
190