1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/schedule/cuda/thread_bind.h>
20
21#include "../utils.h"
22
23namespace tvm {
24namespace tir {
25
26/*! \brief Find all the blocks that are not bound */
27class UnboundBlockFinder : private StmtVisitor {
28 public:
29 static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) {
30 UnboundBlockFinder finder(self);
31 for (const auto& kv : self->mod->functions) {
32 GlobalVar g_var = kv.first;
33 BaseFunc base_func = kv.second;
34 if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
35 finder.global_var_name_ = g_var->name_hint;
36 finder(Downcast<BlockRealize>(prim_func->body)->block->body);
37 }
38 }
39 return std::move(finder.blocks_);
40 }
41
42 private:
43 void VisitStmt_(const ForNode* loop) final {
44 runtime::ThreadScope thread_scope = GetThreadScope(loop);
45 if (IsBlockIdx(thread_scope)) {
46 ++n_block_idx_;
47 } else if (IsThreadIdx(thread_scope)) {
48 ++n_thread_idx_;
49 }
50 if (n_block_idx_ == 0 || n_thread_idx_ == 0) {
51 StmtVisitor::VisitStmt_(loop);
52 }
53 if (IsBlockIdx(thread_scope)) {
54 --n_block_idx_;
55 } else if (IsThreadIdx(thread_scope)) {
56 --n_thread_idx_;
57 }
58 }
59
60 void VisitStmt_(const BlockNode* block) final {
61 blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_);
62 }
63
64 explicit UnboundBlockFinder(const ScheduleState& self)
65 : self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {}
66
67 /*! \brief The schedule state */
68 const ScheduleState& self_;
69 /*! \brief The list of unbound blocks */
70 std::vector<std::pair<StmtSRef, String>> blocks_;
71 /*! \brief The number of blockIdx above the current stmt */
72 int n_block_idx_;
73 /*! \brief The number of threadIdx above the current stmt */
74 int n_thread_idx_;
75 /*! \brief The name of the global var */
76 String global_var_name_;
77};
78
79} // namespace tir
80} // namespace tvm
81
82namespace tvm {
83namespace meta_schedule {
84
85/*! \brief Add thread binding to unbound blocks */
86class RewriteUnboundBlockNode : public PostprocNode {
87 public:
88 // Inherited from PostprocNode
89 void InitializeWithTuneContext(const TuneContext& context) final {
90 CHECK(context->target.defined()) << "ValueError: target is not defined";
91 Optional<Integer> max_threads_per_block =
92 context->target.value()->GetAttr<Integer>("max_threads_per_block");
93 CHECK(max_threads_per_block.defined())
94 << "ValueError: missing attribute `max_threads_per_block` in the target";
95 this->max_threads_per_block_ = max_threads_per_block.value().IntValue();
96 }
97
98 // Inherited from PostprocNode
99 bool Apply(const tir::Schedule& sch) final;
100
101 Postproc Clone() const {
102 ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>(*this);
103 return Postproc(n);
104 }
105
106 public:
107 /*! \brief The max number of threads per block from Target */
108 int max_threads_per_block_ = -1;
109 /*! \brief The max number of threadblocks in the cuda device */
110 int max_threadblocks_ = -1;
111
112 void VisitAttrs(tvm::AttrVisitor* v) {
113 // `max_threads_per_block_` is not visited
114 // `max_threadblocks_` is not visited
115 }
116
117 static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock";
118 TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode);
119};
120
121bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) {
122 using tir::BlockRV;
123 using tir::ExprRV;
124 using tir::LoopRV;
125 using tir::Schedule;
126 ICHECK_NE(this->max_threads_per_block_, -1);
127 auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV {
128 return Integer(std::min(t, max_extent));
129 };
130 std::vector<std::pair<tir::StmtSRef, String>> unbound_blocks =
131 tir::UnboundBlockFinder::Find(sch->state());
132 for (const auto& kv : unbound_blocks) {
133 tir::StmtSRef block_sref = kv.first;
134 String global_var_name = kv.second;
135 BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
136 BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor);
137 }
138 return true;
139}
140
141Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) {
142 ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>();
143 n->max_threadblocks_ = max_threadblocks;
144 n->max_threads_per_block_ = -1;
145 return Postproc(n);
146}
147
148TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode);
149TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock")
150 .set_body_typed(Postproc::RewriteUnboundBlock);
151
152} // namespace meta_schedule
153} // namespace tvm
154