1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Defines an implementation of Module-based Model Runtime Interface that works with |
22 | * Ahead-of-Time compilation. |
23 | * \file aot_executor.cc |
24 | */ |
25 | |
26 | #include "aot_executor.h" |
27 | |
28 | #include <tvm/runtime/c_runtime_api.h> |
29 | #include <tvm/runtime/data_type.h> |
30 | #include <tvm/runtime/name_transforms.h> |
31 | |
32 | #include <limits> |
33 | #include <memory> |
34 | |
35 | #include "../meta_data.h" |
36 | |
37 | namespace tvm { |
38 | namespace runtime { |
39 | |
40 | AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector<Device>& devs) |
41 | : module_{module}, devices_{devs} { |
42 | auto fmetadata = module->GetFunction("get_metadata" ); |
43 | CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata" ; |
44 | auto ret_value = fmetadata(); |
45 | metadata_ = ret_value.AsObjectRef<tvm::runtime::metadata::Metadata>(); |
46 | |
47 | ICHECK_EQ(devices_.size(), 1) << "Expect exactly 1 device passed." ; |
48 | DLDevice expected_device{kDLCPU, 0}; |
49 | ICHECK_EQ(devices_[0].device_id, expected_device.device_id) |
50 | << "At this time, AOTExecutor supports only execution on kDLCPU 0" ; |
51 | // TODO(tvm-team): Temporary hack since Hexagon is defined different than kDLCPU. |
52 | bool is_valid_device = |
53 | (devices_[0].device_type == kDLHexagon) || (devices_[0].device_type == kDLCPU); |
54 | CHECK(is_valid_device) |
55 | << "At this time, AOTExecutor supports only execution on kDLCPU 0 or kDLHexagon 0" ; |
56 | |
57 | for (auto input : metadata_->inputs()) { |
58 | // TODO(areusch): Encode device information in Metadata. |
59 | args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), |
60 | input->dtype(), devices_[0])); |
61 | } |
62 | |
63 | for (auto output : metadata_->outputs()) { |
64 | args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), |
65 | output->dtype(), devices_[0])); |
66 | } |
67 | |
68 | // USMP is used |
69 | if (metadata_->num_workspace_pools()) { |
70 | // merge all constants into one ndarray |
71 | int64_t blob_len = 0; |
72 | for (const auto& c : metadata_->constant_pools()) { |
73 | auto data = c->data(); |
74 | int64_t byte_size = GetDataSize(*data.operator->()) + c->byte_offset(); |
75 | blob_len = blob_len > byte_size ? blob_len : byte_size; |
76 | } |
77 | ICHECK(blob_len < std::numeric_limits<int32_t>::max()); |
78 | NDArray ci = NDArray::Empty({blob_len}, DataType::UInt(8), devices_[0]); |
79 | for (const auto& c : metadata_->constant_pools()) { |
80 | auto data = c->data(); |
81 | data.CopyToBytes(static_cast<uint8_t*>(ci->data) + c->byte_offset(), |
82 | GetDataSize(*data.operator->())); |
83 | } |
84 | // Emplace constant node pool only if workspace pools supplied |
85 | args_.emplace_back(ci); |
86 | |
87 | int32_t pool_len = 0; |
88 | for (auto pool : metadata_->workspace_pools()) { |
89 | pool_len = |
90 | GetDataSize(*NDArray::Empty({pool->shape()}, pool->dtype(), devices_[0]).operator->()); |
91 | args_.emplace_back(NDArray::Empty({pool_len}, DataType::UInt(8), devices_[0])); |
92 | } |
93 | } |
94 | } |
95 | |
96 | PackedFunc AotExecutor::GetFunction(const std::string& name, |
97 | const ObjectPtr<Object>& sptr_to_self) { |
98 | // Return member functions during query. |
99 | if (name == "set_input" ) { |
100 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
101 | if (String::CanConvertFrom(args[0])) { |
102 | int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); |
103 | if (in_idx >= 0) this->SetInput(in_idx, args[1]); |
104 | } else { |
105 | this->SetInput(args[0], args[1]); |
106 | } |
107 | }); |
108 | } else if (name == "set_input_zero_copy" ) { |
109 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
110 | if (String::CanConvertFrom(args[0])) { |
111 | int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); |
112 | if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); |
113 | } else { |
114 | this->SetInputZeroCopy(args[0], args[1]); |
115 | } |
116 | }); |
117 | } else if (name == "set_output_zero_copy" ) { |
118 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
119 | if (String::CanConvertFrom(args[0])) { |
120 | int out_idx = this->GetOutputIndex(tvm::runtime::SanitizeName(args[0].operator String())); |
121 | if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); |
122 | } else { |
123 | this->SetOutputZeroCopy(args[0], args[1]); |
124 | } |
125 | }); |
126 | } else if (name == "get_output" ) { |
127 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
128 | if (args.num_args == 2) { |
129 | this->CopyOutputTo(args[0], args[1]); |
130 | } else { |
131 | *rv = this->GetOutput(args[0]); |
132 | } |
133 | }); |
134 | } else if (name == "get_input" ) { |
135 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
136 | int in_idx = 0; |
137 | if (String::CanConvertFrom(args[0])) { |
138 | in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); |
139 | } else { |
140 | in_idx = args[0]; |
141 | } |
142 | if (in_idx >= 0) { |
143 | *rv = this->GetInput(in_idx); |
144 | } |
145 | }); |
146 | } else if (name == "get_num_outputs" ) { |
147 | return PackedFunc( |
148 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); |
149 | } else if (name == "get_num_inputs" ) { |
150 | return PackedFunc( |
151 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); |
152 | } else if (name == "run" ) { |
153 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); |
154 | } else if (name == "get_input_index" ) { |
155 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
156 | CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string" ; |
157 | *rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); |
158 | }); |
159 | } else { |
160 | return PackedFunc(); |
161 | } |
162 | } |
163 | |
164 | void AotExecutor::Run() { |
165 | auto pf = module_.GetFunction( |
166 | get_name_mangled(metadata_->mod_name(), ::tvm::runtime::symbol::tvm_module_main), |
167 | true /* query_imports */); |
168 | ICHECK(pf != nullptr) << "Module entrypoint is not defined" ; |
169 | |
170 | const int num_args = args_.size(); |
171 | auto call_values = ::std::make_unique<TVMValue[]>(num_args); |
172 | auto call_type_codes = ::std::make_unique<int[]>(num_args); |
173 | for (int i = 0; i < num_args; ++i) { |
174 | auto managed = args_[i].ToDLPack(); |
175 | call_values.get()[i].v_handle = &managed->dl_tensor; |
176 | call_type_codes.get()[i] = kTVMDLTensorHandle; |
177 | } |
178 | |
179 | TVMArgs args{call_values.get(), call_type_codes.get(), num_args}; |
180 | TVMRetValue rv; |
181 | pf.CallPacked(args, &rv); |
182 | } |
183 | |
184 | int AotExecutor::GetInputIndex(const std::string& name) { |
185 | auto inputs = metadata_->inputs(); |
186 | for (unsigned int i = 0; i < inputs.size(); i++) { |
187 | if (inputs[i]->name() == name) { |
188 | return i; |
189 | } |
190 | } |
191 | return -1; |
192 | } |
193 | |
194 | int AotExecutor::GetOutputIndex(const std::string& name) { |
195 | auto outputs = metadata_->outputs(); |
196 | for (unsigned int i = 0; i < outputs.size(); i++) { |
197 | if (outputs[i]->name() == name) { |
198 | return i; |
199 | } |
200 | } |
201 | return -1; |
202 | } |
203 | |
204 | void AotExecutor::SetInput(int index, DLTensor* data_ref) { args_[index].CopyFrom(data_ref); } |
205 | |
206 | void AotExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { |
207 | ICHECK(false) << "not implemented" ; |
208 | } |
209 | |
210 | void AotExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { |
211 | ICHECK(false) << "not implemented" ; |
212 | } |
213 | |
214 | int AotExecutor::NumOutputs() const { return metadata_->num_outputs(); } |
215 | |
216 | int AotExecutor::NumInputs() const { return metadata_->num_inputs(); } |
217 | |
218 | NDArray AotExecutor::GetInput(int index) const { return args_[index]; } |
219 | |
220 | NDArray AotExecutor::GetOutput(int index) const { return args_[metadata_->num_inputs() + index]; } |
221 | |
222 | void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { GetOutput(index).CopyTo(data_out); } |
223 | |
224 | } // namespace runtime |
225 | } // namespace tvm |
226 | |