1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/target/target.h>
21#include <tvm/tir/stmt_functor.h>
22#include <tvm/tir/transform.h>
23#include <tvm/tir/usmp/algorithms.h>
24#include <tvm/tir/usmp/analysis.h>
25#include <tvm/tir/usmp/transform.h>
26#include <tvm/tir/usmp/utils.h>
27
28#include <stack>
29#include <string>
30
31namespace tvm {
32namespace tir {
33namespace usmp {
34
35/*! \brief Assign PoolInfo objects to allocate that does not have any.
36 * The schedulers have the oppurtunity to assign PoolInfo objects to
37 * allocate nodes. However, each allocate node is expected to have
38 * at least one PoolInfo node assigned to it. If it was not the case,
39 * this Pass will assign all PoolInfo objects that the target could
40 * access.*/
41class PoolInfoAssigner : public StmtExprMutator {
42 public:
43 explicit PoolInfoAssigner(const IRModule& module) {
44 PrimFunc main_func =
45 Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
46 ICHECK(main_func.defined()) << "main function is not in the module";
47 Optional<Target> target_host = main_func->GetAttr<Target>(tvm::attr::kTarget);
48 ICHECK(target_host) << "main function does not have a target attr";
49 WorkspaceMemoryPools workspace_pools =
50 module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools)
51 .value_or(WorkspaceMemoryPools({CreateDefaultWorkspaceMemoryPool(module)}));
52 // make default ConstantPoolInfo if no constant and no workspace pool infos supplied
53 ConstantMemoryPools constant_pools =
54 module->GetAttr<ConstantMemoryPools>(tvm::attr::kConstantMemoryPools)
55 .value_or(
56 module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools).defined()
57 ? ConstantMemoryPools()
58 : ConstantMemoryPools({CreateDefaultConstantMemoryPool(module)}));
59 auto to_map = [](auto pool_infos) {
60 Map<String, Array<PoolInfo>> pool_map;
61 for (const PoolInfo& pool_info : pool_infos) {
62 for (const auto& tgt : pool_info->targets) {
63 if (pool_map.find(tgt->str()) == pool_map.end()) {
64 pool_map.Set(tgt->str(), Array<PoolInfo>());
65 }
66 Array<PoolInfo> pool_info_arr = pool_map[tgt->str()];
67 pool_info_arr.push_back(pool_info);
68 pool_map.Set(tgt->str(), pool_info_arr);
69 }
70 }
71 return pool_map;
72 };
73
74 target_pool_infos_ = to_map(workspace_pools->pools);
75 if (constant_pools.defined()) {
76 target_const_pool_infos_ = to_map(constant_pools->pools);
77 }
78 mod_ = module->ShallowCopy();
79 }
80
81 IRModule operator()();
82
83 private:
84 Stmt VisitStmt_(const AllocateNode* op) override;
85 Stmt VisitStmt_(const AllocateConstNode* op) override;
86
87 IRModule mod_;
88 Map<String, Array<PoolInfo>> target_pool_infos_;
89 Map<String, Array<PoolInfo>> target_const_pool_infos_;
90 PrimFunc func_;
91 WorkspacePoolInfo CreateDefaultWorkspaceMemoryPool(const IRModule& module);
92 ConstantPoolInfo CreateDefaultConstantMemoryPool(const IRModule& module) {
93 auto p = CreateDefaultWorkspaceMemoryPool(module);
94 return ConstantPoolInfo(
95 "global_const_workspace", {p->targets}, {},
96 PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth,
97 kUnknownWriteBandwidth, 0, 0, {p->target_burst_bytes}, Bool(true)));
98 }
99};
100
101WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) {
102 VLOG(1) << "Creating default memory pool for:" << std::endl << module;
103 Map<Target, String> target_access;
104 tir::PrimFunc tir_main_func =
105 Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
106 Target target_host = tir_main_func->GetAttr<Target>(tvm::attr::kTarget).value();
107 for (const auto& kv : module->functions) {
108 BaseFunc func = kv.second;
109 Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
110 target_access.Set(target.value_or(target_host), kTargetPoolReadWriteAccess);
111 }
112 Array<Target> targets;
113 for (const auto& kv : target_access) {
114 bool exist = false;
115 // Exclude targets with the same string representation
116 for (const auto& t : targets) {
117 if (t->str() == kv.first->str()) {
118 exist = true;
119 }
120 }
121 if (!exist) {
122 targets.push_back(kv.first);
123 }
124 }
125 return WorkspacePoolInfo(
126 "global_workspace", targets,
127 PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth,
128 kUnknownWriteBandwidth, 0, 0, {{target_host, 1}}, Bool(true)));
129}
130
131Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) {
132 Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value();
133 ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_;
134 Map<String, ObjectRef> annotations = Map<String, ObjectRef>(op->annotations);
135 if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) {
136 ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0)
137 << "Target " << tgt << " not found among " << target_pool_infos_;
138 annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]);
139 }
140 Stmt body = VisitStmt(op->body);
141 auto allocate =
142 Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body, annotations);
143 return std::move(allocate);
144}
145
146Stmt PoolInfoAssigner::VisitStmt_(const AllocateConstNode* op) {
147 if (!target_const_pool_infos_.size()) {
148 return StmtExprMutator::VisitStmt_(op);
149 }
150 Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value();
151 ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_;
152 Map<String, ObjectRef> annotations = Map<String, ObjectRef>(op->annotations);
153 if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) {
154 annotations.Set(kPoolCandidatesAllocateAttr, target_const_pool_infos_[tgt.value()->str()]);
155 annotations.Set(kTargetPoolReadOnlyAccess, Integer(1));
156 }
157 Stmt body = VisitStmt(op->body);
158 auto allocate_const =
159 AllocateConst(op->buffer_var, op->dtype, op->extents, op->data, body, annotations);
160 return std::move(allocate_const);
161}
162
163IRModule PoolInfoAssigner::operator()() {
164 for (const auto& kv : mod_->functions) {
165 GlobalVar gv = kv.first;
166 if (kv.second->IsInstance<PrimFuncNode>()) {
167 func_ = Downcast<PrimFunc>(kv.second);
168 Stmt body = this->VisitStmt(func_->body);
169 PrimFunc new_prim_func =
170 PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs);
171 mod_->Update(gv, new_prim_func);
172 }
173 }
174 return mod_;
175}
176
177namespace transform {
178
179tvm::transform::Pass AssignPoolInfo() {
180 auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
181 return PoolInfoAssigner(m)();
182 };
183 return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.AssignPoolInfo", {});
184}
185
186TVM_REGISTER_GLOBAL("tir.usmp.transform.AssignPoolInfo").set_body_typed(AssignPoolInfo);
187
188} // namespace transform
189
190} // namespace usmp
191} // namespace tir
192} // namespace tvm
193