1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Compile executable modules.
22 * \file driver_api.cc
23 */
24#include <dmlc/thread_local.h>
25#include <tvm/driver/driver_api.h>
26#include <tvm/ir/transform.h>
27#include <tvm/relay/executor.h>
28#include <tvm/relay/runtime.h>
29#include <tvm/runtime/registry.h>
30#include <tvm/target/codegen.h>
31#include <tvm/te/operation.h>
32#include <tvm/tir/analysis.h>
33#include <tvm/tir/transform.h>
34
35#include <algorithm>
36#include <mutex>
37#include <stack>
38
39namespace tvm {
40
41// Register build pipeline related options
42TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
43TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
44TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
45TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
46TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
47TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
48TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
49TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
50TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
51TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
52TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
53TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
54TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
55TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
56TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
57TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
58
59// WARNING: May cause coherency issues resulting data miscompares
60// Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When
61// bypassing the cache TVM must manage cache coherency in software. Software managed cache coherency
62// can be tricky e.g. it is yet to be proven out in the Hexagon runtime. Hence the warning above and
63// the "experimental" notation for this feature.
64TVM_REGISTER_PASS_CONFIG_OPTION("tir.experimental_dma_bypass_cache", Bool);
65
66using tvm::Array;
67using tvm::transform::Pass;
68
69bool LLVMEnabled() {
70 const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
71 return pf != nullptr;
72}
73
74bool ShouldAnnotateEntryFunc(const IRModule mod) {
75 Optional<tvm::relay::Executor> executor = mod->GetAttr<tvm::relay::Executor>("executor");
76 const bool aot_executor = executor.defined() && executor.value()->name == "aot";
77 const bool single_entry_func = (mod->functions.size() == 1);
78 return single_entry_func && !aot_executor;
79}
80
81/*! \return The default host target for a given device target */
82Target DefaultTargetHost(Target target) {
83 if (target.defined() && target->GetTargetDeviceType() == kDLCPU) {
84 return target;
85 } else {
86 if (LLVMEnabled()) {
87 return Target("llvm");
88 } else {
89 return Target("stackvm");
90 }
91 }
92}
93
94void GetBinds(const Array<ObjectRef>& args, bool compact,
95 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
96 Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
97 *out_binds = binds;
98
99 for (const ObjectRef& x : args) {
100 if (const te::TensorNode* tensor_node = x.as<te::TensorNode>()) {
101 te::Tensor x_ref = GetRef<te::Tensor>(tensor_node);
102 if (out_binds->find(x_ref) == out_binds->end()) {
103 tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype,
104 x_ref->op->name, -1, 0, compact);
105 out_binds->Set(x_ref, buf);
106 out_arg_list->push_back(buf);
107 } else {
108 out_arg_list->push_back((*out_binds)[x_ref]);
109 }
110 } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
111 out_arg_list->push_back(x);
112 } else {
113 LOG(FATAL)
114 << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, "
115 << "but got a " << x->GetTypeKey();
116 }
117 }
118}
119
120void GetBinds(const Array<te::Tensor>& args, bool compact,
121 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
122 Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
123 Array<ObjectRef> ref_args;
124 for (ObjectRef x : args) {
125 ref_args.push_back(x);
126 }
127 GetBinds(ref_args, compact, binds, out_binds, out_arg_list);
128}
129
130TVM_REGISTER_GLOBAL("driver.get_binds")
131 .set_body_typed([](const Array<ObjectRef>& args, bool compact,
132 const Map<te::Tensor, tir::Buffer>& binds) {
133 std::unordered_map<te::Tensor, tir::Buffer> c_binds;
134 // Check to make sure binds is not null before doing the conversion;
135 if (binds.get() != nullptr) {
136 for (auto kv : binds) {
137 c_binds.insert({kv.first, kv.second});
138 }
139 }
140 Map<te::Tensor, tir::Buffer> out_binds;
141 Array<ObjectRef> out_arg_list;
142 GetBinds(args, compact, c_binds, &out_binds, &out_arg_list);
143
144 // TVM object system doesn't have a pair object, so we'll put both ret values in an array
145 // and return that.
146 Array<ObjectRef> out_arr = {out_binds, out_arg_list};
147 return out_arr;
148 });
149
150Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
151 transform::PassContext pass_ctx = transform::PassContext::Current();
152
153 bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
154 bool disable_storage_rewrite =
155 pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
156 bool instrument_bound_checkers =
157 pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
158 bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
159 bool enable_equiv_terms_in_cse_tir =
160 pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();
161
162 // Get any user-added passes
163 Array<Array<ObjectRef>> add_lower_pass =
164 pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
165 .value();
166
167 bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp", Bool(false)).value();
168
169 Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
170 Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
171 Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>();
172 Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>();
173
174 // phase passes is of the form
175 // [[phase_number, pass], [phase_number, pass]... ]
176 for (Array<ObjectRef> phase_pass : add_lower_pass) {
177 const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>();
178 ICHECK(phase_num)
179 << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
180 int phase_num_val = phase_num->value;
181
182 CHECK_GE(phase_num_val, 0);
183
184 const tvm::transform::PassNode* pass_node = phase_pass[1].as<tvm::transform::PassNode>();
185 tvm::transform::Pass pass = GetRef<tvm::transform::Pass>(pass_node);
186 // Copy the pass into the correct phase
187 if (phase_num_val == 0) {
188 user_lower_phase0.push_back(pass);
189 } else if (phase_num_val == 1) {
190 user_lower_phase1.push_back(pass);
191 } else if (phase_num_val == 2) {
192 user_lower_phase2.push_back(pass);
193 } else if (phase_num_val >= 3) {
194 user_lower_phase3.push_back(pass);
195 }
196 }
197
198 // Construct the pass list, inserting the user provided passes at the end of the phase
199
200 // PHASE 0
201 Array<tvm::transform::Pass> pass_list = user_lower_phase0;
202
203 // PHASE 1
204 pass_list.push_back(tir::transform::InjectPrefetch());
205 pass_list.push_back(tir::transform::TextureFlatten());
206 pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
207 pass_list.push_back(tir::transform::LowerCrossThreadReduction());
208 pass_list.push_back(tir::transform::LowerInitBlock());
209 pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
210 pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
211 pass_list.push_back(tir::transform::UnifyThreadBinding());
212 pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
213 pass_list.push_back(tir::transform::CompactBufferAllocation());
214 pass_list.push_back(tir::transform::LowerMatchBuffer());
215 pass_list.push_back(tir::transform::InjectSoftwarePipeline());
216 pass_list.push_back(tir::transform::LowerOpaqueBlock());
217 pass_list.push_back(tir::transform::FlattenBuffer());
218 pass_list.push_back(tir::transform::BF16Legalize());
219 pass_list.push_back(tir::transform::NarrowDataType(32));
220 pass_list.push_back(tir::transform::Simplify());
221
222 // Add user-defined phase-1 passes
223 pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
224
225 // PHASE 2
226 if (!disable_loop_partition) {
227 pass_list.push_back(tir::transform::LoopPartition());
228 }
229
230 pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
231 pass_list.push_back(tir::transform::InjectVirtualThread());
232 pass_list.push_back(tir::transform::InjectDoubleBuffer());
233 if (!disable_storage_rewrite) {
234 pass_list.push_back(tir::transform::StorageRewrite());
235 }
236 bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
237
238 if (use_async_copy) {
239 pass_list.push_back(tir::transform::LowerAsyncDMA());
240 }
241 pass_list.push_back(tir::transform::UnrollLoop());
242
243 // Add user-defined phase-2 passes
244 pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
245
246 // PHASE 3
247 pass_list.push_back(tir::transform::RenormalizeSplitPattern());
248 pass_list.push_back(tir::transform::Simplify());
249 pass_list.push_back(tir::transform::RemoveNoOp());
250 pass_list.push_back(tir::transform::RewriteUnsafeSelect());
251 pass_list.push_back(tir::transform::HoistIfThenElse());
252
253 // Add user-defined phase-3 passes
254 pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
255
256 if (instrument_bound_checkers) {
257 pass_list.push_back(tir::transform::InstrumentBoundCheckers());
258 }
259
260 pass_list.push_back(
261 tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));
262
263 // This pass instruments the loops with the profile builtin calls to capture the runtime
264 // performance data (only enabled for Hexagon at the moment). To ensure that no other
265 // optimizations are performed on the instrumented code, this pass must be added at the end
266 // of the list.
267 if (instrument_lwp) {
268 pass_list.push_back(tir::transform::InstrumentProfileIntrinsics());
269 }
270
271 return pass_list;
272}
273
274IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
275 auto optimize = tvm::transform::Sequential(pass_list);
276 mod = optimize(std::move(mod));
277 return mod;
278}
279
280IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
281 mod = seq(std::move(mod));
282 return mod;
283}
284
285// Convert te schedule to IRModule
286IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
287 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
288 GlobalVarSupply global_var_supply) {
289 sch = sch.normalize();
290
291 transform::PassContext pass_ctx = transform::PassContext::Current();
292 bool debug_keep_trivial_loop =
293 pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop", Bool(false)).value();
294
295 // Before TIR transformation.
296 tir::Stmt stmt = te::ScheduleOps(sch, te::InferBound(sch), debug_keep_trivial_loop);
297 bool compact = te::VerifyCompactBuffer(stmt);
298
299 Map<te::Tensor, tir::Buffer> out_binds;
300 Array<ObjectRef> out_arg_list;
301 GetBinds(args, compact, binds, &out_binds, &out_arg_list);
302
303 // Build the function, converting from te::Tensor to tir::Buffer
304 tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
305 f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
306
307 // Mark this schedule as being converted from an TE schedule. Makes sure that
308 // the correct TE passes are run.
309 f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true));
310
311 bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
312
313 if (noalias) {
314 f = WithAttr(std::move(f), "tir.noalias", Bool(true));
315 }
316 GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false);
317 return IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
318}
319
320TVM_REGISTER_GLOBAL("driver.schedule_to_module")
321 .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
322 const Map<te::Tensor, tir::Buffer>& binds) {
323 std::unordered_map<te::Tensor, tir::Buffer> c_binds;
324 // Check to make sure binds is not null before doing the conversion;
325 if (binds.defined()) {
326 for (auto kv : binds) {
327 c_binds.insert({kv.first, kv.second});
328 }
329 }
330 IRModule mod =
331 ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")));
332 return mod;
333 });
334
335IRModule LowerModule(IRModule mod, bool simple_mode) {
336 Array<transform::Pass> pass_list = CreatePassList(simple_mode);
337 return LowerWithPassList(std::move(mod), pass_list);
338}
339
340TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
341 return LowerModule(std::move(mod), simple_mode);
342});
343
344IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) {
345 transform::PassContext pass_ctx = transform::PassContext::Current();
346 tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name));
347
348 bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
349
350 if (noalias) {
351 f = WithAttr(std::move(f), "tir.noalias", Bool(true));
352 }
353 IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
354
355 // Get the pass list
356 Array<transform::Pass> pass_list = CreatePassList(simple_mode);
357 return LowerWithPassList(std::move(mod), pass_list);
358}
359
360TVM_REGISTER_GLOBAL("driver.lower_primfunc")
361 .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) {
362 return LowerPrimFunc(std::move(func), name, simple_mode);
363 });
364
365IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
366 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
367 GlobalVarSupply global_var_supply, bool simple_mode) {
368 Array<ObjectRef> ref_args;
369 for (ObjectRef x : args) {
370 ref_args.push_back(x);
371 }
372 return LowerSchedule(std::move(sch), ref_args, name, binds, global_var_supply, simple_mode);
373}
374
375IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
376 const std::unordered_map<te::Tensor, tir::Buffer>& binds,
377 GlobalVarSupply global_var_supply, bool simple_mode) {
378 IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply);
379 // Get the legacy TE pass list
380 Array<transform::Pass> pass_list = CreatePassList(simple_mode);
381 return LowerWithPassList(mod, pass_list);
382}
383
384TVM_REGISTER_GLOBAL("driver.lower_schedule")
385 .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
386 const Map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
387 std::unordered_map<te::Tensor, tir::Buffer> c_binds;
388 // Check to make sure binds is not null before doing the conversion;
389 if (binds.get() != nullptr) {
390 for (auto kv : binds) {
391 c_binds.insert({kv.first, kv.second});
392 }
393 }
394 return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")),
395 simple_mode);
396 });
397
398/**
399 * This function takes the input module that contains both the device and host opts.
400 * Then, it applies transformation on the original module before splitting into separate modules for
401 * device and host. Then it also applies transformations on the new splitted modules.
402 */
403std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
404 const Target& target_host_arg) {
405 Target target = target_arg, target_host = target_host_arg;
406 CheckAndUpdateHostConsistency(&target, &target_host);
407
408 ICHECK(mod_mixed.defined()) << "This module must be defined";
409
410 mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
411
412 IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
413
414 IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
415
416 auto keys = target->GetKeys();
417
418 CheckAndUpdateHostConsistency(&target, &target_host);
419
420 bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
421 if (target_is_gpu && device_mod->functions.size() == 0) {
422 DLOG(WARNING) << "Specified target " << target->str()
423 << " but cannot find device code. Did you forget to bind?";
424 }
425
426 return {host_mod, device_mod};
427}
428
429runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
430 const Target& target_host_arg) {
431 std::vector<runtime::Module> device_modules;
432 Map<Target, IRModule> inputs = inputs_arg;
433 Target target_host = target_host_arg;
434
435 // Fetch previous defined target host in targets
436 CheckAndUpdateHostConsistency(&inputs, &target_host);
437
438 if (!target_host.defined()) {
439 for (const auto& it : inputs) {
440 if (it.first->GetTargetDeviceType() == kDLCPU ||
441 it.first->GetTargetDeviceType() == kDLMicroDev) {
442 target_host = it.first;
443 break;
444 }
445 }
446 }
447
448 if (!target_host.defined()) {
449 target_host = DefaultTargetHost(target_host);
450 }
451
452 // Update target host for all targets
453 CheckAndUpdateHostConsistency(&inputs, &target_host);
454
455 // Take the attrs from the first module so the eventual modules have them.
456 // Ideally this would just be one unified module all the way through;
457 IRModule first_module = (*inputs.begin()).second;
458 IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);
459
460 ICHECK(mhost_all.defined()) << "The host module must be defined";
461
462 for (const auto& it : inputs) {
463 if (it.second.defined()) {
464 const Target& target = it.first;
465 const IRModule& ir_module = it.second;
466 auto pair = SplitMixedModule(ir_module, target, target_host);
467 auto& host_mod = pair.first;
468 auto& device_mod = pair.second;
469
470 ICHECK(host_mod.defined()) << "The split host module must be defined";
471
472 ICHECK(mhost_all.defined()) << "The host module must be defined";
473
474 // We don't want library modules going back into host codegen
475 // unless they're supposed to. Here if we overrode the target host
476 // to allow lowering previously we check that it's meant to be placed
477 // back into the host Module.
478 bool overrides_host_target =
479 target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
480 bool non_host_target_kind = target->kind != target_host->kind;
481 if (overrides_host_target && non_host_target_kind) {
482 device_modules.push_back(codegen::Build(host_mod, it.first));
483 } else {
484 mhost_all->Update(host_mod);
485 }
486
487 if (device_mod->functions.size() != 0) {
488 device_modules.push_back(codegen::Build(device_mod, it.first));
489 }
490 }
491 }
492
493 runtime::Module mhost = codegen::Build(mhost_all, target_host);
494 for (const auto& it : device_modules) {
495 if (it.operator->()) {
496 mhost.Import(it);
497 }
498 }
499
500 return mhost;
501}
502
503TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
504 .set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) {
505 return TIRToRuntime(inputs_arg, host_target);
506 });
507
508// Build for heterogeneous execution when targets are specified as
509// objects. This wrapper around the internal API is maintained for
510// backwards compatibility.
511runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host) {
512 return TIRToRuntime(input, target_host);
513}
514
515// Build for heterogeneous execution when target is a string.
516runtime::Module build(const Map<String, IRModule>& inputs_arg, const Target& target_host_arg) {
517 Map<Target, IRModule> updated_inputs;
518 Target target_host = target_host_arg;
519 for (const auto& it : inputs_arg) {
520 Target target = Target(it.first);
521 CheckAndUpdateHostConsistency(&target, &target_host);
522 Optional<String> device = target->GetAttr<String>("device");
523 if (device.defined() && device.value() == "vta") {
524 target = Target("ext_dev");
525 }
526 updated_inputs.Set(target, it.second);
527 }
528 return TIRToRuntime(updated_inputs, target_host);
529}
530
531// Build for homogeneous execution.
532runtime::Module build(const IRModule& funcs, const Target& target_arg,
533 const Target& target_host_arg) {
534 auto target = target_arg, target_host = target_host_arg;
535 CheckAndUpdateHostConsistency(&target, &target_host);
536 // More maps of target and target host
537 Map<Target, IRModule> inputs = {{target, funcs}};
538 return TIRToRuntime(inputs, target_host);
539}
540
541int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
542 if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
543 if (target.defined() && target->kind->name == "hexagon") {
544 auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
545 if (value > 0) return value;
546 }
547 return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
548}
549
550transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
551 transform::PassContext pass_ctx = transform::PassContext::Current();
552
553 Array<Pass> mixed_pass_list;
554
555 // VerifyVTCMLimit must occur before LowerVtcmAlloc
556 mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx)));
557 // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
558 mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
559
560 mixed_pass_list.push_back(tir::transform::BindTarget(target));
561
562 mixed_pass_list.push_back(tir::transform::VerifyMemory());
563
564 if (ShouldAnnotateEntryFunc(mixed_mod)) {
565 mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
566 }
567
568 bool detect_global_barrier =
569 pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value();
570 if (detect_global_barrier) {
571 mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
572 }
573
574 mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
575 mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
576 mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
577 mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
578 mixed_pass_list.push_back(tir::transform::InferFragment());
579 mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
580
581 bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
582
583 if (use_async_copy) {
584 mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
585 }
586
587 bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
588 .value_or(relay::Executor::Create("graph", {}))
589 ->GetAttr<Bool>("unpacked-api")
590 .value_or(Bool(false));
591 if (unpacked_api) {
592 mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
593 } else {
594 mixed_pass_list.push_back(tir::transform::MakePackedAPI());
595 }
596 mixed_pass_list.push_back(tir::transform::SplitHostDevice());
597
598 return transform::Sequential(mixed_pass_list);
599}
600
601TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
602 .set_body_typed([](IRModule mixed_mod, Target target) {
603 return MixedModulePassManager(mixed_mod, target);
604 });
605
606transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
607 transform::PassContext pass_ctx = transform::PassContext::Current();
608 bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();
609
610 Array<tvm::transform::Pass> host_pass_list;
611
612 runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
613 return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
614 CallingConv::kDeviceKernelLaunch;
615 };
616 host_pass_list.push_back(tir::transform::Filter(fcond));
617
618 ICHECK(mixed_mod.defined()) << "This module must be defined";
619
620 host_pass_list.push_back(tir::transform::BindTarget(target_host));
621
622 host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
623 host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
624 host_pass_list.push_back(tir::transform::LowerIntrin());
625 host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
626 host_pass_list.push_back(tir::transform::CombineContextCall());
627
628 if (enable_debug) {
629 host_pass_list.push_back(tir::transform::InstallDebugSpans());
630 }
631
632 return transform::Sequential(host_pass_list);
633}
634
635TVM_REGISTER_GLOBAL("driver.host_mod_passes")
636 .set_body_typed([](IRModule mixed_mod, Target target_host) {
637 return HostModulePassManager(mixed_mod, target_host);
638 });
639
640transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
641 Array<Pass> device_pass_list;
642 runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
643 return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
644 CallingConv::kDeviceKernelLaunch;
645 };
646 device_pass_list.push_back(tir::transform::Filter(fcond));
647
648 device_pass_list.push_back(tir::transform::BindTarget(target));
649
650 device_pass_list.push_back(tir::transform::LowerWarpMemory());
651 device_pass_list.push_back(tir::transform::Simplify());
652 device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
653 device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
654 device_pass_list.push_back(tir::transform::LowerIntrin());
655
656 return transform::Sequential(device_pass_list);
657}
658
659TVM_REGISTER_GLOBAL("driver.device_mod_passes")
660 .set_body_typed([](IRModule mixed_mod, Target target_host) {
661 return DeviceModulePassManager(mixed_mod, target_host);
662 });
663
664} // namespace tvm
665