1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/build_module.cc
22 * \brief Code generation for TVM's graph executor.
23 */
24#include <tvm/driver/driver_api.h>
25#include <tvm/ir/expr.h>
26#include <tvm/ir/memory_pools.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/executor.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/qnn/transform.h>
31#include <tvm/relay/runtime.h>
32#include <tvm/relay/transform.h>
33#include <tvm/runtime/device_api.h>
34#include <tvm/target/compilation_config.h>
35
36#include <memory>
37
38#include "../../driver/internal_driver_api.h"
39#include "../../target/func_registry_generator.h"
40#include "../../target/metadata_module.h"
41#include "../../target/source/codegen_source_base.h"
42#include "te_compiler.h"
43#include "utils.h"
44
45namespace tvm {
46namespace relay {
47namespace transform {
48Pass LabelOps();
49}
50namespace backend {
51
52using namespace tvm::relay::transform;
53
54/*!
55 * \brief Output of building module
56 */
57struct BuildOutput {
58 std::string graph_json;
59 runtime::Module mod;
60 std::unordered_map<std::string, tvm::runtime::NDArray> params;
61};
62
63struct ExecutorCodegen {
64 void Init(runtime::Module* m, const Array<Target>& raw_targets) {
65 CallFunc("init", m, raw_targets);
66 }
67
68 void Codegen(IRModule mod, const Function& func, String mod_name) {
69 CallFunc("codegen", mod, func, mod_name);
70 }
71
72 virtual void UpdateOutput(BuildOutput* ret) = 0;
73
74 Map<String, FunctionInfo> GetFunctionMetadata() {
75 return CallFunc<Map<String, FunctionInfo>>("get_function_metadata", nullptr);
76 }
77
78 std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
79 std::unordered_map<std::string, tvm::runtime::NDArray> ret;
80 auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
81 for (const auto& expr : names) {
82 // Implicit cast from runtime::String to std::string
83 std::string key = expr;
84 ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
85 }
86 return ret;
87 }
88
89 Array<tvm::runtime::Module> GetExternalModules() {
90 return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
91 }
92
93 Map<Target, IRModule> GetIRModule() {
94 return CallFunc<Map<Target, IRModule>>("get_irmodule", nullptr);
95 }
96
97 Array<String> ListDevices() { return CallFunc<Array<String>>("get_devices"); }
98
99 relay::backend::ExecutorCodegenMetadata GetExecutorCodegenMetadata() {
100 return CallFunc<relay::backend::ExecutorCodegenMetadata>("get_executor_codegen_metadata");
101 }
102 virtual ~ExecutorCodegen() {}
103
104 protected:
105 tvm::runtime::Module mod;
106 template <typename R, typename... Args>
107 R CallFunc(const std::string& name, Args... args) {
108 auto pf = mod.GetFunction(name, false);
109 return pf(std::forward<Args>(args)...);
110 }
111 template <typename... Args>
112 void CallFunc(const std::string& name, Args... args) {
113 auto pf = mod.GetFunction(name, false);
114 pf(std::forward<Args>(args)...);
115 return;
116 }
117};
118
119struct AOTCodegen : ExecutorCodegen {
120 AOTCodegen() {
121 auto pf = GetPackedFunc("relay.build_module._AOTExecutorCodegen");
122 mod = (*pf)();
123 }
124
125 void UpdateOutput(BuildOutput* ret) override { ret->graph_json = ""; }
126
127 ~AOTCodegen() {}
128};
129
130/*!
131 * \brief GraphCodegen module wrapper
132 *
133 */
134struct GraphCodegen : ExecutorCodegen {
135 GraphCodegen() {
136 auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen");
137 mod = (*pf)();
138 }
139 void UpdateOutput(BuildOutput* ret) override { ret->graph_json = GetGraphJSON(); }
140
141 std::string GetGraphJSON() { return CallFunc<std::string>("get_graph_json", nullptr); }
142
143 ~GraphCodegen() {}
144};
145
146/*!
147 * \brief Executor codegen factory function
148 */
149std::unique_ptr<ExecutorCodegen> MakeExecutorCodegen(String executor_str) {
150 std::unique_ptr<ExecutorCodegen> ret;
151 if (executor_str == runtime::kTvmExecutorGraph) {
152 ret = std::make_unique<GraphCodegen>();
153 } else if (executor_str == runtime::kTvmExecutorAot) {
154 ret = std::make_unique<AOTCodegen>();
155 } else {
156 CHECK(false) << "Executor " << executor_str << " not supported";
157 }
158 return ret;
159}
160
161/*!
162 * \brief Relay build module
163 *
164 */
165class RelayBuildModule : public runtime::ModuleNode {
166 public:
167 RelayBuildModule() = default;
168
169 /*!
170 * \brief Get member function to front-end
171 * \param name The name of the function.
172 * \param sptr_to_self The pointer to the module node.
173 * \return The corresponding member function.
174 */
175 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
176 if (name == "get_graph_json") {
177 return PackedFunc(
178 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); });
179 } else if (name == "get_module") {
180 return PackedFunc(
181 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); });
182 } else if (name == "build") {
183 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
184 ICHECK_EQ(args.num_args, 8);
185 this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
186 });
187 } else if (name == "list_params") {
188 return PackedFunc(
189 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); });
190 } else if (name == "get_params") {
191 return PackedFunc(
192 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); });
193 } else if (name == "set_params") {
194 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
195 Map<String, Constant> params = args[0];
196 for (const auto& kv : params) {
197 this->SetParam(kv.first, kv.second->data);
198 }
199 });
200 } else if (name == "get_devices") {
201 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
202 *rv = this->executor_codegen_->ListDevices();
203 });
204 } else if (name == "get_irmodule") {
205 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
206 *rv = this->executor_codegen_->GetIRModule();
207 });
208 } else if (name == "get_external_modules") {
209 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
210 *rv = this->executor_codegen_->GetExternalModules();
211 });
212 } else if (name == "get_function_metadata") {
213 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
214 *rv = this->executor_codegen_->GetFunctionMetadata();
215 });
216 } else if (name == "get_executor_codegen_metadata") {
217 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
218 *rv = this->executor_codegen_->GetExecutorCodegenMetadata();
219 });
220 } else if (name == "optimize") {
221 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
222 ICHECK_EQ(args.num_args, 2);
223 *rv = this->Optimize(args[0], args[1]);
224 });
225 } else {
226 LOG(FATAL) << "Unknown packed function: " << name;
227 return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
228 }
229 }
230
231 /*!
232 * \brief Get the GraphJSON for runtime
233 *
234 * \return const std::string graph_json
235 */
236 const std::string& GetGraphJSON() { return ret_.graph_json; }
237
238 /*!
239 * \brief Get the Module object
240 *
241 * \return runtime::Module
242 */
243 runtime::Module GetModule() { return ret_.mod; }
244
245 /*!
246 * \brief List all paramter names
247 *
248 * \return Array<runtime::String> names of params
249 */
250 Array<runtime::String> ListParamNames() {
251 Array<runtime::String> ret;
252 for (const auto& kv : params_) {
253 ret.push_back(kv.first);
254 }
255 return ret;
256 }
257
258 /*!
259 * \brief Get params dictionary
260 *
261 * \return Map<String, Constant> params dictionary
262 */
263 Map<String, Constant> GetParams() {
264 Map<String, Constant> ret;
265 for (const auto& kv : ret_.params) {
266 ret.Set(kv.first, Constant(kv.second));
267 }
268 return ret;
269 }
270
271 /*!
272 * \brief Set the parameters
273 *
274 * \param name name of parameter
275 * \param data_in input DLTensor
276 */
277 void SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; }
278
279 /*!
280 * \brief type key
281 *
282 * \return const char*
283 */
284 const char* type_key() const final { return "RelayBuildModule"; }
285
286 /*!
287 * \brief Build relay IRModule for graph executor
288 *
289 * \param mod Relay IRModule
290 * \param raw_targets List of available targets for kernels.
291 * \param executor Executor to target
292 * \param runtime Runtime to codegen for
293 * \param mod_name Name of the module
294 */
295 void Build(IRModule mod, const Array<Target>& raw_targets, const tvm::Target& target_host,
296 const Executor& executor, const Runtime& runtime,
297 const WorkspaceMemoryPools& workspace_memory_pools,
298 const ConstantMemoryPools& constant_memory_pools, const String mod_name) {
299 VLOG_CONTEXT << "Build";
300 executor_ = executor;
301 runtime_ = runtime;
302 workspace_memory_pools_ = workspace_memory_pools;
303 constant_memory_pools_ = constant_memory_pools;
304 config_ = CompilationConfig(PassContext::Current(), raw_targets);
305 VLOG(1) << "Using compilation config:" << std::endl << config_;
306 BuildRelay(std::move(mod), mod_name);
307 }
308
309 protected:
310 /*!
311 * \brief Optimize a Relay IRModule.
312 *
313 * \param relay_module The input IRModule where optmization will be applied on.
314 * \param raw_targets List of available targets for kernels.
315 *
316 * \return relay::IRModule The updated Relay IR module after optimization.
317 */
318 IRModule Optimize(IRModule relay_module, const Array<Target>& raw_targets) {
319 VLOG_CONTEXT << "Optimize";
320 config_ = CompilationConfig(PassContext ::Current(), raw_targets);
321 VLOG(1) << "Using compilation config:" << std::endl << config_;
322 return OptimizeImpl(std::move(relay_module));
323 }
324
325 IRModule OptimizeImpl(IRModule relay_module) {
326 ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler.";
327
328 backend::BindParamsInModule(relay_module, params_);
329
330 Array<Pass> pass_seqs =
331 GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false);
332 transform::PassContext pass_ctx = PassContext::Current();
333
334 if (config_->optional_homogeneous_target.defined()) {
335 // This pass currently only supports the homogeneous case.
336 pass_seqs.push_back(transform::SplitArgs(
337 config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1)
338 .value()
339 .IntValue()));
340 }
341
342 // Always plan devices so the remaining passes don't need to distinguish homogeneous vs
343 // hetrogenous execution.
344 pass_seqs.push_back(transform::PlanDevices(config_));
345
346 // Fuse the operations if it is needed.
347 pass_seqs.push_back(transform::FuseOps());
348
349 // Create a sequential pass and perform optimizations.
350 transform::Pass seq = transform::Sequential(pass_seqs);
351 if (config_->optional_homogeneous_target.defined()) {
352 With<Target> tctx(config_->optional_homogeneous_target);
353 relay_module = seq(relay_module);
354 } else {
355 relay_module = seq(relay_module);
356 }
357
358 // Do layout rewrite for auto-scheduler.
359 if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) {
360 Pass major_pass = transform::AutoSchedulerLayoutRewrite();
361 bool enable_layout_rewrite_targets =
362 config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU ||
363 config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
364 if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
365 With<Target> tctx(config_->optional_homogeneous_target);
366 relay_module = major_pass(relay_module);
367 // Defuse ops to fold constants, then fuse them again
368 relay_module = transform::DefuseOps()(relay_module);
369 relay_module = transform::FoldConstant()(relay_module);
370 relay_module = transform::FuseOps()(relay_module);
371 }
372 }
373 if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) {
374 Pass major_pass = transform::MetaScheduleLayoutRewrite();
375 bool enable_layout_rewrite_targets =
376 config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU ||
377 config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
378 if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
379 With<Target> tctx(config_->optional_homogeneous_target);
380 relay_module = major_pass(relay_module);
381 // Defuse ops to fold constants, then fuse them again
382 relay_module = transform::DefuseOps()(relay_module);
383 relay_module = transform::FoldConstant()(relay_module);
384 relay_module = transform::FuseOps()(relay_module);
385 }
386 }
387
388 relay_module = transform::InferType()(relay_module);
389
390 // Inline the functions that have been lifted by the module scope.
391 //
392 // TODO(@zhiics) Note that we need to be careful about the subgraphs with
393 // global function calls. We should make sure that these callees are also
394 // inline functions. However, this should be very unlikely for accelerators
395 // and vendor-provided libraries. So we don't handle for now.
396 relay_module = transform::Inline()(relay_module);
397 relay_module = transform::InferType()(relay_module);
398 relay_module = transform::LabelOps()(relay_module);
399 relay_module = transform::AnnotateMemoryScope()(relay_module);
400
401 ICHECK(relay_module.defined());
402
403 return relay_module;
404 }
405
406 /*!
407 * \brief Compile a Relay IR module to runtime module.
408 *
409 * \param relay_module The Relay IR module.
410 * \param params The parameters.
411 */
412 void BuildRelay(IRModule relay_module, const String& mod_name) {
413 // Relay IRModule -> IRModule optimizations.
414 IRModule module = WithAttrs(
415 relay_module, {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}});
416 relay_module = OptimizeImpl(std::move(module));
417
418 // Get the updated function and new IRModule to build.
419 // Instead of recreating the IRModule, we should look at the differences between this and the
420 // incoming IRModule to see if we can just pass (IRModule, Function) to the code generator.
421 Function func = Downcast<Function>(relay_module->Lookup("main"));
422 IRModule func_module = WithAttrs(IRModule::FromExpr(func),
423 {{tvm::attr::kExecutor, executor_},
424 {tvm::attr::kRuntime, runtime_},
425 {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_},
426 {tvm::attr::kConstantMemoryPools, constant_memory_pools_}});
427
428 // Generate code for the updated function.
429 executor_codegen_ = MakeExecutorCodegen(executor_->name);
430 executor_codegen_->Init(nullptr, config_->primitive_targets);
431 executor_codegen_->Codegen(func_module, func, mod_name);
432 executor_codegen_->UpdateOutput(&ret_);
433 ret_.params = executor_codegen_->GetParams();
434
435 auto lowered_funcs = executor_codegen_->GetIRModule();
436
437 // No need to build for external functions.
438 Target ext_dev("ext_dev");
439 if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
440 lowered_funcs.Set(ext_dev, IRModule());
441 }
442
443 const Target& host_target = config_->host_virtual_device->target;
444 const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
445 // When there is no lowered_funcs due to reasons such as optimization.
446 if (lowered_funcs.size() == 0) {
447 if (host_target->kind->name == "llvm") {
448 CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen.";
449 // If we can decide the target is LLVM, we then create an empty LLVM module.
450 ret_.mod = (*pf)(host_target->str(), "empty_module");
451 } else {
452 // If we cannot decide the target is LLVM, we create an empty CSourceModule.
453 // The code content is initialized with ";" to prevent complaining
454 // from CSourceModuleNode::SaveToFile.
455 ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
456 }
457 } else {
458 ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target);
459 }
460
461 auto ext_mods = executor_codegen_->GetExternalModules();
462 ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target,
463 runtime_, executor_,
464 executor_codegen_->GetExecutorCodegenMetadata());
465 // Remove external params which were stored in metadata module.
466 for (tvm::runtime::Module mod : ext_mods) {
467 auto pf_var = mod.GetFunction("get_const_vars");
468 if (pf_var != nullptr) {
469 Array<String> variables = pf_var();
470 for (size_t i = 0; i < variables.size(); i++) {
471 auto it = ret_.params.find(variables[i].operator std::string());
472 if (it != ret_.params.end()) {
473 VLOG(1) << "constant '" << variables[i] << "' has been captured in external module";
474 ret_.params.erase(it);
475 }
476 }
477 }
478 }
479 }
480
481 protected:
482 std::unique_ptr<ExecutorCodegen> executor_codegen_;
483 /*! \brief Executor to build for */
484 Executor executor_;
485 /*! \brief Runtime to codegen for */
486 Runtime runtime_;
487 /*! \brief Workspace memory pools to codegen for */
488 WorkspaceMemoryPools workspace_memory_pools_;
489 /*! \brief Constant memory pools to codegen for */
490 ConstantMemoryPools constant_memory_pools_;
491 /*! \brief parameters */
492 std::unordered_map<std::string, runtime::NDArray> params_;
493 /*! \brief building output */
494 BuildOutput ret_;
495 /*! \brief Collects all the targets and scopes we need during compilation. */
496 CompilationConfig config_;
497};
498
499runtime::Module RelayBuildCreate() {
500 auto exec = make_object<RelayBuildModule>();
501 return runtime::Module(exec);
502}
503
504TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
505 *rv = RelayBuildCreate();
506});
507
508TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
509 .set_body([](TVMArgs args, TVMRetValue* rv) {
510 Map<String, Constant> params = args[1];
511 std::unordered_map<std::string, runtime::NDArray> params_;
512 for (const auto& kv : params) {
513 params_[kv.first] = kv.second->data;
514 }
515 *rv = relay::backend::BindParamsByName(args[0], params_);
516 });
517
518} // namespace backend
519} // namespace relay
520} // namespace tvm
521