1#pragma once
2
3#include <ostream>
4#include <vector>
5
6#include <c10/core/Scalar.h>
7#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
8#include <torch/csrc/lazy/core/hash.h>
9
10C10_DECLARE_bool(ltc_enable_symbolic_shapes);
11
12namespace torch {
13namespace lazy {
14
15class TORCH_API Shape {
16 public:
17 Shape() = default;
18
19 Shape(
20 at::ScalarType scalar_type,
21 c10::ArrayRef<int64_t> sizes,
22 c10::optional<std::vector<bool>> is_symbolic = c10::nullopt);
23
24 std::string to_string() const;
25
26 c10::ScalarType scalar_type() const {
27 return scalar_type_;
28 }
29 void set_scalar_type(at::ScalarType value) {
30 scalar_type_ = value;
31 }
32
33 int64_t dim() const {
34 return sizes_.size();
35 }
36 c10::ArrayRef<int64_t> sizes() const {
37 return sizes_;
38 }
39 int64_t size(int64_t dim) const {
40 return sizes_.at(dim);
41 }
42 void set_size(int64_t dim, int64_t size) {
43 sizes_.at(dim) = size;
44 }
45
46 const c10::optional<std::vector<bool>>& is_symbolic() const {
47 return is_symbolic_;
48 }
49
50 // Makes a copy with symbolic dims applied
51 Shape with_symbolic_dims(
52 c10::optional<std::vector<bool>> symbolic_dims) const;
53
54 size_t numel() const;
55 hash_t hash(bool bakeInSizes) const;
56
57 bool operator==(const Shape& other) const;
58
59 private:
60 c10::ScalarType scalar_type_{c10::ScalarType::Undefined};
61
62 // Sizes are the upper bound sizes for a tensor, used by XLA.
63 std::vector<int64_t> sizes_;
64 // Stores which dimmensions are symbolic
65 // If nullopt, either it hasn't been initialized or the symbolic
66 // dimmensions are not calculatable
67 c10::optional<std::vector<bool>> is_symbolic_ = c10::nullopt;
68};
69
70TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape);
71
72TORCH_API bool symbolicShapeEnabled();
73// Calculate and applies symbolic shapes onto the
74// Shape objects passed to result_shapes
75TORCH_API void applySymbolicShapesOnLT(
76 const char* schema_str,
77 std::vector<c10::IValue> args,
78 std::vector<Shape>& result_shapes);
79} // namespace lazy
80} // namespace torch
81