1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/tensorexpr/test_base.h> |
4 | |
5 | #include <c10/util/irange.h> |
6 | #include <test/cpp/tensorexpr/padded_buffer.h> |
7 | #include <test/cpp/tensorexpr/test_utils.h> |
8 | #include <torch/csrc/jit/tensorexpr/eval.h> |
9 | #include <torch/csrc/jit/tensorexpr/ir.h> |
10 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
11 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
12 | #include <torch/csrc/jit/tensorexpr/ir_verifier.h> |
13 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
14 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
15 | |
16 | #include <cmath> |
17 | #include <sstream> |
18 | #include <stdexcept> |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | namespace torch { |
23 | namespace jit { |
24 | using namespace torch::jit::tensorexpr; |
25 | |
26 | using SimpleIRExprEval = ExprEval<SimpleIREvaluator>; |
27 | |
28 | TEST(Expr, BasicValueTest) { |
29 | ExprHandle a = IntImm::make(2), b = IntImm::make(3); |
30 | ExprHandle c = Add::make(a, b); |
31 | SimpleIRExprEval eval(c); |
32 | ASSERT_EQ(eval.value<int>(), 5); |
33 | } |
34 | |
35 | TEST(Expr, BasicValueTest02) { |
36 | ExprHandle a(2.0f); |
37 | ExprHandle b(3.0f); |
38 | ExprHandle c(4.0f); |
39 | ExprHandle d(5.0f); |
40 | ExprHandle f = (a + b) - (c + d); |
41 | SimpleIRExprEval eval(f); |
42 | ASSERT_EQ(eval.value<float>(), -4.0f); |
43 | } |
44 | |
45 | TEST(Expr, IsChannelsLastContiguous) { |
46 | std::vector<VarHandle> vars = { |
47 | VarHandle("var1" , kLong), |
48 | VarHandle("var2" , kLong), |
49 | VarHandle("var3" , kLong), |
50 | VarHandle("var4" , kLong), |
51 | VarHandle("var5" , kLong)}; |
52 | |
53 | // { |
54 | // key: ndims, |
55 | // value: [ |
56 | // ... |
57 | // [dim_2, dim_1, ..., dim_n] |
58 | // ] |
59 | // } |
60 | using shapGenInfo = std::unordered_map<int, std::vector<std::vector<int>>>; |
61 | |
62 | // { |
63 | // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n], |
64 | // strides: [ |
65 | // ... |
66 | // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z] |
67 | // ] |
68 | // } |
69 | using shapeInfo = |
70 | std::pair<std::vector<ExprHandle>, std::vector<std::vector<ExprHandle>>>; |
71 | |
72 | std::vector<int> dims = {3, 4, 5}; |
73 | |
74 | std::unordered_map<int, std::vector<ExprHandle>> dims_expr_vec_conf = { |
75 | {3, std::vector<ExprHandle>(vars.begin(), vars.begin() + 2)}, |
76 | {4, std::vector<ExprHandle>(vars.begin(), vars.begin() + 3)}, |
77 | {5, std::vector<ExprHandle>(vars.begin(), vars.begin() + 4)}, |
78 | }; |
79 | |
80 | shapGenInfo channels_last_cont_shape_conf = { |
81 | {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}}; |
82 | shapGenInfo channels_last_non_cont_shape_conf = { |
83 | {3, {{2, 1, 0}, {1, 0, 2}}}, |
84 | {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}}, |
85 | {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}}; |
86 | |
87 | shapGenInfo cont_shape_conf = { |
88 | {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}}; |
89 | |
90 | auto shape_gen_fn = [dims_expr_vec_conf]( |
91 | int ndims, shapGenInfo shape_gen_info) -> shapeInfo { |
92 | auto dims_expr_vec = dims_expr_vec_conf.at(ndims); |
93 | std::vector<std::vector<ExprHandle>> strides_expr_vec; |
94 | for (size_t i = 0; i < strides_expr_vec.size(); i++) { |
95 | strides_expr_vec[i].resize(ndims); |
96 | } |
97 | |
98 | auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) { |
99 | if (indicator % 2 == 0) { |
100 | return a * b; |
101 | } else { |
102 | return b * a; |
103 | } |
104 | }; |
105 | |
106 | auto stride_order_vec = shape_gen_info.at(ndims); |
107 | for (size_t i = 0; i < strides_expr_vec.size(); i++) { |
108 | auto stride_order = stride_order_vec[i]; |
109 | |
110 | strides_expr_vec[i][stride_order[0]] = 1; |
111 | for (size_t j = 1; j < stride_order.size(); j++) { |
112 | auto cur_dim_idx = stride_order[j]; |
113 | auto adjacent_dim_idx = stride_order[j - 1]; |
114 | |
115 | strides_expr_vec[i][cur_dim_idx] = stride_gen_fn( |
116 | i, |
117 | dims_expr_vec[adjacent_dim_idx], |
118 | strides_expr_vec[i][adjacent_dim_idx]); |
119 | } |
120 | } |
121 | |
122 | return {dims_expr_vec, strides_expr_vec}; |
123 | }; |
124 | |
125 | auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool { |
126 | if (ndims == 3) { |
127 | return buf_handle.is_channels_last_1d_contiguous(); |
128 | } else if (ndims == 4) { |
129 | return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast); |
130 | } else { |
131 | return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d); |
132 | } |
133 | }; |
134 | |
135 | // channels-last contigous |
136 | for (size_t i = 0; i < dims.size(); i++) { |
137 | auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); |
138 | for (size_t j = 0; j < shape_info.second.size(); j++) { |
139 | BufHandle buf_handle("a" , shape_info.first, shape_info.second[j], kFloat); |
140 | ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true); |
141 | } |
142 | } |
143 | |
144 | // channels-last non-contigous |
145 | for (size_t i = 0; i < dims.size(); i++) { |
146 | auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf); |
147 | for (size_t j = 0; j < shape_info.second.size(); j++) { |
148 | BufHandle buf_handle("a" , shape_info.first, shape_info.second[j], kFloat); |
149 | ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false); |
150 | } |
151 | } |
152 | |
153 | // contiguous |
154 | for (size_t i = 0; i < dims.size(); i++) { |
155 | auto shape_info = shape_gen_fn(dims[i], cont_shape_conf); |
156 | for (size_t j = 0; j < shape_info.second.size(); j++) { |
157 | BufHandle buf_handle("a" , shape_info.first, shape_info.second[j], kFloat); |
158 | ASSERT_EQ(buf_handle.is_contiguous(), true); |
159 | } |
160 | } |
161 | |
162 | // non-contiguous |
163 | for (size_t i = 0; i < dims.size(); i++) { |
164 | auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); |
165 | for (size_t j = 0; j < shape_info.second.size(); j++) { |
166 | BufHandle buf_handle("a" , shape_info.first, shape_info.second[j], kFloat); |
167 | ASSERT_EQ(buf_handle.is_contiguous(), false); |
168 | } |
169 | } |
170 | } |
171 | |
172 | TEST(Expr, LetTest01) { |
173 | VarHandle x("x" , kFloat); |
174 | ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); |
175 | SimpleIRExprEval eval(body); |
176 | eval.bindVar(x, ExprHandle(3.f)); |
177 | ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4)); |
178 | } |
179 | |
180 | TEST(Expr, LetTest02) { |
181 | VarHandle x("x" , kFloat); |
182 | VarHandle y("y" , kFloat); |
183 | ExprHandle body = |
184 | ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); |
185 | SimpleIRExprEval eval(body); |
186 | eval.bindVar(x, ExprHandle(3.f)); |
187 | eval.bindVar(y, ExprHandle(6.f)); |
188 | ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6)); |
189 | } |
190 | |
191 | TEST(Expr, LetStmtTest01) { |
192 | BufHandle a_buf("a" , {1}, kFloat); |
193 | BufHandle b_buf("b" , {1}, kFloat); |
194 | |
195 | ExprHandle load_a = a_buf.load(0); |
196 | VarHandle var = VarHandle("v" , kFloat); |
197 | StmtPtr let_store = Let::make(var, load_a); |
198 | StmtPtr store_b = b_buf.store({0}, var); |
199 | BlockPtr block = Block::make({let_store, store_b}); |
200 | |
201 | SimpleIREvaluator eval(block, {a_buf, b_buf}); |
202 | |
203 | PaddedBuffer<float> a_v(1); |
204 | PaddedBuffer<float> b_v(1); |
205 | PaddedBuffer<float> b_ref(1); |
206 | |
207 | a_v(0) = 23; |
208 | b_ref(0) = a_v(0); |
209 | eval(a_v, b_v); |
210 | |
211 | ExpectAllNear(b_v, b_ref, 1e-5); |
212 | } |
213 | |
214 | TEST(Expr, IntTest) { |
215 | VarHandle x("x" , kInt); |
216 | ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); |
217 | SimpleIRExprEval eval(body); |
218 | eval.bindVar(x, ExprHandle(3)); |
219 | ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4)); |
220 | } |
221 | |
222 | TEST(Expr, FloatTest) { |
223 | VarHandle x("x" , kFloat); |
224 | ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); |
225 | SimpleIRExprEval eval(body); |
226 | eval.bindVar(x, ExprHandle(3.f)); |
227 | ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4)); |
228 | } |
229 | |
230 | TEST(Expr, ByteTest) { |
231 | VarHandle x("x" , kByte); |
232 | ExprHandle body = ExprHandle((uint8_t)2) + |
233 | (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); |
234 | SimpleIRExprEval eval(body); |
235 | eval.bindVar(x, ExprHandle((uint8_t)3)); |
236 | ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4)); |
237 | } |
238 | |
239 | TEST(Expr, CharTest) { |
240 | VarHandle x("x" , kChar); |
241 | ExprHandle body = ExprHandle((int8_t)2) + |
242 | (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); |
243 | SimpleIRExprEval eval(body); |
244 | eval.bindVar(x, ExprHandle((int8_t)3)); |
245 | ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4)); |
246 | } |
247 | |
248 | TEST(Expr, ShortTest) { |
249 | VarHandle x("x" , kShort); |
250 | ExprHandle body = ExprHandle((int16_t)2) + |
251 | (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); |
252 | SimpleIRExprEval eval(body); |
253 | eval.bindVar(x, ExprHandle((int16_t)3)); |
254 | ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4)); |
255 | } |
256 | |
257 | TEST(Expr, LongTest) { |
258 | VarHandle x("x" , kLong); |
259 | ExprHandle body = ExprHandle((int64_t)2) + |
260 | (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); |
261 | SimpleIRExprEval eval(body); |
262 | eval.bindVar(x, ExprHandle((int64_t)3)); |
263 | ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4)); |
264 | } |
265 | |
266 | TEST(Expr, HalfTest) { |
267 | VarHandle x("x" , kHalf); |
268 | ExprHandle body = ExprHandle((at::Half)2) + |
269 | (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); |
270 | SimpleIRExprEval eval(body); |
271 | eval.bindVar(x, ExprHandle((at::Half)3)); |
272 | ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4)); |
273 | } |
274 | |
275 | TEST(Expr, DoubleTest) { |
276 | VarHandle x("x" , kDouble); |
277 | ExprHandle body = ExprHandle((double)2) + |
278 | (x * ExprHandle((double)3) + ExprHandle((double)4)); |
279 | SimpleIRExprEval eval(body); |
280 | eval.bindVar(x, ExprHandle((double)3)); |
281 | ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4)); |
282 | } |
283 | |
284 | TEST(Expr, VectorAdd01) { |
285 | const int kVectorSize = 8; |
286 | const int kVectorCount = 128; |
287 | const int kTotalSize = kVectorSize * kVectorCount; |
288 | |
289 | BufHandle a_buf("A" , {kTotalSize}, kFloat); |
290 | BufHandle b_buf("B" , {kTotalSize}, kFloat); |
291 | BufHandle c_buf("C" , {kTotalSize}, kFloat); |
292 | |
293 | /* |
294 | Build the following: |
295 | for (const auto index : c10::irange(kVectorCount)) { |
296 | store(c_buf, ramp(index * 8, 1, 8), |
297 | load(a_buf, ramp(index * 8, 1, 8) + |
298 | load(b_buf, ramp(index * 8, 1, 8)))) |
299 | } |
300 | */ |
301 | VarHandle index = VarHandle("index" , kInt); |
302 | ExprHandle load_a = |
303 | a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); |
304 | ExprHandle load_b = |
305 | b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); |
306 | ExprHandle value = load_a + load_b; |
307 | StmtPtr store_c = |
308 | c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); |
309 | StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); |
310 | |
311 | ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); |
312 | ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); |
313 | ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); |
314 | |
315 | PaddedBuffer<float> a_v(kTotalSize); |
316 | PaddedBuffer<float> b_v(kTotalSize); |
317 | PaddedBuffer<float> c_v(kTotalSize); |
318 | PaddedBuffer<float> c_ref(kTotalSize); |
319 | for (const auto i : c10::irange(kTotalSize)) { |
320 | a_v(i) = i * i; |
321 | b_v(i) = i * i * 4; |
322 | c_ref(i) = a_v(i) + b_v(i); |
323 | } |
324 | SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); |
325 | ir_eval(a_v, b_v, c_v); |
326 | ExpectAllNear(c_v, c_ref, 1e-5); |
327 | } |
328 | |
329 | TEST(Expr, CompareSelectEQ) { |
330 | constexpr int N = 1024; |
331 | BufHandle a("A" , {N}, kInt); |
332 | BufHandle b("B" , {N}, kInt); |
333 | BufHandle c("C" , {N}, kInt); |
334 | std::vector<int> a_buffer(N, 1); |
335 | std::vector<int> b_buffer(N, 1); |
336 | std::vector<int> c_buffer(N, 0); |
337 | std::vector<int> c_ref(N, 0); |
338 | |
339 | VarHandle i("i" , kInt); |
340 | auto memcpy_expr = For::make( |
341 | i, |
342 | 0, |
343 | N, |
344 | c.store( |
345 | {i}, |
346 | CompareSelect::make( |
347 | a.load(i), b.load(i), CompareSelectOperation::kEQ))); |
348 | |
349 | SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); |
350 | ir_eval(a_buffer, b_buffer, c_buffer); |
351 | |
352 | ASSERT_EQ(a_buffer.size(), N); |
353 | ASSERT_EQ(b_buffer.size(), N); |
354 | ASSERT_EQ(c_buffer.size(), N); |
355 | |
356 | assertAllEqual(a_buffer, 1); |
357 | assertAllEqual(b_buffer, 1); |
358 | assertAllEqual(c_buffer, 1); |
359 | } |
360 | |
361 | TEST(Expr, CompareSelectDtypes) { |
362 | // LHS and RHS expressions should have the same dtype, but this dtype could |
363 | // differ from the dtype of the return values (but dtypes of true and false |
364 | // return values should be the same). |
365 | // This test constructs a CompareSelect expression where the input dtype is |
366 | // different from the output dtype and verifies that it works correctly: |
367 | // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 |
368 | constexpr int N = 1024; |
369 | BufHandle a("A" , {N}, kInt); |
370 | BufHandle b("B" , {N}, kInt); |
371 | BufHandle c("C" , {N}, kFloat); |
372 | std::vector<int> a_buffer(N, 1); |
373 | std::vector<int> b_buffer(N, 1); |
374 | std::vector<float> c_buffer(N, 0.0f); |
375 | std::vector<float> c_ref(N, 3.14f); |
376 | |
377 | VarHandle i("i" , kInt); |
378 | // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f |
379 | // A and B are int, C is float. |
380 | auto select_expr = For::make( |
381 | i, |
382 | 0, |
383 | N, |
384 | c.store( |
385 | {i}, |
386 | CompareSelect::make( |
387 | a.load(i), |
388 | b.load(i), |
389 | FloatImm::make(3.14f), |
390 | FloatImm::make(2.78f), |
391 | CompareSelectOperation::kEQ))); |
392 | |
393 | SimpleIREvaluator ir_eval(select_expr, {a, b, c}); |
394 | ir_eval(a_buffer, b_buffer, c_buffer); |
395 | |
396 | ASSERT_EQ(a_buffer.size(), N); |
397 | ASSERT_EQ(b_buffer.size(), N); |
398 | ASSERT_EQ(c_buffer.size(), N); |
399 | |
400 | assertAllEqual(a_buffer, 1); |
401 | assertAllEqual(b_buffer, 1); |
402 | ExpectAllNear(c_buffer, c_ref, 1e-7); |
403 | } |
404 | |
405 | TEST(Expr, IntrinsicsDtypes) { |
406 | constexpr int N = 256; |
407 | BufHandle a("A" , {N}, kDouble); |
408 | BufHandle b("B" , {N}, kDouble); |
409 | std::vector<double> a_buffer(N, -10.0); |
410 | std::vector<double> b_buffer(N, 0.0); |
411 | std::vector<double> b_ref(N, 10.0); |
412 | |
413 | VarHandle i("i" , kInt); |
414 | auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); |
415 | |
416 | SimpleIREvaluator ir_eval(abs_expr, {a, b}); |
417 | ir_eval(a_buffer, b_buffer); |
418 | |
419 | ASSERT_EQ(a_buffer.size(), N); |
420 | ASSERT_EQ(b_buffer.size(), N); |
421 | |
422 | assertAllEqual(a_buffer, -10.0); |
423 | ExpectAllNear(b_buffer, b_ref, 1e-7); |
424 | } |
425 | |
426 | TEST(Expr, Substitute01) { |
427 | VarPtr x = alloc<Var>("x" , kFloat); |
428 | VarPtr y = alloc<Var>("y" , kFloat); |
429 | ExprPtr e = |
430 | alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y)); |
431 | |
432 | VarPtr z = alloc<Var>("z" , kFloat); |
433 | ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}}); |
434 | ExprPtr e2_ref = alloc<Mul>( |
435 | alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)), |
436 | alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y)); |
437 | std::ostringstream oss; |
438 | oss << *e2; |
439 | std::string e2_str = oss.str(); |
440 | |
441 | oss.str("" ); |
442 | oss << *e2_ref; |
443 | std::string e2_ref_str = oss.str(); |
444 | ASSERT_EQ(e2_str, e2_ref_str); |
445 | } |
446 | |
447 | TEST(Expr, Math01) { |
448 | ExprHandle v = sin(ExprHandle(1.0f)); |
449 | |
450 | std::ostringstream oss; |
451 | oss << v; |
452 | ASSERT_EQ(oss.str(), "sin(1.f)" ); |
453 | |
454 | SimpleIRExprEval eval(v); |
455 | float v_ref = std::sin(1.0f); |
456 | float res = eval.value<float>(); |
457 | ASSERT_NEAR(res, v_ref, 1e-6); |
458 | } |
459 | |
460 | TEST(Expr, UnaryMath01) { |
461 | struct TestConfig { |
462 | std::function<ExprHandle(const ExprHandle&)> func; |
463 | std::function<float(float)> ref_func; |
464 | }; |
465 | |
466 | std::vector<TestConfig> test_configs = { |
467 | {[](const ExprHandle& v) { return sin(v); }, |
468 | [](float v) { return std::sin(v); }}, |
469 | {[](const ExprHandle& v) { return sin(v); }, |
470 | [](float v) { return std::sin(v); }}, |
471 | {[](const ExprHandle& v) { return tan(v); }, |
472 | [](float v) { return std::tan(v); }}, |
473 | {[](const ExprHandle& v) { return asin(v); }, |
474 | [](float v) { return std::asin(v); }}, |
475 | {[](const ExprHandle& v) { return acos(v); }, |
476 | [](float v) { return std::acos(v); }}, |
477 | {[](const ExprHandle& v) { return atan(v); }, |
478 | [](float v) { return std::atan(v); }}, |
479 | {[](const ExprHandle& v) { return sinh(v); }, |
480 | [](float v) { return std::sinh(v); }}, |
481 | {[](const ExprHandle& v) { return cosh(v); }, |
482 | [](float v) { return std::cosh(v); }}, |
483 | {[](const ExprHandle& v) { return tanh(v); }, |
484 | [](float v) { return std::tanh(v); }}, |
485 | {[](const ExprHandle& v) { return exp(v); }, |
486 | [](float v) { return std::exp(v); }}, |
487 | {[](const ExprHandle& v) { return tensorexpr::abs(v); }, |
488 | [](float v) { return std::fabs(v); }}, |
489 | {[](const ExprHandle& v) { return log(v); }, |
490 | [](float v) { return std::log(v); }}, |
491 | {[](const ExprHandle& v) { return log2(v); }, |
492 | [](float v) { return std::log2(v); }}, |
493 | {[](const ExprHandle& v) { return log10(v); }, |
494 | [](float v) { return std::log10(v); }}, |
495 | {[](const ExprHandle& v) { return erf(v); }, |
496 | [](float v) { return std::erf(v); }}, |
497 | {[](const ExprHandle& v) { return sqrt(v); }, |
498 | [](float v) { return std::sqrt(v); }}, |
499 | {[](const ExprHandle& v) { return rsqrt(v); }, |
500 | [](float v) { return 1.0f / std::sqrt(v); }}, |
501 | {[](const ExprHandle& v) { return ceil(v); }, |
502 | [](float v) { return std::ceil(v); }}, |
503 | {[](const ExprHandle& v) { return floor(v); }, |
504 | [](float v) { return std::floor(v); }}, |
505 | {[](const ExprHandle& v) { return round(v); }, |
506 | [](float v) { return std::round(v); }}, |
507 | {[](const ExprHandle& v) { return trunc(v); }, |
508 | [](float v) { return std::trunc(v); }}, |
509 | }; |
510 | |
511 | for (const TestConfig& test_config : test_configs) { |
512 | const float input_v = 0.8765f; |
513 | ExprHandle v = test_config.func(ExprHandle(input_v)); |
514 | float v_ref = test_config.ref_func(input_v); |
515 | SimpleIRExprEval eval(v); |
516 | ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6); |
517 | } |
518 | |
519 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
520 | for (float input_v : {std::nan("1" ), 0., .5}) { |
521 | ExprHandle v = FloatImm::make(input_v); |
522 | SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); |
523 | ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0); |
524 | } |
525 | } |
526 | |
527 | TEST(Expr, BinaryMath01) { |
528 | struct TestConfig { |
529 | std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func; |
530 | std::function<float(float, float)> ref_func; |
531 | }; |
532 | |
533 | std::vector<TestConfig> test_configs = { |
534 | {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, |
535 | [](float v1, float v2) { return std::pow(v1, v2); }}, |
536 | {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, |
537 | [](float v1, float v2) { return std::fmod(v1, v2); }}, |
538 | }; |
539 | |
540 | for (const TestConfig& test_config : test_configs) { |
541 | const float v1 = 0.8765f; |
542 | float v2 = 1.2345f; |
543 | ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); |
544 | float v_ref = test_config.ref_func(v1, v2); |
545 | SimpleIRExprEval eval(v_expr); |
546 | ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6); |
547 | } |
548 | } |
549 | |
550 | TEST(Expr, LogicalOps01) { |
551 | ExprHandle a(23); |
552 | ExprHandle b(11); |
553 | ExprHandle c(0.72f); |
554 | ExprHandle d(0.69f); |
555 | ExprHandle f1 = (a > b) && (c > d); |
556 | ExprHandle f2 = (a > b) && (c < d); |
557 | ExprHandle f3 = (a < b) && (c > d); |
558 | ExprHandle f4 = (a < b) && (c < d); |
559 | ExprHandle f5 = (a < b) || (c > d); |
560 | ExprHandle f6 = (a < b) || (c < d); |
561 | ExprHandle f7 = (a > b) || (c < d); |
562 | ExprHandle f8 = (a > b) || (c > d); |
563 | |
564 | SimpleIRExprEval eval1(f1); |
565 | SimpleIRExprEval eval2(f2); |
566 | SimpleIRExprEval eval3(f3); |
567 | SimpleIRExprEval eval4(f4); |
568 | SimpleIRExprEval eval5(f5); |
569 | SimpleIRExprEval eval6(f6); |
570 | SimpleIRExprEval eval7(f7); |
571 | SimpleIRExprEval eval8(f8); |
572 | ASSERT_EQ(eval1.value<int>(), 1); |
573 | ASSERT_EQ(eval2.value<int>(), 0); |
574 | ASSERT_EQ(eval3.value<int>(), 0); |
575 | ASSERT_EQ(eval4.value<int>(), 0); |
576 | ASSERT_EQ(eval5.value<int>(), 1); |
577 | ASSERT_EQ(eval6.value<int>(), 0); |
578 | ASSERT_EQ(eval7.value<int>(), 1); |
579 | ASSERT_EQ(eval8.value<int>(), 1); |
580 | } |
581 | |
582 | TEST(Expr, LogicalOps02) { |
583 | ExprHandle a(23); |
584 | ExprHandle b(11); |
585 | ExprHandle c(0.72f); |
586 | ExprHandle d(0.72f); |
587 | |
588 | ExprHandle f1 = (a > b) || (c > d); |
589 | ExprHandle f2 = (a > b) && (c <= d); |
590 | ExprHandle f3 = (a > b) && (c > d); |
591 | ExprHandle ff1 = f1 && f2; |
592 | ExprHandle ff2 = f2 || f3; |
593 | |
594 | SimpleIRExprEval eval1(ff1); |
595 | SimpleIRExprEval eval2(ff2); |
596 | ASSERT_EQ(eval1.value<int>(), 1); |
597 | ASSERT_EQ(eval2.value<int>(), 1); |
598 | } |
599 | |
600 | TEST(Expr, LogicalOps03) { |
601 | ExprHandle a(23); |
602 | ExprHandle b(11); |
603 | ExprHandle c(0.72f); |
604 | ExprHandle d(0.69f); |
605 | |
606 | // Bool types |
607 | ExprHandle bool_f1 = (a > b) && BoolImm::make(true); |
608 | ExprHandle bool_f2 = (c <= d) || BoolImm::make(true); |
609 | |
610 | // Int types |
611 | ExprHandle int_f1 = (a > b) && IntImm::make(1); |
612 | ExprHandle int_f2 = (c <= d) || IntImm::make(1); |
613 | |
614 | // Short types |
615 | ExprHandle short_f1 = (a > b) && ShortImm::make(1); |
616 | ExprHandle short_f2 = (c <= d) || ShortImm::make(1); |
617 | |
618 | // Long types |
619 | ExprHandle long_f1 = (a > b) && LongImm::make(1); |
620 | ExprHandle long_f2 = (c <= d) || LongImm::make(1); |
621 | |
622 | // Char types |
623 | ExprHandle char_f1 = (a > b) && CharImm::make(1); |
624 | ExprHandle char_f2 = (c <= d) || CharImm::make(1); |
625 | |
626 | // Byte types |
627 | ExprHandle byte_f1 = (a > b) && ByteImm::make(1); |
628 | ExprHandle byte_f2 = (c <= d) || ByteImm::make(1); |
629 | |
630 | SimpleIRExprEval eval1(bool_f1); |
631 | SimpleIRExprEval eval2(bool_f2); |
632 | SimpleIRExprEval eval3(int_f1); |
633 | SimpleIRExprEval eval4(int_f2); |
634 | SimpleIRExprEval eval5(short_f1); |
635 | SimpleIRExprEval eval6(short_f2); |
636 | SimpleIRExprEval eval7(long_f1); |
637 | SimpleIRExprEval eval8(long_f2); |
638 | SimpleIRExprEval eval9(char_f1); |
639 | SimpleIRExprEval eval10(char_f2); |
640 | SimpleIRExprEval eval11(byte_f1); |
641 | SimpleIRExprEval eval12(byte_f2); |
642 | |
643 | ASSERT_EQ(eval1.value<bool>(), true); |
644 | ASSERT_EQ(eval2.value<bool>(), true); |
645 | ASSERT_EQ(eval3.value<int>(), 1); |
646 | ASSERT_EQ(eval4.value<int>(), 1); |
647 | ASSERT_EQ(eval5.value<int16_t>(), 1); |
648 | ASSERT_EQ(eval6.value<int16_t>(), 1); |
649 | ASSERT_EQ(eval7.value<int64_t>(), 1); |
650 | ASSERT_EQ(eval8.value<int64_t>(), 1); |
651 | ASSERT_EQ(eval9.value<int8_t>(), 1); |
652 | ASSERT_EQ(eval10.value<int8_t>(), 1); |
653 | ASSERT_EQ(eval11.value<uint8_t>(), 1); |
654 | ASSERT_EQ(eval12.value<uint8_t>(), 1); |
655 | } |
656 | |
657 | TEST(Expr, BitwiseOps) { |
658 | ExprHandle a(59); |
659 | ExprHandle b(11); |
660 | ExprHandle c(101); |
661 | ExprHandle d(2); |
662 | ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; |
663 | |
664 | SimpleIRExprEval eval(f); |
665 | ASSERT_EQ(eval.value<int>(), 11); |
666 | } |
667 | |
668 | TEST(Expr, DynamicShapeAdd) { |
669 | auto testWithSize = [](int32_t size) { |
670 | VarHandle n("n" , kInt); |
671 | BufHandle a("a" , {n}, kFloat); |
672 | BufHandle b("b" , {n}, kFloat); |
673 | BufHandle c("c" , {n}, kFloat); |
674 | VarHandle i("i" , kInt); |
675 | StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); |
676 | std::vector<float> aData(size, 1.0f); |
677 | std::vector<float> bData(size, 2.0f); |
678 | std::vector<float> cData(size, 0.0f); |
679 | SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); |
680 | ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7); |
681 | }; |
682 | testWithSize(1); |
683 | testWithSize(16); |
684 | testWithSize(37); |
685 | } |
686 | |
687 | TEST(Expr, OutOfBounds) { |
688 | ExprHandle N(10); |
689 | ExprHandle start(0); |
690 | ExprHandle stop(15); |
691 | VarHandle i("i" , kInt); |
692 | |
693 | BufHandle X("X" , {N}, kInt); |
694 | |
695 | auto body = Store::make(X, {i}, i); |
696 | auto stmt = For::make(i, start, stop, body); |
697 | |
698 | PaddedBuffer<int> data(20); |
699 | |
700 | EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); |
701 | } |
702 | |
703 | TEST(Expr, OutOfBounds2d) { |
704 | std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}}; |
705 | for (auto sizes : size_options) { |
706 | ExprHandle N(sizes.first); |
707 | ExprHandle M(sizes.second); |
708 | ExprHandle start(0); |
709 | ExprHandle stopInner(15); |
710 | ExprHandle stopOuter(15); |
711 | VarHandle i("i" , kInt); |
712 | VarHandle j("j" , kInt); |
713 | |
714 | BufHandle X("X" , {N, M}, kInt); |
715 | |
716 | auto body = Store::make(X, {i, j}, i); |
717 | auto inner = For::make(j, start, stopInner, body); |
718 | auto stmt = For::make(i, start, stopOuter, inner); |
719 | |
720 | PaddedBuffer<int> data(400); |
721 | |
722 | EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); |
723 | } |
724 | } |
725 | |
726 | TEST(Expr, OutOfBounds2dFlattenedIndex) { |
727 | ExprHandle buf_size(149); |
728 | ExprHandle start(0); |
729 | ExprHandle stopInner(15); |
730 | ExprHandle stopOuter(10); |
731 | VarHandle i("i" , kInt); |
732 | VarHandle j("j" , kInt); |
733 | |
734 | BufHandle X("X" , {buf_size}, kInt); |
735 | |
736 | auto idx = Add::make(Mul::make(i, stopInner), j); |
737 | auto body = Store::make(X, {idx}, i); |
738 | auto inner = For::make(j, start, stopInner, body); |
739 | auto stmt = For::make(i, start, stopOuter, inner); |
740 | |
741 | PaddedBuffer<int> data(400); |
742 | |
743 | EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); |
744 | } |
745 | |
746 | void testCond01() { |
747 | const int N = 16; |
748 | PaddedBuffer<float> a_v(N); |
749 | BufHandle a_buf("a" , {N}, kFloat); |
750 | VarHandle index = VarHandle("index" , kInt); |
751 | StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2); |
752 | StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3); |
753 | ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); |
754 | StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); |
755 | StmtPtr for_stmt = For::make(index, 0, N, assign); |
756 | SimpleIREvaluator(for_stmt, {a_buf})(a_v); |
757 | |
758 | PaddedBuffer<float> a_ref(N); |
759 | for (const auto i : c10::irange(N)) { |
760 | if (i % 2 == 0) { |
761 | a_ref(i) = i * 2; |
762 | } else { |
763 | a_ref(i) = i * 3; |
764 | } |
765 | } |
766 | ExpectAllNear(a_v, a_ref, 1e-5); |
767 | } |
768 | |
769 | void testIfThenElse01() { |
770 | ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); |
771 | |
772 | std::ostringstream oss; |
773 | oss << v; |
774 | ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)" ); |
775 | |
776 | SimpleIRExprEval eval(v); |
777 | ASSERT_EQ(eval.value<float>(), 1.0f); |
778 | } |
779 | |
780 | void testIfThenElse02() { |
781 | ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); |
782 | |
783 | std::ostringstream oss; |
784 | oss << v; |
785 | ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)" ); |
786 | |
787 | SimpleIRExprEval eval(v); |
788 | ASSERT_EQ(eval.value<float>(), 2.0f); |
789 | } |
790 | |
791 | void testIfThenElse03() { |
792 | ExprHandle v = |
793 | ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); |
794 | |
795 | std::ostringstream oss; |
796 | oss << v; |
797 | ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)" ); |
798 | |
799 | SimpleIRExprEval eval(v); |
800 | ASSERT_EQ(eval.value<float>(), 2.0f); |
801 | } |
802 | |
803 | void testStmtClone() { |
804 | const int N = 16; |
805 | |
806 | BufHandle a_buf("a" , {N}, kInt); |
807 | VarHandle index = VarHandle("index" , kInt); |
808 | StmtPtr body = a_buf.store({index}, 5); |
809 | StmtPtr loop = For::make(index, 0, N, body); |
810 | |
811 | StmtPtr cloned_loop = Stmt::clone(loop); |
812 | std::vector<int> orig_loop_results(N); |
813 | std::vector<int> cloned_loop_results(N); |
814 | SimpleIREvaluator(loop, {a_buf})(orig_loop_results); |
815 | SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); |
816 | |
817 | assertAllEqual(orig_loop_results, 5); |
818 | assertAllEqual(cloned_loop_results, 5); |
819 | |
820 | // Let's add another assign to the body in the cloned loop and verify that the |
821 | // original statement hasn't changed while the cloned one has. |
822 | StmtPtr body_addition = a_buf.store({index}, 33); |
823 | BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body()); |
824 | cloned_body->append_stmt(body_addition); |
825 | |
826 | std::vector<int> orig_loop_results_after_mutation(N); |
827 | std::vector<int> cloned_loop_results_after_mutation(N); |
828 | SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); |
829 | SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); |
830 | |
831 | assertAllEqual(orig_loop_results_after_mutation, 5); |
832 | assertAllEqual(cloned_loop_results_after_mutation, 33); |
833 | } |
834 | |
835 | } // namespace jit |
836 | } // namespace torch |
837 | |