1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <optional>
21
22#include "../utils.h"
23
24namespace tvm {
25namespace tir {
26
27/*! \brief The schedule error class when the padding size is invalid. */
28class InvalidPaddingError : public ScheduleError {
29 public:
30 InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
31 : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
32 IRModule mod() const final { return mod_; }
33 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
34 String FastErrorString() const final {
35 return "ScheduleError: The padding size for the block is invalid.";
36 }
37 String DetailRenderTemplate() const final {
38 std::ostringstream os;
39 os << "The padding for the block {0} are invalid. It should be a list of "
40 << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
41 return os.str();
42 }
43
44 static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
45 if (padding.size() != block->iter_vars.size()) {
46 throw InvalidPaddingError(self->mod, block, padding);
47 }
48 for (const auto& pad : padding) {
49 if (pad->value < 0) {
50 throw InvalidPaddingError(self->mod, block, padding);
51 }
52 }
53 }
54
55 private:
56 IRModule mod_;
57 Block block_;
58 Array<Integer> padding_;
59};
60
61/*! \brief The schedule error class when the block body is not an Einsum pattern. */
62class NonEinsumError : public ScheduleError {
63 public:
64 explicit NonEinsumError(IRModule mod, Block block)
65 : mod_(std::move(mod)), block_(std::move(block)) {}
66
67 IRModule mod() const final { return mod_; }
68 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
69 String FastErrorString() const final {
70 return "ScheduleError: The block is not a computation of Einsum pattern.";
71 }
72 String DetailRenderTemplate() const final {
73 return "The block {0} not a computation of Einsum pattern.";
74 }
75
76 private:
77 IRModule mod_;
78 Block block_;
79};
80
81/*! \brief Data structure that represents a Einsum computation. */
82struct Einsum {
83 // The output buffer
84 Buffer output_buffer;
85 // The indices of the output buffer
86 Array<Var> output_indices;
87 // The indices of the input buffers
88 Map<Buffer, Array<Var>> input_indices;
89};
90
91class EinsumExtractor : public ExprVisitor {
92 public:
93 EinsumExtractor() = default;
94
95 std::optional<Einsum> Extract(const Block& block) {
96 const BufferStoreNode* update = block->body.as<BufferStoreNode>();
97 // Step 1: Check the body is a BufferStore and the block has the init statement, and the
98 // BufferStore and the init statement store have the same output buffer indices.
99 if (update == nullptr || !block->init.defined()) {
100 return std::nullopt;
101 }
102
103 if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(update);
104 opt_indices.defined()) {
105 ein_sum_.output_indices = std::move(opt_indices.value());
106 } else {
107 return std::nullopt;
108 }
109 ein_sum_.output_buffer = update->buffer;
110
111 const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
112 ICHECK(init != nullptr);
113 if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
114 return std::nullopt;
115 }
116 // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
117 // block iter variables.
118 CheckStoreValue(update->value);
119 if (fail_) {
120 return std::nullopt;
121 }
122 return std::move(ein_sum_);
123 }
124
125 private:
126 void CheckStoreValue(const PrimExpr& update) {
127 // Check the update part has the form:
128 // Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
129 // where output_indices and input_indices_i are the indices are arrays whose elements are the
130 // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
131
132 // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
133 const AddNode* add = update.as<AddNode>();
134 if (add == nullptr) {
135 fail_ = true;
136 return;
137 }
138 const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
139 const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
140 if (lhs == nullptr && rhs != nullptr) {
141 std::swap(lhs, rhs);
142 }
143 if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
144 !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
145 fail_ = true;
146 return;
147 }
148 VisitExpr(add->b);
149 }
150
151 void VisitExpr(const PrimExpr& n) final {
152 if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
153 ExprVisitor::VisitExpr(n);
154 } else {
155 fail_ = true;
156 return;
157 }
158 }
159
160 void VisitExpr_(const BufferLoadNode* op) final {
161 if (auto it = ein_sum_.input_indices.find(op->buffer);
162 it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
163 fail_ = true;
164 return;
165 }
166 if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(op); opt_indices.defined()) {
167 ein_sum_.input_indices.Set(op->buffer, std::move(opt_indices.value()));
168 } else {
169 fail_ = true;
170 return;
171 }
172 }
173
174 void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
175
176 bool Fail() { return fail_; }
177
178 bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
179 return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
180 [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
181 }
182
183 Einsum ein_sum_;
184 bool fail_{false};
185};
186
187Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
188 EinsumExtractor extractor;
189 std::optional<Einsum> einsum = extractor.Extract(block);
190 if (!einsum.has_value()) {
191 throw NonEinsumError(self->mod, block);
192 }
193 return einsum.value();
194}
195
196class BufferNotAllocatedInScopeError : public ScheduleError {
197 public:
198 explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
199 : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
200
201 String FastErrorString() const final {
202 return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
203 "PrimFunc.";
204 }
205
206 String DetailRenderTemplate() const final {
207 std::ostringstream os;
208 os << "The buffer " << buffer_->name
209 << " is not allocated as an intermediate buffer in current PrimFunc.";
210 return os.str();
211 }
212
213 IRModule mod() const final { return mod_; }
214 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
215
216 private:
217 IRModule mod_;
218 Buffer buffer_;
219};
220
221class PadEinsumRewriter : public ReplaceBufferMutator {
222 public:
223 PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
224 Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
225 Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
226 : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
227 producer_predicate_(producer_predicate),
228 padded_iter_extents_(padded_iter_extents),
229 analyzer_(analyzer) {}
230 using ReplaceBufferMutator::VisitExpr_;
231 using ReplaceBufferMutator::VisitStmt_;
232
233 Stmt VisitStmt_(const ForNode* op) final {
234 For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
235 if (padded_iter_extents_.count(new_for->loop_var)) {
236 new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
237 }
238 return std::move(new_for);
239 }
240
241 Block PadProducerBlock(Block block, const PrimExpr& predicate) {
242 BufferStore store = Downcast<BufferStore>(block->body);
243 store.CopyOnWrite()->value =
244 analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));
245 block.CopyOnWrite()->body = std::move(store);
246 return block;
247 }
248
249 Stmt VisitStmt_(const BlockNode* op) final {
250 Block old_block = GetRef<Block>(op);
251 Block new_block = Downcast<Block>(ReplaceBufferMutator::VisitStmt_(op));
252 if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) {
253 new_block = PadProducerBlock(std::move(new_block), (*it).second);
254 }
255
256 // Mutate block iters
257 Array<IterVar> new_iters;
258 bool changed = false;
259 for (const IterVar& iter : new_block->iter_vars) {
260 if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) {
261 changed = true;
262 new_iters.push_back(
263 IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type));
264 } else {
265 new_iters.push_back(iter);
266 }
267 }
268 if (changed) {
269 new_block.CopyOnWrite()->iter_vars = std::move(new_iters);
270 }
271 if (!old_block.same_as(new_block)) {
272 block_sref_reuse_->Set(old_block, new_block);
273 }
274 return std::move(new_block);
275 }
276
277 private:
278 const std::unordered_set<const BlockNode*> producer_blocks_;
279 const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate_;
280 const Map<Var, PrimExpr> padded_iter_extents_;
281 arith::Analyzer* analyzer_;
282};
283
284/*! \brief The schedule error class when the producer block cannot be padded. */
285class InvalidProducerError : public ScheduleError {
286 public:
287 explicit InvalidProducerError(IRModule mod, Block producer)
288 : mod_(std::move(mod)), producer_(std::move(producer)) {}
289
290 String FastErrorString() const final {
291 return "ScheduleError: The producer block cannot be padded.";
292 }
293
294 String DetailRenderTemplate() const final {
295 std::ostringstream os;
296 os << "The producer block {0} cannot be padded. It should write to a single buffer and the "
297 "body should be a BufferStore.";
298 return os.str();
299 }
300
301 IRModule mod() const final { return mod_; }
302 Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }
303
304 private:
305 IRModule mod_;
306 Buffer buffer_;
307 Block producer_;
308};
309
310void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integer>& padding) {
311 arith::Analyzer analyzer;
312 // Step 1: Input checking and error handling
313 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
314 BlockRealize realize = GetBlockRealize(self, block_sref);
315
316 const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
317 InvalidPaddingError::Check(self, GetRef<Block>(block), padding);
318
319 const Array<StmtSRef> producers = GetProducers(self, block_sref);
320 {
321 auto f_check_block_properties = [&](const StmtSRef& block_sref, bool is_producer) {
322 CheckBlockHasTrivialBinding(self, block_sref);
323 if (is_producer) {
324 CheckCompleteBlock(self, block_sref, scope_sref);
325 } else {
326 CheckReductionBlock(self, block_sref, scope_sref);
327 }
328 Array loops = GetLoops(block_sref);
329 ICHECK(!loops.empty());
330 CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
331 };
332
333 // Check block properties of the computation block
334 f_check_block_properties(block_sref, false);
335
336 // Check block properties of the producer block
337 for (const StmtSRef& producer_sref : producers) {
338 f_check_block_properties(producer_sref, true);
339 }
340 }
341
342 Einsum einsum = ExtractEinsum(self, GetRef<Block>(block));
343
344 // Check input and output buffers are all allocated in the current scope.
345 {
346 auto f_check_buffer_allocated = [&](const Buffer& buffer) {
347 auto [defining_site_sref, is_allocate] = GetBufferDefiningSite(block_sref, buffer);
348 if (!defining_site_sref.defined() || !is_allocate) {
349 throw BufferNotAllocatedInScopeError(self->mod, buffer);
350 }
351 };
352 f_check_buffer_allocated(einsum.output_buffer);
353 for (const auto& buffer_indices_pair : einsum.input_indices) {
354 f_check_buffer_allocated(buffer_indices_pair.first);
355 }
356 }
357
358 // Step 2: Prepare buffer and variable remapping. Infer the new shape of the input and the output
359 // buffers. Infer the new extent of the block iters of the computation block and the producer
360 // block.
361
362 Map<Var, PrimExpr> padded_iter_extents; // The new extents of both the block iters and loop vars
363
364 // Convert the input padding array to a map from variables to the padded extents
365 for (int i = 0, n = padding.size(); i < n; ++i) {
366 const IterVar& iter = block->iter_vars[i];
367 PrimExpr new_extent =
368 IntImm(iter->var->dtype, Downcast<Integer>(iter->dom->extent)->value + padding[i]->value);
369 padded_iter_extents.Set(iter->var, new_extent);
370 padded_iter_extents.Set(Downcast<Var>(realize->iter_values[i]), new_extent);
371 }
372
373 Map<Buffer, Buffer> buffer_remap; // mapping from buffers to new buffers with padded shapes
374
375 // Utility function to pad a buffer with the new shape
376 auto f_pad_buffer = [&padded_iter_extents](Buffer buffer, const Array<Var>& indices) -> Buffer {
377 Array<PrimExpr> new_shape;
378 for (const Var& index : indices) {
379 new_shape.push_back(padded_iter_extents.at(index));
380 }
381 ICHECK_EQ(buffer->shape.size(), new_shape.size());
382 buffer.CopyOnWrite()->shape = std::move(new_shape);
383 return buffer;
384 };
385
386 buffer_remap.Set(einsum.output_buffer, f_pad_buffer(einsum.output_buffer, einsum.output_indices));
387
388 std::unordered_map<const BlockNode*, PrimExpr> producer_predicate;
389
390 // Different from the output block, the padding for the producer block is not directly specified
391 // as the input argument. Instead, it is inferred from indices of the producer buffer accessed in
392 // the output block.
393 // We will find the indices (which are block iters) in BufferStore to the producer buffer
394 // and infer the new extents of the block iters and the corresponding loop vars.
395 for (const StmtSRef& producer_sref : producers) {
396 const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_sref);
397 const BufferStoreNode* buffer_store = producer_block->body.as<BufferStoreNode>();
398 Optional<Array<Var>> producer_store_indices;
399 if (!buffer_store || producer_block->writes.size() != 1 ||
400 !(producer_store_indices = CheckTrivialBufferIndices(buffer_store)).defined()) {
401 throw InvalidProducerError(self->mod, GetRef<Block>(producer_block));
402 }
403 BlockRealize producer_realize = GetBlockRealize(self, producer_sref);
404
405 const Buffer& old_buffer = producer_block->writes[0]->buffer;
406 Buffer new_buffer = f_pad_buffer(old_buffer, einsum.input_indices.at(old_buffer));
407 buffer_remap.Set(old_buffer, new_buffer);
408
409 // The predicate to ensure the producer block is in the original bound before padding
410 PrimExpr predicate = Bool(true);
411 Map<Var, PrimExpr> indices_to_padded_extents; // buffer indices to padded extents
412 for (int i = 0, n = producer_store_indices.value().size(); i < n; ++i) {
413 const Var& index = producer_store_indices.value()[i];
414 PrimExpr padded_extent = new_buffer->shape[i];
415 if (!analyzer.CanProveEqual(padded_extent, old_buffer->shape[i])) {
416 predicate = predicate && (index < old_buffer->shape[i]);
417 }
418 indices_to_padded_extents.Set(index, padded_extent);
419 }
420
421 for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
422 const IterVar& iter = producer_block->iter_vars[i];
423 if (auto it = indices_to_padded_extents.find(iter->var);
424 it != indices_to_padded_extents.end()) {
425 const PrimExpr& padded_extent = (*it).second;
426 padded_iter_extents.Set(iter->var, padded_extent);
427 padded_iter_extents.Set(Downcast<Var>(producer_realize->iter_values[i]), padded_extent);
428 } else if (!is_one(iter->dom->extent)) {
429 throw InvalidProducerError(self->mod, GetRef<Block>(producer_block));
430 }
431 }
432 producer_predicate[producer_block] = predicate;
433 }
434
435 // Step 3: Mutate the AST subtree with the new buffers and the new block iter extents.
436 Map<Block, Block> block_sref_reuse;
437 PadEinsumRewriter rewriter(producer_predicate, padded_iter_extents, buffer_remap,
438 &block_sref_reuse, &analyzer);
439 const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
440 Stmt new_scope_block = rewriter(GetRef<Block>(scope_block));
441
442 // Step 4: Do the actual replacement.
443 self->Replace(scope_sref, new_scope_block, block_sref_reuse);
444}
445
446/******** Instruction Registration ********/
447
448struct PadEinsumTraits : public UnpackedInstTraits<PadEinsumTraits> {
449 static constexpr const char* kName = "PadEinsum";
450 static constexpr bool kIsPure = false;
451
452 private:
453 static constexpr size_t kNumInputs = 1;
454 static constexpr size_t kNumAttrs = 1;
455 static constexpr size_t kNumDecisions = 0;
456
457 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array<Integer> padding) {
458 sch->PadEinsum(block, padding);
459 }
460
461 static String UnpackedAsPython(Array<String> outputs, String block, Array<Integer> padding) {
462 PythonAPICall py("pad_einsum");
463 py.Input("block", block);
464 py.Input("padding", padding);
465 return py.Str();
466 }
467
468 template <typename>
469 friend struct ::tvm::tir::UnpackedInstTraits;
470};
471
472TVM_REGISTER_INST_KIND_TRAITS(PadEinsumTraits);
473
474} // namespace tir
475} // namespace tvm
476