1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Compile executable modules. |
22 | * \file driver_api.cc |
23 | */ |
24 | #include <dmlc/thread_local.h> |
25 | #include <tvm/driver/driver_api.h> |
26 | #include <tvm/ir/transform.h> |
27 | #include <tvm/relay/executor.h> |
28 | #include <tvm/relay/runtime.h> |
29 | #include <tvm/runtime/registry.h> |
30 | #include <tvm/target/codegen.h> |
31 | #include <tvm/te/operation.h> |
32 | #include <tvm/tir/analysis.h> |
33 | #include <tvm/tir/transform.h> |
34 | |
35 | #include <algorithm> |
36 | #include <mutex> |
37 | #include <stack> |
38 | |
39 | namespace tvm { |
40 | |
41 | // Register build pipeline related options |
42 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias" , Bool); |
43 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier" , Bool); |
44 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers" , Bool); |
45 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert" , Bool); |
46 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize" , Bool); |
47 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir" , Bool); |
48 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug" , Bool); |
49 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir" , Bool); |
50 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite" , Bool); |
51 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func" , Bool); |
52 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass" , Array<Array<ObjectRef>>); |
53 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop" , Bool); |
54 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy" , Bool); |
55 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope" , Bool); |
56 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp" , Bool); |
57 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity" , Integer); |
58 | |
59 | // WARNING: May cause coherency issues resulting data miscompares |
60 | // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When |
61 | // bypassing the cache TVM must manage cache coherency in software. Software managed cache coherency |
62 | // can be tricky e.g. it is yet to be proven out in the Hexagon runtime. Hence the warning above and |
63 | // the "experimental" notation for this feature. |
64 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.experimental_dma_bypass_cache" , Bool); |
65 | |
66 | using tvm::Array; |
67 | using tvm::transform::Pass; |
68 | |
69 | bool LLVMEnabled() { |
70 | const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm" ); |
71 | return pf != nullptr; |
72 | } |
73 | |
74 | bool ShouldAnnotateEntryFunc(const IRModule mod) { |
75 | Optional<tvm::relay::Executor> executor = mod->GetAttr<tvm::relay::Executor>("executor" ); |
76 | const bool aot_executor = executor.defined() && executor.value()->name == "aot" ; |
77 | const bool single_entry_func = (mod->functions.size() == 1); |
78 | return single_entry_func && !aot_executor; |
79 | } |
80 | |
81 | /*! \return The default host target for a given device target */ |
82 | Target DefaultTargetHost(Target target) { |
83 | if (target.defined() && target->GetTargetDeviceType() == kDLCPU) { |
84 | return target; |
85 | } else { |
86 | if (LLVMEnabled()) { |
87 | return Target("llvm" ); |
88 | } else { |
89 | return Target("stackvm" ); |
90 | } |
91 | } |
92 | } |
93 | |
94 | void GetBinds(const Array<ObjectRef>& args, bool compact, |
95 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
96 | Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) { |
97 | *out_binds = binds; |
98 | |
99 | for (const ObjectRef& x : args) { |
100 | if (const te::TensorNode* tensor_node = x.as<te::TensorNode>()) { |
101 | te::Tensor x_ref = GetRef<te::Tensor>(tensor_node); |
102 | if (out_binds->find(x_ref) == out_binds->end()) { |
103 | tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, |
104 | x_ref->op->name, -1, 0, compact); |
105 | out_binds->Set(x_ref, buf); |
106 | out_arg_list->push_back(buf); |
107 | } else { |
108 | out_arg_list->push_back((*out_binds)[x_ref]); |
109 | } |
110 | } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) { |
111 | out_arg_list->push_back(x); |
112 | } else { |
113 | LOG(FATAL) |
114 | << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " |
115 | << "but got a " << x->GetTypeKey(); |
116 | } |
117 | } |
118 | } |
119 | |
120 | void GetBinds(const Array<te::Tensor>& args, bool compact, |
121 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
122 | Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) { |
123 | Array<ObjectRef> ref_args; |
124 | for (ObjectRef x : args) { |
125 | ref_args.push_back(x); |
126 | } |
127 | GetBinds(ref_args, compact, binds, out_binds, out_arg_list); |
128 | } |
129 | |
130 | TVM_REGISTER_GLOBAL("driver.get_binds" ) |
131 | .set_body_typed([](const Array<ObjectRef>& args, bool compact, |
132 | const Map<te::Tensor, tir::Buffer>& binds) { |
133 | std::unordered_map<te::Tensor, tir::Buffer> c_binds; |
134 | // Check to make sure binds is not null before doing the conversion; |
135 | if (binds.get() != nullptr) { |
136 | for (auto kv : binds) { |
137 | c_binds.insert({kv.first, kv.second}); |
138 | } |
139 | } |
140 | Map<te::Tensor, tir::Buffer> out_binds; |
141 | Array<ObjectRef> out_arg_list; |
142 | GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); |
143 | |
144 | // TVM object system doesn't have a pair object, so we'll put both ret values in an array |
145 | // and return that. |
146 | Array<ObjectRef> out_arr = {out_binds, out_arg_list}; |
147 | return out_arr; |
148 | }); |
149 | |
150 | Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) { |
151 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
152 | |
153 | bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize" , Bool(false)).value(); |
154 | bool disable_storage_rewrite = |
155 | pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite" , Bool(false)).value(); |
156 | bool instrument_bound_checkers = |
157 | pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers" , Bool(false)).value(); |
158 | bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir" , Bool(false)).value(); |
159 | bool enable_equiv_terms_in_cse_tir = |
160 | pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir" , Bool(false)).value(); |
161 | |
162 | // Get any user-added passes |
163 | Array<Array<ObjectRef>> add_lower_pass = |
164 | pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass" , Array<Array<ObjectRef>>()) |
165 | .value(); |
166 | |
167 | bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp" , Bool(false)).value(); |
168 | |
169 | Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>(); |
170 | Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>(); |
171 | Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>(); |
172 | Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>(); |
173 | |
174 | // phase passes is of the form |
175 | // [[phase_number, pass], [phase_number, pass]... ] |
176 | for (Array<ObjectRef> phase_pass : add_lower_pass) { |
177 | const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>(); |
178 | ICHECK(phase_num) |
179 | << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer" ; |
180 | int phase_num_val = phase_num->value; |
181 | |
182 | CHECK_GE(phase_num_val, 0); |
183 | |
184 | const tvm::transform::PassNode* pass_node = phase_pass[1].as<tvm::transform::PassNode>(); |
185 | tvm::transform::Pass pass = GetRef<tvm::transform::Pass>(pass_node); |
186 | // Copy the pass into the correct phase |
187 | if (phase_num_val == 0) { |
188 | user_lower_phase0.push_back(pass); |
189 | } else if (phase_num_val == 1) { |
190 | user_lower_phase1.push_back(pass); |
191 | } else if (phase_num_val == 2) { |
192 | user_lower_phase2.push_back(pass); |
193 | } else if (phase_num_val >= 3) { |
194 | user_lower_phase3.push_back(pass); |
195 | } |
196 | } |
197 | |
198 | // Construct the pass list, inserting the user provided passes at the end of the phase |
199 | |
200 | // PHASE 0 |
201 | Array<tvm::transform::Pass> pass_list = user_lower_phase0; |
202 | |
203 | // PHASE 1 |
204 | pass_list.push_back(tir::transform::InjectPrefetch()); |
205 | pass_list.push_back(tir::transform::TextureFlatten()); |
206 | pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); |
207 | pass_list.push_back(tir::transform::LowerCrossThreadReduction()); |
208 | pass_list.push_back(tir::transform::LowerInitBlock()); |
209 | pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); |
210 | pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); |
211 | pass_list.push_back(tir::transform::UnifyThreadBinding()); |
212 | pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); |
213 | pass_list.push_back(tir::transform::CompactBufferAllocation()); |
214 | pass_list.push_back(tir::transform::LowerMatchBuffer()); |
215 | pass_list.push_back(tir::transform::InjectSoftwarePipeline()); |
216 | pass_list.push_back(tir::transform::LowerOpaqueBlock()); |
217 | pass_list.push_back(tir::transform::FlattenBuffer()); |
218 | pass_list.push_back(tir::transform::BF16Legalize()); |
219 | pass_list.push_back(tir::transform::NarrowDataType(32)); |
220 | pass_list.push_back(tir::transform::Simplify()); |
221 | |
222 | // Add user-defined phase-1 passes |
223 | pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); |
224 | |
225 | // PHASE 2 |
226 | if (!disable_loop_partition) { |
227 | pass_list.push_back(tir::transform::LoopPartition()); |
228 | } |
229 | |
230 | pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); |
231 | pass_list.push_back(tir::transform::InjectVirtualThread()); |
232 | pass_list.push_back(tir::transform::InjectDoubleBuffer()); |
233 | if (!disable_storage_rewrite) { |
234 | pass_list.push_back(tir::transform::StorageRewrite()); |
235 | } |
236 | bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy" , Bool(false)).value(); |
237 | |
238 | if (use_async_copy) { |
239 | pass_list.push_back(tir::transform::LowerAsyncDMA()); |
240 | } |
241 | pass_list.push_back(tir::transform::UnrollLoop()); |
242 | |
243 | // Add user-defined phase-2 passes |
244 | pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); |
245 | |
246 | // PHASE 3 |
247 | pass_list.push_back(tir::transform::RenormalizeSplitPattern()); |
248 | pass_list.push_back(tir::transform::Simplify()); |
249 | pass_list.push_back(tir::transform::RemoveNoOp()); |
250 | pass_list.push_back(tir::transform::RewriteUnsafeSelect()); |
251 | pass_list.push_back(tir::transform::HoistIfThenElse()); |
252 | |
253 | // Add user-defined phase-3 passes |
254 | pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); |
255 | |
256 | if (instrument_bound_checkers) { |
257 | pass_list.push_back(tir::transform::InstrumentBoundCheckers()); |
258 | } |
259 | |
260 | pass_list.push_back( |
261 | tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir)); |
262 | |
263 | // This pass instruments the loops with the profile builtin calls to capture the runtime |
264 | // performance data (only enabled for Hexagon at the moment). To ensure that no other |
265 | // optimizations are performed on the instrumented code, this pass must be added at the end |
266 | // of the list. |
267 | if (instrument_lwp) { |
268 | pass_list.push_back(tir::transform::InstrumentProfileIntrinsics()); |
269 | } |
270 | |
271 | return pass_list; |
272 | } |
273 | |
274 | IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) { |
275 | auto optimize = tvm::transform::Sequential(pass_list); |
276 | mod = optimize(std::move(mod)); |
277 | return mod; |
278 | } |
279 | |
280 | IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { |
281 | mod = seq(std::move(mod)); |
282 | return mod; |
283 | } |
284 | |
285 | // Convert te schedule to IRModule |
286 | IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, |
287 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
288 | GlobalVarSupply global_var_supply) { |
289 | sch = sch.normalize(); |
290 | |
291 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
292 | bool debug_keep_trivial_loop = |
293 | pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop" , Bool(false)).value(); |
294 | |
295 | // Before TIR transformation. |
296 | tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop); |
297 | bool compact = te::VerifyCompactBuffer(stmt); |
298 | |
299 | Map<te::Tensor, tir::Buffer> out_binds; |
300 | Array<ObjectRef> out_arg_list; |
301 | GetBinds(args, compact, binds, &out_binds, &out_arg_list); |
302 | |
303 | // Build the function, converting from te::Tensor to tir::Buffer |
304 | tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); |
305 | f = WithAttr(std::move(f), "global_symbol" , runtime::String(name)); |
306 | |
307 | // Mark this schedule as being converted from an TE schedule. Makes sure that |
308 | // the correct TE passes are run. |
309 | f = WithAttr(std::move(f), "from_legacy_te_schedule" , Bool(true)); |
310 | |
311 | bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias" , Bool(true)).value(); |
312 | |
313 | if (noalias) { |
314 | f = WithAttr(std::move(f), "tir.noalias" , Bool(true)); |
315 | } |
316 | GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false); |
317 | return IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}})); |
318 | } |
319 | |
320 | TVM_REGISTER_GLOBAL("driver.schedule_to_module" ) |
321 | .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name, |
322 | const Map<te::Tensor, tir::Buffer>& binds) { |
323 | std::unordered_map<te::Tensor, tir::Buffer> c_binds; |
324 | // Check to make sure binds is not null before doing the conversion; |
325 | if (binds.defined()) { |
326 | for (auto kv : binds) { |
327 | c_binds.insert({kv.first, kv.second}); |
328 | } |
329 | } |
330 | IRModule mod = |
331 | ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("" ))); |
332 | return mod; |
333 | }); |
334 | |
335 | IRModule LowerModule(IRModule mod, bool simple_mode) { |
336 | Array<transform::Pass> pass_list = CreatePassList(simple_mode); |
337 | return LowerWithPassList(std::move(mod), pass_list); |
338 | } |
339 | |
340 | TVM_REGISTER_GLOBAL("driver.lower_module" ).set_body_typed([](IRModule mod, bool simple_mode) { |
341 | return LowerModule(std::move(mod), simple_mode); |
342 | }); |
343 | |
344 | IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { |
345 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
346 | tir::PrimFunc f = WithAttr(std::move(func), "global_symbol" , runtime::String(name)); |
347 | |
348 | bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias" , Bool(true)).value(); |
349 | |
350 | if (noalias) { |
351 | f = WithAttr(std::move(f), "tir.noalias" , Bool(true)); |
352 | } |
353 | IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}})); |
354 | |
355 | // Get the pass list |
356 | Array<transform::Pass> pass_list = CreatePassList(simple_mode); |
357 | return LowerWithPassList(std::move(mod), pass_list); |
358 | } |
359 | |
360 | TVM_REGISTER_GLOBAL("driver.lower_primfunc" ) |
361 | .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { |
362 | return LowerPrimFunc(std::move(func), name, simple_mode); |
363 | }); |
364 | |
365 | IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name, |
366 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
367 | GlobalVarSupply global_var_supply, bool simple_mode) { |
368 | Array<ObjectRef> ref_args; |
369 | for (ObjectRef x : args) { |
370 | ref_args.push_back(x); |
371 | } |
372 | return LowerSchedule(std::move(sch), ref_args, name, binds, global_var_supply, simple_mode); |
373 | } |
374 | |
375 | IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, |
376 | const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
377 | GlobalVarSupply global_var_supply, bool simple_mode) { |
378 | IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply); |
379 | // Get the legacy TE pass list |
380 | Array<transform::Pass> pass_list = CreatePassList(simple_mode); |
381 | return LowerWithPassList(mod, pass_list); |
382 | } |
383 | |
384 | TVM_REGISTER_GLOBAL("driver.lower_schedule" ) |
385 | .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name, |
386 | const Map<te::Tensor, tir::Buffer>& binds, bool simple_mode) { |
387 | std::unordered_map<te::Tensor, tir::Buffer> c_binds; |
388 | // Check to make sure binds is not null before doing the conversion; |
389 | if (binds.get() != nullptr) { |
390 | for (auto kv : binds) { |
391 | c_binds.insert({kv.first, kv.second}); |
392 | } |
393 | } |
394 | return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("" )), |
395 | simple_mode); |
396 | }); |
397 | |
398 | /** |
399 | * This function takes the input module that contains both the device and host opts. |
400 | * Then, it applies transformation on the original module before splitting into separate modules for |
401 | * device and host. Then it also applies transformations on the new splitted modules. |
402 | */ |
403 | std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg, |
404 | const Target& target_host_arg) { |
405 | Target target = target_arg, target_host = target_host_arg; |
406 | CheckAndUpdateHostConsistency(&target, &target_host); |
407 | |
408 | ICHECK(mod_mixed.defined()) << "This module must be defined" ; |
409 | |
410 | mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); |
411 | |
412 | IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); |
413 | |
414 | IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); |
415 | |
416 | auto keys = target->GetKeys(); |
417 | |
418 | CheckAndUpdateHostConsistency(&target, &target_host); |
419 | |
420 | bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu" ) != keys.end(); |
421 | if (target_is_gpu && device_mod->functions.size() == 0) { |
422 | DLOG(WARNING) << "Specified target " << target->str() |
423 | << " but cannot find device code. Did you forget to bind?" ; |
424 | } |
425 | |
426 | return {host_mod, device_mod}; |
427 | } |
428 | |
429 | runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg, |
430 | const Target& target_host_arg) { |
431 | std::vector<runtime::Module> device_modules; |
432 | Map<Target, IRModule> inputs = inputs_arg; |
433 | Target target_host = target_host_arg; |
434 | |
435 | // Fetch previous defined target host in targets |
436 | CheckAndUpdateHostConsistency(&inputs, &target_host); |
437 | |
438 | if (!target_host.defined()) { |
439 | for (const auto& it : inputs) { |
440 | if (it.first->GetTargetDeviceType() == kDLCPU || |
441 | it.first->GetTargetDeviceType() == kDLMicroDev) { |
442 | target_host = it.first; |
443 | break; |
444 | } |
445 | } |
446 | } |
447 | |
448 | if (!target_host.defined()) { |
449 | target_host = DefaultTargetHost(target_host); |
450 | } |
451 | |
452 | // Update target host for all targets |
453 | CheckAndUpdateHostConsistency(&inputs, &target_host); |
454 | |
455 | // Take the attrs from the first module so the eventual modules have them. |
456 | // Ideally this would just be one unified module all the way through; |
457 | IRModule first_module = (*inputs.begin()).second; |
458 | IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs); |
459 | |
460 | ICHECK(mhost_all.defined()) << "The host module must be defined" ; |
461 | |
462 | for (const auto& it : inputs) { |
463 | if (it.second.defined()) { |
464 | const Target& target = it.first; |
465 | const IRModule& ir_module = it.second; |
466 | auto pair = SplitMixedModule(ir_module, target, target_host); |
467 | auto& host_mod = pair.first; |
468 | auto& device_mod = pair.second; |
469 | |
470 | ICHECK(host_mod.defined()) << "The split host module must be defined" ; |
471 | |
472 | ICHECK(mhost_all.defined()) << "The host module must be defined" ; |
473 | |
474 | // We don't want library modules going back into host codegen |
475 | // unless they're supposed to. Here if we overrode the target host |
476 | // to allow lowering previously we check that it's meant to be placed |
477 | // back into the host Module. |
478 | bool overrides_host_target = |
479 | target->GetTargetDeviceType() == target_host->GetTargetDeviceType(); |
480 | bool non_host_target_kind = target->kind != target_host->kind; |
481 | if (overrides_host_target && non_host_target_kind) { |
482 | device_modules.push_back(codegen::Build(host_mod, it.first)); |
483 | } else { |
484 | mhost_all->Update(host_mod); |
485 | } |
486 | |
487 | if (device_mod->functions.size() != 0) { |
488 | device_modules.push_back(codegen::Build(device_mod, it.first)); |
489 | } |
490 | } |
491 | } |
492 | |
493 | runtime::Module mhost = codegen::Build(mhost_all, target_host); |
494 | for (const auto& it : device_modules) { |
495 | if (it.operator->()) { |
496 | mhost.Import(it); |
497 | } |
498 | } |
499 | |
500 | return mhost; |
501 | } |
502 | |
503 | TVM_REGISTER_GLOBAL("driver.tir_to_runtime" ) |
504 | .set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) { |
505 | return TIRToRuntime(inputs_arg, host_target); |
506 | }); |
507 | |
508 | // Build for heterogeneous execution when targets are specified as |
509 | // objects. This wrapper around the internal API is maintained for |
510 | // backwards compatibility. |
511 | runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host) { |
512 | return TIRToRuntime(input, target_host); |
513 | } |
514 | |
515 | // Build for heterogeneous execution when target is a string. |
516 | runtime::Module build(const Map<String, IRModule>& inputs_arg, const Target& target_host_arg) { |
517 | Map<Target, IRModule> updated_inputs; |
518 | Target target_host = target_host_arg; |
519 | for (const auto& it : inputs_arg) { |
520 | Target target = Target(it.first); |
521 | CheckAndUpdateHostConsistency(&target, &target_host); |
522 | Optional<String> device = target->GetAttr<String>("device" ); |
523 | if (device.defined() && device.value() == "vta" ) { |
524 | target = Target("ext_dev" ); |
525 | } |
526 | updated_inputs.Set(target, it.second); |
527 | } |
528 | return TIRToRuntime(updated_inputs, target_host); |
529 | } |
530 | |
531 | // Build for homogeneous execution. |
532 | runtime::Module build(const IRModule& funcs, const Target& target_arg, |
533 | const Target& target_host_arg) { |
534 | auto target = target_arg, target_host = target_host_arg; |
535 | CheckAndUpdateHostConsistency(&target, &target_host); |
536 | // More maps of target and target host |
537 | Map<Target, IRModule> inputs = {{target, funcs}}; |
538 | return TIRToRuntime(inputs, target_host); |
539 | } |
540 | |
541 | int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { |
542 | if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); |
543 | if (target.defined() && target->kind->name == "hexagon" ) { |
544 | auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity" ))->value; |
545 | if (value > 0) return value; |
546 | } |
547 | return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity" , Integer(0)).value()->value; |
548 | } |
549 | |
550 | transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { |
551 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
552 | |
553 | Array<Pass> mixed_pass_list; |
554 | |
555 | // VerifyVTCMLimit must occur before LowerVtcmAlloc |
556 | mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx))); |
557 | // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations |
558 | mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); |
559 | |
560 | mixed_pass_list.push_back(tir::transform::BindTarget(target)); |
561 | |
562 | mixed_pass_list.push_back(tir::transform::VerifyMemory()); |
563 | |
564 | if (ShouldAnnotateEntryFunc(mixed_mod)) { |
565 | mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); |
566 | } |
567 | |
568 | bool detect_global_barrier = |
569 | pass_ctx->GetConfig<Bool>("tir.detect_global_barrier" , Bool(false)).value(); |
570 | if (detect_global_barrier) { |
571 | mixed_pass_list.push_back(tir::transform::ThreadSync("global" )); |
572 | } |
573 | |
574 | mixed_pass_list.push_back(tir::transform::ThreadSync("shared" )); |
575 | mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn" )); |
576 | mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); |
577 | mixed_pass_list.push_back(tir::transform::ThreadSync("warp" )); |
578 | mixed_pass_list.push_back(tir::transform::InferFragment()); |
579 | mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); |
580 | |
581 | bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy" , Bool(false)).value(); |
582 | |
583 | if (use_async_copy) { |
584 | mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy()); |
585 | } |
586 | |
587 | bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor) |
588 | .value_or(relay::Executor::Create("graph" , {})) |
589 | ->GetAttr<Bool>("unpacked-api" ) |
590 | .value_or(Bool(false)); |
591 | if (unpacked_api) { |
592 | mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); |
593 | } else { |
594 | mixed_pass_list.push_back(tir::transform::MakePackedAPI()); |
595 | } |
596 | mixed_pass_list.push_back(tir::transform::SplitHostDevice()); |
597 | |
598 | return transform::Sequential(mixed_pass_list); |
599 | } |
600 | |
601 | TVM_REGISTER_GLOBAL("driver.mixed_mod_passes" ) |
602 | .set_body_typed([](IRModule mixed_mod, Target target) { |
603 | return MixedModulePassManager(mixed_mod, target); |
604 | }); |
605 | |
606 | transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { |
607 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
608 | bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug" , Bool(false)).value(); |
609 | |
610 | Array<tvm::transform::Pass> host_pass_list; |
611 | |
612 | runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) { |
613 | return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != |
614 | CallingConv::kDeviceKernelLaunch; |
615 | }; |
616 | host_pass_list.push_back(tir::transform::Filter(fcond)); |
617 | |
618 | ICHECK(mixed_mod.defined()) << "This module must be defined" ; |
619 | |
620 | host_pass_list.push_back(tir::transform::BindTarget(target_host)); |
621 | |
622 | host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); |
623 | host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); |
624 | host_pass_list.push_back(tir::transform::LowerIntrin()); |
625 | host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); |
626 | host_pass_list.push_back(tir::transform::CombineContextCall()); |
627 | |
628 | if (enable_debug) { |
629 | host_pass_list.push_back(tir::transform::InstallDebugSpans()); |
630 | } |
631 | |
632 | return transform::Sequential(host_pass_list); |
633 | } |
634 | |
635 | TVM_REGISTER_GLOBAL("driver.host_mod_passes" ) |
636 | .set_body_typed([](IRModule mixed_mod, Target target_host) { |
637 | return HostModulePassManager(mixed_mod, target_host); |
638 | }); |
639 | |
640 | transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { |
641 | Array<Pass> device_pass_list; |
642 | runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) { |
643 | return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == |
644 | CallingConv::kDeviceKernelLaunch; |
645 | }; |
646 | device_pass_list.push_back(tir::transform::Filter(fcond)); |
647 | |
648 | device_pass_list.push_back(tir::transform::BindTarget(target)); |
649 | |
650 | device_pass_list.push_back(tir::transform::LowerWarpMemory()); |
651 | device_pass_list.push_back(tir::transform::Simplify()); |
652 | device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); |
653 | device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); |
654 | device_pass_list.push_back(tir::transform::LowerIntrin()); |
655 | |
656 | return transform::Sequential(device_pass_list); |
657 | } |
658 | |
659 | TVM_REGISTER_GLOBAL("driver.device_mod_passes" ) |
660 | .set_body_typed([](IRModule mixed_mod, Target target_host) { |
661 | return DeviceModulePassManager(mixed_mod, target_host); |
662 | }); |
663 | |
664 | } // namespace tvm |
665 | |