1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file texture_flatten.cc |
22 | * \brief Flattens texture storage from multi-dimensional array |
23 | * to 2D (width, height) buffer access |
24 | */ |
25 | |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/te/operation.h> |
28 | #include <tvm/tir/builtin.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/stmt.h> |
31 | #include <tvm/tir/transform.h> |
32 | |
33 | #include <unordered_map> |
34 | |
35 | #include "../../arith/ir_visitor_with_analyzer.h" |
36 | #include "../../runtime/texture.h" |
37 | #include "../../runtime/thread_storage_scope.h" |
38 | |
39 | namespace tvm { |
40 | namespace tir { |
41 | using arith::IRVisitorWithAnalyzer; |
42 | using runtime::ApplyTexture2DFlattening; |
43 | using runtime::DefaultTextureLayoutSeparator; |
44 | using runtime::IsTextureStorage; |
45 | |
46 | class TextureLoweringBase : public StmtExprMutator { |
47 | public: |
48 | explicit TextureLoweringBase(const Map<Var, Buffer>& extern_buffer_map, |
49 | IRVisitorWithAnalyzer* bound_analyzer) |
50 | : bound_analyzer_{bound_analyzer} { |
51 | for (auto kv : extern_buffer_map) { |
52 | extern_buf_.insert(kv.second); |
53 | } |
54 | } |
55 | |
56 | inline PrimExpr SimplifyOffset(const Array<PrimExpr>& shape, const Array<PrimExpr>& index) const { |
57 | PrimExpr base = make_const(DataType::Int(32), 0); |
58 | ICHECK_EQ(shape.size(), index.size()); |
59 | if (index.size() > 0) { |
60 | PrimExpr offset = index[0]; |
61 | for (size_t i = 1; i < index.size(); ++i) { |
62 | offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]); |
63 | } |
64 | base = base + offset; |
65 | } |
66 | return base; |
67 | } |
68 | |
69 | protected: |
70 | std::string GetStorageScope(const Buffer& buffer) { |
71 | auto* ptr = buffer->data->type_annotation.as<PointerTypeNode>(); |
72 | ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType" ; |
73 | return ptr->storage_scope; |
74 | } |
75 | |
76 | // Set of all external input and output buffers |
77 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buf_; |
78 | // Bound analzer |
79 | IRVisitorWithAnalyzer* bound_analyzer_; |
80 | }; |
81 | |
82 | // Lower Nd storage access to 2d texture access using lowering convention |
83 | // specified by the buffers storage scope. |
84 | class TextureFlattener : public TextureLoweringBase { |
85 | public: |
86 | using StmtExprMutator::VisitStmt_; |
87 | explicit TextureFlattener(const Map<Var, Buffer>& extern_buffer_map, |
88 | IRVisitorWithAnalyzer* bound_analyzer) |
89 | : TextureLoweringBase(extern_buffer_map, bound_analyzer) {} |
90 | |
91 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
92 | if (extern_buf_.count(op->buffer)) { |
93 | return this->VisitStmt(op->body); |
94 | } |
95 | |
96 | std::string storage_scope = GetStorageScope(op->buffer); |
97 | Var buffer_var(op->buffer->data->name_hint, |
98 | PointerType(PrimType(op->buffer->dtype), String(storage_scope))); |
99 | let_binding_.insert({op->buffer->data, buffer_var}); |
100 | |
101 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
102 | op = stmt.as<BufferRealizeNode>(); |
103 | |
104 | // Rewrite any buffer realizations with storage scope to 2d texture allocations |
105 | if (IsTextureStorage(storage_scope)) { |
106 | Stmt body = this->VisitStmt(op->body); |
107 | ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported" ; |
108 | int vec_length = static_cast<int>(op->bounds.back()->extent.as<IntImmNode>()->value); |
109 | ICHECK(vec_length == 4 || vec_length == 1) |
110 | << "Inner dimension of texture must be vector of length 1 or 4 (RGBA), was: " |
111 | << vec_length; |
112 | |
113 | struct ShapeFromRange { |
114 | const Array<Range>& bounds; |
115 | PrimExpr operator[](size_t i) const { return bounds[i]->extent; } |
116 | }; |
117 | size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope); |
118 | auto texture = |
119 | ApplyTexture2DFlattening<PrimExpr>(ShapeFromRange{op->bounds}, op->bounds.size(), axis); |
120 | Array<PrimExpr> args; |
121 | args.push_back(StringImm(storage_scope)); |
122 | args.push_back(IntImm(DataType::Int(64), 2)); // 2d |
123 | args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), |
124 | {texture.width, texture.height})); |
125 | stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), |
126 | body); |
127 | } |
128 | |
129 | return stmt; |
130 | } |
131 | |
132 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
133 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
134 | op = stmt.as<BufferStoreNode>(); |
135 | std::string storage_scope = GetStorageScope(op->buffer); |
136 | // Lower to two dimensional access |
137 | if (IsTextureStorage(storage_scope)) { |
138 | Array<PrimExpr> args = GetTextureAccessArgs(op, op->buffer); |
139 | args.push_back(op->value); |
140 | stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); |
141 | } |
142 | |
143 | return stmt; |
144 | } |
145 | |
146 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
147 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
148 | op = expr.as<BufferLoadNode>(); |
149 | // Lower to two dimensional access |
150 | std::string storage_scope = GetStorageScope(op->buffer); |
151 | if (IsTextureStorage(storage_scope)) { |
152 | Array<PrimExpr> args = GetTextureAccessArgs(op, op->buffer); |
153 | args.push_back(op->indices.back()); |
154 | expr = Call(op->buffer->dtype, builtin::texture2d_load(), args); |
155 | } |
156 | |
157 | return expr; |
158 | } |
159 | |
160 | protected: |
161 | template <typename T> |
162 | Array<PrimExpr> GetTextureAccessArgs(const T* op, const Buffer& buffer) { |
163 | Array<PrimExpr> args; |
164 | if (let_binding_.count(op->buffer->data)) { |
165 | args.push_back(let_binding_[op->buffer->data]); |
166 | } else { |
167 | args.push_back(buffer->data); |
168 | } |
169 | Array<PrimExpr> row_dims, row_indices, col_dims, col_indices; |
170 | for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) { |
171 | if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) { |
172 | col_dims.push_back(op->buffer->shape[i]); |
173 | col_indices.push_back(op->indices[i]); |
174 | } else { |
175 | row_dims.push_back(op->buffer->shape[i]); |
176 | row_indices.push_back(op->indices[i]); |
177 | } |
178 | } |
179 | PrimExpr row_offset = SimplifyOffset(row_dims, row_indices); |
180 | PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); |
181 | args.push_back(row_offset); |
182 | args.push_back(col_offset); |
183 | return args; |
184 | } |
185 | |
186 | // Bindings to new texture vars with texture pointer scope |
187 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_; |
188 | }; |
189 | |
190 | PrimFunc TextureFlatten(PrimFunc func) { |
191 | auto fptr = func.CopyOnWrite(); |
192 | IRVisitorWithAnalyzer bound_analyzer; |
193 | bound_analyzer(fptr->body); |
194 | fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); |
195 | return func; |
196 | } |
197 | |
198 | namespace transform { |
199 | |
200 | Pass TextureFlatten() { |
201 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
202 | return TextureFlatten(std::move(f)); |
203 | }; |
204 | return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten" , {}); |
205 | } |
206 | |
207 | TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten" ).set_body_typed(TextureFlatten); |
208 | |
209 | } // namespace transform |
210 | |
211 | } // namespace tir |
212 | } // namespace tvm |
213 | |