1#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
2#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
3#include <torch/csrc/lazy/core/cache.h>
4
5#include <utility>
6
7// SHAPE CACHINHG CODE
8namespace torch {
9namespace jit {
10namespace {
11using CanonicalArg = c10::variant<CanonicalizedSymbolicShape, IValue>;
12using CanonicalArgVec = std::vector<CanonicalArg>;
13using CanonicalRet = std::vector<CanonicalizedSymbolicShape>;
14using ShapeCacheKey = std::tuple<c10::OperatorName, CanonicalArgVec>;
15
16CanonicalArgVec cannonicalizeVec(
17 const std::vector<SSAInput>& arg_vec,
18 std::unordered_map<int64_t, int64_t>& ss_map,
19 bool deep_copy = true) {
20 CanonicalArgVec canonical_args;
21 canonical_args.reserve(arg_vec.size());
22 for (auto& arg : arg_vec) {
23 if (const IValue* iv = c10::get_if<IValue>(&arg)) {
24 if (deep_copy) {
25 canonical_args.emplace_back(iv->deepcopy());
26 } else {
27 canonical_args.emplace_back(*iv);
28 }
29 } else {
30 auto& ss = c10::get<at::SymbolicShape>(arg);
31 canonical_args.emplace_back(CanonicalizedSymbolicShape(ss, ss_map));
32 }
33 }
34 return canonical_args;
35}
36
37std::vector<CanonicalizedSymbolicShape> cannonicalizeVec(
38 const std::vector<at::SymbolicShape>& ret_vec,
39 std::unordered_map<int64_t, int64_t>& ss_map) {
40 std::vector<CanonicalizedSymbolicShape> canonical_rets;
41 canonical_rets.reserve(ret_vec.size());
42 for (auto& ss : ret_vec) {
43 canonical_rets.emplace_back(ss, ss_map);
44 }
45 return canonical_rets;
46}
47
48struct ArgumentsHasher {
49 size_t operator()(const ShapeCacheKey& cacheKey) const {
50 // TODO: ignore arguments that are not used in shape function (not needed
51 // initially)
52 auto& op_name = std::get<0>(cacheKey);
53 auto& arg_vec = std::get<1>(cacheKey);
54
55 size_t hash_val = c10::hash<c10::OperatorName>()(op_name);
56
57 hash_val = at::hash_combine(std::hash<size_t>{}(arg_vec.size()), hash_val);
58 for (const CanonicalArg& arg : arg_vec) {
59 size_t cur_arg = 0;
60 if (const IValue* ival = c10::get_if<IValue>(&arg)) {
61 // IValue doesn't hash List (as Python doesn't), so we will do a custom
62 // list hash
63 if (ival->isList()) {
64 TORCH_INTERNAL_ASSERT(ival->isIntList(), "Unexpected Args in List");
65 cur_arg = ival->toListRef().size();
66 for (const IValue& elem_ival : ival->toListRef()) {
67 cur_arg = at::hash_combine(cur_arg, IValue::hash(elem_ival));
68 }
69 } else {
70 cur_arg = IValue::hash(ival);
71 }
72 } else {
73 cur_arg = c10::get<CanonicalizedSymbolicShape>(arg).hash();
74 }
75 hash_val = at::hash_combine(hash_val, cur_arg);
76 }
77 return hash_val;
78 }
79};
80
81using ShapeCache = lazy::Cache<
82 ShapeCacheKey,
83 std::vector<CanonicalizedSymbolicShape>,
84 ArgumentsHasher>;
85
86constexpr size_t kShapeCacheSize = 1024;
87ShapeCache shapeCache(kShapeCacheSize);
88
89ShapeCacheKey get_cache_key(
90 const FunctionSchema* schema,
91 const std::vector<SSAInput>& arg_vec,
92 std::unordered_map<int64_t, int64_t>& ss_map,
93 bool deep_copy = true) {
94 CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy);
95 return std::make_tuple(schema->operator_name(), canonical_args);
96}
97
98} // namespace
99
100TORCH_API void cache_shape_function(
101 const FunctionSchema* schema,
102 const std::vector<SSAInput>& arg_vec,
103 const std::vector<at::SymbolicShape>& ret_vec) {
104 // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>>
105 auto ss_map = std::unordered_map<int64_t, int64_t>();
106 auto cache_key = get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ true);
107 auto can_ret_vec = std::make_shared<std::vector<CanonicalizedSymbolicShape>>(
108 cannonicalizeVec(ret_vec, ss_map));
109 shapeCache.Add(std::move(cache_key), std::move(can_ret_vec));
110}
111
112TORCH_API c10::optional<std::vector<at::SymbolicShape>>
113get_cached_shape_function(
114 const FunctionSchema* schema,
115 const std::vector<SSAInput>& arg_vec) {
116 // TODO: compare perf using std::vector<std::tuple<int64_t, int64_t>> for both
117 // ss_map and inverse_ss_map
118 auto ss_map = std::unordered_map<int64_t, int64_t>();
119 auto cache_key =
120 get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false);
121 auto cached_ret_vec = shapeCache.Get(cache_key);
122 if (cached_ret_vec == nullptr) {
123 return c10::nullopt;
124 }
125 // Decanonicalize the return values
126 auto inverse_ss_map = std::unordered_map<int64_t, int64_t>();
127 for (auto& ss_val : ss_map) {
128 inverse_ss_map[ss_val.second] = ss_val.first;
129 }
130 std::vector<at::SymbolicShape> ret_vec;
131 for (auto& css : *cached_ret_vec) {
132 ret_vec.emplace_back(css.toSymbolicShape(inverse_ss_map));
133 }
134 return ret_vec;
135}
136
137// Function only to access the cache, used for testing
138TORCH_API void clear_shape_cache() {
139 shapeCache.Clear();
140}
141
142TORCH_API size_t get_shape_cache_size() {
143 return shapeCache.Numel();
144}
145
146void CanonicalizedSymbolicShape::init(
147 const c10::SymbolicShape& orig_shape,
148 std::unordered_map<int64_t, int64_t>& ss_map) {
149 auto sizes = orig_shape.sizes();
150 if (!sizes) {
151 values_ = c10::nullopt;
152 return;
153 }
154 values_ = std::vector<int64_t>();
155 int64_t cur_symbolic_index = -static_cast<int64_t>(ss_map.size()) - 1;
156 for (auto& cur_shape : *sizes) {
157 if (cur_shape.is_static()) {
158 values_->push_back(cur_shape.static_size());
159 } else {
160 // Check for aliasing
161 auto it = ss_map.find(cur_shape.value());
162
163 if (it == ss_map.end()) {
164 values_->push_back(cur_symbolic_index);
165 ss_map.insert({cur_shape.value(), cur_symbolic_index});
166 cur_symbolic_index--;
167 } else {
168 values_->push_back(it->second);
169 }
170 }
171 }
172}
173
174c10::SymbolicShape CanonicalizedSymbolicShape::toSymbolicShape(
175 std::unordered_map<int64_t, int64_t>& inverse_ss_map) const {
176 if (!values_.has_value()) {
177 return c10::SymbolicShape();
178 }
179 std::vector<at::ShapeSymbol> sizes;
180 for (long long cur_val : *values_) {
181 if (cur_val >= 0) {
182 sizes.push_back(at::ShapeSymbol::fromStaticSize(cur_val));
183 continue;
184 }
185 auto res = inverse_ss_map.find(cur_val);
186 if (res != inverse_ss_map.end()) {
187 sizes.push_back(at::ShapeSymbol::fromStaticSize(res->second));
188 } else {
189 auto new_symbol = at::ShapeSymbol::newSymbol();
190 inverse_ss_map.insert({cur_val, new_symbol.value()});
191 sizes.push_back(new_symbol);
192 }
193 }
194 return c10::SymbolicShape(std::move(sizes));
195}
196
197size_t CanonicalizedSymbolicShape::hash() const {
198 if (!values_.has_value()) {
199 return 0x8cc80c80; // random value to prevent hash collisions
200 }
201 return c10::hash<std::vector<int64_t>>()(values_.value());
202}
203
204bool operator==(
205 const CanonicalizedSymbolicShape& a,
206 const CanonicalizedSymbolicShape& b) {
207 return a.values_ == b.values_;
208};
209} // namespace jit
210} // namespace torch
211