1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/aot/create_executor_metadata.cc
22 * \brief Create the ExecutorCodegenMetadata from a compiled IRModule.
23 */
24
25#include "./create_executor_metadata.h"
26
27#include "../utils.h"
28
29namespace tvm {
30namespace relay {
31namespace backend {
32namespace aot {
33
34ExecutorCodegenMetadata CreateExecutorMetadata(const IRModule& mod, String mod_name,
35 Executor executor, Integer workspace_byte_alignment,
36 Integer constant_byte_alignment) {
37 // Get relevant executor config information
38 std::string interface_api = executor->GetAttr<String>("interface-api").value_or("packed");
39 bool unpacked_api = executor->GetAttr<Bool>("unpacked-api").value_or(Bool(false));
40 // Get the input vars
41 auto tir_main_func = Downcast<tir::PrimFunc>(mod->Lookup(runtime::symbol::tvm_module_main));
42 Array<tir::Var> inputs = tir_main_func->GetAttr<Array<tir::Var>>("input_vars").value();
43 Array<TensorType> input_tensor_types;
44 for (const auto& input : inputs) {
45 auto buffer = tir_main_func->buffer_map.Get(input).value();
46 input_tensor_types.push_back(TensorType(buffer->shape, buffer->dtype));
47 }
48 // Extract USMP metadata to pass onto metadata sources
49 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
50 std::vector<tir::Var> pool_vars;
51 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
52 tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
53 if (allocated_pool_infos) {
54 for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) {
55 int pool_var_index = allocated_pool_info->pool_var_idx.value()->value;
56 pool_vars.push_back(tir_main_func->params[pool_var_index]);
57 pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info);
58 }
59 }
60 Map<String, tir::usmp::PoolAllocation> io_pool_allocations =
61 mod->GetAttr<Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations)
62 .value_or({});
63
64 Array<tir::Var> outputs = tir_main_func->GetAttr<Array<tir::Var>>("output_vars").value();
65 Array<TensorType> output_tensor_types;
66 std::vector<String> output_var_names;
67 for (const auto& output : outputs) {
68 auto buffer = tir_main_func->buffer_map.Get(output).value();
69 output_tensor_types.push_back(TensorType(buffer->shape, buffer->dtype));
70 output_var_names.push_back(output->name_hint);
71 }
72 auto devices = tir_main_func->GetAttr<Array<String>>("devices").value_or({});
73
74 return ExecutorCodegenMetadata(inputs, input_tensor_types, output_var_names, output_tensor_types,
75 pool_vars, devices, runtime::kTvmExecutorAot, mod_name,
76 interface_api, unpacked_api, workspace_byte_alignment,
77 constant_byte_alignment, pool_var_info, io_pool_allocations);
78}
79
80TVM_REGISTER_GLOBAL("relay.backend.aot.CreateExecutorMetadata")
81 .set_body_typed(CreateExecutorMetadata);
82
83} // namespace aot
84} // namespace backend
85} // namespace relay
86} // namespace tvm
87