1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/lazy/core/shape.h> |
3 | #include <torch/csrc/lazy/core/tensor.h> |
4 | |
5 | C10_DEFINE_bool( |
6 | ltc_enable_symbolic_shapes, |
7 | false, |
8 | "Enables calculation of if dims are symbolic" ); |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | |
13 | Shape::Shape( |
14 | at::ScalarType scalar_type, |
15 | c10::ArrayRef<int64_t> sizes, |
16 | c10::optional<std::vector<bool>> is_symbolic) |
17 | : scalar_type_(scalar_type), |
18 | sizes_(sizes.begin(), sizes.end()), |
19 | is_symbolic_(std::move(is_symbolic)) {} |
20 | |
21 | std::string Shape::to_string() const { |
22 | return c10::str(toString(scalar_type_), "[" , c10::Join("," , sizes_), "]" ); |
23 | } |
24 | |
25 | bool Shape::operator==(const Shape& other) const { |
26 | return scalar_type_ == other.scalar_type_ && sizes_ == other.sizes_; |
27 | } |
28 | |
29 | std::ostream& operator<<(std::ostream& out, const Shape& shape) { |
30 | return out << shape.to_string(); |
31 | } |
32 | |
33 | size_t Shape::numel() const { |
34 | size_t elts = 1; |
35 | for (auto size : sizes_) { |
36 | elts *= size; |
37 | } |
38 | return elts; |
39 | } |
40 | |
41 | hash_t Shape::hash(bool bakeInSizes) const { |
42 | if (bakeInSizes) { |
43 | return HashCombine( |
44 | Hash(scalar_type_), |
45 | DataHash(sizes_.data(), sizes_.size() * sizeof(int64_t))); |
46 | } else { |
47 | return HashCombine(Hash(scalar_type_), Hash(sizes_.size())); |
48 | } |
49 | } |
50 | |
51 | Shape Shape::with_symbolic_dims( |
52 | c10::optional<std::vector<bool>> symbolic_dims) const { |
53 | Shape copy = *this; |
54 | copy.is_symbolic_ = symbolic_dims; |
55 | return copy; |
56 | } |
57 | |
58 | bool symbolicShapeEnabled() { |
59 | static bool enabled = std::getenv("LTC_ENABLE_SYMBOLIC_SHAPES" ) != nullptr; |
60 | return enabled || FLAGS_ltc_enable_symbolic_shapes; |
61 | } |
62 | |
63 | c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) { |
64 | auto ltc_tensor = TryGetLtcTensor(tensor); |
65 | if (!ltc_tensor) { |
66 | // Set Concrete sizes for Concrete tensors |
67 | return c10::SymbolicShape(tensor.sizes()); |
68 | } |
69 | const Shape& input_shape = ltc_tensor->GetIrValue()->shape(); |
70 | auto& is_symbolic = input_shape.is_symbolic(); |
71 | if (!is_symbolic.has_value()) { |
72 | return c10::SymbolicShape(); |
73 | } |
74 | auto sizes = input_shape.sizes(); |
75 | TORCH_INTERNAL_ASSERT( |
76 | sizes.size() == is_symbolic->size(), |
77 | "Dims of two values are not consistent" ); |
78 | std::vector<c10::optional<int64_t>> symbolic_dims; |
79 | for (int64_t i = 0; i < sizes.size(); i++) { |
80 | if (is_symbolic->at(i)) { |
81 | symbolic_dims.emplace_back(c10::nullopt); |
82 | } else { |
83 | symbolic_dims.emplace_back(sizes.at(i)); |
84 | } |
85 | } |
86 | return c10::SymbolicShape(symbolic_dims); |
87 | } |
88 | |
89 | void applySymbolicShapesOnLT( |
90 | const char* schema_str, |
91 | std::vector<c10::IValue> args, |
92 | std::vector<Shape>& result_shapes) { |
93 | std::vector<jit::SSAInput> converted_args; |
94 | // TODO: Determine if there are any unknown values in LazyTensor |
95 | const c10::FunctionSchema& schema = |
96 | jit::getOperatorForLiteral(schema_str)->schema(); |
97 | |
98 | for (auto& arg : args) { |
99 | // Handle list of tensors |
100 | if (arg.isTensorList()) { |
101 | at::List<at::Tensor> tensor_list = arg.toTensorList(); |
102 | for (at::Tensor tensor : tensor_list) { |
103 | converted_args.emplace_back(get_symbolic_shape(tensor)); |
104 | } |
105 | } else if (arg.isTensor()) { |
106 | auto ss = get_symbolic_shape(arg.toTensor()); |
107 | converted_args.emplace_back(ss); |
108 | } else { |
109 | // If we need to support symbolic ints, here is the place |
110 | // to add it. |
111 | converted_args.emplace_back(arg); |
112 | } |
113 | } |
114 | auto res_symbolic = jit::calculateSymbolicShapesOnOp(&schema, converted_args); |
115 | if (!res_symbolic) { |
116 | for (auto& result_shape : result_shapes) { |
117 | result_shape = result_shape.with_symbolic_dims(c10::nullopt); |
118 | } |
119 | } else { |
120 | TORCH_INTERNAL_ASSERT( |
121 | res_symbolic->size() == result_shapes.size(), |
122 | "Result shape size is not consistent" ); |
123 | for (int64_t i = 0; i < res_symbolic->size(); i++) { |
124 | auto sym_dims = res_symbolic->at(i).symbolicDims(); |
125 | if (sym_dims.has_value()) { |
126 | result_shapes[i] = result_shapes[i].with_symbolic_dims(*sym_dims); |
127 | } |
128 | } |
129 | } |
130 | } |
131 | |
132 | } // namespace lazy |
133 | } // namespace torch |
134 | |