1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file graph_executor_factory.cc
22 * \brief Graph executor factory implementations
23 */
24
25#include "./graph_executor_factory.h"
26
27#include <tvm/runtime/container/map.h>
28#include <tvm/runtime/container/string.h>
29#include <tvm/runtime/device_api.h>
30#include <tvm/runtime/registry.h>
31
32#include <iterator>
33#include <vector>
34
35namespace tvm {
36namespace runtime {
37
38GraphExecutorFactory::GraphExecutorFactory(
39 const std::string& graph_json,
40 const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
41 const std::string& module_name) {
42 graph_json_ = graph_json;
43 params_ = params;
44 module_name_ = module_name;
45}
46
47PackedFunc GraphExecutorFactory::GetFunction(
48 const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
49 if (name == module_name_) {
50 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
51 std::vector<Device> devices;
52 for (int i = 0; i < args.num_args; ++i) {
53 devices.emplace_back(args[i].operator Device());
54 }
55 *rv = this->ExecutorCreate(devices);
56 });
57 } else if (name == "get_graph_json") {
58 return PackedFunc(
59 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_json_; });
60
61 } else if (name == "get_graph_params") {
62 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
63 Map<String, tvm::runtime::NDArray> params;
64 for (const auto& kv : params_) {
65 params.Set(kv.first, kv.second);
66 }
67 *rv = params;
68 });
69 } else if (name == "debug_create") {
70 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
71 ICHECK_GE(args.size(), 2);
72 std::string module_name = args[0].operator String();
73 ICHECK(module_name == module_name_) << "Currently we only support single model for now.";
74 std::vector<Device> devices;
75 for (int i = 1; i < args.num_args; ++i) {
76 devices.emplace_back(args[i].operator Device());
77 }
78 *rv = this->DebugExecutorCreate(devices);
79 });
80 } else if (name == "remove_params") {
81 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
82 std::unordered_map<std::string, tvm::runtime::NDArray> empty_params{};
83 auto exec =
84 make_object<GraphExecutorFactory>(this->graph_json_, empty_params, this->module_name_);
85 exec->Import(this->imports_[0]);
86 *rv = Module(exec);
87 });
88 } else if (name == "cuda_graph_create") {
89 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
90 std::vector<Device> devices;
91 for (int i = 0; i < args.num_args; ++i) {
92 devices.emplace_back(args[i].operator Device());
93 }
94 *rv = this->CudaGraphExecutorCreate(devices);
95 });
96 } else {
97 return PackedFunc();
98 }
99}
100
101void GraphExecutorFactory::SaveToBinary(dmlc::Stream* stream) {
102 stream->Write(graph_json_);
103 std::vector<std::string> names;
104 std::vector<DLTensor*> arrays;
105 for (const auto& v : params_) {
106 names.emplace_back(v.first);
107 arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->()));
108 }
109 uint64_t sz = arrays.size();
110 ICHECK(sz == names.size());
111 stream->Write(sz);
112 stream->Write(names);
113 for (size_t i = 0; i < sz; ++i) {
114 tvm::runtime::SaveDLTensor(stream, arrays[i]);
115 }
116 stream->Write(module_name_);
117}
118
119Module GraphExecutorFactory::ExecutorCreate(const std::vector<Device>& devs) {
120 auto exec = make_object<GraphExecutor>();
121 exec->Init(this->graph_json_, this->imports_[0], devs, PackedFunc());
122 // set params
123 SetParams(exec.get(), this->params_);
124 return Module(exec);
125}
126
127Module GraphExecutorFactory::DebugExecutorCreate(const std::vector<Device>& devs) {
128 const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_executor_debug.create");
129 ICHECK(pf != nullptr) << "Cannot find function tvm.graph_executor_debug.create in registry. "
130 "Do you enable debug graph executor build?";
131 // Debug executor create packed function will call GetAllContexs, so we unpack the devs.
132 std::vector<int> unpacked_devs;
133 for (const auto& dev : devs) {
134 unpacked_devs.emplace_back(dev.device_type);
135 unpacked_devs.emplace_back(dev.device_id);
136 }
137 size_t args_size = unpacked_devs.size() + 2;
138 std::vector<TVMValue> values(args_size);
139 std::vector<int> codes(args_size);
140 runtime::TVMArgsSetter setter(values.data(), codes.data());
141 setter(0, this->graph_json_);
142 setter(1, this->imports_[0]);
143 for (size_t i = 0; i < unpacked_devs.size(); ++i) {
144 setter(i + 2, unpacked_devs[i]);
145 }
146 TVMRetValue rv;
147 pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv);
148 Module mod = rv.operator Module();
149 // debug graph executor is one child class of graph executor.
150 SetParams(const_cast<GraphExecutor*>(mod.as<GraphExecutor>()), this->params_);
151 return mod;
152}
153
154Module GraphExecutorFactory::CudaGraphExecutorCreate(const std::vector<Device>& devs) {
155 const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_executor_cuda_graph.create");
156 ICHECK(pf != nullptr) << "Cannot find function tvm.graph_executor_cuda_graph.create in registry. "
157 "Did you set(USE_GRAPH_EXECUTOR_CUGRAPH=ON)?";
158 std::vector<int> unpacked_devs;
159 for (const auto& dev : devs) {
160 unpacked_devs.emplace_back(dev.device_type);
161 unpacked_devs.emplace_back(dev.device_id);
162 }
163 size_t args_size = unpacked_devs.size() + 2;
164 std::vector<TVMValue> values(args_size);
165 std::vector<int> codes(args_size);
166 runtime::TVMArgsSetter setter(values.data(), codes.data());
167 setter(0, this->graph_json_);
168 setter(1, this->imports_[0]);
169 for (size_t i = 0; i < unpacked_devs.size(); ++i) {
170 setter(i + 2, unpacked_devs[i]);
171 }
172 TVMRetValue rv;
173 pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv);
174 Module mod = rv.operator Module();
175 SetParams(const_cast<GraphExecutor*>(mod.as<GraphExecutor>()), this->params_);
176 return mod;
177}
178
179Module GraphExecutorFactoryModuleLoadBinary(void* strm) {
180 dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
181 std::string graph_json;
182 std::unordered_map<std::string, tvm::runtime::NDArray> params;
183 std::string module_name;
184 ICHECK(stream->Read(&graph_json));
185 uint64_t sz;
186 ICHECK(stream->Read(&sz));
187 std::vector<std::string> names;
188 ICHECK(stream->Read(&names));
189 ICHECK(sz == names.size());
190 for (size_t i = 0; i < sz; ++i) {
191 tvm::runtime::NDArray temp;
192 temp.Load(stream);
193 params[names[i]] = temp;
194 }
195 ICHECK(stream->Read(&module_name));
196 auto exec = make_object<GraphExecutorFactory>(graph_json, params, module_name);
197 return Module(exec);
198}
199
200TVM_REGISTER_GLOBAL("tvm.graph_executor_factory.create")
201 .set_body([](TVMArgs args, TVMRetValue* rv) {
202 ICHECK_GE(args.num_args, 3) << "The expected number of arguments for "
203 "graph_executor_factory.create needs at least 3, "
204 "but it has "
205 << args.num_args;
206 // The argument order is graph_json, module, module_name, param0_name, param0_tensor,
207 // [param1_name, param1_tensor], ...
208 ICHECK_EQ((args.size() - 3) % 2, 0);
209 std::unordered_map<std::string, tvm::runtime::NDArray> params;
210 for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
211 std::string name = args[i].operator String();
212 params[name] = args[i + 1].operator tvm::runtime::NDArray();
213 }
214 auto exec = make_object<GraphExecutorFactory>(args[0], params, args[2]);
215 exec->Import(args[1]);
216 *rv = Module(exec);
217 });
218
219TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphExecutorFactory")
220 .set_body_typed(GraphExecutorFactoryModuleLoadBinary);
221
222Module GraphRuntimeFactoryModuleLoadBinary(void* strm) {
223 LOG(WARNING) << "You are loading a module which was built with GraphRuntimeFactory. "
224 << "GraphRuntime has been renamed to GraphExecutor, and support for loading "
225 << "GraphRuntimeFactory modules will be removed after the next TVM release. "
226 << "Please rebuild the module before then to avoid breakage.";
227 return GraphExecutorFactoryModuleLoadBinary(strm);
228}
229
230TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory")
231 .set_body_typed(GraphRuntimeFactoryModuleLoadBinary);
232
233} // namespace runtime
234} // namespace tvm
235