1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/schedule/cuda/thread_bind.h> |
20 | |
21 | #include "../utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace tir { |
25 | |
26 | /*! \brief Find all the blocks that are not bound */ |
27 | class UnboundBlockFinder : private StmtVisitor { |
28 | public: |
29 | static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) { |
30 | UnboundBlockFinder finder(self); |
31 | for (const auto& kv : self->mod->functions) { |
32 | GlobalVar g_var = kv.first; |
33 | BaseFunc base_func = kv.second; |
34 | if (const auto* prim_func = base_func.as<PrimFuncNode>()) { |
35 | finder.global_var_name_ = g_var->name_hint; |
36 | finder(Downcast<BlockRealize>(prim_func->body)->block->body); |
37 | } |
38 | } |
39 | return std::move(finder.blocks_); |
40 | } |
41 | |
42 | private: |
43 | void VisitStmt_(const ForNode* loop) final { |
44 | runtime::ThreadScope thread_scope = GetThreadScope(loop); |
45 | if (IsBlockIdx(thread_scope)) { |
46 | ++n_block_idx_; |
47 | } else if (IsThreadIdx(thread_scope)) { |
48 | ++n_thread_idx_; |
49 | } |
50 | if (n_block_idx_ == 0 || n_thread_idx_ == 0) { |
51 | StmtVisitor::VisitStmt_(loop); |
52 | } |
53 | if (IsBlockIdx(thread_scope)) { |
54 | --n_block_idx_; |
55 | } else if (IsThreadIdx(thread_scope)) { |
56 | --n_thread_idx_; |
57 | } |
58 | } |
59 | |
60 | void VisitStmt_(const BlockNode* block) final { |
61 | blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_); |
62 | } |
63 | |
64 | explicit UnboundBlockFinder(const ScheduleState& self) |
65 | : self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {} |
66 | |
67 | /*! \brief The schedule state */ |
68 | const ScheduleState& self_; |
69 | /*! \brief The list of unbound blocks */ |
70 | std::vector<std::pair<StmtSRef, String>> blocks_; |
71 | /*! \brief The number of blockIdx above the current stmt */ |
72 | int n_block_idx_; |
73 | /*! \brief The number of threadIdx above the current stmt */ |
74 | int n_thread_idx_; |
75 | /*! \brief The name of the global var */ |
76 | String global_var_name_; |
77 | }; |
78 | |
79 | } // namespace tir |
80 | } // namespace tvm |
81 | |
82 | namespace tvm { |
83 | namespace meta_schedule { |
84 | |
85 | /*! \brief Add thread binding to unbound blocks */ |
86 | class RewriteUnboundBlockNode : public PostprocNode { |
87 | public: |
88 | // Inherited from PostprocNode |
89 | void InitializeWithTuneContext(const TuneContext& context) final { |
90 | CHECK(context->target.defined()) << "ValueError: target is not defined" ; |
91 | Optional<Integer> max_threads_per_block = |
92 | context->target.value()->GetAttr<Integer>("max_threads_per_block" ); |
93 | CHECK(max_threads_per_block.defined()) |
94 | << "ValueError: missing attribute `max_threads_per_block` in the target" ; |
95 | this->max_threads_per_block_ = max_threads_per_block.value().IntValue(); |
96 | } |
97 | |
98 | // Inherited from PostprocNode |
99 | bool Apply(const tir::Schedule& sch) final; |
100 | |
101 | Postproc Clone() const { |
102 | ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>(*this); |
103 | return Postproc(n); |
104 | } |
105 | |
106 | public: |
107 | /*! \brief The max number of threads per block from Target */ |
108 | int max_threads_per_block_ = -1; |
109 | /*! \brief The max number of threadblocks in the cuda device */ |
110 | int max_threadblocks_ = -1; |
111 | |
112 | void VisitAttrs(tvm::AttrVisitor* v) { |
113 | // `max_threads_per_block_` is not visited |
114 | // `max_threadblocks_` is not visited |
115 | } |
116 | |
117 | static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock" ; |
118 | TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); |
119 | }; |
120 | |
121 | bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { |
122 | using tir::BlockRV; |
123 | using tir::ExprRV; |
124 | using tir::LoopRV; |
125 | using tir::Schedule; |
126 | ICHECK_NE(this->max_threads_per_block_, -1); |
127 | auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { |
128 | return Integer(std::min(t, max_extent)); |
129 | }; |
130 | std::vector<std::pair<tir::StmtSRef, String>> unbound_blocks = |
131 | tir::UnboundBlockFinder::Find(sch->state()); |
132 | for (const auto& kv : unbound_blocks) { |
133 | tir::StmtSRef block_sref = kv.first; |
134 | String global_var_name = kv.second; |
135 | BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); |
136 | BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); |
137 | } |
138 | return true; |
139 | } |
140 | |
141 | Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { |
142 | ObjectPtr<RewriteUnboundBlockNode> n = make_object<RewriteUnboundBlockNode>(); |
143 | n->max_threadblocks_ = max_threadblocks; |
144 | n->max_threads_per_block_ = -1; |
145 | return Postproc(n); |
146 | } |
147 | |
148 | TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); |
149 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock" ) |
150 | .set_body_typed(Postproc::RewriteUnboundBlock); |
151 | |
152 | } // namespace meta_schedule |
153 | } // namespace tvm |
154 | |