1 | #include <gtest/gtest.h> |
2 | |
3 | #include "torch/csrc/jit/tensorexpr/eval.h" |
4 | #include "torch/csrc/jit/tensorexpr/ir.h" |
5 | #include "torch/csrc/jit/tensorexpr/tensor.h" |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | using namespace torch::jit::tensorexpr; |
10 | |
11 | TEST(Type, Test01) { |
12 | { |
13 | Dtype dt1 = kInt; |
14 | ASSERT_EQ(dt1, kInt); |
15 | } |
16 | { |
17 | Dtype dt2_a(kInt, 8); |
18 | Dtype dt2_b(kInt, 4); |
19 | Dtype dt2_c(ScalarType::Int, 8); |
20 | ASSERT_EQ(dt2_a, dt2_c); |
21 | ASSERT_NE(dt2_a, dt2_b); |
22 | } |
23 | { |
24 | ASSERT_EQ(kInt, ToDtype<int>()); |
25 | ASSERT_EQ(kFloat, ToDtype<float>()); |
26 | ASSERT_EQ(kByte, ToDtype<uint8_t>()); |
27 | ASSERT_EQ(kChar, ToDtype<int8_t>()); |
28 | ASSERT_EQ(kShort, ToDtype<int16_t>()); |
29 | ASSERT_EQ(kLong, ToDtype<int64_t>()); |
30 | ASSERT_EQ(kHalf, ToDtype<at::Half>()); |
31 | ASSERT_EQ(kDouble, ToDtype<double>()); |
32 | ASSERT_EQ(kBool, ToDtype<bool>()); |
33 | } |
34 | { |
35 | Dtype int32x8(kInt, 8); |
36 | Dtype float32x8(kFloat, 8); |
37 | ASSERT_NE(int32x8, float32x8); |
38 | ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); |
39 | ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); |
40 | ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); |
41 | ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); |
42 | } |
43 | } |
44 | |
45 | TEST(Type, BitCasting) { |
46 | { |
47 | VarHandle x("x" , kFloat); |
48 | ExprHandle y = bitcast<int32_t>(x); |
49 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
50 | ASSERT_EQ(y.dtype(), kInt); |
51 | } |
52 | { |
53 | VarHandle x("x" , kInt); |
54 | ExprHandle y = bitcast<float>(x); |
55 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
56 | ASSERT_EQ(y.dtype(), kFloat); |
57 | } |
58 | { |
59 | VarHandle x("x" , kShort); |
60 | ExprHandle y = bitcast<at::Half>(x); |
61 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
62 | ASSERT_EQ(y.dtype(), kHalf); |
63 | } |
64 | { |
65 | VarHandle x("x" , kHalf); |
66 | ExprHandle y = bitcast<int16_t>(x); |
67 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
68 | ASSERT_EQ(y.dtype(), kShort); |
69 | } |
70 | |
71 | constexpr int32_t ref32 = 1337; |
72 | constexpr int64_t ref64 = 1337; |
73 | constexpr float reff32 = 1337.0f; |
74 | constexpr double reff64 = 1337.0f; |
75 | using SimpleIRExprEval = ExprEval<SimpleIREvaluator>; |
76 | // this is broken |
77 | /*{ |
78 | constexpr int16_t ref16 = 1337; |
79 | at::Half k_; |
80 | at::Half* k = &k_; |
81 | *reinterpret_cast<int16_t*>(k) = ref16; |
82 | auto a = HalfImm::make(*k); |
83 | auto b = BitCast::make(kShort, a); |
84 | SimpleIRExprEval cg(b); |
85 | ASSERT_EQ(cg.value<int16_t>(), ref16); |
86 | }*/ |
87 | |
88 | { |
89 | float k = raw_bitcast<float>(ref32); |
90 | auto a = FloatImm::make(k); |
91 | auto b = BitCast::make(kInt, a); |
92 | SimpleIRExprEval cg(b); |
93 | ASSERT_EQ(cg.value<int32_t>(), ref32); |
94 | } |
95 | |
96 | { |
97 | double k = raw_bitcast<double>(ref64); |
98 | auto a = DoubleImm::make(k); |
99 | auto b = BitCast::make(kLong, a); |
100 | SimpleIRExprEval cg(b); |
101 | ASSERT_EQ(cg.value<int64_t>(), ref64); |
102 | } |
103 | |
104 | { |
105 | int64_t k = raw_bitcast<int64_t>(reff64); |
106 | auto a = LongImm::make(k); |
107 | auto b = BitCast::make(kDouble, a); |
108 | SimpleIRExprEval cg(b); |
109 | ASSERT_EQ(cg.value<double>(), reff64); |
110 | } |
111 | |
112 | { |
113 | int32_t k = raw_bitcast<int32_t>(reff32); |
114 | auto a = IntImm::make(k); |
115 | auto b = BitCast::make(kFloat, a); |
116 | SimpleIRExprEval cg(b); |
117 | ASSERT_EQ(cg.value<float>(), reff32); |
118 | } |
119 | |
120 | // This segfaults :( |
121 | /*{ |
122 | VarHandle x("x", kDouble); |
123 | ASSERT_ANY_THROW(ExprHandle y = bitcast<int32_t>(x)); |
124 | } |
125 | { |
126 | VarHandle x("x", kFloat); |
127 | ASSERT_ANY_THROW(ExprHandle y = bitcast<int64_t>(x)); |
128 | } |
129 | { |
130 | VarHandle x("x", kLong); |
131 | ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x)); |
132 | } |
133 | { |
134 | VarHandle x("x", kShort); |
135 | ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x)); |
136 | } |
137 | { |
138 | VarHandle x("x", kInt); |
139 | ASSERT_ANY_THROW(ExprHandle y = bitcast<at::Half>(x)); |
140 | }*/ |
141 | } |
142 | |
143 | TEST(Type, Propagation) { |
144 | // Same types: |
145 | { |
146 | VarHandle x("x" , kFloat); |
147 | VarHandle y("y" , kFloat); |
148 | ExprHandle body = FloatImm::make(2.f) + |
149 | (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); |
150 | ASSERT_EQ(body.dtype(), kFloat); |
151 | } |
152 | // Int to bigger int: |
153 | { |
154 | VarHandle x("x" , kShort); |
155 | VarHandle y("y" , kLong); |
156 | ExprHandle body = |
157 | ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); |
158 | ASSERT_EQ(body.dtype(), kLong); |
159 | } |
160 | // Float to bigger float: |
161 | { |
162 | VarHandle x("x" , kHalf); |
163 | VarHandle y("y" , kDouble); |
164 | ExprHandle body = |
165 | HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); |
166 | ASSERT_EQ(body.dtype(), kDouble); |
167 | } |
168 | // Int to Float: |
169 | { |
170 | VarHandle x("x" , kFloat); |
171 | VarHandle y("y" , kInt); |
172 | ExprHandle body = |
173 | IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); |
174 | ASSERT_EQ(body.dtype(), kFloat); |
175 | } |
176 | // Smaller float, bigger Int: |
177 | { |
178 | VarHandle x("x" , kHalf); |
179 | VarHandle y("y" , kLong); |
180 | ExprHandle body = |
181 | HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); |
182 | ASSERT_EQ(body.dtype(), kHalf); |
183 | } |
184 | // Bigger float, smaller Int: |
185 | { |
186 | VarHandle x("x" , kChar); |
187 | VarHandle y("y" , kDouble); |
188 | ExprHandle body = |
189 | CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); |
190 | ASSERT_EQ(body.dtype(), kDouble); |
191 | } |
192 | // Sign change char/byte upgrades to short: |
193 | { |
194 | VarHandle x("x" , kChar); |
195 | VarHandle y("y" , kByte); |
196 | ExprHandle body = |
197 | CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); |
198 | ASSERT_EQ(body.dtype(), kShort); |
199 | } |
200 | } |
201 | } // namespace jit |
202 | } // namespace torch |
203 | |