1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file storage_rewrite.cc |
22 | * \brief Memory access pattern analysis and optimization. |
23 | * Re-write data access to enable memory sharing when possible. |
24 | */ |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/ir/type.h> |
27 | #include <tvm/relay/expr.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/target/target_info.h> |
30 | #include <tvm/tir/analysis.h> |
31 | #include <tvm/tir/builtin.h> |
32 | #include <tvm/tir/expr.h> |
33 | #include <tvm/tir/function.h> |
34 | #include <tvm/tir/stmt_functor.h> |
35 | #include <tvm/tir/transform.h> |
36 | |
37 | #include <map> |
38 | #include <unordered_map> |
39 | #include <unordered_set> |
40 | |
41 | #include "../../runtime/thread_storage_scope.h" |
42 | #include "ir_utils.h" |
43 | |
44 | namespace tvm { |
45 | namespace tir { |
46 | |
47 | class ParamsCollector : public StmtExprVisitor { |
48 | public: |
49 | explicit ParamsCollector(const Map<tir::Var, runtime::NDArray>& constant_map) |
50 | : constant_map_(constant_map) {} |
51 | std::vector<const tir::VarNode*> CollectParams(tir::Stmt body) { |
52 | this->VisitStmt(body); |
53 | return constant_list_; |
54 | } |
55 | |
56 | void VisitExpr_(const BufferLoadNode* ln) { |
57 | if (constant_map_.find(ln->buffer->data) != constant_map_.end()) { |
58 | auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get()); |
59 | if (it == constant_list_.end()) { |
60 | constant_list_.push_back(ln->buffer->data.get()); |
61 | } |
62 | } |
63 | StmtExprVisitor::VisitExpr_(ln); |
64 | } |
65 | |
66 | void VisitExpr_(const CallNode* cn) { |
67 | if (cn->op.same_as(builtin::tvm_access_ptr())) { |
68 | ICHECK_EQ(cn->args.size(), 5U); |
69 | const Var& var = Downcast<Var>(cn->args[1]); |
70 | const VarNode* buffer = cn->args[1].as<VarNode>(); |
71 | auto it = constant_map_.find(var); |
72 | if (it != constant_map_.end()) { |
73 | auto it = std::find(constant_list_.begin(), constant_list_.end(), buffer); |
74 | if (it == constant_list_.end()) { |
75 | constant_list_.push_back(buffer); |
76 | } |
77 | } |
78 | } |
79 | StmtExprVisitor::VisitExpr_(cn); |
80 | } |
81 | |
82 | private: |
83 | std::vector<const tir::VarNode*> constant_list_; |
84 | Map<tir::Var, runtime::NDArray> constant_map_; |
85 | }; |
86 | |
87 | PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants) { |
88 | Map<tir::Var, runtime::NDArray> constant_map; |
89 | |
90 | // Remove constants from the primfunc signature |
91 | size_t num_constants = constants.size(); |
92 | size_t start = f->params.size() - num_constants; |
93 | Array<tir::Var> params; |
94 | for (unsigned i = 0; i < start; i++) { |
95 | params.push_back(f->params[i]); |
96 | } |
97 | |
98 | auto* n = f.CopyOnWrite(); |
99 | for (unsigned i = start; i < f->params.size(); i++) { |
100 | tir::Var p = n->params[i]; |
101 | tir::Var b = n->buffer_map[p]->data; |
102 | n->buffer_map.erase(p); |
103 | constant_map.Set(b, constants[i - start]); |
104 | } |
105 | n->params = params; |
106 | auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); |
107 | |
108 | // Allocate constants within the primfunc |
109 | for (auto i : constant_list) { |
110 | auto var = GetRef<Var>(i); |
111 | int ndim = constant_map[var]->ndim; |
112 | Array<PrimExpr> extents; |
113 | |
114 | for (int i = 0; i < ndim; i++) { |
115 | int shape = constant_map[var]->shape[i]; |
116 | extents.push_back(make_const(DataType::Int(32), shape)); |
117 | } |
118 | DataType dtype = DataType(constant_map[var]->dtype); |
119 | |
120 | if (n->body->IsInstance<BlockRealizeNode>()) { |
121 | auto* block_realize = n->body.as<BlockRealizeNode>(); |
122 | auto block = block_realize->block; |
123 | block.CopyOnWrite()->body = |
124 | tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); |
125 | n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block); |
126 | } else { |
127 | n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); |
128 | } |
129 | } |
130 | return f; |
131 | } |
132 | |
133 | namespace transform { |
134 | |
135 | Pass BindParams(const Array<runtime::NDArray>& constants) { |
136 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
137 | return BindParams(f, constants); |
138 | }; |
139 | return CreatePrimFuncPass(pass_func, 0, "tir.BindParams" , {}); |
140 | } |
141 | } // namespace transform |
142 | |
143 | } // namespace tir |
144 | } // namespace tvm |
145 | |