1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/transform.h>
20
21#include "../utils.h"
22
23namespace tvm {
24namespace tir {
25
26class ThreadExtentChecker : private StmtVisitor {
27 public:
28 static bool Check(const Stmt& stmt, int thread_warp_size) {
29 try {
30 ICHECK(thread_warp_size > 0);
31 ThreadExtentChecker checker(thread_warp_size);
32 checker.VisitStmt(stmt);
33 return true;
34 } catch (const dmlc::Error& e) {
35 return false;
36 }
37 }
38
39 private:
40 explicit ThreadExtentChecker(int thread_warp_size) : thread_warp_size_(thread_warp_size) {}
41
42 void VisitStmt_(const ForNode* loop) {
43 runtime::ThreadScope thread_scope = GetThreadScope(loop);
44 if (IsThreadIdx(thread_scope)) {
45 if (const int64_t* p_ext = GetLoopIntExtent(loop)) {
46 int64_t ext = *p_ext;
47 if (thread_scope.dim_index == 0) {
48 std::swap(thread_idx_x, ext);
49 StmtVisitor::VisitStmt_(loop);
50 std::swap(thread_idx_x, ext);
51 } else if (thread_scope.dim_index == 1) {
52 std::swap(thread_idx_y, ext);
53 StmtVisitor::VisitStmt_(loop);
54 std::swap(thread_idx_y, ext);
55 } else if (thread_scope.dim_index == 2) {
56 std::swap(thread_idx_z, ext);
57 StmtVisitor::VisitStmt_(loop);
58 std::swap(thread_idx_z, ext);
59 } else {
60 StmtVisitor::VisitStmt_(loop);
61 }
62 return;
63 } else {
64 throw dmlc::Error("Dynamic thread extent");
65 }
66 }
67 StmtVisitor::VisitStmt_(loop);
68 }
69
70 void VisitStmt_(const BlockNode* block) {
71 int old_thread_idx_x = thread_idx_x;
72 if (block->annotations.count(attr::warp_execution)) {
73 thread_idx_x = thread_warp_size_;
74 }
75 if (Optional<Integer> low_inclusive =
76 GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
77 if (Optional<Integer> high_inclusive =
78 GetAnn<Integer>(block, attr::meta_schedule_thread_extent_high_inclusive)) {
79 int64_t low = low_inclusive.value()->value;
80 int64_t high = high_inclusive.value()->value;
81 int64_t thread_extent_product = thread_idx_x * thread_idx_y * thread_idx_z;
82 if (!(low <= thread_extent_product && thread_extent_product <= high)) {
83 throw dmlc::Error("Thread extent");
84 }
85 }
86 }
87 StmtVisitor::VisitStmt_(block);
88 thread_idx_x = old_thread_idx_x;
89 }
90
91 int64_t thread_idx_x = 1;
92 int64_t thread_idx_y = 1;
93 int64_t thread_idx_z = 1;
94 int thread_warp_size_ = -1;
95};
96
97} // namespace tir
98} // namespace tvm
99
100namespace tvm {
101namespace meta_schedule {
102
103/*! \brief Extract attribute from a target. */
104Integer Extract(const Target& target, const char* name) {
105 ICHECK(target.defined());
106 if (Optional<Integer> v = target->GetAttr<Integer>(name)) {
107 return v.value();
108 }
109 LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target";
110 throw;
111}
112
113/*! \brief Verify the correctness of the generated GPU code. */
114class VerifyGPUCodeNode : public PostprocNode {
115 public:
116 Target target_{nullptr};
117 Map<String, PrimExpr> target_constraints_{nullptr};
118 int thread_warp_size_ = -1;
119
120 void InitializeWithTuneContext(const TuneContext& context) final {
121 ICHECK(context->target.defined());
122 this->target_ = context->target.value();
123 this->target_constraints_ = Map<String, PrimExpr>{
124 {"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")},
125 {"max_threads_per_block", Extract(this->target_, "max_threads_per_block")},
126 {"max_vthread", Integer(8)},
127 {"max_vector_bytes", Integer(16)},
128 };
129 thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue();
130 }
131
132 bool Verify(const IRModule& mod) const {
133 for (const auto& kv : mod->functions) {
134 if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) {
135 if (!tir::VerifyGPUCode(GetRef<tir::PrimFunc>(prim_func), this->target_constraints_)) {
136 return false;
137 }
138 }
139 }
140 return true;
141 }
142
143 bool Apply(const tir::Schedule& sch) final {
144 IRModule mod = sch->mod();
145 for (const auto& kv : mod->functions) {
146 const GlobalVar& g_var = kv.first;
147 const BaseFunc& base_func = kv.second;
148 if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
149 if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) {
150 return false;
151 }
152 IRModule lowered{nullptr};
153 try {
154 auto pass_list = Array<tvm::transform::Pass>();
155 // Phase 1
156 // First three passes are not needed in TIR schedule.
157 // pass_list.push_back(tir::transform::InjectPrefetch());
158 // pass_list.push_back(tir::transform::TextureFlatten());
159 // pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
160 pass_list.push_back(tir::transform::LowerCrossThreadReduction());
161 pass_list.push_back(tir::transform::LowerInitBlock());
162 pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
163 pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
164 pass_list.push_back(tir::transform::UnifyThreadBinding());
165 pass_list.push_back(tir::transform::CompactBufferAllocation());
166 pass_list.push_back(tir::transform::LowerMatchBuffer());
167 pass_list.push_back(tir::transform::InjectSoftwarePipeline());
168 pass_list.push_back(tir::transform::LowerOpaqueBlock());
169 pass_list.push_back(tir::transform::FlattenBuffer());
170 pass_list.push_back(tir::transform::BF16Legalize());
171 pass_list.push_back(tir::transform::NarrowDataType(32));
172 pass_list.push_back(tir::transform::Simplify());
173 // Phase 2
174 pass_list.push_back(tir::transform::VectorizeLoop(true));
175 pass_list.push_back(tir::transform::InjectVirtualThread());
176 pass_list.push_back(tir::transform::InjectDoubleBuffer());
177 pass_list.push_back(tir::transform::StorageRewrite());
178 pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
179 pass_list.push_back(tir::transform::LowerIntrin());
180 // Convert Function to IRModule
181 transform::PassContext pass_ctx = transform::PassContext::Current();
182 tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
183 runtime::String(g_var->name_hint));
184 f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin
185 bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
186 if (noalias) {
187 f = WithAttr(std::move(f), "tir.noalias", Bool(true));
188 }
189 IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
190 lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
191 } catch (const dmlc::Error& e) {
192 return false;
193 }
194 if (!Verify(lowered)) {
195 return false;
196 }
197 }
198 }
199 return true;
200 }
201
202 Postproc Clone() const {
203 ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>(*this);
204 n->target_constraints_ = this->target_constraints_;
205 return Postproc(n);
206 }
207
208 static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode";
209 TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode);
210};
211
212Postproc Postproc::VerifyGPUCode() {
213 ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>();
214 return Postproc(n);
215}
216
217TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode);
218TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode);
219
220} // namespace meta_schedule
221} // namespace tvm
222