1 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
2 | #include <torch/csrc/jit/passes/symbolic_shape_cache.h> |
3 | #include <torch/csrc/lazy/core/cache.h> |
4 | |
5 | #include <utility> |
6 | |
7 | // SHAPE CACHINHG CODE |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace { |
11 | using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>; |
12 | using CanonicalArgVec = std::vector<CanonicalArg>; |
13 | using CanonicalRet = std::vector<CanonicalizedSymbolicShape>; |
14 | using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>; |
15 | |
16 | CanonicalArgVec cannonicalizeVec( |
17 | const std::vector<SSAInput>& arg_vec, |
18 | std::unordered_map<int64_t, int64_t>& ss_map, |
19 | bool deep_copy = true) { |
20 | CanonicalArgVec canonical_args; |
21 | canonical_args.reserve(arg_vec.size()); |
22 | for (auto& arg : arg_vec) { |
23 | if (const IValue* iv = c10::get_if<IValue>(&arg)) { |
24 | if (deep_copy) { |
25 | canonical_args.emplace_back(iv->deepcopy()); |
26 | } else { |
27 | canonical_args.emplace_back(*iv); |
28 | } |
29 | } else { |
30 | auto& ss = c10::get<at::SymbolicShape>(arg); |
31 | canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map)); |
32 | } |
33 | } |
34 | return canonical_args; |
35 | } |
36 | |
37 | std::vector<CanonicalizedSymbolicShape> cannonicalizeVec( |
38 | const std::vector<at::SymbolicShape>& ret_vec, |
39 | std::unordered_map<int64_t, int64_t>& ss_map) { |
40 | std::vector<CanonicalizedSymbolicShape> canonical_rets; |
41 | canonical_rets.reserve(ret_vec.size()); |
42 | for (auto& ss : ret_vec) { |
43 | canonical_rets.emplace_back(ss, ss_map); |
44 | } |
45 | return canonical_rets; |
46 | } |
47 | |
48 | struct ArgumentsHasher { |
49 | size_t operator()(const ShapeCacheKey& cacheKey) const { |
50 | // TODO: ignore arguments that are not used in shape function (not needed |
51 | // initially) |
52 | auto& op_name = std::get<0>(cacheKey); |
53 | auto& arg_vec = std::get<1>(cacheKey); |
54 | |
55 | size_t hash_val = c10::hash<c10::OperatorName>()(op_name); |
56 | |
57 | hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val); |
58 | for (const CanonicalArg& arg : arg_vec) { |
59 | size_t cur_arg = 0; |
60 | if (const IValue* ival = c10::get_if<IValue>(&arg)) { |
61 | // IValue doesn't hash List (as Python doesn't), so we will do a custom |
62 | // list hash |
63 | if (ival->isList()) { |
64 | TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List" ); |
65 | cur_arg = ival->toListRef().size(); |
66 | for (const IValue& elem_ival : ival->toListRef()) { |
67 | cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival)); |
68 | } |
69 | } else { |
70 | cur_arg = IValue::hash(ival); |
71 | } |
72 | } else { |
73 | cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash(); |
74 | } |
75 | hash_val = at::hash_combine(hash_val, cur_arg); |
76 | } |
77 | return hash_val; |
78 | } |
79 | }; |
80 | |
81 | using ShapeCache = lazy::Cache< |
82 | ShapeCacheKey, |
83 | std::vector<CanonicalizedSymbolicShape>, |
84 | ArgumentsHasher>; |
85 | |
86 | constexpr size_t kShapeCacheSize = 1024; |
87 | ShapeCache shapeCache(kShapeCacheSize); |
88 | |
89 | ShapeCacheKey get_cache_key( |
90 | const FunctionSchema* schema, |
91 | const std::vector<SSAInput>& arg_vec, |
92 | std::unordered_map<int64_t, int64_t>& ss_map, |
93 | bool deep_copy = true) { |
94 | CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy); |
95 | return std::make_tuple(schema->operator_name(), canonical_args); |
96 | } |
97 | |
98 | } // namespace |
99 | |
100 | TORCH_API void cache_shape_function( |
101 | const FunctionSchema* schema, |
102 | const std::vector<SSAInput>& arg_vec, |
103 | const std::vector<at::SymbolicShape>& ret_vec) { |
104 | // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> |
105 | auto ss_map = std::unordered_map<int64_t, int64_t>(); |
106 | auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true); |
107 | auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>( |
108 | cannonicalizeVec(ret_vec, ss_map)); |
109 | shapeCache.Add(std::move(cache_key), std::move(can_ret_vec)); |
110 | } |
111 | |
112 | TORCH_API c10::optional<std::vector<at::SymbolicShape>> |
113 | get_cached_shape_function( |
114 | const FunctionSchema* schema, |
115 | const std::vector<SSAInput>& arg_vec) { |
116 | // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both |
117 | // ss_map and inverse_ss_map |
118 | auto ss_map = std::unordered_map<int64_t, int64_t>(); |
119 | auto cache_key = |
120 | get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false); |
121 | auto cached_ret_vec = shapeCache.Get(cache_key); |
122 | if (cached_ret_vec == nullptr) { |
123 | return c10::nullopt; |
124 | } |
125 | // Decanonicalize the return values |
126 | auto inverse_ss_map = std::unordered_map<int64_t, int64_t>(); |
127 | for (auto& ss_val : ss_map) { |
128 | inverse_ss_map[ss_val.second] = ss_val.first; |
129 | } |
130 | std::vector<at::SymbolicShape> ret_vec; |
131 | for (auto& css : *cached_ret_vec) { |
132 | ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map)); |
133 | } |
134 | return ret_vec; |
135 | } |
136 | |
137 | // Function only to access the cache, used for testing |
138 | TORCH_API void clear_shape_cache() { |
139 | shapeCache.Clear(); |
140 | } |
141 | |
142 | TORCH_API size_t get_shape_cache_size() { |
143 | return shapeCache.Numel(); |
144 | } |
145 | |
146 | void CanonicalizedSymbolicShape::init( |
147 | const c10::SymbolicShape& orig_shape, |
148 | std::unordered_map<int64_t, int64_t>& ss_map) { |
149 | auto sizes = orig_shape.sizes(); |
150 | if (!sizes) { |
151 | values_ = c10::nullopt; |
152 | return; |
153 | } |
154 | values_ = std::vector<int64_t>(); |
155 | int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1; |
156 | for (auto& cur_shape : *sizes) { |
157 | if (cur_shape.is_static()) { |
158 | values_->push_back(cur_shape.static_size()); |
159 | } else { |
160 | // Check for aliasing |
161 | auto it = ss_map.find(cur_shape.value()); |
162 | |
163 | if (it == ss_map.end()) { |
164 | values_->push_back(cur_symbolic_index); |
165 | ss_map.insert({cur_shape.value(), cur_symbolic_index}); |
166 | cur_symbolic_index--; |
167 | } else { |
168 | values_->push_back(it->second); |
169 | } |
170 | } |
171 | } |
172 | } |
173 | |
174 | c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape( |
175 | std::unordered_map<int64_t, int64_t>& inverse_ss_map) const { |
176 | if (!values_.has_value()) { |
177 | return c10::SymbolicShape(); |
178 | } |
179 | std::vector<at::ShapeSymbol> sizes; |
180 | for (long long cur_val : *values_) { |
181 | if (cur_val >= 0) { |
182 | sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val)); |
183 | continue; |
184 | } |
185 | auto res = inverse_ss_map.find(cur_val); |
186 | if (res != inverse_ss_map.end()) { |
187 | sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second)); |
188 | } else { |
189 | auto new_symbol = at::ShapeSymbol::newSymbol(); |
190 | inverse_ss_map.insert({cur_val, new_symbol.value()}); |
191 | sizes.push_back(new_symbol); |
192 | } |
193 | } |
194 | return c10::SymbolicShape(std::move(sizes)); |
195 | } |
196 | |
197 | size_t CanonicalizedSymbolicShape::hash() const { |
198 | if (!values_.has_value()) { |
199 | return 0x8cc80c80; // random value to prevent hash collisions |
200 | } |
201 | return c10::hash<std::vector<int64_t>>()(values_.value()); |
202 | } |
203 | |
204 | bool operator==( |
205 | const CanonicalizedSymbolicShape& a, |
206 | const CanonicalizedSymbolicShape& b) { |
207 | return a.values_ == b.values_; |
208 | }; |
209 | } // namespace jit |
210 | } // namespace torch |
211 | |