1#pragma once
2#include <unordered_set>
3#include <vector>
4#include <ATen/core/symbol.h>
5#include <c10/util/Exception.h>
6#include <c10/util/hash.h>
7
8namespace c10 {
9/**
10 * class AliasInfo
11 *
12 * Data structure to hold aliasing information for an `Argument`. They can be
13 * nested to represent aliasing information on contained types.
14 *
15 * There is a `beforeSet` which describes the aliasing information before the
16 * operator executes, and an `afterSet` that describes aliasing info
17 * after execution.
18 */
19class AliasInfo {
20 public:
21 // Symbol for the set that can alias anything
22 static Symbol wildcardSet() {
23 static const Symbol wc = Symbol::fromQualString("alias::*");
24 return wc;
25 }
26
27 void setIsWrite(bool isWrite) {
28 isWrite_ = isWrite;
29 }
30
31 bool isWrite() const {
32 return isWrite_;
33 }
34
35 void addBeforeSet(Symbol aliasSet) {
36 beforeSets_.insert(aliasSet);
37 }
38
39 void addAfterSet(Symbol aliasSet) {
40 afterSets_.insert(aliasSet);
41 }
42
43 const std::unordered_set<Symbol>& beforeSets() const {
44 return beforeSets_;
45 }
46
47 const std::unordered_set<Symbol>& afterSets() const {
48 return afterSets_;
49 }
50
51 Symbol beforeSet() const {
52 AT_ASSERT(beforeSets_.size() == 1);
53 return *beforeSets_.begin();
54 }
55
56 bool isWildcardBefore() const {
57 return beforeSets_.count(wildcardSet()) != 0;
58 }
59
60 bool isWildcardAfter() const {
61 return afterSets_.count(wildcardSet()) != 0;
62 }
63
64 // the alias info for the contained types of the type
65 // e.g. if this is an annotation on List[T], `sets` refers to
66 // the alias sets that the list may be in
67 // while containedTypes()[0] refers to the sets that members of the list
68 // may be in
69 void addContainedType(AliasInfo aliasInfo) {
70 containedTypes_.push_back(std::move(aliasInfo));
71 }
72 const std::vector<AliasInfo>& containedTypes() const {
73 return containedTypes_;
74 }
75
76 private:
77 std::unordered_set<Symbol> beforeSets_;
78 std::unordered_set<Symbol> afterSets_;
79 std::vector<AliasInfo> containedTypes_;
80 bool isWrite_ = false;
81};
82
83inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
84 return lhs.isWrite() == rhs.isWrite()
85 && lhs.beforeSets() == rhs.beforeSets()
86 && lhs.afterSets() == rhs.afterSets()
87 && lhs.containedTypes() == rhs.containedTypes();
88}
89
90// this does match the way things are represented in the schema
91inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
92 out << "(";
93 bool first = true;
94 for (const auto& set : aliasInfo.beforeSets()) {
95 if (first) {
96 first = false;
97 } else {
98 out << "|";
99 }
100 out << set.toUnqualString();
101 }
102 if (aliasInfo.isWrite()) {
103 out << "!";
104 }
105 if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
106 out << " -> ";
107 first = true;
108 for (const auto& set : aliasInfo.afterSets()) {
109 if (first) {
110 first = false;
111 } else {
112 out << "|";
113 }
114 out << set.toUnqualString();
115 }
116 }
117 out << ")";
118 return out;
119}
120} // namespace c10
121
122namespace std {
123template <>
124 struct hash<c10::AliasInfo> {
125 size_t operator()(const c10::AliasInfo& aliasInfo) const {
126 auto hash = std::hash<bool>()(aliasInfo.isWrite());
127
128 // NOTE: for unordered_set hashes, we couldn't use hash_combine
129 // because hash_combine is order dependent. Instead, we choose to
130 // use XOR as the combining function as XOR is commutative.
131 size_t before_set_hash_seed = 0;
132 for (auto &e: aliasInfo.beforeSets()) {
133 auto symbol_hash = std::hash<c10::Symbol>()(e);
134 before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
135 }
136 size_t after_set_hash_seed = 0;
137 for (auto &e: aliasInfo.afterSets()) {
138 auto symbol_hash = std::hash<c10::Symbol>()(e);
139 after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
140 }
141
142 hash = c10::hash_combine(hash, before_set_hash_seed);
143 hash = c10::hash_combine(hash, after_set_hash_seed);
144 for (auto &e: aliasInfo.containedTypes()) {
145 auto contained_type_hash = std::hash<c10::AliasInfo>()(e);
146 hash = c10::hash_combine(hash, contained_type_hash);
147 }
148 return hash;
149 }
150 };
151}
152