1 | #pragma once |
2 | |
3 | #include <ATen/Context.h> |
4 | #include <ATen/native/TypeProperties.h> |
5 | #include <c10/core/ScalarType.h> |
6 | #include <ir_interface_nodes.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | //! |
15 | //! The TypePromotionConfig flags are derived from Aten/TensorIterator.h |
16 | //! |
17 | //! 1) check_all_same_dtype_ flag checks that all inputs and defined outputs |
18 | //! have the same dtype. Default = False |
19 | //! |
20 | //! 2) promote_inputs_to_common_dtype flag will cast the inputs to the common |
21 | //! dtype. Default = True |
22 | //! |
23 | //! 3) promote_integer_inputs_to_float flag will cast the common dtype to the |
24 | //! default float scalar type if it is an integral type (including bool). |
25 | //! |
26 | struct TypePromotionConfig { |
27 | bool promote_integer_inputs_to_float = false; |
28 | }; |
29 | |
30 | namespace TypePromotion { |
31 | |
32 | static const TypePromotionConfig comparison_op_config; |
33 | static const TypePromotionConfig default_op_config; |
34 | static const TypePromotionConfig float_op_config{ |
35 | /* promote_integer_inputs_to_float */ true}; |
36 | |
37 | } // namespace TypePromotion |
38 | |
39 | // Implements the the behavior of the following flags: |
40 | // - promote_inputs_to_common_dtype |
41 | // - promote_integer_inputs_to_float |
42 | c10::ScalarType computeTypes( |
43 | const TypePromotionConfig& config, |
44 | const std::vector<TypePtr>& operands); |
45 | |
46 | DataType computeTypes( |
47 | const TypePromotionConfig& config, |
48 | const std::vector<Val*>& operands); |
49 | |
50 | // Computes the common dtype for the given operands |
51 | // Casts operands to common dtype if necessary |
52 | // Automatically cast FP16/BF16 dtype to Float |
53 | std::vector<Val*> promoteValues( |
54 | const TypePromotionConfig& config, |
55 | const std::vector<Val*>& operands); |
56 | |
57 | std::vector<Val*> promoteValues( |
58 | const std::vector<Val*>& operands, |
59 | DataType common_type); |
60 | |
61 | // Casts value to common dtype if necessary |
62 | // Avoid cast if value's dtype matches its dtype class |
63 | Val* optionalCast(DataType dtype, Val* v); |
64 | |
65 | // Casts value to common dtype if necessary |
66 | Val* optionalCastStrict(DataType dtype, Val* v); |
67 | |
68 | } // namespace cuda |
69 | } // namespace fuser |
70 | } // namespace jit |
71 | } // namespace torch |
72 | |