1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file module.cc |
22 | * \brief TVM module system |
23 | */ |
24 | #include <tvm/runtime/module.h> |
25 | #include <tvm/runtime/packed_func.h> |
26 | #include <tvm/runtime/registry.h> |
27 | |
28 | #include <cstring> |
29 | #include <unordered_set> |
30 | |
31 | #include "file_utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace runtime { |
35 | |
36 | void ModuleNode::Import(Module other) { |
37 | // specially handle rpc |
38 | if (!std::strcmp(this->type_key(), "rpc" )) { |
39 | static const PackedFunc* fimport_ = nullptr; |
40 | if (fimport_ == nullptr) { |
41 | fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule" ); |
42 | ICHECK(fimport_ != nullptr); |
43 | } |
44 | (*fimport_)(GetRef<Module>(this), other); |
45 | return; |
46 | } |
47 | // cyclic detection. |
48 | std::unordered_set<const ModuleNode*> visited{other.operator->()}; |
49 | std::vector<const ModuleNode*> stack{other.operator->()}; |
50 | while (!stack.empty()) { |
51 | const ModuleNode* n = stack.back(); |
52 | stack.pop_back(); |
53 | for (const Module& m : n->imports_) { |
54 | const ModuleNode* next = m.operator->(); |
55 | if (visited.count(next)) continue; |
56 | visited.insert(next); |
57 | stack.push_back(next); |
58 | } |
59 | } |
60 | ICHECK(!visited.count(this)) << "Cyclic dependency detected during import" ; |
61 | this->imports_.emplace_back(std::move(other)); |
62 | } |
63 | |
64 | PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) { |
65 | ModuleNode* self = this; |
66 | PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this)); |
67 | if (pf != nullptr) return pf; |
68 | if (query_imports) { |
69 | for (Module& m : self->imports_) { |
70 | pf = m.operator->()->GetFunction(name, query_imports); |
71 | if (pf != nullptr) { |
72 | return pf; |
73 | } |
74 | } |
75 | } |
76 | return pf; |
77 | } |
78 | |
79 | Module Module::LoadFromFile(const std::string& file_name, const std::string& format) { |
80 | std::string fmt = GetFileFormat(file_name, format); |
81 | ICHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name; |
82 | if (fmt == "dll" || fmt == "dylib" || fmt == "dso" ) { |
83 | fmt = "so" ; |
84 | } |
85 | std::string load_f_name = "runtime.module.loadfile_" + fmt; |
86 | VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'" ; |
87 | const PackedFunc* f = Registry::Get(load_f_name); |
88 | ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered," |
89 | << " resolved to (" << load_f_name << ") in the global registry." |
90 | << "Ensure that you have loaded the correct runtime code, and" |
91 | << "that you are on the correct hardware architecture." ; |
92 | Module m = (*f)(file_name, format); |
93 | return m; |
94 | } |
95 | |
96 | void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { |
97 | LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile" ; |
98 | } |
99 | |
100 | void ModuleNode::SaveToBinary(dmlc::Stream* stream) { |
101 | LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary" ; |
102 | } |
103 | |
104 | std::string ModuleNode::GetSource(const std::string& format) { |
105 | LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource" ; |
106 | } |
107 | |
108 | const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { |
109 | std::lock_guard<std::mutex> lock(mutex_); |
110 | auto it = import_cache_.find(name); |
111 | if (it != import_cache_.end()) return it->second.get(); |
112 | PackedFunc pf; |
113 | for (Module& m : this->imports_) { |
114 | pf = m.GetFunction(name, true); |
115 | if (pf != nullptr) break; |
116 | } |
117 | if (pf == nullptr) { |
118 | const PackedFunc* f = Registry::Get(name); |
119 | ICHECK(f != nullptr) << "Cannot find function " << name |
120 | << " in the imported modules or global registry." |
121 | << " If this involves ops from a contrib library like" |
122 | << " cuDNN, ensure TVM was built with the relevant" |
123 | << " library." ; |
124 | return f; |
125 | } else { |
126 | import_cache_.insert(std::make_pair(name, std::make_shared<PackedFunc>(pf))); |
127 | return import_cache_.at(name).get(); |
128 | } |
129 | } |
130 | |
131 | std::string ModuleNode::GetFormat() { |
132 | LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat" ; |
133 | } |
134 | |
135 | bool ModuleNode::IsDSOExportable() const { return false; } |
136 | |
137 | bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) { |
138 | return GetFunction(name, query_imports) != nullptr; |
139 | } |
140 | |
141 | bool RuntimeEnabled(const std::string& target) { |
142 | std::string f_name; |
143 | if (target == "cpu" ) { |
144 | return true; |
145 | } else if (target == "cuda" || target == "gpu" ) { |
146 | f_name = "device_api.cuda" ; |
147 | } else if (target == "cl" || target == "opencl" || target == "sdaccel" ) { |
148 | f_name = "device_api.opencl" ; |
149 | } else if (target == "mtl" || target == "metal" ) { |
150 | f_name = "device_api.metal" ; |
151 | } else if (target == "tflite" ) { |
152 | f_name = "target.runtime.tflite" ; |
153 | } else if (target == "vulkan" ) { |
154 | f_name = "device_api.vulkan" ; |
155 | } else if (target == "stackvm" ) { |
156 | f_name = "target.build.stackvm" ; |
157 | } else if (target == "rpc" ) { |
158 | f_name = "device_api.rpc" ; |
159 | } else if (target == "hexagon" ) { |
160 | f_name = "device_api.hexagon" ; |
161 | } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx" ) { |
162 | f_name = "device_api.cuda" ; |
163 | } else if (target.length() >= 4 && target.substr(0, 4) == "rocm" ) { |
164 | f_name = "device_api.rocm" ; |
165 | } else if (target.length() >= 4 && target.substr(0, 4) == "llvm" ) { |
166 | const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled" ); |
167 | if (pf == nullptr) return false; |
168 | return (*pf)(target); |
169 | } else { |
170 | LOG(FATAL) << "Unknown optional runtime " << target; |
171 | } |
172 | return runtime::Registry::Get(f_name) != nullptr; |
173 | } |
174 | |
175 | TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled" ).set_body_typed(RuntimeEnabled); |
176 | |
177 | TVM_REGISTER_GLOBAL("runtime.ModuleGetSource" ).set_body_typed([](Module mod, std::string fmt) { |
178 | return mod->GetSource(fmt); |
179 | }); |
180 | |
181 | TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize" ).set_body_typed([](Module mod) { |
182 | return static_cast<int64_t>(mod->imports().size()); |
183 | }); |
184 | |
185 | TVM_REGISTER_GLOBAL("runtime.ModuleGetImport" ).set_body_typed([](Module mod, int index) { |
186 | return mod->imports().at(index); |
187 | }); |
188 | |
189 | TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey" ).set_body_typed([](Module mod) { |
190 | return std::string(mod->type_key()); |
191 | }); |
192 | |
193 | TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat" ).set_body_typed([](Module mod) { |
194 | return mod->GetFormat(); |
195 | }); |
196 | |
197 | TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile" ).set_body_typed(Module::LoadFromFile); |
198 | |
199 | TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile" ) |
200 | .set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); }); |
201 | |
202 | TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable" ).set_body_typed([](Module mod) { |
203 | return mod->IsDSOExportable(); |
204 | }); |
205 | |
206 | TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction" ) |
207 | .set_body_typed([](Module mod, String name, bool query_imports) { |
208 | return mod->ImplementsFunction(std::move(name), query_imports); |
209 | }); |
210 | |
211 | TVM_REGISTER_OBJECT_TYPE(ModuleNode); |
212 | } // namespace runtime |
213 | } // namespace tvm |
214 | |