1#include <gtest/gtest.h>
2
3#include <test/cpp/tensorexpr/test_base.h>
4
5#include <c10/util/irange.h>
6#include <test/cpp/tensorexpr/padded_buffer.h>
7#include <test/cpp/tensorexpr/test_utils.h>
8#include <torch/csrc/jit/tensorexpr/eval.h>
9#include <torch/csrc/jit/tensorexpr/ir.h>
10#include <torch/csrc/jit/tensorexpr/ir_printer.h>
11#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
12#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
13#include <torch/csrc/jit/tensorexpr/loopnest.h>
14#include <torch/csrc/jit/tensorexpr/tensor.h>
15
16#include <cmath>
17#include <sstream>
18#include <stdexcept>
19#include <string>
20#include <vector>
21
22namespace torch {
23namespace jit {
24using namespace torch::jit::tensorexpr;
25
26using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
27
28TEST(Expr, BasicValueTest) {
29 ExprHandle a = IntImm::make(2), b = IntImm::make(3);
30 ExprHandle c = Add::make(a, b);
31 SimpleIRExprEval eval(c);
32 ASSERT_EQ(eval.value<int>(), 5);
33}
34
35TEST(Expr, BasicValueTest02) {
36 ExprHandle a(2.0f);
37 ExprHandle b(3.0f);
38 ExprHandle c(4.0f);
39 ExprHandle d(5.0f);
40 ExprHandle f = (a + b) - (c + d);
41 SimpleIRExprEval eval(f);
42 ASSERT_EQ(eval.value<float>(), -4.0f);
43}
44
45TEST(Expr, IsChannelsLastContiguous) {
46 std::vector<VarHandle> vars = {
47 VarHandle("var1", kLong),
48 VarHandle("var2", kLong),
49 VarHandle("var3", kLong),
50 VarHandle("var4", kLong),
51 VarHandle("var5", kLong)};
52
53 // {
54 // key: ndims,
55 // value: [
56 // ...
57 // [dim_2, dim_1, ..., dim_n]
58 // ]
59 // }
60 using shapGenInfo = std::unordered_map<int, std::vector<std::vector<int>>>;
61
62 // {
63 // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n],
64 // strides: [
65 // ...
66 // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z]
67 // ]
68 // }
69 using shapeInfo =
70 std::pair<std::vector<ExprHandle>, std::vector<std::vector<ExprHandle>>>;
71
72 std::vector<int> dims = {3, 4, 5};
73
74 std::unordered_map<int, std::vector<ExprHandle>> dims_expr_vec_conf = {
75 {3, std::vector<ExprHandle>(vars.begin(), vars.begin() + 2)},
76 {4, std::vector<ExprHandle>(vars.begin(), vars.begin() + 3)},
77 {5, std::vector<ExprHandle>(vars.begin(), vars.begin() + 4)},
78 };
79
80 shapGenInfo channels_last_cont_shape_conf = {
81 {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}};
82 shapGenInfo channels_last_non_cont_shape_conf = {
83 {3, {{2, 1, 0}, {1, 0, 2}}},
84 {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}},
85 {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}};
86
87 shapGenInfo cont_shape_conf = {
88 {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}};
89
90 auto shape_gen_fn = [dims_expr_vec_conf](
91 int ndims, shapGenInfo shape_gen_info) -> shapeInfo {
92 auto dims_expr_vec = dims_expr_vec_conf.at(ndims);
93 std::vector<std::vector<ExprHandle>> strides_expr_vec;
94 for (size_t i = 0; i < strides_expr_vec.size(); i++) {
95 strides_expr_vec[i].resize(ndims);
96 }
97
98 auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) {
99 if (indicator % 2 == 0) {
100 return a * b;
101 } else {
102 return b * a;
103 }
104 };
105
106 auto stride_order_vec = shape_gen_info.at(ndims);
107 for (size_t i = 0; i < strides_expr_vec.size(); i++) {
108 auto stride_order = stride_order_vec[i];
109
110 strides_expr_vec[i][stride_order[0]] = 1;
111 for (size_t j = 1; j < stride_order.size(); j++) {
112 auto cur_dim_idx = stride_order[j];
113 auto adjacent_dim_idx = stride_order[j - 1];
114
115 strides_expr_vec[i][cur_dim_idx] = stride_gen_fn(
116 i,
117 dims_expr_vec[adjacent_dim_idx],
118 strides_expr_vec[i][adjacent_dim_idx]);
119 }
120 }
121
122 return {dims_expr_vec, strides_expr_vec};
123 };
124
125 auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool {
126 if (ndims == 3) {
127 return buf_handle.is_channels_last_1d_contiguous();
128 } else if (ndims == 4) {
129 return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast);
130 } else {
131 return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d);
132 }
133 };
134
135 // channels-last contigous
136 for (size_t i = 0; i < dims.size(); i++) {
137 auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
138 for (size_t j = 0; j < shape_info.second.size(); j++) {
139 BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
140 ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true);
141 }
142 }
143
144 // channels-last non-contigous
145 for (size_t i = 0; i < dims.size(); i++) {
146 auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf);
147 for (size_t j = 0; j < shape_info.second.size(); j++) {
148 BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
149 ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false);
150 }
151 }
152
153 // contiguous
154 for (size_t i = 0; i < dims.size(); i++) {
155 auto shape_info = shape_gen_fn(dims[i], cont_shape_conf);
156 for (size_t j = 0; j < shape_info.second.size(); j++) {
157 BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
158 ASSERT_EQ(buf_handle.is_contiguous(), true);
159 }
160 }
161
162 // non-contiguous
163 for (size_t i = 0; i < dims.size(); i++) {
164 auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
165 for (size_t j = 0; j < shape_info.second.size(); j++) {
166 BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
167 ASSERT_EQ(buf_handle.is_contiguous(), false);
168 }
169 }
170}
171
172TEST(Expr, LetTest01) {
173 VarHandle x("x", kFloat);
174 ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
175 SimpleIRExprEval eval(body);
176 eval.bindVar(x, ExprHandle(3.f));
177 ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
178}
179
180TEST(Expr, LetTest02) {
181 VarHandle x("x", kFloat);
182 VarHandle y("y", kFloat);
183 ExprHandle body =
184 ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
185 SimpleIRExprEval eval(body);
186 eval.bindVar(x, ExprHandle(3.f));
187 eval.bindVar(y, ExprHandle(6.f));
188 ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
189}
190
191TEST(Expr, LetStmtTest01) {
192 BufHandle a_buf("a", {1}, kFloat);
193 BufHandle b_buf("b", {1}, kFloat);
194
195 ExprHandle load_a = a_buf.load(0);
196 VarHandle var = VarHandle("v", kFloat);
197 StmtPtr let_store = Let::make(var, load_a);
198 StmtPtr store_b = b_buf.store({0}, var);
199 BlockPtr block = Block::make({let_store, store_b});
200
201 SimpleIREvaluator eval(block, {a_buf, b_buf});
202
203 PaddedBuffer<float> a_v(1);
204 PaddedBuffer<float> b_v(1);
205 PaddedBuffer<float> b_ref(1);
206
207 a_v(0) = 23;
208 b_ref(0) = a_v(0);
209 eval(a_v, b_v);
210
211 ExpectAllNear(b_v, b_ref, 1e-5);
212}
213
214TEST(Expr, IntTest) {
215 VarHandle x("x", kInt);
216 ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
217 SimpleIRExprEval eval(body);
218 eval.bindVar(x, ExprHandle(3));
219 ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
220}
221
222TEST(Expr, FloatTest) {
223 VarHandle x("x", kFloat);
224 ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
225 SimpleIRExprEval eval(body);
226 eval.bindVar(x, ExprHandle(3.f));
227 ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
228}
229
230TEST(Expr, ByteTest) {
231 VarHandle x("x", kByte);
232 ExprHandle body = ExprHandle((uint8_t)2) +
233 (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
234 SimpleIRExprEval eval(body);
235 eval.bindVar(x, ExprHandle((uint8_t)3));
236 ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
237}
238
239TEST(Expr, CharTest) {
240 VarHandle x("x", kChar);
241 ExprHandle body = ExprHandle((int8_t)2) +
242 (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
243 SimpleIRExprEval eval(body);
244 eval.bindVar(x, ExprHandle((int8_t)3));
245 ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
246}
247
248TEST(Expr, ShortTest) {
249 VarHandle x("x", kShort);
250 ExprHandle body = ExprHandle((int16_t)2) +
251 (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
252 SimpleIRExprEval eval(body);
253 eval.bindVar(x, ExprHandle((int16_t)3));
254 ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
255}
256
257TEST(Expr, LongTest) {
258 VarHandle x("x", kLong);
259 ExprHandle body = ExprHandle((int64_t)2) +
260 (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
261 SimpleIRExprEval eval(body);
262 eval.bindVar(x, ExprHandle((int64_t)3));
263 ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
264}
265
266TEST(Expr, HalfTest) {
267 VarHandle x("x", kHalf);
268 ExprHandle body = ExprHandle((at::Half)2) +
269 (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
270 SimpleIRExprEval eval(body);
271 eval.bindVar(x, ExprHandle((at::Half)3));
272 ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
273}
274
275TEST(Expr, DoubleTest) {
276 VarHandle x("x", kDouble);
277 ExprHandle body = ExprHandle((double)2) +
278 (x * ExprHandle((double)3) + ExprHandle((double)4));
279 SimpleIRExprEval eval(body);
280 eval.bindVar(x, ExprHandle((double)3));
281 ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
282}
283
284TEST(Expr, VectorAdd01) {
285 const int kVectorSize = 8;
286 const int kVectorCount = 128;
287 const int kTotalSize = kVectorSize * kVectorCount;
288
289 BufHandle a_buf("A", {kTotalSize}, kFloat);
290 BufHandle b_buf("B", {kTotalSize}, kFloat);
291 BufHandle c_buf("C", {kTotalSize}, kFloat);
292
293 /*
294 Build the following:
295 for (const auto index : c10::irange(kVectorCount)) {
296 store(c_buf, ramp(index * 8, 1, 8),
297 load(a_buf, ramp(index * 8, 1, 8) +
298 load(b_buf, ramp(index * 8, 1, 8))))
299 }
300 */
301 VarHandle index = VarHandle("index", kInt);
302 ExprHandle load_a =
303 a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
304 ExprHandle load_b =
305 b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
306 ExprHandle value = load_a + load_b;
307 StmtPtr store_c =
308 c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
309 StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
310
311 ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
312 ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
313 ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
314
315 PaddedBuffer<float> a_v(kTotalSize);
316 PaddedBuffer<float> b_v(kTotalSize);
317 PaddedBuffer<float> c_v(kTotalSize);
318 PaddedBuffer<float> c_ref(kTotalSize);
319 for (const auto i : c10::irange(kTotalSize)) {
320 a_v(i) = i * i;
321 b_v(i) = i * i * 4;
322 c_ref(i) = a_v(i) + b_v(i);
323 }
324 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
325 ir_eval(a_v, b_v, c_v);
326 ExpectAllNear(c_v, c_ref, 1e-5);
327}
328
329TEST(Expr, CompareSelectEQ) {
330 constexpr int N = 1024;
331 BufHandle a("A", {N}, kInt);
332 BufHandle b("B", {N}, kInt);
333 BufHandle c("C", {N}, kInt);
334 std::vector<int> a_buffer(N, 1);
335 std::vector<int> b_buffer(N, 1);
336 std::vector<int> c_buffer(N, 0);
337 std::vector<int> c_ref(N, 0);
338
339 VarHandle i("i", kInt);
340 auto memcpy_expr = For::make(
341 i,
342 0,
343 N,
344 c.store(
345 {i},
346 CompareSelect::make(
347 a.load(i), b.load(i), CompareSelectOperation::kEQ)));
348
349 SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
350 ir_eval(a_buffer, b_buffer, c_buffer);
351
352 ASSERT_EQ(a_buffer.size(), N);
353 ASSERT_EQ(b_buffer.size(), N);
354 ASSERT_EQ(c_buffer.size(), N);
355
356 assertAllEqual(a_buffer, 1);
357 assertAllEqual(b_buffer, 1);
358 assertAllEqual(c_buffer, 1);
359}
360
361TEST(Expr, CompareSelectDtypes) {
362 // LHS and RHS expressions should have the same dtype, but this dtype could
363 // differ from the dtype of the return values (but dtypes of true and false
364 // return values should be the same).
365 // This test constructs a CompareSelect expression where the input dtype is
366 // different from the output dtype and verifies that it works correctly:
367 // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2
368 constexpr int N = 1024;
369 BufHandle a("A", {N}, kInt);
370 BufHandle b("B", {N}, kInt);
371 BufHandle c("C", {N}, kFloat);
372 std::vector<int> a_buffer(N, 1);
373 std::vector<int> b_buffer(N, 1);
374 std::vector<float> c_buffer(N, 0.0f);
375 std::vector<float> c_ref(N, 3.14f);
376
377 VarHandle i("i", kInt);
378 // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f
379 // A and B are int, C is float.
380 auto select_expr = For::make(
381 i,
382 0,
383 N,
384 c.store(
385 {i},
386 CompareSelect::make(
387 a.load(i),
388 b.load(i),
389 FloatImm::make(3.14f),
390 FloatImm::make(2.78f),
391 CompareSelectOperation::kEQ)));
392
393 SimpleIREvaluator ir_eval(select_expr, {a, b, c});
394 ir_eval(a_buffer, b_buffer, c_buffer);
395
396 ASSERT_EQ(a_buffer.size(), N);
397 ASSERT_EQ(b_buffer.size(), N);
398 ASSERT_EQ(c_buffer.size(), N);
399
400 assertAllEqual(a_buffer, 1);
401 assertAllEqual(b_buffer, 1);
402 ExpectAllNear(c_buffer, c_ref, 1e-7);
403}
404
405TEST(Expr, IntrinsicsDtypes) {
406 constexpr int N = 256;
407 BufHandle a("A", {N}, kDouble);
408 BufHandle b("B", {N}, kDouble);
409 std::vector<double> a_buffer(N, -10.0);
410 std::vector<double> b_buffer(N, 0.0);
411 std::vector<double> b_ref(N, 10.0);
412
413 VarHandle i("i", kInt);
414 auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i))));
415
416 SimpleIREvaluator ir_eval(abs_expr, {a, b});
417 ir_eval(a_buffer, b_buffer);
418
419 ASSERT_EQ(a_buffer.size(), N);
420 ASSERT_EQ(b_buffer.size(), N);
421
422 assertAllEqual(a_buffer, -10.0);
423 ExpectAllNear(b_buffer, b_ref, 1e-7);
424}
425
426TEST(Expr, Substitute01) {
427 VarPtr x = alloc<Var>("x", kFloat);
428 VarPtr y = alloc<Var>("y", kFloat);
429 ExprPtr e =
430 alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
431
432 VarPtr z = alloc<Var>("z", kFloat);
433 ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
434 ExprPtr e2_ref = alloc<Mul>(
435 alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
436 alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
437 std::ostringstream oss;
438 oss << *e2;
439 std::string e2_str = oss.str();
440
441 oss.str("");
442 oss << *e2_ref;
443 std::string e2_ref_str = oss.str();
444 ASSERT_EQ(e2_str, e2_ref_str);
445}
446
447TEST(Expr, Math01) {
448 ExprHandle v = sin(ExprHandle(1.0f));
449
450 std::ostringstream oss;
451 oss << v;
452 ASSERT_EQ(oss.str(), "sin(1.f)");
453
454 SimpleIRExprEval eval(v);
455 float v_ref = std::sin(1.0f);
456 float res = eval.value<float>();
457 ASSERT_NEAR(res, v_ref, 1e-6);
458}
459
460TEST(Expr, UnaryMath01) {
461 struct TestConfig {
462 std::function<ExprHandle(const ExprHandle&)> func;
463 std::function<float(float)> ref_func;
464 };
465
466 std::vector<TestConfig> test_configs = {
467 {[](const ExprHandle& v) { return sin(v); },
468 [](float v) { return std::sin(v); }},
469 {[](const ExprHandle& v) { return sin(v); },
470 [](float v) { return std::sin(v); }},
471 {[](const ExprHandle& v) { return tan(v); },
472 [](float v) { return std::tan(v); }},
473 {[](const ExprHandle& v) { return asin(v); },
474 [](float v) { return std::asin(v); }},
475 {[](const ExprHandle& v) { return acos(v); },
476 [](float v) { return std::acos(v); }},
477 {[](const ExprHandle& v) { return atan(v); },
478 [](float v) { return std::atan(v); }},
479 {[](const ExprHandle& v) { return sinh(v); },
480 [](float v) { return std::sinh(v); }},
481 {[](const ExprHandle& v) { return cosh(v); },
482 [](float v) { return std::cosh(v); }},
483 {[](const ExprHandle& v) { return tanh(v); },
484 [](float v) { return std::tanh(v); }},
485 {[](const ExprHandle& v) { return exp(v); },
486 [](float v) { return std::exp(v); }},
487 {[](const ExprHandle& v) { return tensorexpr::abs(v); },
488 [](float v) { return std::fabs(v); }},
489 {[](const ExprHandle& v) { return log(v); },
490 [](float v) { return std::log(v); }},
491 {[](const ExprHandle& v) { return log2(v); },
492 [](float v) { return std::log2(v); }},
493 {[](const ExprHandle& v) { return log10(v); },
494 [](float v) { return std::log10(v); }},
495 {[](const ExprHandle& v) { return erf(v); },
496 [](float v) { return std::erf(v); }},
497 {[](const ExprHandle& v) { return sqrt(v); },
498 [](float v) { return std::sqrt(v); }},
499 {[](const ExprHandle& v) { return rsqrt(v); },
500 [](float v) { return 1.0f / std::sqrt(v); }},
501 {[](const ExprHandle& v) { return ceil(v); },
502 [](float v) { return std::ceil(v); }},
503 {[](const ExprHandle& v) { return floor(v); },
504 [](float v) { return std::floor(v); }},
505 {[](const ExprHandle& v) { return round(v); },
506 [](float v) { return std::round(v); }},
507 {[](const ExprHandle& v) { return trunc(v); },
508 [](float v) { return std::trunc(v); }},
509 };
510
511 for (const TestConfig& test_config : test_configs) {
512 const float input_v = 0.8765f;
513 ExprHandle v = test_config.func(ExprHandle(input_v));
514 float v_ref = test_config.ref_func(input_v);
515 SimpleIRExprEval eval(v);
516 ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
517 }
518
519 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
520 for (float input_v : {std::nan("1"), 0., .5}) {
521 ExprHandle v = FloatImm::make(input_v);
522 SimpleIRExprEval eval(Intrinsics::make(kIsNan, v));
523 ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0);
524 }
525}
526
527TEST(Expr, BinaryMath01) {
528 struct TestConfig {
529 std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func;
530 std::function<float(float, float)> ref_func;
531 };
532
533 std::vector<TestConfig> test_configs = {
534 {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); },
535 [](float v1, float v2) { return std::pow(v1, v2); }},
536 {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); },
537 [](float v1, float v2) { return std::fmod(v1, v2); }},
538 };
539
540 for (const TestConfig& test_config : test_configs) {
541 const float v1 = 0.8765f;
542 float v2 = 1.2345f;
543 ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
544 float v_ref = test_config.ref_func(v1, v2);
545 SimpleIRExprEval eval(v_expr);
546 ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
547 }
548}
549
550TEST(Expr, LogicalOps01) {
551 ExprHandle a(23);
552 ExprHandle b(11);
553 ExprHandle c(0.72f);
554 ExprHandle d(0.69f);
555 ExprHandle f1 = (a > b) && (c > d);
556 ExprHandle f2 = (a > b) && (c < d);
557 ExprHandle f3 = (a < b) && (c > d);
558 ExprHandle f4 = (a < b) && (c < d);
559 ExprHandle f5 = (a < b) || (c > d);
560 ExprHandle f6 = (a < b) || (c < d);
561 ExprHandle f7 = (a > b) || (c < d);
562 ExprHandle f8 = (a > b) || (c > d);
563
564 SimpleIRExprEval eval1(f1);
565 SimpleIRExprEval eval2(f2);
566 SimpleIRExprEval eval3(f3);
567 SimpleIRExprEval eval4(f4);
568 SimpleIRExprEval eval5(f5);
569 SimpleIRExprEval eval6(f6);
570 SimpleIRExprEval eval7(f7);
571 SimpleIRExprEval eval8(f8);
572 ASSERT_EQ(eval1.value<int>(), 1);
573 ASSERT_EQ(eval2.value<int>(), 0);
574 ASSERT_EQ(eval3.value<int>(), 0);
575 ASSERT_EQ(eval4.value<int>(), 0);
576 ASSERT_EQ(eval5.value<int>(), 1);
577 ASSERT_EQ(eval6.value<int>(), 0);
578 ASSERT_EQ(eval7.value<int>(), 1);
579 ASSERT_EQ(eval8.value<int>(), 1);
580}
581
582TEST(Expr, LogicalOps02) {
583 ExprHandle a(23);
584 ExprHandle b(11);
585 ExprHandle c(0.72f);
586 ExprHandle d(0.72f);
587
588 ExprHandle f1 = (a > b) || (c > d);
589 ExprHandle f2 = (a > b) && (c <= d);
590 ExprHandle f3 = (a > b) && (c > d);
591 ExprHandle ff1 = f1 && f2;
592 ExprHandle ff2 = f2 || f3;
593
594 SimpleIRExprEval eval1(ff1);
595 SimpleIRExprEval eval2(ff2);
596 ASSERT_EQ(eval1.value<int>(), 1);
597 ASSERT_EQ(eval2.value<int>(), 1);
598}
599
600TEST(Expr, LogicalOps03) {
601 ExprHandle a(23);
602 ExprHandle b(11);
603 ExprHandle c(0.72f);
604 ExprHandle d(0.69f);
605
606 // Bool types
607 ExprHandle bool_f1 = (a > b) && BoolImm::make(true);
608 ExprHandle bool_f2 = (c <= d) || BoolImm::make(true);
609
610 // Int types
611 ExprHandle int_f1 = (a > b) && IntImm::make(1);
612 ExprHandle int_f2 = (c <= d) || IntImm::make(1);
613
614 // Short types
615 ExprHandle short_f1 = (a > b) && ShortImm::make(1);
616 ExprHandle short_f2 = (c <= d) || ShortImm::make(1);
617
618 // Long types
619 ExprHandle long_f1 = (a > b) && LongImm::make(1);
620 ExprHandle long_f2 = (c <= d) || LongImm::make(1);
621
622 // Char types
623 ExprHandle char_f1 = (a > b) && CharImm::make(1);
624 ExprHandle char_f2 = (c <= d) || CharImm::make(1);
625
626 // Byte types
627 ExprHandle byte_f1 = (a > b) && ByteImm::make(1);
628 ExprHandle byte_f2 = (c <= d) || ByteImm::make(1);
629
630 SimpleIRExprEval eval1(bool_f1);
631 SimpleIRExprEval eval2(bool_f2);
632 SimpleIRExprEval eval3(int_f1);
633 SimpleIRExprEval eval4(int_f2);
634 SimpleIRExprEval eval5(short_f1);
635 SimpleIRExprEval eval6(short_f2);
636 SimpleIRExprEval eval7(long_f1);
637 SimpleIRExprEval eval8(long_f2);
638 SimpleIRExprEval eval9(char_f1);
639 SimpleIRExprEval eval10(char_f2);
640 SimpleIRExprEval eval11(byte_f1);
641 SimpleIRExprEval eval12(byte_f2);
642
643 ASSERT_EQ(eval1.value<bool>(), true);
644 ASSERT_EQ(eval2.value<bool>(), true);
645 ASSERT_EQ(eval3.value<int>(), 1);
646 ASSERT_EQ(eval4.value<int>(), 1);
647 ASSERT_EQ(eval5.value<int16_t>(), 1);
648 ASSERT_EQ(eval6.value<int16_t>(), 1);
649 ASSERT_EQ(eval7.value<int64_t>(), 1);
650 ASSERT_EQ(eval8.value<int64_t>(), 1);
651 ASSERT_EQ(eval9.value<int8_t>(), 1);
652 ASSERT_EQ(eval10.value<int8_t>(), 1);
653 ASSERT_EQ(eval11.value<uint8_t>(), 1);
654 ASSERT_EQ(eval12.value<uint8_t>(), 1);
655}
656
657TEST(Expr, BitwiseOps) {
658 ExprHandle a(59);
659 ExprHandle b(11);
660 ExprHandle c(101);
661 ExprHandle d(2);
662 ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
663
664 SimpleIRExprEval eval(f);
665 ASSERT_EQ(eval.value<int>(), 11);
666}
667
668TEST(Expr, DynamicShapeAdd) {
669 auto testWithSize = [](int32_t size) {
670 VarHandle n("n", kInt);
671 BufHandle a("a", {n}, kFloat);
672 BufHandle b("b", {n}, kFloat);
673 BufHandle c("c", {n}, kFloat);
674 VarHandle i("i", kInt);
675 StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
676 std::vector<float> aData(size, 1.0f);
677 std::vector<float> bData(size, 2.0f);
678 std::vector<float> cData(size, 0.0f);
679 SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size);
680 ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
681 };
682 testWithSize(1);
683 testWithSize(16);
684 testWithSize(37);
685}
686
687TEST(Expr, OutOfBounds) {
688 ExprHandle N(10);
689 ExprHandle start(0);
690 ExprHandle stop(15);
691 VarHandle i("i", kInt);
692
693 BufHandle X("X", {N}, kInt);
694
695 auto body = Store::make(X, {i}, i);
696 auto stmt = For::make(i, start, stop, body);
697
698 PaddedBuffer<int> data(20);
699
700 EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
701}
702
703TEST(Expr, OutOfBounds2d) {
704 std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
705 for (auto sizes : size_options) {
706 ExprHandle N(sizes.first);
707 ExprHandle M(sizes.second);
708 ExprHandle start(0);
709 ExprHandle stopInner(15);
710 ExprHandle stopOuter(15);
711 VarHandle i("i", kInt);
712 VarHandle j("j", kInt);
713
714 BufHandle X("X", {N, M}, kInt);
715
716 auto body = Store::make(X, {i, j}, i);
717 auto inner = For::make(j, start, stopInner, body);
718 auto stmt = For::make(i, start, stopOuter, inner);
719
720 PaddedBuffer<int> data(400);
721
722 EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
723 }
724}
725
726TEST(Expr, OutOfBounds2dFlattenedIndex) {
727 ExprHandle buf_size(149);
728 ExprHandle start(0);
729 ExprHandle stopInner(15);
730 ExprHandle stopOuter(10);
731 VarHandle i("i", kInt);
732 VarHandle j("j", kInt);
733
734 BufHandle X("X", {buf_size}, kInt);
735
736 auto idx = Add::make(Mul::make(i, stopInner), j);
737 auto body = Store::make(X, {idx}, i);
738 auto inner = For::make(j, start, stopInner, body);
739 auto stmt = For::make(i, start, stopOuter, inner);
740
741 PaddedBuffer<int> data(400);
742
743 EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
744}
745
746void testCond01() {
747 const int N = 16;
748 PaddedBuffer<float> a_v(N);
749 BufHandle a_buf("a", {N}, kFloat);
750 VarHandle index = VarHandle("index", kInt);
751 StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
752 StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
753 ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
754 StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
755 StmtPtr for_stmt = For::make(index, 0, N, assign);
756 SimpleIREvaluator(for_stmt, {a_buf})(a_v);
757
758 PaddedBuffer<float> a_ref(N);
759 for (const auto i : c10::irange(N)) {
760 if (i % 2 == 0) {
761 a_ref(i) = i * 2;
762 } else {
763 a_ref(i) = i * 3;
764 }
765 }
766 ExpectAllNear(a_v, a_ref, 1e-5);
767}
768
769void testIfThenElse01() {
770 ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f));
771
772 std::ostringstream oss;
773 oss << v;
774 ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)");
775
776 SimpleIRExprEval eval(v);
777 ASSERT_EQ(eval.value<float>(), 1.0f);
778}
779
780void testIfThenElse02() {
781 ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f));
782
783 std::ostringstream oss;
784 oss << v;
785 ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
786
787 SimpleIRExprEval eval(v);
788 ASSERT_EQ(eval.value<float>(), 2.0f);
789}
790
791void testIfThenElse03() {
792 ExprHandle v =
793 ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f));
794
795 std::ostringstream oss;
796 oss << v;
797 ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
798
799 SimpleIRExprEval eval(v);
800 ASSERT_EQ(eval.value<float>(), 2.0f);
801}
802
803void testStmtClone() {
804 const int N = 16;
805
806 BufHandle a_buf("a", {N}, kInt);
807 VarHandle index = VarHandle("index", kInt);
808 StmtPtr body = a_buf.store({index}, 5);
809 StmtPtr loop = For::make(index, 0, N, body);
810
811 StmtPtr cloned_loop = Stmt::clone(loop);
812 std::vector<int> orig_loop_results(N);
813 std::vector<int> cloned_loop_results(N);
814 SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
815 SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results);
816
817 assertAllEqual(orig_loop_results, 5);
818 assertAllEqual(cloned_loop_results, 5);
819
820 // Let's add another assign to the body in the cloned loop and verify that the
821 // original statement hasn't changed while the cloned one has.
822 StmtPtr body_addition = a_buf.store({index}, 33);
823 BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
824 cloned_body->append_stmt(body_addition);
825
826 std::vector<int> orig_loop_results_after_mutation(N);
827 std::vector<int> cloned_loop_results_after_mutation(N);
828 SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation);
829 SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation);
830
831 assertAllEqual(orig_loop_results_after_mutation, 5);
832 assertAllEqual(cloned_loop_results_after_mutation, 33);
833}
834
835} // namespace jit
836} // namespace torch
837