1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file lower_async_dma.cc |
22 | */ |
23 | |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/iter_affine_map.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | |
29 | #include "ir_utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | class AsyncDMALowerer : public StmtExprMutator { |
35 | public: |
36 | explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {} |
37 | |
38 | // Create member statement to track a mapping from iter var to iter range |
39 | Stmt VisitStmt_(const ForNode* op) final { |
40 | input_iters.Set(op->loop_var, Range(op->min, op->extent)); |
41 | return StmtExprMutator::VisitStmt_(op); |
42 | } |
43 | |
44 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
45 | // Convert this, for example: |
46 | // attr [0] "async_wait_queue_scope" = 0; |
47 | // attr [0] "async_wait_inflight_count" = 0; |
48 | // |
49 | // To this: |
50 | // @tir.dma_wait( |
51 | // 0, /* queue id */ |
52 | // 0, /* in flight count */ |
53 | // dtype=int32 |
54 | // ) |
55 | if (op->attr_key == tir::attr::async_wait_queue_scope) { |
56 | // get queue ID |
57 | auto queue_id_node = op->value.as<IntImmNode>(); |
58 | ICHECK(queue_id_node); |
59 | int queue_id = queue_id_node->value; |
60 | |
61 | // abort if we have not seen this queue ID in `copy` transform |
62 | if (queue_ids_.find(queue_id) == queue_ids_.end()) { |
63 | DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the " |
64 | "`async_wait_queue_scope` transform has not been previously observed in the " |
65 | "`async_commit_queue_scope` transform" ; |
66 | return StmtExprMutator::VisitStmt_(op); |
67 | } |
68 | |
69 | auto async_wait = op->body.as<AttrStmtNode>(); |
70 | if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) { |
71 | DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " |
72 | "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key " |
73 | "`async_wait_inflight_count`" ; |
74 | return StmtExprMutator::VisitStmt_(op); |
75 | } |
76 | |
77 | auto call_dma_wait = |
78 | Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value})); |
79 | |
80 | // concatenate the call with the body and return |
81 | return SeqStmt({call_dma_wait, StmtExprMutator::VisitStmt(async_wait->body)}); |
82 | |
83 | // Convert this, for example: |
84 | // attr [0] "async_commit_queue_scope" = 0; |
85 | // attr [0] "async_scope" = 1; |
86 | // for (ax0: int32, 0, 128) { |
87 | // A_global[ax0] = A[ax0] |
88 | // } |
89 | // |
90 | // To this: |
91 | // @tir.dma_copy( |
92 | // 0, /* queue id */ |
93 | // @tir.address_of(A_global[0], dtype=handle), |
94 | // @tir.address_of(A[0], dtype=handle), |
95 | // 128, /* size */ |
96 | // dtype=int32 |
97 | // ) |
98 | } else if (op->attr_key == tir::attr::async_commit_queue_scope) { |
99 | // get queue ID |
100 | auto queue_id_node = op->value.as<IntImmNode>(); |
101 | ICHECK(queue_id_node); |
102 | int queue_id = queue_id_node->value; |
103 | |
104 | // walk the graph to verify this is a mem copy ... |
105 | // 1) async_commit_queue_scope contains async_scope |
106 | auto async_scope = op->body.as<AttrStmtNode>(); |
107 | if (!async_scope || async_scope->attr_key != tir::attr::async_scope) { |
108 | DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " |
109 | "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key " |
110 | "`async_scope`" ; |
111 | return StmtExprMutator::VisitStmt_(op); |
112 | } |
113 | |
114 | // 2) async_scope contains single for loop |
115 | auto for_loop = async_scope->body.as<ForNode>(); |
116 | if (!for_loop) { |
117 | DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " |
118 | "`async_scope` does not contain a single `ForNode`" ; |
119 | return StmtExprMutator::VisitStmt_(op); |
120 | } |
121 | |
122 | // Add the current loop to the input iters mapping. |
123 | input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent)); |
124 | |
125 | // 3) for loop contains buffer store with single index |
126 | auto bufferstorenode = for_loop->body.as<BufferStoreNode>(); |
127 | if (!bufferstorenode || bufferstorenode->indices.size() != 1) { |
128 | DLOG(INFO) |
129 | << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a " |
130 | "single `BufferStoreNode` with a single index variable" ; |
131 | return StmtExprMutator::VisitStmt_(op); |
132 | } |
133 | |
134 | // 4) buffer store value is a buffer load with single index |
135 | auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>(); |
136 | if (!bufferloadnode || bufferloadnode->indices.size() != 1) { |
137 | DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a " |
138 | "single `BufferLoadNode` with a single index variable" ; |
139 | return StmtExprMutator::VisitStmt_(op); |
140 | } |
141 | |
142 | // get store buffer; assert it exists and is contiguous given it uses a single index |
143 | auto bufferstore = bufferstorenode->buffer.as<BufferNode>(); |
144 | ICHECK(bufferstore && bufferstore->strides.empty()); |
145 | |
146 | // get load buffer; assert it exists and is contiguous given it uses a single index |
147 | auto bufferload = bufferloadnode->buffer.as<BufferNode>(); |
148 | ICHECK(bufferload && bufferload->strides.empty()); |
149 | |
150 | // we will be replacing the entire for loop including its index |
151 | // with a DMA copy instrinsic that spans the entire index space of the for loop |
152 | // so we will need to replace the for loop index with value zero in the buffer indices |
153 | // thus we eliminate the index from the expression so the DMA copy receives the buffer range |
154 | // base address |
155 | Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}}; |
156 | |
157 | // map loop variable to zero for the store index & simplify |
158 | Array<PrimExpr> store_index = bufferstorenode->indices; |
159 | |
160 | // Use DetectIterMap to detect whether store index is non-contiguous. |
161 | arith::Analyzer analyzer; |
162 | auto store_iter_map = DetectIterMap(store_index, input_iters, 1, |
163 | arith::IterMapLevel::Surjective, &analyzer, false); |
164 | if (!store_iter_map->errors.empty()) { |
165 | LOG(FATAL) |
166 | << "Unable to lower async dma for non contiguous memory access with store index: " |
167 | << store_index; |
168 | } |
169 | |
170 | store_index.MutateByApply([&](PrimExpr expr) { |
171 | arith::Analyzer analyzer; |
172 | return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); |
173 | }); |
174 | |
175 | // map loop variable to zero for the load index & simplify |
176 | Array<PrimExpr> load_index = bufferloadnode->indices; |
177 | |
178 | // Use DetectIterMap to detect whether load index is non-contiguous. |
179 | auto load_iter_map = DetectIterMap(load_index, input_iters, 1, |
180 | arith::IterMapLevel::Surjective, &analyzer, false); |
181 | if (!load_iter_map->errors.empty()) { |
182 | LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: " |
183 | << load_index; |
184 | } |
185 | |
186 | load_index.MutateByApply([&](PrimExpr expr) { |
187 | arith::Analyzer analyzer; |
188 | return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); |
189 | }); |
190 | |
191 | // now that we are about to perform the `copy` transform |
192 | // save queue ID for inspection in `wait` transform |
193 | queue_ids_.insert(queue_id); |
194 | |
195 | return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), |
196 | {queue_id, |
197 | Call(DataType::Handle(), builtin::address_of(), |
198 | {BufferLoad(bufferstorenode->buffer, store_index)}), |
199 | Call(DataType::Handle(), builtin::address_of(), |
200 | {BufferLoad(bufferloadnode->buffer, load_index)}), |
201 | for_loop->extent * bufferloadnode->dtype.bytes(), dma_bypass_cache_})); |
202 | } |
203 | return StmtExprMutator::VisitStmt_(op); |
204 | } |
205 | |
206 | private: |
207 | std::set<int> queue_ids_; |
208 | bool dma_bypass_cache_; |
209 | Map<Var, Range> input_iters = Map<Var, Range>(); |
210 | }; |
211 | |
212 | namespace transform { |
213 | |
214 | Pass LowerAsyncDMA() { |
215 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
216 | auto fptr = f.CopyOnWrite(); |
217 | bool dma_bypass_cache = |
218 | ctx->GetConfig<Bool>("tir.experimental_dma_bypass_cache" , Bool(false)).value(); |
219 | fptr->body = AsyncDMALowerer(dma_bypass_cache)(std::move(fptr->body)); |
220 | return f; |
221 | }; |
222 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA" , {}); |
223 | } |
224 | |
225 | TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA" ).set_body_typed(LowerAsyncDMA); |
226 | } // namespace transform |
227 | |
228 | } // namespace tir |
229 | } // namespace tvm |
230 | |