1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file aot_executor_factory.cc |
22 | * \brief AOT executor factory implementations |
23 | */ |
24 | |
25 | #include "./aot_executor_factory.h" |
26 | |
27 | #include <tvm/runtime/container/string.h> |
28 | #include <tvm/runtime/device_api.h> |
29 | #include <tvm/runtime/registry.h> |
30 | |
31 | #include <iterator> |
32 | #include <vector> |
33 | |
34 | namespace tvm { |
35 | namespace runtime { |
36 | |
37 | AotExecutorFactory::AotExecutorFactory( |
38 | const std::unordered_map<std::string, tvm::runtime::NDArray>& params, |
39 | const std::string& module_name) { |
40 | params_ = params; |
41 | module_name_ = module_name; |
42 | } |
43 | |
44 | PackedFunc AotExecutorFactory::GetFunction( |
45 | const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) { |
46 | if (name == module_name_) { |
47 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
48 | ICHECK_GT(args.num_args, 0) << "Must supply at least one device argument" ; |
49 | std::vector<Device> devices; |
50 | for (int i = 0; i < args.num_args; ++i) { |
51 | devices.emplace_back(args[i].operator Device()); |
52 | } |
53 | *rv = this->ExecutorCreate(devices); |
54 | }); |
55 | } else if (name == "list_module_names" ) { |
56 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
57 | Array<String> names = {module_name_}; |
58 | *rv = names; |
59 | }); |
60 | } else if (name == "remove_params" ) { |
61 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
62 | std::unordered_map<std::string, tvm::runtime::NDArray> empty_params{}; |
63 | auto exec = make_object<AotExecutorFactory>(empty_params, this->module_name_); |
64 | exec->Import(this->imports_[0]); |
65 | *rv = Module(exec); |
66 | }); |
67 | } else { |
68 | return PackedFunc(); |
69 | } |
70 | } |
71 | |
72 | void AotExecutorFactory::SaveToBinary(dmlc::Stream* stream) { |
73 | std::vector<std::string> names; |
74 | std::vector<DLTensor*> arrays; |
75 | for (const auto& v : params_) { |
76 | names.emplace_back(v.first); |
77 | arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->())); |
78 | } |
79 | uint64_t sz = arrays.size(); |
80 | ICHECK(sz == names.size()); |
81 | stream->Write(sz); |
82 | stream->Write(names); |
83 | for (size_t i = 0; i < sz; ++i) { |
84 | tvm::runtime::SaveDLTensor(stream, arrays[i]); |
85 | } |
86 | stream->Write(module_name_); |
87 | } |
88 | |
89 | Module AotExecutorFactory::ExecutorCreate(const std::vector<Device>& devs) { |
90 | auto exec = make_object<AotExecutor>(this->imports_[0], devs); |
91 | // set params |
92 | SetParams(exec.get(), this->params_); |
93 | return Module(exec); |
94 | } |
95 | |
96 | Module AotExecutorFactoryModuleLoadBinary(void* strm) { |
97 | dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); |
98 | std::unordered_map<std::string, tvm::runtime::NDArray> params; |
99 | std::string module_name; |
100 | uint64_t sz; |
101 | ICHECK(stream->Read(&sz)); |
102 | std::vector<std::string> names; |
103 | ICHECK(stream->Read(&names)); |
104 | ICHECK(sz == names.size()); |
105 | for (size_t i = 0; i < sz; ++i) { |
106 | tvm::runtime::NDArray temp; |
107 | temp.Load(stream); |
108 | params[names[i]] = temp; |
109 | } |
110 | ICHECK(stream->Read(&module_name)); |
111 | auto exec = make_object<AotExecutorFactory>(params, module_name); |
112 | return Module(exec); |
113 | } |
114 | |
115 | TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
116 | ICHECK_GE(args.num_args, 2) << "The expected number of arguments for " |
117 | "aot_executor_factory.create needs at least 2, " |
118 | "but it has " |
119 | << args.num_args; |
120 | // The argument order is module, module_name, param0_name, param0_tensor, |
121 | // [param1_name, param1_tensor], ... |
122 | ICHECK_EQ((args.size() - 2) % 2, 0); |
123 | std::unordered_map<std::string, tvm::runtime::NDArray> params; |
124 | for (size_t i = 2; i < static_cast<size_t>(args.size()); i += 2) { |
125 | std::string name = args[i].operator String(); |
126 | params[name] = args[i + 1].operator tvm::runtime::NDArray(); |
127 | } |
128 | auto exec = make_object<AotExecutorFactory>(params, args[1]); |
129 | exec->Import(args[0]); |
130 | *rv = Module(exec); |
131 | }); |
132 | |
133 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_AotExecutorFactory" ) |
134 | .set_body_typed(AotExecutorFactoryModuleLoadBinary); |
135 | |
136 | } // namespace runtime |
137 | } // namespace tvm |
138 | |