1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file lower_async_dma.cc
22 */
23
24#include <tvm/arith/analyzer.h>
25#include <tvm/arith/iter_affine_map.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28
29#include "ir_utils.h"
30
31namespace tvm {
32namespace tir {
33
34class AsyncDMALowerer : public StmtExprMutator {
35 public:
36 explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {}
37
38 // Create member statement to track a mapping from iter var to iter range
39 Stmt VisitStmt_(const ForNode* op) final {
40 input_iters.Set(op->loop_var, Range(op->min, op->extent));
41 return StmtExprMutator::VisitStmt_(op);
42 }
43
44 Stmt VisitStmt_(const AttrStmtNode* op) final {
45 // Convert this, for example:
46 // attr [0] "async_wait_queue_scope" = 0;
47 // attr [0] "async_wait_inflight_count" = 0;
48 //
49 // To this:
50 // @tir.dma_wait(
51 // 0, /* queue id */
52 // 0, /* in flight count */
53 // dtype=int32
54 // )
55 if (op->attr_key == tir::attr::async_wait_queue_scope) {
56 // get queue ID
57 auto queue_id_node = op->value.as<IntImmNode>();
58 ICHECK(queue_id_node);
59 int queue_id = queue_id_node->value;
60
61 // abort if we have not seen this queue ID in `copy` transform
62 if (queue_ids_.find(queue_id) == queue_ids_.end()) {
63 DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
64 "`async_wait_queue_scope` transform has not been previously observed in the "
65 "`async_commit_queue_scope` transform";
66 return StmtExprMutator::VisitStmt_(op);
67 }
68
69 auto async_wait = op->body.as<AttrStmtNode>();
70 if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) {
71 DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
72 "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key "
73 "`async_wait_inflight_count`";
74 return StmtExprMutator::VisitStmt_(op);
75 }
76
77 auto call_dma_wait =
78 Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));
79
80 // concatenate the call with the body and return
81 return SeqStmt({call_dma_wait, StmtExprMutator::VisitStmt(async_wait->body)});
82
83 // Convert this, for example:
84 // attr [0] "async_commit_queue_scope" = 0;
85 // attr [0] "async_scope" = 1;
86 // for (ax0: int32, 0, 128) {
87 // A_global[ax0] = A[ax0]
88 // }
89 //
90 // To this:
91 // @tir.dma_copy(
92 // 0, /* queue id */
93 // @tir.address_of(A_global[0], dtype=handle),
94 // @tir.address_of(A[0], dtype=handle),
95 // 128, /* size */
96 // dtype=int32
97 // )
98 } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
99 // get queue ID
100 auto queue_id_node = op->value.as<IntImmNode>();
101 ICHECK(queue_id_node);
102 int queue_id = queue_id_node->value;
103
104 // walk the graph to verify this is a mem copy ...
105 // 1) async_commit_queue_scope contains async_scope
106 auto async_scope = op->body.as<AttrStmtNode>();
107 if (!async_scope || async_scope->attr_key != tir::attr::async_scope) {
108 DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
109 "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key "
110 "`async_scope`";
111 return StmtExprMutator::VisitStmt_(op);
112 }
113
114 // 2) async_scope contains single for loop
115 auto for_loop = async_scope->body.as<ForNode>();
116 if (!for_loop) {
117 DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
118 "`async_scope` does not contain a single `ForNode`";
119 return StmtExprMutator::VisitStmt_(op);
120 }
121
122 // Add the current loop to the input iters mapping.
123 input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent));
124
125 // 3) for loop contains buffer store with single index
126 auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
127 if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
128 DLOG(INFO)
129 << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a "
130 "single `BufferStoreNode` with a single index variable";
131 return StmtExprMutator::VisitStmt_(op);
132 }
133
134 // 4) buffer store value is a buffer load with single index
135 auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
136 if (!bufferloadnode || bufferloadnode->indices.size() != 1) {
137 DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a "
138 "single `BufferLoadNode` with a single index variable";
139 return StmtExprMutator::VisitStmt_(op);
140 }
141
142 // get store buffer; assert it exists and is contiguous given it uses a single index
143 auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
144 ICHECK(bufferstore && bufferstore->strides.empty());
145
146 // get load buffer; assert it exists and is contiguous given it uses a single index
147 auto bufferload = bufferloadnode->buffer.as<BufferNode>();
148 ICHECK(bufferload && bufferload->strides.empty());
149
150 // we will be replacing the entire for loop including its index
151 // with a DMA copy instrinsic that spans the entire index space of the for loop
152 // so we will need to replace the for loop index with value zero in the buffer indices
153 // thus we eliminate the index from the expression so the DMA copy receives the buffer range
154 // base address
155 Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}};
156
157 // map loop variable to zero for the store index & simplify
158 Array<PrimExpr> store_index = bufferstorenode->indices;
159
160 // Use DetectIterMap to detect whether store index is non-contiguous.
161 arith::Analyzer analyzer;
162 auto store_iter_map = DetectIterMap(store_index, input_iters, 1,
163 arith::IterMapLevel::Surjective, &analyzer, false);
164 if (!store_iter_map->errors.empty()) {
165 LOG(FATAL)
166 << "Unable to lower async dma for non contiguous memory access with store index: "
167 << store_index;
168 }
169
170 store_index.MutateByApply([&](PrimExpr expr) {
171 arith::Analyzer analyzer;
172 return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
173 });
174
175 // map loop variable to zero for the load index & simplify
176 Array<PrimExpr> load_index = bufferloadnode->indices;
177
178 // Use DetectIterMap to detect whether load index is non-contiguous.
179 auto load_iter_map = DetectIterMap(load_index, input_iters, 1,
180 arith::IterMapLevel::Surjective, &analyzer, false);
181 if (!load_iter_map->errors.empty()) {
182 LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: "
183 << load_index;
184 }
185
186 load_index.MutateByApply([&](PrimExpr expr) {
187 arith::Analyzer analyzer;
188 return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
189 });
190
191 // now that we are about to perform the `copy` transform
192 // save queue ID for inspection in `wait` transform
193 queue_ids_.insert(queue_id);
194
195 return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
196 {queue_id,
197 Call(DataType::Handle(), builtin::address_of(),
198 {BufferLoad(bufferstorenode->buffer, store_index)}),
199 Call(DataType::Handle(), builtin::address_of(),
200 {BufferLoad(bufferloadnode->buffer, load_index)}),
201 for_loop->extent * bufferloadnode->dtype.bytes(), dma_bypass_cache_}));
202 }
203 return StmtExprMutator::VisitStmt_(op);
204 }
205
206 private:
207 std::set<int> queue_ids_;
208 bool dma_bypass_cache_;
209 Map<Var, Range> input_iters = Map<Var, Range>();
210};
211
212namespace transform {
213
214Pass LowerAsyncDMA() {
215 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
216 auto fptr = f.CopyOnWrite();
217 bool dma_bypass_cache =
218 ctx->GetConfig<Bool>("tir.experimental_dma_bypass_cache", Bool(false)).value();
219 fptr->body = AsyncDMALowerer(dma_bypass_cache)(std::move(fptr->body));
220 return f;
221 };
222 return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
223}
224
225TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA);
226} // namespace transform
227
228} // namespace tir
229} // namespace tvm
230