1#include <gtest/gtest.h>
2
3#include "torch/csrc/jit/tensorexpr/eval.h"
4#include "torch/csrc/jit/tensorexpr/ir.h"
5#include "torch/csrc/jit/tensorexpr/tensor.h"
6
7namespace torch {
8namespace jit {
9using namespace torch::jit::tensorexpr;
10
11TEST(Type, Test01) {
12 {
13 Dtype dt1 = kInt;
14 ASSERT_EQ(dt1, kInt);
15 }
16 {
17 Dtype dt2_a(kInt, 8);
18 Dtype dt2_b(kInt, 4);
19 Dtype dt2_c(ScalarType::Int, 8);
20 ASSERT_EQ(dt2_a, dt2_c);
21 ASSERT_NE(dt2_a, dt2_b);
22 }
23 {
24 ASSERT_EQ(kInt, ToDtype<int>());
25 ASSERT_EQ(kFloat, ToDtype<float>());
26 ASSERT_EQ(kByte, ToDtype<uint8_t>());
27 ASSERT_EQ(kChar, ToDtype<int8_t>());
28 ASSERT_EQ(kShort, ToDtype<int16_t>());
29 ASSERT_EQ(kLong, ToDtype<int64_t>());
30 ASSERT_EQ(kHalf, ToDtype<at::Half>());
31 ASSERT_EQ(kDouble, ToDtype<double>());
32 ASSERT_EQ(kBool, ToDtype<bool>());
33 }
34 {
35 Dtype int32x8(kInt, 8);
36 Dtype float32x8(kFloat, 8);
37 ASSERT_NE(int32x8, float32x8);
38 ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
39 ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
40 ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
41 ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
42 }
43}
44
45TEST(Type, BitCasting) {
46 {
47 VarHandle x("x", kFloat);
48 ExprHandle y = bitcast<int32_t>(x);
49 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
50 ASSERT_EQ(y.dtype(), kInt);
51 }
52 {
53 VarHandle x("x", kInt);
54 ExprHandle y = bitcast<float>(x);
55 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
56 ASSERT_EQ(y.dtype(), kFloat);
57 }
58 {
59 VarHandle x("x", kShort);
60 ExprHandle y = bitcast<at::Half>(x);
61 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
62 ASSERT_EQ(y.dtype(), kHalf);
63 }
64 {
65 VarHandle x("x", kHalf);
66 ExprHandle y = bitcast<int16_t>(x);
67 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
68 ASSERT_EQ(y.dtype(), kShort);
69 }
70
71 constexpr int32_t ref32 = 1337;
72 constexpr int64_t ref64 = 1337;
73 constexpr float reff32 = 1337.0f;
74 constexpr double reff64 = 1337.0f;
75 using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
76 // this is broken
77 /*{
78 constexpr int16_t ref16 = 1337;
79 at::Half k_;
80 at::Half* k = &k_;
81 *reinterpret_cast<int16_t*>(k) = ref16;
82 auto a = HalfImm::make(*k);
83 auto b = BitCast::make(kShort, a);
84 SimpleIRExprEval cg(b);
85 ASSERT_EQ(cg.value<int16_t>(), ref16);
86 }*/
87
88 {
89 float k = raw_bitcast<float>(ref32);
90 auto a = FloatImm::make(k);
91 auto b = BitCast::make(kInt, a);
92 SimpleIRExprEval cg(b);
93 ASSERT_EQ(cg.value<int32_t>(), ref32);
94 }
95
96 {
97 double k = raw_bitcast<double>(ref64);
98 auto a = DoubleImm::make(k);
99 auto b = BitCast::make(kLong, a);
100 SimpleIRExprEval cg(b);
101 ASSERT_EQ(cg.value<int64_t>(), ref64);
102 }
103
104 {
105 int64_t k = raw_bitcast<int64_t>(reff64);
106 auto a = LongImm::make(k);
107 auto b = BitCast::make(kDouble, a);
108 SimpleIRExprEval cg(b);
109 ASSERT_EQ(cg.value<double>(), reff64);
110 }
111
112 {
113 int32_t k = raw_bitcast<int32_t>(reff32);
114 auto a = IntImm::make(k);
115 auto b = BitCast::make(kFloat, a);
116 SimpleIRExprEval cg(b);
117 ASSERT_EQ(cg.value<float>(), reff32);
118 }
119
120 // This segfaults :(
121 /*{
122 VarHandle x("x", kDouble);
123 ASSERT_ANY_THROW(ExprHandle y = bitcast<int32_t>(x));
124 }
125 {
126 VarHandle x("x", kFloat);
127 ASSERT_ANY_THROW(ExprHandle y = bitcast<int64_t>(x));
128 }
129 {
130 VarHandle x("x", kLong);
131 ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
132 }
133 {
134 VarHandle x("x", kShort);
135 ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
136 }
137 {
138 VarHandle x("x", kInt);
139 ASSERT_ANY_THROW(ExprHandle y = bitcast<at::Half>(x));
140 }*/
141}
142
143TEST(Type, Propagation) {
144 // Same types:
145 {
146 VarHandle x("x", kFloat);
147 VarHandle y("y", kFloat);
148 ExprHandle body = FloatImm::make(2.f) +
149 (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y);
150 ASSERT_EQ(body.dtype(), kFloat);
151 }
152 // Int to bigger int:
153 {
154 VarHandle x("x", kShort);
155 VarHandle y("y", kLong);
156 ExprHandle body =
157 ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y);
158 ASSERT_EQ(body.dtype(), kLong);
159 }
160 // Float to bigger float:
161 {
162 VarHandle x("x", kHalf);
163 VarHandle y("y", kDouble);
164 ExprHandle body =
165 HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
166 ASSERT_EQ(body.dtype(), kDouble);
167 }
168 // Int to Float:
169 {
170 VarHandle x("x", kFloat);
171 VarHandle y("y", kInt);
172 ExprHandle body =
173 IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y);
174 ASSERT_EQ(body.dtype(), kFloat);
175 }
176 // Smaller float, bigger Int:
177 {
178 VarHandle x("x", kHalf);
179 VarHandle y("y", kLong);
180 ExprHandle body =
181 HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
182 ASSERT_EQ(body.dtype(), kHalf);
183 }
184 // Bigger float, smaller Int:
185 {
186 VarHandle x("x", kChar);
187 VarHandle y("y", kDouble);
188 ExprHandle body =
189 CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
190 ASSERT_EQ(body.dtype(), kDouble);
191 }
192 // Sign change char/byte upgrades to short:
193 {
194 VarHandle x("x", kChar);
195 VarHandle y("y", kByte);
196 ExprHandle body =
197 CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
198 ASSERT_EQ(body.dtype(), kShort);
199 }
200}
201} // namespace jit
202} // namespace torch
203