1#pragma once
2
3#include <ATen/Context.h>
4#include <ATen/native/TypeProperties.h>
5#include <c10/core/ScalarType.h>
6#include <ir_interface_nodes.h>
7#include <torch/csrc/jit/ir/ir.h>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14//!
15//! The TypePromotionConfig flags are derived from Aten/TensorIterator.h
16//!
17//! 1) check_all_same_dtype_ flag checks that all inputs and defined outputs
18//! have the same dtype. Default = False
19//!
20//! 2) promote_inputs_to_common_dtype flag will cast the inputs to the common
21//! dtype. Default = True
22//!
23//! 3) promote_integer_inputs_to_float flag will cast the common dtype to the
24//! default float scalar type if it is an integral type (including bool).
25//!
26struct TypePromotionConfig {
27 bool promote_integer_inputs_to_float = false;
28};
29
30namespace TypePromotion {
31
32static const TypePromotionConfig comparison_op_config;
33static const TypePromotionConfig default_op_config;
34static const TypePromotionConfig float_op_config{
35 /* promote_integer_inputs_to_float */ true};
36
37} // namespace TypePromotion
38
39// Implements the the behavior of the following flags:
40// - promote_inputs_to_common_dtype
41// - promote_integer_inputs_to_float
42c10::ScalarType computeTypes(
43 const TypePromotionConfig& config,
44 const std::vector<TypePtr>& operands);
45
46DataType computeTypes(
47 const TypePromotionConfig& config,
48 const std::vector<Val*>& operands);
49
50// Computes the common dtype for the given operands
51// Casts operands to common dtype if necessary
52// Automatically cast FP16/BF16 dtype to Float
53std::vector<Val*> promoteValues(
54 const TypePromotionConfig& config,
55 const std::vector<Val*>& operands);
56
57std::vector<Val*> promoteValues(
58 const std::vector<Val*>& operands,
59 DataType common_type);
60
61// Casts value to common dtype if necessary
62// Avoid cast if value's dtype matches its dtype class
63Val* optionalCast(DataType dtype, Val* v);
64
65// Casts value to common dtype if necessary
66Val* optionalCastStrict(DataType dtype, Val* v);
67
68} // namespace cuda
69} // namespace fuser
70} // namespace jit
71} // namespace torch
72