1#pragma once
2
3#include <c10/util/Exception.h>
4
5#include <algorithm>
6#include <initializer_list>
7#include <unordered_map>
8#include <unordered_set>
9#include <vector>
10
11// For printing of the set when using a Statement as the type for the set
12#include <ir_base_nodes.h>
13
14namespace torch {
15namespace jit {
16namespace fuser {
17namespace cuda {
18
19namespace {
20
21template <typename T>
22std::string abstractToString(T* ptr) {
23 return ptr->toString();
24}
25
26template <typename T>
27std::string abstractToString(T ref) {
28 return ref.toString();
29}
30
31} // namespace
32
33// Vector like class that will prevent adding duplicate entries by also
34// maintaing a set
35template <typename T, typename Hash = std::hash<T>>
36class VectorOfUniqueEntries {
37 public:
38 VectorOfUniqueEntries() = default;
39
40 VectorOfUniqueEntries(const std::initializer_list<T>& x)
41 : vector_(x), set_(x) {}
42
43 // Returns if a node was actually added
44 bool pushBack(T entry) {
45 if (set_.emplace(entry).second) {
46 vector_.push_back(entry);
47 return true;
48 }
49 return false;
50 }
51
52 // Returns if any node was added
53 bool pushBack(const VectorOfUniqueEntries<T, Hash>& other) {
54 bool any_added = false;
55 for (auto entry : other) {
56 any_added = any_added | pushBack(entry);
57 }
58 return any_added;
59 }
60
61 // Returns a const vector useful for iterating on
62 const std::vector<T>& vector() const {
63 return vector_;
64 }
65
66 // Returns first element in vector
67 T front() const {
68 return vector_.front();
69 }
70
71 // Returns last element in vector
72 T back() const {
73 return vector_.back();
74 }
75
76 // Remove and returns the last element in vector
77 T popBack() {
78 T v = vector_.back();
79 set_.erase(v);
80 vector_.pop_back();
81 return v;
82 }
83
84 // Returns if this container is empty
85 bool empty() const {
86 return vector_.empty();
87 }
88
89 // Returns the number of elements in this container
90 size_t size() const {
91 return vector_.size();
92 }
93
94 // Returns if entry is in this vector
95 bool has(T entry) const {
96 return set_.find(entry) != set_.end();
97 }
98
99 // Erase given entry from the containers if
100 // there is a match.
101 void erase(T entry) {
102 vector_.erase(
103 std::remove_if(
104 vector_.begin(),
105 vector_.end(),
106 [entry](T val) { return val == entry; }),
107 vector_.end());
108
109 set_.erase(entry);
110 }
111
112 // Insert elements at the end of the container.
113 template <typename InputIt>
114 void insert(InputIt begin, InputIt end) {
115 for (auto it = begin; it != end; it++) {
116 pushBack(*it);
117 }
118 }
119
120 // Returns iterator pointing to the beginning of vector container
121 auto begin() const {
122 return vector().begin();
123 }
124
125 // Returns iterator pointing to the end of vector container
126 auto end() const {
127 return vector().end();
128 }
129
130 // Returns iterator pointing to the beginning of vector container
131 auto begin() {
132 return vector().begin();
133 }
134
135 // Returns iterator pointing to the end of vector container
136 auto end() {
137 return vector().end();
138 }
139
140 std::string toString() {
141 std::stringstream ss;
142 ss << "{ ";
143 for (auto entry : vector()) {
144 ss << abstractToString(entry);
145 if (entry != vector().back()) {
146 ss << "; ";
147 }
148 }
149 ss << " }";
150 return ss.str();
151 }
152
153 private:
154 std::vector<T> vector_;
155 std::unordered_set<T, Hash> set_;
156};
157
158//! Container class DisjointSet models equivalence relationships
159//!
160//! Each instance of this class keeps equivalence sets
161//! DisjointSet::mapEntries(a,b) makes the full set of a and b equivalent
162//! DisjointSet::*AreMapped(a,b) checks if a and b belong to the same disjoint
163//! set
164template <typename T, typename Hash = std::hash<T>>
165class DisjointSets {
166 public:
167 DisjointSets() = default;
168
169 // Warning: returned values should never be modified. This accessor isn't
170 // strictly safe as VectorOfUniqueEntries is not returned as a const.
171 const std::
172 unordered_map<T, std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, Hash>&
173 disjointSetMap() const {
174 return disjoint_set_maps_;
175 }
176
177 // Warning: returned values should never be modified. This accessor isn't
178 // strictly safe as VectorOfUniqueEntries is not returned as a const.
179 const std::vector<std::shared_ptr<VectorOfUniqueEntries<T, Hash>>>&
180 disjointSets() const {
181 return disjoint_sets_;
182 }
183
184 // Return the entire disjoint set of provided entry
185 const VectorOfUniqueEntries<T, Hash>& getDisjointSetOf(T entry) const {
186 auto set_it = disjoint_set_maps_.find(entry);
187 TORCH_INTERNAL_ASSERT(
188 set_it != disjoint_set_maps_.end(),
189 "Could not find entry for ",
190 entry->toString());
191 return *(set_it->second);
192 }
193
194 // Initializes a new set for provided entry
195 //
196 // TODO: Return iterator
197 void initializeSet(T entry) {
198 if (disjoint_set_maps_.find(entry) != disjoint_set_maps_.end()) {
199 return;
200 }
201
202 disjoint_sets_.push_back(
203 std::make_shared<VectorOfUniqueEntries<T, Hash>>());
204 disjoint_sets_.back()->pushBack(entry);
205 disjoint_set_maps_.emplace(std::make_pair(entry, disjoint_sets_.back()));
206 }
207
208 // Adds all of the disjoint set belonging to entry1 to the disjoint set
209 // belonging to entry0, maps all entries of disjoint set belonging to entry1
210 // to entry0, removes original disjoint set belonging to entry1.
211 void mapEntries(T entry0, T entry1) {
212 auto set_it_0 = disjoint_set_maps_.find(entry0);
213 auto set_it_1 = disjoint_set_maps_.find(entry1);
214
215 // Track if we need to reset iterators, optimize for case where both entries
216 // exist
217 bool invalid_iterators = false;
218 if (set_it_0 == disjoint_set_maps_.end()) {
219 initializeSet(entry0);
220 invalid_iterators = true;
221 }
222
223 if (set_it_1 == disjoint_set_maps_.end()) {
224 initializeSet(entry1);
225 invalid_iterators = true;
226 }
227
228 // TODO: We can avoid refinding one iterator if initialize set returns an
229 // iterator, though if we insert entry1 we'd have to refind entry0 as it
230 // could invalidate all iterators
231 if (invalid_iterators) {
232 set_it_0 = disjoint_set_maps_.find(entry0);
233 set_it_1 = disjoint_set_maps_.find(entry1);
234 }
235
236 auto set0_shared_ptr = set_it_0->second;
237 auto set1_shared_ptr = set_it_1->second;
238
239 // If the sets are already the same, do nothing
240 if (set0_shared_ptr == set1_shared_ptr) {
241 return;
242 }
243
244 // Place everything in set1 into set0 and remap all entries in set1 to set0
245 for (auto entry : set1_shared_ptr->vector()) {
246 set0_shared_ptr->pushBack(entry);
247 disjoint_set_maps_[entry] = set0_shared_ptr;
248 }
249
250 // set1 no longer needed as its entries are copied into set0
251 disjoint_sets_.erase(std::find(
252 disjoint_sets_.begin(), disjoint_sets_.end(), set1_shared_ptr));
253 }
254
255 // Will assert if provided entry0 is not in any disjoint set, otherwise
256 // returns if entry0 and entry1 are in the same disjoint set.
257 bool strictAreMapped(T entry0, T entry1) const {
258 auto entry_it = disjointSetMap().find(entry0);
259 TORCH_INTERNAL_ASSERT(
260 entry_it != disjointSetMap().end(),
261 "Strict mapping failed on element: ",
262 abstractToString(entry0),
263 " either an error occurred, or non strict mapping should have been used.");
264 return entry_it->second->has(entry1);
265 }
266
267 // If entry0 doesn't have a disjoint set returns false, otherwise returns if
268 // entry0 and entry1 are in the same disjoint set.
269 bool permissiveAreMapped(T entry0, T entry1) const {
270 auto entry_it = disjointSetMap().find(entry0);
271 if (entry_it == disjointSetMap().end()) {
272 return false;
273 }
274 return entry_it->second->has(entry1);
275 }
276
277 // Returns if a set exists with provided entry
278 bool mappingExists(T entry) const {
279 return disjoint_set_maps_.find(entry) != disjoint_set_maps_.end();
280 }
281
282 // Returns a deterministic list of all entries that have been added to any
283 // disjoint set.
284 //
285 // Warning: constructed on every call, consider caching result.
286 VectorOfUniqueEntries<T, Hash> getAllElements() const {
287 VectorOfUniqueEntries<T, Hash> all_elements;
288 for (auto set : disjoint_sets_) {
289 for (auto entry : set->vector()) {
290 all_elements.pushBack(entry);
291 }
292 }
293 return all_elements;
294 }
295
296 // Completely clears all disjoint sets
297 void clear() {
298 disjoint_set_maps_.clear();
299 disjoint_sets_.clear();
300 }
301
302 std::string toString() const {
303 std::stringstream ss;
304 ss << "disjoint sets{\n";
305 const std::string sep(" ");
306 for (auto s_ptr : disjoint_sets_) {
307 auto& set = *s_ptr;
308 ss << sep << "{\n";
309 for (auto entry : set.vector()) {
310 ss << sep << sep << abstractToString(entry) << "\n";
311 }
312 ss << sep << "}\n";
313 }
314 ss << "}";
315 return ss.str();
316 }
317
318 private:
319 // Disjoint sets
320 std::unordered_map<T, std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, Hash>
321 disjoint_set_maps_;
322
323 // Keep a list of disjoint_sets that's deterministic to iterate over
324 std::vector<std::shared_ptr<VectorOfUniqueEntries<T, Hash>>> disjoint_sets_;
325};
326
327} // namespace cuda
328} // namespace fuser
329} // namespace jit
330} // namespace torch
331