1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | struct TORCH_API CanonicalizedSymbolicShape { |
10 | // TODO: Consider in the future if it is reasonable to |
11 | // merge code with SymbolicShape or VaryingShape while keeping |
12 | // the two not implicitly convertable (and cause bugs). |
13 | CanonicalizedSymbolicShape( |
14 | const c10::SymbolicShape& orig_shape, |
15 | std::unordered_map<int64_t, int64_t>& ss_map) { |
16 | init(orig_shape, ss_map); |
17 | } |
18 | |
19 | CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) { |
20 | std::unordered_map<int64_t, int64_t> new_ssmap; |
21 | init(orig_shape, new_ssmap); |
22 | } |
23 | |
24 | size_t hash() const; |
25 | |
26 | c10::SymbolicShape toSymbolicShape( |
27 | std::unordered_map<int64_t, int64_t>& inverse_ss_map) const; |
28 | |
29 | TORCH_API friend bool operator==( |
30 | const CanonicalizedSymbolicShape& a, |
31 | const CanonicalizedSymbolicShape& b); |
32 | |
33 | private: |
34 | c10::optional<std::vector<int64_t>> values_; |
35 | |
36 | void init( |
37 | const c10::SymbolicShape& orig_shape, |
38 | std::unordered_map<int64_t, int64_t>& ss_map); |
39 | }; |
40 | |
41 | // SHAPE CACHE API |
42 | TORCH_API c10::optional<std::vector<at::SymbolicShape>> |
43 | get_cached_shape_function( |
44 | const FunctionSchema* schema, |
45 | const std::vector<SSAInput>& arg_vec); |
46 | |
47 | TORCH_API void cache_shape_function( |
48 | const FunctionSchema* schema, |
49 | const std::vector<SSAInput>& arg_vec, |
50 | const std::vector<at::SymbolicShape>& ret_vec); |
51 | |
52 | // For use in test code |
53 | TORCH_API void clear_shape_cache(); |
54 | TORCH_API size_t get_shape_cache_size(); |
55 | |
56 | } // namespace jit |
57 | } // namespace torch |
58 | |