1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file name_supply.cc
22 * \brief NameSupply that can be used to generate unique variable names.
23 */
24#include "tvm/ir/name_supply.h"
25
26#include <tvm/runtime/registry.h>
27
28#include <utility>
29
30namespace tvm {
31
32NameSupply::NameSupply(const String& prefix, std::unordered_map<std::string, int> name_map) {
33 auto n = make_object<NameSupplyNode>(prefix, std::move(name_map));
34 data_ = std::move(n);
35}
36
37String NameSupplyNode::ReserveName(const String& name, bool add_prefix) {
38 String final_name = name;
39 if (add_prefix) {
40 final_name = add_prefix_to_name(name);
41 }
42 name_map[final_name] = 0;
43 return final_name;
44}
45
46String NameSupplyNode::FreshName(const String& name, bool add_prefix) {
47 String unique_name = name;
48 if (add_prefix) {
49 unique_name = add_prefix_to_name(name);
50 }
51 unique_name = GetUniqueName(unique_name);
52 return unique_name;
53}
54
55bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) {
56 String unique_name = name;
57 if (add_prefix) {
58 unique_name = add_prefix_to_name(name);
59 }
60
61 return name_map.count(unique_name);
62}
63
64String NameSupplyNode::add_prefix_to_name(const String& name) {
65 if (prefix_.empty()) {
66 return name;
67 }
68
69 std::ostringstream ss;
70 ICHECK(name.defined());
71 ss << prefix_ << "_" << name;
72 return ss.str();
73}
74
75std::string NameSupplyNode::GetUniqueName(std::string name) {
76 for (size_t i = 0; i < name.size(); ++i) {
77 if (name[i] == '.') name[i] = '_';
78 }
79 auto it = name_map.find(name);
80 if (it != name_map.end()) {
81 auto new_name = name;
82 while (!name_map.insert({new_name, 0}).second) {
83 std::ostringstream os;
84 os << name << "_" << (++it->second);
85 new_name = os.str();
86 }
87 return new_name;
88 }
89 name_map[name] = 0;
90 return name;
91}
92
93TVM_REGISTER_NODE_TYPE(NameSupplyNode);
94
95TVM_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) {
96 return NameSupply(prefix);
97});
98
99TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName")
100 .set_body_method<NameSupply>(&NameSupplyNode::FreshName);
101
102TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName")
103 .set_body_method<NameSupply>(&NameSupplyNode::ReserveName);
104
105TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName")
106 .set_body_method<NameSupply>(&NameSupplyNode::ContainsName);
107
108} // namespace tvm
109