1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*!
25 * \brief Parse instruction: sch.bind(..., axis)
26 * \param sch The schedule
27 * \param inst The instruction to be parsed
28 * \param axis The axis name expected
29 * \return NullOpt if parsing fails; Otherwise, the extent of thread axis
30 */
31Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) {
32 static InstructionKind inst_kind_bind = InstructionKind::Get("Bind");
33 if (!inst->kind.same_as(inst_kind_bind)) {
34 return NullOpt;
35 }
36 ICHECK_EQ(inst->inputs.size(), 1);
37 ICHECK_EQ(inst->attrs.size(), 1);
38 String thread_axis = Downcast<String>(inst->attrs[0]);
39 if (thread_axis != axis) {
40 return NullOpt;
41 }
42 return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
43}
44
45/*!
46 * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch)
47 * \param sch The schedule
48 * \param inst The instruction to be parsed
49 * \param vector_lane The number of vector lane in vectorized cooperative fetching
50 * \return NullOpt if parsing fails; Otherwise, the annotated block
51 */
52Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst,
53 int64_t* vector_lane) {
54 static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
55 if (!inst->kind.same_as(inst_kind_annotate)) {
56 return NullOpt;
57 }
58 ICHECK_EQ(inst->inputs.size(), 2);
59 ICHECK_EQ(inst->attrs.size(), 1);
60 String ann_key = Downcast<String>(inst->attrs[0]);
61 if (ann_key != attr::meta_schedule_cooperative_fetch) {
62 return NullOpt;
63 }
64 *vector_lane = Downcast<Integer>(sch->Get(Downcast<ExprRV>(inst->inputs[1])))->value;
65 return Downcast<BlockRV>(inst->inputs[0]);
66}
67
68/*!
69 * \brief Parse instruction: sch.annotate(..., attr::warp_execution)
70 * \param sch The schedule
71 * \param inst The instruction to be parsed
72 * \return Whether ths parsing is successful
73 */
74bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) {
75 static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
76 if (!inst->kind.same_as(inst_kind_annotate)) {
77 return false;
78 }
79 ICHECK_EQ(inst->inputs.size(), 2);
80 ICHECK_EQ(inst->attrs.size(), 1);
81 String ann_key = Downcast<String>(inst->attrs[0]);
82 return ann_key == attr::warp_execution;
83}
84
85size_t GetMaxUsedDtypeBytes(Block block) {
86 size_t max_bytes = 1;
87 static auto q_multiply_shift_per_axis = Op::Get("tir.q_multiply_shift_per_axis");
88 static auto q_multiply_shift = Op::Get("tir.q_multiply_shift");
89
90 tir::PostOrderVisit(block->body, [&](const ObjectRef& obj) {
91 if (const auto* store = obj.as<tir::BufferStoreNode>()) {
92 max_bytes = std::max(max_bytes, static_cast<size_t>(store->value->dtype.bytes()));
93 } else if (const auto* load = obj.as<tir::BufferLoadNode>()) {
94 max_bytes = std::max(max_bytes, static_cast<size_t>(load->dtype.bytes()));
95 } else if (const auto* call = obj.as<tir::CallNode>()) {
96 if (call->op.same_as(q_multiply_shift_per_axis) || call->op.same_as(q_multiply_shift)) {
97 // q_multiply_shift uses 64 bit multiply
98 max_bytes = std::max<size_t>(max_bytes, 8);
99 }
100 } else if (const auto* cast = obj.as<tir::CastNode>()) {
101 max_bytes = std::max<size_t>(max_bytes, cast->dtype.bytes());
102 }
103 });
104
105 return max_bytes;
106}
107
108} // namespace tir
109
110namespace meta_schedule {
111
112/*!
113 * \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching
114 * in loop bindings.
115 */
116class RewriteCooperativeFetchNode : public PostprocNode {
117 public:
118 // Inherited from PostprocNode
119 void InitializeWithTuneContext(const TuneContext& context) final {
120 if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
121 this->thread_warp_size_ = v.value()->value;
122 } else {
123 TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
124 }
125 }
126
127 // Inherited from PostprocNode
128 bool Apply(const tir::Schedule& sch) final;
129
130 Postproc Clone() const {
131 ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>(*this);
132 return Postproc(n);
133 }
134
135 void VisitAttrs(tvm::AttrVisitor* v) {}
136
137 static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
138 TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode);
139
140 private:
141 int thread_warp_size_ = -1;
142};
143
144bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
145 tir::Trace trace = sch->trace().value();
146 int64_t thread_extent_x = -1;
147 int64_t thread_extent_y = -1;
148 int64_t vector_lane = 1;
149 std::vector<std::function<void()>> tasks;
150 for (const tir::Instruction& inst : trace->insts) {
151 if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) {
152 thread_extent_x = new_thread_extent.value()->value;
153 continue;
154 }
155 if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
156 thread_extent_y = new_thread_extent.value()->value;
157 continue;
158 }
159 if (tir::ParseWarpExecutionAnn(sch, inst)) {
160 thread_extent_x = thread_warp_size_;
161 continue;
162 }
163 Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane);
164 if (!opt_block_rv.defined()) {
165 continue;
166 }
167 auto task = [thread_extent_x, thread_extent_y, vector_lane, sch,
168 block = opt_block_rv.value()]() mutable -> void {
169 sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch);
170 tir::LoopRV fused = sch->GetLoops(block).back();
171 int64_t fused_extent = -1;
172 if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) {
173 fused_extent = *extent;
174 } else {
175 return;
176 }
177 if (fused_extent % vector_lane != 0) {
178 vector_lane = 1;
179 }
180 // If the block involves 64 bit values, disable vectorization for now since
181 // vectorization of 64 bit values does not work well on CUDA.
182 // TODO(masahi, vinx13): Decouple epilogue fusion computation and shared to global store, so
183 // that we can always vectorize the latter.
184 if (tir::GetMaxUsedDtypeBytes(sch->Get(block)) > 4) {
185 vector_lane = 1;
186 }
187 if (thread_extent_y != -1) {
188 if (vector_lane > 1) {
189 Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
190 Integer(thread_extent_y), //
191 Integer(thread_extent_x), //
192 Integer(vector_lane)});
193 sch->Vectorize(split[3]);
194 sch->Bind(split[2], "threadIdx.x");
195 sch->Bind(split[1], "threadIdx.y");
196 } else {
197 Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
198 Integer(thread_extent_y), //
199 Integer(thread_extent_x)});
200 sch->Bind(split[2], "threadIdx.x");
201 sch->Bind(split[1], "threadIdx.y");
202 }
203 } else {
204 if (vector_lane > 1) {
205 Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
206 Integer(thread_extent_x), //
207 Integer(vector_lane)});
208 sch->Vectorize(split[2]);
209 sch->Bind(split[1], "threadIdx.x");
210 } else {
211 Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
212 sch->Bind(split[1], "threadIdx.x");
213 }
214 }
215 };
216 tasks.push_back(task);
217 }
218 for (auto&& task : tasks) {
219 task();
220 }
221 return true;
222}
223
224Postproc Postproc::RewriteCooperativeFetch() {
225 ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>();
226 return Postproc(n);
227}
228
229TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode);
230TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch")
231 .set_body_typed(Postproc::RewriteCooperativeFetch);
232
233} // namespace meta_schedule
234} // namespace tvm
235