1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/target/target.h> |
21 | #include <tvm/tir/stmt_functor.h> |
22 | #include <tvm/tir/transform.h> |
23 | #include <tvm/tir/usmp/algorithms.h> |
24 | #include <tvm/tir/usmp/analysis.h> |
25 | #include <tvm/tir/usmp/transform.h> |
26 | #include <tvm/tir/usmp/utils.h> |
27 | |
28 | #include <stack> |
29 | #include <string> |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | namespace usmp { |
34 | |
35 | /*! \brief Assign PoolInfo objects to allocate that does not have any. |
36 | * The schedulers have the oppurtunity to assign PoolInfo objects to |
37 | * allocate nodes. However, each allocate node is expected to have |
38 | * at least one PoolInfo node assigned to it. If it was not the case, |
39 | * this Pass will assign all PoolInfo objects that the target could |
40 | * access.*/ |
41 | class PoolInfoAssigner : public StmtExprMutator { |
42 | public: |
43 | explicit PoolInfoAssigner(const IRModule& module) { |
44 | PrimFunc main_func = |
45 | Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
46 | ICHECK(main_func.defined()) << "main function is not in the module" ; |
47 | Optional<Target> target_host = main_func->GetAttr<Target>(tvm::attr::kTarget); |
48 | ICHECK(target_host) << "main function does not have a target attr" ; |
49 | WorkspaceMemoryPools workspace_pools = |
50 | module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools) |
51 | .value_or(WorkspaceMemoryPools({CreateDefaultWorkspaceMemoryPool(module)})); |
52 | // make default ConstantPoolInfo if no constant and no workspace pool infos supplied |
53 | ConstantMemoryPools constant_pools = |
54 | module->GetAttr<ConstantMemoryPools>(tvm::attr::kConstantMemoryPools) |
55 | .value_or( |
56 | module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools).defined() |
57 | ? ConstantMemoryPools() |
58 | : ConstantMemoryPools({CreateDefaultConstantMemoryPool(module)})); |
59 | auto to_map = [](auto pool_infos) { |
60 | Map<String, Array<PoolInfo>> pool_map; |
61 | for (const PoolInfo& pool_info : pool_infos) { |
62 | for (const auto& tgt : pool_info->targets) { |
63 | if (pool_map.find(tgt->str()) == pool_map.end()) { |
64 | pool_map.Set(tgt->str(), Array<PoolInfo>()); |
65 | } |
66 | Array<PoolInfo> pool_info_arr = pool_map[tgt->str()]; |
67 | pool_info_arr.push_back(pool_info); |
68 | pool_map.Set(tgt->str(), pool_info_arr); |
69 | } |
70 | } |
71 | return pool_map; |
72 | }; |
73 | |
74 | target_pool_infos_ = to_map(workspace_pools->pools); |
75 | if (constant_pools.defined()) { |
76 | target_const_pool_infos_ = to_map(constant_pools->pools); |
77 | } |
78 | mod_ = module->ShallowCopy(); |
79 | } |
80 | |
81 | IRModule operator()(); |
82 | |
83 | private: |
84 | Stmt VisitStmt_(const AllocateNode* op) override; |
85 | Stmt VisitStmt_(const AllocateConstNode* op) override; |
86 | |
87 | IRModule mod_; |
88 | Map<String, Array<PoolInfo>> target_pool_infos_; |
89 | Map<String, Array<PoolInfo>> target_const_pool_infos_; |
90 | PrimFunc func_; |
91 | WorkspacePoolInfo CreateDefaultWorkspaceMemoryPool(const IRModule& module); |
92 | ConstantPoolInfo CreateDefaultConstantMemoryPool(const IRModule& module) { |
93 | auto p = CreateDefaultWorkspaceMemoryPool(module); |
94 | return ConstantPoolInfo( |
95 | "global_const_workspace" , {p->targets}, {}, |
96 | PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, |
97 | kUnknownWriteBandwidth, 0, 0, {p->target_burst_bytes}, Bool(true))); |
98 | } |
99 | }; |
100 | |
101 | WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) { |
102 | VLOG(1) << "Creating default memory pool for:" << std::endl << module; |
103 | Map<Target, String> target_access; |
104 | tir::PrimFunc tir_main_func = |
105 | Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
106 | Target target_host = tir_main_func->GetAttr<Target>(tvm::attr::kTarget).value(); |
107 | for (const auto& kv : module->functions) { |
108 | BaseFunc func = kv.second; |
109 | Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget); |
110 | target_access.Set(target.value_or(target_host), kTargetPoolReadWriteAccess); |
111 | } |
112 | Array<Target> targets; |
113 | for (const auto& kv : target_access) { |
114 | bool exist = false; |
115 | // Exclude targets with the same string representation |
116 | for (const auto& t : targets) { |
117 | if (t->str() == kv.first->str()) { |
118 | exist = true; |
119 | } |
120 | } |
121 | if (!exist) { |
122 | targets.push_back(kv.first); |
123 | } |
124 | } |
125 | return WorkspacePoolInfo( |
126 | "global_workspace" , targets, |
127 | PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, |
128 | kUnknownWriteBandwidth, 0, 0, {{target_host, 1}}, Bool(true))); |
129 | } |
130 | |
131 | Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { |
132 | Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value(); |
133 | ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; |
134 | Map<String, ObjectRef> annotations = Map<String, ObjectRef>(op->annotations); |
135 | if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { |
136 | ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0) |
137 | << "Target " << tgt << " not found among " << target_pool_infos_; |
138 | annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]); |
139 | } |
140 | Stmt body = VisitStmt(op->body); |
141 | auto allocate = |
142 | Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body, annotations); |
143 | return std::move(allocate); |
144 | } |
145 | |
146 | Stmt PoolInfoAssigner::VisitStmt_(const AllocateConstNode* op) { |
147 | if (!target_const_pool_infos_.size()) { |
148 | return StmtExprMutator::VisitStmt_(op); |
149 | } |
150 | Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value(); |
151 | ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; |
152 | Map<String, ObjectRef> annotations = Map<String, ObjectRef>(op->annotations); |
153 | if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { |
154 | annotations.Set(kPoolCandidatesAllocateAttr, target_const_pool_infos_[tgt.value()->str()]); |
155 | annotations.Set(kTargetPoolReadOnlyAccess, Integer(1)); |
156 | } |
157 | Stmt body = VisitStmt(op->body); |
158 | auto allocate_const = |
159 | AllocateConst(op->buffer_var, op->dtype, op->extents, op->data, body, annotations); |
160 | return std::move(allocate_const); |
161 | } |
162 | |
163 | IRModule PoolInfoAssigner::operator()() { |
164 | for (const auto& kv : mod_->functions) { |
165 | GlobalVar gv = kv.first; |
166 | if (kv.second->IsInstance<PrimFuncNode>()) { |
167 | func_ = Downcast<PrimFunc>(kv.second); |
168 | Stmt body = this->VisitStmt(func_->body); |
169 | PrimFunc new_prim_func = |
170 | PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs); |
171 | mod_->Update(gv, new_prim_func); |
172 | } |
173 | } |
174 | return mod_; |
175 | } |
176 | |
177 | namespace transform { |
178 | |
179 | tvm::transform::Pass AssignPoolInfo() { |
180 | auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { |
181 | return PoolInfoAssigner(m)(); |
182 | }; |
183 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.AssignPoolInfo" , {}); |
184 | } |
185 | |
186 | TVM_REGISTER_GLOBAL("tir.usmp.transform.AssignPoolInfo" ).set_body_typed(AssignPoolInfo); |
187 | |
188 | } // namespace transform |
189 | |
190 | } // namespace usmp |
191 | } // namespace tir |
192 | } // namespace tvm |
193 | |