1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file global_var_supply.cc |
22 | * \brief GlobalVarSupply that can be used to generate unique GlobalVars. |
23 | */ |
24 | #include "tvm/ir/global_var_supply.h" |
25 | |
26 | #include <tvm/runtime/registry.h> |
27 | |
28 | #include <utility> |
29 | |
30 | #include "tvm/ir/expr.h" |
31 | |
32 | namespace tvm { |
33 | GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, |
34 | std::unordered_map<std::string, GlobalVar> name_to_var_map) { |
35 | auto n = make_object<GlobalVarSupplyNode>(name_supply, name_to_var_map); |
36 | data_ = std::move(n); |
37 | } |
38 | |
39 | std::string GetModuleName(const IRModule& module) { |
40 | return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default" ); |
41 | } |
42 | |
43 | GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply(NameSupply("" )) { |
44 | if (!modules.empty()) { |
45 | IRModule first_mod = modules.front(); |
46 | this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); |
47 | } |
48 | for (auto& mod : modules) { |
49 | for (auto kv : mod->functions) { |
50 | this->operator->()->ReserveGlobalVar(kv.first); |
51 | } |
52 | } |
53 | } |
54 | |
55 | GlobalVarSupply::GlobalVarSupply(const IRModule module) |
56 | : GlobalVarSupply(Array<IRModule>{module}) {} |
57 | |
58 | void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { |
59 | name_supply_->ReserveName(var->name_hint, false); |
60 | if (!allow_conflict) { |
61 | ICHECK(name_to_var_map_.count(var->name_hint) == 0) |
62 | << "GlobalVar " << var << " conflicts by name in this supply." ; |
63 | } |
64 | name_to_var_map_[var->name_hint] = var; |
65 | } |
66 | |
67 | GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, |
68 | std::unordered_map<std::string, GlobalVar> name_to_var_map) |
69 | : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} |
70 | |
71 | GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_prefix) { |
72 | String final_name = name_supply_->ReserveName(name, add_prefix); |
73 | |
74 | auto it = name_to_var_map_.find(final_name); |
75 | if (it != name_to_var_map_.end()) { |
76 | return it->second; |
77 | } else { |
78 | GlobalVar var = GlobalVar(final_name); |
79 | name_to_var_map_.emplace(final_name, var); |
80 | return var; |
81 | } |
82 | } |
83 | |
84 | GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { |
85 | String final_name = name_supply_->FreshName(name, add_prefix); |
86 | ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) |
87 | << "GlobalVar already exists for name " << final_name; |
88 | GlobalVar var = GlobalVar(final_name); |
89 | name_to_var_map_.emplace(final_name, var); |
90 | return var; |
91 | } |
92 | |
93 | TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode); |
94 | |
95 | TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply" ) |
96 | .set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }); |
97 | |
98 | TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule" ).set_body_typed([](IRModule mod) { |
99 | return GlobalVarSupply(std::move(mod)); |
100 | }); |
101 | |
102 | TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules" ).set_body_typed([](const Array<IRModule>& mods) { |
103 | return GlobalVarSupply(mods); |
104 | }); |
105 | |
106 | TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal" ) |
107 | .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::FreshGlobal); |
108 | |
109 | TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor" ) |
110 | .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::UniqueGlobalFor); |
111 | |
112 | TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar" ) |
113 | .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::ReserveGlobalVar); |
114 | |
115 | } // namespace tvm |
116 | |