1 | #pragma once |
---|---|
2 | #include <unordered_set> |
3 | #include <vector> |
4 | #include <ATen/core/symbol.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/hash.h> |
7 | |
8 | namespace c10 { |
9 | /** |
10 | * class AliasInfo |
11 | * |
12 | * Data structure to hold aliasing information for an `Argument`. They can be |
13 | * nested to represent aliasing information on contained types. |
14 | * |
15 | * There is a `beforeSet` which describes the aliasing information before the |
16 | * operator executes, and an `afterSet` that describes aliasing info |
17 | * after execution. |
18 | */ |
19 | class AliasInfo { |
20 | public: |
21 | // Symbol for the set that can alias anything |
22 | static Symbol wildcardSet() { |
23 | static const Symbol wc = Symbol::fromQualString("alias::*"); |
24 | return wc; |
25 | } |
26 | |
27 | void setIsWrite(bool isWrite) { |
28 | isWrite_ = isWrite; |
29 | } |
30 | |
31 | bool isWrite() const { |
32 | return isWrite_; |
33 | } |
34 | |
35 | void addBeforeSet(Symbol aliasSet) { |
36 | beforeSets_.insert(aliasSet); |
37 | } |
38 | |
39 | void addAfterSet(Symbol aliasSet) { |
40 | afterSets_.insert(aliasSet); |
41 | } |
42 | |
43 | const std::unordered_set<Symbol>& beforeSets() const { |
44 | return beforeSets_; |
45 | } |
46 | |
47 | const std::unordered_set<Symbol>& afterSets() const { |
48 | return afterSets_; |
49 | } |
50 | |
51 | Symbol beforeSet() const { |
52 | AT_ASSERT(beforeSets_.size() == 1); |
53 | return *beforeSets_.begin(); |
54 | } |
55 | |
56 | bool isWildcardBefore() const { |
57 | return beforeSets_.count(wildcardSet()) != 0; |
58 | } |
59 | |
60 | bool isWildcardAfter() const { |
61 | return afterSets_.count(wildcardSet()) != 0; |
62 | } |
63 | |
64 | // the alias info for the contained types of the type |
65 | // e.g. if this is an annotation on List[T], `sets` refers to |
66 | // the alias sets that the list may be in |
67 | // while containedTypes()[0] refers to the sets that members of the list |
68 | // may be in |
69 | void addContainedType(AliasInfo aliasInfo) { |
70 | containedTypes_.push_back(std::move(aliasInfo)); |
71 | } |
72 | const std::vector<AliasInfo>& containedTypes() const { |
73 | return containedTypes_; |
74 | } |
75 | |
76 | private: |
77 | std::unordered_set<Symbol> beforeSets_; |
78 | std::unordered_set<Symbol> afterSets_; |
79 | std::vector<AliasInfo> containedTypes_; |
80 | bool isWrite_ = false; |
81 | }; |
82 | |
83 | inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { |
84 | return lhs.isWrite() == rhs.isWrite() |
85 | && lhs.beforeSets() == rhs.beforeSets() |
86 | && lhs.afterSets() == rhs.afterSets() |
87 | && lhs.containedTypes() == rhs.containedTypes(); |
88 | } |
89 | |
90 | // this does match the way things are represented in the schema |
91 | inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { |
92 | out << "("; |
93 | bool first = true; |
94 | for (const auto& set : aliasInfo.beforeSets()) { |
95 | if (first) { |
96 | first = false; |
97 | } else { |
98 | out << "|"; |
99 | } |
100 | out << set.toUnqualString(); |
101 | } |
102 | if (aliasInfo.isWrite()) { |
103 | out << "!"; |
104 | } |
105 | if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { |
106 | out << " -> "; |
107 | first = true; |
108 | for (const auto& set : aliasInfo.afterSets()) { |
109 | if (first) { |
110 | first = false; |
111 | } else { |
112 | out << "|"; |
113 | } |
114 | out << set.toUnqualString(); |
115 | } |
116 | } |
117 | out << ")"; |
118 | return out; |
119 | } |
120 | } // namespace c10 |
121 | |
122 | namespace std { |
123 | template <> |
124 | struct hash<c10::AliasInfo> { |
125 | size_t operator()(const c10::AliasInfo& aliasInfo) const { |
126 | auto hash = std::hash<bool>()(aliasInfo.isWrite()); |
127 | |
128 | // NOTE: for unordered_set hashes, we couldn't use hash_combine |
129 | // because hash_combine is order dependent. Instead, we choose to |
130 | // use XOR as the combining function as XOR is commutative. |
131 | size_t before_set_hash_seed = 0; |
132 | for (auto &e: aliasInfo.beforeSets()) { |
133 | auto symbol_hash = std::hash<c10::Symbol>()(e); |
134 | before_set_hash_seed = before_set_hash_seed ^ symbol_hash; |
135 | } |
136 | size_t after_set_hash_seed = 0; |
137 | for (auto &e: aliasInfo.afterSets()) { |
138 | auto symbol_hash = std::hash<c10::Symbol>()(e); |
139 | after_set_hash_seed = after_set_hash_seed ^ symbol_hash; |
140 | } |
141 | |
142 | hash = c10::hash_combine(hash, before_set_hash_seed); |
143 | hash = c10::hash_combine(hash, after_set_hash_seed); |
144 | for (auto &e: aliasInfo.containedTypes()) { |
145 | auto contained_type_hash = std::hash<c10::AliasInfo>()(e); |
146 | hash = c10::hash_combine(hash, contained_type_hash); |
147 | } |
148 | return hash; |
149 | } |
150 | }; |
151 | } |
152 |