1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Replace copy from global to shared with async copy |
22 | * \file inject_ptx_async_copy.cc |
23 | */ |
24 | #include <tvm/tir/analysis.h> |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include "../ir/buffer_common.h" |
31 | #include "storage_access.h" |
32 | #include "tvm/tir/stmt.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | class PTXAsyncCopyInjector : public StmtMutator { |
38 | public: |
39 | Stmt VisitStmt_(const AttrStmtNode* attr) { |
40 | if (attr->attr_key == tir::attr::async_scope) { |
41 | ICHECK(in_async == false) << "Nested async scopes not supported" ; |
42 | in_async = true; |
43 | auto body = this->VisitStmt(attr->body); |
44 | in_async = false; |
45 | return body; |
46 | } |
47 | return StmtMutator::VisitStmt_(attr); |
48 | } |
49 | |
50 | Stmt VisitStmt_(const BufferStoreNode* store) { |
51 | if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn" )) { |
52 | if (auto* load = store->value.as<BufferLoadNode>()) { |
53 | if (load->buffer.scope() == "global" ) { |
54 | ICHECK(load->indices.size() == 1 && store->indices.size() == 1); |
55 | ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); |
56 | |
57 | const int indices_lanes = load->indices[0]->dtype.lanes(); |
58 | const int bytes = indices_lanes * load->buffer->dtype.bytes(); |
59 | |
60 | if (bytes == 4 || bytes == 8 || bytes == 16) { |
61 | auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); |
62 | auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); |
63 | ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) |
64 | << "Both store and load buffer should have a pointer type annotation." ; |
65 | |
66 | int index_factor = 1; |
67 | if (dst_elem_type.value() != src_elem_type.value()) { |
68 | // The only case where src and dst have different dtypes is when the dst shared memory |
69 | // is a byte buffer generated by merging dynamic shared memory. |
70 | ICHECK(store->buffer.scope() == "shared.dyn" ); |
71 | ICHECK(dst_elem_type.value() == DataType::UInt(8)); |
72 | // BufferStore/Load have the "pointer reinterpret" semantics according to their |
73 | // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, |
74 | // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; |
75 | // To replace BufferStore/Load with cp.async, we need to multiply the store index by |
76 | // the byte size of the "value" dtype, to get the correct offset into the byte buffer. |
77 | index_factor = src_elem_type->bytes(); |
78 | } |
79 | |
80 | if (indices_lanes == 1) { |
81 | auto src_offset = load->indices[0]; |
82 | auto dst_offset = store->indices[0]; |
83 | return Evaluate( |
84 | Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), |
85 | {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), |
86 | load->buffer->data, src_offset, PrimExpr(bytes)})); |
87 | } |
88 | |
89 | // Only some vectorized indexing patterns are supported for now. |
90 | auto src_offset = [=]() -> PrimExpr { |
91 | if (load->indices[0]->IsInstance<RampNode>()) { |
92 | return load->indices[0].as<RampNode>()->base; |
93 | } |
94 | return PrimExpr(); |
95 | }(); |
96 | |
97 | auto dst_offset = [=]() -> PrimExpr { |
98 | if (store->indices[0].as<RampNode>()) { |
99 | return store->indices[0].as<RampNode>()->base; |
100 | } else if (store->indices[0].as<AddNode>()) { |
101 | // The case where the dst buffer is a byte buffer generated by merging dynamic |
102 | // shared memory. |
103 | // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] |
104 | auto* add = store->indices[0].as<AddNode>(); |
105 | if (!add->a->IsInstance<RampNode>()) return PrimExpr(); |
106 | if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr(); |
107 | return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value); |
108 | } |
109 | return PrimExpr(); |
110 | }(); |
111 | |
112 | if (src_offset.defined() && dst_offset.defined()) { |
113 | return Evaluate( |
114 | Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), |
115 | {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), |
116 | load->buffer->data, src_offset, PrimExpr(bytes)})); |
117 | } |
118 | } |
119 | } |
120 | } |
121 | } |
122 | return StmtMutator::VisitStmt_(store); |
123 | } |
124 | |
125 | private: |
126 | bool in_async{false}; |
127 | }; |
128 | |
129 | namespace transform { |
130 | |
131 | Pass InjectPTXAsyncCopy() { |
132 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
133 | auto* n = f.CopyOnWrite(); |
134 | n->body = PTXAsyncCopyInjector()(n->body); |
135 | return f; |
136 | }; |
137 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy" , {}); |
138 | } |
139 | |
140 | TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy" ).set_body_typed(InjectPTXAsyncCopy); |
141 | |
142 | } // namespace transform |
143 | |
144 | } // namespace tir |
145 | } // namespace tvm |
146 | |