1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/tir/transform.h> |
20 | |
21 | #include "../utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace tir { |
25 | |
26 | class ThreadExtentChecker : private StmtVisitor { |
27 | public: |
28 | static bool Check(const Stmt& stmt, int thread_warp_size) { |
29 | try { |
30 | ICHECK(thread_warp_size > 0); |
31 | ThreadExtentChecker checker(thread_warp_size); |
32 | checker.VisitStmt(stmt); |
33 | return true; |
34 | } catch (const dmlc::Error& e) { |
35 | return false; |
36 | } |
37 | } |
38 | |
39 | private: |
40 | explicit ThreadExtentChecker(int thread_warp_size) : thread_warp_size_(thread_warp_size) {} |
41 | |
42 | void VisitStmt_(const ForNode* loop) { |
43 | runtime::ThreadScope thread_scope = GetThreadScope(loop); |
44 | if (IsThreadIdx(thread_scope)) { |
45 | if (const int64_t* p_ext = GetLoopIntExtent(loop)) { |
46 | int64_t ext = *p_ext; |
47 | if (thread_scope.dim_index == 0) { |
48 | std::swap(thread_idx_x, ext); |
49 | StmtVisitor::VisitStmt_(loop); |
50 | std::swap(thread_idx_x, ext); |
51 | } else if (thread_scope.dim_index == 1) { |
52 | std::swap(thread_idx_y, ext); |
53 | StmtVisitor::VisitStmt_(loop); |
54 | std::swap(thread_idx_y, ext); |
55 | } else if (thread_scope.dim_index == 2) { |
56 | std::swap(thread_idx_z, ext); |
57 | StmtVisitor::VisitStmt_(loop); |
58 | std::swap(thread_idx_z, ext); |
59 | } else { |
60 | StmtVisitor::VisitStmt_(loop); |
61 | } |
62 | return; |
63 | } else { |
64 | throw dmlc::Error("Dynamic thread extent" ); |
65 | } |
66 | } |
67 | StmtVisitor::VisitStmt_(loop); |
68 | } |
69 | |
70 | void VisitStmt_(const BlockNode* block) { |
71 | int old_thread_idx_x = thread_idx_x; |
72 | if (block->annotations.count(attr::warp_execution)) { |
73 | thread_idx_x = thread_warp_size_; |
74 | } |
75 | if (Optional<Integer> low_inclusive = |
76 | GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) { |
77 | if (Optional<Integer> high_inclusive = |
78 | GetAnn<Integer>(block, attr::meta_schedule_thread_extent_high_inclusive)) { |
79 | int64_t low = low_inclusive.value()->value; |
80 | int64_t high = high_inclusive.value()->value; |
81 | int64_t thread_extent_product = thread_idx_x * thread_idx_y * thread_idx_z; |
82 | if (!(low <= thread_extent_product && thread_extent_product <= high)) { |
83 | throw dmlc::Error("Thread extent" ); |
84 | } |
85 | } |
86 | } |
87 | StmtVisitor::VisitStmt_(block); |
88 | thread_idx_x = old_thread_idx_x; |
89 | } |
90 | |
91 | int64_t thread_idx_x = 1; |
92 | int64_t thread_idx_y = 1; |
93 | int64_t thread_idx_z = 1; |
94 | int thread_warp_size_ = -1; |
95 | }; |
96 | |
97 | } // namespace tir |
98 | } // namespace tvm |
99 | |
100 | namespace tvm { |
101 | namespace meta_schedule { |
102 | |
103 | /*! \brief Extract attribute from a target. */ |
104 | Integer (const Target& target, const char* name) { |
105 | ICHECK(target.defined()); |
106 | if (Optional<Integer> v = target->GetAttr<Integer>(name)) { |
107 | return v.value(); |
108 | } |
109 | LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target" ; |
110 | throw; |
111 | } |
112 | |
113 | /*! \brief Verify the correctness of the generated GPU code. */ |
114 | class VerifyGPUCodeNode : public PostprocNode { |
115 | public: |
116 | Target target_{nullptr}; |
117 | Map<String, PrimExpr> target_constraints_{nullptr}; |
118 | int thread_warp_size_ = -1; |
119 | |
120 | void InitializeWithTuneContext(const TuneContext& context) final { |
121 | ICHECK(context->target.defined()); |
122 | this->target_ = context->target.value(); |
123 | this->target_constraints_ = Map<String, PrimExpr>{ |
124 | {"max_shared_memory_per_block" , Extract(this->target_, "max_shared_memory_per_block" )}, |
125 | {"max_threads_per_block" , Extract(this->target_, "max_threads_per_block" )}, |
126 | {"max_vthread" , Integer(8)}, |
127 | {"max_vector_bytes" , Integer(16)}, |
128 | }; |
129 | thread_warp_size_ = Extract(this->target_, "thread_warp_size" ).IntValue(); |
130 | } |
131 | |
132 | bool Verify(const IRModule& mod) const { |
133 | for (const auto& kv : mod->functions) { |
134 | if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) { |
135 | if (!tir::VerifyGPUCode(GetRef<tir::PrimFunc>(prim_func), this->target_constraints_)) { |
136 | return false; |
137 | } |
138 | } |
139 | } |
140 | return true; |
141 | } |
142 | |
143 | bool Apply(const tir::Schedule& sch) final { |
144 | IRModule mod = sch->mod(); |
145 | for (const auto& kv : mod->functions) { |
146 | const GlobalVar& g_var = kv.first; |
147 | const BaseFunc& base_func = kv.second; |
148 | if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) { |
149 | if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { |
150 | return false; |
151 | } |
152 | IRModule lowered{nullptr}; |
153 | try { |
154 | auto pass_list = Array<tvm::transform::Pass>(); |
155 | // Phase 1 |
156 | // First three passes are not needed in TIR schedule. |
157 | // pass_list.push_back(tir::transform::InjectPrefetch()); |
158 | // pass_list.push_back(tir::transform::TextureFlatten()); |
159 | // pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); |
160 | pass_list.push_back(tir::transform::LowerCrossThreadReduction()); |
161 | pass_list.push_back(tir::transform::LowerInitBlock()); |
162 | pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); |
163 | pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); |
164 | pass_list.push_back(tir::transform::UnifyThreadBinding()); |
165 | pass_list.push_back(tir::transform::CompactBufferAllocation()); |
166 | pass_list.push_back(tir::transform::LowerMatchBuffer()); |
167 | pass_list.push_back(tir::transform::InjectSoftwarePipeline()); |
168 | pass_list.push_back(tir::transform::LowerOpaqueBlock()); |
169 | pass_list.push_back(tir::transform::FlattenBuffer()); |
170 | pass_list.push_back(tir::transform::BF16Legalize()); |
171 | pass_list.push_back(tir::transform::NarrowDataType(32)); |
172 | pass_list.push_back(tir::transform::Simplify()); |
173 | // Phase 2 |
174 | pass_list.push_back(tir::transform::VectorizeLoop(true)); |
175 | pass_list.push_back(tir::transform::InjectVirtualThread()); |
176 | pass_list.push_back(tir::transform::InjectDoubleBuffer()); |
177 | pass_list.push_back(tir::transform::StorageRewrite()); |
178 | pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); |
179 | pass_list.push_back(tir::transform::LowerIntrin()); |
180 | // Convert Function to IRModule |
181 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
182 | tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol" , |
183 | runtime::String(g_var->name_hint)); |
184 | f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin |
185 | bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias" , Bool(true)).value(); |
186 | if (noalias) { |
187 | f = WithAttr(std::move(f), "tir.noalias" , Bool(true)); |
188 | } |
189 | IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}})); |
190 | lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); |
191 | } catch (const dmlc::Error& e) { |
192 | return false; |
193 | } |
194 | if (!Verify(lowered)) { |
195 | return false; |
196 | } |
197 | } |
198 | } |
199 | return true; |
200 | } |
201 | |
202 | Postproc Clone() const { |
203 | ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>(*this); |
204 | n->target_constraints_ = this->target_constraints_; |
205 | return Postproc(n); |
206 | } |
207 | |
208 | static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode" ; |
209 | TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); |
210 | }; |
211 | |
212 | Postproc Postproc::VerifyGPUCode() { |
213 | ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>(); |
214 | return Postproc(n); |
215 | } |
216 | |
217 | TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); |
218 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode" ).set_body_typed(Postproc::VerifyGPUCode); |
219 | |
220 | } // namespace meta_schedule |
221 | } // namespace tvm |
222 | |