1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | class StorageAlignAxisOutOfRangeError : public ScheduleError { |
25 | public: |
26 | explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) |
27 | : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} |
28 | |
29 | String FastErrorString() const final { |
30 | return "ScheduleError: The input `axis` is out of range. It is required to be in range " |
31 | "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " |
32 | "storage alignment." ; |
33 | } |
34 | |
35 | String DetailRenderTemplate() const final { |
36 | std::ostringstream os; |
37 | int ndim = static_cast<int>(buffer_->shape.size()); |
38 | os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim |
39 | << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim |
40 | << ") for storage_align. However, the input `axis` is " << axis_ |
41 | << ", which is out of the expected range." ; |
42 | return os.str(); |
43 | } |
44 | |
45 | IRModule mod() const final { return mod_; } |
46 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
47 | |
48 | static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { |
49 | int ndim = static_cast<int>(buffer->shape.size()); |
50 | if (axis < -ndim || axis >= ndim) { |
51 | throw StorageAlignAxisOutOfRangeError(mod, buffer, axis); |
52 | } |
53 | // If axis is negative, convert it to a non-negative one. |
54 | if (axis < 0) { |
55 | axis += ndim; |
56 | } |
57 | return axis; |
58 | } |
59 | |
60 | private: |
61 | IRModule mod_; |
62 | Buffer buffer_; |
63 | int axis_; |
64 | }; |
65 | |
66 | class NonAllocatedBufferError : public ScheduleError { |
67 | public: |
68 | explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} |
69 | |
70 | String FastErrorString() const final { |
71 | return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is " |
72 | " either a function parameter or defined in `match_buffer` of a block." ; |
73 | } |
74 | |
75 | String DetailRenderTemplate() const final { |
76 | std::ostringstream os; |
77 | os << "The input buffer " << buffer_->name |
78 | << " is not allocated by a block. This means the buffer is either a function parameter or " |
79 | "defined in `match_buffer` of a block." ; |
80 | return os.str(); |
81 | } |
82 | |
83 | static StmtSRef CheckAndGetBufferAllocationSite(const IRModule& mod, const StmtSRef& block_sref, |
84 | const Buffer& buffer) { |
85 | auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, buffer); |
86 | if (!defining_site_sref.defined() || !is_alloc) { |
87 | throw NonAllocatedBufferError(mod, buffer); |
88 | } |
89 | |
90 | return defining_site_sref.value(); |
91 | } |
92 | |
93 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
94 | IRModule mod() const final { return mod_; } |
95 | |
96 | private: |
97 | IRModule mod_; |
98 | Buffer buffer_; |
99 | }; |
100 | |
101 | class StorageAlignInvalidFactorError : public ScheduleError { |
102 | public: |
103 | explicit StorageAlignInvalidFactorError(IRModule mod, int factor) |
104 | : mod_(std::move(mod)), factor_(factor) {} |
105 | |
106 | String FastErrorString() const final { |
107 | return "ScheduleError: The input `factor` of storage_align is expected to be a positive " |
108 | "number." ; |
109 | } |
110 | |
111 | String DetailRenderTemplate() const final { |
112 | std::ostringstream os; |
113 | os << "The input `factor` of storage_align is expected to be a positive number. However, the " |
114 | "input `factor` is " |
115 | << factor_ << ", which is out of the expected range." ; |
116 | return os.str(); |
117 | } |
118 | |
119 | static void Check(const IRModule& mod, int factor) { |
120 | if (factor <= 0) { |
121 | throw StorageAlignInvalidFactorError(mod, factor); |
122 | } |
123 | } |
124 | |
125 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
126 | IRModule mod() const final { return mod_; } |
127 | |
128 | private: |
129 | IRModule mod_; |
130 | int factor_; |
131 | }; |
132 | |
133 | class StorageAlignInvalidAnnotationError : public ScheduleError { |
134 | public: |
135 | explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) |
136 | : mod_(std::move(mod)), block_(std::move(block)) {} |
137 | |
138 | String FastErrorString() const final { |
139 | return "ScheduleError: The block annotation for storage align is expected to be an array of " |
140 | "4-integer-tuples (buffer_index, axis, factor, offset)." ; |
141 | } |
142 | |
143 | String DetailRenderTemplate() const final { |
144 | std::ostringstream os; |
145 | os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " |
146 | "(buffer_index, axis, factor, offset). However, the block annotation with key " |
147 | << attr::buffer_dim_align << " of the block {0} is " |
148 | << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected." ; |
149 | return os.str(); |
150 | } |
151 | |
152 | static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const Block& block) { |
153 | // Get existing annotation value. |
154 | auto it = block->annotations.find(attr::buffer_dim_align); |
155 | if (it != block->annotations.end()) { |
156 | if (!IsValidAnnotation(block, (*it).second)) { |
157 | throw StorageAlignInvalidAnnotationError(mod, block); |
158 | } |
159 | return Downcast<StorageAlignAnnotation>((*it).second); |
160 | } |
161 | |
162 | // Create new annotation value |
163 | StorageAlignAnnotation storage_align_annotation; |
164 | return storage_align_annotation; |
165 | } |
166 | |
167 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
168 | IRModule mod() const final { return mod_; } |
169 | |
170 | private: |
171 | static bool IsValidAnnotation(const Block& block, const ObjectRef& anno_value) { |
172 | if (!anno_value->IsInstance<ArrayNode>()) { |
173 | return false; |
174 | } |
175 | auto storage_align_annotations = Downcast<Array<ObjectRef>>(anno_value); |
176 | for (const ObjectRef& storage_align_annotation : storage_align_annotations) { |
177 | if (!storage_align_annotation->IsInstance<ArrayNode>()) { |
178 | return false; |
179 | } |
180 | auto storage_align_tuple = Downcast<Array<ObjectRef>>(storage_align_annotation); |
181 | // Check if the annotation is a 4-tuple. |
182 | if (storage_align_tuple.size() != 4) { |
183 | return false; |
184 | } |
185 | for (const ObjectRef& tuple_element : storage_align_tuple) { |
186 | if (!tuple_element->IsInstance<IntImmNode>()) { |
187 | return false; |
188 | } |
189 | } |
190 | } |
191 | return true; |
192 | } |
193 | |
194 | IRModule mod_; |
195 | Block block_; |
196 | }; |
197 | |
198 | /*! |
199 | * \brief A helper mutator which recursively mutates the old buffer's storage scope and collects |
200 | * the block sref reuse information for the following replacement. |
201 | */ |
202 | class StorageScopeMutator : private ReplaceBufferMutator { |
203 | public: |
204 | /*! |
205 | * \param allocate_site The block where `old_buffer` was allocated. |
206 | * \param old_buffer The old buffer |
207 | * \param storage_scope The storage scope to be set |
208 | * \param block_sref_reuse The block sref reuse map to be updated |
209 | * \return The new block after the mutation |
210 | */ |
211 | static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, |
212 | const String& storage_scope, Map<Block, Block>* block_sref_reuse) { |
213 | Buffer new_buffer = WithScope(old_buffer, storage_scope); |
214 | StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse); |
215 | Stmt new_block = mutator.VisitStmt(allocate_site); |
216 | return Downcast<Block>(new_block); |
217 | } |
218 | |
219 | private: |
220 | StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope, |
221 | Map<Block, Block>* block_sref_reuse) |
222 | : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {} |
223 | |
224 | MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { |
225 | auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); |
226 | if (it != buffer_var_map_.end()) { |
227 | Buffer new_target_buffer = WithScope(match_buffer->buffer, it->second.scope()); |
228 | buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer; |
229 | return MatchBufferRegion(new_target_buffer, |
230 | BufferRegion(it->second, match_buffer->source->region)); |
231 | } else { |
232 | return match_buffer; |
233 | } |
234 | } |
235 | }; |
236 | |
237 | void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, |
238 | int factor, int offset) { |
239 | const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); |
240 | Buffer buffer = |
241 | GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, BufferIndexType::kWrite); |
242 | StorageAlignInvalidFactorError::Check(self->mod, factor); |
243 | axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); |
244 | NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); |
245 | |
246 | // Step 1: Get existing or create new annotation value. |
247 | StorageAlignAnnotation storage_align_annotation = |
248 | StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, |
249 | GetRef<Block>(block_ptr)); |
250 | |
251 | // Step 2: Update the annotation value |
252 | bool found = false; |
253 | StorageAlignTuple new_storage_align_tuple{Integer(buffer_index), Integer(axis), Integer(factor), |
254 | Integer(offset)}; |
255 | for (size_t j = 0; j < storage_align_annotation.size(); ++j) { |
256 | const auto& storage_align_tuple = storage_align_annotation[j]; |
257 | ICHECK(storage_align_tuple.size() == 4); |
258 | if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] == axis) { |
259 | storage_align_annotation.Set(j, std::move(new_storage_align_tuple)); |
260 | found = true; |
261 | break; |
262 | } |
263 | } |
264 | if (!found) { |
265 | storage_align_annotation.push_back(std::move(new_storage_align_tuple)); |
266 | } |
267 | |
268 | // Step 3: Replace the block with the new annotation |
269 | Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); |
270 | self->Replace(block_sref, new_block, {{GetRef<Block>(block_ptr), new_block}}); |
271 | } |
272 | |
273 | void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
274 | const String& storage_scope) { |
275 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
276 | Buffer buffer = |
277 | GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite); |
278 | |
279 | // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return. |
280 | if (buffer.scope() == storage_scope) { |
281 | return; |
282 | } |
283 | |
284 | // Step 2. Throw an error if the input storage scope is invalid. |
285 | CheckStorageScope(self, storage_scope); |
286 | |
287 | // Step 3. Get the allocation site of the target buffer. |
288 | StmtSRef alloc_site_sref = |
289 | NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); |
290 | const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); |
291 | |
292 | // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given |
293 | // storage scope. In the meanwhile, collect the block sref reuse information. |
294 | Map<Block, Block> block_reuse_map; |
295 | Block new_block = StorageScopeMutator::Mutate(GetRef<Block>(alloc_site), buffer, storage_scope, |
296 | &block_reuse_map); |
297 | self->Replace(alloc_site_sref, new_block, block_reuse_map); |
298 | } |
299 | |
300 | /******** InstructionKind Registration ********/ |
301 | |
302 | struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> { |
303 | static constexpr const char* kName = "StorageAlign" ; |
304 | static constexpr bool kIsPure = false; |
305 | |
306 | private: |
307 | static constexpr size_t kNumInputs = 1; |
308 | static constexpr size_t kNumAttrs = 4; |
309 | static constexpr size_t kNumDecisions = 0; |
310 | |
311 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, |
312 | Integer axis, Integer factor, Integer offset) { |
313 | return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value, |
314 | offset->value); |
315 | } |
316 | |
317 | static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index, |
318 | Integer axis, Integer factor, Integer offset) { |
319 | PythonAPICall py("storage_align" ); |
320 | py.Input("block" , block_rv); |
321 | py.Input("buffer_index" , buffer_index); |
322 | py.Input("axis" , axis); |
323 | py.Input("factor" , factor); |
324 | py.Input("offset" , offset); |
325 | return py.Str(); |
326 | } |
327 | |
328 | template <typename> |
329 | friend struct ::tvm::tir::UnpackedInstTraits; |
330 | }; |
331 | |
332 | struct SetScopeTraits : public UnpackedInstTraits<SetScopeTraits> { |
333 | static constexpr const char* kName = "SetScope" ; |
334 | static constexpr bool kIsPure = false; |
335 | |
336 | private: |
337 | static constexpr size_t kNumInputs = 1; |
338 | static constexpr size_t kNumAttrs = 2; |
339 | static constexpr size_t kNumDecisions = 0; |
340 | |
341 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, |
342 | String storage_scope) { |
343 | return sch->SetScope(block_rv, buffer_index->value, storage_scope); |
344 | } |
345 | |
346 | static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index, |
347 | String storage_scope) { |
348 | PythonAPICall py("set_scope" ); |
349 | py.Input("block" , block_rv); |
350 | py.Input("buffer_index" , buffer_index); |
351 | py.Input("storage_scope" , storage_scope); |
352 | return py.Str(); |
353 | } |
354 | |
355 | template <typename> |
356 | friend struct ::tvm::tir::UnpackedInstTraits; |
357 | }; |
358 | |
359 | TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); |
360 | TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); |
361 | |
362 | } // namespace tir |
363 | } // namespace tvm |
364 | |