1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file remove_weight_layout_rewrite_block.cc |
22 | * \brief Remove weight layout rewrite block before benchmark |
23 | */ |
24 | |
25 | #include <tvm/tir/index_map.h> |
26 | #include <tvm/tir/op.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <unordered_set> |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | class RemoveLayoutRewriteBlock : public StmtMutator { |
36 | public: |
37 | static std::tuple<PrimFunc, Map<Buffer, Buffer>, std::unordered_map<const VarNode*, IndexMap>, |
38 | std::unordered_map<const VarNode*, Array<PrimExpr>>> |
39 | Rewrite(PrimFunc f) { |
40 | RemoveLayoutRewriteBlock rewriter; |
41 | |
42 | PrimFuncNode* n = f.CopyOnWrite(); |
43 | n->body = rewriter(std::move(n->body)); |
44 | return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_, |
45 | rewriter.buffer_var_to_rewritten_shape_); |
46 | } |
47 | |
48 | private: |
49 | Stmt VisitStmt_(const BlockNode* op) final { |
50 | Block block = Downcast<Block>(StmtMutator::VisitStmt_(op)); |
51 | |
52 | auto it = block->annotations.find(attr::meta_schedule_layout_rewrite_preproc); |
53 | if (it == block->annotations.end() || !is_one(Downcast<PrimExpr>((*it).second))) { |
54 | // The block is not a weight layout block |
55 | // Remove allocates if needed |
56 | Array<Buffer> alloc_buffers; |
57 | for (const Buffer& buffer : block->alloc_buffers) { |
58 | if (!rewritten_buffers_.count(buffer)) { |
59 | alloc_buffers.push_back(buffer); |
60 | } |
61 | } |
62 | if (alloc_buffers.size() < block->alloc_buffers.size()) { |
63 | auto n = CopyOnWrite(block.get()); |
64 | n->alloc_buffers = std::move(alloc_buffers); |
65 | return Stmt(n); |
66 | } else { |
67 | return std::move(block); |
68 | } |
69 | } |
70 | |
71 | // Step 0. Checking block attrs |
72 | ICHECK(block->alloc_buffers.empty()); |
73 | ICHECK(block->match_buffers.empty()); |
74 | |
75 | // Step 1. Checking the body is a BufferStore |
76 | const auto* store = block->body.as<BufferStoreNode>(); |
77 | ICHECK(store); |
78 | |
79 | // Step 2. Checking the rhs of buffer store is a BufferLoad |
80 | const auto* load = store->value.as<BufferLoadNode>(); |
81 | ICHECK(load); |
82 | |
83 | // Step 3. Update Buffer |
84 | buf_map_.Set(load->buffer, store->buffer); |
85 | rewritten_buffers_.insert(store->buffer); |
86 | |
87 | // Step 4. Set block body as no_op |
88 | auto n = CopyOnWrite(block.get()); |
89 | n->body = std::move(Evaluate(0)); |
90 | n->reads = {}; |
91 | n->writes = {}; |
92 | |
93 | Array<Var> load_indices; |
94 | for (auto ind : load->indices) { |
95 | ICHECK(ind->IsInstance<VarNode>()); |
96 | load_indices.push_back(Downcast<Var>(ind)); |
97 | } |
98 | buffer_var_to_index_map_[load->buffer->data.get()] = IndexMap(load_indices, store->indices); |
99 | |
100 | buffer_var_to_rewritten_shape_[load->buffer->data.get()] = store->buffer->shape; |
101 | |
102 | return Stmt(n); |
103 | } |
104 | |
105 | private: |
106 | /*! \brief The buffer map from original layout buffer to rewritten buffer */ |
107 | Map<Buffer, Buffer> buf_map_; |
108 | /*! \brief The buffer map from original layout buffer to rewritten buffer */ |
109 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> rewritten_buffers_; |
110 | /*! \brief Maps a buffer load to an index map associated with the load / store |
111 | in a layout rewrite block. */ |
112 | std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_; |
113 | /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ |
114 | std::unordered_map<const VarNode*, Array<PrimExpr>> buffer_var_to_rewritten_shape_; |
115 | }; |
116 | |
117 | // After RemoveLayoutRewriteBlock, the body of a compute update block references a |
118 | // non-existant buffer. For example, fused_constant_2_global below is originally a |
119 | // cache_read buffer, whose allocation is removed by RemoveLayoutRewriteBlock: |
120 | // |
121 | // constant fused_constant_2[float32 * 3 * 3 * 64 * 64] |
122 | // conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2_global[ry, |
123 | // floordiv(rc, 32), |
124 | // floordiv(ff, 16), |
125 | // rx, |
126 | // floormod(rc, 32), |
127 | // floormod(ff, 16)])) |
128 | // |
129 | // When cache_read is reading from AllocateConstant, we need to replace the reference |
130 | // to fused_constant_2_global with the corresponding transformed AllocateConstant. |
131 | // To do that, we manually rewrite the original constant using the associated index map, |
132 | // and let the body of the compute block to load from the rewritten constant. |
133 | // |
134 | // After this transformation, the example above looks like: |
135 | // |
136 | // constant fused_constant_2[float32 * 3 * 2 * 4 * 3 * 32 * 16] |
137 | // conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2[ry, |
138 | // floordiv(rc, 32), |
139 | // floordiv(ff, 16), |
140 | // rx, |
141 | // floormod(rc, 32), |
142 | // floormod(ff, 16)])) |
143 | |
144 | using BufferVarMap = std::unordered_map<const tir::VarNode*, const tir::VarNode*>; |
145 | |
146 | class AllocateConstRewrite : public StmtExprMutator { |
147 | public: |
148 | AllocateConstRewrite( |
149 | const BufferVarMap& buffer_var_map, |
150 | const std::unordered_map<const VarNode*, IndexMap>& buffer_var_to_index_map, |
151 | const std::unordered_map<const VarNode*, Array<PrimExpr>>& buffer_var_to_rewritten_shape, |
152 | bool skip_ndarray_rewrite) |
153 | : buffer_var_map_(buffer_var_map), |
154 | buffer_var_to_index_map_(buffer_var_to_index_map), |
155 | buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape), |
156 | skip_ndarray_rewrite_(skip_ndarray_rewrite) {} |
157 | |
158 | private: |
159 | Stmt VisitStmt_(const BlockNode* op) final { |
160 | Block block = Downcast<Block>(StmtMutator::VisitStmt_(op)); |
161 | auto n = CopyOnWrite(block.get()); |
162 | Array<BufferRegion> new_reads; |
163 | for (auto read_region : op->reads) { |
164 | if (auto it = new_load_buf_.find(read_region->buffer->data.get()); |
165 | it != new_load_buf_.end()) { |
166 | new_reads.push_back(BufferRegion(it->second, read_region->region)); |
167 | } else { |
168 | new_reads.push_back(read_region); |
169 | } |
170 | } |
171 | n->reads = new_reads; |
172 | return Stmt(n); |
173 | } |
174 | |
175 | Stmt VisitStmt_(const AllocateConstNode* alloc) final { |
176 | if (auto it = buffer_var_to_index_map_.find(alloc->buffer_var.get()); |
177 | it != buffer_var_to_index_map_.end()) { |
178 | ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get())); |
179 | auto new_body = StmtMutator::VisitStmt(alloc->body); |
180 | auto rewritten_ndarray = RewriteNDArray( |
181 | alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]); |
182 | Array<PrimExpr> rewritten_extents; |
183 | for (auto s : rewritten_ndarray.Shape()) { |
184 | rewritten_extents.push_back(PrimExpr(static_cast<int>(s))); |
185 | } |
186 | return AllocateConst(alloc->buffer_var, alloc->dtype, rewritten_extents, rewritten_ndarray, |
187 | new_body, alloc->annotations, alloc->span); |
188 | } |
189 | return StmtMutator::VisitStmt_(alloc); |
190 | } |
191 | |
192 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
193 | if (auto it = buffer_var_map_.find(op->buffer->data.get()); it != buffer_var_map_.end()) { |
194 | auto new_buffer = |
195 | Buffer(GetRef<Var>(it->second), op->buffer->dtype, op->buffer->shape, op->buffer->strides, |
196 | op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, |
197 | op->buffer->offset_factor, op->buffer->buffer_type); |
198 | new_load_buf_[op->buffer->data.get()] = new_buffer; |
199 | return BufferLoad(new_buffer, op->indices); |
200 | } |
201 | return ExprMutator::VisitExpr_(op); |
202 | } |
203 | |
204 | runtime::NDArray RewriteNDArray(runtime::NDArray src, const IndexMap& index_map, |
205 | const Array<PrimExpr>& dst_shape) { |
206 | if (skip_ndarray_rewrite_) { |
207 | // Only the shape of the destination array needs to be correct. |
208 | std::vector<int64_t> dst_shape_int; |
209 | for (auto s : dst_shape) { |
210 | ICHECK(s->IsInstance<IntImmNode>()); |
211 | dst_shape_int.push_back(s.as<IntImmNode>()->value); |
212 | } |
213 | return src.CreateView(dst_shape_int, src.DataType()); |
214 | } else { |
215 | return index_map->MapNDArray(src); |
216 | } |
217 | } |
218 | |
219 | /*! \brief Maps a buffer store to a load in a layout rewrite block */ |
220 | BufferVarMap buffer_var_map_; |
221 | /*! \brief Maps a buffer load to an index map associated with the load / store |
222 | in a layout rewrite block. */ |
223 | std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_; |
224 | /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ |
225 | std::unordered_map<const VarNode*, Array<PrimExpr>> buffer_var_to_rewritten_shape_; |
226 | /*! \brief Maps load buffer variables to newly created buffers */ |
227 | std::unordered_map<const VarNode*, Buffer> new_load_buf_; |
228 | /*! \brief Whether or not to skip rewriting of NDArray contents */ |
229 | bool skip_ndarray_rewrite_; |
230 | }; |
231 | |
232 | class CollectAllocateConstBufferVars : public StmtVisitor { |
233 | public: |
234 | void VisitStmt_(const AllocateConstNode* alloc) final { |
235 | StmtVisitor::VisitStmt_(alloc); |
236 | constant_buf_var.insert(alloc->buffer_var.get()); |
237 | } |
238 | |
239 | std::unordered_set<const VarNode*> constant_buf_var; |
240 | }; |
241 | |
242 | class WeightLayoutRewriteBlockRemover : public StmtMutator { |
243 | public: |
244 | static PrimFunc Remove(PrimFunc f, bool skip_ndarray_rewrite) { |
245 | CollectAllocateConstBufferVars collector; |
246 | collector(f->body); |
247 | |
248 | auto [f_, buf_map, buffer_var_to_index_map, buffer_var_to_rewritten_shape] = |
249 | RemoveLayoutRewriteBlock().Rewrite(f); |
250 | |
251 | BufferVarMap buffer_var_map; |
252 | for (const auto& [load_buf, store_buf] : buf_map) { |
253 | if (collector.constant_buf_var.find(load_buf->data.get()) != |
254 | collector.constant_buf_var.end()) { |
255 | buffer_var_map[store_buf->data.get()] = load_buf->data.get(); |
256 | } |
257 | } |
258 | |
259 | PrimFuncNode* n = f_.CopyOnWrite(); |
260 | |
261 | AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map, |
262 | buffer_var_to_rewritten_shape, skip_ndarray_rewrite); |
263 | n->body = rewriter(std::move(n->body)); |
264 | |
265 | Map<tir::Var, Buffer> buffer_map; |
266 | for (const auto& [param, buffer] : f_->buffer_map) { |
267 | auto it = buf_map.find(buffer); |
268 | if (it != buf_map.end()) { |
269 | buffer_map.Set(param, (*it).second); |
270 | } else { |
271 | buffer_map.Set(param, buffer); |
272 | } |
273 | } |
274 | n->buffer_map = std::move(buffer_map); |
275 | return f_; |
276 | } |
277 | }; |
278 | |
279 | namespace transform { |
280 | |
281 | Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) { |
282 | auto pass_func = [skip_ndarray_rewrite](PrimFunc f, IRModule m, PassContext ctx) { |
283 | return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_ndarray_rewrite); |
284 | }; |
285 | return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock" , {}); |
286 | } |
287 | |
288 | TVM_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock" ) |
289 | .set_body_typed(RemoveWeightLayoutRewriteBlock); |
290 | |
291 | } // namespace transform |
292 | |
293 | } // namespace tir |
294 | } // namespace tvm |
295 | |