1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file aot_executor_factory.cc
22 * \brief AOT executor factory implementations
23 */
24
25#include "./aot_executor_factory.h"
26
27#include <tvm/runtime/container/string.h>
28#include <tvm/runtime/device_api.h>
29#include <tvm/runtime/registry.h>
30
31#include <iterator>
32#include <vector>
33
34namespace tvm {
35namespace runtime {
36
37AotExecutorFactory::AotExecutorFactory(
38 const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
39 const std::string& module_name) {
40 params_ = params;
41 module_name_ = module_name;
42}
43
44PackedFunc AotExecutorFactory::GetFunction(
45 const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
46 if (name == module_name_) {
47 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
48 ICHECK_GT(args.num_args, 0) << "Must supply at least one device argument";
49 std::vector<Device> devices;
50 for (int i = 0; i < args.num_args; ++i) {
51 devices.emplace_back(args[i].operator Device());
52 }
53 *rv = this->ExecutorCreate(devices);
54 });
55 } else if (name == "list_module_names") {
56 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
57 Array<String> names = {module_name_};
58 *rv = names;
59 });
60 } else if (name == "remove_params") {
61 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
62 std::unordered_map<std::string, tvm::runtime::NDArray> empty_params{};
63 auto exec = make_object<AotExecutorFactory>(empty_params, this->module_name_);
64 exec->Import(this->imports_[0]);
65 *rv = Module(exec);
66 });
67 } else {
68 return PackedFunc();
69 }
70}
71
72void AotExecutorFactory::SaveToBinary(dmlc::Stream* stream) {
73 std::vector<std::string> names;
74 std::vector<DLTensor*> arrays;
75 for (const auto& v : params_) {
76 names.emplace_back(v.first);
77 arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->()));
78 }
79 uint64_t sz = arrays.size();
80 ICHECK(sz == names.size());
81 stream->Write(sz);
82 stream->Write(names);
83 for (size_t i = 0; i < sz; ++i) {
84 tvm::runtime::SaveDLTensor(stream, arrays[i]);
85 }
86 stream->Write(module_name_);
87}
88
89Module AotExecutorFactory::ExecutorCreate(const std::vector<Device>& devs) {
90 auto exec = make_object<AotExecutor>(this->imports_[0], devs);
91 // set params
92 SetParams(exec.get(), this->params_);
93 return Module(exec);
94}
95
96Module AotExecutorFactoryModuleLoadBinary(void* strm) {
97 dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
98 std::unordered_map<std::string, tvm::runtime::NDArray> params;
99 std::string module_name;
100 uint64_t sz;
101 ICHECK(stream->Read(&sz));
102 std::vector<std::string> names;
103 ICHECK(stream->Read(&names));
104 ICHECK(sz == names.size());
105 for (size_t i = 0; i < sz; ++i) {
106 tvm::runtime::NDArray temp;
107 temp.Load(stream);
108 params[names[i]] = temp;
109 }
110 ICHECK(stream->Read(&module_name));
111 auto exec = make_object<AotExecutorFactory>(params, module_name);
112 return Module(exec);
113}
114
115TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) {
116 ICHECK_GE(args.num_args, 2) << "The expected number of arguments for "
117 "aot_executor_factory.create needs at least 2, "
118 "but it has "
119 << args.num_args;
120 // The argument order is module, module_name, param0_name, param0_tensor,
121 // [param1_name, param1_tensor], ...
122 ICHECK_EQ((args.size() - 2) % 2, 0);
123 std::unordered_map<std::string, tvm::runtime::NDArray> params;
124 for (size_t i = 2; i < static_cast<size_t>(args.size()); i += 2) {
125 std::string name = args[i].operator String();
126 params[name] = args[i + 1].operator tvm::runtime::NDArray();
127 }
128 auto exec = make_object<AotExecutorFactory>(params, args[1]);
129 exec->Import(args[0]);
130 *rv = Module(exec);
131});
132
133TVM_REGISTER_GLOBAL("runtime.module.loadbinary_AotExecutorFactory")
134 .set_body_typed(AotExecutorFactoryModuleLoadBinary);
135
136} // namespace runtime
137} // namespace tvm
138