1#include <type_promotion.h>
2
3#include <arith.h>
4#include <ir_interface_nodes.h>
5
6#include <ATen/native/TypeProperties.h>
7#include <c10/core/ScalarType.h>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14namespace {
15
16enum ValueType { Tensor, Scalar, None };
17
18struct OperandType {
19 ValueType value_type = ValueType::Tensor;
20 c10::ScalarType scalar_type = c10::ScalarType::Undefined;
21 size_t dim = 0;
22};
23
24c10::ScalarType promoteTypesSkipUndefined(
25 c10::ScalarType a,
26 c10::ScalarType b) {
27 if (a == c10::ScalarType::Undefined) {
28 return b;
29 }
30 if (b == c10::ScalarType::Undefined) {
31 return a;
32 }
33 return c10::promoteTypes(a, b);
34}
35
36at::native::ResultTypeState updateResultTypeState(
37 OperandType tensor,
38 const at::native::ResultTypeState& in_state) {
39 at::native::ResultTypeState new_state = in_state;
40 c10::ScalarType current = tensor.scalar_type;
41
42 if (tensor.dim > 0) {
43 new_state.dimResult =
44 promoteTypesSkipUndefined(in_state.dimResult, current);
45 } else {
46 new_state.zeroResult =
47 promoteTypesSkipUndefined(in_state.zeroResult, current);
48 }
49 return new_state;
50}
51
52at::native::ResultTypeState updateResultTypeState(
53 const c10::ScalarType scalar,
54 const at::native::ResultTypeState& in_state) {
55 at::native::ResultTypeState new_state = in_state;
56 c10::ScalarType current = scalar;
57 if (c10::isFloatingType(scalar)) {
58 current = c10::typeMetaToScalarType(at::get_default_dtype());
59 }
60 new_state.wrappedResult =
61 promoteTypesSkipUndefined(in_state.wrappedResult, current);
62 return new_state;
63}
64
65// Computes a common dtype using type promotion
66c10::ScalarType computeCommonDtype(const std::vector<OperandType>& operands) {
67 at::native::ResultTypeState state = {};
68 for (const auto& op : operands) {
69 if (op.value_type == ValueType::Tensor) {
70 state = updateResultTypeState(op, state);
71 } else {
72 state = updateResultTypeState(op.scalar_type, state);
73 }
74 }
75 auto common_dtype = at::native::result_type(state);
76 TORCH_INTERNAL_ASSERT(common_dtype != c10::ScalarType::Undefined);
77 return common_dtype;
78}
79
80c10::ScalarType computeTypes(
81 const TypePromotionConfig& config,
82 const std::vector<OperandType>& operands) {
83 auto common_dtype = c10::ScalarType::Undefined;
84
85 bool has_different_input_dtypes = false;
86 for (auto& op : operands) {
87 if (op.scalar_type != common_dtype) {
88 if (common_dtype == c10::ScalarType::Undefined) {
89 common_dtype = op.scalar_type;
90 } else {
91 has_different_input_dtypes = true;
92 }
93 }
94 }
95
96 // Computes a common dtype, if needed
97 if (has_different_input_dtypes) {
98 common_dtype = computeCommonDtype(operands);
99 }
100
101 // Promotes common dtype to the default float scalar type, if needed
102 if (config.promote_integer_inputs_to_float &&
103 c10::isIntegralType(common_dtype, /*includeBool=*/true)) {
104 common_dtype = c10::get_default_dtype_as_scalartype();
105 }
106 return common_dtype;
107}
108
109OperandType getValueType(TypePtr type) {
110 if (auto tensor_type = type->cast<TensorType>()) {
111 TORCH_INTERNAL_ASSERT(
112 tensor_type->scalarType().has_value(),
113 "Missing Scalar Type information");
114 // TODO: Type Inference does not propagate Shape Information
115 return {
116 ValueType::Tensor,
117 tensor_type->scalarType().value(),
118 tensor_type->dim().has_value() ? tensor_type->dim().value() : 1};
119 } else if (auto scalar_type = tryScalarTypeFromJitType(*type)) {
120 return {ValueType::Scalar, scalar_type.value()};
121 } else {
122 return {ValueType::None, c10::ScalarType::Undefined};
123 }
124}
125
126OperandType getValueType(Val* type) {
127 TORCH_INTERNAL_ASSERT(type->getDataType().has_value());
128
129 if (type->isA<TensorView>()) {
130 auto tensor_view = type->as<TensorView>();
131 return {
132 ValueType::Tensor,
133 data_type_to_aten(tensor_view->getDataType().value()),
134 tensor_view->getMaybeRFactorDomain().size()};
135 } else if (type->getDataType().has_value()) {
136 return {ValueType::Scalar, data_type_to_aten(type->getDataType().value())};
137 } else {
138 return {ValueType::None, c10::ScalarType::Undefined};
139 }
140}
141
142} // namespace
143
144c10::ScalarType computeTypes(
145 const TypePromotionConfig& config,
146 const std::vector<TypePtr>& operands) {
147 std::vector<OperandType> vt_operands;
148 vt_operands.reserve(operands.size());
149 for (const auto& op : operands) {
150 vt_operands.emplace_back(getValueType(op));
151 }
152 return computeTypes(config, vt_operands);
153}
154
155DataType computeTypes(
156 const TypePromotionConfig& config,
157 const std::vector<Val*>& operands) {
158 std::vector<OperandType> vt_operands;
159 vt_operands.reserve(operands.size());
160 for (const auto& op : operands) {
161 vt_operands.push_back(getValueType(op));
162 }
163
164 auto common_type = aten_to_data_type(computeTypes(config, vt_operands));
165
166 // Cast FP16 / BFloat16 to Float
167 if (common_type == DataType::Half || common_type == DataType::BFloat16) {
168 common_type = DataType::Float;
169 }
170
171 return common_type;
172}
173
174std::vector<Val*> promoteValues(
175 const std::vector<Val*>& operands,
176 DataType common_type) {
177 std::vector<Val*> promoted_operands;
178 promoted_operands.reserve(operands.size());
179 for (auto op : operands) {
180 promoted_operands.push_back(optionalCast(common_type, op));
181 }
182
183 TORCH_INTERNAL_ASSERT(operands.size() == promoted_operands.size());
184 return promoted_operands;
185}
186
187std::vector<Val*> promoteValues(
188 const TypePromotionConfig& config,
189 const std::vector<Val*>& operands) {
190 return promoteValues(operands, computeTypes(config, operands));
191}
192
193Val* optionalCast(DataType dtype, Val* v) {
194 TORCH_INTERNAL_ASSERT(v->getDataType().has_value());
195 // Avoid casting Float/Int/ComplexDouble scalar to any corresponding
196 // FloatingPoint/Integral/Double type in fusion. Instead, we cast them
197 // directly. The exception is Bool, which is always cast to the desired
198 // type.
199 const bool kSameDtype = v->getDataType().value() == dtype;
200 const bool kIsScalarFloat =
201 !v->isA<TensorView>() && isFloatingPointType(dtype);
202 const bool kIsScalarInt = !v->isA<TensorView>() && isIntegralType(dtype);
203 const bool kIsScalarComplex = !v->isA<TensorView>() && isComplexType(dtype);
204 if (kSameDtype ||
205 (kIsScalarFloat && isFloatingPointType(v->getDataType().value())) ||
206 (kIsScalarInt && isIntegralType(v->getDataType().value())) ||
207 (kIsScalarComplex && isComplexType(v->getDataType().value()))) {
208 return v;
209 } else {
210 return castOp(dtype, v);
211 }
212}
213
214Val* optionalCastStrict(DataType dtype, Val* v) {
215 TORCH_INTERNAL_ASSERT(v->getDataType().has_value());
216 const bool kSameDtype = v->getDataType().value() == dtype;
217 return (kSameDtype) ? v : castOp(dtype, v);
218}
219
220} // namespace cuda
221} // namespace fuser
222} // namespace jit
223} // namespace torch
224