1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! |
25 | * \brief Parse instruction: sch.bind(..., axis) |
26 | * \param sch The schedule |
27 | * \param inst The instruction to be parsed |
28 | * \param axis The axis name expected |
29 | * \return NullOpt if parsing fails; Otherwise, the extent of thread axis |
30 | */ |
31 | Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) { |
32 | static InstructionKind inst_kind_bind = InstructionKind::Get("Bind" ); |
33 | if (!inst->kind.same_as(inst_kind_bind)) { |
34 | return NullOpt; |
35 | } |
36 | ICHECK_EQ(inst->inputs.size(), 1); |
37 | ICHECK_EQ(inst->attrs.size(), 1); |
38 | String thread_axis = Downcast<String>(inst->attrs[0]); |
39 | if (thread_axis != axis) { |
40 | return NullOpt; |
41 | } |
42 | return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent); |
43 | } |
44 | |
45 | /*! |
46 | * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch) |
47 | * \param sch The schedule |
48 | * \param inst The instruction to be parsed |
49 | * \param vector_lane The number of vector lane in vectorized cooperative fetching |
50 | * \return NullOpt if parsing fails; Otherwise, the annotated block |
51 | */ |
52 | Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst, |
53 | int64_t* vector_lane) { |
54 | static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate" ); |
55 | if (!inst->kind.same_as(inst_kind_annotate)) { |
56 | return NullOpt; |
57 | } |
58 | ICHECK_EQ(inst->inputs.size(), 2); |
59 | ICHECK_EQ(inst->attrs.size(), 1); |
60 | String ann_key = Downcast<String>(inst->attrs[0]); |
61 | if (ann_key != attr::meta_schedule_cooperative_fetch) { |
62 | return NullOpt; |
63 | } |
64 | *vector_lane = Downcast<Integer>(sch->Get(Downcast<ExprRV>(inst->inputs[1])))->value; |
65 | return Downcast<BlockRV>(inst->inputs[0]); |
66 | } |
67 | |
68 | /*! |
69 | * \brief Parse instruction: sch.annotate(..., attr::warp_execution) |
70 | * \param sch The schedule |
71 | * \param inst The instruction to be parsed |
72 | * \return Whether ths parsing is successful |
73 | */ |
74 | bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { |
75 | static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate" ); |
76 | if (!inst->kind.same_as(inst_kind_annotate)) { |
77 | return false; |
78 | } |
79 | ICHECK_EQ(inst->inputs.size(), 2); |
80 | ICHECK_EQ(inst->attrs.size(), 1); |
81 | String ann_key = Downcast<String>(inst->attrs[0]); |
82 | return ann_key == attr::warp_execution; |
83 | } |
84 | |
85 | size_t GetMaxUsedDtypeBytes(Block block) { |
86 | size_t max_bytes = 1; |
87 | static auto q_multiply_shift_per_axis = Op::Get("tir.q_multiply_shift_per_axis" ); |
88 | static auto q_multiply_shift = Op::Get("tir.q_multiply_shift" ); |
89 | |
90 | tir::PostOrderVisit(block->body, [&](const ObjectRef& obj) { |
91 | if (const auto* store = obj.as<tir::BufferStoreNode>()) { |
92 | max_bytes = std::max(max_bytes, static_cast<size_t>(store->value->dtype.bytes())); |
93 | } else if (const auto* load = obj.as<tir::BufferLoadNode>()) { |
94 | max_bytes = std::max(max_bytes, static_cast<size_t>(load->dtype.bytes())); |
95 | } else if (const auto* call = obj.as<tir::CallNode>()) { |
96 | if (call->op.same_as(q_multiply_shift_per_axis) || call->op.same_as(q_multiply_shift)) { |
97 | // q_multiply_shift uses 64 bit multiply |
98 | max_bytes = std::max<size_t>(max_bytes, 8); |
99 | } |
100 | } else if (const auto* cast = obj.as<tir::CastNode>()) { |
101 | max_bytes = std::max<size_t>(max_bytes, cast->dtype.bytes()); |
102 | } |
103 | }); |
104 | |
105 | return max_bytes; |
106 | } |
107 | |
108 | } // namespace tir |
109 | |
110 | namespace meta_schedule { |
111 | |
112 | /*! |
113 | * \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching |
114 | * in loop bindings. |
115 | */ |
116 | class RewriteCooperativeFetchNode : public PostprocNode { |
117 | public: |
118 | // Inherited from PostprocNode |
119 | void InitializeWithTuneContext(const TuneContext& context) final { |
120 | if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size" )) { |
121 | this->thread_warp_size_ = v.value()->value; |
122 | } else { |
123 | TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target" ; |
124 | } |
125 | } |
126 | |
127 | // Inherited from PostprocNode |
128 | bool Apply(const tir::Schedule& sch) final; |
129 | |
130 | Postproc Clone() const { |
131 | ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>(*this); |
132 | return Postproc(n); |
133 | } |
134 | |
135 | void VisitAttrs(tvm::AttrVisitor* v) {} |
136 | |
137 | static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch" ; |
138 | TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); |
139 | |
140 | private: |
141 | int thread_warp_size_ = -1; |
142 | }; |
143 | |
144 | bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { |
145 | tir::Trace trace = sch->trace().value(); |
146 | int64_t thread_extent_x = -1; |
147 | int64_t thread_extent_y = -1; |
148 | int64_t vector_lane = 1; |
149 | std::vector<std::function<void()>> tasks; |
150 | for (const tir::Instruction& inst : trace->insts) { |
151 | if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x" )) { |
152 | thread_extent_x = new_thread_extent.value()->value; |
153 | continue; |
154 | } |
155 | if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y" )) { |
156 | thread_extent_y = new_thread_extent.value()->value; |
157 | continue; |
158 | } |
159 | if (tir::ParseWarpExecutionAnn(sch, inst)) { |
160 | thread_extent_x = thread_warp_size_; |
161 | continue; |
162 | } |
163 | Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); |
164 | if (!opt_block_rv.defined()) { |
165 | continue; |
166 | } |
167 | auto task = [thread_extent_x, thread_extent_y, vector_lane, sch, |
168 | block = opt_block_rv.value()]() mutable -> void { |
169 | sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch); |
170 | tir::LoopRV fused = sch->GetLoops(block).back(); |
171 | int64_t fused_extent = -1; |
172 | if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) { |
173 | fused_extent = *extent; |
174 | } else { |
175 | return; |
176 | } |
177 | if (fused_extent % vector_lane != 0) { |
178 | vector_lane = 1; |
179 | } |
180 | // If the block involves 64 bit values, disable vectorization for now since |
181 | // vectorization of 64 bit values does not work well on CUDA. |
182 | // TODO(masahi, vinx13): Decouple epilogue fusion computation and shared to global store, so |
183 | // that we can always vectorize the latter. |
184 | if (tir::GetMaxUsedDtypeBytes(sch->Get(block)) > 4) { |
185 | vector_lane = 1; |
186 | } |
187 | if (thread_extent_y != -1) { |
188 | if (vector_lane > 1) { |
189 | Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, // |
190 | Integer(thread_extent_y), // |
191 | Integer(thread_extent_x), // |
192 | Integer(vector_lane)}); |
193 | sch->Vectorize(split[3]); |
194 | sch->Bind(split[2], "threadIdx.x" ); |
195 | sch->Bind(split[1], "threadIdx.y" ); |
196 | } else { |
197 | Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, // |
198 | Integer(thread_extent_y), // |
199 | Integer(thread_extent_x)}); |
200 | sch->Bind(split[2], "threadIdx.x" ); |
201 | sch->Bind(split[1], "threadIdx.y" ); |
202 | } |
203 | } else { |
204 | if (vector_lane > 1) { |
205 | Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, // |
206 | Integer(thread_extent_x), // |
207 | Integer(vector_lane)}); |
208 | sch->Vectorize(split[2]); |
209 | sch->Bind(split[1], "threadIdx.x" ); |
210 | } else { |
211 | Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)}); |
212 | sch->Bind(split[1], "threadIdx.x" ); |
213 | } |
214 | } |
215 | }; |
216 | tasks.push_back(task); |
217 | } |
218 | for (auto&& task : tasks) { |
219 | task(); |
220 | } |
221 | return true; |
222 | } |
223 | |
224 | Postproc Postproc::RewriteCooperativeFetch() { |
225 | ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>(); |
226 | return Postproc(n); |
227 | } |
228 | |
229 | TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); |
230 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch" ) |
231 | .set_body_typed(Postproc::RewriteCooperativeFetch); |
232 | |
233 | } // namespace meta_schedule |
234 | } // namespace tvm |
235 | |