1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/build_module.cc |
22 | * \brief Code generation for TVM's graph executor. |
23 | */ |
24 | #include <tvm/driver/driver_api.h> |
25 | #include <tvm/ir/expr.h> |
26 | #include <tvm/ir/memory_pools.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/executor.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/qnn/transform.h> |
31 | #include <tvm/relay/runtime.h> |
32 | #include <tvm/relay/transform.h> |
33 | #include <tvm/runtime/device_api.h> |
34 | #include <tvm/target/compilation_config.h> |
35 | |
36 | #include <memory> |
37 | |
38 | #include "../../driver/internal_driver_api.h" |
39 | #include "../../target/func_registry_generator.h" |
40 | #include "../../target/metadata_module.h" |
41 | #include "../../target/source/codegen_source_base.h" |
42 | #include "te_compiler.h" |
43 | #include "utils.h" |
44 | |
45 | namespace tvm { |
46 | namespace relay { |
47 | namespace transform { |
48 | Pass LabelOps(); |
49 | } |
50 | namespace backend { |
51 | |
52 | using namespace tvm::relay::transform; |
53 | |
54 | /*! |
55 | * \brief Output of building module |
56 | */ |
57 | struct BuildOutput { |
58 | std::string graph_json; |
59 | runtime::Module mod; |
60 | std::unordered_map<std::string, tvm::runtime::NDArray> params; |
61 | }; |
62 | |
63 | struct ExecutorCodegen { |
64 | void Init(runtime::Module* m, const Array<Target>& raw_targets) { |
65 | CallFunc("init" , m, raw_targets); |
66 | } |
67 | |
68 | void Codegen(IRModule mod, const Function& func, String mod_name) { |
69 | CallFunc("codegen" , mod, func, mod_name); |
70 | } |
71 | |
72 | virtual void UpdateOutput(BuildOutput* ret) = 0; |
73 | |
74 | Map<String, FunctionInfo> GetFunctionMetadata() { |
75 | return CallFunc<Map<String, FunctionInfo>>("get_function_metadata" , nullptr); |
76 | } |
77 | |
78 | std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() { |
79 | std::unordered_map<std::string, tvm::runtime::NDArray> ret; |
80 | auto names = CallFunc<Array<runtime::String>>("list_params_name" , nullptr); |
81 | for (const auto& expr : names) { |
82 | // Implicit cast from runtime::String to std::string |
83 | std::string key = expr; |
84 | ret[key] = CallFunc<runtime::NDArray>("get_param_by_name" , key); |
85 | } |
86 | return ret; |
87 | } |
88 | |
89 | Array<tvm::runtime::Module> GetExternalModules() { |
90 | return CallFunc<Array<tvm::runtime::Module>>("get_external_modules" , nullptr); |
91 | } |
92 | |
93 | Map<Target, IRModule> GetIRModule() { |
94 | return CallFunc<Map<Target, IRModule>>("get_irmodule" , nullptr); |
95 | } |
96 | |
97 | Array<String> ListDevices() { return CallFunc<Array<String>>("get_devices" ); } |
98 | |
99 | relay::backend::ExecutorCodegenMetadata GetExecutorCodegenMetadata() { |
100 | return CallFunc<relay::backend::ExecutorCodegenMetadata>("get_executor_codegen_metadata" ); |
101 | } |
102 | virtual ~ExecutorCodegen() {} |
103 | |
104 | protected: |
105 | tvm::runtime::Module mod; |
106 | template <typename R, typename... Args> |
107 | R CallFunc(const std::string& name, Args... args) { |
108 | auto pf = mod.GetFunction(name, false); |
109 | return pf(std::forward<Args>(args)...); |
110 | } |
111 | template <typename... Args> |
112 | void CallFunc(const std::string& name, Args... args) { |
113 | auto pf = mod.GetFunction(name, false); |
114 | pf(std::forward<Args>(args)...); |
115 | return; |
116 | } |
117 | }; |
118 | |
119 | struct AOTCodegen : ExecutorCodegen { |
120 | AOTCodegen() { |
121 | auto pf = GetPackedFunc("relay.build_module._AOTExecutorCodegen" ); |
122 | mod = (*pf)(); |
123 | } |
124 | |
125 | void UpdateOutput(BuildOutput* ret) override { ret->graph_json = "" ; } |
126 | |
127 | ~AOTCodegen() {} |
128 | }; |
129 | |
130 | /*! |
131 | * \brief GraphCodegen module wrapper |
132 | * |
133 | */ |
134 | struct GraphCodegen : ExecutorCodegen { |
135 | GraphCodegen() { |
136 | auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen" ); |
137 | mod = (*pf)(); |
138 | } |
139 | void UpdateOutput(BuildOutput* ret) override { ret->graph_json = GetGraphJSON(); } |
140 | |
141 | std::string GetGraphJSON() { return CallFunc<std::string>("get_graph_json" , nullptr); } |
142 | |
143 | ~GraphCodegen() {} |
144 | }; |
145 | |
146 | /*! |
147 | * \brief Executor codegen factory function |
148 | */ |
149 | std::unique_ptr<ExecutorCodegen> MakeExecutorCodegen(String executor_str) { |
150 | std::unique_ptr<ExecutorCodegen> ret; |
151 | if (executor_str == runtime::kTvmExecutorGraph) { |
152 | ret = std::make_unique<GraphCodegen>(); |
153 | } else if (executor_str == runtime::kTvmExecutorAot) { |
154 | ret = std::make_unique<AOTCodegen>(); |
155 | } else { |
156 | CHECK(false) << "Executor " << executor_str << " not supported" ; |
157 | } |
158 | return ret; |
159 | } |
160 | |
161 | /*! |
162 | * \brief Relay build module |
163 | * |
164 | */ |
165 | class RelayBuildModule : public runtime::ModuleNode { |
166 | public: |
167 | RelayBuildModule() = default; |
168 | |
169 | /*! |
170 | * \brief Get member function to front-end |
171 | * \param name The name of the function. |
172 | * \param sptr_to_self The pointer to the module node. |
173 | * \return The corresponding member function. |
174 | */ |
175 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
176 | if (name == "get_graph_json" ) { |
177 | return PackedFunc( |
178 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); |
179 | } else if (name == "get_module" ) { |
180 | return PackedFunc( |
181 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); |
182 | } else if (name == "build" ) { |
183 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
184 | ICHECK_EQ(args.num_args, 8); |
185 | this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); |
186 | }); |
187 | } else if (name == "list_params" ) { |
188 | return PackedFunc( |
189 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); }); |
190 | } else if (name == "get_params" ) { |
191 | return PackedFunc( |
192 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); |
193 | } else if (name == "set_params" ) { |
194 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
195 | Map<String, Constant> params = args[0]; |
196 | for (const auto& kv : params) { |
197 | this->SetParam(kv.first, kv.second->data); |
198 | } |
199 | }); |
200 | } else if (name == "get_devices" ) { |
201 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
202 | *rv = this->executor_codegen_->ListDevices(); |
203 | }); |
204 | } else if (name == "get_irmodule" ) { |
205 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
206 | *rv = this->executor_codegen_->GetIRModule(); |
207 | }); |
208 | } else if (name == "get_external_modules" ) { |
209 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
210 | *rv = this->executor_codegen_->GetExternalModules(); |
211 | }); |
212 | } else if (name == "get_function_metadata" ) { |
213 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
214 | *rv = this->executor_codegen_->GetFunctionMetadata(); |
215 | }); |
216 | } else if (name == "get_executor_codegen_metadata" ) { |
217 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
218 | *rv = this->executor_codegen_->GetExecutorCodegenMetadata(); |
219 | }); |
220 | } else if (name == "optimize" ) { |
221 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
222 | ICHECK_EQ(args.num_args, 2); |
223 | *rv = this->Optimize(args[0], args[1]); |
224 | }); |
225 | } else { |
226 | LOG(FATAL) << "Unknown packed function: " << name; |
227 | return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); |
228 | } |
229 | } |
230 | |
231 | /*! |
232 | * \brief Get the GraphJSON for runtime |
233 | * |
234 | * \return const std::string graph_json |
235 | */ |
236 | const std::string& GetGraphJSON() { return ret_.graph_json; } |
237 | |
238 | /*! |
239 | * \brief Get the Module object |
240 | * |
241 | * \return runtime::Module |
242 | */ |
243 | runtime::Module GetModule() { return ret_.mod; } |
244 | |
245 | /*! |
246 | * \brief List all paramter names |
247 | * |
248 | * \return Array<runtime::String> names of params |
249 | */ |
250 | Array<runtime::String> ListParamNames() { |
251 | Array<runtime::String> ret; |
252 | for (const auto& kv : params_) { |
253 | ret.push_back(kv.first); |
254 | } |
255 | return ret; |
256 | } |
257 | |
258 | /*! |
259 | * \brief Get params dictionary |
260 | * |
261 | * \return Map<String, Constant> params dictionary |
262 | */ |
263 | Map<String, Constant> GetParams() { |
264 | Map<String, Constant> ret; |
265 | for (const auto& kv : ret_.params) { |
266 | ret.Set(kv.first, Constant(kv.second)); |
267 | } |
268 | return ret; |
269 | } |
270 | |
271 | /*! |
272 | * \brief Set the parameters |
273 | * |
274 | * \param name name of parameter |
275 | * \param data_in input DLTensor |
276 | */ |
277 | void SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } |
278 | |
279 | /*! |
280 | * \brief type key |
281 | * |
282 | * \return const char* |
283 | */ |
284 | const char* type_key() const final { return "RelayBuildModule" ; } |
285 | |
286 | /*! |
287 | * \brief Build relay IRModule for graph executor |
288 | * |
289 | * \param mod Relay IRModule |
290 | * \param raw_targets List of available targets for kernels. |
291 | * \param executor Executor to target |
292 | * \param runtime Runtime to codegen for |
293 | * \param mod_name Name of the module |
294 | */ |
295 | void Build(IRModule mod, const Array<Target>& raw_targets, const tvm::Target& target_host, |
296 | const Executor& executor, const Runtime& runtime, |
297 | const WorkspaceMemoryPools& workspace_memory_pools, |
298 | const ConstantMemoryPools& constant_memory_pools, const String mod_name) { |
299 | VLOG_CONTEXT << "Build" ; |
300 | executor_ = executor; |
301 | runtime_ = runtime; |
302 | workspace_memory_pools_ = workspace_memory_pools; |
303 | constant_memory_pools_ = constant_memory_pools; |
304 | config_ = CompilationConfig(PassContext::Current(), raw_targets); |
305 | VLOG(1) << "Using compilation config:" << std::endl << config_; |
306 | BuildRelay(std::move(mod), mod_name); |
307 | } |
308 | |
309 | protected: |
310 | /*! |
311 | * \brief Optimize a Relay IRModule. |
312 | * |
313 | * \param relay_module The input IRModule where optmization will be applied on. |
314 | * \param raw_targets List of available targets for kernels. |
315 | * |
316 | * \return relay::IRModule The updated Relay IR module after optimization. |
317 | */ |
318 | IRModule Optimize(IRModule relay_module, const Array<Target>& raw_targets) { |
319 | VLOG_CONTEXT << "Optimize" ; |
320 | config_ = CompilationConfig(PassContext ::Current(), raw_targets); |
321 | VLOG(1) << "Using compilation config:" << std::endl << config_; |
322 | return OptimizeImpl(std::move(relay_module)); |
323 | } |
324 | |
325 | IRModule OptimizeImpl(IRModule relay_module) { |
326 | ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler." ; |
327 | |
328 | backend::BindParamsInModule(relay_module, params_); |
329 | |
330 | Array<Pass> pass_seqs = |
331 | GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false); |
332 | transform::PassContext pass_ctx = PassContext::Current(); |
333 | |
334 | if (config_->optional_homogeneous_target.defined()) { |
335 | // This pass currently only supports the homogeneous case. |
336 | pass_seqs.push_back(transform::SplitArgs( |
337 | config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args" , -1) |
338 | .value() |
339 | .IntValue())); |
340 | } |
341 | |
342 | // Always plan devices so the remaining passes don't need to distinguish homogeneous vs |
343 | // hetrogenous execution. |
344 | pass_seqs.push_back(transform::PlanDevices(config_)); |
345 | |
346 | // Fuse the operations if it is needed. |
347 | pass_seqs.push_back(transform::FuseOps()); |
348 | |
349 | // Create a sequential pass and perform optimizations. |
350 | transform::Pass seq = transform::Sequential(pass_seqs); |
351 | if (config_->optional_homogeneous_target.defined()) { |
352 | With<Target> tctx(config_->optional_homogeneous_target); |
353 | relay_module = seq(relay_module); |
354 | } else { |
355 | relay_module = seq(relay_module); |
356 | } |
357 | |
358 | // Do layout rewrite for auto-scheduler. |
359 | if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) { |
360 | Pass major_pass = transform::AutoSchedulerLayoutRewrite(); |
361 | bool enable_layout_rewrite_targets = |
362 | config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU || |
363 | config_->optional_homogeneous_target->GetAttr<String>("device" , "" ) == "mali" ; |
364 | if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { |
365 | With<Target> tctx(config_->optional_homogeneous_target); |
366 | relay_module = major_pass(relay_module); |
367 | // Defuse ops to fold constants, then fuse them again |
368 | relay_module = transform::DefuseOps()(relay_module); |
369 | relay_module = transform::FoldConstant()(relay_module); |
370 | relay_module = transform::FuseOps()(relay_module); |
371 | } |
372 | } |
373 | if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) { |
374 | Pass major_pass = transform::MetaScheduleLayoutRewrite(); |
375 | bool enable_layout_rewrite_targets = |
376 | config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU || |
377 | config_->optional_homogeneous_target->GetAttr<String>("device" , "" ) == "mali" ; |
378 | if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { |
379 | With<Target> tctx(config_->optional_homogeneous_target); |
380 | relay_module = major_pass(relay_module); |
381 | // Defuse ops to fold constants, then fuse them again |
382 | relay_module = transform::DefuseOps()(relay_module); |
383 | relay_module = transform::FoldConstant()(relay_module); |
384 | relay_module = transform::FuseOps()(relay_module); |
385 | } |
386 | } |
387 | |
388 | relay_module = transform::InferType()(relay_module); |
389 | |
390 | // Inline the functions that have been lifted by the module scope. |
391 | // |
392 | // TODO(@zhiics) Note that we need to be careful about the subgraphs with |
393 | // global function calls. We should make sure that these callees are also |
394 | // inline functions. However, this should be very unlikely for accelerators |
395 | // and vendor-provided libraries. So we don't handle for now. |
396 | relay_module = transform::Inline()(relay_module); |
397 | relay_module = transform::InferType()(relay_module); |
398 | relay_module = transform::LabelOps()(relay_module); |
399 | relay_module = transform::AnnotateMemoryScope()(relay_module); |
400 | |
401 | ICHECK(relay_module.defined()); |
402 | |
403 | return relay_module; |
404 | } |
405 | |
406 | /*! |
407 | * \brief Compile a Relay IR module to runtime module. |
408 | * |
409 | * \param relay_module The Relay IR module. |
410 | * \param params The parameters. |
411 | */ |
412 | void BuildRelay(IRModule relay_module, const String& mod_name) { |
413 | // Relay IRModule -> IRModule optimizations. |
414 | IRModule module = WithAttrs( |
415 | relay_module, {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}}); |
416 | relay_module = OptimizeImpl(std::move(module)); |
417 | |
418 | // Get the updated function and new IRModule to build. |
419 | // Instead of recreating the IRModule, we should look at the differences between this and the |
420 | // incoming IRModule to see if we can just pass (IRModule, Function) to the code generator. |
421 | Function func = Downcast<Function>(relay_module->Lookup("main" )); |
422 | IRModule func_module = WithAttrs(IRModule::FromExpr(func), |
423 | {{tvm::attr::kExecutor, executor_}, |
424 | {tvm::attr::kRuntime, runtime_}, |
425 | {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_}, |
426 | {tvm::attr::kConstantMemoryPools, constant_memory_pools_}}); |
427 | |
428 | // Generate code for the updated function. |
429 | executor_codegen_ = MakeExecutorCodegen(executor_->name); |
430 | executor_codegen_->Init(nullptr, config_->primitive_targets); |
431 | executor_codegen_->Codegen(func_module, func, mod_name); |
432 | executor_codegen_->UpdateOutput(&ret_); |
433 | ret_.params = executor_codegen_->GetParams(); |
434 | |
435 | auto lowered_funcs = executor_codegen_->GetIRModule(); |
436 | |
437 | // No need to build for external functions. |
438 | Target ext_dev("ext_dev" ); |
439 | if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { |
440 | lowered_funcs.Set(ext_dev, IRModule()); |
441 | } |
442 | |
443 | const Target& host_target = config_->host_virtual_device->target; |
444 | const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate" ); |
445 | // When there is no lowered_funcs due to reasons such as optimization. |
446 | if (lowered_funcs.size() == 0) { |
447 | if (host_target->kind->name == "llvm" ) { |
448 | CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen." ; |
449 | // If we can decide the target is LLVM, we then create an empty LLVM module. |
450 | ret_.mod = (*pf)(host_target->str(), "empty_module" ); |
451 | } else { |
452 | // If we cannot decide the target is LLVM, we create an empty CSourceModule. |
453 | // The code content is initialized with ";" to prevent complaining |
454 | // from CSourceModuleNode::SaveToFile. |
455 | ret_.mod = tvm::codegen::CSourceModuleCreate(";" , "" , Array<String>{}); |
456 | } |
457 | } else { |
458 | ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target); |
459 | } |
460 | |
461 | auto ext_mods = executor_codegen_->GetExternalModules(); |
462 | ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, |
463 | runtime_, executor_, |
464 | executor_codegen_->GetExecutorCodegenMetadata()); |
465 | // Remove external params which were stored in metadata module. |
466 | for (tvm::runtime::Module mod : ext_mods) { |
467 | auto pf_var = mod.GetFunction("get_const_vars" ); |
468 | if (pf_var != nullptr) { |
469 | Array<String> variables = pf_var(); |
470 | for (size_t i = 0; i < variables.size(); i++) { |
471 | auto it = ret_.params.find(variables[i].operator std::string()); |
472 | if (it != ret_.params.end()) { |
473 | VLOG(1) << "constant '" << variables[i] << "' has been captured in external module" ; |
474 | ret_.params.erase(it); |
475 | } |
476 | } |
477 | } |
478 | } |
479 | } |
480 | |
481 | protected: |
482 | std::unique_ptr<ExecutorCodegen> executor_codegen_; |
483 | /*! \brief Executor to build for */ |
484 | Executor executor_; |
485 | /*! \brief Runtime to codegen for */ |
486 | Runtime runtime_; |
487 | /*! \brief Workspace memory pools to codegen for */ |
488 | WorkspaceMemoryPools workspace_memory_pools_; |
489 | /*! \brief Constant memory pools to codegen for */ |
490 | ConstantMemoryPools constant_memory_pools_; |
491 | /*! \brief parameters */ |
492 | std::unordered_map<std::string, runtime::NDArray> params_; |
493 | /*! \brief building output */ |
494 | BuildOutput ret_; |
495 | /*! \brief Collects all the targets and scopes we need during compilation. */ |
496 | CompilationConfig config_; |
497 | }; |
498 | |
499 | runtime::Module RelayBuildCreate() { |
500 | auto exec = make_object<RelayBuildModule>(); |
501 | return runtime::Module(exec); |
502 | } |
503 | |
504 | TVM_REGISTER_GLOBAL("relay.build_module._BuildModule" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
505 | *rv = RelayBuildCreate(); |
506 | }); |
507 | |
508 | TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName" ) |
509 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
510 | Map<String, Constant> params = args[1]; |
511 | std::unordered_map<std::string, runtime::NDArray> params_; |
512 | for (const auto& kv : params) { |
513 | params_[kv.first] = kv.second->data; |
514 | } |
515 | *rv = relay::backend::BindParamsByName(args[0], params_); |
516 | }); |
517 | |
518 | } // namespace backend |
519 | } // namespace relay |
520 | } // namespace tvm |
521 | |