1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/module.h |
22 | * \brief Runtime container of the functions generated by TVM, |
23 | * This is used to support dynamically link, load and save |
24 | * functions from different convention under unified API. |
25 | */ |
26 | #ifndef TVM_RUNTIME_MODULE_H_ |
27 | #define TVM_RUNTIME_MODULE_H_ |
28 | |
29 | #include <dmlc/io.h> |
30 | #include <tvm/runtime/c_runtime_api.h> |
31 | #include <tvm/runtime/container/string.h> |
32 | #include <tvm/runtime/memory.h> |
33 | #include <tvm/runtime/object.h> |
34 | |
35 | #include <memory> |
36 | #include <mutex> |
37 | #include <string> |
38 | #include <unordered_map> |
39 | #include <vector> |
40 | |
41 | namespace tvm { |
42 | namespace runtime { |
43 | |
44 | class ModuleNode; |
45 | class PackedFunc; |
46 | |
47 | /*! |
48 | * \brief Module container of TVM. |
49 | */ |
50 | class Module : public ObjectRef { |
51 | public: |
52 | Module() {} |
53 | // constructor from container. |
54 | explicit Module(ObjectPtr<Object> n) : ObjectRef(n) {} |
55 | /*! |
56 | * \brief Get packed function from current module by name. |
57 | * |
58 | * \param name The name of the function. |
59 | * \param query_imports Whether also query dependency modules. |
60 | * \return The result function. |
61 | * This function will return PackedFunc(nullptr) if function do not exist. |
62 | * \note Implemented in packed_func.cc |
63 | */ |
64 | inline PackedFunc GetFunction(const std::string& name, bool query_imports = false); |
65 | // The following functions requires link with runtime. |
66 | /*! |
67 | * \brief Import another module into this module. |
68 | * \param other The module to be imported. |
69 | * |
70 | * \note Cyclic dependency is not allowed among modules, |
71 | * An error will be thrown when cyclic dependency is detected. |
72 | */ |
73 | inline void Import(Module other); |
74 | /*! \return internal container */ |
75 | inline ModuleNode* operator->(); |
76 | /*! \return internal container */ |
77 | inline const ModuleNode* operator->() const; |
78 | /*! |
79 | * \brief Load a module from file. |
80 | * \param file_name The name of the host function module. |
81 | * \param format The format of the file. |
82 | * \note This function won't load the import relationship. |
83 | * Re-create import relationship by calling Import. |
84 | */ |
85 | TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = "" ); |
86 | // refer to the corresponding container. |
87 | using ContainerType = ModuleNode; |
88 | friend class ModuleNode; |
89 | }; |
90 | |
91 | /*! |
92 | * \brief Base container of module. |
93 | * |
94 | * Please subclass ModuleNode to create a specific runtime module. |
95 | * |
96 | * \code |
97 | * |
98 | * class MyModuleNode : public ModuleNode { |
99 | * public: |
100 | * // implement the interface |
101 | * }; |
102 | * |
103 | * // use make_object to create a specific |
104 | * // instace of MyModuleNode. |
105 | * Module CreateMyModule() { |
106 | * ObjectPtr<MyModuleNode> n = |
107 | * tvm::runtime::make_object<MyModuleNode>(); |
108 | * return Module(n); |
109 | * } |
110 | * |
111 | * \endcode |
112 | */ |
113 | class TVM_DLL ModuleNode : public Object { |
114 | public: |
115 | /*! \brief virtual destructor */ |
116 | virtual ~ModuleNode() = default; |
117 | /*! |
118 | * \return The per module type key. |
119 | * \note This key is used to for serializing custom modules. |
120 | */ |
121 | virtual const char* type_key() const = 0; |
122 | /*! |
123 | * \brief Get a PackedFunc from module. |
124 | * |
125 | * The PackedFunc may not be fully initialized, |
126 | * there might still be first time running overhead when |
127 | * executing the function on certain devices. |
128 | * For benchmarking, use prepare to eliminate |
129 | * |
130 | * \param name the name of the function. |
131 | * \param sptr_to_self The ObjectPtr that points to this module node. |
132 | * |
133 | * \return PackedFunc(nullptr) when it is not available. |
134 | * |
135 | * \note The function will always remain valid. |
136 | * If the function need resource from the module(e.g. late linking), |
137 | * it should capture sptr_to_self. |
138 | */ |
139 | virtual PackedFunc GetFunction(const std::string& name, |
140 | const ObjectPtr<Object>& sptr_to_self) = 0; |
141 | /*! |
142 | * \brief Save the module to file. |
143 | * \param file_name The file to be saved to. |
144 | * \param format The format of the file. |
145 | */ |
146 | virtual void SaveToFile(const std::string& file_name, const std::string& format); |
147 | /*! |
148 | * \brief Save the module to binary stream. |
149 | * \param stream The binary stream to save to. |
150 | * \note It is recommended to implement this for device modules, |
151 | * but not necessarily host modules. |
152 | * We can use this to do AOT loading of bundled device functions. |
153 | */ |
154 | virtual void SaveToBinary(dmlc::Stream* stream); |
155 | /*! |
156 | * \brief Get the source code of module, when available. |
157 | * \param format Format of the source code, can be empty by default. |
158 | * \return Possible source code when available. |
159 | */ |
160 | virtual std::string GetSource(const std::string& format = "" ); |
161 | /*! |
162 | * \brief Get the format of the module, when available. |
163 | * \return Possible format when available. |
164 | */ |
165 | virtual std::string GetFormat(); |
166 | /*! |
167 | * \brief Get packed function from current module by name. |
168 | * |
169 | * \param name The name of the function. |
170 | * \param query_imports Whether also query dependency modules. |
171 | * \return The result function. |
172 | * This function will return PackedFunc(nullptr) if function do not exist. |
173 | * \note Implemented in packed_func.cc |
174 | */ |
175 | PackedFunc GetFunction(const std::string& name, bool query_imports = false); |
176 | /*! |
177 | * \brief Import another module into this module. |
178 | * \param other The module to be imported. |
179 | * |
180 | * \note Cyclic dependency is not allowed among modules, |
181 | * An error will be thrown when cyclic dependency is detected. |
182 | */ |
183 | void Import(Module other); |
184 | /*! |
185 | * \brief Get a function from current environment |
186 | * The environment includes all the imports as well as Global functions. |
187 | * |
188 | * \param name name of the function. |
189 | * \return The corresponding function. |
190 | */ |
191 | const PackedFunc* GetFuncFromEnv(const std::string& name); |
192 | /*! \return The module it imports from */ |
193 | const std::vector<Module>& imports() const { return imports_; } |
194 | |
195 | /*! |
196 | * \brief Returns true if this module is 'DSO exportable'. |
197 | * |
198 | * A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be incorporated into the |
199 | * final runtime artifact (ie shared library) by compilation and/or linking using the external |
200 | * compiler (llvm, nvcc, etc). DSO exportable modules must implement SaveToFile. |
201 | * |
202 | * By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 'cuda') typically must |
203 | * be incorporated into the final runtime artifact by being serialized as data into the |
204 | * artifact, then deserialized at runtime. Non-DSO exportable modules must implement SaveToBinary, |
205 | * and have a matching deserializer registered as 'runtime.module.loadbinary_<type_key>'. |
206 | * |
207 | * The default implementation returns false. |
208 | */ |
209 | virtual bool IsDSOExportable() const; |
210 | |
211 | /*! |
212 | * \brief Returns true if this module has a definition for a function of \p name. If |
213 | * \p query_imports is true, also search in any imported modules. |
214 | * |
215 | * Note that even if this function returns true the corresponding \p GetFunction result may be |
216 | * nullptr if the function is not yet callable without further compilation. |
217 | * |
218 | * The default implementation just checkis if \p GetFunction is non-null. |
219 | */ |
220 | virtual bool ImplementsFunction(const String& name, bool query_imports = false); |
221 | |
222 | // integration with the existing components. |
223 | static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule; |
224 | static constexpr const char* _type_key = "runtime.Module" ; |
225 | // NOTE: ModuleNode can still be sub-classed |
226 | // |
227 | TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); |
228 | |
229 | protected: |
230 | friend class Module; |
231 | friend class ModuleInternal; |
232 | /*! \brief The modules this module depend on */ |
233 | std::vector<Module> imports_; |
234 | |
235 | private: |
236 | /*! \brief Cache used by GetImport */ |
237 | std::unordered_map<std::string, std::shared_ptr<PackedFunc>> import_cache_; |
238 | std::mutex mutex_; |
239 | }; |
240 | |
241 | /*! |
242 | * \brief Check if runtime module is enabled for target. |
243 | * \param target The target module name. |
244 | * \return Whether runtime is enabled. |
245 | */ |
246 | TVM_DLL bool RuntimeEnabled(const std::string& target); |
247 | |
248 | /*! \brief namespace for constant symbols */ |
249 | namespace symbol { |
250 | /*! \brief A PackedFunc that retrieves exported metadata. */ |
251 | constexpr const char* tvm_get_c_metadata = "get_c_metadata" ; |
252 | /*! \brief Global variable to store module context. */ |
253 | constexpr const char* tvm_module_ctx = "__tvm_module_ctx" ; |
254 | /*! \brief Global variable to store device module blob */ |
255 | constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob" ; |
256 | /*! \brief Number of bytes of device module blob. */ |
257 | constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes" ; |
258 | /*! \brief global function to set device */ |
259 | constexpr const char* tvm_set_device = "__tvm_set_device" ; |
260 | /*! \brief Auxiliary counter to global barrier. */ |
261 | constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state" ; |
262 | /*! \brief Prepare the global barrier before kernels that uses global barrier. */ |
263 | constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier" ; |
264 | /*! \brief Placeholder for the module's entry function. */ |
265 | constexpr const char* tvm_module_main = "__tvm_main__" ; |
266 | /*! \brief Prefix for parameter symbols emitted into the main program. */ |
267 | constexpr const char* tvm_param_prefix = "__tvm_param__" ; |
268 | /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ |
269 | constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param" ; |
270 | /*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ |
271 | constexpr const char* tvm_entrypoint_suffix = "run" ; |
272 | } // namespace symbol |
273 | |
274 | // implementations of inline functions. |
275 | |
276 | inline void Module::Import(Module other) { return (*this)->Import(other); } |
277 | |
278 | inline ModuleNode* Module::operator->() { return static_cast<ModuleNode*>(get_mutable()); } |
279 | |
280 | inline const ModuleNode* Module::operator->() const { |
281 | return static_cast<const ModuleNode*>(get()); |
282 | } |
283 | |
284 | inline std::ostream& operator<<(std::ostream& out, const Module& module) { |
285 | out << "Module(type_key= " ; |
286 | out << module->type_key(); |
287 | out << ")" ; |
288 | |
289 | return out; |
290 | } |
291 | |
292 | } // namespace runtime |
293 | } // namespace tvm |
294 | |
295 | #include <tvm/runtime/packed_func.h> // NOLINT(*) |
296 | #endif // TVM_RUNTIME_MODULE_H_ |
297 | |