1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file split_host_device.cc
22 * \brief Split device function from host.
23 */
24#include <tvm/ir/global_var_supply.h>
25#include <tvm/ir/transform.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/target/target.h>
28#include <tvm/tir/analysis.h>
29#include <tvm/tir/builtin.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/op.h>
32#include <tvm/tir/stmt_functor.h>
33#include <tvm/tir/transform.h>
34
35#include <unordered_map>
36
37#include "../../runtime/thread_storage_scope.h"
38#include "ir_utils.h"
39
40namespace tvm {
41namespace tir {
42
43// use/def analysis, also delete unreferenced lets
44class VarUseDefAnalysis : public StmtExprMutator {
45 public:
46 Stmt VisitStmt_(const AttrStmtNode* op) final {
47 if (op->attr_key == attr::thread_extent) {
48 IterVar iv = Downcast<IterVar>(op->node);
49 ICHECK_NE(iv->thread_tag.length(), 0U);
50 // thread_extent can appear multiple times
51 // use the first appearance as def.
52 if (!use_count_.count(iv->var.get())) {
53 this->HandleDef(iv->var.get());
54 thread_axis_.push_back(iv);
55 thread_extent_.push_back(op->value);
56 }
57
58 PrimExpr value = op->value;
59 if (visit_thread_extent_) {
60 value = this->VisitExpr(value);
61 }
62 Stmt body = this->VisitStmt(op->body);
63 if (value.same_as(op->value) && body.same_as(op->body)) {
64 return GetRef<Stmt>(op);
65 }
66 return AttrStmt(op->node, op->attr_key, value, body);
67 } else {
68 return StmtExprMutator::VisitStmt_(op);
69 }
70 }
71
72 Stmt VisitStmt_(const LetStmtNode* op) final {
73 this->HandleDef(op->var.get());
74 Stmt body = this->VisitStmt(op->body);
75 // eliminate unreferenced let
76 if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
77 simplify_let_) {
78 return body;
79 } else {
80 PrimExpr value = this->VisitExpr(op->value);
81 if (body.same_as(op->body) && value.same_as(op->value)) {
82 return GetRef<Stmt>(op);
83 } else {
84 return LetStmt(op->var, value, body);
85 }
86 }
87 }
88
89 Stmt VisitStmt_(const ForNode* op) final {
90 this->HandleDef(op->loop_var.get());
91 return StmtExprMutator::VisitStmt_(op);
92 }
93
94 Stmt VisitStmt_(const AllocateNode* op) final {
95 this->HandleDef(op->buffer_var.get());
96 auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
97 if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") {
98 ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed.";
99 ICHECK_GT(op->extents.size(), 0);
100 dyn_shmem_size_ = op->extents[0];
101 for (size_t i = 1; i < op->extents.size(); ++i) {
102 dyn_shmem_size_ *= op->extents[i];
103 }
104 dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes());
105 use_dyn_shmem_ = true;
106 }
107 return StmtExprMutator::VisitStmt_(op);
108 }
109
110 Stmt VisitStmt_(const AllocateConstNode* op) final {
111 this->HandleDef(op->buffer_var.get());
112 return StmtExprMutator::VisitStmt_(op);
113 }
114
115 Stmt VisitStmt_(const StoreNode* op) final {
116 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
117 }
118
119 Stmt VisitStmt_(const BufferStoreNode* op) final {
120 VisitBuffer(op->buffer);
121 return StmtExprMutator::VisitStmt_(op);
122 }
123
124 PrimExpr VisitExpr_(const LetNode* op) final {
125 // Weaker SSA condition
126 // A single var can be binded in multiple lets
127 // but they have to bind to the same value.
128 // This is used to allow cases when we reuse a single let
129 // expression to construct a nested expr.
130 // (let x = 1 in x + 1) * (let x = 1 in x + 1)
131 auto it = let_binding_.find(op->var);
132 PrimExpr value = this->VisitExpr(op->value);
133 if (it != let_binding_.end()) {
134 ICHECK(deep_equal_(it->second->value, value))
135 << "Let cannot bind the same var to two different values";
136 return GetRef<PrimExpr>(it->second);
137 } else {
138 this->HandleDef(op->var.get());
139 let_binding_[op->var] = op;
140 }
141 PrimExpr body = this->VisitExpr(op->body);
142 // eliminate unreferenced let
143 if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
144 simplify_let_) {
145 return body;
146 } else {
147 if (body.same_as(op->body) && value.same_as(op->value)) {
148 return GetRef<PrimExpr>(op);
149 } else {
150 return Let(op->var, value, body);
151 }
152 }
153 }
154
155 PrimExpr VisitExpr_(const VarNode* op) final {
156 this->HandleUse(GetRef<PrimExpr>(op));
157 return StmtExprMutator::VisitExpr_(op);
158 }
159
160 PrimExpr VisitExpr_(const ReduceNode* op) final {
161 for (const auto& iv : op->axis) {
162 this->HandleDef(iv->var.get());
163 }
164 return StmtExprMutator::VisitExpr_(op);
165 }
166
167 PrimExpr VisitExpr_(const LoadNode* op) final {
168 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
169 }
170
171 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
172 VisitBuffer(op->buffer);
173 return StmtExprMutator::VisitExpr_(op);
174 }
175
176 void VisitBuffer(Buffer buffer) {
177 this->HandleUse(buffer->data);
178 auto visit_arr = [&](Array<PrimExpr> arr) {
179 for (const auto& element : arr) {
180 this->VisitExpr(element);
181 }
182 };
183
184 visit_arr(buffer->shape);
185 visit_arr(buffer->strides);
186 }
187
188 void HandleDef(const VarNode* v) {
189 ICHECK(!def_count_.count(v)) << "variable " << v->name_hint
190 << " has already been defined, the Stmt is not SSA";
191 ICHECK(!use_count_.count(v)) << "variable " << v->name_hint
192 << " has been used before definition!";
193 use_count_[v] = 0;
194 def_count_[v] = 1;
195 }
196
197 void HandleUse(const PrimExpr& v) {
198 ICHECK(v.as<VarNode>());
199 Var var = Downcast<Var>(v);
200 auto it = use_count_.find(var.get());
201 if (it != use_count_.end()) {
202 if (it->second >= 0) {
203 ++it->second;
204 }
205 } else {
206 undefined_.push_back(var);
207 use_count_[var.get()] = -1;
208 }
209 }
210
211 // The fields are publically readible to
212 // be accessible to the users.
213 bool visit_thread_extent_{true};
214 bool simplify_let_{true};
215 Array<Var> undefined_;
216 Array<IterVar> thread_axis_;
217 Array<PrimExpr> thread_extent_;
218 PrimExpr dyn_shmem_size_{0};
219 bool use_dyn_shmem_{false};
220 std::unordered_map<const VarNode*, int> use_count_;
221 std::unordered_map<const VarNode*, int> def_count_;
222
223 private:
224 ExprDeepEqual deep_equal_;
225 std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
226};
227
228Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
229 VarUseDefAnalysis m;
230 m.simplify_let_ = false;
231 for (Var arg : args) {
232 m.use_count_[arg.get()] = 0;
233 }
234 m(stmt);
235 return m.undefined_;
236}
237
238Array<Var> UndefinedVars(const PrimExpr& expr) {
239 VarUseDefAnalysis m;
240 m.simplify_let_ = false;
241 m(expr);
242 return m.undefined_;
243}
244
245class HostDeviceSplitter : public StmtMutator {
246 public:
247 explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix)
248 : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {}
249
250 Stmt VisitStmt_(const AllocateNode* op) final {
251 handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
252 return StmtMutator::VisitStmt_(op);
253 }
254
255 Stmt VisitStmt_(const AttrStmtNode* op) final {
256 if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
257 op->attr_key == attr::device_scope) {
258 return SplitDeviceFunc(GetRef<Stmt>(op));
259 }
260 return StmtMutator::VisitStmt_(op);
261 }
262
263 private:
264 Stmt SplitDeviceFunc(Stmt body) {
265 std::ostringstream os;
266 os << name_prefix_ << "_kernel" << device_func_counter_++;
267 std::string kernel_symbol = os.str();
268 // isolate the device function.
269 VarUseDefAnalysis m;
270 m.visit_thread_extent_ = false;
271 body = m(std::move(body));
272
273 Array<Var> params;
274 Array<PrimExpr> arguments;
275 Map<tir::Var, PrimExpr> remap_vars;
276
277 // Strictly order the arguments: Var pointers, positional arguments.
278 for (Var var : m.undefined_) {
279 if (var.dtype().is_handle()) {
280 // Create a new version of v.
281 auto it = handle_data_type_.find(var.get());
282 if (it != handle_data_type_.end()) {
283 String storage_scope;
284 if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
285 storage_scope = ptr_type->storage_scope;
286 }
287 tir::Var new_var(var->name_hint,
288 PointerType(PrimType((*it).second->dtype), storage_scope));
289 params.push_back(new_var);
290 remap_vars.Set(var, new_var);
291 } else {
292 params.push_back(var);
293 }
294 arguments.push_back(var);
295 }
296 }
297 // positional arguments
298 for (Var var : m.undefined_) {
299 if (!var.dtype().is_handle()) {
300 params.push_back(var);
301 arguments.push_back(var);
302 }
303 }
304 GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
305 GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false);
306
307 PrimFunc device_func(params, Substitute(body, remap_vars));
308 device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_);
309 device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
310 Integer(CallingConv::kDeviceKernelLaunch));
311 device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
312 runtime::String(kernel_symbol_global->name_hint));
313 device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
314 device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
315 device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1));
316 if (m.use_dyn_shmem_) {
317 device_func =
318 WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1));
319 }
320 (*device_mod_)->Add(kernel_symbol_global, device_func);
321
322 // generate calls to the device function
323 Array<PrimExpr> call_args;
324 call_args.push_back(StringImm(kernel_symbol_global->name_hint));
325 for (PrimExpr arg : arguments) {
326 call_args.push_back(arg);
327 }
328 for (PrimExpr ext : m.thread_extent_) {
329 call_args.push_back(ext);
330 }
331 if (m.use_dyn_shmem_) {
332 call_args.push_back(m.dyn_shmem_size_);
333 }
334 return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args));
335 }
336
337 // target ir module
338 IRModule* device_mod_;
339 // Device target
340 Target device_target_;
341 // function name hint
342 std::string name_prefix_;
343 // Number of device functions.
344 int device_func_counter_{0};
345 std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
346};
347
348PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
349 auto target = func->GetAttr<Target>(tvm::attr::kTarget);
350 ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute";
351 auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
352 ICHECK(global_symbol.defined())
353 << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
354
355 HostDeviceSplitter splitter(device_mod, target.value(),
356 static_cast<std::string>(global_symbol.value()));
357
358 auto* n = func.CopyOnWrite();
359 n->body = splitter(std::move(n->body));
360 // set the host target to None.
361 func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
362 return std::move(func);
363}
364
365namespace transform {
366
367Pass SplitHostDevice() {
368 auto pass_func = [](IRModule mod, PassContext ctx) {
369 IRModuleNode* mod_ptr = mod.CopyOnWrite();
370 auto* func_dict = mod_ptr->functions.CopyOnWrite();
371 IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
372
373 for (auto& kv : *func_dict) {
374 if (kv.second->IsInstance<PrimFuncNode>()) {
375 PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
376 ICHECK(device_mod.defined()) << "The device module must be defined.";
377 kv.second = SplitHostDevice(std::move(func), &device_mod);
378 }
379 }
380 mod->Update(device_mod);
381 return mod;
382 };
383
384 return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {});
385}
386
387TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice);
388
389} // namespace transform
390} // namespace tir
391} // namespace tvm
392