1 | #pragma once |
2 | |
3 | #include <c10/util/Exception.h> |
4 | |
5 | #include <algorithm> |
6 | #include <initializer_list> |
7 | #include <unordered_map> |
8 | #include <unordered_set> |
9 | #include <vector> |
10 | |
11 | // For printing of the set when using a Statement as the type for the set |
12 | #include <ir_base_nodes.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | namespace fuser { |
17 | namespace cuda { |
18 | |
19 | namespace { |
20 | |
21 | template <typename T> |
22 | std::string abstractToString(T* ptr) { |
23 | return ptr->toString(); |
24 | } |
25 | |
26 | template <typename T> |
27 | std::string abstractToString(T ref) { |
28 | return ref.toString(); |
29 | } |
30 | |
31 | } // namespace |
32 | |
33 | // Vector like class that will prevent adding duplicate entries by also |
34 | // maintaing a set |
35 | template <typename T, typename Hash = std::hash<T>> |
36 | class VectorOfUniqueEntries { |
37 | public: |
38 | VectorOfUniqueEntries() = default; |
39 | |
40 | VectorOfUniqueEntries(const std::initializer_list<T>& x) |
41 | : vector_(x), set_(x) {} |
42 | |
43 | // Returns if a node was actually added |
44 | bool pushBack(T entry) { |
45 | if (set_.emplace(entry).second) { |
46 | vector_.push_back(entry); |
47 | return true; |
48 | } |
49 | return false; |
50 | } |
51 | |
52 | // Returns if any node was added |
53 | bool pushBack(const VectorOfUniqueEntries<T, Hash>& other) { |
54 | bool any_added = false; |
55 | for (auto entry : other) { |
56 | any_added = any_added | pushBack(entry); |
57 | } |
58 | return any_added; |
59 | } |
60 | |
61 | // Returns a const vector useful for iterating on |
62 | const std::vector<T>& vector() const { |
63 | return vector_; |
64 | } |
65 | |
66 | // Returns first element in vector |
67 | T front() const { |
68 | return vector_.front(); |
69 | } |
70 | |
71 | // Returns last element in vector |
72 | T back() const { |
73 | return vector_.back(); |
74 | } |
75 | |
76 | // Remove and returns the last element in vector |
77 | T popBack() { |
78 | T v = vector_.back(); |
79 | set_.erase(v); |
80 | vector_.pop_back(); |
81 | return v; |
82 | } |
83 | |
84 | // Returns if this container is empty |
85 | bool empty() const { |
86 | return vector_.empty(); |
87 | } |
88 | |
89 | // Returns the number of elements in this container |
90 | size_t size() const { |
91 | return vector_.size(); |
92 | } |
93 | |
94 | // Returns if entry is in this vector |
95 | bool has(T entry) const { |
96 | return set_.find(entry) != set_.end(); |
97 | } |
98 | |
99 | // Erase given entry from the containers if |
100 | // there is a match. |
101 | void erase(T entry) { |
102 | vector_.erase( |
103 | std::remove_if( |
104 | vector_.begin(), |
105 | vector_.end(), |
106 | [entry](T val) { return val == entry; }), |
107 | vector_.end()); |
108 | |
109 | set_.erase(entry); |
110 | } |
111 | |
112 | // Insert elements at the end of the container. |
113 | template <typename InputIt> |
114 | void insert(InputIt begin, InputIt end) { |
115 | for (auto it = begin; it != end; it++) { |
116 | pushBack(*it); |
117 | } |
118 | } |
119 | |
120 | // Returns iterator pointing to the beginning of vector container |
121 | auto begin() const { |
122 | return vector().begin(); |
123 | } |
124 | |
125 | // Returns iterator pointing to the end of vector container |
126 | auto end() const { |
127 | return vector().end(); |
128 | } |
129 | |
130 | // Returns iterator pointing to the beginning of vector container |
131 | auto begin() { |
132 | return vector().begin(); |
133 | } |
134 | |
135 | // Returns iterator pointing to the end of vector container |
136 | auto end() { |
137 | return vector().end(); |
138 | } |
139 | |
140 | std::string toString() { |
141 | std::stringstream ss; |
142 | ss << "{ " ; |
143 | for (auto entry : vector()) { |
144 | ss << abstractToString(entry); |
145 | if (entry != vector().back()) { |
146 | ss << "; " ; |
147 | } |
148 | } |
149 | ss << " }" ; |
150 | return ss.str(); |
151 | } |
152 | |
153 | private: |
154 | std::vector<T> vector_; |
155 | std::unordered_set<T, Hash> set_; |
156 | }; |
157 | |
158 | //! Container class DisjointSet models equivalence relationships |
159 | //! |
160 | //! Each instance of this class keeps equivalence sets |
161 | //! DisjointSet::mapEntries(a,b) makes the full set of a and b equivalent |
162 | //! DisjointSet::*AreMapped(a,b) checks if a and b belong to the same disjoint |
163 | //! set |
164 | template <typename T, typename Hash = std::hash<T>> |
165 | class DisjointSets { |
166 | public: |
167 | DisjointSets() = default; |
168 | |
169 | // Warning: returned values should never be modified. This accessor isn't |
170 | // strictly safe as VectorOfUniqueEntries is not returned as a const. |
171 | const std:: |
172 | unordered_map<T, std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, Hash>& |
173 | disjointSetMap() const { |
174 | return disjoint_set_maps_; |
175 | } |
176 | |
177 | // Warning: returned values should never be modified. This accessor isn't |
178 | // strictly safe as VectorOfUniqueEntries is not returned as a const. |
179 | const std::vector<std::shared_ptr<VectorOfUniqueEntries<T, Hash>>>& |
180 | disjointSets() const { |
181 | return disjoint_sets_; |
182 | } |
183 | |
184 | // Return the entire disjoint set of provided entry |
185 | const VectorOfUniqueEntries<T, Hash>& getDisjointSetOf(T entry) const { |
186 | auto set_it = disjoint_set_maps_.find(entry); |
187 | TORCH_INTERNAL_ASSERT( |
188 | set_it != disjoint_set_maps_.end(), |
189 | "Could not find entry for " , |
190 | entry->toString()); |
191 | return *(set_it->second); |
192 | } |
193 | |
194 | // Initializes a new set for provided entry |
195 | // |
196 | // TODO: Return iterator |
197 | void initializeSet(T entry) { |
198 | if (disjoint_set_maps_.find(entry) != disjoint_set_maps_.end()) { |
199 | return; |
200 | } |
201 | |
202 | disjoint_sets_.push_back( |
203 | std::make_shared<VectorOfUniqueEntries<T, Hash>>()); |
204 | disjoint_sets_.back()->pushBack(entry); |
205 | disjoint_set_maps_.emplace(std::make_pair(entry, disjoint_sets_.back())); |
206 | } |
207 | |
208 | // Adds all of the disjoint set belonging to entry1 to the disjoint set |
209 | // belonging to entry0, maps all entries of disjoint set belonging to entry1 |
210 | // to entry0, removes original disjoint set belonging to entry1. |
211 | void mapEntries(T entry0, T entry1) { |
212 | auto set_it_0 = disjoint_set_maps_.find(entry0); |
213 | auto set_it_1 = disjoint_set_maps_.find(entry1); |
214 | |
215 | // Track if we need to reset iterators, optimize for case where both entries |
216 | // exist |
217 | bool invalid_iterators = false; |
218 | if (set_it_0 == disjoint_set_maps_.end()) { |
219 | initializeSet(entry0); |
220 | invalid_iterators = true; |
221 | } |
222 | |
223 | if (set_it_1 == disjoint_set_maps_.end()) { |
224 | initializeSet(entry1); |
225 | invalid_iterators = true; |
226 | } |
227 | |
228 | // TODO: We can avoid refinding one iterator if initialize set returns an |
229 | // iterator, though if we insert entry1 we'd have to refind entry0 as it |
230 | // could invalidate all iterators |
231 | if (invalid_iterators) { |
232 | set_it_0 = disjoint_set_maps_.find(entry0); |
233 | set_it_1 = disjoint_set_maps_.find(entry1); |
234 | } |
235 | |
236 | auto set0_shared_ptr = set_it_0->second; |
237 | auto set1_shared_ptr = set_it_1->second; |
238 | |
239 | // If the sets are already the same, do nothing |
240 | if (set0_shared_ptr == set1_shared_ptr) { |
241 | return; |
242 | } |
243 | |
244 | // Place everything in set1 into set0 and remap all entries in set1 to set0 |
245 | for (auto entry : set1_shared_ptr->vector()) { |
246 | set0_shared_ptr->pushBack(entry); |
247 | disjoint_set_maps_[entry] = set0_shared_ptr; |
248 | } |
249 | |
250 | // set1 no longer needed as its entries are copied into set0 |
251 | disjoint_sets_.erase(std::find( |
252 | disjoint_sets_.begin(), disjoint_sets_.end(), set1_shared_ptr)); |
253 | } |
254 | |
255 | // Will assert if provided entry0 is not in any disjoint set, otherwise |
256 | // returns if entry0 and entry1 are in the same disjoint set. |
257 | bool strictAreMapped(T entry0, T entry1) const { |
258 | auto entry_it = disjointSetMap().find(entry0); |
259 | TORCH_INTERNAL_ASSERT( |
260 | entry_it != disjointSetMap().end(), |
261 | "Strict mapping failed on element: " , |
262 | abstractToString(entry0), |
263 | " either an error occurred, or non strict mapping should have been used." ); |
264 | return entry_it->second->has(entry1); |
265 | } |
266 | |
267 | // If entry0 doesn't have a disjoint set returns false, otherwise returns if |
268 | // entry0 and entry1 are in the same disjoint set. |
269 | bool permissiveAreMapped(T entry0, T entry1) const { |
270 | auto entry_it = disjointSetMap().find(entry0); |
271 | if (entry_it == disjointSetMap().end()) { |
272 | return false; |
273 | } |
274 | return entry_it->second->has(entry1); |
275 | } |
276 | |
277 | // Returns if a set exists with provided entry |
278 | bool mappingExists(T entry) const { |
279 | return disjoint_set_maps_.find(entry) != disjoint_set_maps_.end(); |
280 | } |
281 | |
282 | // Returns a deterministic list of all entries that have been added to any |
283 | // disjoint set. |
284 | // |
285 | // Warning: constructed on every call, consider caching result. |
286 | VectorOfUniqueEntries<T, Hash> getAllElements() const { |
287 | VectorOfUniqueEntries<T, Hash> all_elements; |
288 | for (auto set : disjoint_sets_) { |
289 | for (auto entry : set->vector()) { |
290 | all_elements.pushBack(entry); |
291 | } |
292 | } |
293 | return all_elements; |
294 | } |
295 | |
296 | // Completely clears all disjoint sets |
297 | void clear() { |
298 | disjoint_set_maps_.clear(); |
299 | disjoint_sets_.clear(); |
300 | } |
301 | |
302 | std::string toString() const { |
303 | std::stringstream ss; |
304 | ss << "disjoint sets{\n" ; |
305 | const std::string sep(" " ); |
306 | for (auto s_ptr : disjoint_sets_) { |
307 | auto& set = *s_ptr; |
308 | ss << sep << "{\n" ; |
309 | for (auto entry : set.vector()) { |
310 | ss << sep << sep << abstractToString(entry) << "\n" ; |
311 | } |
312 | ss << sep << "}\n" ; |
313 | } |
314 | ss << "}" ; |
315 | return ss.str(); |
316 | } |
317 | |
318 | private: |
319 | // Disjoint sets |
320 | std::unordered_map<T, std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, Hash> |
321 | disjoint_set_maps_; |
322 | |
323 | // Keep a list of disjoint_sets that's deterministic to iterate over |
324 | std::vector<std::shared_ptr<VectorOfUniqueEntries<T, Hash>>> disjoint_sets_; |
325 | }; |
326 | |
327 | } // namespace cuda |
328 | } // namespace fuser |
329 | } // namespace jit |
330 | } // namespace torch |
331 | |