1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Defines an implementation of Module-based Model Runtime Interface that works with
22 * Ahead-of-Time compilation.
23 * \file aot_executor.cc
24 */
25
26#include "aot_executor.h"
27
28#include <tvm/runtime/c_runtime_api.h>
29#include <tvm/runtime/data_type.h>
30#include <tvm/runtime/name_transforms.h>
31
32#include <limits>
33#include <memory>
34
35#include "../meta_data.h"
36
37namespace tvm {
38namespace runtime {
39
40AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector<Device>& devs)
41 : module_{module}, devices_{devs} {
42 auto fmetadata = module->GetFunction("get_metadata");
43 CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata";
44 auto ret_value = fmetadata();
45 metadata_ = ret_value.AsObjectRef<tvm::runtime::metadata::Metadata>();
46
47 ICHECK_EQ(devices_.size(), 1) << "Expect exactly 1 device passed.";
48 DLDevice expected_device{kDLCPU, 0};
49 ICHECK_EQ(devices_[0].device_id, expected_device.device_id)
50 << "At this time, AOTExecutor supports only execution on kDLCPU 0";
51 // TODO(tvm-team): Temporary hack since Hexagon is defined different than kDLCPU.
52 bool is_valid_device =
53 (devices_[0].device_type == kDLHexagon) || (devices_[0].device_type == kDLCPU);
54 CHECK(is_valid_device)
55 << "At this time, AOTExecutor supports only execution on kDLCPU 0 or kDLHexagon 0";
56
57 for (auto input : metadata_->inputs()) {
58 // TODO(areusch): Encode device information in Metadata.
59 args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()),
60 input->dtype(), devices_[0]));
61 }
62
63 for (auto output : metadata_->outputs()) {
64 args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()),
65 output->dtype(), devices_[0]));
66 }
67
68 // USMP is used
69 if (metadata_->num_workspace_pools()) {
70 // merge all constants into one ndarray
71 int64_t blob_len = 0;
72 for (const auto& c : metadata_->constant_pools()) {
73 auto data = c->data();
74 int64_t byte_size = GetDataSize(*data.operator->()) + c->byte_offset();
75 blob_len = blob_len > byte_size ? blob_len : byte_size;
76 }
77 ICHECK(blob_len < std::numeric_limits<int32_t>::max());
78 NDArray ci = NDArray::Empty({blob_len}, DataType::UInt(8), devices_[0]);
79 for (const auto& c : metadata_->constant_pools()) {
80 auto data = c->data();
81 data.CopyToBytes(static_cast<uint8_t*>(ci->data) + c->byte_offset(),
82 GetDataSize(*data.operator->()));
83 }
84 // Emplace constant node pool only if workspace pools supplied
85 args_.emplace_back(ci);
86
87 int32_t pool_len = 0;
88 for (auto pool : metadata_->workspace_pools()) {
89 pool_len =
90 GetDataSize(*NDArray::Empty({pool->shape()}, pool->dtype(), devices_[0]).operator->());
91 args_.emplace_back(NDArray::Empty({pool_len}, DataType::UInt(8), devices_[0]));
92 }
93 }
94}
95
96PackedFunc AotExecutor::GetFunction(const std::string& name,
97 const ObjectPtr<Object>& sptr_to_self) {
98 // Return member functions during query.
99 if (name == "set_input") {
100 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
101 if (String::CanConvertFrom(args[0])) {
102 int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
103 if (in_idx >= 0) this->SetInput(in_idx, args[1]);
104 } else {
105 this->SetInput(args[0], args[1]);
106 }
107 });
108 } else if (name == "set_input_zero_copy") {
109 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
110 if (String::CanConvertFrom(args[0])) {
111 int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
112 if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
113 } else {
114 this->SetInputZeroCopy(args[0], args[1]);
115 }
116 });
117 } else if (name == "set_output_zero_copy") {
118 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
119 if (String::CanConvertFrom(args[0])) {
120 int out_idx = this->GetOutputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
121 if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]);
122 } else {
123 this->SetOutputZeroCopy(args[0], args[1]);
124 }
125 });
126 } else if (name == "get_output") {
127 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
128 if (args.num_args == 2) {
129 this->CopyOutputTo(args[0], args[1]);
130 } else {
131 *rv = this->GetOutput(args[0]);
132 }
133 });
134 } else if (name == "get_input") {
135 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
136 int in_idx = 0;
137 if (String::CanConvertFrom(args[0])) {
138 in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
139 } else {
140 in_idx = args[0];
141 }
142 if (in_idx >= 0) {
143 *rv = this->GetInput(in_idx);
144 }
145 });
146 } else if (name == "get_num_outputs") {
147 return PackedFunc(
148 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
149 } else if (name == "get_num_inputs") {
150 return PackedFunc(
151 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
152 } else if (name == "run") {
153 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
154 } else if (name == "get_input_index") {
155 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
156 CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
157 *rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
158 });
159 } else {
160 return PackedFunc();
161 }
162}
163
164void AotExecutor::Run() {
165 auto pf = module_.GetFunction(
166 get_name_mangled(metadata_->mod_name(), ::tvm::runtime::symbol::tvm_module_main),
167 true /* query_imports */);
168 ICHECK(pf != nullptr) << "Module entrypoint is not defined";
169
170 const int num_args = args_.size();
171 auto call_values = ::std::make_unique<TVMValue[]>(num_args);
172 auto call_type_codes = ::std::make_unique<int[]>(num_args);
173 for (int i = 0; i < num_args; ++i) {
174 auto managed = args_[i].ToDLPack();
175 call_values.get()[i].v_handle = &managed->dl_tensor;
176 call_type_codes.get()[i] = kTVMDLTensorHandle;
177 }
178
179 TVMArgs args{call_values.get(), call_type_codes.get(), num_args};
180 TVMRetValue rv;
181 pf.CallPacked(args, &rv);
182}
183
184int AotExecutor::GetInputIndex(const std::string& name) {
185 auto inputs = metadata_->inputs();
186 for (unsigned int i = 0; i < inputs.size(); i++) {
187 if (inputs[i]->name() == name) {
188 return i;
189 }
190 }
191 return -1;
192}
193
194int AotExecutor::GetOutputIndex(const std::string& name) {
195 auto outputs = metadata_->outputs();
196 for (unsigned int i = 0; i < outputs.size(); i++) {
197 if (outputs[i]->name() == name) {
198 return i;
199 }
200 }
201 return -1;
202}
203
204void AotExecutor::SetInput(int index, DLTensor* data_ref) { args_[index].CopyFrom(data_ref); }
205
206void AotExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) {
207 ICHECK(false) << "not implemented";
208}
209
210void AotExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) {
211 ICHECK(false) << "not implemented";
212}
213
214int AotExecutor::NumOutputs() const { return metadata_->num_outputs(); }
215
216int AotExecutor::NumInputs() const { return metadata_->num_inputs(); }
217
218NDArray AotExecutor::GetInput(int index) const { return args_[index]; }
219
220NDArray AotExecutor::GetOutput(int index) const { return args_[metadata_->num_inputs() + index]; }
221
222void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { GetOutput(index).CopyTo(data_out); }
223
224} // namespace runtime
225} // namespace tvm
226