1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <optional> |
21 | |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | /*! \brief The schedule error class when the padding size is invalid. */ |
28 | class InvalidPaddingError : public ScheduleError { |
29 | public: |
30 | InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding) |
31 | : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {} |
32 | IRModule mod() const final { return mod_; } |
33 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
34 | String FastErrorString() const final { |
35 | return "ScheduleError: The padding size for the block is invalid." ; |
36 | } |
37 | String DetailRenderTemplate() const final { |
38 | std::ostringstream os; |
39 | os << "The padding for the block {0} are invalid. It should be a list of " |
40 | << block_->iter_vars.size() << " non-negative integers. Got " << padding_; |
41 | return os.str(); |
42 | } |
43 | |
44 | static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) { |
45 | if (padding.size() != block->iter_vars.size()) { |
46 | throw InvalidPaddingError(self->mod, block, padding); |
47 | } |
48 | for (const auto& pad : padding) { |
49 | if (pad->value < 0) { |
50 | throw InvalidPaddingError(self->mod, block, padding); |
51 | } |
52 | } |
53 | } |
54 | |
55 | private: |
56 | IRModule mod_; |
57 | Block block_; |
58 | Array<Integer> padding_; |
59 | }; |
60 | |
61 | /*! \brief The schedule error class when the block body is not an Einsum pattern. */ |
62 | class NonEinsumError : public ScheduleError { |
63 | public: |
64 | explicit NonEinsumError(IRModule mod, Block block) |
65 | : mod_(std::move(mod)), block_(std::move(block)) {} |
66 | |
67 | IRModule mod() const final { return mod_; } |
68 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
69 | String FastErrorString() const final { |
70 | return "ScheduleError: The block is not a computation of Einsum pattern." ; |
71 | } |
72 | String DetailRenderTemplate() const final { |
73 | return "The block {0} not a computation of Einsum pattern." ; |
74 | } |
75 | |
76 | private: |
77 | IRModule mod_; |
78 | Block block_; |
79 | }; |
80 | |
81 | /*! \brief Data structure that represents a Einsum computation. */ |
82 | struct Einsum { |
83 | // The output buffer |
84 | Buffer output_buffer; |
85 | // The indices of the output buffer |
86 | Array<Var> output_indices; |
87 | // The indices of the input buffers |
88 | Map<Buffer, Array<Var>> input_indices; |
89 | }; |
90 | |
91 | class : public ExprVisitor { |
92 | public: |
93 | () = default; |
94 | |
95 | std::optional<Einsum> (const Block& block) { |
96 | const BufferStoreNode* update = block->body.as<BufferStoreNode>(); |
97 | // Step 1: Check the body is a BufferStore and the block has the init statement, and the |
98 | // BufferStore and the init statement store have the same output buffer indices. |
99 | if (update == nullptr || !block->init.defined()) { |
100 | return std::nullopt; |
101 | } |
102 | |
103 | if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(update); |
104 | opt_indices.defined()) { |
105 | ein_sum_.output_indices = std::move(opt_indices.value()); |
106 | } else { |
107 | return std::nullopt; |
108 | } |
109 | ein_sum_.output_buffer = update->buffer; |
110 | |
111 | const BufferStoreNode* init = block->init.value().as<BufferStoreNode>(); |
112 | ICHECK(init != nullptr); |
113 | if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) { |
114 | return std::nullopt; |
115 | } |
116 | // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are |
117 | // block iter variables. |
118 | CheckStoreValue(update->value); |
119 | if (fail_) { |
120 | return std::nullopt; |
121 | } |
122 | return std::move(ein_sum_); |
123 | } |
124 | |
125 | private: |
126 | void (const PrimExpr& update) { |
127 | // Check the update part has the form: |
128 | // Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ... |
129 | // where output_indices and input_indices_i are the indices are arrays whose elements are the |
130 | // block iter variables instead of composite PrimExpr, and op_i are the binary operations. |
131 | |
132 | // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer. |
133 | const AddNode* add = update.as<AddNode>(); |
134 | if (add == nullptr) { |
135 | fail_ = true; |
136 | return; |
137 | } |
138 | const BufferLoadNode* lhs = add->a.as<BufferLoadNode>(); |
139 | const BufferLoadNode* rhs = add->b.as<BufferLoadNode>(); |
140 | if (lhs == nullptr && rhs != nullptr) { |
141 | std::swap(lhs, rhs); |
142 | } |
143 | if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) || |
144 | !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) { |
145 | fail_ = true; |
146 | return; |
147 | } |
148 | VisitExpr(add->b); |
149 | } |
150 | |
151 | void (const PrimExpr& n) final { |
152 | if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) { |
153 | ExprVisitor::VisitExpr(n); |
154 | } else { |
155 | fail_ = true; |
156 | return; |
157 | } |
158 | } |
159 | |
160 | void (const BufferLoadNode* op) final { |
161 | if (auto it = ein_sum_.input_indices.find(op->buffer); |
162 | it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) { |
163 | fail_ = true; |
164 | return; |
165 | } |
166 | if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(op); opt_indices.defined()) { |
167 | ein_sum_.input_indices.Set(op->buffer, std::move(opt_indices.value())); |
168 | } else { |
169 | fail_ = true; |
170 | return; |
171 | } |
172 | } |
173 | |
174 | void (const CastNode* op) { VisitExpr(op->value); } |
175 | |
176 | bool () { return fail_; } |
177 | |
178 | bool (const Array<PrimExpr>& indices, const Array<Var>& other) { |
179 | return std::equal(indices.begin(), indices.end(), other.begin(), other.end(), |
180 | [](const PrimExpr& a, const Var& b) { return a.same_as(b); }); |
181 | } |
182 | |
183 | Einsum ; |
184 | bool {false}; |
185 | }; |
186 | |
187 | Einsum (const ScheduleState& self, const Block& block) { |
188 | EinsumExtractor ; |
189 | std::optional<Einsum> einsum = extractor.Extract(block); |
190 | if (!einsum.has_value()) { |
191 | throw NonEinsumError(self->mod, block); |
192 | } |
193 | return einsum.value(); |
194 | } |
195 | |
196 | class BufferNotAllocatedInScopeError : public ScheduleError { |
197 | public: |
198 | explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer) |
199 | : mod_(std::move(mod)), buffer_(std::move(buffer)) {} |
200 | |
201 | String FastErrorString() const final { |
202 | return "ScheduleError: The buffer is not allocated as an intermediate buffer in current " |
203 | "PrimFunc." ; |
204 | } |
205 | |
206 | String DetailRenderTemplate() const final { |
207 | std::ostringstream os; |
208 | os << "The buffer " << buffer_->name |
209 | << " is not allocated as an intermediate buffer in current PrimFunc." ; |
210 | return os.str(); |
211 | } |
212 | |
213 | IRModule mod() const final { return mod_; } |
214 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
215 | |
216 | private: |
217 | IRModule mod_; |
218 | Buffer buffer_; |
219 | }; |
220 | |
221 | class PadEinsumRewriter : public ReplaceBufferMutator { |
222 | public: |
223 | PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate, |
224 | Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap, |
225 | Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) |
226 | : ReplaceBufferMutator(buffer_remap, block_sref_reuse), |
227 | producer_predicate_(producer_predicate), |
228 | padded_iter_extents_(padded_iter_extents), |
229 | analyzer_(analyzer) {} |
230 | using ReplaceBufferMutator::VisitExpr_; |
231 | using ReplaceBufferMutator::VisitStmt_; |
232 | |
233 | Stmt VisitStmt_(const ForNode* op) final { |
234 | For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op)); |
235 | if (padded_iter_extents_.count(new_for->loop_var)) { |
236 | new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var); |
237 | } |
238 | return std::move(new_for); |
239 | } |
240 | |
241 | Block PadProducerBlock(Block block, const PrimExpr& predicate) { |
242 | BufferStore store = Downcast<BufferStore>(block->body); |
243 | store.CopyOnWrite()->value = |
244 | analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype()))); |
245 | block.CopyOnWrite()->body = std::move(store); |
246 | return block; |
247 | } |
248 | |
249 | Stmt VisitStmt_(const BlockNode* op) final { |
250 | Block old_block = GetRef<Block>(op); |
251 | Block new_block = Downcast<Block>(ReplaceBufferMutator::VisitStmt_(op)); |
252 | if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) { |
253 | new_block = PadProducerBlock(std::move(new_block), (*it).second); |
254 | } |
255 | |
256 | // Mutate block iters |
257 | Array<IterVar> new_iters; |
258 | bool changed = false; |
259 | for (const IterVar& iter : new_block->iter_vars) { |
260 | if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) { |
261 | changed = true; |
262 | new_iters.push_back( |
263 | IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type)); |
264 | } else { |
265 | new_iters.push_back(iter); |
266 | } |
267 | } |
268 | if (changed) { |
269 | new_block.CopyOnWrite()->iter_vars = std::move(new_iters); |
270 | } |
271 | if (!old_block.same_as(new_block)) { |
272 | block_sref_reuse_->Set(old_block, new_block); |
273 | } |
274 | return std::move(new_block); |
275 | } |
276 | |
277 | private: |
278 | const std::unordered_set<const BlockNode*> producer_blocks_; |
279 | const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate_; |
280 | const Map<Var, PrimExpr> padded_iter_extents_; |
281 | arith::Analyzer* analyzer_; |
282 | }; |
283 | |
284 | /*! \brief The schedule error class when the producer block cannot be padded. */ |
285 | class InvalidProducerError : public ScheduleError { |
286 | public: |
287 | explicit InvalidProducerError(IRModule mod, Block producer) |
288 | : mod_(std::move(mod)), producer_(std::move(producer)) {} |
289 | |
290 | String FastErrorString() const final { |
291 | return "ScheduleError: The producer block cannot be padded." ; |
292 | } |
293 | |
294 | String DetailRenderTemplate() const final { |
295 | std::ostringstream os; |
296 | os << "The producer block {0} cannot be padded. It should write to a single buffer and the " |
297 | "body should be a BufferStore." ; |
298 | return os.str(); |
299 | } |
300 | |
301 | IRModule mod() const final { return mod_; } |
302 | Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; } |
303 | |
304 | private: |
305 | IRModule mod_; |
306 | Buffer buffer_; |
307 | Block producer_; |
308 | }; |
309 | |
310 | void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integer>& padding) { |
311 | arith::Analyzer analyzer; |
312 | // Step 1: Input checking and error handling |
313 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
314 | BlockRealize realize = GetBlockRealize(self, block_sref); |
315 | |
316 | const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); |
317 | InvalidPaddingError::Check(self, GetRef<Block>(block), padding); |
318 | |
319 | const Array<StmtSRef> producers = GetProducers(self, block_sref); |
320 | { |
321 | auto f_check_block_properties = [&](const StmtSRef& block_sref, bool is_producer) { |
322 | CheckBlockHasTrivialBinding(self, block_sref); |
323 | if (is_producer) { |
324 | CheckCompleteBlock(self, block_sref, scope_sref); |
325 | } else { |
326 | CheckReductionBlock(self, block_sref, scope_sref); |
327 | } |
328 | Array loops = GetLoops(block_sref); |
329 | ICHECK(!loops.empty()); |
330 | CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front()); |
331 | }; |
332 | |
333 | // Check block properties of the computation block |
334 | f_check_block_properties(block_sref, false); |
335 | |
336 | // Check block properties of the producer block |
337 | for (const StmtSRef& producer_sref : producers) { |
338 | f_check_block_properties(producer_sref, true); |
339 | } |
340 | } |
341 | |
342 | Einsum einsum = ExtractEinsum(self, GetRef<Block>(block)); |
343 | |
344 | // Check input and output buffers are all allocated in the current scope. |
345 | { |
346 | auto f_check_buffer_allocated = [&](const Buffer& buffer) { |
347 | auto [defining_site_sref, is_allocate] = GetBufferDefiningSite(block_sref, buffer); |
348 | if (!defining_site_sref.defined() || !is_allocate) { |
349 | throw BufferNotAllocatedInScopeError(self->mod, buffer); |
350 | } |
351 | }; |
352 | f_check_buffer_allocated(einsum.output_buffer); |
353 | for (const auto& buffer_indices_pair : einsum.input_indices) { |
354 | f_check_buffer_allocated(buffer_indices_pair.first); |
355 | } |
356 | } |
357 | |
358 | // Step 2: Prepare buffer and variable remapping. Infer the new shape of the input and the output |
359 | // buffers. Infer the new extent of the block iters of the computation block and the producer |
360 | // block. |
361 | |
362 | Map<Var, PrimExpr> padded_iter_extents; // The new extents of both the block iters and loop vars |
363 | |
364 | // Convert the input padding array to a map from variables to the padded extents |
365 | for (int i = 0, n = padding.size(); i < n; ++i) { |
366 | const IterVar& iter = block->iter_vars[i]; |
367 | PrimExpr new_extent = |
368 | IntImm(iter->var->dtype, Downcast<Integer>(iter->dom->extent)->value + padding[i]->value); |
369 | padded_iter_extents.Set(iter->var, new_extent); |
370 | padded_iter_extents.Set(Downcast<Var>(realize->iter_values[i]), new_extent); |
371 | } |
372 | |
373 | Map<Buffer, Buffer> buffer_remap; // mapping from buffers to new buffers with padded shapes |
374 | |
375 | // Utility function to pad a buffer with the new shape |
376 | auto f_pad_buffer = [&padded_iter_extents](Buffer buffer, const Array<Var>& indices) -> Buffer { |
377 | Array<PrimExpr> new_shape; |
378 | for (const Var& index : indices) { |
379 | new_shape.push_back(padded_iter_extents.at(index)); |
380 | } |
381 | ICHECK_EQ(buffer->shape.size(), new_shape.size()); |
382 | buffer.CopyOnWrite()->shape = std::move(new_shape); |
383 | return buffer; |
384 | }; |
385 | |
386 | buffer_remap.Set(einsum.output_buffer, f_pad_buffer(einsum.output_buffer, einsum.output_indices)); |
387 | |
388 | std::unordered_map<const BlockNode*, PrimExpr> producer_predicate; |
389 | |
390 | // Different from the output block, the padding for the producer block is not directly specified |
391 | // as the input argument. Instead, it is inferred from indices of the producer buffer accessed in |
392 | // the output block. |
393 | // We will find the indices (which are block iters) in BufferStore to the producer buffer |
394 | // and infer the new extents of the block iters and the corresponding loop vars. |
395 | for (const StmtSRef& producer_sref : producers) { |
396 | const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_sref); |
397 | const BufferStoreNode* buffer_store = producer_block->body.as<BufferStoreNode>(); |
398 | Optional<Array<Var>> producer_store_indices; |
399 | if (!buffer_store || producer_block->writes.size() != 1 || |
400 | !(producer_store_indices = CheckTrivialBufferIndices(buffer_store)).defined()) { |
401 | throw InvalidProducerError(self->mod, GetRef<Block>(producer_block)); |
402 | } |
403 | BlockRealize producer_realize = GetBlockRealize(self, producer_sref); |
404 | |
405 | const Buffer& old_buffer = producer_block->writes[0]->buffer; |
406 | Buffer new_buffer = f_pad_buffer(old_buffer, einsum.input_indices.at(old_buffer)); |
407 | buffer_remap.Set(old_buffer, new_buffer); |
408 | |
409 | // The predicate to ensure the producer block is in the original bound before padding |
410 | PrimExpr predicate = Bool(true); |
411 | Map<Var, PrimExpr> indices_to_padded_extents; // buffer indices to padded extents |
412 | for (int i = 0, n = producer_store_indices.value().size(); i < n; ++i) { |
413 | const Var& index = producer_store_indices.value()[i]; |
414 | PrimExpr padded_extent = new_buffer->shape[i]; |
415 | if (!analyzer.CanProveEqual(padded_extent, old_buffer->shape[i])) { |
416 | predicate = predicate && (index < old_buffer->shape[i]); |
417 | } |
418 | indices_to_padded_extents.Set(index, padded_extent); |
419 | } |
420 | |
421 | for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { |
422 | const IterVar& iter = producer_block->iter_vars[i]; |
423 | if (auto it = indices_to_padded_extents.find(iter->var); |
424 | it != indices_to_padded_extents.end()) { |
425 | const PrimExpr& padded_extent = (*it).second; |
426 | padded_iter_extents.Set(iter->var, padded_extent); |
427 | padded_iter_extents.Set(Downcast<Var>(producer_realize->iter_values[i]), padded_extent); |
428 | } else if (!is_one(iter->dom->extent)) { |
429 | throw InvalidProducerError(self->mod, GetRef<Block>(producer_block)); |
430 | } |
431 | } |
432 | producer_predicate[producer_block] = predicate; |
433 | } |
434 | |
435 | // Step 3: Mutate the AST subtree with the new buffers and the new block iter extents. |
436 | Map<Block, Block> block_sref_reuse; |
437 | PadEinsumRewriter rewriter(producer_predicate, padded_iter_extents, buffer_remap, |
438 | &block_sref_reuse, &analyzer); |
439 | const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); |
440 | Stmt new_scope_block = rewriter(GetRef<Block>(scope_block)); |
441 | |
442 | // Step 4: Do the actual replacement. |
443 | self->Replace(scope_sref, new_scope_block, block_sref_reuse); |
444 | } |
445 | |
446 | /******** Instruction Registration ********/ |
447 | |
448 | struct PadEinsumTraits : public UnpackedInstTraits<PadEinsumTraits> { |
449 | static constexpr const char* kName = "PadEinsum" ; |
450 | static constexpr bool kIsPure = false; |
451 | |
452 | private: |
453 | static constexpr size_t kNumInputs = 1; |
454 | static constexpr size_t kNumAttrs = 1; |
455 | static constexpr size_t kNumDecisions = 0; |
456 | |
457 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array<Integer> padding) { |
458 | sch->PadEinsum(block, padding); |
459 | } |
460 | |
461 | static String UnpackedAsPython(Array<String> outputs, String block, Array<Integer> padding) { |
462 | PythonAPICall py("pad_einsum" ); |
463 | py.Input("block" , block); |
464 | py.Input("padding" , padding); |
465 | return py.Str(); |
466 | } |
467 | |
468 | template <typename> |
469 | friend struct ::tvm::tir::UnpackedInstTraits; |
470 | }; |
471 | |
472 | TVM_REGISTER_INST_KIND_TRAITS(PadEinsumTraits); |
473 | |
474 | } // namespace tir |
475 | } // namespace tvm |
476 | |