1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Replace copy from global to shared with async copy
22 * \file inject_ptx_async_copy.cc
23 */
24#include <tvm/tir/analysis.h>
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include "../ir/buffer_common.h"
31#include "storage_access.h"
32#include "tvm/tir/stmt.h"
33
34namespace tvm {
35namespace tir {
36
37class PTXAsyncCopyInjector : public StmtMutator {
38 public:
39 Stmt VisitStmt_(const AttrStmtNode* attr) {
40 if (attr->attr_key == tir::attr::async_scope) {
41 ICHECK(in_async == false) << "Nested async scopes not supported";
42 in_async = true;
43 auto body = this->VisitStmt(attr->body);
44 in_async = false;
45 return body;
46 }
47 return StmtMutator::VisitStmt_(attr);
48 }
49
50 Stmt VisitStmt_(const BufferStoreNode* store) {
51 if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) {
52 if (auto* load = store->value.as<BufferLoadNode>()) {
53 if (load->buffer.scope() == "global") {
54 ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
55 ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());
56
57 const int indices_lanes = load->indices[0]->dtype.lanes();
58 const int bytes = indices_lanes * load->buffer->dtype.bytes();
59
60 if (bytes == 4 || bytes == 8 || bytes == 16) {
61 auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
62 auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
63 ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
64 << "Both store and load buffer should have a pointer type annotation.";
65
66 int index_factor = 1;
67 if (dst_elem_type.value() != src_elem_type.value()) {
68 // The only case where src and dst have different dtypes is when the dst shared memory
69 // is a byte buffer generated by merging dynamic shared memory.
70 ICHECK(store->buffer.scope() == "shared.dyn");
71 ICHECK(dst_elem_type.value() == DataType::UInt(8));
72 // BufferStore/Load have the "pointer reinterpret" semantics according to their
73 // "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
74 // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
75 // To replace BufferStore/Load with cp.async, we need to multiply the store index by
76 // the byte size of the "value" dtype, to get the correct offset into the byte buffer.
77 index_factor = src_elem_type->bytes();
78 }
79
80 if (indices_lanes == 1) {
81 auto src_offset = load->indices[0];
82 auto dst_offset = store->indices[0];
83 return Evaluate(
84 Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
85 {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
86 load->buffer->data, src_offset, PrimExpr(bytes)}));
87 }
88
89 // Only some vectorized indexing patterns are supported for now.
90 auto src_offset = [=]() -> PrimExpr {
91 if (load->indices[0]->IsInstance<RampNode>()) {
92 return load->indices[0].as<RampNode>()->base;
93 }
94 return PrimExpr();
95 }();
96
97 auto dst_offset = [=]() -> PrimExpr {
98 if (store->indices[0].as<RampNode>()) {
99 return store->indices[0].as<RampNode>()->base;
100 } else if (store->indices[0].as<AddNode>()) {
101 // The case where the dst buffer is a byte buffer generated by merging dynamic
102 // shared memory.
103 // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
104 auto* add = store->indices[0].as<AddNode>();
105 if (!add->a->IsInstance<RampNode>()) return PrimExpr();
106 if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
107 return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
108 }
109 return PrimExpr();
110 }();
111
112 if (src_offset.defined() && dst_offset.defined()) {
113 return Evaluate(
114 Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
115 {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
116 load->buffer->data, src_offset, PrimExpr(bytes)}));
117 }
118 }
119 }
120 }
121 }
122 return StmtMutator::VisitStmt_(store);
123 }
124
125 private:
126 bool in_async{false};
127};
128
129namespace transform {
130
131Pass InjectPTXAsyncCopy() {
132 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
133 auto* n = f.CopyOnWrite();
134 n->body = PTXAsyncCopyInjector()(n->body);
135 return f;
136 };
137 return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
138}
139
140TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy);
141
142} // namespace transform
143
144} // namespace tir
145} // namespace tvm
146