1#include <c10/util/irange.h>
2#include <torch/csrc/lazy/core/shape.h>
3#include <torch/csrc/lazy/core/tensor.h>
4
5C10_DEFINE_bool(
6 ltc_enable_symbolic_shapes,
7 false,
8 "Enables calculation of if dims are symbolic");
9
10namespace torch {
11namespace lazy {
12
13Shape::Shape(
14 at::ScalarType scalar_type,
15 c10::ArrayRef<int64_t> sizes,
16 c10::optional<std::vector<bool>> is_symbolic)
17 : scalar_type_(scalar_type),
18 sizes_(sizes.begin(), sizes.end()),
19 is_symbolic_(std::move(is_symbolic)) {}
20
21std::string Shape::to_string() const {
22 return c10::str(toString(scalar_type_), "[", c10::Join(",", sizes_), "]");
23}
24
25bool Shape::operator==(const Shape& other) const {
26 return scalar_type_ == other.scalar_type_ && sizes_ == other.sizes_;
27}
28
29std::ostream& operator<<(std::ostream& out, const Shape& shape) {
30 return out << shape.to_string();
31}
32
33size_t Shape::numel() const {
34 size_t elts = 1;
35 for (auto size : sizes_) {
36 elts *= size;
37 }
38 return elts;
39}
40
41hash_t Shape::hash(bool bakeInSizes) const {
42 if (bakeInSizes) {
43 return HashCombine(
44 Hash(scalar_type_),
45 DataHash(sizes_.data(), sizes_.size() * sizeof(int64_t)));
46 } else {
47 return HashCombine(Hash(scalar_type_), Hash(sizes_.size()));
48 }
49}
50
51Shape Shape::with_symbolic_dims(
52 c10::optional<std::vector<bool>> symbolic_dims) const {
53 Shape copy = *this;
54 copy.is_symbolic_ = symbolic_dims;
55 return copy;
56}
57
58bool symbolicShapeEnabled() {
59 static bool enabled = std::getenv("LTC_ENABLE_SYMBOLIC_SHAPES") != nullptr;
60 return enabled || FLAGS_ltc_enable_symbolic_shapes;
61}
62
63c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) {
64 auto ltc_tensor = TryGetLtcTensor(tensor);
65 if (!ltc_tensor) {
66 // Set Concrete sizes for Concrete tensors
67 return c10::SymbolicShape(tensor.sizes());
68 }
69 const Shape& input_shape = ltc_tensor->GetIrValue()->shape();
70 auto& is_symbolic = input_shape.is_symbolic();
71 if (!is_symbolic.has_value()) {
72 return c10::SymbolicShape();
73 }
74 auto sizes = input_shape.sizes();
75 TORCH_INTERNAL_ASSERT(
76 sizes.size() == is_symbolic->size(),
77 "Dims of two values are not consistent");
78 std::vector<c10::optional<int64_t>> symbolic_dims;
79 for (int64_t i = 0; i < sizes.size(); i++) {
80 if (is_symbolic->at(i)) {
81 symbolic_dims.emplace_back(c10::nullopt);
82 } else {
83 symbolic_dims.emplace_back(sizes.at(i));
84 }
85 }
86 return c10::SymbolicShape(symbolic_dims);
87}
88
89void applySymbolicShapesOnLT(
90 const char* schema_str,
91 std::vector<c10::IValue> args,
92 std::vector<Shape>& result_shapes) {
93 std::vector<jit::SSAInput> converted_args;
94 // TODO: Determine if there are any unknown values in LazyTensor
95 const c10::FunctionSchema& schema =
96 jit::getOperatorForLiteral(schema_str)->schema();
97
98 for (auto& arg : args) {
99 // Handle list of tensors
100 if (arg.isTensorList()) {
101 at::List<at::Tensor> tensor_list = arg.toTensorList();
102 for (at::Tensor tensor : tensor_list) {
103 converted_args.emplace_back(get_symbolic_shape(tensor));
104 }
105 } else if (arg.isTensor()) {
106 auto ss = get_symbolic_shape(arg.toTensor());
107 converted_args.emplace_back(ss);
108 } else {
109 // If we need to support symbolic ints, here is the place
110 // to add it.
111 converted_args.emplace_back(arg);
112 }
113 }
114 auto res_symbolic = jit::calculateSymbolicShapesOnOp(&schema, converted_args);
115 if (!res_symbolic) {
116 for (auto& result_shape : result_shapes) {
117 result_shape = result_shape.with_symbolic_dims(c10::nullopt);
118 }
119 } else {
120 TORCH_INTERNAL_ASSERT(
121 res_symbolic->size() == result_shapes.size(),
122 "Result shape size is not consistent");
123 for (int64_t i = 0; i < res_symbolic->size(); i++) {
124 auto sym_dims = res_symbolic->at(i).symbolicDims();
125 if (sym_dims.has_value()) {
126 result_shapes[i] = result_shapes[i].with_symbolic_dims(*sym_dims);
127 }
128 }
129 }
130}
131
132} // namespace lazy
133} // namespace torch
134