1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file module.cc
22 * \brief TVM module system
23 */
24#include <tvm/runtime/module.h>
25#include <tvm/runtime/packed_func.h>
26#include <tvm/runtime/registry.h>
27
28#include <cstring>
29#include <unordered_set>
30
31#include "file_utils.h"
32
33namespace tvm {
34namespace runtime {
35
36void ModuleNode::Import(Module other) {
37 // specially handle rpc
38 if (!std::strcmp(this->type_key(), "rpc")) {
39 static const PackedFunc* fimport_ = nullptr;
40 if (fimport_ == nullptr) {
41 fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule");
42 ICHECK(fimport_ != nullptr);
43 }
44 (*fimport_)(GetRef<Module>(this), other);
45 return;
46 }
47 // cyclic detection.
48 std::unordered_set<const ModuleNode*> visited{other.operator->()};
49 std::vector<const ModuleNode*> stack{other.operator->()};
50 while (!stack.empty()) {
51 const ModuleNode* n = stack.back();
52 stack.pop_back();
53 for (const Module& m : n->imports_) {
54 const ModuleNode* next = m.operator->();
55 if (visited.count(next)) continue;
56 visited.insert(next);
57 stack.push_back(next);
58 }
59 }
60 ICHECK(!visited.count(this)) << "Cyclic dependency detected during import";
61 this->imports_.emplace_back(std::move(other));
62}
63
64PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
65 ModuleNode* self = this;
66 PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
67 if (pf != nullptr) return pf;
68 if (query_imports) {
69 for (Module& m : self->imports_) {
70 pf = m.operator->()->GetFunction(name, query_imports);
71 if (pf != nullptr) {
72 return pf;
73 }
74 }
75 }
76 return pf;
77}
78
79Module Module::LoadFromFile(const std::string& file_name, const std::string& format) {
80 std::string fmt = GetFileFormat(file_name, format);
81 ICHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
82 if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
83 fmt = "so";
84 }
85 std::string load_f_name = "runtime.module.loadfile_" + fmt;
86 VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'";
87 const PackedFunc* f = Registry::Get(load_f_name);
88 ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered,"
89 << " resolved to (" << load_f_name << ") in the global registry."
90 << "Ensure that you have loaded the correct runtime code, and"
91 << "that you are on the correct hardware architecture.";
92 Module m = (*f)(file_name, format);
93 return m;
94}
95
96void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
97 LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
98}
99
100void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
101 LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
102}
103
104std::string ModuleNode::GetSource(const std::string& format) {
105 LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
106}
107
108const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
109 std::lock_guard<std::mutex> lock(mutex_);
110 auto it = import_cache_.find(name);
111 if (it != import_cache_.end()) return it->second.get();
112 PackedFunc pf;
113 for (Module& m : this->imports_) {
114 pf = m.GetFunction(name, true);
115 if (pf != nullptr) break;
116 }
117 if (pf == nullptr) {
118 const PackedFunc* f = Registry::Get(name);
119 ICHECK(f != nullptr) << "Cannot find function " << name
120 << " in the imported modules or global registry."
121 << " If this involves ops from a contrib library like"
122 << " cuDNN, ensure TVM was built with the relevant"
123 << " library.";
124 return f;
125 } else {
126 import_cache_.insert(std::make_pair(name, std::make_shared<PackedFunc>(pf)));
127 return import_cache_.at(name).get();
128 }
129}
130
131std::string ModuleNode::GetFormat() {
132 LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat";
133}
134
135bool ModuleNode::IsDSOExportable() const { return false; }
136
137bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) {
138 return GetFunction(name, query_imports) != nullptr;
139}
140
141bool RuntimeEnabled(const std::string& target) {
142 std::string f_name;
143 if (target == "cpu") {
144 return true;
145 } else if (target == "cuda" || target == "gpu") {
146 f_name = "device_api.cuda";
147 } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
148 f_name = "device_api.opencl";
149 } else if (target == "mtl" || target == "metal") {
150 f_name = "device_api.metal";
151 } else if (target == "tflite") {
152 f_name = "target.runtime.tflite";
153 } else if (target == "vulkan") {
154 f_name = "device_api.vulkan";
155 } else if (target == "stackvm") {
156 f_name = "target.build.stackvm";
157 } else if (target == "rpc") {
158 f_name = "device_api.rpc";
159 } else if (target == "hexagon") {
160 f_name = "device_api.hexagon";
161 } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
162 f_name = "device_api.cuda";
163 } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
164 f_name = "device_api.rocm";
165 } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
166 const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
167 if (pf == nullptr) return false;
168 return (*pf)(target);
169 } else {
170 LOG(FATAL) << "Unknown optional runtime " << target;
171 }
172 return runtime::Registry::Get(f_name) != nullptr;
173}
174
175TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled);
176
177TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) {
178 return mod->GetSource(fmt);
179});
180
181TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) {
182 return static_cast<int64_t>(mod->imports().size());
183});
184
185TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) {
186 return mod->imports().at(index);
187});
188
189TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
190 return std::string(mod->type_key());
191});
192
193TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
194 return mod->GetFormat();
195});
196
197TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
198
199TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
200 .set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); });
201
202TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module mod) {
203 return mod->IsDSOExportable();
204});
205
206TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction")
207 .set_body_typed([](Module mod, String name, bool query_imports) {
208 return mod->ImplementsFunction(std::move(name), query_imports);
209 });
210
211TVM_REGISTER_OBJECT_TYPE(ModuleNode);
212} // namespace runtime
213} // namespace tvm
214