1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/aot/create_executor_metadata.cc |
22 | * \brief Create the ExecutorCodegenMetadata from a compiled IRModule. |
23 | */ |
24 | |
25 | #include "./create_executor_metadata.h" |
26 | |
27 | #include "../utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace backend { |
32 | namespace aot { |
33 | |
34 | ExecutorCodegenMetadata CreateExecutorMetadata(const IRModule& mod, String mod_name, |
35 | Executor executor, Integer workspace_byte_alignment, |
36 | Integer constant_byte_alignment) { |
37 | // Get relevant executor config information |
38 | std::string interface_api = executor->GetAttr<String>("interface-api" ).value_or("packed" ); |
39 | bool unpacked_api = executor->GetAttr<Bool>("unpacked-api" ).value_or(Bool(false)); |
40 | // Get the input vars |
41 | auto tir_main_func = Downcast<tir::PrimFunc>(mod->Lookup(runtime::symbol::tvm_module_main)); |
42 | Array<tir::Var> inputs = tir_main_func->GetAttr<Array<tir::Var>>("input_vars" ).value(); |
43 | Array<TensorType> input_tensor_types; |
44 | for (const auto& input : inputs) { |
45 | auto buffer = tir_main_func->buffer_map.Get(input).value(); |
46 | input_tensor_types.push_back(TensorType(buffer->shape, buffer->dtype)); |
47 | } |
48 | // Extract USMP metadata to pass onto metadata sources |
49 | Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info; |
50 | std::vector<tir::Var> pool_vars; |
51 | Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos = |
52 | tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs); |
53 | if (allocated_pool_infos) { |
54 | for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { |
55 | int pool_var_index = allocated_pool_info->pool_var_idx.value()->value; |
56 | pool_vars.push_back(tir_main_func->params[pool_var_index]); |
57 | pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); |
58 | } |
59 | } |
60 | Map<String, tir::usmp::PoolAllocation> io_pool_allocations = |
61 | mod->GetAttr<Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations) |
62 | .value_or({}); |
63 | |
64 | Array<tir::Var> outputs = tir_main_func->GetAttr<Array<tir::Var>>("output_vars" ).value(); |
65 | Array<TensorType> output_tensor_types; |
66 | std::vector<String> output_var_names; |
67 | for (const auto& output : outputs) { |
68 | auto buffer = tir_main_func->buffer_map.Get(output).value(); |
69 | output_tensor_types.push_back(TensorType(buffer->shape, buffer->dtype)); |
70 | output_var_names.push_back(output->name_hint); |
71 | } |
72 | auto devices = tir_main_func->GetAttr<Array<String>>("devices" ).value_or({}); |
73 | |
74 | return ExecutorCodegenMetadata(inputs, input_tensor_types, output_var_names, output_tensor_types, |
75 | pool_vars, devices, runtime::kTvmExecutorAot, mod_name, |
76 | interface_api, unpacked_api, workspace_byte_alignment, |
77 | constant_byte_alignment, pool_var_info, io_pool_allocations); |
78 | } |
79 | |
80 | TVM_REGISTER_GLOBAL("relay.backend.aot.CreateExecutorMetadata" ) |
81 | .set_body_typed(CreateExecutorMetadata); |
82 | |
83 | } // namespace aot |
84 | } // namespace backend |
85 | } // namespace relay |
86 | } // namespace tvm |
87 | |