1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file module.cc |
21 | * \brief The global module in TVM. |
22 | */ |
23 | #include <tvm/ir/global_var_supply.h> |
24 | #include <tvm/ir/module.h> |
25 | #include <tvm/ir/type_functor.h> |
26 | #include <tvm/node/structural_equal.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include <fstream> |
30 | #include <sstream> |
31 | #include <unordered_set> |
32 | |
33 | namespace tvm { |
34 | |
35 | IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions, |
36 | tvm::Map<GlobalTypeVar, TypeData> type_definitions, |
37 | std::unordered_set<String> import_set, SourceMap source_map, DictAttrs attrs) { |
38 | auto n = make_object<IRModuleNode>(); |
39 | n->functions = std::move(functions); |
40 | n->type_definitions = std::move(type_definitions); |
41 | n->global_type_var_map_ = {}; |
42 | n->global_var_map_ = {}; |
43 | n->constructor_tag_map_ = {}; |
44 | n->import_set_ = std::move(import_set); |
45 | n->source_map = source_map; |
46 | n->attrs = std::move(attrs); |
47 | |
48 | for (const auto& kv : n->functions) { |
49 | // set global var map |
50 | ICHECK(n->global_var_map_.count(kv.first->name_hint) == 0) |
51 | << "Duplicate global function name " << kv.first->name_hint; |
52 | n->global_var_map_.Set(kv.first->name_hint, kv.first); |
53 | } |
54 | |
55 | for (const auto& kv : n->type_definitions) { |
56 | // set global typevar map |
57 | ICHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0) |
58 | << "Duplicate global type definition name " << kv.first->name_hint; |
59 | n->global_type_var_map_.Set(kv.first->name_hint, kv.first); |
60 | n->RegisterConstructors(kv.first, kv.second); |
61 | } |
62 | data_ = std::move(n); |
63 | } |
64 | |
65 | bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { |
66 | if (functions.size() != other->functions.size()) return false; |
67 | if (!equal(this->attrs, other->attrs)) return false; |
68 | for (const auto& kv : this->functions) { |
69 | if (!other->ContainGlobalVar(kv.first->name_hint)) return false; |
70 | if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; |
71 | } |
72 | if (type_definitions.size() != other->type_definitions.size()) return false; |
73 | for (const auto& kv : this->type_definitions) { |
74 | if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; |
75 | if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; |
76 | } |
77 | return true; |
78 | } |
79 | |
80 | void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { |
81 | using KV = std::pair<std::string, ObjectRef>; |
82 | // hash the functions. |
83 | std::vector<KV> temp; |
84 | |
85 | auto reduce_temp = [&]() { |
86 | // sort by the hash key of the keys. |
87 | std::sort(temp.begin(), temp.end(), |
88 | [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); |
89 | |
90 | hash_reduce(static_cast<uint64_t>(temp.size())); |
91 | // hash the content |
92 | for (size_t i = 0; i < temp.size(); ++i) { |
93 | hash_reduce(temp[i].first); |
94 | hash_reduce(temp[i].second); |
95 | } |
96 | }; |
97 | |
98 | for (const auto& kv : this->functions) { |
99 | temp.emplace_back(kv.first->name_hint, kv.second); |
100 | } |
101 | reduce_temp(); |
102 | |
103 | temp.clear(); |
104 | for (const auto& kv : this->type_definitions) { |
105 | temp.emplace_back(kv.first->name_hint, kv.second); |
106 | } |
107 | reduce_temp(); |
108 | hash_reduce(this->attrs); |
109 | } |
110 | |
111 | bool IRModuleNode::ContainGlobalVar(const String& name) const { |
112 | return global_var_map_.find(name) != global_var_map_.end(); |
113 | } |
114 | |
115 | bool IRModuleNode::ContainGlobalTypeVar(const String& name) const { |
116 | return global_type_var_map_.find(name) != global_type_var_map_.end(); |
117 | } |
118 | |
119 | GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { |
120 | auto it = global_var_map_.find(name); |
121 | if (it == global_var_map_.end()) { |
122 | std::ostringstream msg; |
123 | msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n" |
124 | << "candidates are: [" ; |
125 | int counter = 0; |
126 | for (auto kv : global_var_map_) { |
127 | if (counter++ != 0) { |
128 | msg << ", " ; |
129 | } |
130 | msg << "\"" << kv.first << "\"" ; |
131 | } |
132 | msg << "]" ; |
133 | LOG(FATAL) << msg.str(); |
134 | } |
135 | return (*it).second; |
136 | } |
137 | |
138 | tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const { |
139 | std::vector<GlobalVar> global_vars; |
140 | for (const auto& pair : global_var_map_) { |
141 | global_vars.push_back(pair.second); |
142 | } |
143 | return tvm::Array<GlobalVar>(global_vars); |
144 | } |
145 | |
146 | GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const { |
147 | ICHECK(global_type_var_map_.defined()); |
148 | auto it = global_type_var_map_.find(name); |
149 | ICHECK(it != global_type_var_map_.end()) |
150 | << "Cannot find global type var " << name << " in the Module" ; |
151 | return (*it).second; |
152 | } |
153 | |
154 | Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const { |
155 | TypeData typeDef = this->LookupTypeDef(adt); |
156 | for (Constructor c : typeDef->constructors) { |
157 | if (cons.compare(c->name_hint) == 0) { |
158 | return c; |
159 | } |
160 | } |
161 | |
162 | LOG(FATAL) << adt << " does not contain constructor " << cons; |
163 | } |
164 | |
165 | tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const { |
166 | std::vector<GlobalTypeVar> global_type_vars; |
167 | for (const auto& pair : global_type_var_map_) { |
168 | global_type_vars.push_back(pair.second); |
169 | } |
170 | return tvm::Array<GlobalTypeVar>(global_type_vars); |
171 | } |
172 | |
173 | void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { |
174 | BaseFunc checked_func = f; |
175 | if (const auto* f = runtime::Registry::Get("relay.ir.WarnIfMalformed" )) { |
176 | (*f)(GetRef<IRModule>(this), checked_func); |
177 | } |
178 | AddUnchecked(var, checked_func); |
179 | } |
180 | |
181 | void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { |
182 | this->functions.Set(var, func); |
183 | |
184 | auto it = global_var_map_.find(var->name_hint); |
185 | if (it != global_var_map_.end()) { |
186 | ICHECK_EQ((*it).second, var); |
187 | } else { |
188 | ICHECK(global_var_map_.count(var->name_hint) == 0) << "Duplicate global function name " << var; |
189 | } |
190 | |
191 | global_var_map_.Set(var->name_hint, var); |
192 | } |
193 | |
194 | void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { |
195 | // We hash the global type var name to use as a globally unique prefix for tags. |
196 | // The hash will be used as the most significant byte of the tag, with the index of |
197 | // the constructor in the less significant bytes |
198 | size_t hash = std::hash<std::string>()(var->name_hint); |
199 | int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24; |
200 | for (size_t i = 0; i < type->constructors.size(); ++i) { |
201 | type->constructors[i]->tag = prefix | static_cast<int32_t>(i); |
202 | constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; |
203 | } |
204 | } |
205 | |
206 | void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { |
207 | // TODO(@jroesch): we have temporarily removed kind checking here, and will consolidate |
208 | // to the type checker in follow up PR. |
209 | AddTypeDefUnchecked(var, type, update); |
210 | } |
211 | |
212 | void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, |
213 | bool update) { |
214 | this->type_definitions.Set(var, type); |
215 | if (!update) { |
216 | // set global type var map |
217 | ICHECK(global_type_var_map_.count(var->name_hint) == 0) |
218 | << "Duplicate global type definition name " << var; |
219 | } |
220 | global_type_var_map_.Set(var->name_hint, var); |
221 | RegisterConstructors(var, type); |
222 | } |
223 | |
224 | void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) { |
225 | this->Add(var, func, true); |
226 | } |
227 | |
228 | void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { |
229 | this->AddTypeDef(var, type, true); |
230 | } |
231 | |
232 | void IRModuleNode::Remove(const GlobalVar& var) { |
233 | auto functions_node = this->functions.CopyOnWrite(); |
234 | functions_node->erase(var); |
235 | auto gvar_node = global_var_map_.CopyOnWrite(); |
236 | gvar_node->erase(var->name_hint); |
237 | } |
238 | |
239 | BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { |
240 | auto it = functions.find(var); |
241 | ICHECK(it != functions.end()) << "There is no definition of " << var; |
242 | return (*it).second; |
243 | } |
244 | |
245 | BaseFunc IRModuleNode::Lookup(const String& name) const { |
246 | GlobalVar id = this->GetGlobalVar(name); |
247 | return this->Lookup(id); |
248 | } |
249 | |
250 | TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { |
251 | auto it = type_definitions.find(var); |
252 | ICHECK(it != type_definitions.end()) << "There is no definition of " << var; |
253 | return (*it).second; |
254 | } |
255 | |
256 | TypeData IRModuleNode::LookupTypeDef(const String& name) const { |
257 | GlobalTypeVar id = this->GetGlobalTypeVar(name); |
258 | return this->LookupTypeDef(id); |
259 | } |
260 | |
261 | Constructor IRModuleNode::LookupTag(const int32_t tag) { |
262 | auto it = constructor_tag_map_.find(tag); |
263 | ICHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; |
264 | return (*it).second; |
265 | } |
266 | |
267 | void IRModuleNode::Update(const IRModule& mod) { |
268 | if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleUpdateWithRenamer" )) { |
269 | (*f)(GetRef<IRModule>(this), mod); |
270 | return; |
271 | } |
272 | for (auto pair : mod->functions) { |
273 | // TODO(@jroesch): rename into IRModule. |
274 | this->AddUnchecked(pair.first, pair.second); |
275 | } |
276 | } |
277 | |
278 | IRModule IRModuleNode::ShallowCopy() { |
279 | return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map, |
280 | this->attrs); |
281 | } |
282 | |
283 | std::pair<IRModule, GlobalVar> IRModule::FromExprInContext( |
284 | const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs, |
285 | const tvm::Map<GlobalTypeVar, TypeData>& type_definitions, |
286 | std::unordered_set<String> import_set) { |
287 | auto mod = IRModule(global_funcs, type_definitions, std::move(import_set)); |
288 | String gv_name; |
289 | |
290 | // All global definitions must be functions. |
291 | BaseFunc func; |
292 | if (auto* func_node = expr.as<BaseFuncNode>()) { |
293 | func = GetRef<BaseFunc>(func_node); |
294 | if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) { |
295 | // Function literal has been annotated with it's required global symbol. |
296 | gv_name = opt.value(); |
297 | } |
298 | } else if (const auto* f = runtime::Registry::Get("relay.ir.FunctionFromExprInContext" )) { |
299 | func = (*f)(expr, mod); |
300 | } else { |
301 | LOG(FATAL) << "`relay.ir.FunctionFromExprInContext` is not registered" ; |
302 | } |
303 | |
304 | GlobalVar main_gv; |
305 | auto global_var_supply = GlobalVarSupply(mod); |
306 | if (gv_name.empty()) { |
307 | // Bind function to 'main' (though rename if would clash with existing 'main'). |
308 | main_gv = global_var_supply->FreshGlobal("main" , false); |
309 | } else { |
310 | main_gv = global_var_supply->UniqueGlobalFor(gv_name, false); |
311 | } |
312 | mod->Add(main_gv, func); |
313 | return {mod, main_gv}; |
314 | } |
315 | |
316 | IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs, |
317 | const Map<GlobalTypeVar, TypeData>& type_definitions) { |
318 | return FromExprInContext(expr, global_funcs, type_definitions).first; |
319 | } |
320 | |
321 | void IRModuleNode::Import(const String& path) { |
322 | static const auto* f = runtime::Registry::Get("relay.parser.ParseModule" ); |
323 | ICHECK(f != nullptr) << "ValueError: Relay parser is not available" ; |
324 | if (this->import_set_.count(path) == 0) { |
325 | this->import_set_.insert(path); |
326 | std::fstream src_file(path, std::fstream::in); |
327 | std::string file_contents{std::istreambuf_iterator<char>(src_file), |
328 | std::istreambuf_iterator<char>()}; |
329 | auto mod_to_import = (*f)(path, file_contents, GetRef<IRModule>(this)); |
330 | Update(mod_to_import); |
331 | } |
332 | } |
333 | |
334 | void IRModuleNode::ImportFromStd(const String& path) { |
335 | auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path" ); |
336 | ICHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path." ; |
337 | std::string std_path = (*f)(); |
338 | this->Import(std_path + "/" + path); |
339 | } |
340 | |
341 | std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; } |
342 | |
343 | IRModule IRModule::FromText(const String& text, const String& source_path) { |
344 | static const auto* f = runtime::Registry::Get("relay.parser.ParseModule" ); |
345 | ICHECK(f != nullptr) << "ValueError: Relay parser is not available" ; |
346 | return (*f)(source_path, text, Optional<IRModule>()); |
347 | } |
348 | |
349 | TVM_REGISTER_NODE_TYPE(IRModuleNode); |
350 | |
351 | TVM_REGISTER_GLOBAL("ir.IRModule" ) |
352 | .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, |
353 | tvm::Map<GlobalTypeVar, TypeData> types) { |
354 | return IRModule(funcs, types, {}); |
355 | }); |
356 | |
357 | TVM_REGISTER_GLOBAL("ir.Module_Add" ) |
358 | .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { |
359 | ICHECK(val->IsInstance<RelayExprNode>()); |
360 | if (const auto* f = runtime::Registry::Get("relay.ir.IRModuleAdd" )) { |
361 | return (*f)(mod, var, val, update); |
362 | } |
363 | mod->Add(var, Downcast<BaseFunc>(val), update); |
364 | return mod; |
365 | }); |
366 | |
367 | TVM_REGISTER_GLOBAL("ir.Module_AddDef" ).set_body_method<IRModule>(&IRModuleNode::AddTypeDef); |
368 | |
369 | TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar" ) |
370 | .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar); |
371 | |
372 | TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars" ) |
373 | .set_body_method<IRModule>(&IRModuleNode::GetGlobalVars); |
374 | |
375 | TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars" ) |
376 | .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars); |
377 | |
378 | TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar" ) |
379 | .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar); |
380 | |
381 | TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalTypeVar" ) |
382 | .set_body_method<IRModule>(&IRModuleNode::ContainGlobalTypeVar); |
383 | |
384 | TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar" ) |
385 | .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar); |
386 | |
387 | TVM_REGISTER_GLOBAL("ir.Module_Lookup" ).set_body_typed([](IRModule mod, GlobalVar var) { |
388 | return mod->Lookup(var); |
389 | }); |
390 | |
391 | TVM_REGISTER_GLOBAL("ir.Module_Lookup_str" ).set_body_typed([](IRModule mod, String var) { |
392 | return mod->Lookup(var); |
393 | }); |
394 | |
395 | TVM_REGISTER_GLOBAL("ir.Module_LookupDef" ).set_body_typed([](IRModule mod, GlobalTypeVar var) { |
396 | return mod->LookupTypeDef(var); |
397 | }); |
398 | |
399 | TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str" ).set_body_typed([](IRModule mod, String var) { |
400 | return mod->LookupTypeDef(var); |
401 | }); |
402 | |
403 | TVM_REGISTER_GLOBAL("ir.Module_LookupTag" ).set_body_typed([](IRModule mod, int32_t tag) { |
404 | return mod->LookupTag(tag); |
405 | }); |
406 | |
407 | TVM_REGISTER_GLOBAL("ir.Module_FromExpr" ).set_body_typed(&IRModule::FromExpr); |
408 | |
409 | TVM_REGISTER_GLOBAL("ir.Module_Update" ).set_body_typed([](IRModule mod, IRModule from) { |
410 | mod->Update(from); |
411 | }); |
412 | |
413 | TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction" ) |
414 | .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); |
415 | |
416 | TVM_REGISTER_GLOBAL("ir.Module_Import" ).set_body_typed([](IRModule mod, String path) { |
417 | mod->Import(path); |
418 | }); |
419 | |
420 | TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd" ).set_body_typed([](IRModule mod, String path) { |
421 | mod->ImportFromStd(path); |
422 | }); |
423 | |
424 | TVM_REGISTER_GLOBAL("ir.Module_WithAttr" ) |
425 | .set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule { |
426 | return WithAttr(mod, key, value); |
427 | }); |
428 | |
429 | TVM_REGISTER_GLOBAL("ir.Module_GetAttr" ).set_body_typed([](IRModule mod, String key) -> ObjectRef { |
430 | return mod->GetAttr<ObjectRef>(key); |
431 | }); |
432 | |
433 | } // namespace tvm |
434 | |