1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/module.h
22 * \brief Runtime container of the functions generated by TVM,
23 * This is used to support dynamically link, load and save
24 * functions from different convention under unified API.
25 */
26#ifndef TVM_RUNTIME_MODULE_H_
27#define TVM_RUNTIME_MODULE_H_
28
29#include <dmlc/io.h>
30#include <tvm/runtime/c_runtime_api.h>
31#include <tvm/runtime/container/string.h>
32#include <tvm/runtime/memory.h>
33#include <tvm/runtime/object.h>
34
35#include <memory>
36#include <mutex>
37#include <string>
38#include <unordered_map>
39#include <vector>
40
41namespace tvm {
42namespace runtime {
43
44class ModuleNode;
45class PackedFunc;
46
47/*!
48 * \brief Module container of TVM.
49 */
50class Module : public ObjectRef {
51 public:
52 Module() {}
53 // constructor from container.
54 explicit Module(ObjectPtr<Object> n) : ObjectRef(n) {}
55 /*!
56 * \brief Get packed function from current module by name.
57 *
58 * \param name The name of the function.
59 * \param query_imports Whether also query dependency modules.
60 * \return The result function.
61 * This function will return PackedFunc(nullptr) if function do not exist.
62 * \note Implemented in packed_func.cc
63 */
64 inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
65 // The following functions requires link with runtime.
66 /*!
67 * \brief Import another module into this module.
68 * \param other The module to be imported.
69 *
70 * \note Cyclic dependency is not allowed among modules,
71 * An error will be thrown when cyclic dependency is detected.
72 */
73 inline void Import(Module other);
74 /*! \return internal container */
75 inline ModuleNode* operator->();
76 /*! \return internal container */
77 inline const ModuleNode* operator->() const;
78 /*!
79 * \brief Load a module from file.
80 * \param file_name The name of the host function module.
81 * \param format The format of the file.
82 * \note This function won't load the import relationship.
83 * Re-create import relationship by calling Import.
84 */
85 TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = "");
86 // refer to the corresponding container.
87 using ContainerType = ModuleNode;
88 friend class ModuleNode;
89};
90
91/*!
92 * \brief Base container of module.
93 *
94 * Please subclass ModuleNode to create a specific runtime module.
95 *
96 * \code
97 *
98 * class MyModuleNode : public ModuleNode {
99 * public:
100 * // implement the interface
101 * };
102 *
103 * // use make_object to create a specific
104 * // instace of MyModuleNode.
105 * Module CreateMyModule() {
106 * ObjectPtr<MyModuleNode> n =
107 * tvm::runtime::make_object<MyModuleNode>();
108 * return Module(n);
109 * }
110 *
111 * \endcode
112 */
113class TVM_DLL ModuleNode : public Object {
114 public:
115 /*! \brief virtual destructor */
116 virtual ~ModuleNode() = default;
117 /*!
118 * \return The per module type key.
119 * \note This key is used to for serializing custom modules.
120 */
121 virtual const char* type_key() const = 0;
122 /*!
123 * \brief Get a PackedFunc from module.
124 *
125 * The PackedFunc may not be fully initialized,
126 * there might still be first time running overhead when
127 * executing the function on certain devices.
128 * For benchmarking, use prepare to eliminate
129 *
130 * \param name the name of the function.
131 * \param sptr_to_self The ObjectPtr that points to this module node.
132 *
133 * \return PackedFunc(nullptr) when it is not available.
134 *
135 * \note The function will always remain valid.
136 * If the function need resource from the module(e.g. late linking),
137 * it should capture sptr_to_self.
138 */
139 virtual PackedFunc GetFunction(const std::string& name,
140 const ObjectPtr<Object>& sptr_to_self) = 0;
141 /*!
142 * \brief Save the module to file.
143 * \param file_name The file to be saved to.
144 * \param format The format of the file.
145 */
146 virtual void SaveToFile(const std::string& file_name, const std::string& format);
147 /*!
148 * \brief Save the module to binary stream.
149 * \param stream The binary stream to save to.
150 * \note It is recommended to implement this for device modules,
151 * but not necessarily host modules.
152 * We can use this to do AOT loading of bundled device functions.
153 */
154 virtual void SaveToBinary(dmlc::Stream* stream);
155 /*!
156 * \brief Get the source code of module, when available.
157 * \param format Format of the source code, can be empty by default.
158 * \return Possible source code when available.
159 */
160 virtual std::string GetSource(const std::string& format = "");
161 /*!
162 * \brief Get the format of the module, when available.
163 * \return Possible format when available.
164 */
165 virtual std::string GetFormat();
166 /*!
167 * \brief Get packed function from current module by name.
168 *
169 * \param name The name of the function.
170 * \param query_imports Whether also query dependency modules.
171 * \return The result function.
172 * This function will return PackedFunc(nullptr) if function do not exist.
173 * \note Implemented in packed_func.cc
174 */
175 PackedFunc GetFunction(const std::string& name, bool query_imports = false);
176 /*!
177 * \brief Import another module into this module.
178 * \param other The module to be imported.
179 *
180 * \note Cyclic dependency is not allowed among modules,
181 * An error will be thrown when cyclic dependency is detected.
182 */
183 void Import(Module other);
184 /*!
185 * \brief Get a function from current environment
186 * The environment includes all the imports as well as Global functions.
187 *
188 * \param name name of the function.
189 * \return The corresponding function.
190 */
191 const PackedFunc* GetFuncFromEnv(const std::string& name);
192 /*! \return The module it imports from */
193 const std::vector<Module>& imports() const { return imports_; }
194
195 /*!
196 * \brief Returns true if this module is 'DSO exportable'.
197 *
198 * A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be incorporated into the
199 * final runtime artifact (ie shared library) by compilation and/or linking using the external
200 * compiler (llvm, nvcc, etc). DSO exportable modules must implement SaveToFile.
201 *
202 * By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 'cuda') typically must
203 * be incorporated into the final runtime artifact by being serialized as data into the
204 * artifact, then deserialized at runtime. Non-DSO exportable modules must implement SaveToBinary,
205 * and have a matching deserializer registered as 'runtime.module.loadbinary_<type_key>'.
206 *
207 * The default implementation returns false.
208 */
209 virtual bool IsDSOExportable() const;
210
211 /*!
212 * \brief Returns true if this module has a definition for a function of \p name. If
213 * \p query_imports is true, also search in any imported modules.
214 *
215 * Note that even if this function returns true the corresponding \p GetFunction result may be
216 * nullptr if the function is not yet callable without further compilation.
217 *
218 * The default implementation just checkis if \p GetFunction is non-null.
219 */
220 virtual bool ImplementsFunction(const String& name, bool query_imports = false);
221
222 // integration with the existing components.
223 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
224 static constexpr const char* _type_key = "runtime.Module";
225 // NOTE: ModuleNode can still be sub-classed
226 //
227 TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
228
229 protected:
230 friend class Module;
231 friend class ModuleInternal;
232 /*! \brief The modules this module depend on */
233 std::vector<Module> imports_;
234
235 private:
236 /*! \brief Cache used by GetImport */
237 std::unordered_map<std::string, std::shared_ptr<PackedFunc>> import_cache_;
238 std::mutex mutex_;
239};
240
241/*!
242 * \brief Check if runtime module is enabled for target.
243 * \param target The target module name.
244 * \return Whether runtime is enabled.
245 */
246TVM_DLL bool RuntimeEnabled(const std::string& target);
247
248/*! \brief namespace for constant symbols */
249namespace symbol {
250/*! \brief A PackedFunc that retrieves exported metadata. */
251constexpr const char* tvm_get_c_metadata = "get_c_metadata";
252/*! \brief Global variable to store module context. */
253constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
254/*! \brief Global variable to store device module blob */
255constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
256/*! \brief Number of bytes of device module blob. */
257constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
258/*! \brief global function to set device */
259constexpr const char* tvm_set_device = "__tvm_set_device";
260/*! \brief Auxiliary counter to global barrier. */
261constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
262/*! \brief Prepare the global barrier before kernels that uses global barrier. */
263constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
264/*! \brief Placeholder for the module's entry function. */
265constexpr const char* tvm_module_main = "__tvm_main__";
266/*! \brief Prefix for parameter symbols emitted into the main program. */
267constexpr const char* tvm_param_prefix = "__tvm_param__";
268/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
269constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
270/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */
271constexpr const char* tvm_entrypoint_suffix = "run";
272} // namespace symbol
273
274// implementations of inline functions.
275
276inline void Module::Import(Module other) { return (*this)->Import(other); }
277
278inline ModuleNode* Module::operator->() { return static_cast<ModuleNode*>(get_mutable()); }
279
280inline const ModuleNode* Module::operator->() const {
281 return static_cast<const ModuleNode*>(get());
282}
283
284inline std::ostream& operator<<(std::ostream& out, const Module& module) {
285 out << "Module(type_key= ";
286 out << module->type_key();
287 out << ")";
288
289 return out;
290}
291
292} // namespace runtime
293} // namespace tvm
294
295#include <tvm/runtime/packed_func.h> // NOLINT(*)
296#endif // TVM_RUNTIME_MODULE_H_
297