1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file name_supply.cc |
22 | * \brief NameSupply that can be used to generate unique variable names. |
23 | */ |
24 | #include "tvm/ir/name_supply.h" |
25 | |
26 | #include <tvm/runtime/registry.h> |
27 | |
28 | #include <utility> |
29 | |
30 | namespace tvm { |
31 | |
32 | NameSupply::NameSupply(const String& prefix, std::unordered_map<std::string, int> name_map) { |
33 | auto n = make_object<NameSupplyNode>(prefix, std::move(name_map)); |
34 | data_ = std::move(n); |
35 | } |
36 | |
37 | String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { |
38 | String final_name = name; |
39 | if (add_prefix) { |
40 | final_name = add_prefix_to_name(name); |
41 | } |
42 | name_map[final_name] = 0; |
43 | return final_name; |
44 | } |
45 | |
46 | String NameSupplyNode::FreshName(const String& name, bool add_prefix) { |
47 | String unique_name = name; |
48 | if (add_prefix) { |
49 | unique_name = add_prefix_to_name(name); |
50 | } |
51 | unique_name = GetUniqueName(unique_name); |
52 | return unique_name; |
53 | } |
54 | |
55 | bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { |
56 | String unique_name = name; |
57 | if (add_prefix) { |
58 | unique_name = add_prefix_to_name(name); |
59 | } |
60 | |
61 | return name_map.count(unique_name); |
62 | } |
63 | |
64 | String NameSupplyNode::add_prefix_to_name(const String& name) { |
65 | if (prefix_.empty()) { |
66 | return name; |
67 | } |
68 | |
69 | std::ostringstream ss; |
70 | ICHECK(name.defined()); |
71 | ss << prefix_ << "_" << name; |
72 | return ss.str(); |
73 | } |
74 | |
75 | std::string NameSupplyNode::GetUniqueName(std::string name) { |
76 | for (size_t i = 0; i < name.size(); ++i) { |
77 | if (name[i] == '.') name[i] = '_'; |
78 | } |
79 | auto it = name_map.find(name); |
80 | if (it != name_map.end()) { |
81 | auto new_name = name; |
82 | while (!name_map.insert({new_name, 0}).second) { |
83 | std::ostringstream os; |
84 | os << name << "_" << (++it->second); |
85 | new_name = os.str(); |
86 | } |
87 | return new_name; |
88 | } |
89 | name_map[name] = 0; |
90 | return name; |
91 | } |
92 | |
93 | TVM_REGISTER_NODE_TYPE(NameSupplyNode); |
94 | |
95 | TVM_REGISTER_GLOBAL("ir.NameSupply" ).set_body_typed([](String prefix) { |
96 | return NameSupply(prefix); |
97 | }); |
98 | |
99 | TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName" ) |
100 | .set_body_method<NameSupply>(&NameSupplyNode::FreshName); |
101 | |
102 | TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName" ) |
103 | .set_body_method<NameSupply>(&NameSupplyNode::ReserveName); |
104 | |
105 | TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName" ) |
106 | .set_body_method<NameSupply>(&NameSupplyNode::ContainsName); |
107 | |
108 | } // namespace tvm |
109 | |