1 | #include <type_promotion.h> |
2 | |
3 | #include <arith.h> |
4 | #include <ir_interface_nodes.h> |
5 | |
6 | #include <ATen/native/TypeProperties.h> |
7 | #include <c10/core/ScalarType.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | namespace { |
15 | |
16 | enum ValueType { Tensor, Scalar, None }; |
17 | |
18 | struct OperandType { |
19 | ValueType value_type = ValueType::Tensor; |
20 | c10::ScalarType scalar_type = c10::ScalarType::Undefined; |
21 | size_t dim = 0; |
22 | }; |
23 | |
24 | c10::ScalarType promoteTypesSkipUndefined( |
25 | c10::ScalarType a, |
26 | c10::ScalarType b) { |
27 | if (a == c10::ScalarType::Undefined) { |
28 | return b; |
29 | } |
30 | if (b == c10::ScalarType::Undefined) { |
31 | return a; |
32 | } |
33 | return c10::promoteTypes(a, b); |
34 | } |
35 | |
36 | at::native::ResultTypeState updateResultTypeState( |
37 | OperandType tensor, |
38 | const at::native::ResultTypeState& in_state) { |
39 | at::native::ResultTypeState new_state = in_state; |
40 | c10::ScalarType current = tensor.scalar_type; |
41 | |
42 | if (tensor.dim > 0) { |
43 | new_state.dimResult = |
44 | promoteTypesSkipUndefined(in_state.dimResult, current); |
45 | } else { |
46 | new_state.zeroResult = |
47 | promoteTypesSkipUndefined(in_state.zeroResult, current); |
48 | } |
49 | return new_state; |
50 | } |
51 | |
52 | at::native::ResultTypeState updateResultTypeState( |
53 | const c10::ScalarType scalar, |
54 | const at::native::ResultTypeState& in_state) { |
55 | at::native::ResultTypeState new_state = in_state; |
56 | c10::ScalarType current = scalar; |
57 | if (c10::isFloatingType(scalar)) { |
58 | current = c10::typeMetaToScalarType(at::get_default_dtype()); |
59 | } |
60 | new_state.wrappedResult = |
61 | promoteTypesSkipUndefined(in_state.wrappedResult, current); |
62 | return new_state; |
63 | } |
64 | |
65 | // Computes a common dtype using type promotion |
66 | c10::ScalarType computeCommonDtype(const std::vector<OperandType>& operands) { |
67 | at::native::ResultTypeState state = {}; |
68 | for (const auto& op : operands) { |
69 | if (op.value_type == ValueType::Tensor) { |
70 | state = updateResultTypeState(op, state); |
71 | } else { |
72 | state = updateResultTypeState(op.scalar_type, state); |
73 | } |
74 | } |
75 | auto common_dtype = at::native::result_type(state); |
76 | TORCH_INTERNAL_ASSERT(common_dtype != c10::ScalarType::Undefined); |
77 | return common_dtype; |
78 | } |
79 | |
80 | c10::ScalarType computeTypes( |
81 | const TypePromotionConfig& config, |
82 | const std::vector<OperandType>& operands) { |
83 | auto common_dtype = c10::ScalarType::Undefined; |
84 | |
85 | bool has_different_input_dtypes = false; |
86 | for (auto& op : operands) { |
87 | if (op.scalar_type != common_dtype) { |
88 | if (common_dtype == c10::ScalarType::Undefined) { |
89 | common_dtype = op.scalar_type; |
90 | } else { |
91 | has_different_input_dtypes = true; |
92 | } |
93 | } |
94 | } |
95 | |
96 | // Computes a common dtype, if needed |
97 | if (has_different_input_dtypes) { |
98 | common_dtype = computeCommonDtype(operands); |
99 | } |
100 | |
101 | // Promotes common dtype to the default float scalar type, if needed |
102 | if (config.promote_integer_inputs_to_float && |
103 | c10::isIntegralType(common_dtype, /*includeBool=*/true)) { |
104 | common_dtype = c10::get_default_dtype_as_scalartype(); |
105 | } |
106 | return common_dtype; |
107 | } |
108 | |
109 | OperandType getValueType(TypePtr type) { |
110 | if (auto tensor_type = type->cast<TensorType>()) { |
111 | TORCH_INTERNAL_ASSERT( |
112 | tensor_type->scalarType().has_value(), |
113 | "Missing Scalar Type information" ); |
114 | // TODO: Type Inference does not propagate Shape Information |
115 | return { |
116 | ValueType::Tensor, |
117 | tensor_type->scalarType().value(), |
118 | tensor_type->dim().has_value() ? tensor_type->dim().value() : 1}; |
119 | } else if (auto scalar_type = tryScalarTypeFromJitType(*type)) { |
120 | return {ValueType::Scalar, scalar_type.value()}; |
121 | } else { |
122 | return {ValueType::None, c10::ScalarType::Undefined}; |
123 | } |
124 | } |
125 | |
126 | OperandType getValueType(Val* type) { |
127 | TORCH_INTERNAL_ASSERT(type->getDataType().has_value()); |
128 | |
129 | if (type->isA<TensorView>()) { |
130 | auto tensor_view = type->as<TensorView>(); |
131 | return { |
132 | ValueType::Tensor, |
133 | data_type_to_aten(tensor_view->getDataType().value()), |
134 | tensor_view->getMaybeRFactorDomain().size()}; |
135 | } else if (type->getDataType().has_value()) { |
136 | return {ValueType::Scalar, data_type_to_aten(type->getDataType().value())}; |
137 | } else { |
138 | return {ValueType::None, c10::ScalarType::Undefined}; |
139 | } |
140 | } |
141 | |
142 | } // namespace |
143 | |
144 | c10::ScalarType computeTypes( |
145 | const TypePromotionConfig& config, |
146 | const std::vector<TypePtr>& operands) { |
147 | std::vector<OperandType> vt_operands; |
148 | vt_operands.reserve(operands.size()); |
149 | for (const auto& op : operands) { |
150 | vt_operands.emplace_back(getValueType(op)); |
151 | } |
152 | return computeTypes(config, vt_operands); |
153 | } |
154 | |
155 | DataType computeTypes( |
156 | const TypePromotionConfig& config, |
157 | const std::vector<Val*>& operands) { |
158 | std::vector<OperandType> vt_operands; |
159 | vt_operands.reserve(operands.size()); |
160 | for (const auto& op : operands) { |
161 | vt_operands.push_back(getValueType(op)); |
162 | } |
163 | |
164 | auto common_type = aten_to_data_type(computeTypes(config, vt_operands)); |
165 | |
166 | // Cast FP16 / BFloat16 to Float |
167 | if (common_type == DataType::Half || common_type == DataType::BFloat16) { |
168 | common_type = DataType::Float; |
169 | } |
170 | |
171 | return common_type; |
172 | } |
173 | |
174 | std::vector<Val*> promoteValues( |
175 | const std::vector<Val*>& operands, |
176 | DataType common_type) { |
177 | std::vector<Val*> promoted_operands; |
178 | promoted_operands.reserve(operands.size()); |
179 | for (auto op : operands) { |
180 | promoted_operands.push_back(optionalCast(common_type, op)); |
181 | } |
182 | |
183 | TORCH_INTERNAL_ASSERT(operands.size() == promoted_operands.size()); |
184 | return promoted_operands; |
185 | } |
186 | |
187 | std::vector<Val*> promoteValues( |
188 | const TypePromotionConfig& config, |
189 | const std::vector<Val*>& operands) { |
190 | return promoteValues(operands, computeTypes(config, operands)); |
191 | } |
192 | |
193 | Val* optionalCast(DataType dtype, Val* v) { |
194 | TORCH_INTERNAL_ASSERT(v->getDataType().has_value()); |
195 | // Avoid casting Float/Int/ComplexDouble scalar to any corresponding |
196 | // FloatingPoint/Integral/Double type in fusion. Instead, we cast them |
197 | // directly. The exception is Bool, which is always cast to the desired |
198 | // type. |
199 | const bool kSameDtype = v->getDataType().value() == dtype; |
200 | const bool kIsScalarFloat = |
201 | !v->isA<TensorView>() && isFloatingPointType(dtype); |
202 | const bool kIsScalarInt = !v->isA<TensorView>() && isIntegralType(dtype); |
203 | const bool kIsScalarComplex = !v->isA<TensorView>() && isComplexType(dtype); |
204 | if (kSameDtype || |
205 | (kIsScalarFloat && isFloatingPointType(v->getDataType().value())) || |
206 | (kIsScalarInt && isIntegralType(v->getDataType().value())) || |
207 | (kIsScalarComplex && isComplexType(v->getDataType().value()))) { |
208 | return v; |
209 | } else { |
210 | return castOp(dtype, v); |
211 | } |
212 | } |
213 | |
214 | Val* optionalCastStrict(DataType dtype, Val* v) { |
215 | TORCH_INTERNAL_ASSERT(v->getDataType().has_value()); |
216 | const bool kSameDtype = v->getDataType().value() == dtype; |
217 | return (kSameDtype) ? v : castOp(dtype, v); |
218 | } |
219 | |
220 | } // namespace cuda |
221 | } // namespace fuser |
222 | } // namespace jit |
223 | } // namespace torch |
224 | |