1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file module.cc
21 * \brief The global module in TVM.
22 */
23#include <tvm/ir/global_var_supply.h>
24#include <tvm/ir/module.h>
25#include <tvm/ir/type_functor.h>
26#include <tvm/node/structural_equal.h>
27#include <tvm/runtime/registry.h>
28
29#include <fstream>
30#include <sstream>
31#include <unordered_set>
32
33namespace tvm {
34
35IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
36 tvm::Map<GlobalTypeVar, TypeData> type_definitions,
37 std::unordered_set<String> import_set, SourceMap source_map, DictAttrs attrs) {
38 auto n = make_object<IRModuleNode>();
39 n->functions = std::move(functions);
40 n->type_definitions = std::move(type_definitions);
41 n->global_type_var_map_ = {};
42 n->global_var_map_ = {};
43 n->constructor_tag_map_ = {};
44 n->import_set_ = std::move(import_set);
45 n->source_map = source_map;
46 n->attrs = std::move(attrs);
47
48 for (const auto& kv : n->functions) {
49 // set global var map
50 ICHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
51 << "Duplicate global function name " << kv.first->name_hint;
52 n->global_var_map_.Set(kv.first->name_hint, kv.first);
53 }
54
55 for (const auto& kv : n->type_definitions) {
56 // set global typevar map
57 ICHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
58 << "Duplicate global type definition name " << kv.first->name_hint;
59 n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
60 n->RegisterConstructors(kv.first, kv.second);
61 }
62 data_ = std::move(n);
63}
64
65bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
66 if (functions.size() != other->functions.size()) return false;
67 if (!equal(this->attrs, other->attrs)) return false;
68 for (const auto& kv : this->functions) {
69 if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
70 if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
71 }
72 if (type_definitions.size() != other->type_definitions.size()) return false;
73 for (const auto& kv : this->type_definitions) {
74 if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
75 if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
76 }
77 return true;
78}
79
80void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
81 using KV = std::pair<std::string, ObjectRef>;
82 // hash the functions.
83 std::vector<KV> temp;
84
85 auto reduce_temp = [&]() {
86 // sort by the hash key of the keys.
87 std::sort(temp.begin(), temp.end(),
88 [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
89
90 hash_reduce(static_cast<uint64_t>(temp.size()));
91 // hash the content
92 for (size_t i = 0; i < temp.size(); ++i) {
93 hash_reduce(temp[i].first);
94 hash_reduce(temp[i].second);
95 }
96 };
97
98 for (const auto& kv : this->functions) {
99 temp.emplace_back(kv.first->name_hint, kv.second);
100 }
101 reduce_temp();
102
103 temp.clear();
104 for (const auto& kv : this->type_definitions) {
105 temp.emplace_back(kv.first->name_hint, kv.second);
106 }
107 reduce_temp();
108 hash_reduce(this->attrs);
109}
110
111bool IRModuleNode::ContainGlobalVar(const String& name) const {
112 return global_var_map_.find(name) != global_var_map_.end();
113}
114
115bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
116 return global_type_var_map_.find(name) != global_type_var_map_.end();
117}
118
119GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
120 auto it = global_var_map_.find(name);
121 if (it == global_var_map_.end()) {
122 std::ostringstream msg;
123 msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
124 << "candidates are: [";
125 int counter = 0;
126 for (auto kv : global_var_map_) {
127 if (counter++ != 0) {
128 msg << ", ";
129 }
130 msg << "\"" << kv.first << "\"";
131 }
132 msg << "]";
133 LOG(FATAL) << msg.str();
134 }
135 return (*it).second;
136}
137
138tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
139 std::vector<GlobalVar> global_vars;
140 for (const auto& pair : global_var_map_) {
141 global_vars.push_back(pair.second);
142 }
143 return tvm::Array<GlobalVar>(global_vars);
144}
145
146GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
147 ICHECK(global_type_var_map_.defined());
148 auto it = global_type_var_map_.find(name);
149 ICHECK(it != global_type_var_map_.end())
150 << "Cannot find global type var " << name << " in the Module";
151 return (*it).second;
152}
153
154Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
155 TypeData typeDef = this->LookupTypeDef(adt);
156 for (Constructor c : typeDef->constructors) {
157 if (cons.compare(c->name_hint) == 0) {
158 return c;
159 }
160 }
161
162 LOG(FATAL) << adt << " does not contain constructor " << cons;
163}
164
165tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
166 std::vector<GlobalTypeVar> global_type_vars;
167 for (const auto& pair : global_type_var_map_) {
168 global_type_vars.push_back(pair.second);
169 }
170 return tvm::Array<GlobalTypeVar>(global_type_vars);
171}
172
173void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
174 BaseFunc checked_func = f;
175 if (const auto* f = runtime::Registry::Get("relay.ir.WarnIfMalformed")) {
176 (*f)(GetRef<IRModule>(this), checked_func);
177 }
178 AddUnchecked(var, checked_func);
179}
180
181void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
182 this->functions.Set(var, func);
183
184 auto it = global_var_map_.find(var->name_hint);
185 if (it != global_var_map_.end()) {
186 ICHECK_EQ((*it).second, var);
187 } else {
188 ICHECK(global_var_map_.count(var->name_hint) == 0) << "Duplicate global function name " << var;
189 }
190
191 global_var_map_.Set(var->name_hint, var);
192}
193
194void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
195 // We hash the global type var name to use as a globally unique prefix for tags.
196 // The hash will be used as the most significant byte of the tag, with the index of
197 // the constructor in the less significant bytes
198 size_t hash = std::hash<std::string>()(var->name_hint);
199 int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
200 for (size_t i = 0; i < type->constructors.size(); ++i) {
201 type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
202 constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
203 }
204}
205
206void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
207 // TODO(@jroesch): we have temporarily removed kind checking here, and will consolidate
208 // to the type checker in follow up PR.
209 AddTypeDefUnchecked(var, type, update);
210}
211
212void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
213 bool update) {
214 this->type_definitions.Set(var, type);
215 if (!update) {
216 // set global type var map
217 ICHECK(global_type_var_map_.count(var->name_hint) == 0)
218 << "Duplicate global type definition name " << var;
219 }
220 global_type_var_map_.Set(var->name_hint, var);
221 RegisterConstructors(var, type);
222}
223
224void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) {
225 this->Add(var, func, true);
226}
227
228void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
229 this->AddTypeDef(var, type, true);
230}
231
232void IRModuleNode::Remove(const GlobalVar& var) {
233 auto functions_node = this->functions.CopyOnWrite();
234 functions_node->erase(var);
235 auto gvar_node = global_var_map_.CopyOnWrite();
236 gvar_node->erase(var->name_hint);
237}
238
239BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
240 auto it = functions.find(var);
241 ICHECK(it != functions.end()) << "There is no definition of " << var;
242 return (*it).second;
243}
244
245BaseFunc IRModuleNode::Lookup(const String& name) const {
246 GlobalVar id = this->GetGlobalVar(name);
247 return this->Lookup(id);
248}
249
250TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
251 auto it = type_definitions.find(var);
252 ICHECK(it != type_definitions.end()) << "There is no definition of " << var;
253 return (*it).second;
254}
255
256TypeData IRModuleNode::LookupTypeDef(const String& name) const {
257 GlobalTypeVar id = this->GetGlobalTypeVar(name);
258 return this->LookupTypeDef(id);
259}
260
261Constructor IRModuleNode::LookupTag(const int32_t tag) {
262 auto it = constructor_tag_map_.find(tag);
263 ICHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag;
264 return (*it).second;
265}
266
267void IRModuleNode::Update(const IRModule& mod) {
268 if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleUpdateWithRenamer")) {
269 (*f)(GetRef<IRModule>(this), mod);
270 return;
271 }
272 for (auto pair : mod->functions) {
273 // TODO(@jroesch): rename into IRModule.
274 this->AddUnchecked(pair.first, pair.second);
275 }
276}
277
278IRModule IRModuleNode::ShallowCopy() {
279 return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
280 this->attrs);
281}
282
283std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
284 const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
285 const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
286 std::unordered_set<String> import_set) {
287 auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
288 String gv_name;
289
290 // All global definitions must be functions.
291 BaseFunc func;
292 if (auto* func_node = expr.as<BaseFuncNode>()) {
293 func = GetRef<BaseFunc>(func_node);
294 if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
295 // Function literal has been annotated with it's required global symbol.
296 gv_name = opt.value();
297 }
298 } else if (const auto* f = runtime::Registry::Get("relay.ir.FunctionFromExprInContext")) {
299 func = (*f)(expr, mod);
300 } else {
301 LOG(FATAL) << "`relay.ir.FunctionFromExprInContext` is not registered";
302 }
303
304 GlobalVar main_gv;
305 auto global_var_supply = GlobalVarSupply(mod);
306 if (gv_name.empty()) {
307 // Bind function to 'main' (though rename if would clash with existing 'main').
308 main_gv = global_var_supply->FreshGlobal("main", false);
309 } else {
310 main_gv = global_var_supply->UniqueGlobalFor(gv_name, false);
311 }
312 mod->Add(main_gv, func);
313 return {mod, main_gv};
314}
315
316IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs,
317 const Map<GlobalTypeVar, TypeData>& type_definitions) {
318 return FromExprInContext(expr, global_funcs, type_definitions).first;
319}
320
321void IRModuleNode::Import(const String& path) {
322 static const auto* f = runtime::Registry::Get("relay.parser.ParseModule");
323 ICHECK(f != nullptr) << "ValueError: Relay parser is not available";
324 if (this->import_set_.count(path) == 0) {
325 this->import_set_.insert(path);
326 std::fstream src_file(path, std::fstream::in);
327 std::string file_contents{std::istreambuf_iterator<char>(src_file),
328 std::istreambuf_iterator<char>()};
329 auto mod_to_import = (*f)(path, file_contents, GetRef<IRModule>(this));
330 Update(mod_to_import);
331 }
332}
333
334void IRModuleNode::ImportFromStd(const String& path) {
335 auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
336 ICHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
337 std::string std_path = (*f)();
338 this->Import(std_path + "/" + path);
339}
340
341std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
342
343IRModule IRModule::FromText(const String& text, const String& source_path) {
344 static const auto* f = runtime::Registry::Get("relay.parser.ParseModule");
345 ICHECK(f != nullptr) << "ValueError: Relay parser is not available";
346 return (*f)(source_path, text, Optional<IRModule>());
347}
348
349TVM_REGISTER_NODE_TYPE(IRModuleNode);
350
351TVM_REGISTER_GLOBAL("ir.IRModule")
352 .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
353 tvm::Map<GlobalTypeVar, TypeData> types) {
354 return IRModule(funcs, types, {});
355 });
356
357TVM_REGISTER_GLOBAL("ir.Module_Add")
358 .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
359 ICHECK(val->IsInstance<RelayExprNode>());
360 if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleAdd")) {
361 return (*f)(mod, var, val, update);
362 }
363 mod->Add(var, Downcast<BaseFunc>(val), update);
364 return mod;
365 });
366
367TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
368
369TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
370 .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
371
372TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
373 .set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
374
375TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
376 .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
377
378TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
379 .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
380
381TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalTypeVar")
382 .set_body_method<IRModule>(&IRModuleNode::ContainGlobalTypeVar);
383
384TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
385 .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
386
387TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) {
388 return mod->Lookup(var);
389});
390
391TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) {
392 return mod->Lookup(var);
393});
394
395TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
396 return mod->LookupTypeDef(var);
397});
398
399TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
400 return mod->LookupTypeDef(var);
401});
402
403TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
404 return mod->LookupTag(tag);
405});
406
407TVM_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr);
408
409TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
410 mod->Update(from);
411});
412
413TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
414 .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
415
416TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
417 mod->Import(path);
418});
419
420TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
421 mod->ImportFromStd(path);
422});
423
424TVM_REGISTER_GLOBAL("ir.Module_WithAttr")
425 .set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule {
426 return WithAttr(mod, key, value);
427 });
428
429TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef {
430 return mod->GetAttr<ObjectRef>(key);
431});
432
433} // namespace tvm
434