1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
5
6namespace torch {
7namespace jit {
8
9struct TORCH_API CanonicalizedSymbolicShape {
10 // TODO: Consider in the future if it is reasonable to
11 // merge code with SymbolicShape or VaryingShape while keeping
12 // the two not implicitly convertable (and cause bugs).
13 CanonicalizedSymbolicShape(
14 const c10::SymbolicShape& orig_shape,
15 std::unordered_map<int64_t, int64_t>& ss_map) {
16 init(orig_shape, ss_map);
17 }
18
19 CanonicalizedSymbolicShape(c10::SymbolicShape& orig_shape) {
20 std::unordered_map<int64_t, int64_t> new_ssmap;
21 init(orig_shape, new_ssmap);
22 }
23
24 size_t hash() const;
25
26 c10::SymbolicShape toSymbolicShape(
27 std::unordered_map<int64_t, int64_t>& inverse_ss_map) const;
28
29 TORCH_API friend bool operator==(
30 const CanonicalizedSymbolicShape& a,
31 const CanonicalizedSymbolicShape& b);
32
33 private:
34 c10::optional<std::vector<int64_t>> values_;
35
36 void init(
37 const c10::SymbolicShape& orig_shape,
38 std::unordered_map<int64_t, int64_t>& ss_map);
39};
40
41// SHAPE CACHE API
42TORCH_API c10::optional<std::vector<at::SymbolicShape>>
43get_cached_shape_function(
44 const FunctionSchema* schema,
45 const std::vector<SSAInput>& arg_vec);
46
47TORCH_API void cache_shape_function(
48 const FunctionSchema* schema,
49 const std::vector<SSAInput>& arg_vec,
50 const std::vector<at::SymbolicShape>& ret_vec);
51
52// For use in test code
53TORCH_API void clear_shape_cache();
54TORCH_API size_t get_shape_cache_size();
55
56} // namespace jit
57} // namespace torch
58