1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file storage_rewrite.cc
22 * \brief Memory access pattern analysis and optimization.
23 * Re-write data access to enable memory sharing when possible.
24 */
25#include <tvm/arith/analyzer.h>
26#include <tvm/ir/type.h>
27#include <tvm/relay/expr.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/target/target_info.h>
30#include <tvm/tir/analysis.h>
31#include <tvm/tir/builtin.h>
32#include <tvm/tir/expr.h>
33#include <tvm/tir/function.h>
34#include <tvm/tir/stmt_functor.h>
35#include <tvm/tir/transform.h>
36
37#include <map>
38#include <unordered_map>
39#include <unordered_set>
40
41#include "../../runtime/thread_storage_scope.h"
42#include "ir_utils.h"
43
44namespace tvm {
45namespace tir {
46
47class ParamsCollector : public StmtExprVisitor {
48 public:
49 explicit ParamsCollector(const Map<tir::Var, runtime::NDArray>& constant_map)
50 : constant_map_(constant_map) {}
51 std::vector<const tir::VarNode*> CollectParams(tir::Stmt body) {
52 this->VisitStmt(body);
53 return constant_list_;
54 }
55
56 void VisitExpr_(const BufferLoadNode* ln) {
57 if (constant_map_.find(ln->buffer->data) != constant_map_.end()) {
58 auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get());
59 if (it == constant_list_.end()) {
60 constant_list_.push_back(ln->buffer->data.get());
61 }
62 }
63 StmtExprVisitor::VisitExpr_(ln);
64 }
65
66 void VisitExpr_(const CallNode* cn) {
67 if (cn->op.same_as(builtin::tvm_access_ptr())) {
68 ICHECK_EQ(cn->args.size(), 5U);
69 const Var& var = Downcast<Var>(cn->args[1]);
70 const VarNode* buffer = cn->args[1].as<VarNode>();
71 auto it = constant_map_.find(var);
72 if (it != constant_map_.end()) {
73 auto it = std::find(constant_list_.begin(), constant_list_.end(), buffer);
74 if (it == constant_list_.end()) {
75 constant_list_.push_back(buffer);
76 }
77 }
78 }
79 StmtExprVisitor::VisitExpr_(cn);
80 }
81
82 private:
83 std::vector<const tir::VarNode*> constant_list_;
84 Map<tir::Var, runtime::NDArray> constant_map_;
85};
86
87PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants) {
88 Map<tir::Var, runtime::NDArray> constant_map;
89
90 // Remove constants from the primfunc signature
91 size_t num_constants = constants.size();
92 size_t start = f->params.size() - num_constants;
93 Array<tir::Var> params;
94 for (unsigned i = 0; i < start; i++) {
95 params.push_back(f->params[i]);
96 }
97
98 auto* n = f.CopyOnWrite();
99 for (unsigned i = start; i < f->params.size(); i++) {
100 tir::Var p = n->params[i];
101 tir::Var b = n->buffer_map[p]->data;
102 n->buffer_map.erase(p);
103 constant_map.Set(b, constants[i - start]);
104 }
105 n->params = params;
106 auto constant_list = ParamsCollector(constant_map).CollectParams(n->body);
107
108 // Allocate constants within the primfunc
109 for (auto i : constant_list) {
110 auto var = GetRef<Var>(i);
111 int ndim = constant_map[var]->ndim;
112 Array<PrimExpr> extents;
113
114 for (int i = 0; i < ndim; i++) {
115 int shape = constant_map[var]->shape[i];
116 extents.push_back(make_const(DataType::Int(32), shape));
117 }
118 DataType dtype = DataType(constant_map[var]->dtype);
119
120 if (n->body->IsInstance<BlockRealizeNode>()) {
121 auto* block_realize = n->body.as<BlockRealizeNode>();
122 auto block = block_realize->block;
123 block.CopyOnWrite()->body =
124 tir::AllocateConst(var, dtype, extents, constant_map[var], block->body);
125 n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block);
126 } else {
127 n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body);
128 }
129 }
130 return f;
131}
132
133namespace transform {
134
135Pass BindParams(const Array<runtime::NDArray>& constants) {
136 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
137 return BindParams(f, constants);
138 };
139 return CreatePrimFuncPass(pass_func, 0, "tir.BindParams", {});
140}
141} // namespace transform
142
143} // namespace tir
144} // namespace tvm
145