1 | #pragma once |
---|---|
2 | |
3 | #include <ostream> |
4 | #include <vector> |
5 | |
6 | #include <c10/core/Scalar.h> |
7 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
8 | #include <torch/csrc/lazy/core/hash.h> |
9 | |
10 | C10_DECLARE_bool(ltc_enable_symbolic_shapes); |
11 | |
12 | namespace torch { |
13 | namespace lazy { |
14 | |
15 | class TORCH_API Shape { |
16 | public: |
17 | Shape() = default; |
18 | |
19 | Shape( |
20 | at::ScalarType scalar_type, |
21 | c10::ArrayRef<int64_t> sizes, |
22 | c10::optional<std::vector<bool>> is_symbolic = c10::nullopt); |
23 | |
24 | std::string to_string() const; |
25 | |
26 | c10::ScalarType scalar_type() const { |
27 | return scalar_type_; |
28 | } |
29 | void set_scalar_type(at::ScalarType value) { |
30 | scalar_type_ = value; |
31 | } |
32 | |
33 | int64_t dim() const { |
34 | return sizes_.size(); |
35 | } |
36 | c10::ArrayRef<int64_t> sizes() const { |
37 | return sizes_; |
38 | } |
39 | int64_t size(int64_t dim) const { |
40 | return sizes_.at(dim); |
41 | } |
42 | void set_size(int64_t dim, int64_t size) { |
43 | sizes_.at(dim) = size; |
44 | } |
45 | |
46 | const c10::optional<std::vector<bool>>& is_symbolic() const { |
47 | return is_symbolic_; |
48 | } |
49 | |
50 | // Makes a copy with symbolic dims applied |
51 | Shape with_symbolic_dims( |
52 | c10::optional<std::vector<bool>> symbolic_dims) const; |
53 | |
54 | size_t numel() const; |
55 | hash_t hash(bool bakeInSizes) const; |
56 | |
57 | bool operator==(const Shape& other) const; |
58 | |
59 | private: |
60 | c10::ScalarType scalar_type_{c10::ScalarType::Undefined}; |
61 | |
62 | // Sizes are the upper bound sizes for a tensor, used by XLA. |
63 | std::vector<int64_t> sizes_; |
64 | // Stores which dimmensions are symbolic |
65 | // If nullopt, either it hasn't been initialized or the symbolic |
66 | // dimmensions are not calculatable |
67 | c10::optional<std::vector<bool>> is_symbolic_ = c10::nullopt; |
68 | }; |
69 | |
70 | TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape); |
71 | |
72 | TORCH_API bool symbolicShapeEnabled(); |
73 | // Calculate and applies symbolic shapes onto the |
74 | // Shape objects passed to result_shapes |
75 | TORCH_API void applySymbolicShapesOnLT( |
76 | const char* schema_str, |
77 | std::vector<c10::IValue> args, |
78 | std::vector<Shape>& result_shapes); |
79 | } // namespace lazy |
80 | } // namespace torch |
81 |