1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file remove_weight_layout_rewrite_block.cc
22 * \brief Remove weight layout rewrite block before benchmark
23 */
24
25#include <tvm/tir/index_map.h>
26#include <tvm/tir/op.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include <unordered_set>
31
32namespace tvm {
33namespace tir {
34
35class RemoveLayoutRewriteBlock : public StmtMutator {
36 public:
37 static std::tuple<PrimFunc, Map<Buffer, Buffer>, std::unordered_map<const VarNode*, IndexMap>,
38 std::unordered_map<const VarNode*, Array<PrimExpr>>>
39 Rewrite(PrimFunc f) {
40 RemoveLayoutRewriteBlock rewriter;
41
42 PrimFuncNode* n = f.CopyOnWrite();
43 n->body = rewriter(std::move(n->body));
44 return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_,
45 rewriter.buffer_var_to_rewritten_shape_);
46 }
47
48 private:
49 Stmt VisitStmt_(const BlockNode* op) final {
50 Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
51
52 auto it = block->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
53 if (it == block->annotations.end() || !is_one(Downcast<PrimExpr>((*it).second))) {
54 // The block is not a weight layout block
55 // Remove allocates if needed
56 Array<Buffer> alloc_buffers;
57 for (const Buffer& buffer : block->alloc_buffers) {
58 if (!rewritten_buffers_.count(buffer)) {
59 alloc_buffers.push_back(buffer);
60 }
61 }
62 if (alloc_buffers.size() < block->alloc_buffers.size()) {
63 auto n = CopyOnWrite(block.get());
64 n->alloc_buffers = std::move(alloc_buffers);
65 return Stmt(n);
66 } else {
67 return std::move(block);
68 }
69 }
70
71 // Step 0. Checking block attrs
72 ICHECK(block->alloc_buffers.empty());
73 ICHECK(block->match_buffers.empty());
74
75 // Step 1. Checking the body is a BufferStore
76 const auto* store = block->body.as<BufferStoreNode>();
77 ICHECK(store);
78
79 // Step 2. Checking the rhs of buffer store is a BufferLoad
80 const auto* load = store->value.as<BufferLoadNode>();
81 ICHECK(load);
82
83 // Step 3. Update Buffer
84 buf_map_.Set(load->buffer, store->buffer);
85 rewritten_buffers_.insert(store->buffer);
86
87 // Step 4. Set block body as no_op
88 auto n = CopyOnWrite(block.get());
89 n->body = std::move(Evaluate(0));
90 n->reads = {};
91 n->writes = {};
92
93 Array<Var> load_indices;
94 for (auto ind : load->indices) {
95 ICHECK(ind->IsInstance<VarNode>());
96 load_indices.push_back(Downcast<Var>(ind));
97 }
98 buffer_var_to_index_map_[load->buffer->data.get()] = IndexMap(load_indices, store->indices);
99
100 buffer_var_to_rewritten_shape_[load->buffer->data.get()] = store->buffer->shape;
101
102 return Stmt(n);
103 }
104
105 private:
106 /*! \brief The buffer map from original layout buffer to rewritten buffer */
107 Map<Buffer, Buffer> buf_map_;
108 /*! \brief The buffer map from original layout buffer to rewritten buffer */
109 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> rewritten_buffers_;
110 /*! \brief Maps a buffer load to an index map associated with the load / store
111 in a layout rewrite block. */
112 std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_;
113 /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */
114 std::unordered_map<const VarNode*, Array<PrimExpr>> buffer_var_to_rewritten_shape_;
115};
116
117// After RemoveLayoutRewriteBlock, the body of a compute update block references a
118// non-existant buffer. For example, fused_constant_2_global below is originally a
119// cache_read buffer, whose allocation is removed by RemoveLayoutRewriteBlock:
120//
121// constant fused_constant_2[float32 * 3 * 3 * 64 * 64]
122// conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2_global[ry,
123// floordiv(rc, 32),
124// floordiv(ff, 16),
125// rx,
126// floormod(rc, 32),
127// floormod(ff, 16)]))
128//
129// When cache_read is reading from AllocateConstant, we need to replace the reference
130// to fused_constant_2_global with the corresponding transformed AllocateConstant.
131// To do that, we manually rewrite the original constant using the associated index map,
132// and let the body of the compute block to load from the rewritten constant.
133//
134// After this transformation, the example above looks like:
135//
136// constant fused_constant_2[float32 * 3 * 2 * 4 * 3 * 32 * 16]
137// conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2[ry,
138// floordiv(rc, 32),
139// floordiv(ff, 16),
140// rx,
141// floormod(rc, 32),
142// floormod(ff, 16)]))
143
144using BufferVarMap = std::unordered_map<const tir::VarNode*, const tir::VarNode*>;
145
146class AllocateConstRewrite : public StmtExprMutator {
147 public:
148 AllocateConstRewrite(
149 const BufferVarMap& buffer_var_map,
150 const std::unordered_map<const VarNode*, IndexMap>& buffer_var_to_index_map,
151 const std::unordered_map<const VarNode*, Array<PrimExpr>>& buffer_var_to_rewritten_shape,
152 bool skip_ndarray_rewrite)
153 : buffer_var_map_(buffer_var_map),
154 buffer_var_to_index_map_(buffer_var_to_index_map),
155 buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape),
156 skip_ndarray_rewrite_(skip_ndarray_rewrite) {}
157
158 private:
159 Stmt VisitStmt_(const BlockNode* op) final {
160 Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
161 auto n = CopyOnWrite(block.get());
162 Array<BufferRegion> new_reads;
163 for (auto read_region : op->reads) {
164 if (auto it = new_load_buf_.find(read_region->buffer->data.get());
165 it != new_load_buf_.end()) {
166 new_reads.push_back(BufferRegion(it->second, read_region->region));
167 } else {
168 new_reads.push_back(read_region);
169 }
170 }
171 n->reads = new_reads;
172 return Stmt(n);
173 }
174
175 Stmt VisitStmt_(const AllocateConstNode* alloc) final {
176 if (auto it = buffer_var_to_index_map_.find(alloc->buffer_var.get());
177 it != buffer_var_to_index_map_.end()) {
178 ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get()));
179 auto new_body = StmtMutator::VisitStmt(alloc->body);
180 auto rewritten_ndarray = RewriteNDArray(
181 alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]);
182 Array<PrimExpr> rewritten_extents;
183 for (auto s : rewritten_ndarray.Shape()) {
184 rewritten_extents.push_back(PrimExpr(static_cast<int>(s)));
185 }
186 return AllocateConst(alloc->buffer_var, alloc->dtype, rewritten_extents, rewritten_ndarray,
187 new_body, alloc->annotations, alloc->span);
188 }
189 return StmtMutator::VisitStmt_(alloc);
190 }
191
192 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
193 if (auto it = buffer_var_map_.find(op->buffer->data.get()); it != buffer_var_map_.end()) {
194 auto new_buffer =
195 Buffer(GetRef<Var>(it->second), op->buffer->dtype, op->buffer->shape, op->buffer->strides,
196 op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment,
197 op->buffer->offset_factor, op->buffer->buffer_type);
198 new_load_buf_[op->buffer->data.get()] = new_buffer;
199 return BufferLoad(new_buffer, op->indices);
200 }
201 return ExprMutator::VisitExpr_(op);
202 }
203
204 runtime::NDArray RewriteNDArray(runtime::NDArray src, const IndexMap& index_map,
205 const Array<PrimExpr>& dst_shape) {
206 if (skip_ndarray_rewrite_) {
207 // Only the shape of the destination array needs to be correct.
208 std::vector<int64_t> dst_shape_int;
209 for (auto s : dst_shape) {
210 ICHECK(s->IsInstance<IntImmNode>());
211 dst_shape_int.push_back(s.as<IntImmNode>()->value);
212 }
213 return src.CreateView(dst_shape_int, src.DataType());
214 } else {
215 return index_map->MapNDArray(src);
216 }
217 }
218
219 /*! \brief Maps a buffer store to a load in a layout rewrite block */
220 BufferVarMap buffer_var_map_;
221 /*! \brief Maps a buffer load to an index map associated with the load / store
222 in a layout rewrite block. */
223 std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_;
224 /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */
225 std::unordered_map<const VarNode*, Array<PrimExpr>> buffer_var_to_rewritten_shape_;
226 /*! \brief Maps load buffer variables to newly created buffers */
227 std::unordered_map<const VarNode*, Buffer> new_load_buf_;
228 /*! \brief Whether or not to skip rewriting of NDArray contents */
229 bool skip_ndarray_rewrite_;
230};
231
232class CollectAllocateConstBufferVars : public StmtVisitor {
233 public:
234 void VisitStmt_(const AllocateConstNode* alloc) final {
235 StmtVisitor::VisitStmt_(alloc);
236 constant_buf_var.insert(alloc->buffer_var.get());
237 }
238
239 std::unordered_set<const VarNode*> constant_buf_var;
240};
241
242class WeightLayoutRewriteBlockRemover : public StmtMutator {
243 public:
244 static PrimFunc Remove(PrimFunc f, bool skip_ndarray_rewrite) {
245 CollectAllocateConstBufferVars collector;
246 collector(f->body);
247
248 auto [f_, buf_map, buffer_var_to_index_map, buffer_var_to_rewritten_shape] =
249 RemoveLayoutRewriteBlock().Rewrite(f);
250
251 BufferVarMap buffer_var_map;
252 for (const auto& [load_buf, store_buf] : buf_map) {
253 if (collector.constant_buf_var.find(load_buf->data.get()) !=
254 collector.constant_buf_var.end()) {
255 buffer_var_map[store_buf->data.get()] = load_buf->data.get();
256 }
257 }
258
259 PrimFuncNode* n = f_.CopyOnWrite();
260
261 AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map,
262 buffer_var_to_rewritten_shape, skip_ndarray_rewrite);
263 n->body = rewriter(std::move(n->body));
264
265 Map<tir::Var, Buffer> buffer_map;
266 for (const auto& [param, buffer] : f_->buffer_map) {
267 auto it = buf_map.find(buffer);
268 if (it != buf_map.end()) {
269 buffer_map.Set(param, (*it).second);
270 } else {
271 buffer_map.Set(param, buffer);
272 }
273 }
274 n->buffer_map = std::move(buffer_map);
275 return f_;
276 }
277};
278
279namespace transform {
280
281Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) {
282 auto pass_func = [skip_ndarray_rewrite](PrimFunc f, IRModule m, PassContext ctx) {
283 return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_ndarray_rewrite);
284 };
285 return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {});
286}
287
288TVM_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock")
289 .set_body_typed(RemoveWeightLayoutRewriteBlock);
290
291} // namespace transform
292
293} // namespace tir
294} // namespace tvm
295