1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file split_host_device.cc |
22 | * \brief Split device function from host. |
23 | */ |
24 | #include <tvm/ir/global_var_supply.h> |
25 | #include <tvm/ir/transform.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/target/target.h> |
28 | #include <tvm/tir/analysis.h> |
29 | #include <tvm/tir/builtin.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/op.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | #include <tvm/tir/transform.h> |
34 | |
35 | #include <unordered_map> |
36 | |
37 | #include "../../runtime/thread_storage_scope.h" |
38 | #include "ir_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace tir { |
42 | |
43 | // use/def analysis, also delete unreferenced lets |
44 | class VarUseDefAnalysis : public StmtExprMutator { |
45 | public: |
46 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
47 | if (op->attr_key == attr::thread_extent) { |
48 | IterVar iv = Downcast<IterVar>(op->node); |
49 | ICHECK_NE(iv->thread_tag.length(), 0U); |
50 | // thread_extent can appear multiple times |
51 | // use the first appearance as def. |
52 | if (!use_count_.count(iv->var.get())) { |
53 | this->HandleDef(iv->var.get()); |
54 | thread_axis_.push_back(iv); |
55 | thread_extent_.push_back(op->value); |
56 | } |
57 | |
58 | PrimExpr value = op->value; |
59 | if (visit_thread_extent_) { |
60 | value = this->VisitExpr(value); |
61 | } |
62 | Stmt body = this->VisitStmt(op->body); |
63 | if (value.same_as(op->value) && body.same_as(op->body)) { |
64 | return GetRef<Stmt>(op); |
65 | } |
66 | return AttrStmt(op->node, op->attr_key, value, body); |
67 | } else { |
68 | return StmtExprMutator::VisitStmt_(op); |
69 | } |
70 | } |
71 | |
72 | Stmt VisitStmt_(const LetStmtNode* op) final { |
73 | this->HandleDef(op->var.get()); |
74 | Stmt body = this->VisitStmt(op->body); |
75 | // eliminate unreferenced let |
76 | if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && |
77 | simplify_let_) { |
78 | return body; |
79 | } else { |
80 | PrimExpr value = this->VisitExpr(op->value); |
81 | if (body.same_as(op->body) && value.same_as(op->value)) { |
82 | return GetRef<Stmt>(op); |
83 | } else { |
84 | return LetStmt(op->var, value, body); |
85 | } |
86 | } |
87 | } |
88 | |
89 | Stmt VisitStmt_(const ForNode* op) final { |
90 | this->HandleDef(op->loop_var.get()); |
91 | return StmtExprMutator::VisitStmt_(op); |
92 | } |
93 | |
94 | Stmt VisitStmt_(const AllocateNode* op) final { |
95 | this->HandleDef(op->buffer_var.get()); |
96 | auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); |
97 | if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn" ) { |
98 | ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed." ; |
99 | ICHECK_GT(op->extents.size(), 0); |
100 | dyn_shmem_size_ = op->extents[0]; |
101 | for (size_t i = 1; i < op->extents.size(); ++i) { |
102 | dyn_shmem_size_ *= op->extents[i]; |
103 | } |
104 | dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); |
105 | use_dyn_shmem_ = true; |
106 | } |
107 | return StmtExprMutator::VisitStmt_(op); |
108 | } |
109 | |
110 | Stmt VisitStmt_(const AllocateConstNode* op) final { |
111 | this->HandleDef(op->buffer_var.get()); |
112 | return StmtExprMutator::VisitStmt_(op); |
113 | } |
114 | |
115 | Stmt VisitStmt_(const StoreNode* op) final { |
116 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
117 | } |
118 | |
119 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
120 | VisitBuffer(op->buffer); |
121 | return StmtExprMutator::VisitStmt_(op); |
122 | } |
123 | |
124 | PrimExpr VisitExpr_(const LetNode* op) final { |
125 | // Weaker SSA condition |
126 | // A single var can be binded in multiple lets |
127 | // but they have to bind to the same value. |
128 | // This is used to allow cases when we reuse a single let |
129 | // expression to construct a nested expr. |
130 | // (let x = 1 in x + 1) * (let x = 1 in x + 1) |
131 | auto it = let_binding_.find(op->var); |
132 | PrimExpr value = this->VisitExpr(op->value); |
133 | if (it != let_binding_.end()) { |
134 | ICHECK(deep_equal_(it->second->value, value)) |
135 | << "Let cannot bind the same var to two different values" ; |
136 | return GetRef<PrimExpr>(it->second); |
137 | } else { |
138 | this->HandleDef(op->var.get()); |
139 | let_binding_[op->var] = op; |
140 | } |
141 | PrimExpr body = this->VisitExpr(op->body); |
142 | // eliminate unreferenced let |
143 | if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && |
144 | simplify_let_) { |
145 | return body; |
146 | } else { |
147 | if (body.same_as(op->body) && value.same_as(op->value)) { |
148 | return GetRef<PrimExpr>(op); |
149 | } else { |
150 | return Let(op->var, value, body); |
151 | } |
152 | } |
153 | } |
154 | |
155 | PrimExpr VisitExpr_(const VarNode* op) final { |
156 | this->HandleUse(GetRef<PrimExpr>(op)); |
157 | return StmtExprMutator::VisitExpr_(op); |
158 | } |
159 | |
160 | PrimExpr VisitExpr_(const ReduceNode* op) final { |
161 | for (const auto& iv : op->axis) { |
162 | this->HandleDef(iv->var.get()); |
163 | } |
164 | return StmtExprMutator::VisitExpr_(op); |
165 | } |
166 | |
167 | PrimExpr VisitExpr_(const LoadNode* op) final { |
168 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
169 | } |
170 | |
171 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
172 | VisitBuffer(op->buffer); |
173 | return StmtExprMutator::VisitExpr_(op); |
174 | } |
175 | |
176 | void VisitBuffer(Buffer buffer) { |
177 | this->HandleUse(buffer->data); |
178 | auto visit_arr = [&](Array<PrimExpr> arr) { |
179 | for (const auto& element : arr) { |
180 | this->VisitExpr(element); |
181 | } |
182 | }; |
183 | |
184 | visit_arr(buffer->shape); |
185 | visit_arr(buffer->strides); |
186 | } |
187 | |
188 | void HandleDef(const VarNode* v) { |
189 | ICHECK(!def_count_.count(v)) << "variable " << v->name_hint |
190 | << " has already been defined, the Stmt is not SSA" ; |
191 | ICHECK(!use_count_.count(v)) << "variable " << v->name_hint |
192 | << " has been used before definition!" ; |
193 | use_count_[v] = 0; |
194 | def_count_[v] = 1; |
195 | } |
196 | |
197 | void HandleUse(const PrimExpr& v) { |
198 | ICHECK(v.as<VarNode>()); |
199 | Var var = Downcast<Var>(v); |
200 | auto it = use_count_.find(var.get()); |
201 | if (it != use_count_.end()) { |
202 | if (it->second >= 0) { |
203 | ++it->second; |
204 | } |
205 | } else { |
206 | undefined_.push_back(var); |
207 | use_count_[var.get()] = -1; |
208 | } |
209 | } |
210 | |
211 | // The fields are publically readible to |
212 | // be accessible to the users. |
213 | bool visit_thread_extent_{true}; |
214 | bool simplify_let_{true}; |
215 | Array<Var> undefined_; |
216 | Array<IterVar> thread_axis_; |
217 | Array<PrimExpr> thread_extent_; |
218 | PrimExpr dyn_shmem_size_{0}; |
219 | bool use_dyn_shmem_{false}; |
220 | std::unordered_map<const VarNode*, int> use_count_; |
221 | std::unordered_map<const VarNode*, int> def_count_; |
222 | |
223 | private: |
224 | ExprDeepEqual deep_equal_; |
225 | std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_; |
226 | }; |
227 | |
228 | Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) { |
229 | VarUseDefAnalysis m; |
230 | m.simplify_let_ = false; |
231 | for (Var arg : args) { |
232 | m.use_count_[arg.get()] = 0; |
233 | } |
234 | m(stmt); |
235 | return m.undefined_; |
236 | } |
237 | |
238 | Array<Var> UndefinedVars(const PrimExpr& expr) { |
239 | VarUseDefAnalysis m; |
240 | m.simplify_let_ = false; |
241 | m(expr); |
242 | return m.undefined_; |
243 | } |
244 | |
245 | class HostDeviceSplitter : public StmtMutator { |
246 | public: |
247 | explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) |
248 | : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {} |
249 | |
250 | Stmt VisitStmt_(const AllocateNode* op) final { |
251 | handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); |
252 | return StmtMutator::VisitStmt_(op); |
253 | } |
254 | |
255 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
256 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || |
257 | op->attr_key == attr::device_scope) { |
258 | return SplitDeviceFunc(GetRef<Stmt>(op)); |
259 | } |
260 | return StmtMutator::VisitStmt_(op); |
261 | } |
262 | |
263 | private: |
264 | Stmt SplitDeviceFunc(Stmt body) { |
265 | std::ostringstream os; |
266 | os << name_prefix_ << "_kernel" << device_func_counter_++; |
267 | std::string kernel_symbol = os.str(); |
268 | // isolate the device function. |
269 | VarUseDefAnalysis m; |
270 | m.visit_thread_extent_ = false; |
271 | body = m(std::move(body)); |
272 | |
273 | Array<Var> params; |
274 | Array<PrimExpr> arguments; |
275 | Map<tir::Var, PrimExpr> remap_vars; |
276 | |
277 | // Strictly order the arguments: Var pointers, positional arguments. |
278 | for (Var var : m.undefined_) { |
279 | if (var.dtype().is_handle()) { |
280 | // Create a new version of v. |
281 | auto it = handle_data_type_.find(var.get()); |
282 | if (it != handle_data_type_.end()) { |
283 | String storage_scope; |
284 | if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) { |
285 | storage_scope = ptr_type->storage_scope; |
286 | } |
287 | tir::Var new_var(var->name_hint, |
288 | PointerType(PrimType((*it).second->dtype), storage_scope)); |
289 | params.push_back(new_var); |
290 | remap_vars.Set(var, new_var); |
291 | } else { |
292 | params.push_back(var); |
293 | } |
294 | arguments.push_back(var); |
295 | } |
296 | } |
297 | // positional arguments |
298 | for (Var var : m.undefined_) { |
299 | if (!var.dtype().is_handle()) { |
300 | params.push_back(var); |
301 | arguments.push_back(var); |
302 | } |
303 | } |
304 | GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_); |
305 | GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); |
306 | |
307 | PrimFunc device_func(params, Substitute(body, remap_vars)); |
308 | device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); |
309 | device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, |
310 | Integer(CallingConv::kDeviceKernelLaunch)); |
311 | device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, |
312 | runtime::String(kernel_symbol_global->name_hint)); |
313 | device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); |
314 | device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); |
315 | device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); |
316 | if (m.use_dyn_shmem_) { |
317 | device_func = |
318 | WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); |
319 | } |
320 | (*device_mod_)->Add(kernel_symbol_global, device_func); |
321 | |
322 | // generate calls to the device function |
323 | Array<PrimExpr> call_args; |
324 | call_args.push_back(StringImm(kernel_symbol_global->name_hint)); |
325 | for (PrimExpr arg : arguments) { |
326 | call_args.push_back(arg); |
327 | } |
328 | for (PrimExpr ext : m.thread_extent_) { |
329 | call_args.push_back(ext); |
330 | } |
331 | if (m.use_dyn_shmem_) { |
332 | call_args.push_back(m.dyn_shmem_size_); |
333 | } |
334 | return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); |
335 | } |
336 | |
337 | // target ir module |
338 | IRModule* device_mod_; |
339 | // Device target |
340 | Target device_target_; |
341 | // function name hint |
342 | std::string name_prefix_; |
343 | // Number of device functions. |
344 | int device_func_counter_{0}; |
345 | std::unordered_map<const VarNode*, PrimExpr> handle_data_type_; |
346 | }; |
347 | |
348 | PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { |
349 | auto target = func->GetAttr<Target>(tvm::attr::kTarget); |
350 | ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute" ; |
351 | auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
352 | ICHECK(global_symbol.defined()) |
353 | << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute" ; |
354 | |
355 | HostDeviceSplitter splitter(device_mod, target.value(), |
356 | static_cast<std::string>(global_symbol.value())); |
357 | |
358 | auto* n = func.CopyOnWrite(); |
359 | n->body = splitter(std::move(n->body)); |
360 | // set the host target to None. |
361 | func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); |
362 | return std::move(func); |
363 | } |
364 | |
365 | namespace transform { |
366 | |
367 | Pass SplitHostDevice() { |
368 | auto pass_func = [](IRModule mod, PassContext ctx) { |
369 | IRModuleNode* mod_ptr = mod.CopyOnWrite(); |
370 | auto* func_dict = mod_ptr->functions.CopyOnWrite(); |
371 | IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({})); |
372 | |
373 | for (auto& kv : *func_dict) { |
374 | if (kv.second->IsInstance<PrimFuncNode>()) { |
375 | PrimFunc func = Downcast<PrimFunc>(std::move(kv.second)); |
376 | ICHECK(device_mod.defined()) << "The device module must be defined." ; |
377 | kv.second = SplitHostDevice(std::move(func), &device_mod); |
378 | } |
379 | } |
380 | mod->Update(device_mod); |
381 | return mod; |
382 | }; |
383 | |
384 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice" , {}); |
385 | } |
386 | |
387 | TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice" ).set_body_typed(SplitHostDevice); |
388 | |
389 | } // namespace transform |
390 | } // namespace tir |
391 | } // namespace tvm |
392 | |