1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file texture_flatten.cc
22 * \brief Flattens texture storage from multi-dimensional array
23 * to 2D (width, height) buffer access
24 */
25
26#include <tvm/runtime/registry.h>
27#include <tvm/te/operation.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt.h>
31#include <tvm/tir/transform.h>
32
33#include <unordered_map>
34
35#include "../../arith/ir_visitor_with_analyzer.h"
36#include "../../runtime/texture.h"
37#include "../../runtime/thread_storage_scope.h"
38
39namespace tvm {
40namespace tir {
41using arith::IRVisitorWithAnalyzer;
42using runtime::ApplyTexture2DFlattening;
43using runtime::DefaultTextureLayoutSeparator;
44using runtime::IsTextureStorage;
45
46class TextureLoweringBase : public StmtExprMutator {
47 public:
48 explicit TextureLoweringBase(const Map<Var, Buffer>& extern_buffer_map,
49 IRVisitorWithAnalyzer* bound_analyzer)
50 : bound_analyzer_{bound_analyzer} {
51 for (auto kv : extern_buffer_map) {
52 extern_buf_.insert(kv.second);
53 }
54 }
55
56 inline PrimExpr SimplifyOffset(const Array<PrimExpr>& shape, const Array<PrimExpr>& index) const {
57 PrimExpr base = make_const(DataType::Int(32), 0);
58 ICHECK_EQ(shape.size(), index.size());
59 if (index.size() > 0) {
60 PrimExpr offset = index[0];
61 for (size_t i = 1; i < index.size(); ++i) {
62 offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]);
63 }
64 base = base + offset;
65 }
66 return base;
67 }
68
69 protected:
70 std::string GetStorageScope(const Buffer& buffer) {
71 auto* ptr = buffer->data->type_annotation.as<PointerTypeNode>();
72 ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
73 return ptr->storage_scope;
74 }
75
76 // Set of all external input and output buffers
77 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buf_;
78 // Bound analzer
79 IRVisitorWithAnalyzer* bound_analyzer_;
80};
81
82// Lower Nd storage access to 2d texture access using lowering convention
83// specified by the buffers storage scope.
84class TextureFlattener : public TextureLoweringBase {
85 public:
86 using StmtExprMutator::VisitStmt_;
87 explicit TextureFlattener(const Map<Var, Buffer>& extern_buffer_map,
88 IRVisitorWithAnalyzer* bound_analyzer)
89 : TextureLoweringBase(extern_buffer_map, bound_analyzer) {}
90
91 Stmt VisitStmt_(const BufferRealizeNode* op) final {
92 if (extern_buf_.count(op->buffer)) {
93 return this->VisitStmt(op->body);
94 }
95
96 std::string storage_scope = GetStorageScope(op->buffer);
97 Var buffer_var(op->buffer->data->name_hint,
98 PointerType(PrimType(op->buffer->dtype), String(storage_scope)));
99 let_binding_.insert({op->buffer->data, buffer_var});
100
101 Stmt stmt = StmtExprMutator::VisitStmt_(op);
102 op = stmt.as<BufferRealizeNode>();
103
104 // Rewrite any buffer realizations with storage scope to 2d texture allocations
105 if (IsTextureStorage(storage_scope)) {
106 Stmt body = this->VisitStmt(op->body);
107 ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported";
108 int vec_length = static_cast<int>(op->bounds.back()->extent.as<IntImmNode>()->value);
109 ICHECK(vec_length == 4 || vec_length == 1)
110 << "Inner dimension of texture must be vector of length 1 or 4 (RGBA), was: "
111 << vec_length;
112
113 struct ShapeFromRange {
114 const Array<Range>& bounds;
115 PrimExpr operator[](size_t i) const { return bounds[i]->extent; }
116 };
117 size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope);
118 auto texture =
119 ApplyTexture2DFlattening<PrimExpr>(ShapeFromRange{op->bounds}, op->bounds.size(), axis);
120 Array<PrimExpr> args;
121 args.push_back(StringImm(storage_scope));
122 args.push_back(IntImm(DataType::Int(64), 2)); // 2d
123 args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(),
124 {texture.width, texture.height}));
125 stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args),
126 body);
127 }
128
129 return stmt;
130 }
131
132 Stmt VisitStmt_(const BufferStoreNode* op) final {
133 Stmt stmt = StmtExprMutator::VisitStmt_(op);
134 op = stmt.as<BufferStoreNode>();
135 std::string storage_scope = GetStorageScope(op->buffer);
136 // Lower to two dimensional access
137 if (IsTextureStorage(storage_scope)) {
138 Array<PrimExpr> args = GetTextureAccessArgs(op, op->buffer);
139 args.push_back(op->value);
140 stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args));
141 }
142
143 return stmt;
144 }
145
146 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
147 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
148 op = expr.as<BufferLoadNode>();
149 // Lower to two dimensional access
150 std::string storage_scope = GetStorageScope(op->buffer);
151 if (IsTextureStorage(storage_scope)) {
152 Array<PrimExpr> args = GetTextureAccessArgs(op, op->buffer);
153 args.push_back(op->indices.back());
154 expr = Call(op->buffer->dtype, builtin::texture2d_load(), args);
155 }
156
157 return expr;
158 }
159
160 protected:
161 template <typename T>
162 Array<PrimExpr> GetTextureAccessArgs(const T* op, const Buffer& buffer) {
163 Array<PrimExpr> args;
164 if (let_binding_.count(op->buffer->data)) {
165 args.push_back(let_binding_[op->buffer->data]);
166 } else {
167 args.push_back(buffer->data);
168 }
169 Array<PrimExpr> row_dims, row_indices, col_dims, col_indices;
170 for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) {
171 if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) {
172 col_dims.push_back(op->buffer->shape[i]);
173 col_indices.push_back(op->indices[i]);
174 } else {
175 row_dims.push_back(op->buffer->shape[i]);
176 row_indices.push_back(op->indices[i]);
177 }
178 }
179 PrimExpr row_offset = SimplifyOffset(row_dims, row_indices);
180 PrimExpr col_offset = SimplifyOffset(col_dims, col_indices);
181 args.push_back(row_offset);
182 args.push_back(col_offset);
183 return args;
184 }
185
186 // Bindings to new texture vars with texture pointer scope
187 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
188};
189
190PrimFunc TextureFlatten(PrimFunc func) {
191 auto fptr = func.CopyOnWrite();
192 IRVisitorWithAnalyzer bound_analyzer;
193 bound_analyzer(fptr->body);
194 fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
195 return func;
196}
197
198namespace transform {
199
200Pass TextureFlatten() {
201 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
202 return TextureFlatten(std::move(f));
203 };
204 return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten", {});
205}
206
207TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten").set_body_typed(TextureFlatten);
208
209} // namespace transform
210
211} // namespace tir
212} // namespace tvm
213