1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file global_var_supply.cc
22 * \brief GlobalVarSupply that can be used to generate unique GlobalVars.
23 */
24#include "tvm/ir/global_var_supply.h"
25
26#include <tvm/runtime/registry.h>
27
28#include <utility>
29
30#include "tvm/ir/expr.h"
31
32namespace tvm {
33GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply,
34 std::unordered_map<std::string, GlobalVar> name_to_var_map) {
35 auto n = make_object<GlobalVarSupplyNode>(name_supply, name_to_var_map);
36 data_ = std::move(n);
37}
38
39std::string GetModuleName(const IRModule& module) {
40 return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
41}
42
43GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply(NameSupply("")) {
44 if (!modules.empty()) {
45 IRModule first_mod = modules.front();
46 this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
47 }
48 for (auto& mod : modules) {
49 for (auto kv : mod->functions) {
50 this->operator->()->ReserveGlobalVar(kv.first);
51 }
52 }
53}
54
55GlobalVarSupply::GlobalVarSupply(const IRModule module)
56 : GlobalVarSupply(Array<IRModule>{module}) {}
57
58void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) {
59 name_supply_->ReserveName(var->name_hint, false);
60 if (!allow_conflict) {
61 ICHECK(name_to_var_map_.count(var->name_hint) == 0)
62 << "GlobalVar " << var << " conflicts by name in this supply.";
63 }
64 name_to_var_map_[var->name_hint] = var;
65}
66
67GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply,
68 std::unordered_map<std::string, GlobalVar> name_to_var_map)
69 : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {}
70
71GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_prefix) {
72 String final_name = name_supply_->ReserveName(name, add_prefix);
73
74 auto it = name_to_var_map_.find(final_name);
75 if (it != name_to_var_map_.end()) {
76 return it->second;
77 } else {
78 GlobalVar var = GlobalVar(final_name);
79 name_to_var_map_.emplace(final_name, var);
80 return var;
81 }
82}
83
84GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) {
85 String final_name = name_supply_->FreshName(name, add_prefix);
86 ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end())
87 << "GlobalVar already exists for name " << final_name;
88 GlobalVar var = GlobalVar(final_name);
89 name_to_var_map_.emplace(final_name, var);
90 return var;
91}
92
93TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode);
94
95TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply")
96 .set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); });
97
98TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) {
99 return GlobalVarSupply(std::move(mod));
100});
101
102TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const Array<IRModule>& mods) {
103 return GlobalVarSupply(mods);
104});
105
106TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
107 .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::FreshGlobal);
108
109TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor")
110 .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::UniqueGlobalFor);
111
112TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar")
113 .set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::ReserveGlobalVar);
114
115} // namespace tvm
116