1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file graph_executor_factory.cc |
22 | * \brief Graph executor factory implementations |
23 | */ |
24 | |
25 | #include "./graph_executor_factory.h" |
26 | |
27 | #include <tvm/runtime/container/map.h> |
28 | #include <tvm/runtime/container/string.h> |
29 | #include <tvm/runtime/device_api.h> |
30 | #include <tvm/runtime/registry.h> |
31 | |
32 | #include <iterator> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace runtime { |
37 | |
38 | GraphExecutorFactory::GraphExecutorFactory( |
39 | const std::string& graph_json, |
40 | const std::unordered_map<std::string, tvm::runtime::NDArray>& params, |
41 | const std::string& module_name) { |
42 | graph_json_ = graph_json; |
43 | params_ = params; |
44 | module_name_ = module_name; |
45 | } |
46 | |
47 | PackedFunc GraphExecutorFactory::GetFunction( |
48 | const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) { |
49 | if (name == module_name_) { |
50 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
51 | std::vector<Device> devices; |
52 | for (int i = 0; i < args.num_args; ++i) { |
53 | devices.emplace_back(args[i].operator Device()); |
54 | } |
55 | *rv = this->ExecutorCreate(devices); |
56 | }); |
57 | } else if (name == "get_graph_json" ) { |
58 | return PackedFunc( |
59 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_json_; }); |
60 | |
61 | } else if (name == "get_graph_params" ) { |
62 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
63 | Map<String, tvm::runtime::NDArray> params; |
64 | for (const auto& kv : params_) { |
65 | params.Set(kv.first, kv.second); |
66 | } |
67 | *rv = params; |
68 | }); |
69 | } else if (name == "debug_create" ) { |
70 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
71 | ICHECK_GE(args.size(), 2); |
72 | std::string module_name = args[0].operator String(); |
73 | ICHECK(module_name == module_name_) << "Currently we only support single model for now." ; |
74 | std::vector<Device> devices; |
75 | for (int i = 1; i < args.num_args; ++i) { |
76 | devices.emplace_back(args[i].operator Device()); |
77 | } |
78 | *rv = this->DebugExecutorCreate(devices); |
79 | }); |
80 | } else if (name == "remove_params" ) { |
81 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
82 | std::unordered_map<std::string, tvm::runtime::NDArray> empty_params{}; |
83 | auto exec = |
84 | make_object<GraphExecutorFactory>(this->graph_json_, empty_params, this->module_name_); |
85 | exec->Import(this->imports_[0]); |
86 | *rv = Module(exec); |
87 | }); |
88 | } else if (name == "cuda_graph_create" ) { |
89 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
90 | std::vector<Device> devices; |
91 | for (int i = 0; i < args.num_args; ++i) { |
92 | devices.emplace_back(args[i].operator Device()); |
93 | } |
94 | *rv = this->CudaGraphExecutorCreate(devices); |
95 | }); |
96 | } else { |
97 | return PackedFunc(); |
98 | } |
99 | } |
100 | |
101 | void GraphExecutorFactory::SaveToBinary(dmlc::Stream* stream) { |
102 | stream->Write(graph_json_); |
103 | std::vector<std::string> names; |
104 | std::vector<DLTensor*> arrays; |
105 | for (const auto& v : params_) { |
106 | names.emplace_back(v.first); |
107 | arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->())); |
108 | } |
109 | uint64_t sz = arrays.size(); |
110 | ICHECK(sz == names.size()); |
111 | stream->Write(sz); |
112 | stream->Write(names); |
113 | for (size_t i = 0; i < sz; ++i) { |
114 | tvm::runtime::SaveDLTensor(stream, arrays[i]); |
115 | } |
116 | stream->Write(module_name_); |
117 | } |
118 | |
119 | Module GraphExecutorFactory::ExecutorCreate(const std::vector<Device>& devs) { |
120 | auto exec = make_object<GraphExecutor>(); |
121 | exec->Init(this->graph_json_, this->imports_[0], devs, PackedFunc()); |
122 | // set params |
123 | SetParams(exec.get(), this->params_); |
124 | return Module(exec); |
125 | } |
126 | |
127 | Module GraphExecutorFactory::DebugExecutorCreate(const std::vector<Device>& devs) { |
128 | const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_executor_debug.create" ); |
129 | ICHECK(pf != nullptr) << "Cannot find function tvm.graph_executor_debug.create in registry. " |
130 | "Do you enable debug graph executor build?" ; |
131 | // Debug executor create packed function will call GetAllContexs, so we unpack the devs. |
132 | std::vector<int> unpacked_devs; |
133 | for (const auto& dev : devs) { |
134 | unpacked_devs.emplace_back(dev.device_type); |
135 | unpacked_devs.emplace_back(dev.device_id); |
136 | } |
137 | size_t args_size = unpacked_devs.size() + 2; |
138 | std::vector<TVMValue> values(args_size); |
139 | std::vector<int> codes(args_size); |
140 | runtime::TVMArgsSetter setter(values.data(), codes.data()); |
141 | setter(0, this->graph_json_); |
142 | setter(1, this->imports_[0]); |
143 | for (size_t i = 0; i < unpacked_devs.size(); ++i) { |
144 | setter(i + 2, unpacked_devs[i]); |
145 | } |
146 | TVMRetValue rv; |
147 | pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); |
148 | Module mod = rv.operator Module(); |
149 | // debug graph executor is one child class of graph executor. |
150 | SetParams(const_cast<GraphExecutor*>(mod.as<GraphExecutor>()), this->params_); |
151 | return mod; |
152 | } |
153 | |
154 | Module GraphExecutorFactory::CudaGraphExecutorCreate(const std::vector<Device>& devs) { |
155 | const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_executor_cuda_graph.create" ); |
156 | ICHECK(pf != nullptr) << "Cannot find function tvm.graph_executor_cuda_graph.create in registry. " |
157 | "Did you set(USE_GRAPH_EXECUTOR_CUGRAPH=ON)?" ; |
158 | std::vector<int> unpacked_devs; |
159 | for (const auto& dev : devs) { |
160 | unpacked_devs.emplace_back(dev.device_type); |
161 | unpacked_devs.emplace_back(dev.device_id); |
162 | } |
163 | size_t args_size = unpacked_devs.size() + 2; |
164 | std::vector<TVMValue> values(args_size); |
165 | std::vector<int> codes(args_size); |
166 | runtime::TVMArgsSetter setter(values.data(), codes.data()); |
167 | setter(0, this->graph_json_); |
168 | setter(1, this->imports_[0]); |
169 | for (size_t i = 0; i < unpacked_devs.size(); ++i) { |
170 | setter(i + 2, unpacked_devs[i]); |
171 | } |
172 | TVMRetValue rv; |
173 | pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); |
174 | Module mod = rv.operator Module(); |
175 | SetParams(const_cast<GraphExecutor*>(mod.as<GraphExecutor>()), this->params_); |
176 | return mod; |
177 | } |
178 | |
179 | Module GraphExecutorFactoryModuleLoadBinary(void* strm) { |
180 | dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); |
181 | std::string graph_json; |
182 | std::unordered_map<std::string, tvm::runtime::NDArray> params; |
183 | std::string module_name; |
184 | ICHECK(stream->Read(&graph_json)); |
185 | uint64_t sz; |
186 | ICHECK(stream->Read(&sz)); |
187 | std::vector<std::string> names; |
188 | ICHECK(stream->Read(&names)); |
189 | ICHECK(sz == names.size()); |
190 | for (size_t i = 0; i < sz; ++i) { |
191 | tvm::runtime::NDArray temp; |
192 | temp.Load(stream); |
193 | params[names[i]] = temp; |
194 | } |
195 | ICHECK(stream->Read(&module_name)); |
196 | auto exec = make_object<GraphExecutorFactory>(graph_json, params, module_name); |
197 | return Module(exec); |
198 | } |
199 | |
200 | TVM_REGISTER_GLOBAL("tvm.graph_executor_factory.create" ) |
201 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
202 | ICHECK_GE(args.num_args, 3) << "The expected number of arguments for " |
203 | "graph_executor_factory.create needs at least 3, " |
204 | "but it has " |
205 | << args.num_args; |
206 | // The argument order is graph_json, module, module_name, param0_name, param0_tensor, |
207 | // [param1_name, param1_tensor], ... |
208 | ICHECK_EQ((args.size() - 3) % 2, 0); |
209 | std::unordered_map<std::string, tvm::runtime::NDArray> params; |
210 | for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) { |
211 | std::string name = args[i].operator String(); |
212 | params[name] = args[i + 1].operator tvm::runtime::NDArray(); |
213 | } |
214 | auto exec = make_object<GraphExecutorFactory>(args[0], params, args[2]); |
215 | exec->Import(args[1]); |
216 | *rv = Module(exec); |
217 | }); |
218 | |
219 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphExecutorFactory" ) |
220 | .set_body_typed(GraphExecutorFactoryModuleLoadBinary); |
221 | |
222 | Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { |
223 | LOG(WARNING) << "You are loading a module which was built with GraphRuntimeFactory. " |
224 | << "GraphRuntime has been renamed to GraphExecutor, and support for loading " |
225 | << "GraphRuntimeFactory modules will be removed after the next TVM release. " |
226 | << "Please rebuild the module before then to avoid breakage." ; |
227 | return GraphExecutorFactoryModuleLoadBinary(strm); |
228 | } |
229 | |
230 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory" ) |
231 | .set_body_typed(GraphRuntimeFactoryModuleLoadBinary); |
232 | |
233 | } // namespace runtime |
234 | } // namespace tvm |
235 | |