1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24class StorageAlignAxisOutOfRangeError : public ScheduleError {
25 public:
26 explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis)
27 : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {}
28
29 String FastErrorString() const final {
30 return "ScheduleError: The input `axis` is out of range. It is required to be in range "
31 "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set "
32 "storage alignment.";
33 }
34
35 String DetailRenderTemplate() const final {
36 std::ostringstream os;
37 int ndim = static_cast<int>(buffer_->shape.size());
38 os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim
39 << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim
40 << ") for storage_align. However, the input `axis` is " << axis_
41 << ", which is out of the expected range.";
42 return os.str();
43 }
44
45 IRModule mod() const final { return mod_; }
46 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
47
48 static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) {
49 int ndim = static_cast<int>(buffer->shape.size());
50 if (axis < -ndim || axis >= ndim) {
51 throw StorageAlignAxisOutOfRangeError(mod, buffer, axis);
52 }
53 // If axis is negative, convert it to a non-negative one.
54 if (axis < 0) {
55 axis += ndim;
56 }
57 return axis;
58 }
59
60 private:
61 IRModule mod_;
62 Buffer buffer_;
63 int axis_;
64};
65
66class NonAllocatedBufferError : public ScheduleError {
67 public:
68 explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {}
69
70 String FastErrorString() const final {
71 return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is "
72 " either a function parameter or defined in `match_buffer` of a block.";
73 }
74
75 String DetailRenderTemplate() const final {
76 std::ostringstream os;
77 os << "The input buffer " << buffer_->name
78 << " is not allocated by a block. This means the buffer is either a function parameter or "
79 "defined in `match_buffer` of a block.";
80 return os.str();
81 }
82
83 static StmtSRef CheckAndGetBufferAllocationSite(const IRModule& mod, const StmtSRef& block_sref,
84 const Buffer& buffer) {
85 auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, buffer);
86 if (!defining_site_sref.defined() || !is_alloc) {
87 throw NonAllocatedBufferError(mod, buffer);
88 }
89
90 return defining_site_sref.value();
91 }
92
93 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
94 IRModule mod() const final { return mod_; }
95
96 private:
97 IRModule mod_;
98 Buffer buffer_;
99};
100
101class StorageAlignInvalidFactorError : public ScheduleError {
102 public:
103 explicit StorageAlignInvalidFactorError(IRModule mod, int factor)
104 : mod_(std::move(mod)), factor_(factor) {}
105
106 String FastErrorString() const final {
107 return "ScheduleError: The input `factor` of storage_align is expected to be a positive "
108 "number.";
109 }
110
111 String DetailRenderTemplate() const final {
112 std::ostringstream os;
113 os << "The input `factor` of storage_align is expected to be a positive number. However, the "
114 "input `factor` is "
115 << factor_ << ", which is out of the expected range.";
116 return os.str();
117 }
118
119 static void Check(const IRModule& mod, int factor) {
120 if (factor <= 0) {
121 throw StorageAlignInvalidFactorError(mod, factor);
122 }
123 }
124
125 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
126 IRModule mod() const final { return mod_; }
127
128 private:
129 IRModule mod_;
130 int factor_;
131};
132
133class StorageAlignInvalidAnnotationError : public ScheduleError {
134 public:
135 explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block)
136 : mod_(std::move(mod)), block_(std::move(block)) {}
137
138 String FastErrorString() const final {
139 return "ScheduleError: The block annotation for storage align is expected to be an array of "
140 "4-integer-tuples (buffer_index, axis, factor, offset).";
141 }
142
143 String DetailRenderTemplate() const final {
144 std::ostringstream os;
145 os << "The block annotation for storage align is expected to be an array of 4-integer-tuples "
146 "(buffer_index, axis, factor, offset). However, the block annotation with key "
147 << attr::buffer_dim_align << " of the block {0} is "
148 << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected.";
149 return os.str();
150 }
151
152 static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const Block& block) {
153 // Get existing annotation value.
154 auto it = block->annotations.find(attr::buffer_dim_align);
155 if (it != block->annotations.end()) {
156 if (!IsValidAnnotation(block, (*it).second)) {
157 throw StorageAlignInvalidAnnotationError(mod, block);
158 }
159 return Downcast<StorageAlignAnnotation>((*it).second);
160 }
161
162 // Create new annotation value
163 StorageAlignAnnotation storage_align_annotation;
164 return storage_align_annotation;
165 }
166
167 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
168 IRModule mod() const final { return mod_; }
169
170 private:
171 static bool IsValidAnnotation(const Block& block, const ObjectRef& anno_value) {
172 if (!anno_value->IsInstance<ArrayNode>()) {
173 return false;
174 }
175 auto storage_align_annotations = Downcast<Array<ObjectRef>>(anno_value);
176 for (const ObjectRef& storage_align_annotation : storage_align_annotations) {
177 if (!storage_align_annotation->IsInstance<ArrayNode>()) {
178 return false;
179 }
180 auto storage_align_tuple = Downcast<Array<ObjectRef>>(storage_align_annotation);
181 // Check if the annotation is a 4-tuple.
182 if (storage_align_tuple.size() != 4) {
183 return false;
184 }
185 for (const ObjectRef& tuple_element : storage_align_tuple) {
186 if (!tuple_element->IsInstance<IntImmNode>()) {
187 return false;
188 }
189 }
190 }
191 return true;
192 }
193
194 IRModule mod_;
195 Block block_;
196};
197
198/*!
199 * \brief A helper mutator which recursively mutates the old buffer's storage scope and collects
200 * the block sref reuse information for the following replacement.
201 */
202class StorageScopeMutator : private ReplaceBufferMutator {
203 public:
204 /*!
205 * \param allocate_site The block where `old_buffer` was allocated.
206 * \param old_buffer The old buffer
207 * \param storage_scope The storage scope to be set
208 * \param block_sref_reuse The block sref reuse map to be updated
209 * \return The new block after the mutation
210 */
211 static Block Mutate(const Block& allocate_site, const Buffer& old_buffer,
212 const String& storage_scope, Map<Block, Block>* block_sref_reuse) {
213 Buffer new_buffer = WithScope(old_buffer, storage_scope);
214 StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse);
215 Stmt new_block = mutator.VisitStmt(allocate_site);
216 return Downcast<Block>(new_block);
217 }
218
219 private:
220 StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope,
221 Map<Block, Block>* block_sref_reuse)
222 : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {}
223
224 MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
225 auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
226 if (it != buffer_var_map_.end()) {
227 Buffer new_target_buffer = WithScope(match_buffer->buffer, it->second.scope());
228 buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
229 return MatchBufferRegion(new_target_buffer,
230 BufferRegion(it->second, match_buffer->source->region));
231 } else {
232 return match_buffer;
233 }
234 }
235};
236
237void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
238 int factor, int offset) {
239 const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
240 Buffer buffer =
241 GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, BufferIndexType::kWrite);
242 StorageAlignInvalidFactorError::Check(self->mod, factor);
243 axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
244 NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
245
246 // Step 1: Get existing or create new annotation value.
247 StorageAlignAnnotation storage_align_annotation =
248 StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod,
249 GetRef<Block>(block_ptr));
250
251 // Step 2: Update the annotation value
252 bool found = false;
253 StorageAlignTuple new_storage_align_tuple{Integer(buffer_index), Integer(axis), Integer(factor),
254 Integer(offset)};
255 for (size_t j = 0; j < storage_align_annotation.size(); ++j) {
256 const auto& storage_align_tuple = storage_align_annotation[j];
257 ICHECK(storage_align_tuple.size() == 4);
258 if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] == axis) {
259 storage_align_annotation.Set(j, std::move(new_storage_align_tuple));
260 found = true;
261 break;
262 }
263 }
264 if (!found) {
265 storage_align_annotation.push_back(std::move(new_storage_align_tuple));
266 }
267
268 // Step 3: Replace the block with the new annotation
269 Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation);
270 self->Replace(block_sref, new_block, {{GetRef<Block>(block_ptr), new_block}});
271}
272
273void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
274 const String& storage_scope) {
275 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
276 Buffer buffer =
277 GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);
278
279 // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return.
280 if (buffer.scope() == storage_scope) {
281 return;
282 }
283
284 // Step 2. Throw an error if the input storage scope is invalid.
285 CheckStorageScope(self, storage_scope);
286
287 // Step 3. Get the allocation site of the target buffer.
288 StmtSRef alloc_site_sref =
289 NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
290 const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);
291
292 // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given
293 // storage scope. In the meanwhile, collect the block sref reuse information.
294 Map<Block, Block> block_reuse_map;
295 Block new_block = StorageScopeMutator::Mutate(GetRef<Block>(alloc_site), buffer, storage_scope,
296 &block_reuse_map);
297 self->Replace(alloc_site_sref, new_block, block_reuse_map);
298}
299
300/******** InstructionKind Registration ********/
301
302struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
303 static constexpr const char* kName = "StorageAlign";
304 static constexpr bool kIsPure = false;
305
306 private:
307 static constexpr size_t kNumInputs = 1;
308 static constexpr size_t kNumAttrs = 4;
309 static constexpr size_t kNumDecisions = 0;
310
311 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
312 Integer axis, Integer factor, Integer offset) {
313 return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value,
314 offset->value);
315 }
316
317 static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
318 Integer axis, Integer factor, Integer offset) {
319 PythonAPICall py("storage_align");
320 py.Input("block", block_rv);
321 py.Input("buffer_index", buffer_index);
322 py.Input("axis", axis);
323 py.Input("factor", factor);
324 py.Input("offset", offset);
325 return py.Str();
326 }
327
328 template <typename>
329 friend struct ::tvm::tir::UnpackedInstTraits;
330};
331
332struct SetScopeTraits : public UnpackedInstTraits<SetScopeTraits> {
333 static constexpr const char* kName = "SetScope";
334 static constexpr bool kIsPure = false;
335
336 private:
337 static constexpr size_t kNumInputs = 1;
338 static constexpr size_t kNumAttrs = 2;
339 static constexpr size_t kNumDecisions = 0;
340
341 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
342 String storage_scope) {
343 return sch->SetScope(block_rv, buffer_index->value, storage_scope);
344 }
345
346 static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
347 String storage_scope) {
348 PythonAPICall py("set_scope");
349 py.Input("block", block_rv);
350 py.Input("buffer_index", buffer_index);
351 py.Input("storage_scope", storage_scope);
352 return py.Str();
353 }
354
355 template <typename>
356 friend struct ::tvm::tir::UnpackedInstTraits;
357};
358
359TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
360TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits);
361
362} // namespace tir
363} // namespace tvm
364