1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/module.h |
22 | * \brief IRModule that holds the functions and type definitions. |
23 | */ |
24 | #ifndef TVM_IR_MODULE_H_ |
25 | #define TVM_IR_MODULE_H_ |
26 | |
27 | #include <tvm/ir/adt.h> |
28 | #include <tvm/ir/expr.h> |
29 | #include <tvm/ir/function.h> |
30 | #include <tvm/ir/source_map.h> |
31 | #include <tvm/ir/type.h> |
32 | #include <tvm/runtime/container/array.h> |
33 | #include <tvm/runtime/container/map.h> |
34 | #include <tvm/runtime/container/string.h> |
35 | |
36 | #include <string> |
37 | #include <unordered_map> |
38 | #include <unordered_set> |
39 | #include <utility> |
40 | #include <vector> |
41 | |
42 | namespace tvm { |
43 | |
44 | class IRModule; |
45 | |
46 | /*! |
47 | * \brief IRModule that holds functions and type definitions. |
48 | * |
49 | * IRModule is the basic unit for all IR transformations across the stack. |
50 | * |
51 | * Many operations require access to the global IRModule. |
52 | * We pass the IRModule by value in a functional style as an explicit argument, |
53 | * but we mutate the Module while optimizing programs. |
54 | * \sa IRModule |
55 | */ |
56 | class IRModuleNode : public Object { |
57 | public: |
58 | /*! \brief A map from ids to all global functions. */ |
59 | Map<GlobalVar, BaseFunc> functions; |
60 | /*! \brief A map from global type vars to ADT type data. */ |
61 | Map<GlobalTypeVar, TypeData> type_definitions; |
62 | /*! \brief The source map for the module. */ |
63 | SourceMap source_map; |
64 | /* \brief Additional attributes storing meta-data about the module. */ |
65 | DictAttrs attrs; |
66 | /*! |
67 | * \brief A map from string names to global variables that |
68 | * ensures global uniqueness. |
69 | */ |
70 | Map<String, GlobalVar> global_var_map_; |
71 | |
72 | /*! \brief A map from string names to global type variables (ADT names) |
73 | * that ensures global uniqueness. |
74 | */ |
75 | Map<String, GlobalTypeVar> global_type_var_map_; |
76 | |
77 | /*! \brief A map from constructor tags to constructor objects |
78 | * for convenient access |
79 | */ |
80 | std::unordered_map<int32_t, Constructor> constructor_tag_map_; |
81 | |
82 | /*! \brief The files previously imported, required to ensure |
83 | importing is idempotent for each module. |
84 | */ |
85 | std::unordered_set<String> import_set_; |
86 | |
87 | /*! |
88 | * \brief Get a module attribute. |
89 | * |
90 | * \param attr_key The attribute key. |
91 | * \param default_value The default value if the key does not exist, defaults to nullptr. |
92 | * |
93 | * \return The result |
94 | * |
95 | * \tparam TOBjectRef the expected object type. |
96 | * \throw Error if the key exists but the value does not match TObjectRef |
97 | * |
98 | * \code |
99 | * |
100 | * void GetAttrExample(const IRModule& mod) { |
101 | * auto value = f->GetAttr<Integer>("AttrKey", 0); |
102 | * } |
103 | * |
104 | * \endcode |
105 | */ |
106 | template <typename TObjectRef> |
107 | Optional<TObjectRef> GetAttr( |
108 | const std::string& attr_key, |
109 | Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { |
110 | return attrs.GetAttr(attr_key, default_value); |
111 | } |
112 | // variant that uses TObjectRef to enable implicit conversion to default value. |
113 | template <typename TObjectRef> |
114 | Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { |
115 | return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); |
116 | } |
117 | |
118 | /*! |
119 | * \brief Check whether the module has an non-zero integer attr. |
120 | * |
121 | * This function can be used to check whether an optional |
122 | * attribute mark(e.g. inline) exists. |
123 | * |
124 | * \param attr_key The key to the attribute. |
125 | * \return The check result. |
126 | * |
127 | * \code |
128 | * |
129 | * void HasNonzeroAttrExample(const IRModule& mod) { |
130 | * if (mod->HasNonzeroAttr(attr::kInline)) { |
131 | * // inline the function. |
132 | * } |
133 | * } |
134 | * |
135 | * \endcode |
136 | */ |
137 | bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } |
138 | |
139 | IRModuleNode() : source_map() {} |
140 | |
141 | void VisitAttrs(AttrVisitor* v) { |
142 | v->Visit("functions" , &functions); |
143 | v->Visit("type_definitions" , &type_definitions); |
144 | v->Visit("global_var_map_" , &global_var_map_); |
145 | v->Visit("global_type_var_map_" , &global_type_var_map_); |
146 | v->Visit("source_map" , &source_map); |
147 | v->Visit("attrs" , &attrs); |
148 | } |
149 | |
150 | TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; |
151 | |
152 | TVM_DLL void SHashReduce(SHashReducer hash_reduce) const; |
153 | |
154 | /*! |
155 | * \brief Add a function to the global environment. |
156 | * \param var The var of the global function. |
157 | * \param func The function. |
158 | * \param update Controls whether you can replace a definition in the |
159 | * environment. |
160 | */ |
161 | TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false); |
162 | |
163 | /*! |
164 | * \brief Add a function to the global environment. |
165 | * \param var The name of the global function. |
166 | * \param func The function. |
167 | * |
168 | * It does not do type inference as Add does. |
169 | */ |
170 | TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func); |
171 | |
172 | /*! |
173 | * \brief Add a type-level definition to the global environment. |
174 | * \param var The var of the global type definition. |
175 | * \param type The ADT. |
176 | * \param update Controls whether you can replace a definition in the |
177 | * environment. |
178 | */ |
179 | TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false); |
180 | |
181 | /*! |
182 | * \brief Add a type-level definition to the global environment. |
183 | * \param var The var of the global type definition. |
184 | * \param type The ADT. |
185 | * \param update Controls whether you can replace a definition in the |
186 | * environment. |
187 | * |
188 | * It does not do type checking as AddTypeDef does. |
189 | */ |
190 | TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, |
191 | bool update = false); |
192 | |
193 | /*! |
194 | * \brief Update a function in the global environment. |
195 | * \param var The name of the global function to update. |
196 | * \param func The new function. |
197 | */ |
198 | TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func); |
199 | |
200 | /*! |
201 | * \brief Update a type definition in the global environment. |
202 | * \param var The name of the global type definition to update. |
203 | * \param type The new ADT. |
204 | */ |
205 | TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type); |
206 | |
207 | /*! |
208 | * \brief Remove a function from the global environment. |
209 | * \param var The name of the global function to update. |
210 | */ |
211 | TVM_DLL void Remove(const GlobalVar& var); |
212 | |
213 | /*! |
214 | * \brief Check if the global_var_map_ contains a global variable. |
215 | * \param name The variable name. |
216 | * \returns true if contains, otherise false. |
217 | */ |
218 | TVM_DLL bool ContainGlobalVar(const String& name) const; |
219 | |
220 | /*! |
221 | * \brief Check if the global_type_var_map_ contains a global type variable. |
222 | * \param name The variable name. |
223 | * \returns true if contains, otherise false. |
224 | */ |
225 | TVM_DLL bool ContainGlobalTypeVar(const String& name) const; |
226 | |
227 | /*! |
228 | * \brief Lookup a global function by its variable. |
229 | * \param str The unique string specifying the global variable. |
230 | * \returns The global variable. |
231 | */ |
232 | TVM_DLL GlobalVar GetGlobalVar(const String& str) const; |
233 | |
234 | /*! |
235 | * \brief Collect all global vars defined in this module. |
236 | * \returns An array of global vars |
237 | */ |
238 | TVM_DLL Array<GlobalVar> GetGlobalVars() const; |
239 | |
240 | /*! |
241 | * \brief Look up a global function by its name. |
242 | * \param str The unique string specifying the global variable. |
243 | * \returns The global variable. |
244 | */ |
245 | TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const; |
246 | |
247 | /*! |
248 | * \brief Collect all global type vars defined in this module. |
249 | * \returns An array of global type vars |
250 | */ |
251 | TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const; |
252 | |
253 | /*! |
254 | * \brief Find constructor of ADT using name |
255 | * \param adt name of the ADT the constructor belongs to |
256 | * \param cons name of the constructor |
257 | * \returns Constructor of ADT, error if not found |
258 | */ |
259 | TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const; |
260 | |
261 | /*! |
262 | * \brief Look up a global function by its variable. |
263 | * \param var The global var to lookup. |
264 | * \returns The function named by the variable argument. |
265 | */ |
266 | TVM_DLL BaseFunc Lookup(const GlobalVar& var) const; |
267 | |
268 | /*! |
269 | * \brief Look up a global function by its string name |
270 | * \param name The name of the function. |
271 | * \returns The function named by the argument. |
272 | */ |
273 | TVM_DLL BaseFunc Lookup(const String& name) const; |
274 | |
275 | /*! |
276 | * \brief Look up a global type definition by its variable. |
277 | * \param var The var of the global type definition. |
278 | * \return The type definition. |
279 | */ |
280 | TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const; |
281 | |
282 | /*! |
283 | * \brief Look up a global type definition by its name. |
284 | * \param var The name of the global type definition. |
285 | * \return The type definition. |
286 | */ |
287 | TVM_DLL TypeData LookupTypeDef(const String& var) const; |
288 | |
289 | /*! |
290 | * \brief Look up a constructor by its tag. |
291 | * \param tag The tag for the constructor. |
292 | * \return The constructor object. |
293 | */ |
294 | TVM_DLL Constructor LookupTag(const int32_t tag); |
295 | |
296 | /*! |
297 | * \brief Update the functions inside this environment by |
298 | * functions in another environment. |
299 | * \param other The other environment. |
300 | */ |
301 | TVM_DLL void Update(const IRModule& other); |
302 | |
303 | /*! |
304 | * \brief Create a shallow copy of this IRModule. |
305 | * \returns The shallow copy of the IRModule. |
306 | */ |
307 | TVM_DLL IRModule ShallowCopy(); |
308 | |
309 | /*! |
310 | * \brief Import Relay code from the file at path. |
311 | * \param path The path of the Relay code to import. |
312 | * |
313 | * \note The path resolution behavior is standard, |
314 | * if abosolute will be the absolute file, if |
315 | * relative it will be resovled against the current |
316 | * working directory. |
317 | */ |
318 | TVM_DLL void Import(const String& path); |
319 | |
320 | /*! |
321 | * \brief Import Relay code from the file at path, relative to the standard library. |
322 | * \param path The path of the Relay code to import. |
323 | */ |
324 | TVM_DLL void ImportFromStd(const String& path); |
325 | |
326 | /*! |
327 | * \brief The set of imported files. |
328 | */ |
329 | TVM_DLL std::unordered_set<String> Imports() const; |
330 | |
331 | TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); |
332 | |
333 | static constexpr const char* _type_key = "IRModule" ; |
334 | static constexpr const bool _type_has_method_sequal_reduce = true; |
335 | static constexpr const bool _type_has_method_shash_reduce = true; |
336 | TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); |
337 | |
338 | private: |
339 | /*! \brief Helper function for registering a typedef's constructors */ |
340 | void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); |
341 | friend class IRModule; |
342 | }; |
343 | |
344 | /*! |
345 | * \brief Managed reference class to IRModuleNode. |
346 | * \sa IRModuleNode |
347 | */ |
348 | class IRModule : public ObjectRef { |
349 | public: |
350 | /*! |
351 | * \brief constructor |
352 | * \param functions Functions in the module. |
353 | * \param type_definitions Type definitions in the module. |
354 | * \param import_set Set of imported files in the module. |
355 | * \param map The module source map. |
356 | * \param attrs The module attributes. |
357 | */ |
358 | TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions, |
359 | Map<GlobalTypeVar, TypeData> type_definitions = {}, |
360 | std::unordered_set<String> import_set = {}, SourceMap map = {}, |
361 | DictAttrs attrs = {}); |
362 | |
363 | /*! \brief default constructor */ |
364 | IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {} |
365 | /*! |
366 | * \brief constructor |
367 | * \param n The object pointer. |
368 | */ |
369 | explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {} |
370 | /*! \return mutable pointers to the node. */ |
371 | IRModuleNode* operator->() const { |
372 | auto* ptr = get_mutable(); |
373 | ICHECK(ptr != nullptr); |
374 | return static_cast<IRModuleNode*>(ptr); |
375 | } |
376 | |
377 | /*! |
378 | * \brief Constructs a module from a standalone expression \p expr. |
379 | * |
380 | * If \p expr is a function it will be bound directly. Otherwise a function over the free |
381 | * variables of \p expr (possibly none) with \p expr as body is created and bound. |
382 | * |
383 | * The function is bound to, in preference order: |
384 | * - The "global_symbol" attribute of \p expr, if it is a function with that attribute. |
385 | * - 'main' |
386 | * - A unique name derived from 'main' if 'main' is already bound in \p global_funcs. |
387 | * |
388 | * Additional global functions and type definitions may be included in the result module. |
389 | * |
390 | * See also \p FromExpr. |
391 | * |
392 | * \param expr The expression to set as the main function to the module. |
393 | * \param global_funcs The global function map. Default empty. |
394 | * \param type_definitions The global type definition map. Default empty. |
395 | * \param import_set Set of external modules already imported. Default empty. |
396 | * |
397 | * \returns A module with \p expr set as the main function, and the global var to which |
398 | * \p expr was bound (typcially 'main'). |
399 | * |
400 | * TODO(mbs): Does import_set and the bound global var need to be exposed via ffi? |
401 | */ |
402 | static std::pair<IRModule, GlobalVar> FromExprInContext( |
403 | const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {}, |
404 | const Map<GlobalTypeVar, TypeData>& type_definitions = {}, |
405 | std::unordered_set<String> import_set = {}); |
406 | |
407 | /*! |
408 | * \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no |
409 | * imports. |
410 | */ |
411 | TVM_DLL static IRModule FromExpr(const RelayExpr& expr, |
412 | const Map<GlobalVar, BaseFunc>& global_funcs = {}, |
413 | const Map<GlobalTypeVar, TypeData>& type_definitions = {}); |
414 | |
415 | /*! |
416 | * \brief Parse text format source file into an IRModule. |
417 | * \param text A string of Relay source code. |
418 | * \param source_path The path to the source file. |
419 | * \return A Relay module. |
420 | */ |
421 | TVM_DLL static IRModule FromText(const String& text, const String& source_path); |
422 | |
423 | /*! |
424 | * \brief Create a shallow copy of an IRModule. |
425 | * \param mod The module to copy. |
426 | * \return The copied module. |
427 | */ |
428 | IRModule ShallowCopyIRModule(IRModule mod); |
429 | |
430 | /*! \brief Declare the container type. */ |
431 | using ContainerType = IRModuleNode; |
432 | |
433 | /*! \brief Declare whether Ref is nullable. */ |
434 | static constexpr bool _type_is_nullable = false; |
435 | |
436 | // allow copy on write. |
437 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode); |
438 | }; |
439 | |
440 | namespace attr { |
441 | |
442 | // Following are attributes for IRModule only. |
443 | |
444 | /*! |
445 | * \brief Name of the module |
446 | * |
447 | * Type: String |
448 | * |
449 | * \sa tvm::runtime::String |
450 | */ |
451 | constexpr const char* kModuleName = "mod_name" ; |
452 | |
453 | /*! |
454 | * \brief Executor targeted by the module |
455 | * |
456 | * Type: Executor |
457 | * |
458 | * \sa tvm::relay::Executor |
459 | */ |
460 | constexpr const char* kExecutor = "executor" ; |
461 | |
462 | /*! |
463 | * \brief Runtime target of the module |
464 | * |
465 | * Type: Runtime |
466 | * |
467 | * \sa tvm::relay::Runtime |
468 | */ |
469 | constexpr const char* kRuntime = "runtime" ; |
470 | |
471 | /*! |
472 | * \brief workspace memory pools of the module |
473 | * |
474 | * Type: WorkspaceMemoryPools |
475 | * |
476 | * \sa tvm::WorkspaceMemoryPools |
477 | */ |
478 | constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools" ; |
479 | |
480 | /*! |
481 | * \brief constant memory pools of the module |
482 | * |
483 | * Type: ConstantMemoryPools |
484 | * |
485 | * \sa tvm::ConstantMemoryPools |
486 | */ |
487 | constexpr const char* kConstantMemoryPools = "constant_memory_pools" ; |
488 | |
489 | /* |
490 | * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The |
491 | * node will record the index into this array. See also kConstNameToConstant below, which is |
492 | * the analog for Realy Functions. |
493 | * |
494 | * Type: Array<runtime::NDArray> |
495 | */ |
496 | constexpr const char* kConstants = "constants" ; |
497 | |
498 | /*! |
499 | * \brief All the runtime::Modules accumulated during compilation by external codegen. These |
500 | * modules must be either directly linked or captured in the final compilation artifact. |
501 | * |
502 | * Type: Array<runtime::Module> |
503 | */ |
504 | constexpr const char* kExternalMods = "external_mods" ; |
505 | |
506 | /*! |
507 | * \brief All the named runtime::NDArrays accumulated during compilation by external codegen. |
508 | * Generally the associated runtime::Module will indicate it requires bindings for these names, |
509 | * and during module initialization these bindings will be recovered from a ConstLoaderModule. |
510 | * See also kConstantsArray above, which is the analog for PrimFuncs. |
511 | * |
512 | * Type: Map<String, runtime::NDArray> |
513 | */ |
514 | constexpr const char* kConstNameToConstant = "const_name_to_constant" ; |
515 | |
516 | } // namespace attr |
517 | } // namespace tvm |
518 | #endif // TVM_IR_MODULE_H_ |
519 | |