1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/module.h
22 * \brief IRModule that holds the functions and type definitions.
23 */
24#ifndef TVM_IR_MODULE_H_
25#define TVM_IR_MODULE_H_
26
27#include <tvm/ir/adt.h>
28#include <tvm/ir/expr.h>
29#include <tvm/ir/function.h>
30#include <tvm/ir/source_map.h>
31#include <tvm/ir/type.h>
32#include <tvm/runtime/container/array.h>
33#include <tvm/runtime/container/map.h>
34#include <tvm/runtime/container/string.h>
35
36#include <string>
37#include <unordered_map>
38#include <unordered_set>
39#include <utility>
40#include <vector>
41
42namespace tvm {
43
44class IRModule;
45
46/*!
47 * \brief IRModule that holds functions and type definitions.
48 *
49 * IRModule is the basic unit for all IR transformations across the stack.
50 *
51 * Many operations require access to the global IRModule.
52 * We pass the IRModule by value in a functional style as an explicit argument,
53 * but we mutate the Module while optimizing programs.
54 * \sa IRModule
55 */
56class IRModuleNode : public Object {
57 public:
58 /*! \brief A map from ids to all global functions. */
59 Map<GlobalVar, BaseFunc> functions;
60 /*! \brief A map from global type vars to ADT type data. */
61 Map<GlobalTypeVar, TypeData> type_definitions;
62 /*! \brief The source map for the module. */
63 SourceMap source_map;
64 /* \brief Additional attributes storing meta-data about the module. */
65 DictAttrs attrs;
66 /*!
67 * \brief A map from string names to global variables that
68 * ensures global uniqueness.
69 */
70 Map<String, GlobalVar> global_var_map_;
71
72 /*! \brief A map from string names to global type variables (ADT names)
73 * that ensures global uniqueness.
74 */
75 Map<String, GlobalTypeVar> global_type_var_map_;
76
77 /*! \brief A map from constructor tags to constructor objects
78 * for convenient access
79 */
80 std::unordered_map<int32_t, Constructor> constructor_tag_map_;
81
82 /*! \brief The files previously imported, required to ensure
83 importing is idempotent for each module.
84 */
85 std::unordered_set<String> import_set_;
86
87 /*!
88 * \brief Get a module attribute.
89 *
90 * \param attr_key The attribute key.
91 * \param default_value The default value if the key does not exist, defaults to nullptr.
92 *
93 * \return The result
94 *
95 * \tparam TOBjectRef the expected object type.
96 * \throw Error if the key exists but the value does not match TObjectRef
97 *
98 * \code
99 *
100 * void GetAttrExample(const IRModule& mod) {
101 * auto value = f->GetAttr<Integer>("AttrKey", 0);
102 * }
103 *
104 * \endcode
105 */
106 template <typename TObjectRef>
107 Optional<TObjectRef> GetAttr(
108 const std::string& attr_key,
109 Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
110 return attrs.GetAttr(attr_key, default_value);
111 }
112 // variant that uses TObjectRef to enable implicit conversion to default value.
113 template <typename TObjectRef>
114 Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
115 return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
116 }
117
118 /*!
119 * \brief Check whether the module has an non-zero integer attr.
120 *
121 * This function can be used to check whether an optional
122 * attribute mark(e.g. inline) exists.
123 *
124 * \param attr_key The key to the attribute.
125 * \return The check result.
126 *
127 * \code
128 *
129 * void HasNonzeroAttrExample(const IRModule& mod) {
130 * if (mod->HasNonzeroAttr(attr::kInline)) {
131 * // inline the function.
132 * }
133 * }
134 *
135 * \endcode
136 */
137 bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
138
139 IRModuleNode() : source_map() {}
140
141 void VisitAttrs(AttrVisitor* v) {
142 v->Visit("functions", &functions);
143 v->Visit("type_definitions", &type_definitions);
144 v->Visit("global_var_map_", &global_var_map_);
145 v->Visit("global_type_var_map_", &global_type_var_map_);
146 v->Visit("source_map", &source_map);
147 v->Visit("attrs", &attrs);
148 }
149
150 TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
151
152 TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
153
154 /*!
155 * \brief Add a function to the global environment.
156 * \param var The var of the global function.
157 * \param func The function.
158 * \param update Controls whether you can replace a definition in the
159 * environment.
160 */
161 TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
162
163 /*!
164 * \brief Add a function to the global environment.
165 * \param var The name of the global function.
166 * \param func The function.
167 *
168 * It does not do type inference as Add does.
169 */
170 TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
171
172 /*!
173 * \brief Add a type-level definition to the global environment.
174 * \param var The var of the global type definition.
175 * \param type The ADT.
176 * \param update Controls whether you can replace a definition in the
177 * environment.
178 */
179 TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
180
181 /*!
182 * \brief Add a type-level definition to the global environment.
183 * \param var The var of the global type definition.
184 * \param type The ADT.
185 * \param update Controls whether you can replace a definition in the
186 * environment.
187 *
188 * It does not do type checking as AddTypeDef does.
189 */
190 TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
191 bool update = false);
192
193 /*!
194 * \brief Update a function in the global environment.
195 * \param var The name of the global function to update.
196 * \param func The new function.
197 */
198 TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
199
200 /*!
201 * \brief Update a type definition in the global environment.
202 * \param var The name of the global type definition to update.
203 * \param type The new ADT.
204 */
205 TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
206
207 /*!
208 * \brief Remove a function from the global environment.
209 * \param var The name of the global function to update.
210 */
211 TVM_DLL void Remove(const GlobalVar& var);
212
213 /*!
214 * \brief Check if the global_var_map_ contains a global variable.
215 * \param name The variable name.
216 * \returns true if contains, otherise false.
217 */
218 TVM_DLL bool ContainGlobalVar(const String& name) const;
219
220 /*!
221 * \brief Check if the global_type_var_map_ contains a global type variable.
222 * \param name The variable name.
223 * \returns true if contains, otherise false.
224 */
225 TVM_DLL bool ContainGlobalTypeVar(const String& name) const;
226
227 /*!
228 * \brief Lookup a global function by its variable.
229 * \param str The unique string specifying the global variable.
230 * \returns The global variable.
231 */
232 TVM_DLL GlobalVar GetGlobalVar(const String& str) const;
233
234 /*!
235 * \brief Collect all global vars defined in this module.
236 * \returns An array of global vars
237 */
238 TVM_DLL Array<GlobalVar> GetGlobalVars() const;
239
240 /*!
241 * \brief Look up a global function by its name.
242 * \param str The unique string specifying the global variable.
243 * \returns The global variable.
244 */
245 TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;
246
247 /*!
248 * \brief Collect all global type vars defined in this module.
249 * \returns An array of global type vars
250 */
251 TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;
252
253 /*!
254 * \brief Find constructor of ADT using name
255 * \param adt name of the ADT the constructor belongs to
256 * \param cons name of the constructor
257 * \returns Constructor of ADT, error if not found
258 */
259 TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;
260
261 /*!
262 * \brief Look up a global function by its variable.
263 * \param var The global var to lookup.
264 * \returns The function named by the variable argument.
265 */
266 TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
267
268 /*!
269 * \brief Look up a global function by its string name
270 * \param name The name of the function.
271 * \returns The function named by the argument.
272 */
273 TVM_DLL BaseFunc Lookup(const String& name) const;
274
275 /*!
276 * \brief Look up a global type definition by its variable.
277 * \param var The var of the global type definition.
278 * \return The type definition.
279 */
280 TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const;
281
282 /*!
283 * \brief Look up a global type definition by its name.
284 * \param var The name of the global type definition.
285 * \return The type definition.
286 */
287 TVM_DLL TypeData LookupTypeDef(const String& var) const;
288
289 /*!
290 * \brief Look up a constructor by its tag.
291 * \param tag The tag for the constructor.
292 * \return The constructor object.
293 */
294 TVM_DLL Constructor LookupTag(const int32_t tag);
295
296 /*!
297 * \brief Update the functions inside this environment by
298 * functions in another environment.
299 * \param other The other environment.
300 */
301 TVM_DLL void Update(const IRModule& other);
302
303 /*!
304 * \brief Create a shallow copy of this IRModule.
305 * \returns The shallow copy of the IRModule.
306 */
307 TVM_DLL IRModule ShallowCopy();
308
309 /*!
310 * \brief Import Relay code from the file at path.
311 * \param path The path of the Relay code to import.
312 *
313 * \note The path resolution behavior is standard,
314 * if abosolute will be the absolute file, if
315 * relative it will be resovled against the current
316 * working directory.
317 */
318 TVM_DLL void Import(const String& path);
319
320 /*!
321 * \brief Import Relay code from the file at path, relative to the standard library.
322 * \param path The path of the Relay code to import.
323 */
324 TVM_DLL void ImportFromStd(const String& path);
325
326 /*!
327 * \brief The set of imported files.
328 */
329 TVM_DLL std::unordered_set<String> Imports() const;
330
331 TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
332
333 static constexpr const char* _type_key = "IRModule";
334 static constexpr const bool _type_has_method_sequal_reduce = true;
335 static constexpr const bool _type_has_method_shash_reduce = true;
336 TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
337
338 private:
339 /*! \brief Helper function for registering a typedef's constructors */
340 void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
341 friend class IRModule;
342};
343
344/*!
345 * \brief Managed reference class to IRModuleNode.
346 * \sa IRModuleNode
347 */
348class IRModule : public ObjectRef {
349 public:
350 /*!
351 * \brief constructor
352 * \param functions Functions in the module.
353 * \param type_definitions Type definitions in the module.
354 * \param import_set Set of imported files in the module.
355 * \param map The module source map.
356 * \param attrs The module attributes.
357 */
358 TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
359 Map<GlobalTypeVar, TypeData> type_definitions = {},
360 std::unordered_set<String> import_set = {}, SourceMap map = {},
361 DictAttrs attrs = {});
362
363 /*! \brief default constructor */
364 IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
365 /*!
366 * \brief constructor
367 * \param n The object pointer.
368 */
369 explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {}
370 /*! \return mutable pointers to the node. */
371 IRModuleNode* operator->() const {
372 auto* ptr = get_mutable();
373 ICHECK(ptr != nullptr);
374 return static_cast<IRModuleNode*>(ptr);
375 }
376
377 /*!
378 * \brief Constructs a module from a standalone expression \p expr.
379 *
380 * If \p expr is a function it will be bound directly. Otherwise a function over the free
381 * variables of \p expr (possibly none) with \p expr as body is created and bound.
382 *
383 * The function is bound to, in preference order:
384 * - The "global_symbol" attribute of \p expr, if it is a function with that attribute.
385 * - 'main'
386 * - A unique name derived from 'main' if 'main' is already bound in \p global_funcs.
387 *
388 * Additional global functions and type definitions may be included in the result module.
389 *
390 * See also \p FromExpr.
391 *
392 * \param expr The expression to set as the main function to the module.
393 * \param global_funcs The global function map. Default empty.
394 * \param type_definitions The global type definition map. Default empty.
395 * \param import_set Set of external modules already imported. Default empty.
396 *
397 * \returns A module with \p expr set as the main function, and the global var to which
398 * \p expr was bound (typcially 'main').
399 *
400 * TODO(mbs): Does import_set and the bound global var need to be exposed via ffi?
401 */
402 static std::pair<IRModule, GlobalVar> FromExprInContext(
403 const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
404 const Map<GlobalTypeVar, TypeData>& type_definitions = {},
405 std::unordered_set<String> import_set = {});
406
407 /*!
408 * \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
409 * imports.
410 */
411 TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
412 const Map<GlobalVar, BaseFunc>& global_funcs = {},
413 const Map<GlobalTypeVar, TypeData>& type_definitions = {});
414
415 /*!
416 * \brief Parse text format source file into an IRModule.
417 * \param text A string of Relay source code.
418 * \param source_path The path to the source file.
419 * \return A Relay module.
420 */
421 TVM_DLL static IRModule FromText(const String& text, const String& source_path);
422
423 /*!
424 * \brief Create a shallow copy of an IRModule.
425 * \param mod The module to copy.
426 * \return The copied module.
427 */
428 IRModule ShallowCopyIRModule(IRModule mod);
429
430 /*! \brief Declare the container type. */
431 using ContainerType = IRModuleNode;
432
433 /*! \brief Declare whether Ref is nullable. */
434 static constexpr bool _type_is_nullable = false;
435
436 // allow copy on write.
437 TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
438};
439
440namespace attr {
441
442// Following are attributes for IRModule only.
443
444/*!
445 * \brief Name of the module
446 *
447 * Type: String
448 *
449 * \sa tvm::runtime::String
450 */
451constexpr const char* kModuleName = "mod_name";
452
453/*!
454 * \brief Executor targeted by the module
455 *
456 * Type: Executor
457 *
458 * \sa tvm::relay::Executor
459 */
460constexpr const char* kExecutor = "executor";
461
462/*!
463 * \brief Runtime target of the module
464 *
465 * Type: Runtime
466 *
467 * \sa tvm::relay::Runtime
468 */
469constexpr const char* kRuntime = "runtime";
470
471/*!
472 * \brief workspace memory pools of the module
473 *
474 * Type: WorkspaceMemoryPools
475 *
476 * \sa tvm::WorkspaceMemoryPools
477 */
478constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";
479
480/*!
481 * \brief constant memory pools of the module
482 *
483 * Type: ConstantMemoryPools
484 *
485 * \sa tvm::ConstantMemoryPools
486 */
487constexpr const char* kConstantMemoryPools = "constant_memory_pools";
488
489/*
490 * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
491 * node will record the index into this array. See also kConstNameToConstant below, which is
492 * the analog for Realy Functions.
493 *
494 * Type: Array<runtime::NDArray>
495 */
496constexpr const char* kConstants = "constants";
497
498/*!
499 * \brief All the runtime::Modules accumulated during compilation by external codegen. These
500 * modules must be either directly linked or captured in the final compilation artifact.
501 *
502 * Type: Array<runtime::Module>
503 */
504constexpr const char* kExternalMods = "external_mods";
505
506/*!
507 * \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
508 * Generally the associated runtime::Module will indicate it requires bindings for these names,
509 * and during module initialization these bindings will be recovered from a ConstLoaderModule.
510 * See also kConstantsArray above, which is the analog for PrimFuncs.
511 *
512 * Type: Map<String, runtime::NDArray>
513 */
514constexpr const char* kConstNameToConstant = "const_name_to_constant";
515
516} // namespace attr
517} // namespace tvm
518#endif // TVM_IR_MODULE_H_
519