1#include <gtest/gtest.h>
2#include <test/cpp/tensorexpr/test_base.h>
3
4#include <c10/util/irange.h>
5#include <test/cpp/tensorexpr/test_utils.h>
6#include <torch/csrc/jit/tensorexpr/hash_provider.h>
7#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8#include <torch/csrc/jit/tensorexpr/loopnest.h>
9
10#include <cmath>
11
12namespace torch {
13namespace jit {
14using namespace torch::jit::tensorexpr;
15using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
16
17TEST(Simplify, ConstantFoldSimple) {
18 ExprHandle a(2.0f);
19 ExprHandle b(3.0f);
20 ExprHandle f = (a + b);
21
22 ExprHandle newF = IRSimplifier::simplify(f);
23 ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
24 ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 5);
25
26 SimpleIRExprEval eval(newF);
27 ASSERT_EQ(eval.value<float>(), 5.f);
28}
29
30TEST(Simplify, ConstantFoldTwoLayer) {
31 ExprHandle a(2.0f);
32 ExprHandle b(3.0f);
33 ExprHandle c(4.0f);
34 ExprHandle d(5.0f);
35 ExprHandle f = (a + b) - (c + d);
36
37 ExprHandle newF = IRSimplifier::simplify(f);
38 ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
39 ASSERT_EQ(newF.AsNode<FloatImm>()->value(), -4);
40
41 SimpleIRExprEval eval(newF);
42 ASSERT_EQ(eval.value<float>(), -4.f);
43}
44
45TEST(Simplify, ConstantFoldShifts) {
46 ExprHandle a(7);
47 ExprHandle b(2);
48 ExprHandle c(3);
49 ExprHandle f = ((a << b) << b) >> c;
50
51 ExprHandle newF = IRSimplifier::simplify(f);
52 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
53 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 14);
54
55 SimpleIRExprEval eval(newF);
56 ASSERT_EQ(eval.value<int>(), 7 << (4 - 3));
57}
58
59TEST(Simplify, ConstantFoldBitwise) {
60 ExprHandle a(59);
61 ExprHandle b(22);
62 ExprHandle c(101);
63 ExprHandle f = (a ^ b) & c;
64
65 ExprHandle newF = IRSimplifier::simplify(f);
66 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
67 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 37);
68
69 SimpleIRExprEval eval(newF);
70 ASSERT_EQ(eval.value<int>(), (59 ^ 22) & 101);
71}
72
73TEST(Simplify, ConstantFoldMultiOp) {
74 ExprHandle a(2.0f);
75 ExprHandle b(3.0f);
76 ExprHandle c(4.0f);
77 ExprHandle d(5.0f);
78 ExprHandle e(6.0f);
79 ExprHandle f(7.0f);
80 ExprHandle fn = ((a / e) - (c + d)) * (f / b);
81
82 ExprHandle newF = IRSimplifier::simplify(fn);
83 ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
84
85 SimpleIRExprEval eval(newF);
86 SimpleIRExprEval ref(fn);
87
88 ASSERT_EQ(eval.value<float>(), ref.value<float>());
89}
90
91TEST(Simplify, ConstantFoldMinMax) {
92 ExprHandle a(12.0f);
93 ExprHandle b(15.0f);
94 ExprHandle c(17.0f);
95
96 // x = max(12, min(15, 17)).
97 ExprHandle minHandle = Min::make(b, c, true);
98 ExprHandle fn = Max::make(a, minHandle, false);
99
100 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
101 ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
102
103 ExprHandle newF = IRSimplifier::simplify(fn);
104 ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
105
106 SimpleIRExprEval eval(newF);
107 ASSERT_EQ(eval.value<float>(), 15.f);
108}
109
110TEST(Simplify, ConstantFoldIntrinsics) {
111 ExprHandle a(2.0f);
112 ExprHandle b(3.0f);
113 ExprHandle c(4.0f);
114 ExprHandle powHandle = Intrinsics::make(kPow, a, b);
115 ExprHandle sinHandle = Intrinsics::make(kSin, powHandle);
116 ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle);
117 ExprHandle logHandle = Intrinsics::make(kLog10, modHandle);
118 ExprHandle rndHandle = Intrinsics::make(kRound, logHandle);
119 ExprHandle fn = Intrinsics::make(kAbs, rndHandle);
120
121 ExprHandle newF = IRSimplifier::simplify(fn);
122 ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
123 ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 1);
124
125 SimpleIRExprEval eval(newF);
126 SimpleIRExprEval ref(fn);
127
128 ASSERT_EQ(eval.value<float>(), ref.value<float>());
129}
130
131TEST(Simplify, ConstantFoldCastToBool) {
132 ExprHandle f = Cast::make(kBool, IntImm::make(0));
133 ExprHandle newF = IRSimplifier::simplify(f);
134 SimpleIRExprEval eval(newF);
135 ASSERT_EQ(eval.value<bool>(), false);
136}
137
138TEST(Simplify, ConstantFoldWithVar) {
139 {
140 VarHandle x("x", kInt);
141 ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
142
143 ExprHandle newF = IRSimplifier::simplify(body);
144 MulPtr root = newF.AsNode<Mul>();
145 ASSERT_NE(root, nullptr);
146 ASSERT_NE(to<IntImm>(root->lhs()), nullptr);
147
148 SimpleIRExprEval eval(newF);
149 eval.bindVar(x, ExprHandle(3));
150 ASSERT_EQ(eval.value<int>(), 3 * (2 + 4));
151 }
152
153 {
154 VarHandle x("x", kFloat);
155 ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
156
157 ExprHandle newF = IRSimplifier::simplify(body);
158 MulPtr root = newF.AsNode<Mul>();
159 ASSERT_NE(root, nullptr);
160 ASSERT_NE(to<FloatImm>(root->rhs()), nullptr);
161
162 SimpleIRExprEval eval(newF);
163 eval.bindVar(x, ExprHandle(3.f));
164 ASSERT_EQ(eval.value<float>(), 3 * (2 + 4));
165 }
166}
167
168TEST(Simplify, ConditionalSelectFoldSimple) {
169 ExprHandle a(3.0f);
170 ExprHandle b(4.0f);
171 ExprHandle c(3.0f);
172 {
173 ExprHandle f = (a > b);
174
175 ExprHandle newF = IRSimplifier::simplify(f);
176 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
177 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
178
179 SimpleIRExprEval eval(newF);
180 ASSERT_EQ(eval.value<int>(), 0);
181 }
182 {
183 ExprHandle f = (a < b);
184
185 ExprHandle newF = IRSimplifier::simplify(f);
186 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
187 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
188
189 SimpleIRExprEval eval(newF);
190 ASSERT_EQ(eval.value<int>(), 1);
191 }
192 {
193 ExprHandle f = (a == c);
194
195 ExprHandle newF = IRSimplifier::simplify(f);
196 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
197 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
198
199 SimpleIRExprEval eval(newF);
200 ASSERT_EQ(eval.value<int>(), 1);
201 }
202 {
203 ExprHandle f = (a != c);
204
205 ExprHandle newF = IRSimplifier::simplify(f);
206 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
207 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
208
209 SimpleIRExprEval eval(newF);
210 ASSERT_EQ(eval.value<int>(), 0);
211 }
212}
213
214TEST(Simplify, ConditionalSelectFoldTwoLayer) {
215 ExprHandle a(3.0f);
216 ExprHandle b(2.0f);
217 ExprHandle c(2.0f);
218 ExprHandle d(1.0f);
219 {
220 ExprHandle f = (a + b < c + d);
221
222 ExprHandle newF = IRSimplifier::simplify(f);
223 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
224 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
225
226 SimpleIRExprEval eval(newF);
227 ASSERT_EQ(eval.value<int>(), 0);
228 }
229 {
230 ExprHandle f = (a + b > c + d);
231
232 ExprHandle newF = IRSimplifier::simplify(f);
233 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
234 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
235
236 SimpleIRExprEval eval(newF);
237 ASSERT_EQ(eval.value<int>(), 1);
238 }
239 {
240 ExprHandle f = (a + d == b + c);
241
242 ExprHandle newF = IRSimplifier::simplify(f);
243 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
244 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1);
245
246 SimpleIRExprEval eval(newF);
247 ASSERT_EQ(eval.value<int>(), 1);
248 }
249 {
250 ExprHandle f = (a + d != b + c);
251
252 ExprHandle newF = IRSimplifier::simplify(f);
253 ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
254 ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0);
255
256 SimpleIRExprEval eval(newF);
257 ASSERT_EQ(eval.value<int>(), 0);
258 }
259}
260
261TEST(Simplify, ConditionalSelectFoldWithVar) {
262 VarHandle x("x", kFloat);
263 ExprHandle f = x < 4.f;
264
265 ExprHandle newF = IRSimplifier::simplify(f);
266 IntImmPtr folded = newF.AsNode<IntImm>();
267 ASSERT_EQ(folded, nullptr);
268
269 {
270 SimpleIRExprEval eval(newF);
271 eval.bindVar(x, ExprHandle(3.f));
272 ASSERT_EQ(eval.value<int>(), 1);
273 }
274 {
275 SimpleIRExprEval eval(newF);
276 eval.bindVar(x, ExprHandle(5.f));
277 ASSERT_EQ(eval.value<int>(), 0);
278 }
279}
280
281TEST(Simplify, UnFoldableExpr) {
282 VarHandle x("x", kFloat);
283 VarHandle y("y", kFloat);
284 ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y);
285
286 ExprHandle newF = IRSimplifier::simplify(body);
287 AddPtr root = newF.AsNode<Add>();
288 ASSERT_NE(root, nullptr);
289 ASSERT_EQ(to<FloatImm>(root->lhs()), nullptr);
290 ASSERT_EQ(to<FloatImm>(root->rhs()), nullptr);
291
292 SimpleIRExprEval eval(newF);
293 eval.bindVar(x, ExprHandle(3.f));
294 eval.bindVar(y, ExprHandle(2.f));
295 ASSERT_EQ(eval.value<float>(), 9 + 10);
296}
297
298TEST(Simplify, HashSimple) {
299 VarHandle x("x", kFloat);
300 ExprHandle a(2.0f);
301 ExprHandle b(3.0f);
302 ExprHandle f = a + b * x;
303
304 HashProvider hasher;
305
306 auto hash_x = hasher.hash(x.node());
307 auto hash_a = hasher.hash(a.node());
308 auto hash_f = hasher.hash(f.node());
309
310 ASSERT_NE(hash_x, (size_t)0);
311 ASSERT_NE(hash_a, (size_t)0);
312 ASSERT_NE(hash_f, (size_t)0);
313 ASSERT_NE(hash_x, hash_a);
314 ASSERT_NE(hash_x, hash_f);
315 ASSERT_NE(hash_a, hash_f);
316}
317
318TEST(Simplify, HashEquivalence) {
319 VarHandle x("x", kFloat);
320 VarHandle y("y", kFloat);
321 ExprHandle f = (x * y) + (x * y);
322
323 AddPtr root = f.AsNode<Add>();
324 ASSERT_NE(root, nullptr);
325
326 HashProvider hasher;
327 auto hash_f = hasher.hash(f.node());
328 auto hash_l = hasher.hash(root->lhs());
329 auto hash_r = hasher.hash(root->rhs());
330
331 // Root not equal to either branch.
332 ASSERT_NE(hash_f, hash_l);
333 ASSERT_NE(hash_f, hash_r);
334 // but branches are equal.
335 ASSERT_EQ(hash_l, hash_r);
336
337 // Still equivalent if separate.
338 ExprHandle a(2);
339 ExprHandle f2 = x + a / y;
340 ExprHandle b(2);
341 ExprHandle f3 = x + b / y;
342 ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
343
344 // Not equivalent if different vars (even with same name).
345 VarHandle z("x", kFloat);
346 ExprHandle f4 = z + b / y;
347 ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
348
349 // Intrinsics sanity check.
350 ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x);
351 ASSERT_NE(hasher.hash(f5.node()), (size_t)0);
352}
353
354TEST(Simplify, HashEquivalenceRand) {
355 ExprHandle f =
356 Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt);
357
358 AddPtr root = f.AsNode<Add>();
359 ASSERT_NE(root, nullptr);
360
361 HashProvider hasher;
362 auto hash_f = hasher.hash(f.node());
363 auto hash_l = hasher.hash(root->lhs());
364 auto hash_r = hasher.hash(root->rhs());
365
366 // Root not equal to either branch.
367 ASSERT_NE(hash_f, hash_l);
368 ASSERT_NE(hash_f, hash_r);
369 // and branches are NOT equal.
370 ASSERT_NE(hash_l, hash_r);
371}
372
373TEST(Simplify, HashEquivalenceAfterFolding) {
374 VarHandle x("x", kFloat);
375 ExprHandle a(2.0f);
376 ExprHandle b(3.0f);
377 ExprHandle c(5.0f);
378
379 ExprHandle f1 = ((a + b) * x);
380 ExprHandle f2 = (c * x);
381
382 HashProvider hasher;
383 auto hash_l = hasher.hash(f1.node());
384 auto hash_r = hasher.hash(f2.node());
385
386 // Root not equal to either branch, and branches not equal.
387 ASSERT_NE(hash_l, hash_r);
388
389 ExprHandle ff1 = IRSimplifier::simplify(f1);
390 ExprHandle ff2 = IRSimplifier::simplify(f2);
391
392 auto hash_l_n = hasher.hash(ff1.node());
393 auto hash_r_n = hasher.hash(ff2.node());
394 // but branches are now equal.
395 ASSERT_EQ(hash_l_n, hash_r_n);
396}
397
398TEST(Simplify, HashDifferenceTypes) {
399 HashProvider hasher;
400 std::vector<ExprPtr> immediates;
401
402 immediates.push_back(alloc<DoubleImm>(1));
403 immediates.push_back(alloc<FloatImm>(1));
404 immediates.push_back(alloc<HalfImm>(1));
405 // NOLINTNEXTLINE(modernize-use-bool-literals)
406 immediates.push_back(alloc<BoolImm>(1));
407 immediates.push_back(alloc<CharImm>(1));
408 immediates.push_back(alloc<ByteImm>(1));
409 immediates.push_back(alloc<ShortImm>(1));
410 immediates.push_back(alloc<IntImm>(1));
411 immediates.push_back(alloc<LongImm>(1));
412
413 // Immediates of different types are not equal.
414 for (unsigned int i = 0; i < immediates.size(); ++i) {
415 for (unsigned int j = i + 1; j < immediates.size(); ++j) {
416 ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
417 }
418 }
419
420 // But coerced immediates are if they are the same type:
421 ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1);
422 ExprHandle f2 = Cast::make(kFloat, IntImm::make(3));
423
424 ExprHandle ff1 = IRSimplifier::simplify(f1);
425 ExprHandle ff2 = IRSimplifier::simplify(f2);
426
427 ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
428}
429
430TEST(Simplify, HashLargeExpression) {
431 constexpr int N = 1024;
432 BufHandle a("A", {N}, kInt);
433 BufHandle b("B", {N}, kInt);
434 BufHandle c("C", {N}, kInt);
435 VarHandle i("i", kInt);
436 auto memcpy_stmt = For::make(
437 i,
438 0,
439 N,
440 Store::make(
441 c,
442 {i},
443 CompareSelect::make(
444 Load::make(a, {i}),
445 Load::make(b, {i}),
446 CompareSelectOperation::kEQ)));
447
448 BufHandle d("D", {1}, kInt);
449 BufHandle e("E", {1}, kInt);
450 auto store_ramp_stmt = Store::make(
451 e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)}));
452
453 auto if_stmt = Cond::make(
454 CompareSelect::make(
455 Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE),
456 memcpy_stmt,
457 store_ramp_stmt);
458
459 HashProvider hasher;
460 auto hash_r = hasher.hash(if_stmt);
461 // We should not have to do any more work.
462 ASSERT_TRUE(hasher.cachedHash(memcpy_stmt));
463 auto hash_t = hasher.hash(memcpy_stmt);
464 ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt));
465 auto hash_f = hasher.hash(store_ramp_stmt);
466
467 // Root not equal to either branch, and branches not equal.
468 ASSERT_NE(hash_r, hash_t);
469 ASSERT_NE(hash_r, hash_f);
470 ASSERT_NE(hash_t, hash_f);
471}
472
473TEST(Simplify, HashForLoopOptions) {
474 constexpr int N = 1024;
475 BufHandle a("A", {N}, kInt);
476 BufHandle b("B", {N}, kInt);
477 BufHandle c("C", {N}, kInt);
478 VarHandle i("i", kInt);
479 auto for_stmt = For::make(
480 i,
481 0,
482 N,
483 Store::make(
484 c,
485 {i},
486 CompareSelect::make(
487 Load::make(a, {i}),
488 Load::make(b, {i}),
489 CompareSelectOperation::kEQ)));
490
491 HashProvider hasher;
492 auto hash_before = hasher.hash(for_stmt);
493 hasher.clearCache();
494
495 for_stmt->set_gpu_block_index(LoopOptions::IDX_X);
496 auto hash_block_idx = hasher.hash(for_stmt);
497 hasher.clearCache();
498
499 ASSERT_NE(hash_before, hash_block_idx);
500
501 for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET);
502 auto hash_reset = hasher.hash(for_stmt);
503 hasher.clearCache();
504
505 ASSERT_EQ(hash_before, hash_reset);
506 for_stmt->set_gpu_thread_index(LoopOptions::IDX_X);
507 auto hash_thread_idx = hasher.hash(for_stmt);
508
509 ASSERT_NE(hash_before, hash_thread_idx);
510 ASSERT_NE(hash_block_idx, hash_thread_idx);
511}
512
513/// (2 + x) + 4 => x + 6
514TEST(Simplify, SimplifyAdd) {
515 VarHandle x("x", kInt);
516 VarHandle y("y", kInt);
517
518 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
519 VarHandle m("m", kInt);
520 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
521 VarHandle n("n", kInt);
522 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
523 VarHandle n_1("n_1", kInt);
524 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
525 ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
526
527 ExprHandle simplified = IRSimplifier::simplify(body);
528 AddPtr root = simplified.AsNode<Add>();
529 ASSERT_NE(root, nullptr);
530 VarPtr lhs = to<Var>(root->lhs());
531 ASSERT_NE(lhs, nullptr);
532 ASSERT_EQ(lhs->name_hint(), "x");
533 IntImmPtr rhs = to<IntImm>(root->rhs());
534 ASSERT_NE(rhs, nullptr);
535 ASSERT_EQ(rhs->value(), 6.f);
536}
537
538/// (2 - x) - 4 => -2 - x
539TEST(Simplify, SimplifySub) {
540 VarHandle x("x", kInt);
541 ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
542
543 ExprHandle simplified = IRSimplifier::simplify(body);
544 SubPtr root = simplified.AsNode<Sub>();
545 ASSERT_NE(root, nullptr);
546 IntImmPtr lhs = to<IntImm>(root->lhs());
547 ASSERT_NE(lhs, nullptr);
548 ASSERT_EQ(lhs->value(), -2.f);
549 VarPtr rhs = to<Var>(root->rhs());
550 ASSERT_NE(rhs, nullptr);
551 ASSERT_EQ(rhs->name_hint(), "x");
552}
553
554/// 2 * (1 - x) - 4 => 2 * (-3 - x)
555TEST(Simplify, SimplifyMultiLayer) {
556 VarHandle x("x", kInt);
557 ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4));
558 ExprHandle simplified = IRSimplifier::simplify(body);
559 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
560 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
561 IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
562 IS_IMM_WITH_VAL(Int, sub->lhs(), -3);
563 IS_VAR_WITH_NAME(sub->rhs(), "x");
564}
565
566/// 2 * (3 * x) - (x * 4) => 2 * x
567TEST(Simplify, SimplifyMultiTerm) {
568 VarHandle x("x", kInt);
569 ExprHandle body =
570 (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
571
572 ExprHandle simplified = IRSimplifier::simplify(body);
573 MulPtr root = simplified.AsNode<Mul>();
574 ASSERT_NE(root, nullptr);
575 IntImmPtr lhs = to<IntImm>(root->lhs());
576 ASSERT_NE(lhs, nullptr);
577 ASSERT_EQ(lhs->value(), 2);
578 VarPtr rhs = to<Var>(root->rhs());
579 ASSERT_NE(rhs, nullptr);
580 ASSERT_EQ(rhs->name_hint(), "x");
581}
582
583/// 2 * (3 * (long)x) - (x * 4) => 2 * x
584TEST(Simplify, SimplifyCasts) {
585 VarHandle x("x", kLong);
586 ExprHandle body =
587 (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
588
589 ExprHandle simplified = IRSimplifier::simplify(body);
590 MulPtr root = simplified.AsNode<Mul>();
591 ASSERT_NE(root, nullptr);
592 LongImmPtr lhs = to<LongImm>(root->lhs());
593 ASSERT_NE(lhs, nullptr);
594 ASSERT_EQ(lhs->value(), 2);
595 VarPtr rhs = to<Var>(root->rhs());
596 ASSERT_NE(rhs, nullptr);
597 ASSERT_EQ(rhs->name_hint(), "x");
598}
599
600/// (x + 0) * 1 => x
601TEST(Simplify, SimplifyEliminatesNoOps) {
602 VarHandle x("x", kInt);
603 ExprHandle body = (x + ExprHandle(0)) * 1;
604
605 ExprHandle simplified = IRSimplifier::simplify(body);
606 VarPtr root = simplified.AsNode<Var>();
607 ASSERT_NE(root, nullptr);
608 ASSERT_EQ(root->name_hint(), "x");
609}
610
611/// Cannot simplify this.
612TEST(Simplify, SimplifyMultiVar) {
613 VarHandle x("x", kInt);
614 VarHandle y("y", kInt);
615 ExprHandle body = x * 24 + y * 34;
616
617 ExprHandle simplified = IRSimplifier::simplify(body);
618
619 AddPtr root = simplified.AsNode<Add>();
620 ASSERT_NE(root, nullptr);
621 MulPtr lhs = to<Mul>(root->lhs());
622 ASSERT_NE(lhs, nullptr);
623 VarPtr varX = to<Var>(lhs->rhs());
624 ASSERT_NE(varX, nullptr);
625 ASSERT_EQ(varX->name_hint(), "x");
626 MulPtr rhs = to<Mul>(root->rhs());
627 ASSERT_NE(rhs, nullptr);
628 VarPtr varY = to<Var>(rhs->rhs());
629 ASSERT_NE(varY, nullptr);
630 ASSERT_EQ(varY->name_hint(), "y");
631}
632
633// x + 2 + y => x + y + 2
634TEST(Simplify, DISABLED_SimplifyReorderings) {
635 VarHandle x("x", kInt);
636 VarHandle y("y", kInt);
637 ExprHandle body = x + 2 + y;
638 ExprHandle simplified = IRSimplifier::simplify(body);
639
640 AddPtr root = simplified.AsNode<Add>();
641 ASSERT_NE(root, nullptr);
642
643 IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
644 IS_VAR_WITH_NAME(rhs->lhs(), "x");
645 IS_VAR_WITH_NAME(rhs->rhs(), "y");
646 IS_IMM_WITH_VAL(Int, root->rhs(), 2);
647}
648
649/// y + x * 0 => y
650TEST(Simplify, SimplifyEliminatesVar) {
651 VarHandle x("x", kInt);
652 VarHandle y("y", kInt);
653 ExprHandle body = y + x * ExprHandle(0);
654
655 ExprHandle simplified = IRSimplifier::simplify(body);
656 IS_VAR_WITH_NAME(simplified.node(), "y");
657}
658
659TEST(Simplify, SimplifyAdds) {
660 VarHandle x("x", kInt);
661 VarHandle y("y", kInt);
662
663 {
664 // (x + y) + (x + y) => 2 * (x + y)
665 ExprHandle body = (x + y) + (x + y);
666 ExprHandle simplified = IRSimplifier::simplify(body);
667
668 IS_NODE_WITH_NAME(Mul, simplified.node(), root);
669 IS_IMM_WITH_VAL(Int, root->lhs(), 2);
670 IS_NODE_WITH_NAME(Add, root->rhs(), add);
671 IS_VAR_WITH_NAME(add->lhs(), "x");
672 IS_VAR_WITH_NAME(add->rhs(), "y");
673 }
674
675 {
676 // (x * y) + (x * y) => 2 * (x * y)
677 ExprHandle body = (x * y) + (x * y);
678 ExprHandle simplified = IRSimplifier::simplify(body);
679
680 IS_NODE_WITH_NAME(Mul, simplified.node(), root);
681 IS_IMM_WITH_VAL(Int, root->lhs(), 2);
682 IS_NODE_WITH_NAME(Mul, root->rhs(), mul);
683 IS_VAR_WITH_NAME(mul->lhs(), "x");
684 IS_VAR_WITH_NAME(mul->rhs(), "y");
685 }
686
687 {
688 // (x - y) + (x - y) => 2 * (x - y)
689 ExprHandle body = (x - y) + (x - y);
690 ExprHandle simplified = IRSimplifier::simplify(body);
691
692 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
693 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
694
695 IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
696 IS_VAR_WITH_NAME(rhs->lhs(), "x");
697 IS_VAR_WITH_NAME(rhs->rhs(), "y");
698 }
699
700 {
701 // (x + x + x + x) => 4 * x
702 ExprHandle body = (x + x + x + x);
703 ExprHandle simplified = IRSimplifier::simplify(body);
704
705 IS_NODE_WITH_NAME(Mul, simplified.node(), root);
706 IS_IMM_WITH_VAL(Int, root->lhs(), 4);
707 IS_VAR_WITH_NAME(root->rhs(), "x");
708 }
709
710 {
711 // (x + 0) => x.
712 ExprHandle body = x + 0;
713 ExprHandle simplified = IRSimplifier::simplify(body);
714
715 IS_VAR_WITH_NAME(simplified.node(), "x");
716 }
717
718 {
719 // (x + 0.f) => float(x).
720 ExprHandle body = x + 0.f;
721 ExprHandle simplified = IRSimplifier::simplify(body);
722
723 IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
724 ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
725 IS_VAR_WITH_NAME(cast->src_value(), "x");
726 }
727}
728
729TEST(Simplify, SimplifyMuls) {
730 VarHandle x("x", kInt);
731 VarHandle y("y", kInt);
732
733 {
734 // (x + y) * (x + y) => (x + y) * (x + y)
735 // We don't attempt to simplify mulitplication of polynomials since the
736 // result is only very rarely more efficient.
737 ExprHandle body = (x + y) * (x + y);
738 ExprHandle simplified = IRSimplifier::simplify(body);
739
740 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
741 IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
742 IS_VAR_WITH_NAME(lhs->lhs(), "x");
743 IS_VAR_WITH_NAME(lhs->rhs(), "y");
744 IS_NODE_WITH_NAME(Add, mul->rhs(), rhs);
745 IS_VAR_WITH_NAME(rhs->lhs(), "x");
746 IS_VAR_WITH_NAME(rhs->rhs(), "y");
747 }
748
749 {
750 // x * y * x * y => x * x * y * y
751 // These get reordered only.
752 ExprHandle body = x * y * x * y;
753 ExprHandle simplified = IRSimplifier::simplify(body);
754
755 IS_NODE_WITH_NAME(Mul, simplified.node(), mul1);
756 IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2);
757 IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3);
758 IS_VAR_WITH_NAME(mul1->rhs(), "y");
759 IS_VAR_WITH_NAME(mul2->rhs(), "y");
760 IS_VAR_WITH_NAME(mul3->lhs(), "x");
761 IS_VAR_WITH_NAME(mul3->rhs(), "x");
762 }
763
764 {
765 // 1 * (x * 1) => x
766 // Ones cancel cleanly.
767 ExprHandle body = ExprHandle(1) * (x * ExprHandle(1));
768 ExprHandle simplified = IRSimplifier::simplify(body);
769
770 IS_VAR_WITH_NAME(simplified.node(), "x");
771 }
772
773 {
774 // 1.f * (x * 1.f) => x
775 // Even float ones cancel cleanly, but carry their type.
776 ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f));
777 ExprHandle simplified = IRSimplifier::simplify(body);
778
779 IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
780 ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
781 IS_VAR_WITH_NAME(cast->src_value(), "x");
782 }
783
784 {
785 // 1 * (x * 1.f) => x
786 // One float is enough to cast the expr.
787 ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f));
788 ExprHandle simplified = IRSimplifier::simplify(body);
789
790 IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
791 ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
792 IS_VAR_WITH_NAME(cast->src_value(), "x");
793 }
794
795 {
796 // 1 * (x * 0) => 0
797 // Zeroes are eliminated.
798 ExprHandle body = ExprHandle(1) * (x * ExprHandle(0));
799 ExprHandle simplified = IRSimplifier::simplify(body);
800
801 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
802 }
803
804 {
805 // 1 * (x * 0) => 0
806 // But not for Float since nan * 0 = nan.
807 ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f));
808 ExprHandle simplified = IRSimplifier::simplify(body);
809
810 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
811 IS_NODE_WITH_NAME(Cast, mul->lhs(), cast);
812 ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
813 IS_VAR_WITH_NAME(cast->src_value(), "x");
814 IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0);
815 }
816
817 {
818 // (x - y) * (x - y) => (x - y) * (x - y)
819 // As with Add we don't attempt simplification of this.
820 ExprHandle body = (x - y) * (x - y);
821 ExprHandle simplified = IRSimplifier::simplify(body);
822
823 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
824 IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs);
825 IS_VAR_WITH_NAME(lhs->lhs(), "x");
826 IS_VAR_WITH_NAME(lhs->rhs(), "y");
827 IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
828 IS_VAR_WITH_NAME(rhs->lhs(), "x");
829 IS_VAR_WITH_NAME(rhs->rhs(), "y");
830 }
831
832 {
833 // (x + y) * (x - y) => (x + y) * (x - y)
834 // Don't simplify with different ops on each side.
835 ExprHandle body = (x + y) * (x - y);
836 ExprHandle simplified = IRSimplifier::simplify(body);
837 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
838 IS_NODE_WITH_NAME(Add, mul->lhs(), lhs);
839 IS_VAR_WITH_NAME(lhs->lhs(), "x");
840 IS_VAR_WITH_NAME(lhs->rhs(), "y");
841 IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs);
842 IS_VAR_WITH_NAME(rhs->lhs(), "x");
843 IS_VAR_WITH_NAME(rhs->rhs(), "y");
844 }
845
846 {
847 // Multiply a polynomial by a term.
848 // - term with no scalar, poly with non-identity scalar.
849 // x * (y + 1) => x + x * y
850 ExprHandle body = x * (y + ExprHandle(1));
851 ExprHandle simplified = IRSimplifier::simplify(body);
852
853 IS_NODE_WITH_NAME(Add, simplified.node(), add);
854 IS_VAR_WITH_NAME(add->lhs(), "x");
855 IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
856 IS_VAR_WITH_NAME(mul->lhs(), "x");
857 IS_VAR_WITH_NAME(mul->rhs(), "y");
858 }
859
860 {
861 // Multiply a polynomial by a term.
862 // - term with identity scalar, poly with non-identity scalar.
863 // (x * 1) * (y + 1) => x + x * y
864 ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1));
865 ExprHandle simplified = IRSimplifier::simplify(body);
866
867 IS_NODE_WITH_NAME(Add, simplified.node(), add);
868 IS_VAR_WITH_NAME(add->lhs(), "x");
869 IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
870 IS_VAR_WITH_NAME(mul->lhs(), "x");
871 IS_VAR_WITH_NAME(mul->rhs(), "y");
872 }
873
874 {
875 // Multiply a polynomial by a term.
876 // - term with non-identity scalar, poly with non-identity scalar.
877 // (x * 2) * (y + 1) => 2 * (x + x * y)
878 ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1));
879 ExprHandle simplified = IRSimplifier::simplify(body);
880
881 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
882 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
883 IS_NODE_WITH_NAME(Add, mul->rhs(), add);
884 IS_VAR_WITH_NAME(add->lhs(), "x");
885 IS_NODE_WITH_NAME(Mul, add->rhs(), mul2);
886 IS_VAR_WITH_NAME(mul2->lhs(), "x");
887 IS_VAR_WITH_NAME(mul2->rhs(), "y");
888 }
889
890 {
891 // Multiply a polynomial by a term.
892 // - term with non-identity scalar, poly with identity scalar.
893 // (x * 2) * (y + 0) => 2 * (x * y)
894 ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0));
895 ExprHandle simplified = IRSimplifier::simplify(body);
896
897 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
898 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
899 IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2);
900 IS_VAR_WITH_NAME(mul2->lhs(), "x");
901 IS_VAR_WITH_NAME(mul2->rhs(), "y");
902 }
903
904 {
905 // Multiply a polynomial by a term.
906 // - term with identity scalar, poly with identity scalar.
907 // (x * 1) * (y + 0) => x * y
908 ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0));
909 ExprHandle simplified = IRSimplifier::simplify(body);
910
911 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
912 IS_VAR_WITH_NAME(mul->lhs(), "x");
913 IS_VAR_WITH_NAME(mul->rhs(), "y");
914 }
915
916 {
917 // Multiply a polynomial by a term.
918 // - term with no scalar, poly with identity scalar.
919 // x * (y + 0) => x * y
920 ExprHandle body = x * (y + ExprHandle(0));
921 ExprHandle simplified = IRSimplifier::simplify(body);
922
923 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
924 IS_VAR_WITH_NAME(mul->lhs(), "x");
925 IS_VAR_WITH_NAME(mul->rhs(), "y");
926 }
927}
928
929// Sub an expr from itself will result in zero.
930TEST(Simplify, SimplifySubs) {
931 VarHandle x("x", kInt);
932 VarHandle y("y", kInt);
933
934 {
935 // (x + y) - (x + y) => 0
936 ExprHandle body = (x + y) - (x + y);
937 ExprHandle simplified = IRSimplifier::simplify(body);
938 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
939 }
940
941 {
942 // (x * y) - (x * y) => 0
943 ExprHandle body = (x * y) - (x * y);
944 ExprHandle simplified = IRSimplifier::simplify(body);
945 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
946 }
947
948 {
949 // (x - y) - (x - y) => 0
950 ExprHandle body = (x - y) - (x - y);
951 ExprHandle simplified = IRSimplifier::simplify(body);
952 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
953 }
954
955 {
956 // (x + y) - 2 * (x + y) => -1 * x - y
957 ExprHandle body = (x + y) - ExprHandle(2) * (x + y);
958 ExprHandle simplified = IRSimplifier::simplify(body);
959
960 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
961 IS_NODE_WITH_NAME(Mul, sub->lhs(), mul);
962 IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
963 IS_VAR_WITH_NAME(mul->rhs(), "x");
964 IS_VAR_WITH_NAME(sub->rhs(), "y");
965 }
966
967 {
968 // (x + y) - y => x
969 ExprHandle body = (x + y) - y;
970 ExprHandle simplified = IRSimplifier::simplify(body);
971 IS_VAR_WITH_NAME(simplified.node(), "x");
972 }
973
974 {
975 // (x - 0) => x.
976 ExprHandle body = x - 0;
977 ExprHandle simplified = IRSimplifier::simplify(body);
978 IS_VAR_WITH_NAME(simplified.node(), "x");
979 }
980
981 {
982 // (x - 0.f) => x.
983 // Simple enough to cancel in float.
984 ExprHandle body = x - ExprHandle(0.f);
985 ExprHandle simplified = IRSimplifier::simplify(body);
986 IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
987 ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
988 IS_VAR_WITH_NAME(cast->src_value(), "x");
989 }
990
991 {
992 // (x - (float)(y - y)) => x.
993 ExprHandle body = x - Cast::make(kFloat, y - y);
994 ExprHandle simplified = IRSimplifier::simplify(body);
995 IS_NODE_WITH_NAME(Cast, simplified.node(), cast);
996 ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float);
997 IS_VAR_WITH_NAME(cast->src_value(), "x");
998 }
999
1000 {
1001 // (x - y) - y => x - 2 * y
1002 ExprHandle body = (x - y) - y;
1003 ExprHandle simplified = IRSimplifier::simplify(body);
1004
1005 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1006 IS_VAR_WITH_NAME(sub->lhs(), "x");
1007 IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
1008 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1009 IS_VAR_WITH_NAME(mul->rhs(), "y");
1010 }
1011
1012 {
1013 // 2 * x - x => x
1014 ExprHandle body = (ExprHandle(2) * x) - x;
1015 ExprHandle simplified = IRSimplifier::simplify(body);
1016 IS_VAR_WITH_NAME(simplified.node(), "x");
1017 }
1018
1019 {
1020 // x - 2 * x = -1 * x
1021 // We don't have a unary negate, but this could be 0 -x I guess?
1022 ExprHandle body = x - (ExprHandle(2) * x);
1023 ExprHandle simplified = IRSimplifier::simplify(body);
1024 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1025
1026 IS_IMM_WITH_VAL(Int, mul->lhs(), -1);
1027 IS_VAR_WITH_NAME(mul->rhs(), "x");
1028 }
1029
1030 {
1031 // (x + y + 5) * (x - x) => 0
1032 // Cancelling out one side of Mul cancels both.
1033 ExprHandle body = (x + y + 5) * (x - x);
1034 ExprHandle simplified = IRSimplifier::simplify(body);
1035
1036 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1037 }
1038
1039 {
1040 // Cancel out opaque modulus.
1041 ExprHandle body = (x % y + 2) - (x % y);
1042 ExprHandle simplified = IRSimplifier::simplify(body);
1043 IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1044 }
1045
1046 {
1047 // Cancel out opaque modulus with a bit more going on.
1048 ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y);
1049 ExprHandle simplified = IRSimplifier::simplify(body);
1050 IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1051 }
1052
1053 {
1054 // Sub where result is negative.
1055 ExprHandle body = x - (x + 1);
1056 ExprHandle simplified = IRSimplifier::simplify(body);
1057 IS_IMM_WITH_VAL(Int, simplified.node(), -1);
1058 }
1059
1060 {
1061 // Sub where result is positive due to negative scalar on RHS.
1062 ExprHandle body = x - (x - 1);
1063 ExprHandle simplified = IRSimplifier::simplify(body);
1064 IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1065 }
1066
1067 {
1068 // Term - Polynomial sub where RHS must be negated.
1069 ExprHandle body = (x * 2) - (x * 2 + 1);
1070 ExprHandle simplified = IRSimplifier::simplify(body);
1071 IS_IMM_WITH_VAL(Int, simplified.node(), -1);
1072 }
1073
1074 {
1075 // Term - Polynomial sub where the result is a Term.
1076 ExprHandle body = (y * x * 2) - (x * y);
1077 ExprHandle simplified = IRSimplifier::simplify(body);
1078 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1079
1080 IS_VAR_WITH_NAME(mul->lhs(), "x");
1081 IS_VAR_WITH_NAME(mul->rhs(), "y");
1082 }
1083
1084 {
1085 // Term - Polynomial sub where the result is a Polynomial.
1086 ExprHandle body = (x * 2) - (x + 1);
1087 ExprHandle simplified = IRSimplifier::simplify(body);
1088 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1089
1090 IS_VAR_WITH_NAME(sub->lhs(), "x");
1091 IS_IMM_WITH_VAL(Int, sub->rhs(), 1);
1092 }
1093}
1094
1095TEST(Simplify, SimplifyDiv) {
1096 VarHandle x("x", kInt);
1097
1098 {
1099 ExprHandle body = ExprHandle(0) / x;
1100 ExprHandle simplified = IRSimplifier::simplify(body);
1101
1102 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1103 }
1104
1105 {
1106 ExprHandle body = x / 1;
1107 ExprHandle simplified = IRSimplifier::simplify(body);
1108
1109 IS_VAR_WITH_NAME(simplified.node(), "x");
1110 }
1111}
1112
1113TEST(Simplify, SimplifyDivWithLoopContext0) {
1114 // Stmt to simplify:
1115 // for (int i = 0; i < 100; i++) {
1116 // A[i] = i / 100;
1117 //}
1118 VarHandle i("i", kInt);
1119 BufHandle a_buf("A", {100}, kInt);
1120 auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100)));
1121
1122 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1123
1124 std::ostringstream oss;
1125 oss << *(simplified);
1126 const std::string& verification_pattern =
1127 R"IR(
1128# CHECK: for (int i
1129# CHECK-NEXT: A[i] = 0;
1130 )IR";
1131 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1132}
1133
1134TEST(Simplify, SimplifyDivWithLoopContext1) {
1135 // Stmt to simplify:
1136 // for (const auto i : c10::irange(6)) {
1137 // A[i] = (i + 24) / 6;
1138 //}
1139 VarHandle i("i", kInt);
1140 BufHandle a_buf("A", {6}, kInt);
1141 auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6));
1142
1143 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1144
1145 std::ostringstream oss;
1146 oss << *(simplified);
1147 const std::string& verification_pattern =
1148 R"IR(
1149# CHECK: for (int i
1150# CHECK-NEXT: A[i] = 4;
1151 )IR";
1152 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1153}
1154
1155TEST(Simplify, SimplifyDivWithLoopContext2) {
1156 // Stmt to simplify:
1157 // for (const auto i : c10::irange(5)) {
1158 // A[i] = (i + 25) / 6;
1159 //}
1160 VarHandle i("i", kInt);
1161 BufHandle a_buf("A", {5}, kInt);
1162 auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6));
1163
1164 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1165
1166 std::ostringstream oss;
1167 oss << *(simplified);
1168 const std::string& verification_pattern =
1169 R"IR(
1170# CHECK: for (int i
1171# CHECK-NEXT: A[i] = 4;
1172 )IR";
1173 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1174}
1175
1176TEST(Simplify, SimplifyDivWithLoopContext3) {
1177 // Stmt to simplify:
1178 // for (const auto i : c10::irange(6)) {
1179 // A[i] = (i + 24) / (-6);
1180 //}
1181 VarHandle i("i", kInt);
1182 BufHandle a_buf("A", {6}, kInt);
1183 auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6)));
1184
1185 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1186
1187 std::ostringstream oss;
1188 oss << *(simplified);
1189 const std::string& verification_pattern =
1190 R"IR(
1191# CHECK: for (int i
1192# CHECK-NOT: A[i] = -4;
1193 )IR";
1194 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1195}
1196
1197TEST(Simplify, SimplifyDivWithLoopContext4) {
1198 // Stmt to simplify:
1199 // for (const auto i : c10::irange(5)) {
1200 // A[i] = (i - 5) / 6;
1201 //}
1202 VarHandle i("i", kInt);
1203 BufHandle a_buf("A", {5}, kInt);
1204 auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6));
1205
1206 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1207
1208 std::ostringstream oss;
1209 oss << *(simplified);
1210 const std::string& verification_pattern =
1211 R"IR(
1212# CHECK: for (int i
1213# CHECK-NOT: A[i] = 0;
1214 )IR";
1215 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1216}
1217
1218TEST(Simplify, SimplifyDivWithLoopContext5) {
1219 // Stmt to simplify:
1220 // for (const auto i : c10::irange(6)) {
1221 // for (const auto j : c10::irange(10)) {
1222 // A[i, j] = (i + 6*j) / 6;
1223 // }
1224 //}
1225 VarHandle i("i", kInt);
1226 VarHandle j("j", kInt);
1227 BufHandle a_buf("A", {6, 10}, kInt);
1228 auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6));
1229 auto for_i = For::make(i, 0, 6, for_j);
1230
1231 const StmtPtr simplified = IRSimplifier::simplify(for_i);
1232
1233 std::ostringstream oss;
1234 oss << *(simplified);
1235 const std::string& verification_pattern =
1236 R"IR(
1237# CHECK: for (int i
1238# CHECK: for (int j
1239# CHECK-NEXT: A[i, j] = j;
1240 )IR";
1241 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1242}
1243
1244TEST(Simplify, SimplifyDivWithLoopContext6) {
1245 // Stmt to simplify:
1246 // for (const auto i : c10::irange(6)) {
1247 // for (int j = -1; j < 9; j++) {
1248 // A[i, j+1] = (i + 6*j) / 6;
1249 // }
1250 //}
1251 VarHandle i("i", kInt);
1252 VarHandle j("j", kInt);
1253 BufHandle a_buf("A", {6, 10}, kInt);
1254 auto for_j =
1255 For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6));
1256 auto for_i = For::make(i, 0, 6, for_j);
1257
1258 const StmtPtr simplified = IRSimplifier::simplify(for_i);
1259
1260 std::ostringstream oss;
1261 oss << *(simplified);
1262 const std::string& verification_pattern =
1263 R"IR(
1264# CHECK: for (int i
1265# CHECK: for (int j
1266# CHECK-NOT: A[i, j] = j;
1267 )IR";
1268 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1269}
1270
1271TEST(Simplify, SimplifyDivWithLoopContext7) {
1272 // Stmt to simplify:
1273 // for (const auto i : c10::irange(6)) {
1274 // for (const auto j : c10::irange(10)) {
1275 // A[i, j] = (i + 6*j) / (-6);
1276 // }
1277 //}
1278 VarHandle i("i", kInt);
1279 VarHandle j("j", kInt);
1280 BufHandle a_buf("A", {6, 10}, kInt);
1281 auto for_j =
1282 For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6)));
1283 auto for_i = For::make(i, 0, 6, for_j);
1284
1285 const StmtPtr simplified = IRSimplifier::simplify(for_i);
1286
1287 std::ostringstream oss;
1288 oss << *(simplified);
1289 const std::string& verification_pattern =
1290 R"IR(
1291# CHECK: for (int i
1292# CHECK: for (int j
1293# CHECK-NOT: A[i, j] = -j;
1294 )IR";
1295 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1296}
1297
1298TEST(Simplify, SimplifyModWithLoopContext0) {
1299 // Stmt to simplify:
1300 // for (const auto i : c10::irange(100)) {
1301 // A[i] = i % 100;
1302 //}
1303 VarHandle i("i", kInt);
1304 BufHandle a_buf("A", {100}, kInt);
1305 auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100)));
1306
1307 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1308
1309 std::ostringstream oss;
1310 oss << *(simplified);
1311 const std::string& verification_pattern =
1312 R"IR(
1313# CHECK: for (int i
1314# CHECK-NEXT: A[i] = i;
1315 )IR";
1316 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1317}
1318
1319TEST(Simplify, SimplifyModWithLoopContext1) {
1320 // Stmt to simplify:
1321 // for (const auto i : c10::irange(6)) {
1322 // A[i] = (i + 24) % 6;
1323 //}
1324 VarHandle i("i", kInt);
1325 BufHandle a_buf("A", {6}, kInt);
1326 auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6));
1327
1328 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1329
1330 std::ostringstream oss;
1331 oss << *(simplified);
1332 const std::string& verification_pattern =
1333 R"IR(
1334# CHECK: for (int i
1335# CHECK-NEXT: A[i] = i;
1336 )IR";
1337 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1338}
1339
1340TEST(Simplify, SimplifyModWithLoopContext2) {
1341 // Stmt to simplify:
1342 // for (const auto i : c10::irange(5)) {
1343 // A[i] = (i + 25) % 6;
1344 //}
1345 VarHandle i("i", kInt);
1346 BufHandle a_buf("A", {5}, kInt);
1347 auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6));
1348
1349 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1350
1351 std::ostringstream oss;
1352 oss << *(simplified);
1353 const std::string& verification_pattern =
1354 R"IR(
1355# CHECK: for (int i
1356# CHECK-NEXT: A[i] = i + 1;
1357 )IR";
1358 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1359}
1360
1361TEST(Simplify, SimplifyModWithLoopContext3) {
1362 // Stmt to simplify:
1363 // for (const auto i : c10::irange(6)) {
1364 // A[i] = (i + 24) % (-6);
1365 //}
1366 VarHandle i("i", kInt);
1367 BufHandle a_buf("A", {6}, kInt);
1368 auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6)));
1369
1370 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1371
1372 std::ostringstream oss;
1373 oss << *(simplified);
1374 const std::string& verification_pattern =
1375 R"IR(
1376# CHECK: for (int i
1377# CHECK-NOT: A[i] = i;
1378 )IR";
1379 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1380}
1381
1382TEST(Simplify, SimplifyModWithLoopContext4) {
1383 // Stmt to simplify:
1384 // for (const auto i : c10::irange(5)) {
1385 // A[i] = (i - 5) % 6;
1386 //}
1387 VarHandle i("i", kInt);
1388 BufHandle a_buf("A", {5}, kInt);
1389 auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6));
1390
1391 const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
1392
1393 std::ostringstream oss;
1394 oss << *(simplified);
1395 const std::string& verification_pattern =
1396 R"IR(
1397# CHECK: for (int i
1398# CHECK-NOT: A[i] = i - 5;
1399 )IR";
1400 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1401}
1402
1403TEST(Simplify, SimplifyModWithLoopContext5) {
1404 // Stmt to simplify:
1405 // for (const auto i : c10::irange(6)) {
1406 // for (const auto j : c10::irange(10)) {
1407 // A[i, j] = (i + 6*j) % 6;
1408 // }
1409 //}
1410 VarHandle i("i", kInt);
1411 VarHandle j("j", kInt);
1412 BufHandle a_buf("A", {6, 10}, kInt);
1413 auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6));
1414 auto for_i = For::make(i, 0, 6, for_j);
1415
1416 const StmtPtr simplified = IRSimplifier::simplify(for_i);
1417
1418 std::ostringstream oss;
1419 oss << *(simplified);
1420 const std::string& verification_pattern =
1421 R"IR(
1422# CHECK: for (int i
1423# CHECK: for (int j
1424# CHECK-NEXT: A[i, j] = i;
1425 )IR";
1426 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1427}
1428
1429TEST(Simplify, SimplifyModWithLoopContext6) {
1430 // Stmt to simplify:
1431 // for (const auto i : c10::irange(6)) {
1432 // for (int j = -1; j < 9; j++) {
1433 // A[i, j+1] = (i + 6*j) % 6;
1434 // }
1435 //}
1436 VarHandle i("i", kInt);
1437 VarHandle j("j", kInt);
1438 BufHandle a_buf("A", {6, 10}, kInt);
1439 auto for_j =
1440 For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6));
1441 auto for_i = For::make(i, 0, 6, for_j);
1442
1443 const StmtPtr simplified = IRSimplifier::simplify(for_i);
1444
1445 std::ostringstream oss;
1446 oss << *(simplified);
1447 const std::string& verification_pattern =
1448 R"IR(
1449# CHECK: for (int i
1450# CHECK: for (int j
1451# CHECK-NOT: A[i, j] = i;
1452 )IR";
1453 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1454}
1455
1456TEST(Simplify, SimplifyModWithLoopContext7) {
1457 // Stmt to simplify:
1458 // for (const auto i : c10::irange(6)) {
1459 // for (const auto j : c10::irange(10)) {
1460 // A[i, j] = (i + 6*j) % (-6);
1461 // }
1462 //}
1463 VarHandle i("i", kInt);
1464 VarHandle j("j", kInt);
1465 BufHandle a_buf("A", {6, 10}, kInt);
1466 auto for_j =
1467 For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6)));
1468 auto for_i = For::make(i, 0, 6, for_j);
1469
1470 const StmtPtr simplified = IRSimplifier::simplify(for_i);
1471
1472 std::ostringstream oss;
1473 oss << *(simplified);
1474 const std::string& verification_pattern =
1475 R"IR(
1476# CHECK: for (int i
1477# CHECK: for (int j
1478# CHECK-NOT: A[i, j] = i;
1479 )IR";
1480 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1481}
1482
1483TEST(Simplify, SimplifyMod) {
1484 VarHandle x("x", kInt);
1485 VarHandle y("y", kInt);
1486 VarHandle z("z", kInt);
1487
1488 {
1489 // Constant folding works.
1490 ExprHandle body = ExprHandle(10) % 8;
1491 ExprHandle simplified = IRSimplifier::simplify(body);
1492 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1493 IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1494 }
1495
1496 {
1497 // x % x => 0
1498 ExprHandle body = x % x;
1499 ExprHandle simplified = IRSimplifier::simplify(body);
1500 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1501 }
1502
1503 {
1504 // 0 % x => 0
1505 ExprHandle body = ExprHandle(0) % x;
1506 ExprHandle simplified = IRSimplifier::simplify(body);
1507 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1508 }
1509
1510 {
1511 // x % 1 => 0
1512 ExprHandle body = x % 1;
1513 ExprHandle simplified = IRSimplifier::simplify(body);
1514 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1515 }
1516
1517 {
1518 // Doesn't change unknown mods.
1519 // x % y => x % y
1520 ExprHandle body = x % y;
1521 ExprHandle simplified = IRSimplifier::simplify(body);
1522 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1523 IS_VAR_WITH_NAME(mod->lhs(), "x");
1524 IS_VAR_WITH_NAME(mod->rhs(), "y");
1525 }
1526
1527 {
1528 // don't touch if RHS is unknown.
1529 // 4 % x => 4 % x
1530 ExprHandle body = ExprHandle(4) % x;
1531 ExprHandle simplified = IRSimplifier::simplify(body);
1532 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1533 IS_IMM_WITH_VAL(Int, mod->lhs(), 4);
1534 IS_VAR_WITH_NAME(mod->rhs(), "x");
1535 }
1536
1537 {
1538 // don't touch if LHS is unknown.
1539 // x % 4 => x % 4
1540 ExprHandle body = x % 4;
1541 ExprHandle simplified = IRSimplifier::simplify(body);
1542 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1543 IS_VAR_WITH_NAME(mod->lhs(), "x");
1544 IS_IMM_WITH_VAL(Int, mod->rhs(), 4);
1545 }
1546
1547 {
1548 // if LHS is a multiple of RHS, mod is zero.
1549 // 2 * x % x => 0
1550 ExprHandle body = (x * 2) % x;
1551 ExprHandle simplified = IRSimplifier::simplify(body);
1552 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1553 }
1554
1555 {
1556 // true even if the multiple is not constant.
1557 // x * y % x => 0
1558 ExprHandle body = (x * y) % x;
1559 ExprHandle simplified = IRSimplifier::simplify(body);
1560 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1561 }
1562
1563 {
1564 // true with multiple unknown values in LHS.
1565 // x * y * z % x => 0
1566 ExprHandle body = (x * y * z) % x;
1567 ExprHandle simplified = IRSimplifier::simplify(body);
1568 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1569 }
1570
1571 {
1572 // true if the denom is compound.
1573 // x * y * z % y * z => 0
1574 ExprHandle body = (x * y * z) % (y * z);
1575 ExprHandle simplified = IRSimplifier::simplify(body);
1576 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1577 }
1578
1579 {
1580 // Sanity check true with scalars that are multiples.
1581 // 12 * x % 4 => 0
1582 ExprHandle body = (x * 12) % 4;
1583 ExprHandle simplified = IRSimplifier::simplify(body);
1584 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1585 }
1586
1587 {
1588 // Sanity check not true if the smaller scalar is on LHS.
1589 // 4 * x % 12 => 4 * x % 12
1590 ExprHandle body = (x * 4) % 12;
1591 ExprHandle simplified = IRSimplifier::simplify(body);
1592 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1593 IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
1594 IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
1595 IS_VAR_WITH_NAME(mul->rhs(), "x");
1596 IS_IMM_WITH_VAL(Int, mod->rhs(), 12);
1597 }
1598
1599 {
1600 // Both scalar and symbolic in multiple.
1601 // (6 * x * y) % (3 * x * y) => 0
1602 ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3);
1603 ExprHandle simplified = IRSimplifier::simplify(body);
1604 IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1605 }
1606}
1607
1608// Test that mixing ops together simplifies as expected.
1609TEST(Simplify, SimplifyMultiOp) {
1610 VarHandle x("x", kInt);
1611 VarHandle y("y", kInt);
1612
1613 {
1614 // (x * y) + (x - y) => (x + x * y) - y
1615 ExprHandle body = (x * y) + (x - y);
1616 ExprHandle simplified = IRSimplifier::simplify(body);
1617
1618 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1619 IS_NODE_WITH_NAME(Add, sub->lhs(), add);
1620 IS_VAR_WITH_NAME(add->lhs(), "x");
1621 IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
1622 IS_VAR_WITH_NAME(mul->lhs(), "x");
1623 IS_VAR_WITH_NAME(mul->rhs(), "y");
1624 IS_VAR_WITH_NAME(sub->rhs(), "y");
1625 }
1626
1627 {
1628 // (x + y) - x * y => (x + y) - x * y
1629 ExprHandle body = (x + y) - x * y;
1630 ExprHandle simplified = IRSimplifier::simplify(body);
1631 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1632 IS_NODE_WITH_NAME(Add, sub->lhs(), add);
1633 IS_NODE_WITH_NAME(Mul, sub->rhs(), mul);
1634 IS_VAR_WITH_NAME(add->lhs(), "x");
1635 IS_VAR_WITH_NAME(add->rhs(), "y");
1636 IS_VAR_WITH_NAME(mul->lhs(), "x");
1637 IS_VAR_WITH_NAME(mul->rhs(), "y");
1638 }
1639
1640 {
1641 // (x - y) - (x + y) => -2 * y
1642 ExprHandle body = (x - y) - (x + y);
1643 ExprHandle simplified = IRSimplifier::simplify(body);
1644
1645 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1646 IS_IMM_WITH_VAL(Int, mul->lhs(), -2);
1647 IS_VAR_WITH_NAME(mul->rhs(), "y");
1648 }
1649
1650 {
1651 // (x - 0) + (x * 1) - (x + 0) => x
1652 ExprHandle body = (x - 0) + (x * 1) - (x + 0);
1653 ExprHandle simplified = IRSimplifier::simplify(body);
1654
1655 IS_VAR_WITH_NAME(simplified.node(), "x");
1656 }
1657
1658 {
1659 // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x)
1660 // Even in Float simple terms cancel out, but the variable ones cannot.
1661 ExprHandle body =
1662 (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f));
1663 ExprHandle simplified = IRSimplifier::simplify(body);
1664
1665 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
1666 IS_NODE_WITH_NAME(Add, sub->lhs(), add);
1667 IS_NODE_WITH_NAME(Cast, add->lhs(), cast1);
1668 IS_VAR_WITH_NAME(cast1->src_value(), "x");
1669 IS_NODE_WITH_NAME(Cast, add->rhs(), cast2);
1670 IS_VAR_WITH_NAME(cast2->src_value(), "x");
1671 IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3);
1672 IS_VAR_WITH_NAME(cast3->src_value(), "x");
1673 }
1674}
1675
1676// Test that chaining many ops together works as expected.
1677TEST(Simplify, SimplifyManyOps) {
1678 VarHandle x("x", kInt);
1679 VarHandle y("y", kInt);
1680
1681 {
1682 // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x
1683 ExprHandle body = x + y + x + x + y + y + x + y + x;
1684 ExprHandle simplified = IRSimplifier::simplify(body);
1685
1686 IS_NODE_WITH_NAME(Add, simplified.node(), add);
1687
1688 IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1689 IS_IMM_WITH_VAL(Int, lhs->lhs(), 4);
1690 IS_VAR_WITH_NAME(lhs->rhs(), "y");
1691
1692 IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1693 IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
1694 IS_VAR_WITH_NAME(rhs->rhs(), "x");
1695 }
1696
1697 {
1698 // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y
1699 ExprHandle body = x - y + x + x - y - y + x - y + x;
1700 ExprHandle simplified = IRSimplifier::simplify(body);
1701
1702 IS_NODE_WITH_NAME(Sub, simplified.node(), add);
1703
1704 IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1705 IS_IMM_WITH_VAL(Int, lhs->lhs(), 5);
1706 IS_VAR_WITH_NAME(lhs->rhs(), "x");
1707
1708 IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1709 IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
1710 IS_VAR_WITH_NAME(rhs->rhs(), "y");
1711 }
1712
1713 {
1714 // x + y + x - x - y - y + x + y + x = 3 * x
1715 ExprHandle body = x + y + x - x - y - y + x + y + x;
1716 ExprHandle simplified = IRSimplifier::simplify(body);
1717
1718 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1719 IS_IMM_WITH_VAL(Int, mul->lhs(), 3);
1720 IS_VAR_WITH_NAME(mul->rhs(), "x");
1721 }
1722}
1723
1724TEST(Simplify, SimplifyFactorization) {
1725 VarHandle x("x", kInt);
1726 VarHandle y("y", kInt);
1727
1728 {
1729 // (2 * x) + (2 * y) => 2 * (x + y)
1730 ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y);
1731 ExprHandle simplified = IRSimplifier::simplify(body);
1732
1733 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1734 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1735
1736 IS_NODE_WITH_NAME(Add, mul->rhs(), add);
1737 IS_VAR_WITH_NAME(add->lhs(), "x");
1738 IS_VAR_WITH_NAME(add->rhs(), "y");
1739 }
1740
1741 {
1742 // Factorization when scalars have common divider.
1743 // (2 * x) + (4 * y) => 2 * (2 * y + x)
1744 ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y);
1745 ExprHandle simplified = IRSimplifier::simplify(body);
1746
1747 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1748 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1749
1750 IS_NODE_WITH_NAME(Add, mul->rhs(), add);
1751 IS_VAR_WITH_NAME(add->lhs(), "x");
1752 IS_NODE_WITH_NAME(Mul, add->rhs(), mul2);
1753 IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
1754 IS_VAR_WITH_NAME(mul2->rhs(), "y");
1755 }
1756
1757 {
1758 // Factorization attempt without a common divider.
1759 // (2 * x) + (5 * y) => (5 * y) + (2 * x)
1760 ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y);
1761 ExprHandle simplified = IRSimplifier::simplify(body);
1762
1763 IS_NODE_WITH_NAME(Add, simplified.node(), add);
1764
1765 IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1766 IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
1767 IS_VAR_WITH_NAME(lhs->rhs(), "x");
1768
1769 IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1770 IS_IMM_WITH_VAL(Int, rhs->lhs(), 5);
1771 IS_VAR_WITH_NAME(rhs->rhs(), "y");
1772 }
1773
1774 {
1775 // Factorization after merging.
1776 // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y)
1777 ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) +
1778 (ExprHandle(8) * x + ExprHandle(6) * y);
1779 ExprHandle simplified = IRSimplifier::simplify(body);
1780
1781 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1782 IS_IMM_WITH_VAL(Int, mul->lhs(), 10);
1783
1784 IS_NODE_WITH_NAME(Add, mul->rhs(), add);
1785 IS_VAR_WITH_NAME(add->lhs(), "x");
1786 IS_VAR_WITH_NAME(add->rhs(), "y");
1787 }
1788
1789 {
1790 // Factorization with common divider but different signs.
1791 // (2 * x) + (-4 * y) => 2 * (x - 2 * y)
1792 ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y);
1793 ExprHandle simplified = IRSimplifier::simplify(body);
1794
1795 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1796 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1797
1798 IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
1799 IS_VAR_WITH_NAME(sub->lhs(), "x");
1800 IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2);
1801 IS_IMM_WITH_VAL(Int, mul2->lhs(), 2);
1802 IS_VAR_WITH_NAME(mul2->rhs(), "y");
1803 }
1804
1805 {
1806 // Factorization with all negative numbers.
1807 // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y)
1808 ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y;
1809 ExprHandle simplified = IRSimplifier::simplify(body);
1810
1811 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1812 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
1813
1814 IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
1815 IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2);
1816 IS_IMM_WITH_VAL(Int, mul2->lhs(), -1);
1817 IS_VAR_WITH_NAME(mul2->rhs(), "x");
1818 IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3);
1819 IS_IMM_WITH_VAL(Int, mul3->lhs(), 2);
1820 IS_VAR_WITH_NAME(mul3->rhs(), "y");
1821 }
1822
1823 {
1824 // The following test ensures that there in no infinite recursion during
1825 // factorization when negative numbers are involved.
1826 VarHandle a("a", kInt);
1827 VarHandle b("b", kInt);
1828 VarHandle c("c", kInt);
1829 VarHandle d("d", kInt);
1830 VarHandle e("e", kInt);
1831 VarHandle f("f", kInt);
1832 VarHandle g("g", kInt);
1833 VarHandle h("h", kInt);
1834
1835 ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 +
1836 f * 32 + g * (-1024) + h * (-32);
1837 ExprHandle simplified = IRSimplifier::simplify(body);
1838 checkExprIR(
1839 simplified,
1840 "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h");
1841 }
1842}
1843
1844// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x)
1845TEST(Simplify, SimplifyFactorizeUneven) {
1846 VarHandle x("x", kInt);
1847 VarHandle y("y", kInt);
1848 VarHandle z("z", kInt);
1849 ExprHandle body =
1850 (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4);
1851 ExprHandle simplified = IRSimplifier::simplify(body);
1852
1853 IS_NODE_WITH_NAME(Mul, simplified.node(), root);
1854 IS_IMM_WITH_VAL(Int, root->lhs(), 2);
1855 IS_NODE_WITH_NAME(Add, root->rhs(), add1);
1856 IS_NODE_WITH_NAME(Add, add1->lhs(), add2);
1857
1858 IS_VAR_WITH_NAME(add2->lhs(), "y");
1859 IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul);
1860 IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul);
1861
1862 IS_IMM_WITH_VAL(Int, xmul->lhs(), 4);
1863 IS_VAR_WITH_NAME(xmul->rhs(), "x");
1864
1865 IS_IMM_WITH_VAL(Int, zmul->lhs(), 3);
1866 IS_VAR_WITH_NAME(zmul->rhs(), "z");
1867}
1868
1869// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y)
1870// This is kind of a placeholder test for variable factorization.
1871TEST(Simplify, SimplifyDeeperTerms) {
1872 VarHandle x("x", kInt);
1873 VarHandle y("y", kInt);
1874 ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y);
1875 ExprHandle simplified = IRSimplifier::simplify(body);
1876
1877 IS_NODE_WITH_NAME(Add, simplified.node(), add);
1878
1879 IS_NODE_WITH_NAME(Mul, add->lhs(), lhs);
1880 IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
1881 IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm);
1882 IS_VAR_WITH_NAME(xxTerm->lhs(), "x");
1883 IS_VAR_WITH_NAME(xxTerm->rhs(), "x");
1884
1885 IS_NODE_WITH_NAME(Mul, add->rhs(), rhs);
1886 IS_IMM_WITH_VAL(Int, rhs->lhs(), 3);
1887 IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm);
1888 IS_VAR_WITH_NAME(xyTerm->lhs(), "x");
1889 IS_VAR_WITH_NAME(xyTerm->rhs(), "y");
1890}
1891
1892// Tests the difference between two less trivial expressions.
1893// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1
1894TEST(Simplify, SimplifyDeeperDifference) {
1895 VarHandle n("n", kInt);
1896 VarHandle n_1("n_1", kInt);
1897 VarHandle m("m", kInt);
1898 ExprHandle body =
1899 (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n);
1900 ExprHandle simplified = IRSimplifier::simplify(body);
1901
1902 IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1903}
1904
1905// Test constant folding into the difference between expressions.
1906// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3
1907TEST(Simplify, SimplifyFoldComplexDifference) {
1908 VarHandle n("n", kInt);
1909 VarHandle n_1("n_1", kInt);
1910 VarHandle m("m", kInt);
1911 ExprHandle body =
1912 (IntImm::make(2) +
1913 (Cast::make(
1914 kChar,
1915 (m * (ExprHandle(1) * n_1) + (n + 1)) -
1916 (m * (ExprHandle(1) * n_1) + n))));
1917 ExprHandle simplified = IRSimplifier::simplify(body);
1918 IS_IMM_WITH_VAL(Int, simplified.node(), 3);
1919}
1920
1921TEST(Simplify, SimplifyIfComponents) {
1922 VarHandle x("x", kInt);
1923 VarHandle y("y", kInt);
1924 ExprHandle body = IfThenElse::make(
1925 ((ExprHandle(5) - ExprHandle(4)) * x) > y,
1926 ExprHandle(2) * x - x,
1927 ExprHandle(2) * y - y);
1928
1929 ExprHandle simplified = IRSimplifier::simplify(body);
1930
1931 IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr);
1932
1933 IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp);
1934 ASSERT_EQ(cmp->compare_select_op(), kGT);
1935 IS_VAR_WITH_NAME(cmp->lhs(), "x");
1936 IS_VAR_WITH_NAME(cmp->rhs(), "y");
1937
1938 IS_VAR_WITH_NAME(ifexpr->true_value(), "x");
1939 IS_VAR_WITH_NAME(ifexpr->false_value(), "y");
1940}
1941
1942TEST(Simplify, SimplifyOpaqueTerms) {
1943 VarHandle x("x", kInt);
1944 VarHandle y("y", kInt);
1945
1946 {
1947 // 2 * x/y * y - x/y * y => x/y * y
1948 ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y);
1949 ExprHandle simplified = IRSimplifier::simplify(body);
1950
1951 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
1952 IS_NODE_WITH_NAME(Div, mul->lhs(), div);
1953 IS_VAR_WITH_NAME(div->lhs(), "x");
1954 IS_VAR_WITH_NAME(div->rhs(), "y");
1955 IS_VAR_WITH_NAME(mul->rhs(), "y");
1956 }
1957
1958 {
1959 // x%y - (x%y - 1) => 1
1960 ExprHandle body = (x % y) - ((x % y) - 1);
1961 ExprHandle simplified = IRSimplifier::simplify(body);
1962
1963 IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1964 }
1965}
1966
1967TEST(Simplify, SimplifySymbolicMinMax) {
1968 {
1969 // Minimum with constant difference between terms.
1970 VarHandle x("x", kInt);
1971 ExprHandle body = Min::make(x + 3, x + 7, true);
1972 ExprHandle simplified = IRSimplifier::simplify(body);
1973
1974 IS_NODE_WITH_NAME(Add, simplified.node(), add);
1975 IS_VAR_WITH_NAME(add->lhs(), "x");
1976 IS_IMM_WITH_VAL(Int, add->rhs(), 3);
1977 }
1978
1979 {
1980 // Maximum with constant difference between terms.
1981 VarHandle x("x", kInt);
1982 ExprHandle body = Max::make(x + 3, x + 7, true);
1983 ExprHandle simplified = IRSimplifier::simplify(body);
1984
1985 IS_NODE_WITH_NAME(Add, simplified.node(), add);
1986 IS_VAR_WITH_NAME(add->lhs(), "x");
1987 IS_IMM_WITH_VAL(Int, add->rhs(), 7);
1988 }
1989
1990 {
1991 // Can't simplify multiples because of signedness of variable component.
1992 // TODO: maybe we could for unsigned types?
1993 VarHandle x("x", kInt);
1994 ExprHandle body = Max::make(x * 3, x * 7, true);
1995 ExprHandle simplified = IRSimplifier::simplify(body);
1996
1997 IS_NODE(Max, simplified.node());
1998 }
1999}
2000
2001TEST(Simplify, SimplifyNestedMax) {
2002 VarHandle x("x", kInt);
2003 VarHandle y("y", kInt);
2004 VarHandle z("z", kInt);
2005
2006 {
2007 // Max(x + y, x + y) => x + y
2008 ExprHandle body = Max::make(x + y, x + y, true);
2009 ExprHandle simplified = IRSimplifier::simplify(body);
2010
2011 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2012 IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
2013 }
2014
2015 {
2016 // Max(x + y, Max(x + y, z)) => Max(x + y, z)
2017 ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true);
2018 ExprHandle simplified = IRSimplifier::simplify(body);
2019
2020 IS_NODE_WITH_NAME(Max, simplified.node(), max);
2021 IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2022 IS_VAR_WITH_NAME(max->rhs(), "z");
2023 }
2024
2025 {
2026 // Max(x + y, Max(z, x + y)) => Max(x + y, z)
2027 ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true);
2028 ExprHandle simplified = IRSimplifier::simplify(body);
2029
2030 IS_NODE_WITH_NAME(Max, simplified.node(), max);
2031 IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2032 IS_VAR_WITH_NAME(max->rhs(), "z");
2033 }
2034
2035 {
2036 // Max(Max(x + y, z), x + y) => Max(x + y, z)
2037 ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true);
2038 ExprHandle simplified = IRSimplifier::simplify(body);
2039
2040 IS_NODE_WITH_NAME(Max, simplified.node(), max);
2041 IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2042 IS_VAR_WITH_NAME(max->rhs(), "z");
2043 }
2044
2045 {
2046 // Max(Max(z, x + y), x + y) => Max(x + y, z)
2047 ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true);
2048 ExprHandle simplified = IRSimplifier::simplify(body);
2049
2050 IS_NODE_WITH_NAME(Max, simplified.node(), max);
2051 IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y");
2052 IS_VAR_WITH_NAME(max->rhs(), "z");
2053 }
2054
2055 {
2056 // Max(Max(x, y), x) => Max(Max(x, y), x)
2057 // Nested Max ops with different propagate_nans should not be simplified.
2058 ExprHandle body = Max::make(Max::make(x, y, true), x, false);
2059 ExprHandle simplified = IRSimplifier::simplify(body);
2060
2061 IS_NODE_WITH_NAME(Max, simplified.node(), max);
2062 IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y");
2063 ASSERT_TRUE(max1->propagate_nans());
2064 IS_VAR_WITH_NAME(max->rhs(), "x");
2065 ASSERT_FALSE(max->propagate_nans());
2066 }
2067
2068 {
2069 // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x)
2070 ExprHandle body =
2071 Max::make(Min::make(x, y, true), Min::make(x, z, true), true);
2072 ExprHandle simplified = IRSimplifier::simplify(body);
2073 checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2074 }
2075
2076 {
2077 // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x)
2078 ExprHandle body =
2079 Max::make(Min::make(x, y, true), Min::make(z, x, true), true);
2080 ExprHandle simplified = IRSimplifier::simplify(body);
2081 checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2082 }
2083
2084 {
2085 // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x)
2086 ExprHandle body =
2087 Max::make(Min::make(y, x, true), Min::make(x, z, true), true);
2088 ExprHandle simplified = IRSimplifier::simplify(body);
2089 checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2090 }
2091
2092 {
2093 // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x)
2094 ExprHandle body =
2095 Max::make(Min::make(y, x, true), Min::make(z, x, true), true);
2096 ExprHandle simplified = IRSimplifier::simplify(body);
2097 checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)");
2098 }
2099
2100 {
2101 // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z))
2102 // When all the ops in the pattern do not have the same propagate_nans,
2103 // it should not be simplified.
2104 ExprHandle body =
2105 Max::make(Min::make(y, x, true), Min::make(z, x, false), true);
2106 ExprHandle simplified = IRSimplifier::simplify(body);
2107
2108 IS_NODE_WITH_NAME(Max, simplified.node(), max);
2109 IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y");
2110 ASSERT_TRUE(min1->propagate_nans());
2111 IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z");
2112 ASSERT_FALSE(min2->propagate_nans());
2113 ASSERT_TRUE(max->propagate_nans());
2114 }
2115
2116 {
2117 // Max(5, Max(x, 8)) => Max(x, 8)
2118 ExprHandle body = Max::make(5, Max::make(x, 8, true), true);
2119 ExprHandle simplified = IRSimplifier::simplify(body);
2120
2121 IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2122 ASSERT_TRUE(max->propagate_nans());
2123 }
2124
2125 {
2126 // Max(8, Max(x, 5)) => Max(x, 8)
2127 ExprHandle body = Max::make(8, Max::make(x, 5, true), true);
2128 ExprHandle simplified = IRSimplifier::simplify(body);
2129
2130 IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2131 ASSERT_TRUE(max->propagate_nans());
2132 }
2133
2134 {
2135 // Max(Max(x, 8), 5) => Max(x, 8)
2136 ExprHandle body = Max::make(Max::make(x, 8, true), 5, true);
2137 ExprHandle simplified = IRSimplifier::simplify(body);
2138
2139 IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2140 ASSERT_TRUE(max->propagate_nans());
2141 }
2142
2143 {
2144 // Max(Max(x, 5), 8) => Max(x, 8)
2145 ExprHandle body = Max::make(Max::make(x, 5, true), 8, true);
2146 ExprHandle simplified = IRSimplifier::simplify(body);
2147
2148 IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8);
2149 ASSERT_TRUE(max->propagate_nans());
2150 }
2151
2152 {
2153 // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z)
2154 ExprHandle body = Max::make(
2155 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true);
2156 ExprHandle simplified = IRSimplifier::simplify(body);
2157
2158 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2159 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2160 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2161 ASSERT_TRUE(max3->propagate_nans());
2162 IS_VAR_WITH_NAME(max2->rhs(), "y");
2163 IS_VAR_WITH_NAME(max1->rhs(), "z");
2164 }
2165
2166 {
2167 // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z)
2168 ExprHandle body = Max::make(
2169 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true);
2170 ExprHandle simplified = IRSimplifier::simplify(body);
2171
2172 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2173 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2174 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2175 ASSERT_TRUE(max3->propagate_nans());
2176 IS_VAR_WITH_NAME(max2->rhs(), "y");
2177 IS_VAR_WITH_NAME(max1->rhs(), "z");
2178 }
2179
2180 {
2181 // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z)
2182 ExprHandle body = Max::make(
2183 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true);
2184 ExprHandle simplified = IRSimplifier::simplify(body);
2185
2186 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2187 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2188 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2189 ASSERT_TRUE(max3->propagate_nans());
2190 IS_VAR_WITH_NAME(max2->rhs(), "y");
2191 IS_VAR_WITH_NAME(max1->rhs(), "z");
2192 }
2193
2194 {
2195 // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z)
2196 ExprHandle body = Max::make(
2197 Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true);
2198 ExprHandle simplified = IRSimplifier::simplify(body);
2199
2200 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2201 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2202 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2203 ASSERT_TRUE(max3->propagate_nans());
2204 IS_VAR_WITH_NAME(max2->rhs(), "y");
2205 IS_VAR_WITH_NAME(max1->rhs(), "z");
2206 }
2207
2208 {
2209 // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z)
2210 ExprHandle body = Max::make(
2211 Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true);
2212 ExprHandle simplified = IRSimplifier::simplify(body);
2213
2214 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2215 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2216 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2217 ASSERT_TRUE(max3->propagate_nans());
2218 IS_VAR_WITH_NAME(max2->rhs(), "y");
2219 IS_VAR_WITH_NAME(max1->rhs(), "z");
2220 }
2221
2222 {
2223 // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z)
2224 ExprHandle body = Max::make(
2225 Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true);
2226 ExprHandle simplified = IRSimplifier::simplify(body);
2227
2228 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2229 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2230 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2231 ASSERT_TRUE(max3->propagate_nans());
2232 IS_VAR_WITH_NAME(max2->rhs(), "y");
2233 IS_VAR_WITH_NAME(max1->rhs(), "z");
2234 }
2235
2236 {
2237 // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8)
2238 // Do not simplify when all the Max ops do not have the same
2239 // propagate_nans.
2240 ExprHandle body = Max::make(
2241 Max::make(Max::make(Max::make(z, 5, true), y, false), x, true),
2242 8,
2243 false);
2244 ExprHandle simplified = IRSimplifier::simplify(body);
2245 checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)");
2246 }
2247
2248 {
2249 // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z)
2250 ExprHandle body = Max::make(
2251 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true);
2252 ExprHandle simplified = IRSimplifier::simplify(body);
2253
2254 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2255 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2256 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2257 ASSERT_TRUE(max3->propagate_nans());
2258 IS_VAR_WITH_NAME(max2->rhs(), "y");
2259 IS_VAR_WITH_NAME(max1->rhs(), "z");
2260 }
2261
2262 {
2263 // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z)
2264 ExprHandle body = Max::make(
2265 Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true);
2266 ExprHandle simplified = IRSimplifier::simplify(body);
2267
2268 IS_NODE_WITH_NAME(Max, simplified.node(), max1);
2269 IS_NODE_WITH_NAME(Max, max1->lhs(), max2);
2270 IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8);
2271 ASSERT_TRUE(max3->propagate_nans());
2272 IS_VAR_WITH_NAME(max2->rhs(), "y");
2273 IS_VAR_WITH_NAME(max1->rhs(), "z");
2274 }
2275}
2276
2277TEST(Simplify, SimplifyNestedMin) {
2278 VarHandle x("x", kInt);
2279 VarHandle y("y", kInt);
2280 VarHandle z("z", kInt);
2281
2282 {
2283 // Min(x + y, x + y) => x + y
2284 ExprHandle body = Min::make(x + y, x + y, true);
2285 ExprHandle simplified = IRSimplifier::simplify(body);
2286
2287 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2288 IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y");
2289 }
2290
2291 {
2292 // Min(x + y, Min(x + y, z)) => Min(x + y, z)
2293 ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true);
2294 ExprHandle simplified = IRSimplifier::simplify(body);
2295
2296 IS_NODE_WITH_NAME(Min, simplified.node(), min);
2297 IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2298 IS_VAR_WITH_NAME(min->rhs(), "z");
2299 }
2300
2301 {
2302 // Min(x + y, Min(z, x + y)) => Min(x + y, z)
2303 ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true);
2304 ExprHandle simplified = IRSimplifier::simplify(body);
2305
2306 IS_NODE_WITH_NAME(Min, simplified.node(), min);
2307 IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2308 IS_VAR_WITH_NAME(min->rhs(), "z");
2309 }
2310
2311 {
2312 // Min(Min(x + y, z), x + y) => Min(x + y, z)
2313 ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true);
2314 ExprHandle simplified = IRSimplifier::simplify(body);
2315
2316 IS_NODE_WITH_NAME(Min, simplified.node(), min);
2317 IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2318 IS_VAR_WITH_NAME(min->rhs(), "z");
2319 }
2320
2321 {
2322 // Min(Min(z, x + y), x + y) => Min(x + y, z)
2323 ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true);
2324 ExprHandle simplified = IRSimplifier::simplify(body);
2325
2326 IS_NODE_WITH_NAME(Min, simplified.node(), min);
2327 IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y");
2328 IS_VAR_WITH_NAME(min->rhs(), "z");
2329 }
2330
2331 {
2332 // Min(Min(x, y), x) => Min(Min(x, y), x)
2333 // Nested Min ops with different propagate_nans should not be simplified.
2334 ExprHandle body = Min::make(Min::make(x, y, true), x, false);
2335 ExprHandle simplified = IRSimplifier::simplify(body);
2336
2337 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2338 IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y");
2339 ASSERT_TRUE(min2->propagate_nans());
2340 IS_VAR_WITH_NAME(min1->rhs(), "x");
2341 ASSERT_FALSE(min1->propagate_nans());
2342 }
2343
2344 {
2345 // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x)
2346 ExprHandle body =
2347 Min::make(Max::make(x, y, true), Max::make(x, z, true), true);
2348 ExprHandle simplified = IRSimplifier::simplify(body);
2349 checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2350 }
2351
2352 {
2353 // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x)
2354 ExprHandle body =
2355 Min::make(Max::make(x, y, true), Max::make(z, x, true), true);
2356 ExprHandle simplified = IRSimplifier::simplify(body);
2357 checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2358 }
2359
2360 {
2361 // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x)
2362 ExprHandle body =
2363 Min::make(Max::make(y, x, true), Max::make(x, z, true), true);
2364 ExprHandle simplified = IRSimplifier::simplify(body);
2365 checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2366 }
2367
2368 {
2369 // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x)
2370 ExprHandle body =
2371 Min::make(Max::make(y, x, true), Max::make(z, x, true), true);
2372 ExprHandle simplified = IRSimplifier::simplify(body);
2373 checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)");
2374 }
2375
2376 {
2377 // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z))
2378 // When all the ops in the pattern do not have the same propagate_nans,
2379 // it should not be simplified.
2380 ExprHandle body =
2381 Min::make(Max::make(y, x, true), Max::make(z, x, false), true);
2382 ExprHandle simplified = IRSimplifier::simplify(body);
2383
2384 IS_NODE_WITH_NAME(Min, simplified.node(), min);
2385 IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y");
2386 ASSERT_TRUE(max1->propagate_nans());
2387 IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z");
2388 ASSERT_FALSE(max2->propagate_nans());
2389 ASSERT_TRUE(min->propagate_nans());
2390 }
2391
2392 {
2393 // Min(5, Min(x, 8)) => Min(x, 8)
2394 ExprHandle body = Min::make(5, Min::make(x, 8, true), true);
2395 ExprHandle simplified = IRSimplifier::simplify(body);
2396
2397 IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2398 ASSERT_TRUE(min->propagate_nans());
2399 }
2400
2401 {
2402 // Min(8, Min(x, 5)) => Min(x, 8)
2403 ExprHandle body = Min::make(8, Min::make(x, 5, true), true);
2404 ExprHandle simplified = IRSimplifier::simplify(body);
2405
2406 IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2407 ASSERT_TRUE(min->propagate_nans());
2408 }
2409
2410 {
2411 // Min(Min(x, 8), 5) => Min(x, 8)
2412 ExprHandle body = Min::make(Min::make(x, 8, true), 5, true);
2413 ExprHandle simplified = IRSimplifier::simplify(body);
2414
2415 IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2416 ASSERT_TRUE(min->propagate_nans());
2417 }
2418
2419 {
2420 // Min(Min(x, 5), 8) => Min(x, 8)
2421 ExprHandle body = Min::make(Min::make(x, 5, true), 8, true);
2422 ExprHandle simplified = IRSimplifier::simplify(body);
2423
2424 IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5);
2425 ASSERT_TRUE(min->propagate_nans());
2426 }
2427
2428 {
2429 // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z)
2430 ExprHandle body = Min::make(
2431 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true);
2432 ExprHandle simplified = IRSimplifier::simplify(body);
2433
2434 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2435 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2436 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2437 ASSERT_TRUE(min3->propagate_nans());
2438 IS_VAR_WITH_NAME(min2->rhs(), "y");
2439 IS_VAR_WITH_NAME(min1->rhs(), "z");
2440 }
2441
2442 {
2443 // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z)
2444 ExprHandle body = Min::make(
2445 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true);
2446 ExprHandle simplified = IRSimplifier::simplify(body);
2447
2448 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2449 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2450 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2451 ASSERT_TRUE(min3->propagate_nans());
2452 IS_VAR_WITH_NAME(min2->rhs(), "y");
2453 IS_VAR_WITH_NAME(min1->rhs(), "z");
2454 }
2455
2456 {
2457 // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z)
2458 ExprHandle body = Min::make(
2459 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true);
2460 ExprHandle simplified = IRSimplifier::simplify(body);
2461
2462 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2463 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2464 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2465 ASSERT_TRUE(min3->propagate_nans());
2466 IS_VAR_WITH_NAME(min2->rhs(), "y");
2467 IS_VAR_WITH_NAME(min1->rhs(), "z");
2468 }
2469
2470 {
2471 // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z)
2472 ExprHandle body = Min::make(
2473 Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true);
2474 ExprHandle simplified = IRSimplifier::simplify(body);
2475
2476 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2477 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2478 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2479 ASSERT_TRUE(min3->propagate_nans());
2480 IS_VAR_WITH_NAME(min2->rhs(), "y");
2481 IS_VAR_WITH_NAME(min1->rhs(), "z");
2482 }
2483
2484 {
2485 // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z)
2486 ExprHandle body = Min::make(
2487 Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true);
2488 ExprHandle simplified = IRSimplifier::simplify(body);
2489
2490 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2491 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2492 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2493 ASSERT_TRUE(min3->propagate_nans());
2494 IS_VAR_WITH_NAME(min2->rhs(), "y");
2495 IS_VAR_WITH_NAME(min1->rhs(), "z");
2496 }
2497
2498 {
2499 // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z)
2500 ExprHandle body = Min::make(
2501 Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true);
2502 ExprHandle simplified = IRSimplifier::simplify(body);
2503
2504 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2505 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2506 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2507 ASSERT_TRUE(min3->propagate_nans());
2508 IS_VAR_WITH_NAME(min2->rhs(), "y");
2509 IS_VAR_WITH_NAME(min1->rhs(), "z");
2510 }
2511
2512 {
2513 // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8)
2514 // Do not simplify when all the Min ops do not have the same
2515 // propagate_nans.
2516 ExprHandle body = Min::make(
2517 Min::make(Min::make(Min::make(z, 5, true), y, false), x, true),
2518 8,
2519 false);
2520 ExprHandle simplified = IRSimplifier::simplify(body);
2521 checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)");
2522 }
2523
2524 {
2525 // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z)
2526 ExprHandle body = Min::make(
2527 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true);
2528 ExprHandle simplified = IRSimplifier::simplify(body);
2529
2530 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2531 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2532 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2533 ASSERT_TRUE(min3->propagate_nans());
2534 IS_VAR_WITH_NAME(min2->rhs(), "y");
2535 IS_VAR_WITH_NAME(min1->rhs(), "z");
2536 }
2537
2538 {
2539 // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z)
2540 ExprHandle body = Min::make(
2541 Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true);
2542 ExprHandle simplified = IRSimplifier::simplify(body);
2543
2544 IS_NODE_WITH_NAME(Min, simplified.node(), min1);
2545 IS_NODE_WITH_NAME(Min, min1->lhs(), min2);
2546 IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5);
2547 ASSERT_TRUE(min3->propagate_nans());
2548 IS_VAR_WITH_NAME(min2->rhs(), "y");
2549 IS_VAR_WITH_NAME(min1->rhs(), "z");
2550 }
2551}
2552
2553TEST(Simplify, SimplifyWontReorderFloat) {
2554 {
2555 // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y)
2556 // This is an expression we can simplify.
2557 VarHandle x("x", kInt);
2558 VarHandle y("y", kInt);
2559
2560 ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
2561 ExprHandle(3) * (ExprHandle(3) * y);
2562 ExprHandle simplified = IRSimplifier::simplify(body);
2563
2564 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2565 IS_IMM_WITH_VAL(Int, mul->lhs(), 9);
2566 IS_NODE_WITH_NAME(Sub, mul->rhs(), sub);
2567 IS_VAR_WITH_NAME(sub->lhs(), "x");
2568 IS_VAR_WITH_NAME(sub->rhs(), "y");
2569 }
2570
2571 {
2572 // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y).
2573 // If the vars are floating point, ops are not associative and we can't
2574 // reorder.
2575 VarHandle x("x", kFloat);
2576 VarHandle y("y", kFloat);
2577
2578 ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
2579 ExprHandle(3) * (ExprHandle(3) * y);
2580 ExprHandle simplified = IRSimplifier::simplify(body);
2581
2582 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2583 IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
2584 IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
2585 IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
2586 IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3);
2587 IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
2588
2589 IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
2590 IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
2591 IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
2592 IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
2593 IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y");
2594 }
2595
2596 {
2597 // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y).
2598 // We will simplify subexprs if they dont reorder floating point ops.
2599 VarHandle x("x", kDouble);
2600 VarHandle y("y", kInt);
2601
2602 ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) -
2603 ExprHandle(3) * (ExprHandle(3) * y);
2604 ExprHandle simplified = IRSimplifier::simplify(body);
2605
2606 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2607 IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
2608 IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3);
2609 IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul);
2610 IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3);
2611 IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
2612
2613 IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double);
2614 IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9);
2615 IS_VAR_WITH_NAME(rhsMul->rhs(), "y");
2616 }
2617
2618 {
2619 // Prevent reordering if FP propagated from dtypes.
2620 VarHandle x("x", kInt);
2621 VarHandle y("y", kInt);
2622
2623 ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) -
2624 ExprHandle(3) * (ExprHandle(3.f) * y);
2625 ExprHandle simplified = IRSimplifier::simplify(body);
2626
2627 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2628 IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul);
2629 IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3);
2630 IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float);
2631 IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3);
2632 IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x");
2633
2634 IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul);
2635 IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3);
2636 IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul);
2637 IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3);
2638 IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast);
2639 IS_VAR_WITH_NAME(yCast->src_value(), "y");
2640 }
2641
2642 {
2643 VarHandle x("x", kFloat);
2644 VarHandle y("y", kFloat);
2645 // x%y - (x%y - 1) => x%y - (x%y - 1).
2646 // We wont reorder opaque ops if they are FP.
2647 ExprHandle body = (x % y) - ((x % y) - 1);
2648 ExprHandle simplified = IRSimplifier::simplify(body);
2649
2650 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
2651 IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod);
2652 IS_VAR_WITH_NAME(lhsMod->lhs(), "x");
2653 IS_VAR_WITH_NAME(lhsMod->rhs(), "y");
2654
2655 IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub);
2656 IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod);
2657 IS_VAR_WITH_NAME(rhsMod->lhs(), "x");
2658 IS_VAR_WITH_NAME(rhsMod->rhs(), "y");
2659 IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1);
2660 }
2661}
2662
2663TEST(Simplify, SimplifyRoundModPattern) {
2664 {
2665 // (x/y)*y + x%y => x.
2666 VarHandle x("x", kInt);
2667 VarHandle y("y", kInt);
2668 ExprHandle body = ((x / y) * y) + (x % y);
2669 ExprHandle simplified = IRSimplifier::simplify(body);
2670 IS_VAR_WITH_NAME(simplified.node(), "x");
2671 }
2672
2673 {
2674 // Reverse order.
2675 // x%y + (x/y)*y => x.
2676 VarHandle x("x", kInt);
2677 VarHandle y("y", kInt);
2678 ExprHandle body = (x % y) + ((x / y) * y);
2679 ExprHandle simplified = IRSimplifier::simplify(body);
2680 IS_VAR_WITH_NAME(simplified.node(), "x");
2681 }
2682
2683 {
2684 // Non opaque denominator.
2685 // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x.
2686 VarHandle x("x", kInt);
2687 VarHandle y("y", kInt);
2688 ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) +
2689 (x % (y + ExprHandle(4)));
2690 ExprHandle simplified = IRSimplifier::simplify(body);
2691 IS_VAR_WITH_NAME(simplified.node(), "x");
2692 }
2693
2694 {
2695 // Reverse order.
2696 // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x.
2697 VarHandle x("x", kInt);
2698 VarHandle y("y", kInt);
2699 ExprHandle body = (x % (y + ExprHandle(4))) +
2700 ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y));
2701 ExprHandle simplified = IRSimplifier::simplify(body);
2702 IS_VAR_WITH_NAME(simplified.node(), "x");
2703 }
2704
2705 {
2706 // Opaque denominator.
2707 // (x / (2/y)) * (2/y)) + (x % (2/y)) => x.
2708 VarHandle x("x", kInt);
2709 VarHandle y("y", kInt);
2710 ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) +
2711 (x % (ExprHandle(2) / y));
2712 ExprHandle simplified = IRSimplifier::simplify(body);
2713 IS_VAR_WITH_NAME(simplified.node(), "x");
2714 }
2715
2716 {
2717 // Non opaque numerator
2718 // ((2*x)/y * y) + ((2*x) % y) => 2 * x.
2719 VarHandle x("x", kInt);
2720 VarHandle y("y", kInt);
2721 ExprHandle body =
2722 (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y);
2723 ExprHandle simplified = IRSimplifier::simplify(body);
2724 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2725 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2726 IS_VAR_WITH_NAME(mul->rhs(), "x");
2727 }
2728
2729 {
2730 // Opaque numerator.
2731 // ((x/2) / y * y) + (x/2 % y) => x / 2.
2732 VarHandle x("x", kInt);
2733 VarHandle y("y", kInt);
2734 ExprHandle body =
2735 (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y);
2736 ExprHandle simplified = IRSimplifier::simplify(body);
2737
2738 IS_NODE_WITH_NAME(Div, simplified.node(), div);
2739 IS_VAR_WITH_NAME(div->lhs(), "x");
2740 IS_IMM_WITH_VAL(Int, div->rhs(), 2);
2741 }
2742
2743 {
2744 // Numerator and denominator.
2745 // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x.
2746 VarHandle x("x", kInt);
2747 VarHandle y("y", kInt);
2748 ExprHandle body =
2749 (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) +
2750 ((ExprHandle(2) * x) % (ExprHandle(2) * y));
2751 ExprHandle simplified = IRSimplifier::simplify(body);
2752
2753 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2754 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2755 IS_VAR_WITH_NAME(mul->rhs(), "x");
2756 }
2757
2758 {
2759 // Reverse order.
2760 // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x.
2761 VarHandle x("x", kInt);
2762 VarHandle y("y", kInt);
2763 ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) +
2764 (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y));
2765 ExprHandle simplified = IRSimplifier::simplify(body);
2766
2767 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2768 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2769 IS_VAR_WITH_NAME(mul->rhs(), "x");
2770 }
2771
2772 {
2773 // Negated Subtraction of Round Mod.
2774 // (x/y) * y - (0 - x%y) => x.
2775 VarHandle x("x", kInt);
2776 VarHandle y("y", kInt);
2777 ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y));
2778 ExprHandle simplified = IRSimplifier::simplify(body);
2779 IS_VAR_WITH_NAME(simplified.node(), "x");
2780 }
2781
2782 {
2783 // Other terms are preserved.
2784 // (x/y)*y + x%y + (y * x) => x + (y * x).
2785 VarHandle x("x", kInt);
2786 VarHandle y("y", kInt);
2787 ExprHandle body = ((x / y) * y) + (x % y) + (y * x);
2788 ExprHandle simplified = IRSimplifier::simplify(body);
2789 IS_NODE_WITH_NAME(Add, simplified.node(), add);
2790 IS_VAR_WITH_NAME(add->lhs(), "x");
2791 IS_NODE_WITH_NAME(Mul, add->rhs(), mul);
2792 IS_VAR_WITH_NAME(mul->lhs(), "x");
2793 IS_VAR_WITH_NAME(mul->rhs(), "y");
2794 }
2795
2796 {
2797 // Sanity checking we wont do the optimization on floats.
2798 VarHandle x("x", kFloat);
2799 VarHandle y("y", kFloat);
2800 ExprHandle body = ((x / y) * y) + (x % y);
2801 ExprHandle simplified = IRSimplifier::simplify(body);
2802 IS_NODE_WITH_NAME(Add, simplified.node(), add);
2803 IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul);
2804 IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv);
2805 IS_VAR_WITH_NAME(roundDiv->lhs(), "x");
2806 IS_VAR_WITH_NAME(roundDiv->rhs(), "y");
2807 IS_VAR_WITH_NAME(roundMul->rhs(), "y");
2808 IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
2809 IS_VAR_WITH_NAME(mod->lhs(), "x");
2810 IS_VAR_WITH_NAME(mod->rhs(), "y");
2811 }
2812
2813 {
2814 // Sanity check we wont do it if the mod term doesn't match.
2815 VarHandle x("x", kInt);
2816 VarHandle y("y", kInt);
2817 VarHandle z("z", kInt);
2818 ExprHandle body = ((x / y) * y) + (x % z);
2819 ExprHandle simplified = IRSimplifier::simplify(body);
2820 checkExprIR(simplified, "(x / y) * y + x % z");
2821 }
2822
2823 {
2824 // Sanity check we wont do it if the div term doesn't match.
2825 VarHandle x("x", kInt);
2826 VarHandle y("y", kInt);
2827 VarHandle z("z", kInt);
2828 ExprHandle body = (y * (x / z)) + (x % y);
2829 ExprHandle simplified = IRSimplifier::simplify(body);
2830 checkExprIR(simplified, "x % y + (x / z) * y");
2831 }
2832
2833 {
2834 // Sanity check we wont do it if the mul term doesn't match.
2835 VarHandle x("x", kInt);
2836 VarHandle y("y", kInt);
2837 VarHandle z("z", kInt);
2838 ExprHandle body = ((x / y) * z) + (x % y);
2839 ExprHandle simplified = IRSimplifier::simplify(body);
2840 checkExprIR(simplified, "x % y + (x / y) * z");
2841 }
2842}
2843
2844TEST(Simplify, SimplifyRoundModPatternFactorization) {
2845 {
2846 // Full factorization.
2847 // 2 * (x/y * y) + 2 * (x%y) => 2 * x.
2848 VarHandle x("x", kInt);
2849 VarHandle y("y", kInt);
2850 ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y);
2851 ExprHandle simplified = IRSimplifier::simplify(body);
2852 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2853 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2854 IS_VAR_WITH_NAME(mul->rhs(), "x");
2855 }
2856
2857 {
2858 // Partial Factorization.
2859 // 32 * (x/8) + 4 * (x % 8) => 4 * x.
2860 VarHandle x("x", kInt);
2861 VarHandle y("y", kInt);
2862 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
2863 ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8);
2864 ExprHandle simplified = IRSimplifier::simplify(body);
2865
2866 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2867 IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
2868 IS_VAR_WITH_NAME(mul->rhs(), "x");
2869 }
2870
2871 {
2872 // Factorization requiring constant folding.
2873 // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x.
2874 VarHandle x("x", kInt);
2875 ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) +
2876 (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1));
2877 ExprHandle simplified = IRSimplifier::simplify(body);
2878 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2879 IS_IMM_WITH_VAL(Int, mul->lhs(), 5);
2880 IS_VAR_WITH_NAME(mul->rhs(), "x");
2881 }
2882
2883 {
2884 VarHandle x("x", kInt);
2885 ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5);
2886 ExprHandle simplified = IRSimplifier::simplify(body);
2887 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
2888 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2889 IS_VAR_WITH_NAME(mul->rhs(), "x");
2890 }
2891
2892 {
2893 VarHandle x("x", kInt);
2894 ExprHandle body = (x / 10) * 0 + x % 5;
2895 ExprHandle simplified = IRSimplifier::simplify(body);
2896 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2897 IS_VAR_WITH_NAME(mod->lhs(), "x");
2898 IS_IMM_WITH_VAL(Int, mod->rhs(), 5);
2899 }
2900}
2901
2902TEST(Simplify, SimplifyRoundModPatternMultivar) {
2903 {
2904 // Multivar.
2905 // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y.
2906 VarHandle x("x", kInt);
2907 VarHandle y("y", kInt);
2908 ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) +
2909 (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5);
2910 ExprHandle simplified = IRSimplifier::simplify(body);
2911 IS_NODE_WITH_NAME(Add, simplified.node(), add);
2912 IS_VAR_WITH_NAME(add->lhs(), "x");
2913 IS_VAR_WITH_NAME(add->rhs(), "y");
2914 }
2915
2916 {
2917 // Find the right var.
2918 // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8
2919 VarHandle x("x", kInt);
2920 VarHandle y("y", kInt);
2921 VarHandle z("z", kInt);
2922 ExprHandle body =
2923 (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8);
2924 ExprHandle simplified = IRSimplifier::simplify(body);
2925 IS_NODE_WITH_NAME(Add, simplified.node(), add);
2926 IS_NODE_WITH_NAME(Add, add->lhs(), add2);
2927 IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod);
2928 IS_VAR_WITH_NAME(xMod->lhs(), "x");
2929 IS_IMM_WITH_VAL(Int, xMod->rhs(), 8);
2930 IS_VAR_WITH_NAME(add2->rhs(), "y");
2931 IS_NODE_WITH_NAME(Mod, add->rhs(), zMod);
2932 IS_VAR_WITH_NAME(zMod->lhs(), "z");
2933 IS_IMM_WITH_VAL(Int, zMod->rhs(), 8);
2934 }
2935
2936 {
2937 // Compound.
2938 // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16)
2939 // => (z + 512 * y) + x
2940 VarHandle x("x", kInt);
2941 VarHandle y("y", kInt);
2942 VarHandle z("z", kInt);
2943
2944 ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16);
2945 ExprHandle simplified = IRSimplifier::simplify(body);
2946 checkExprIR(simplified, "x + (z + 512 * y)");
2947 }
2948}
2949
2950TEST(Simplify, SimplifyModRoundModPattern) {
2951 {
2952 // t/7 % 9 * 7 + t % 7 => t%63
2953 VarHandle t("t", kInt);
2954 ExprHandle body = (t / 7 % 9) * 7 + t % 7;
2955 ExprHandle simplified = IRSimplifier::simplify(body);
2956 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2957 IS_VAR_WITH_NAME(mod->lhs(), "t");
2958 IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
2959 }
2960
2961 {
2962 // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63
2963 VarHandle t("t", kInt);
2964 ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7;
2965 ExprHandle simplified = IRSimplifier::simplify(body);
2966 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2967 IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
2968 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
2969 IS_VAR_WITH_NAME(mul->rhs(), "t");
2970 IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
2971 }
2972
2973 {
2974 // t/x % y * x + t % x => t%(x*y)
2975 VarHandle t("t", kInt);
2976 VarHandle x("x", kInt);
2977 VarHandle y("y", kInt);
2978 ExprHandle body = (t / x % y) * x + t % x;
2979 ExprHandle simplified = IRSimplifier::simplify(body);
2980 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
2981 IS_VAR_WITH_NAME(mod->lhs(), "t");
2982 IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
2983 IS_VAR_WITH_NAME(mul->lhs(), "x");
2984 IS_VAR_WITH_NAME(mul->rhs(), "y");
2985 }
2986
2987 {
2988 // k*t/x % y * x + k*t % x => k*t%(x*y)
2989 VarHandle t("t", kInt);
2990 VarHandle x("x", kInt);
2991 VarHandle y("y", kInt);
2992 VarHandle k("k", kInt);
2993 ExprHandle body = (k * t / x % y) * x + k * t % x;
2994 ExprHandle simplified = IRSimplifier::simplify(body);
2995 checkExprIR(simplified, "(k * t) % (x * y)");
2996 }
2997
2998 {
2999 // t/k/x % y * x + t/k % x => t/k%(x*y)
3000 VarHandle t("t", kInt);
3001 VarHandle x("x", kInt);
3002 VarHandle y("y", kInt);
3003 VarHandle k("k", kInt);
3004 ExprHandle body = (t / k / x % y) * x + t / k % x;
3005 ExprHandle simplified = IRSimplifier::simplify(body);
3006 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
3007 IS_NODE_WITH_NAME(Div, mod->lhs(), div);
3008 IS_VAR_WITH_NAME(div->lhs(), "t");
3009 IS_VAR_WITH_NAME(div->rhs(), "k");
3010 IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
3011 IS_VAR_WITH_NAME(mul->lhs(), "x");
3012 IS_VAR_WITH_NAME(mul->rhs(), "y");
3013 }
3014
3015 {
3016 // Sanity checking we wont do the optimization on floats.
3017 VarHandle x("x", kFloat);
3018 VarHandle y("y", kFloat);
3019 VarHandle z("z", kFloat);
3020 ExprHandle body = ((x / y % z) * y) + (x % y);
3021 ExprHandle simplified = IRSimplifier::simplify(body);
3022 IS_NODE_WITH_NAME(Add, simplified.node(), add);
3023 IS_NODE_WITH_NAME(Mul, add->lhs(), mul);
3024 IS_NODE_WITH_NAME(Mod, mul->lhs(), mod);
3025 IS_NODE_WITH_NAME(Div, mod->lhs(), div);
3026 IS_VAR_WITH_NAME(div->lhs(), "x");
3027 IS_VAR_WITH_NAME(div->rhs(), "y");
3028 IS_VAR_WITH_NAME(mod->rhs(), "z");
3029 IS_VAR_WITH_NAME(mul->rhs(), "y");
3030 IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
3031 IS_VAR_WITH_NAME(mod2->lhs(), "x");
3032 IS_VAR_WITH_NAME(mod2->rhs(), "y");
3033 }
3034}
3035
3036TEST(Simplify, SimplifyModRoundModPatternFactorization) {
3037 {
3038 // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63)
3039 VarHandle t("t", kInt);
3040 ExprHandle body =
3041 ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7);
3042 ExprHandle simplified = IRSimplifier::simplify(body);
3043 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3044 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3045 IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
3046 IS_VAR_WITH_NAME(mod->lhs(), "t");
3047 IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
3048 }
3049
3050 {
3051 // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63)
3052 VarHandle t("t", kInt);
3053 ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7);
3054 ExprHandle simplified = IRSimplifier::simplify(body);
3055 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3056 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3057 IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
3058 IS_VAR_WITH_NAME(mod->lhs(), "t");
3059 IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
3060 }
3061
3062 {
3063 // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63
3064 VarHandle t("t", kInt);
3065 ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7;
3066 ExprHandle simplified = IRSimplifier::simplify(body);
3067 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
3068 IS_NODE_WITH_NAME(Div, mod->lhs(), div);
3069 IS_VAR_WITH_NAME(div->lhs(), "t");
3070 IS_IMM_WITH_VAL(Int, div->rhs(), 2);
3071 IS_IMM_WITH_VAL(Int, mod->rhs(), 63);
3072 }
3073
3074 {
3075 // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189
3076 VarHandle t("t", kInt);
3077 ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 +
3078 t % (ExprHandle(7) * ExprHandle(3));
3079 ExprHandle simplified = IRSimplifier::simplify(body);
3080 IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
3081 IS_VAR_WITH_NAME(mod->lhs(), "t");
3082 IS_IMM_WITH_VAL(Int, mod->rhs(), 189);
3083 }
3084
3085 {
3086 // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y))
3087 VarHandle t("t", kInt);
3088 VarHandle x("x", kInt);
3089 VarHandle y("y", kInt);
3090 ExprHandle body =
3091 ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x);
3092 ExprHandle simplified = IRSimplifier::simplify(body);
3093 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3094 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3095 IS_NODE_WITH_NAME(Mod, mul->rhs(), mod);
3096 IS_VAR_WITH_NAME(mod->lhs(), "t");
3097 IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2);
3098 IS_VAR_WITH_NAME(mul2->lhs(), "x");
3099 IS_VAR_WITH_NAME(mul2->rhs(), "y");
3100 }
3101}
3102
3103TEST(Simplify, SimplifyModRoundModPatternMultivar) {
3104 {
3105 // t/7 % 9 * 7 + t % 7 + t => t % 63 + t
3106 VarHandle t("t", kInt);
3107 ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t;
3108 ExprHandle simplified = IRSimplifier::simplify(body);
3109 checkExprIR(simplified, "t % 63 + t");
3110 }
3111
3112 {
3113 // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72
3114 VarHandle t("t", kInt);
3115 ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8;
3116 ExprHandle simplified = IRSimplifier::simplify(body);
3117 IS_NODE_WITH_NAME(Add, simplified.node(), add);
3118 IS_NODE_WITH_NAME(Mod, add->lhs(), mod1);
3119 IS_VAR_WITH_NAME(mod1->lhs(), "t");
3120 IS_IMM_WITH_VAL(Int, mod1->rhs(), 63);
3121 IS_NODE_WITH_NAME(Mod, add->rhs(), mod2);
3122 IS_VAR_WITH_NAME(mod2->lhs(), "t");
3123 IS_IMM_WITH_VAL(Int, mod2->rhs(), 72);
3124 }
3125
3126 {
3127 // k + t/x % y * x + t % x => k + t%(x*y)
3128 VarHandle t("t", kInt);
3129 VarHandle x("x", kInt);
3130 VarHandle y("y", kInt);
3131 VarHandle k("k", kInt);
3132 ExprHandle body = k + (t / x % y) * x + t % x;
3133 ExprHandle simplified = IRSimplifier::simplify(body);
3134 IS_NODE_WITH_NAME(Add, simplified.node(), add);
3135 IS_VAR_WITH_NAME(add->lhs(), "k");
3136 IS_NODE_WITH_NAME(Mod, add->rhs(), mod);
3137 IS_VAR_WITH_NAME(mod->lhs(), "t");
3138 IS_NODE_WITH_NAME(Mul, mod->rhs(), mul);
3139 IS_VAR_WITH_NAME(mul->lhs(), "x");
3140 IS_VAR_WITH_NAME(mul->rhs(), "y");
3141 }
3142
3143 {
3144 // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x
3145 // => t%(x*y) + t/k % (x*y)
3146 VarHandle t("t", kInt);
3147 VarHandle x("x", kInt);
3148 VarHandle y("y", kInt);
3149 VarHandle k("k", kInt);
3150 ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x;
3151 ExprHandle simplified = IRSimplifier::simplify(body);
3152 checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)");
3153 }
3154
3155 {
3156 // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63)
3157 // => io_flat
3158 VarHandle t("io_flat", kInt);
3159 ExprHandle body =
3160 ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63);
3161 ExprHandle simplified = IRSimplifier::simplify(body);
3162 IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3163 }
3164
3165 { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) +
3166 // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 +
3167 // (i0_flat / (9 * 7) % 10) * 7 * 9 +
3168 // (i0_flat / 7 % 9) * 7 +
3169 // i0_flat % 7 => io_flat
3170 VarHandle t("io_flat", kInt);
3171 ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) +
3172 (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 +
3173 (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7;
3174 ExprHandle simplified = IRSimplifier::simplify(body);
3175 IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3176 }
3177
3178 {
3179 // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) *
3180 // (i0_flat / (m * n)) => io_flat
3181 VarHandle t("io_flat", kInt);
3182 VarHandle m("m", kInt);
3183 VarHandle n("n", kInt);
3184 ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n));
3185 ExprHandle simplified = IRSimplifier::simplify(body);
3186 IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3187 }
3188
3189 { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) +
3190 // (i0_flat / (l * n * m) % k) * m * n * l +
3191 // (i0_flat / (n * m) % l) * m * n +
3192 // (i0_flat / m % n) * m +
3193 // i0_flat % m => io_flat
3194 VarHandle t("io_flat", kInt);
3195 VarHandle m("m", kInt);
3196 VarHandle n("n", kInt);
3197 VarHandle l("l", kInt);
3198 VarHandle k("k", kInt);
3199 ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) +
3200 (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n +
3201 (t / m % n) * m + t % m;
3202 ExprHandle simplified = IRSimplifier::simplify(body);
3203 IS_VAR_WITH_NAME(simplified.node(), "io_flat");
3204 }
3205}
3206
3207TEST(Simplify, SimplifyDivisionScalarFactorization) {
3208 {
3209 // Simple factorization of numerator and denominator.
3210 // 8x / 4y => 2x / y.
3211 VarHandle x("x", kInt);
3212 VarHandle y("y", kInt);
3213 ExprHandle body = (x * 8) / (y * 4);
3214 ExprHandle simplified = IRSimplifier::simplify(body);
3215 IS_NODE_WITH_NAME(Div, simplified.node(), div);
3216 IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3217 IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
3218 IS_VAR_WITH_NAME(lhs->rhs(), "x");
3219 IS_VAR_WITH_NAME(div->rhs(), "y");
3220 }
3221
3222 {
3223 // Don't change anything if we can't factorize.
3224 VarHandle x("x", kInt);
3225 VarHandle y("y", kInt);
3226 ExprHandle body = (x * 7) / (y * 4);
3227 ExprHandle simplified = IRSimplifier::simplify(body);
3228 IS_NODE_WITH_NAME(Div, simplified.node(), div);
3229 IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3230 IS_IMM_WITH_VAL(Int, lhs->lhs(), 7);
3231 IS_VAR_WITH_NAME(lhs->rhs(), "x");
3232 IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
3233 IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
3234 IS_VAR_WITH_NAME(rhs->rhs(), "y");
3235 }
3236
3237 {
3238 // Don't reorder floats.
3239 VarHandle x("x", kFloat);
3240 VarHandle y("y", kFloat);
3241 ExprHandle body = (x * 8) / (y * 4);
3242 ExprHandle simplified = IRSimplifier::simplify(body);
3243 IS_NODE_WITH_NAME(Div, simplified.node(), div);
3244 IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3245 IS_VAR_WITH_NAME(lhs->lhs(), "x");
3246 IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f);
3247 IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
3248 IS_VAR_WITH_NAME(rhs->lhs(), "y");
3249 IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f);
3250 }
3251
3252 {
3253 // Sanity check we do nothing if there are only scalar parts.
3254 VarHandle x("x", kInt);
3255 VarHandle y("y", kInt);
3256 ExprHandle body = (x * 1) / (y * 1);
3257 ExprHandle simplified = IRSimplifier::simplify(body);
3258 IS_NODE_WITH_NAME(Div, simplified.node(), div);
3259 IS_VAR_WITH_NAME(div->lhs(), "x");
3260 IS_VAR_WITH_NAME(div->rhs(), "y");
3261 }
3262
3263 {
3264 // Can factorize amounts of variables.
3265 VarHandle x("x", kInt);
3266 VarHandle y("y", kInt);
3267 ExprHandle body = (x + x + x + x) / (y + y);
3268 ExprHandle simplified = IRSimplifier::simplify(body);
3269 IS_NODE_WITH_NAME(Div, simplified.node(), div);
3270 IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
3271 IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
3272 IS_VAR_WITH_NAME(lhs->rhs(), "x");
3273 IS_VAR_WITH_NAME(div->rhs(), "y");
3274 }
3275}
3276
3277TEST(Simplify, SimplifyConstantBranches) {
3278 {
3279 // If the condition is constant true then take the true_value.
3280 // 1 ? x : y => x
3281 VarHandle x("x", kInt);
3282 VarHandle y("y", kInt);
3283 ExprHandle t(1);
3284 ExprHandle body = IfThenElse::make(t, x, y);
3285 ExprHandle simplified = IRSimplifier::simplify(body);
3286 IS_VAR_WITH_NAME(simplified.node(), "x");
3287 }
3288
3289 {
3290 // If the condition is constant false then take the false_value.
3291 // 0 ? x : y => y
3292 VarHandle x("x", kInt);
3293 VarHandle y("y", kInt);
3294 ExprHandle t(0);
3295 ExprHandle body = IfThenElse::make(t, x, y);
3296 ExprHandle simplified = IRSimplifier::simplify(body);
3297 IS_VAR_WITH_NAME(simplified.node(), "y");
3298 }
3299
3300 {
3301 // condition is simplified before checking.
3302 // (x-x) ? x : y => y
3303 VarHandle x("x", kInt);
3304 VarHandle y("y", kInt);
3305 ExprHandle body = IfThenElse::make(x - x, x, y);
3306 ExprHandle simplified = IRSimplifier::simplify(body);
3307 IS_VAR_WITH_NAME(simplified.node(), "y");
3308 }
3309
3310 {
3311 // If both branches are the same then don't do the condition.
3312 // y ? x : x => x
3313 VarHandle x("x", kInt);
3314 VarHandle y("y", kInt);
3315 ExprHandle body = IfThenElse::make(y, x, x);
3316 ExprHandle simplified = IRSimplifier::simplify(body);
3317 IS_VAR_WITH_NAME(simplified.node(), "x");
3318 }
3319
3320 {
3321 // If both branches simplify to the same thing it still works.
3322 // y ? (x + x) : (2 * x) => x
3323 VarHandle x("x", kInt);
3324 VarHandle y("y", kInt);
3325 ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x);
3326 ExprHandle simplified = IRSimplifier::simplify(body);
3327 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
3328 IS_IMM_WITH_VAL(Int, mul->lhs(), 2);
3329 IS_VAR_WITH_NAME(mul->rhs(), "x");
3330 }
3331}
3332
3333TEST(Simplify, SimplifyConstantCond) {
3334 {
3335 // If the condition is constant true then take the true_value.
3336 // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1
3337 BufHandle a("A", {1}, kInt);
3338 BufHandle b("B", {1}, kInt);
3339 ExprHandle condition(1);
3340 StmtPtr true_val = Store::make(a, {0}, 1);
3341 StmtPtr false_val = Store::make(b, {0}, 1);
3342
3343 CondPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3344 StmtPtr simplified = IRSimplifier::simplify(body);
3345 BlockPtr block = to<Block>(simplified);
3346 IS_NODE_WITH_NAME(Store, block->front(), store);
3347 IS_VAR_WITH_NAME(store->base_handle(), "A");
3348 }
3349
3350 {
3351 // If the condition is constant false then take the false_value.
3352 // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1
3353 BufHandle a("A", {1}, kInt);
3354 BufHandle b("B", {1}, kInt);
3355 ExprHandle condition(0);
3356 StmtPtr true_val = Store::make(a, {0}, 1);
3357 StmtPtr false_val = Store::make(b, {0}, 1);
3358
3359 StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3360 StmtPtr simplified = IRSimplifier::simplify(body);
3361 BlockPtr block = to<Block>(simplified);
3362 IS_NODE_WITH_NAME(Store, block->front(), store);
3363 IS_VAR_WITH_NAME(store->base_handle(), "B");
3364 }
3365
3366 {
3367 // condition is simplified before checking.
3368 // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1
3369 VarHandle x("x", kInt);
3370 BufHandle a("A", {1}, kInt);
3371 BufHandle b("B", {1}, kInt);
3372 ExprHandle condition(x - x);
3373 StmtPtr true_val = Store::make(a, {0}, 1);
3374 StmtPtr false_val = Store::make(b, {0}, 1);
3375
3376 StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3377 StmtPtr simplified = IRSimplifier::simplify(body);
3378 BlockPtr block = to<Block>(simplified);
3379 IS_NODE_WITH_NAME(Store, block->front(), store);
3380 IS_VAR_WITH_NAME(store->base_handle(), "B");
3381 }
3382
3383 {
3384 // If both branches are the same then don't do the condition.
3385 // x ? A[0] = x : A[0] = x => A[0] = x
3386 VarHandle x("x", kInt);
3387 BufHandle a("A", {1}, kInt);
3388 ExprHandle condition(x - x);
3389 StmtPtr true_val = Store::make(a, {0}, x);
3390 StmtPtr false_val = Store::make(a, {0}, x);
3391
3392 StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3393 StmtPtr simplified = IRSimplifier::simplify(body);
3394 BlockPtr block = to<Block>(simplified);
3395 IS_NODE_WITH_NAME(Store, block->front(), store);
3396 IS_VAR_WITH_NAME(store->base_handle(), "A");
3397 }
3398
3399 {
3400 // If both branches simplify to the same thing it still works.
3401 // x ? (x + x) : (2 * x) => x
3402 VarHandle x("x", kInt);
3403 BufHandle a("A", {1}, kInt);
3404 ExprHandle condition(x - x);
3405 StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x);
3406 StmtPtr false_val = Store::make(a, {0}, x + x);
3407
3408 StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3409 StmtPtr simplified = IRSimplifier::simplify(body);
3410 BlockPtr block = to<Block>(simplified);
3411 IS_NODE_WITH_NAME(Store, block->front(), store);
3412 IS_VAR_WITH_NAME(store->base_handle(), "A");
3413 }
3414
3415 {
3416 // But not if they dont
3417 // x ? x : (2 * x) => x ? x : (2 * x)
3418 VarHandle x("x", kInt);
3419 BufHandle a("A", {1}, kInt);
3420 ExprHandle condition(x);
3421 StmtPtr true_val = Store::make(a, {0}, x);
3422 StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x);
3423
3424 StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
3425 StmtPtr simplified = IRSimplifier::simplify(body);
3426 BlockPtr block = to<Block>(simplified);
3427 ASSERT_EQ(block, nullptr);
3428 }
3429
3430 {
3431 StmtPtr cond = alloc<Cond>(
3432 ExprHandle(false).node(),
3433 alloc<Block>(std::vector<StmtPtr>({})),
3434 nullptr);
3435 StmtPtr simplified = IRSimplifier::simplify(cond);
3436 ASSERT_EQ(simplified, nullptr);
3437 }
3438
3439 {
3440 StmtPtr cond = alloc<Cond>(
3441 ExprHandle(true).node(),
3442 nullptr,
3443 alloc<Block>(std::vector<StmtPtr>({})));
3444 StmtPtr simplified = IRSimplifier::simplify(cond);
3445 ASSERT_EQ(simplified, nullptr);
3446 }
3447}
3448
3449TEST(Simplify, SimplifyEliminateEmptyCond) {
3450 // If the branches are empty in different ways, eliminate.
3451 {
3452 VarHandle x("x", kInt);
3453 ExprHandle condition(x);
3454 StmtPtr true_val = alloc<Block>(std::vector<StmtPtr>({}));
3455
3456 StmtPtr body = alloc<Cond>(condition.node(), true_val, nullptr);
3457 StmtPtr simplified = IRSimplifier::simplify(body);
3458 BlockPtr block = to<Block>(simplified);
3459 ASSERT_NE(block, nullptr);
3460 ASSERT_EQ(block->nstmts(), 0);
3461 }
3462
3463 {
3464 VarHandle x("x", kInt);
3465 ExprHandle condition(x);
3466 StmtPtr false_val = alloc<Block>(std::vector<StmtPtr>({}));
3467
3468 StmtPtr body = alloc<Cond>(condition.node(), nullptr, false_val);
3469 StmtPtr simplified = IRSimplifier::simplify(body);
3470 BlockPtr block = to<Block>(simplified);
3471 ASSERT_NE(block, nullptr);
3472 ASSERT_EQ(block->nstmts(), 0);
3473 }
3474}
3475
3476TEST(Simplify, SimplifyConstantComparisons) {
3477 auto ComparisonTest =
3478 [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) {
3479 ExprHandle body = CompareSelect::make(a, b, op);
3480 ExprHandle simplified = IRSimplifier::simplify(body);
3481 IS_IMM_WITH_VAL(Int, simplified.node(), result);
3482 };
3483
3484 // Equals.
3485 ComparisonTest(2, 2, kEQ, 1);
3486 ComparisonTest(1, 2, kEQ, 0);
3487 ComparisonTest(2, 1, kEQ, 0);
3488
3489 // Greater than.
3490 ComparisonTest(2, 2, kGT, 0);
3491 ComparisonTest(1, 2, kGT, 0);
3492 ComparisonTest(2, 1, kGT, 1);
3493
3494 // Greater or Equal.
3495 ComparisonTest(2, 2, kGE, 1);
3496 ComparisonTest(1, 2, kGE, 0);
3497 ComparisonTest(2, 1, kGE, 1);
3498
3499 // Less Than.
3500 ComparisonTest(2, 2, kLT, 0);
3501 ComparisonTest(1, 2, kLT, 1);
3502 ComparisonTest(2, 1, kLT, 0);
3503
3504 // Less or Equal.
3505 ComparisonTest(2, 2, kLE, 1);
3506 ComparisonTest(1, 2, kLE, 1);
3507 ComparisonTest(2, 1, kLE, 0);
3508
3509 // Not equal.
3510 ComparisonTest(2, 2, kNE, 0);
3511 ComparisonTest(1, 2, kNE, 1);
3512 ComparisonTest(2, 1, kNE, 1);
3513
3514 // With specified results:
3515 ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE);
3516 ExprHandle simplified = IRSimplifier::simplify(body);
3517 IS_IMM_WITH_VAL(Int, simplified.node(), 42);
3518}
3519
3520TEST(Simplify, SimplifySymbolicComparisons) {
3521 VarHandle x("x", kInt);
3522 VarHandle y("y", kInt);
3523
3524 auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); };
3525 auto TookFalseBranch = [](ExprHandle a) {
3526 IS_IMM_WITH_VAL(Int, a.node(), 0);
3527 };
3528
3529 // EQ
3530
3531 // x == x => 1
3532 ExprHandle body = CompareSelect::make(x, x, kEQ);
3533 TookTrueBranch(IRSimplifier::simplify(body));
3534
3535 // x == x+1 => 0
3536 body = CompareSelect::make(x, x + 1, kEQ);
3537 TookFalseBranch(IRSimplifier::simplify(body));
3538
3539 // x == x * 2 cannot simplify since we don't know x is nonzero.
3540 body = CompareSelect::make(x, x * 2, kEQ);
3541 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3542 IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3543
3544 // x == x * 1 => 1
3545 body = CompareSelect::make(x, x * 1, kEQ);
3546 TookTrueBranch(IRSimplifier::simplify(body));
3547
3548 {
3549 // x == y => x == y
3550 body = CompareSelect::make(x, y, kEQ);
3551 ExprHandle simplified = IRSimplifier::simplify(body);
3552 IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3553 ASSERT_EQ(cmp->compare_select_op(), kEQ);
3554 IS_VAR_WITH_NAME(cmp->lhs(), "x");
3555 IS_VAR_WITH_NAME(cmp->rhs(), "y");
3556 }
3557
3558 {
3559 // x == 5 => x == 5
3560 body = CompareSelect::make(x, 5, kEQ);
3561 ExprHandle simplified = IRSimplifier::simplify(body);
3562 IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3563 ASSERT_EQ(cmp->compare_select_op(), kEQ);
3564 IS_VAR_WITH_NAME(cmp->lhs(), "x");
3565 IS_IMM_WITH_VAL(Int, cmp->rhs(), 5);
3566 }
3567
3568 // GT
3569
3570 // x+1 > x => 1
3571 body = CompareSelect::make(x + 1, x, kGT);
3572 TookTrueBranch(IRSimplifier::simplify(body));
3573
3574 // x > x + 1 => 0
3575 body = CompareSelect::make(x, x + 1, kGT);
3576 TookFalseBranch(IRSimplifier::simplify(body));
3577
3578 // x > x - 1 => 1
3579 body = CompareSelect::make(x, x - 1, kGT);
3580 TookTrueBranch(IRSimplifier::simplify(body));
3581
3582 // x - 1 > x => 0
3583 body = CompareSelect::make(x - 1, x, kGT);
3584 TookFalseBranch(IRSimplifier::simplify(body));
3585
3586 // x > x => 0
3587 body = CompareSelect::make(x, x, kGT);
3588 TookFalseBranch(IRSimplifier::simplify(body));
3589
3590 // x * 2 > x => x * 2 > x
3591 // since we don't know the sign of x.
3592 body = CompareSelect::make(x * 2, x, kGT);
3593 IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3594
3595 // GE
3596
3597 // x+1 >= x => 1
3598 body = CompareSelect::make(x + 1, x, kGE);
3599 TookTrueBranch(IRSimplifier::simplify(body));
3600
3601 // x >= x + 1 => 0
3602 body = CompareSelect::make(x, x + 1, kGE);
3603 TookFalseBranch(IRSimplifier::simplify(body));
3604
3605 // x >= x => 1
3606 body = CompareSelect::make(x, x, kGE);
3607 TookTrueBranch(IRSimplifier::simplify(body));
3608
3609 // x * 2 >= x => x * 2 >= x
3610 // since we don't know the sign of x.
3611 body = CompareSelect::make(x * 2, x, kGE);
3612 IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3613
3614 // LT
3615
3616 // x+1 < x => 0
3617 body = CompareSelect::make(x + 1, x, kLT);
3618 TookFalseBranch(IRSimplifier::simplify(body));
3619
3620 // x < x + 1 => 1
3621 body = CompareSelect::make(x, x + 1, kLT);
3622 TookTrueBranch(IRSimplifier::simplify(body));
3623
3624 // x < x => 0
3625 body = CompareSelect::make(x, x, kLT);
3626 TookFalseBranch(IRSimplifier::simplify(body));
3627
3628 // LE
3629
3630 // x+1 <= x => 0
3631 body = CompareSelect::make(x + 1, x, kLE);
3632 TookFalseBranch(IRSimplifier::simplify(body));
3633
3634 // x <= x + 1 => 1
3635 body = CompareSelect::make(x, x + 1, kLE);
3636 TookTrueBranch(IRSimplifier::simplify(body));
3637
3638 // x <= x => 1
3639 body = CompareSelect::make(x, x, kLE);
3640 TookTrueBranch(IRSimplifier::simplify(body));
3641
3642 // NE
3643
3644 // x+1 != x => 1
3645 body = CompareSelect::make(x + 1, x, kNE);
3646 TookTrueBranch(IRSimplifier::simplify(body));
3647
3648 // x != x + 1 => 1
3649 body = CompareSelect::make(x, x + 1, kNE);
3650 TookTrueBranch(IRSimplifier::simplify(body));
3651
3652 // x != x => 0
3653 body = CompareSelect::make(x, x, kNE);
3654 TookFalseBranch(IRSimplifier::simplify(body));
3655}
3656
3657TEST(Simplify, SimplifyEliminateZeroLengthFor) {
3658 {
3659 // Will eliminate zero loop For.
3660 BufHandle a("A", {4}, kInt);
3661 BufHandle c("C", {4}, kInt);
3662 VarHandle i("i", kInt);
3663 auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i})));
3664 StmtPtr simplified = IRSimplifier::simplify(body);
3665 BlockPtr block = to<Block>(simplified);
3666 ASSERT_EQ(block->nstmts(), 0);
3667 }
3668
3669 {
3670 // still works if start is not zero.
3671 BufHandle a("A", {4}, kInt);
3672 BufHandle c("C", {4}, kInt);
3673 VarHandle i("i", kInt);
3674 auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i})));
3675 StmtPtr simplified = IRSimplifier::simplify(body);
3676 BlockPtr block = to<Block>(simplified);
3677 ASSERT_EQ(block->nstmts(), 0);
3678 }
3679
3680 {
3681 // works if both terms are variable.
3682 VarHandle x("x", kInt);
3683 BufHandle a("A", {4}, kInt);
3684 BufHandle c("C", {4}, kInt);
3685 VarHandle i("i", kInt);
3686 auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i})));
3687 StmtPtr simplified = IRSimplifier::simplify(body);
3688 BlockPtr block = to<Block>(simplified);
3689 ASSERT_EQ(block->nstmts(), 0);
3690 }
3691
3692 {
3693 // works if one term simplifies down.
3694 VarHandle x("x", kInt);
3695 BufHandle a("A", {4}, kInt);
3696 BufHandle c("C", {4}, kInt);
3697 VarHandle i("i", kInt);
3698 auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i})));
3699 StmtPtr simplified = IRSimplifier::simplify(body);
3700 BlockPtr block = to<Block>(simplified);
3701 ASSERT_EQ(block->nstmts(), 0);
3702 }
3703
3704 {
3705 // Sanity check does nothing if the condition is not met.
3706 BufHandle a("A", {4}, kInt);
3707 BufHandle c("C", {4}, kInt);
3708 VarHandle i("i", kInt);
3709 auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
3710 StmtPtr simplified = IRSimplifier::simplify(body);
3711 IS_NODE(For, simplified);
3712 }
3713}
3714
3715TEST(Simplify, SimplifyOneLoopFor) {
3716 {
3717 // Will remove the loop if the body is run once.
3718 BufHandle a("A", {4}, kInt);
3719 BufHandle c("C", {4}, kInt);
3720 VarHandle i("i", kInt);
3721 auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
3722 StmtPtr simplified = IRSimplifier::simplify(body);
3723 BlockPtr block = to<Block>(simplified);
3724 IS_NODE_WITH_NAME(Store, block->front(), store);
3725 IS_VAR_WITH_NAME(store->base_handle(), "C");
3726 IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3727 }
3728
3729 {
3730 // still works if start is not zero.
3731 BufHandle a("A", {4}, kInt);
3732 BufHandle c("C", {4}, kInt);
3733 VarHandle i("i", kInt);
3734 auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i})));
3735 StmtPtr simplified = IRSimplifier::simplify(body);
3736 BlockPtr block = to<Block>(simplified);
3737 IS_NODE_WITH_NAME(Store, block->front(), store);
3738 IS_VAR_WITH_NAME(store->base_handle(), "C");
3739 IS_IMM_WITH_VAL(Int, store->flat_index(), 2);
3740 }
3741
3742 {
3743 // works if both terms are variable.
3744 VarHandle x("x", kInt);
3745 BufHandle a("A", {4}, kInt);
3746 BufHandle c("C", {4}, kInt);
3747 VarHandle i("i", kInt);
3748 auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i})));
3749 StmtPtr simplified = IRSimplifier::simplify(body);
3750 BlockPtr block = to<Block>(simplified);
3751 IS_NODE_WITH_NAME(Store, block->front(), store);
3752 IS_VAR_WITH_NAME(store->base_handle(), "C");
3753 IS_VAR_WITH_NAME(store->flat_index(), "x");
3754 }
3755
3756 {
3757 // works if one term simplifies down.
3758 VarHandle x("x", kInt);
3759 BufHandle a("A", {4}, kInt);
3760 BufHandle c("C", {4}, kInt);
3761 VarHandle i("i", kInt);
3762 auto body =
3763 For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i})));
3764 StmtPtr simplified = IRSimplifier::simplify(body);
3765 BlockPtr block = to<Block>(simplified);
3766 IS_NODE_WITH_NAME(Store, block->front(), store);
3767 IS_VAR_WITH_NAME(store->base_handle(), "C");
3768 IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3769 }
3770
3771 {
3772 // Sanity check does nothing if the condition is not met.
3773 BufHandle a("A", {4}, kInt);
3774 BufHandle c("C", {4}, kInt);
3775 VarHandle i("i", kInt);
3776 auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
3777 StmtPtr simplified = IRSimplifier::simplify(body);
3778 IS_NODE(For, simplified);
3779 }
3780}
3781
3782TEST(Simplify, SimplifyForWontLoseLoopOptions) {
3783 {
3784 // Sanity check does nothing if the condition is not met.
3785 BufHandle a("A", {4}, kInt);
3786 BufHandle c("C", {4}, kInt);
3787 VarHandle i("i", kInt);
3788 LoopOptions options;
3789 options.set_gpu_block_index(LoopOptions::IDX_W);
3790 auto body =
3791 For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options);
3792 StmtPtr simplified = IRSimplifier::simplify(body);
3793 IS_NODE_WITH_NAME(For, simplified, for_);
3794 LoopOptions options2 = for_->loop_options();
3795 ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index());
3796 }
3797}
3798
3799TEST(Simplify, SimplifyMultilevelFor) {
3800 {
3801 // Multiple layers of For will be simplified out.
3802 BufHandle a("A", {4}, kInt);
3803 BufHandle c("C", {4}, kInt);
3804 VarHandle i("i", kInt);
3805 VarHandle j("j", kInt);
3806 auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
3807 auto outer = For::make(j, 0, 1, body);
3808 StmtPtr simplified = IRSimplifier::simplify(outer);
3809 BlockPtr block = to<Block>(simplified);
3810 IS_NODE_WITH_NAME(Store, block->front(), store);
3811 IS_VAR_WITH_NAME(store->base_handle(), "C");
3812 IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3813 }
3814
3815 {
3816 // Will maintain an outer loop if the inner loop is eliminated.
3817 BufHandle a("A", {4}, kInt);
3818 BufHandle c("C", {4}, kInt);
3819 VarHandle i("i", kInt);
3820 VarHandle j("j", kInt);
3821 auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
3822 auto outer = For::make(j, 0, 2, body);
3823 StmtPtr simplified = IRSimplifier::simplify(outer);
3824 ForPtr for__ = static_to<For>(simplified);
3825 IS_NODE_WITH_NAME(For, for__, for_);
3826 IS_VAR_WITH_NAME(for_->var(), "j");
3827 IS_IMM_WITH_VAL(Int, for_->start(), 0);
3828 IS_IMM_WITH_VAL(Int, for_->stop(), 2);
3829 BlockPtr block = to<Block>(for_->body());
3830 ASSERT_NE(block, nullptr);
3831 IS_NODE_WITH_NAME(Store, block->front(), store);
3832 IS_VAR_WITH_NAME(store->base_handle(), "C");
3833 IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
3834 }
3835
3836 {
3837 // Will maintain inner loop if outer loops is eliminated.
3838 BufHandle a("A", {4}, kInt);
3839 BufHandle c("C", {4}, kInt);
3840 VarHandle i("i", kInt);
3841 VarHandle j("j", kInt);
3842 auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i})));
3843 auto outer = For::make(j, 0, 1, body);
3844 StmtPtr simplified = IRSimplifier::simplify(outer);
3845 BlockPtr block = to<Block>(simplified);
3846 IS_NODE_WITH_NAME(For, block->front(), for_);
3847 IS_VAR_WITH_NAME(for_->var(), "i");
3848 IS_IMM_WITH_VAL(Int, for_->start(), 0);
3849 IS_IMM_WITH_VAL(Int, for_->stop(), 2);
3850 IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
3851 IS_VAR_WITH_NAME(store->base_handle(), "C");
3852 IS_VAR_WITH_NAME(store->flat_index(), "i");
3853 }
3854}
3855
3856TEST(Simplify, SimplifyForCleansUp) {
3857 {
3858 BufHandle a("a", {1, 12, 1}, kFloat);
3859 VarHandle x("x", kInt);
3860 Tensor b = Compute(
3861 "x",
3862 {1, 12, 1},
3863 [](const VarHandle& i, const VarHandle& m, const VarHandle& n) {
3864 return i + m + n;
3865 });
3866 LoopNest l({b});
3867 l.prepareForCodegen();
3868
3869 StmtPtr body = LoopNest::sanitizeNames(l.root_stmt());
3870 StmtPtr simplified = IRSimplifier::simplify(body);
3871
3872 BlockPtr block = to<Block>(simplified);
3873 IS_NODE_WITH_NAME(For, block->front(), for_);
3874 // for is over "m".
3875 IS_VAR_WITH_NAME(for_->var(), "j");
3876 // x[m] = m;
3877 IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
3878 IS_VAR_WITH_NAME(store->flat_index(), "j");
3879 IS_VAR_WITH_NAME(store->value(), "j");
3880 }
3881}
3882
3883TEST(Simplify, SimplifyEliminateEmptyFor) {
3884 {
3885 // Flatten many layers around an empty block to an empty block.
3886 StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
3887 for (const auto i : c10::irange(11)) {
3888 (void)i; // Suppress unused variable warning
3889 VarHandle loopVar("loopVar", kInt);
3890 last = For::make(loopVar, 0, 10, last);
3891 }
3892
3893 StmtPtr simplified = IRSimplifier::simplify(last);
3894 IS_NODE_WITH_NAME(Block, simplified, block);
3895 ASSERT_EQ(block->nstmts(), 0);
3896 }
3897}
3898
3899TEST(Simplify, SimplifyFlattenBlock) {
3900 {
3901 // Flatten multiple blocks down to one.
3902 // { { { stmt1, stmt2 } } } => { stmt1, stmt2 }
3903 BufHandle a("A", {1}, kInt);
3904 StorePtr store1 = Store::make(a, {0}, 1);
3905 StorePtr store2 = Store::make(a, {0}, 0);
3906
3907 BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1, store2}));
3908 BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
3909
3910 BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block2}));
3911 StmtPtr simplified = IRSimplifier::simplify(enclosing);
3912
3913 IS_NODE_WITH_NAME(Block, simplified, block);
3914 ASSERT_EQ(block->nstmts(), 2);
3915
3916 IS_NODE_WITH_NAME(Store, block->front(), store1_);
3917 IS_NODE_WITH_NAME(Store, block->back(), store2_);
3918
3919 ASSERT_EQ(store1->value(), store1_->value());
3920 ASSERT_EQ(store2->value(), store2_->value());
3921 }
3922
3923 {
3924 // Flatten multiple sub blocks containing statements.
3925 // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 }
3926 BufHandle a("A", {1}, kInt);
3927 StorePtr store1 = Store::make(a, {0}, 1);
3928 StorePtr store2 = Store::make(a, {0}, 0);
3929
3930 BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1}));
3931 BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({store2}));
3932
3933 BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block1, block2}));
3934 StmtPtr simplified = IRSimplifier::simplify(enclosing);
3935
3936 IS_NODE_WITH_NAME(Block, simplified, block);
3937 ASSERT_EQ(block->nstmts(), 2);
3938
3939 IS_NODE_WITH_NAME(Store, block->front(), store1_);
3940 IS_NODE_WITH_NAME(Store, block->back(), store2_);
3941
3942 ASSERT_EQ(store1->value(), store1_->value());
3943 ASSERT_EQ(store2->value(), store2_->value());
3944 }
3945
3946 {
3947 // Flatten sub blocks with different depths.
3948 // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 }
3949 BufHandle a("A", {1}, kInt);
3950 StorePtr store1 = Store::make(a, {0}, 1);
3951 StorePtr store2 = Store::make(a, {0}, 0);
3952
3953 BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store2}));
3954 BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
3955
3956 BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({store1, block2}));
3957 StmtPtr simplified = IRSimplifier::simplify(enclosing);
3958
3959 IS_NODE_WITH_NAME(Block, simplified, block);
3960 ASSERT_EQ(block->nstmts(), 2);
3961
3962 IS_NODE_WITH_NAME(Store, block->front(), store1_);
3963 IS_NODE_WITH_NAME(Store, block->back(), store2_);
3964
3965 ASSERT_EQ(store1->value(), store1_->value());
3966 ASSERT_EQ(store2->value(), store2_->value());
3967 }
3968
3969 {
3970 // Flatten many layers around an empty block to an empty block.
3971 StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
3972 for (const auto i : c10::irange(11)) {
3973 (void)i; // Suppress unused variable warning
3974 last = alloc<Block>(std::vector<StmtPtr>({last}));
3975 }
3976
3977 StmtPtr simplified = IRSimplifier::simplify(last);
3978 IS_NODE_WITH_NAME(Block, simplified, block);
3979 ASSERT_EQ(block->nstmts(), 0);
3980 }
3981}
3982
3983TEST(Simplify, SimplifyEliminateZeroLengthAlloc) {
3984 {
3985 // Simple positive case.
3986 BufHandle b("x", {0}, kInt);
3987
3988 AllocatePtr alloc_ = Allocate::make(b);
3989 FreePtr free_ = Free::make(b);
3990
3991 BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
3992 ASSERT_EQ(block1->nstmts(), 2);
3993
3994 StmtPtr simplified = IRSimplifier::simplify(block1);
3995 IS_NODE_WITH_NAME(Block, simplified, block2);
3996 ASSERT_EQ(block2->nstmts(), 0);
3997 }
3998
3999 {
4000 // Simple negative case.
4001 BufHandle b("x", {2}, kInt);
4002
4003 AllocatePtr alloc_ = Allocate::make(b);
4004 FreePtr free_ = Free::make(b);
4005
4006 BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
4007 ASSERT_EQ(block1->nstmts(), 2);
4008
4009 StmtPtr simplified = IRSimplifier::simplify(block1);
4010 IS_NODE_WITH_NAME(Block, simplified, block2);
4011 ASSERT_EQ(block2->nstmts(), 2);
4012 }
4013
4014 {
4015 // Finds right Alloc/Free.
4016 BufHandle b1("x", {0}, kInt);
4017 BufHandle b2("y", {2}, kInt);
4018
4019 AllocatePtr alloc1 = Allocate::make(b1);
4020 AllocatePtr alloc2 = Allocate::make(b2);
4021 FreePtr free2_ = Free::make(b2);
4022 FreePtr free1_ = Free::make(b1);
4023
4024 BlockPtr block1 =
4025 alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
4026 ASSERT_EQ(block1->nstmts(), 4);
4027
4028 StmtPtr simplified = IRSimplifier::simplify(block1);
4029 IS_NODE_WITH_NAME(Block, simplified, block2);
4030 ASSERT_EQ(block2->nstmts(), 2);
4031 IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc);
4032 IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y");
4033 IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free);
4034 ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var());
4035 }
4036
4037 {
4038 // Dynamic shape.
4039 VarHandle z("z", kInt);
4040 BufHandle b1("x", {0}, kInt);
4041 BufHandle b2("y", {z}, kInt);
4042
4043 AllocatePtr alloc1 = Allocate::make(b1);
4044 AllocatePtr alloc2 = Allocate::make(b2);
4045 FreePtr free2_ = Free::make(b2);
4046 FreePtr free1_ = Free::make(b1);
4047
4048 BlockPtr block1 =
4049 alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
4050 ASSERT_EQ(block1->nstmts(), 4);
4051 StmtPtr simplified = IRSimplifier::simplify(block1);
4052 IS_NODE_WITH_NAME(Block, simplified, block2);
4053 ASSERT_EQ(block2->nstmts(), 2);
4054 }
4055}
4056
4057TEST(Simplify, DontSimplifyRand) {
4058 {
4059 // rand() + rand() = rand() + rand() NOT 2 * rand().
4060 ExprHandle body =
4061 Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt);
4062 ExprHandle simplified = IRSimplifier::simplify(body);
4063 IS_NODE_WITH_NAME(Add, simplified.node(), add);
4064 IS_RAND(add->lhs());
4065 IS_RAND(add->rhs());
4066 }
4067
4068 {
4069 // rand() - rand() = rand() - rand() NOT 0.
4070 ExprHandle body =
4071 Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat);
4072 ExprHandle simplified = IRSimplifier::simplify(body);
4073 IS_NODE_WITH_NAME(Sub, simplified.node(), sub);
4074 IS_RAND(sub->lhs());
4075 IS_RAND(sub->rhs());
4076 }
4077
4078 {
4079 // rand() * rand() = rand() * rand().
4080 ExprHandle body =
4081 Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt);
4082 ExprHandle simplified = IRSimplifier::simplify(body);
4083 IS_NODE_WITH_NAME(Mul, simplified.node(), mul);
4084 IS_RAND(mul->lhs());
4085 IS_RAND(mul->rhs());
4086 }
4087}
4088
4089TEST(Simplify, SimplifyReorderForCond) {
4090 BufHandle a("A", {4}, kInt);
4091 BufHandle b("B", {1}, kInt);
4092 BufHandle c("C", {4}, kInt);
4093 VarHandle i("i", kInt);
4094 VarHandle j("j", kInt);
4095
4096 {
4097 // for ( if ( ... ) ) => if ( for ( ... ) ).
4098 auto body = For::make(
4099 i,
4100 0,
4101 4,
4102 Cond::make(
4103 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4104 Store::make(c, {i}, Load::make(a, {i})),
4105 nullptr));
4106
4107 StmtPtr simplified = IRSimplifier::simplify(body);
4108 IS_NODE_WITH_NAME(Cond, simplified, cond);
4109 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4110 IS_NODE_WITH_NAME(For, true_block->front(), loop);
4111 }
4112
4113 {
4114 // Can't reorder if condition is dependent on the loop var.
4115 auto body = For::make(
4116 i,
4117 0,
4118 4,
4119 Cond::make(
4120 CompareSelect::make(i, 2, CompareSelectOperation::kEQ),
4121 Store::make(c, {i}, Load::make(a, {i})),
4122 nullptr));
4123
4124 StmtPtr simplified = IRSimplifier::simplify(body);
4125 IS_NODE_WITH_NAME(For, simplified, loop);
4126 IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4127 }
4128
4129 {
4130 // Can't reorder if condition is dependent on a var that is modified inside
4131 // the loop.
4132 auto body = For::make(
4133 i,
4134 0,
4135 4,
4136 Cond::make(
4137 CompareSelect::make(
4138 Load::make(c, {0}), 10, CompareSelectOperation::kLT),
4139 Store::make(c, {0}, Load::make(a, {i})),
4140 nullptr));
4141
4142 StmtPtr simplified = IRSimplifier::simplify(body);
4143 IS_NODE_WITH_NAME(For, simplified, loop);
4144 IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4145 }
4146
4147 {
4148 // Condition based on buffer not referenced in body. Can reorder here.
4149 auto body = For::make(
4150 i,
4151 0,
4152 4,
4153 Cond::make(
4154 CompareSelect::make(
4155 Load::make(b, {0}), 10, CompareSelectOperation::kLT),
4156 Store::make(c, {0}, Load::make(a, {i})),
4157 nullptr));
4158
4159 StmtPtr simplified = IRSimplifier::simplify(body);
4160 IS_NODE_WITH_NAME(Cond, simplified, cond);
4161 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4162 IS_NODE_WITH_NAME(For, true_block->front(), loop);
4163 }
4164
4165 {
4166 // Condition based on buffer read only in body. Can reorder here.
4167 auto body = For::make(
4168 i,
4169 0,
4170 4,
4171 Cond::make(
4172 CompareSelect::make(
4173 Load::make(a, {0}), 10, CompareSelectOperation::kLT),
4174 Store::make(c, {0}, Load::make(a, {i})),
4175 nullptr));
4176
4177 StmtPtr simplified = IRSimplifier::simplify(body);
4178 IS_NODE_WITH_NAME(Cond, simplified, cond);
4179 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4180 IS_NODE_WITH_NAME(For, true_block->front(), loop);
4181 }
4182
4183 {
4184 // Condition depends on Let in the loop. Cannot reorder.
4185 auto body = For::make(
4186 i,
4187 0,
4188 4,
4189 Block::make(
4190 {Let::make(j, 3),
4191 Cond::make(
4192 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4193 Store::make(c, {0}, Load::make(a, {i})),
4194 nullptr)}));
4195
4196 StmtPtr simplified = IRSimplifier::simplify(body);
4197 IS_NODE_WITH_NAME(For, simplified, loop);
4198 IS_NODE_WITH_NAME(Let, loop->body()->front(), let);
4199 IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond);
4200 }
4201
4202 {
4203 // Multi level Ifs where all conditions are distinct. Move BOTH Cond
4204 // statements outside the loop.
4205 auto body = For::make(
4206 i,
4207 0,
4208 4,
4209 Cond::make(
4210 CompareSelect::make(
4211 Load::make(a, {0}), 10, CompareSelectOperation::kLT),
4212 Cond::make(
4213 CompareSelect::make(j, 10, CompareSelectOperation::kEQ),
4214 Store::make(c, {0}, Load::make(a, {i})),
4215 nullptr),
4216 nullptr));
4217
4218 StmtPtr simplified = IRSimplifier::simplify(body);
4219 IS_NODE_WITH_NAME(Cond, simplified, cond);
4220 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4221 IS_NODE_WITH_NAME(Cond, true_block->front(), cond2);
4222 IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2);
4223 IS_NODE_WITH_NAME(For, true_block2->front(), loop);
4224 }
4225
4226 {
4227 // Multi level Ifs where the inner condition does depend on a loop var,
4228 // reorder only the first Cond.
4229 auto body = For::make(
4230 i,
4231 0,
4232 4,
4233 Cond::make(
4234 CompareSelect::make(
4235 Load::make(a, {0}), 10, CompareSelectOperation::kLT),
4236 Cond::make(
4237 CompareSelect::make(i, 3, CompareSelectOperation::kEQ),
4238 Store::make(c, {0}, Load::make(a, {i})),
4239 nullptr),
4240 nullptr));
4241
4242 StmtPtr simplified = IRSimplifier::simplify(body);
4243 IS_NODE_WITH_NAME(Cond, simplified, cond);
4244 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4245 IS_NODE_WITH_NAME(For, true_block->front(), loop);
4246 IS_NODE_WITH_NAME(Block, loop->body(), loop_body);
4247 IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2);
4248 }
4249
4250 {
4251 // Don't reorder if there's an else block of the Cond.
4252 // We could, but is it much better?
4253 auto body = For::make(
4254 i,
4255 0,
4256 4,
4257 Cond::make(
4258 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4259 Store::make(c, {0}, Load::make(a, {i})),
4260 Store::make(c, {0}, 0)));
4261
4262 StmtPtr simplified = IRSimplifier::simplify(body);
4263 IS_NODE_WITH_NAME(For, simplified, loop);
4264 IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4265 }
4266
4267 {
4268 // Condition uses distinct region of Tensor.
4269 // We could reorder here wih better analysis, but we don't. Included for
4270 // completeness.
4271 auto body = For::make(
4272 i,
4273 0,
4274 4,
4275 Cond::make(
4276 CompareSelect::make(
4277 Load::make(c, {0}), 10, CompareSelectOperation::kLT),
4278 Store::make(c, {1}, Load::make(a, {i})),
4279 nullptr));
4280
4281 StmtPtr simplified = IRSimplifier::simplify(body);
4282 IS_NODE_WITH_NAME(For, simplified, loop);
4283 IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
4284 }
4285}
4286
4287TEST(Simplify, SimplifyFuseConditions) {
4288 BufHandle a("A", {2}, kInt);
4289 BufHandle b("B", {2}, kInt);
4290 VarHandle i("i", kInt);
4291 VarHandle j("j", kInt);
4292
4293 {
4294 // Can fuse since the conditions are identical.
4295 // if (A) { X }; if (A) { Y }; => if (A) { X; Y }
4296 auto body = Block::make(
4297 {Cond::make(
4298 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4299 Store::make(a, {0}, i),
4300 nullptr),
4301 Cond::make(
4302 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4303 Store::make(a, {1}, i),
4304 nullptr)});
4305
4306 StmtPtr simplified = IRSimplifier::simplify(body);
4307 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
4308 IS_NODE_WITH_NAME(Block, simplified, block);
4309 ASSERT_EQ(block->nstmts(), 1);
4310 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4311 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4312 ASSERT_EQ(true_stmt->nstmts(), 2);
4313 ASSERT_EQ(cond->false_stmt(), nullptr);
4314 }
4315
4316 {
4317 // Can't fuse, conditions are not identical in lhs (i != j).
4318 auto body = Block::make(
4319 {Cond::make(
4320 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4321 Store::make(a, {0}, i),
4322 nullptr),
4323 Cond::make(
4324 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4325 Store::make(a, {1}, i),
4326 nullptr)});
4327
4328 StmtPtr simplified = IRSimplifier::simplify(body);
4329 IS_NODE_WITH_NAME(Block, simplified, block);
4330 ASSERT_EQ(block->nstmts(), 2);
4331 IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4332 IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4333
4334 IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4335 IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4336 ASSERT_EQ(true_stmt1->nstmts(), 1);
4337 ASSERT_EQ(true_stmt2->nstmts(), 1);
4338
4339 ASSERT_EQ(cond1->false_stmt(), nullptr);
4340 ASSERT_EQ(cond2->false_stmt(), nullptr);
4341 }
4342 {
4343 // Can't fuse, conditions are not identical in rhs (10 != 11).
4344 auto body = Block::make(
4345 {Cond::make(
4346 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4347 Store::make(a, {0}, i),
4348 nullptr),
4349 Cond::make(
4350 CompareSelect::make(i, 11, CompareSelectOperation::kLT),
4351 Store::make(a, {1}, i),
4352 nullptr)});
4353
4354 StmtPtr simplified = IRSimplifier::simplify(body);
4355 IS_NODE_WITH_NAME(Block, simplified, block);
4356 ASSERT_EQ(block->nstmts(), 2);
4357 IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4358 IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4359
4360 IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4361 IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4362 ASSERT_EQ(true_stmt1->nstmts(), 1);
4363 ASSERT_EQ(true_stmt2->nstmts(), 1);
4364
4365 ASSERT_EQ(cond1->false_stmt(), nullptr);
4366 ASSERT_EQ(cond2->false_stmt(), nullptr);
4367 }
4368
4369 {
4370 // Can't fuse, conditions are not identical in operation (LT vs GT).
4371 auto body = Block::make(
4372 {Cond::make(
4373 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4374 Store::make(a, {0}, i),
4375 nullptr),
4376 Cond::make(
4377 CompareSelect::make(i, 10, CompareSelectOperation::kGT),
4378 Store::make(a, {1}, i),
4379 nullptr)});
4380
4381 StmtPtr simplified = IRSimplifier::simplify(body);
4382 IS_NODE_WITH_NAME(Block, simplified, block);
4383 ASSERT_EQ(block->nstmts(), 2);
4384 IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4385 IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4386
4387 IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4388 IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4389 ASSERT_EQ(true_stmt1->nstmts(), 1);
4390 ASSERT_EQ(true_stmt2->nstmts(), 1);
4391
4392 ASSERT_EQ(cond1->false_stmt(), nullptr);
4393 ASSERT_EQ(cond2->false_stmt(), nullptr);
4394 }
4395
4396 {
4397 // Can't fuse, CompareSelect results are different.
4398 // Actually we totally could if we normalized CompareSelect results, but
4399 // TODO for later.
4400 auto body = Block::make(
4401 {Cond::make(
4402 CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT),
4403 Store::make(a, {0}, i),
4404 nullptr),
4405 Cond::make(
4406 CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT),
4407 Store::make(a, {1}, i),
4408 nullptr)});
4409
4410 StmtPtr simplified = IRSimplifier::simplify(body);
4411 IS_NODE_WITH_NAME(Block, simplified, block);
4412 ASSERT_EQ(block->nstmts(), 2);
4413 IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4414 IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4415
4416 IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1);
4417 IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2);
4418 ASSERT_EQ(true_stmt1->nstmts(), 1);
4419 ASSERT_EQ(true_stmt2->nstmts(), 1);
4420
4421 ASSERT_EQ(cond1->false_stmt(), nullptr);
4422 ASSERT_EQ(cond2->false_stmt(), nullptr);
4423 }
4424
4425 {
4426 // Can fuse with false stmt only.
4427 auto body = Block::make(
4428 {Cond::make(
4429 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4430 nullptr,
4431 Store::make(a, {0}, i)),
4432 Cond::make(
4433 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4434 nullptr,
4435 Store::make(a, {1}, i))});
4436
4437 StmtPtr simplified = IRSimplifier::simplify(body);
4438 IS_NODE_WITH_NAME(Block, simplified, block);
4439 ASSERT_EQ(block->nstmts(), 1);
4440 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4441 IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt);
4442 ASSERT_EQ(false_stmt->nstmts(), 2);
4443 ASSERT_EQ(cond->true_stmt(), nullptr);
4444 }
4445
4446 {
4447 // Can fuse with both true and false stmt.
4448 auto body = Block::make(
4449 {Cond::make(
4450 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4451 Store::make(a, {0}, i),
4452 Store::make(b, {0}, i)),
4453 Cond::make(
4454 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4455 Store::make(a, {1}, i),
4456 Store::make(b, {1}, i))});
4457
4458 StmtPtr simplified = IRSimplifier::simplify(body);
4459 IS_NODE_WITH_NAME(Block, simplified, block);
4460 ASSERT_EQ(block->nstmts(), 1);
4461 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4462 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4463 ASSERT_EQ(true_stmt->nstmts(), 2);
4464 IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt);
4465 ASSERT_EQ(false_stmt->nstmts(), 2);
4466 }
4467
4468 {
4469 // Can fuse with mismatched true / false stmt existing
4470 auto body = Block::make(
4471 {Cond::make(
4472 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4473 Store::make(a, {0}, i),
4474 nullptr),
4475 Cond::make(
4476 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4477 nullptr,
4478 Store::make(b, {1}, i))});
4479
4480 StmtPtr simplified = IRSimplifier::simplify(body);
4481 IS_NODE_WITH_NAME(Block, simplified, block);
4482 ASSERT_EQ(block->nstmts(), 1);
4483 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4484 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4485 ASSERT_EQ(true_stmt->nstmts(), 1);
4486 IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt);
4487 ASSERT_EQ(false_stmt->nstmts(), 1);
4488 }
4489
4490 {
4491 // Can fuse partial block contents, ie when there are non fused stmts before
4492 // and after.
4493 // before:
4494 // if (j < 10) { A[0] = j; }
4495 // if (i < 10) { A[0] = i; }
4496 // if (i < 10) { A[1] = i; }
4497 // if (i < 11) { A[1] = j; }
4498 //
4499 // after:
4500 //
4501 // if (j < 10) { A[0] = j; }
4502 // if (i < 10) {
4503 // A[0] = i;
4504 // A[1] = i;
4505 // }
4506 // if (i < 11) { A[1] = j; }
4507
4508 auto body = Block::make({
4509 Cond::make(
4510 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4511 Store::make(a, {0}, j),
4512 nullptr),
4513 Cond::make(
4514 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4515 Store::make(a, {0}, i),
4516 nullptr),
4517 Cond::make(
4518 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4519 Store::make(a, {1}, i),
4520 nullptr),
4521 Cond::make(
4522 CompareSelect::make(i, 11, CompareSelectOperation::kLT),
4523 Store::make(a, {1}, j),
4524 nullptr),
4525 });
4526 StmtPtr simplified = IRSimplifier::simplify(body);
4527 IS_NODE_WITH_NAME(Block, simplified, block);
4528 ASSERT_EQ(block->nstmts(), 3);
4529 auto it = block->begin();
4530 it++;
4531 IS_NODE_WITH_NAME(Cond, *it, cond);
4532 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4533 ASSERT_EQ(true_stmt->nstmts(), 2);
4534 ASSERT_EQ(cond->false_stmt(), nullptr);
4535 }
4536
4537 {
4538 // Can fuse longer sequences of identical conditions.
4539 auto body = Block::make({
4540 Cond::make(
4541 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4542 Store::make(a, {0}, j),
4543 nullptr),
4544 Cond::make(
4545 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4546 Store::make(a, {0}, i),
4547 nullptr),
4548 Cond::make(
4549 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4550 Store::make(a, {1}, i),
4551 nullptr),
4552 Cond::make(
4553 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4554 Store::make(a, {1}, j),
4555 nullptr),
4556 });
4557 StmtPtr simplified = IRSimplifier::simplify(body);
4558 IS_NODE_WITH_NAME(Block, simplified, block);
4559 ASSERT_EQ(block->nstmts(), 1);
4560 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4561 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4562 ASSERT_EQ(true_stmt->nstmts(), 4);
4563 ASSERT_EQ(cond->false_stmt(), nullptr);
4564 }
4565
4566 {
4567 // Can't fuse through a non condition.
4568 auto body = Block::make({
4569 Cond::make(
4570 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4571 Store::make(a, {0}, j),
4572 nullptr),
4573 Cond::make(
4574 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4575 Store::make(a, {0}, i),
4576 nullptr),
4577 Store::make(b, {1}, i + j),
4578 Cond::make(
4579 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4580 Store::make(a, {1}, i),
4581 nullptr),
4582 Cond::make(
4583 CompareSelect::make(i, 10, CompareSelectOperation::kLT),
4584 Store::make(a, {1}, j),
4585 nullptr),
4586 });
4587 StmtPtr simplified = IRSimplifier::simplify(body);
4588 IS_NODE_WITH_NAME(Block, simplified, block);
4589 ASSERT_EQ(block->nstmts(), 3);
4590 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4591 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4592 ASSERT_EQ(true_stmt->nstmts(), 2);
4593 ASSERT_EQ(cond->false_stmt(), nullptr);
4594
4595 IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4596 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2);
4597 ASSERT_EQ(true_stmt2->nstmts(), 2);
4598 ASSERT_EQ(cond2->false_stmt(), nullptr);
4599
4600 auto it = block->begin();
4601 it++;
4602 IS_NODE_WITH_NAME(Store, *it, middle);
4603 }
4604
4605 {
4606 // Can fuse if the conditions simplify to the same thing.
4607 auto body = Block::make(
4608 {Cond::make(
4609 CompareSelect::make(
4610 i * 2,
4611 ExprHandle(87) % ExprHandle(11),
4612 CompareSelectOperation::kLT),
4613 Store::make(a, {0}, i),
4614 nullptr),
4615 Cond::make(
4616 CompareSelect::make(
4617 i * 2,
4618 ExprHandle(300) / ExprHandle(30),
4619 CompareSelectOperation::kLT),
4620 Store::make(a, {1}, i),
4621 nullptr)});
4622 StmtPtr simplified = IRSimplifier::simplify(body);
4623 IS_NODE_WITH_NAME(Block, simplified, block);
4624 ASSERT_EQ(block->nstmts(), 1);
4625 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4626 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4627 ASSERT_EQ(true_stmt->nstmts(), 2);
4628 ASSERT_EQ(cond->false_stmt(), nullptr);
4629 }
4630
4631 {
4632 // Can fuse non-CompareSelects.
4633 // if (i) { X } if (i) { Y } => if (i) { X; Y }
4634 auto body = Block::make(
4635 {Cond::make(i, Store::make(a, {0}, i), nullptr),
4636 Cond::make(i, Store::make(a, {1}, i), nullptr)});
4637
4638 StmtPtr simplified = IRSimplifier::simplify(body);
4639 IS_NODE_WITH_NAME(Block, simplified, block);
4640 ASSERT_EQ(block->nstmts(), 1);
4641 IS_NODE_WITH_NAME(Cond, block->front(), cond);
4642 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt);
4643 ASSERT_EQ(true_stmt->nstmts(), 2);
4644 ASSERT_EQ(cond->false_stmt(), nullptr);
4645 }
4646
4647 {
4648 // Sanity check wont fuse different non-CompareSelects.
4649 auto body = Block::make(
4650 {Cond::make(i, Store::make(a, {0}, i), nullptr),
4651 Cond::make(j, Store::make(a, {1}, i), nullptr)});
4652
4653 StmtPtr simplified = IRSimplifier::simplify(body);
4654 IS_NODE_WITH_NAME(Block, simplified, block);
4655 ASSERT_EQ(block->nstmts(), 2);
4656 IS_NODE_WITH_NAME(Cond, block->front(), cond1);
4657 IS_NODE_WITH_NAME(Cond, block->back(), cond2);
4658 }
4659
4660 {
4661 // Sanity check constant condition elimination still occurs when merging is
4662 // possible.
4663 auto body = Block::make(
4664 {Cond::make(1, Store::make(a, {0}, i), nullptr),
4665 Cond::make(1, Store::make(a, {1}, i), nullptr)});
4666 StmtPtr simplified = IRSimplifier::simplify(body);
4667 IS_NODE_WITH_NAME(Block, simplified, block);
4668 ASSERT_EQ(block->nstmts(), 2);
4669 IS_NODE_WITH_NAME(Store, block->front(), store1);
4670 IS_NODE_WITH_NAME(Store, block->back(), store2);
4671 }
4672
4673 {
4674 // Sanity check for-cond reordering occurs after fusing.
4675 auto body = For::make(
4676 i,
4677 0,
4678 4,
4679 Block::make(
4680 {Cond::make(
4681 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4682 Store::make(a, {1}, Load::make(b, {0})),
4683 nullptr),
4684 Cond::make(
4685 CompareSelect::make(j, 10, CompareSelectOperation::kLT),
4686 Store::make(a, {2}, Load::make(b, {0})),
4687 nullptr)}));
4688
4689 StmtPtr simplified = IRSimplifier::simplify(body);
4690 IS_NODE_WITH_NAME(Cond, simplified, cond);
4691 IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
4692 IS_NODE_WITH_NAME(For, true_block->front(), loop);
4693 }
4694}
4695
4696TEST(Simplify, SimplifySyncThreads) {
4697 BufHandle a("A", {4}, kInt);
4698 VarHandle i("i", kInt);
4699
4700 {
4701 // Merge two inner SyncThreads.
4702 auto body = Block::make(
4703 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
4704 {Store::make(a, {0}, 1),
4705 alloc<SyncThreads>(),
4706 alloc<SyncThreads>(),
4707 Store::make(a, {1}, 0)});
4708 StmtPtr simplified = IRSimplifier::simplify(body);
4709 IS_NODE_WITH_NAME(Block, simplified, block);
4710 ASSERT_EQ(block->nstmts(), 3);
4711 auto it = block->begin();
4712 IS_NODE(Store, *it++);
4713 IS_NODE(SyncThreads, *it++);
4714 IS_NODE(Store, *it++);
4715 }
4716
4717 {
4718 // Eliminate outer SyncThreads.
4719 auto body = Block::make(
4720 {alloc<SyncThreads>(), Store::make(a, {1}, 0), alloc<SyncThreads>()});
4721
4722 StmtPtr simplified = IRSimplifier::simplify(body);
4723 IS_NODE_WITH_NAME(Block, simplified, block);
4724 ASSERT_EQ(block->nstmts(), 1);
4725 auto it = block->begin();
4726 IS_NODE(Store, *it);
4727 }
4728
4729 {
4730 // Merge many inner SyncThreads.
4731 auto body = Block::make(
4732 {Store::make(a, {0}, 1),
4733 alloc<SyncThreads>(),
4734 alloc<SyncThreads>(),
4735 alloc<SyncThreads>(),
4736 alloc<SyncThreads>(),
4737 alloc<SyncThreads>(),
4738 Store::make(a, {1}, 0)});
4739
4740 StmtPtr simplified = IRSimplifier::simplify(body);
4741 IS_NODE_WITH_NAME(Block, simplified, block);
4742 ASSERT_EQ(block->nstmts(), 3);
4743 auto it = block->begin();
4744 IS_NODE(Store, *it++);
4745 IS_NODE(SyncThreads, *it++);
4746 IS_NODE(Store, *it++);
4747 }
4748
4749 {
4750 // Merge multiple outer SyncThreads.
4751 auto body = Block::make(
4752 {alloc<SyncThreads>(),
4753 alloc<SyncThreads>(),
4754 Store::make(a, {1}, 0),
4755 alloc<SyncThreads>(),
4756 alloc<SyncThreads>(),
4757 alloc<SyncThreads>(),
4758 alloc<SyncThreads>()});
4759
4760 StmtPtr simplified = IRSimplifier::simplify(body);
4761 IS_NODE_WITH_NAME(Block, simplified, block);
4762 ASSERT_EQ(block->nstmts(), 1);
4763 auto it = block->begin();
4764 IS_NODE(Store, *it);
4765 }
4766
4767 {
4768 // Merge multiple sections;
4769 auto body = Block::make(
4770 {Store::make(a, {0}, 1),
4771 alloc<SyncThreads>(),
4772 alloc<SyncThreads>(),
4773 Store::make(a, {1}, 0),
4774 Store::make(a, {2}, 0),
4775 alloc<SyncThreads>(),
4776 alloc<SyncThreads>(),
4777 alloc<SyncThreads>(),
4778 Store::make(a, {3}, 0)});
4779
4780 StmtPtr simplified = IRSimplifier::simplify(body);
4781 IS_NODE_WITH_NAME(Block, simplified, block);
4782 ASSERT_EQ(block->nstmts(), 6);
4783 auto it = block->begin();
4784 IS_NODE(Store, *it++);
4785 IS_NODE(SyncThreads, *it++);
4786 IS_NODE(Store, *it++);
4787 IS_NODE(Store, *it++);
4788 IS_NODE(SyncThreads, *it++);
4789 IS_NODE(Store, *it++);
4790 }
4791}
4792
4793TEST(Simplify, SimplifyRampSubBroadcast) {
4794 int num_lanes = 4;
4795 ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes);
4796 ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes);
4797 ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast);
4798 RampPtr newRamp = simplified.AsNode<Ramp>();
4799 IS_NODE_WITH_NAME(IntImm, newRamp->base(), base);
4800 ASSERT_EQ(base->value(), 5);
4801 IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride);
4802 ASSERT_EQ(stride->value(), 6);
4803 ASSERT_EQ(newRamp->lanes(), num_lanes);
4804}
4805
4806TEST(Simplify, SimplifyBroadcastTermExpander) {
4807 int num_lanes = 8;
4808 ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes);
4809 ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes);
4810 ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes);
4811 // NB: We need a term in the middle which isn't simplified to trigger the
4812 // relevant path in TermExpander::mutate. The two bc1 terms are brought
4813 // together and simplified to 2 * bc1, which then needs to make 2 multi-lane.
4814 ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1);
4815 BufHandle buf("buf", {num_lanes}, kInt);
4816 // The result isn't fully simplified currently and thus would be brittle to
4817 // match. Observe its value instead.
4818 auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified);
4819 SimpleIREvaluator eval(store, {buf});
4820 std::vector<int> output(num_lanes);
4821 eval(output);
4822 for (const auto i : c10::irange(num_lanes)) {
4823 ASSERT_EQ(output[i], 2);
4824 }
4825}
4826
4827TEST(Simplify, CompareSelectLoopBounds) {
4828 constexpr int N = 8;
4829 BufHandle b("b", {N}, kFloat);
4830 VarHandle n("n", kInt);
4831 VarHandle m("m", kInt);
4832 VarHandle var_N("var_N", kInt);
4833 VarHandle var_M("var_M", kInt);
4834
4835 auto test_case_fn = [](const VarHandle& n,
4836 const BufHandle& b,
4837 const ExprHandle& start,
4838 const ExprHandle& stop,
4839 const int& cmp_val,
4840 const CompareSelectOperation& cmp_op,
4841 const std::string& check_string) {
4842 StmtPtr s = For::make(
4843 n,
4844 start,
4845 stop,
4846 b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op)));
4847 s = IRSimplifier::simplify(s);
4848 std::ostringstream oss;
4849 oss << *s;
4850 std::string target_string = "# CHECK: ";
4851 target_string += check_string;
4852 torch::jit::testing::FileCheck().run(target_string, oss.str());
4853 };
4854
4855 auto test_case_nest_loops_fn = [](const VarHandle& n,
4856 const VarHandle& m,
4857 const BufHandle& b,
4858 const ExprHandle& n_start,
4859 const ExprHandle& n_stop,
4860 const ExprHandle& m_start,
4861 const ExprHandle& m_stop,
4862 const CompareSelectOperation& cmp_op,
4863 const std::string& check_string) {
4864 StmtPtr s = For::make(
4865 m,
4866 m_start,
4867 m_stop,
4868 b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op)));
4869 StmtPtr root_s = For::make(n, n_start, n_stop, s);
4870 root_s = IRSimplifier::simplify(root_s);
4871 std::ostringstream oss;
4872 oss << *root_s;
4873 std::string target_string = "# CHECK: ";
4874 target_string += check_string;
4875 torch::jit::testing::FileCheck().run(target_string, oss.str());
4876 };
4877
4878 // Before:
4879 // for (const auto n : c10::irange(1, N)) {
4880 // b[n] = n < 1 ? 0.f : 1.f;
4881 // }
4882 // After:
4883 // for (const auto n : c10::irange(1, N)) {
4884 // b[n] = 1.f;
4885 // }
4886 test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;");
4887
4888 // Before:
4889 // for (const auto n : c10::irange(1, N)) {
4890 // b[n] = n <= 1 ? 0.f : 1.f;
4891 // }
4892 // After:
4893 // for (const auto n : c10::irange(1, N)) {
4894 // b[n] = n <= 1 ? 0.f : 1.f;
4895 // }
4896 test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;");
4897
4898 // Before:
4899 // for (const auto n : c10::irange(1, N)) {
4900 // b[n] = n <= 0 ? 0.f : 1.f;
4901 // }
4902 // After:
4903 // for (const auto n : c10::irange(1, N)) {
4904 // b[n] = 1.f;
4905 // }
4906 test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;");
4907
4908 // Before:
4909 // for (const auto n : c10::irange(1, N)) {
4910 // b[n] = n < 0 ? 0.f : 1.f;
4911 // }
4912 // After:
4913 // for (const auto n : c10::irange(1, N)) {
4914 // b[n] = 1.f;
4915 // }
4916 test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;");
4917
4918 // Before:
4919 // for (const auto n : c10::irange(1, N)) {
4920 // b[n] = n < 8 ? 0.f : 1.f;
4921 // }
4922 // After:
4923 // for (const auto n : c10::irange(1, N)) {
4924 // b[n] = 0.f;
4925 // }
4926 test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;");
4927
4928 // Before:
4929 // for (const auto n : c10::irange(1, N)) {
4930 // b[n] = n <= 7 ? 0.f : 1.f;
4931 // }
4932 // After:
4933 // for (const auto n : c10::irange(1, N)) {
4934 // b[n] = 0.f;
4935 // }
4936 test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;");
4937
4938 // Before:
4939 // for (const auto n : c10::irange(1, N)) {
4940 // b[n] = n <= 8 ? 0.f : 1.f;
4941 // }
4942 // After:
4943 // for (const auto n : c10::irange(1, N)) {
4944 // b[n] = 0.f;
4945 // }
4946 test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;");
4947
4948 // Before:
4949 // for (const auto n : c10::irange(1, N)) {
4950 // b[n] = n < 7 ? 0.f : 1.f;
4951 // }
4952 // After:
4953 // for (const auto n : c10::irange(1, N)) {
4954 // b[n] = n < 7 ? 0.f : 1.f;
4955 // }
4956 test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;");
4957
4958 // Before:
4959 // for (const auto n : c10::irange(1, N)) {
4960 // b[n] = n > 0 ? 0.f : 1.f;
4961 // }
4962 // After:
4963 // for (const auto n : c10::irange(1, N)) {
4964 // b[n] = 0.f;
4965 // }
4966 test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;");
4967
4968 // Before:
4969 // for (const auto n : c10::irange(1, N)) {
4970 // b[n] = n > 1 ? 0.f : 1.f;
4971 // }
4972 // After:
4973 // for (const auto n : c10::irange(1, N)) {
4974 // b[n] = n > 1 ? 0.f : 1.f;
4975 // }
4976 test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;");
4977
4978 // Before:
4979 // for (const auto n : c10::irange(1, N)) {
4980 // b[n] = n >= 1 ? 0.f : 1.f;
4981 // }
4982 // After:
4983 // for (const auto n : c10::irange(1, N)) {
4984 // b[n] = 0.f;
4985 // }
4986 test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;");
4987
4988 // Before:
4989 // for (const auto n : c10::irange(1, N)) {
4990 // b[n] = n > 7 ? 0.f : 1.f;
4991 // }
4992 // After:
4993 // for (const auto n : c10::irange(1, N)) {
4994 // b[n] = 1.f;
4995 // }
4996 test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;");
4997
4998 // Before:
4999 // for (const auto n : c10::irange(1, N)) {
5000 // b[n] = n >= 7 ? 0.f : 1.f;
5001 // }
5002 // After:
5003 // for (const auto n : c10::irange(1, N)) {
5004 // b[n] = n >= 7 ? 0.f : 1.f;
5005 // }
5006 test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;");
5007
5008 // Before:
5009 // for (const auto n : c10::irange(1, N)) {
5010 // b[n] = n > 5 ? 0.f : 1.f;
5011 // }
5012 // After:
5013 // for (const auto n : c10::irange(1, N)) {
5014 // b[n] = n > 5 ? 0.f : 1.f;
5015 // }
5016 test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;");
5017
5018 // Before:
5019 // for (const auto n : c10::irange(1, N)) {
5020 // b[n] = n >= 5 ? 0.f : 1.f;
5021 // }
5022 // After:
5023 // for (const auto n : c10::irange(1, N)) {
5024 // b[n] = n >= 5 ? 0.f : 1.f;
5025 // }
5026 test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;");
5027
5028 // Before:
5029 // for (const auto n : c10::irange(1, N)) {
5030 // b[n] = n > 8 ? 0.f : 1.f;
5031 // }
5032 // After:
5033 // for (const auto n : c10::irange(1, N)) {
5034 // b[n] = 1.f;
5035 // }
5036 test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;");
5037
5038 // Before:
5039 // for (const auto n : c10::irange(1, N)) {
5040 // b[n] = n >= 8 ? 0.f : 1.f;
5041 // }
5042 // After:
5043 // for (const auto n : c10::irange(1, N)) {
5044 // b[n] = 1.f;
5045 // }
5046 test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;");
5047
5048 // Before:
5049 // for (const auto n : c10::irange(1, 2)) {
5050 // b[n] = n == 1 ? 0.f : 1.f;
5051 // }
5052 // After:
5053 // for (const auto n : c10::irange(1, 2)) {
5054 // b[1] = 0.f;
5055 // }
5056 test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;");
5057
5058 // Before:
5059 // for (const auto n : c10::irange(1, N)) {
5060 // b[n] = n == 1 ? 0.f : 1.f;
5061 // }
5062 // After:
5063 // for (const auto n : c10::irange(1, N)) {
5064 // b[n] = n == 1 ? 0.f : 1.f;
5065 // }
5066 test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;");
5067
5068 // Before:
5069 // for (const auto n : c10::irange(1, N)) {
5070 // b[n] = n == 0 ? 0.f : 1.f;
5071 // }
5072 // After:
5073 // for (const auto n : c10::irange(1, N)) {
5074 // b[n] = 1.f;
5075 // }
5076 test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;");
5077
5078 // Before:
5079 // for (const auto n : c10::irange(1, N)) {
5080 // b[n] = n == 7 ? 0.f : 1.f;
5081 // }
5082 // After:
5083 // for (const auto n : c10::irange(1, N)) {
5084 // b[n] = n == 7 ? 0.f : 1.f;
5085 // }
5086 test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;");
5087
5088 // Before:
5089 // for (const auto n : c10::irange(1, N)) {
5090 // b[n] = n == 8 ? 0.f : 1.f;
5091 // }
5092 // After:
5093 // for (const auto n : c10::irange(1, N)) {
5094 // b[n] = 1.f;
5095 // }
5096 test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;");
5097
5098 // Before:
5099 // for (const auto n : c10::irange(1, N)) {
5100 // b[n] = n != 1 ? 0.f : 1.f;
5101 // }
5102 // After:
5103 // for (const auto n : c10::irange(1, N)) {
5104 // b[n] = n != 1 ? 0.f : 1.f;
5105 // }
5106 test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;");
5107
5108 // Before:
5109 // for (const auto n : c10::irange(1, N)) {
5110 // b[n] = n != 7 ? 0.f : 1.f;
5111 // }
5112 // After:
5113 // for (const auto n : c10::irange(1, N)) {
5114 // b[n] = n != 7 ? 0.f : 1.f;
5115 // }
5116 test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;");
5117
5118 // Before:
5119 // for (const auto n : c10::irange(1, N)) {
5120 // b[n] = n != 5 ? 0.f : 1.f;
5121 // }
5122 // After:
5123 // for (const auto n : c10::irange(1, N)) {
5124 // b[n] = n != 5 ? 0.f : 1.f;
5125 // }
5126 test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;");
5127
5128 // Before:
5129 // for (const auto n : c10::irange(1, N)) {
5130 // b[n] = n != 0 ? 0.f : 1.f;
5131 // }
5132 // After:
5133 // for (const auto n : c10::irange(1, N)) {
5134 // b[n] = 0.f;
5135 // }
5136 test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;");
5137
5138 // Before:
5139 // for (const auto n : c10::irange(1, N)) {
5140 // b[n] = n != 8 ? 0.f : 1.f;
5141 // }
5142 // After:
5143 // for (const auto n : c10::irange(1, N)) {
5144 // b[n] = 0.f;
5145 // }
5146 test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;");
5147
5148 // Before:
5149 // for (const auto n : c10::irange(10, 20)) {
5150 // for(const auto m : c10::irange(30, 40)) {
5151 // b[n, m] = (n != m) ? 0.f : 1.f;
5152 // }
5153 // }
5154 // After:
5155 // for (const auto n : c10::irange(10, 20)) {
5156 // for(const auto m : c10::irange(30, 40)) {
5157 // b[n, m] = 0.f;
5158 // }
5159 // }
5160 test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;");
5161 test_case_nest_loops_fn(
5162 n,
5163 m,
5164 b,
5165 var_N + 10,
5166 var_N + 20,
5167 var_N + 30,
5168 var_N + 40,
5169 kNE,
5170 "b[n, m] = 0.f;");
5171 test_case_nest_loops_fn(
5172 n,
5173 m,
5174 b,
5175 var_N + 10,
5176 var_N + 20,
5177 var_M + 30,
5178 var_M + 40,
5179 kNE,
5180 "b[n, m] = n!=m ? 0.f : 1.f;");
5181
5182 // Before:
5183 // for (const auto n : c10::irange(30, 40)) {
5184 // for(const auto m : c10::irange(10, 20)) {
5185 // b[n, m] = (n != m) ? 0.f : 1.f;
5186 // }
5187 // }
5188 // After:
5189 // for (const auto n : c10::irange(30, 40)) {
5190 // for(const auto m : c10::irange(10, 20)) {
5191 // b[n, m] = 0.f;
5192 // }
5193 // }
5194 test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;");
5195 test_case_nest_loops_fn(
5196 n,
5197 m,
5198 b,
5199 var_N + 30,
5200 var_N + 40,
5201 var_N + 10,
5202 var_N + 20,
5203 kNE,
5204 "b[n, m] = 0.f;");
5205 test_case_nest_loops_fn(
5206 n,
5207 m,
5208 b,
5209 var_N + 30,
5210 var_N + 40,
5211 var_M + 10,
5212 var_M + 20,
5213 kNE,
5214 "b[n, m] = n!=m ? 0.f : 1.f;");
5215
5216 // Before:
5217 // for (const auto n : c10::irange(30, 40)) {
5218 // for(const auto m : c10::irange(10, 31)) {
5219 // b[n, m] = (n != m) ? 0.f : 1.f;
5220 // }
5221 // }
5222 // After:
5223 // for (const auto n : c10::irange(30, 40)) {
5224 // for(const auto m : c10::irange(10, 31)) {
5225 // b[n, m] = (n != m) ? 0.f : 1.f;
5226 // }
5227 // }
5228 test_case_nest_loops_fn(
5229 n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;");
5230 test_case_nest_loops_fn(
5231 n,
5232 m,
5233 b,
5234 var_N + 30,
5235 var_N + 40,
5236 var_N + 10,
5237 var_N + 31,
5238 kNE,
5239 "b[n, m] = n!=m ? 0.f : 1.f;");
5240 test_case_nest_loops_fn(
5241 n,
5242 m,
5243 b,
5244 var_N + 30,
5245 var_N + 40,
5246 var_M + 10,
5247 var_M + 31,
5248 kNE,
5249 "b[n, m] = n!=m ? 0.f : 1.f;");
5250
5251 // Before:
5252 // for (const auto n : c10::irange(10, 31)) {
5253 // for(const auto m : c10::irange(30, 40)) {
5254 // b[n, m] = (n != m) ? 0.f : 1.f;
5255 // }
5256 // }
5257 // After:
5258 // for (const auto n : c10::irange(10, 31)) {
5259 // for(const auto m : c10::irange(30, 40)) {
5260 // b[n, m] = (n != m) ? 0.f : 1.f;
5261 // }
5262 // }
5263 test_case_nest_loops_fn(
5264 n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;");
5265 test_case_nest_loops_fn(
5266 n,
5267 m,
5268 b,
5269 var_N + 10,
5270 var_N + 31,
5271 var_N + 30,
5272 var_N + 40,
5273 kNE,
5274 "b[n, m] = n!=m ? 0.f : 1.f;");
5275 test_case_nest_loops_fn(
5276 n,
5277 m,
5278 b,
5279 var_N + 10,
5280 var_N + 31,
5281 var_M + 30,
5282 var_M + 40,
5283 kNE,
5284 "b[n, m] = n!=m ? 0.f : 1.f;");
5285
5286 // Before:
5287 // for (const auto n : c10::irange(10, 20)) {
5288 // for(const auto m : c10::irange(30, 40)) {
5289 // b[n, m] = (n < m) ? 0.f : 1.f;
5290 // }
5291 // }
5292 // After:
5293 // for (const auto n : c10::irange(10, 20)) {
5294 // for(const auto m : c10::irange(30, 40)) {
5295 // b[n, m] = 0.f;
5296 // }
5297 // }
5298 test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;");
5299 test_case_nest_loops_fn(
5300 n,
5301 m,
5302 b,
5303 var_N + 10,
5304 var_N + 20,
5305 var_N + 30,
5306 var_N + 40,
5307 kLT,
5308 "b[n, m] = 0.f;");
5309 test_case_nest_loops_fn(
5310 n,
5311 m,
5312 b,
5313 var_N + 10,
5314 var_N + 20,
5315 var_M + 30,
5316 var_M + 40,
5317 kLT,
5318 "b[n, m] = n<m ? 0.f : 1.f;");
5319
5320 // Before:
5321 // for (const auto n : c10::irange(30, 40)) {
5322 // for(const auto m : c10::irange(10, 31)) {
5323 // b[n, m] = (n < m) ? 0.f : 1.f;
5324 // }
5325 // }
5326 // After:
5327 // for (const auto n : c10::irange(30, 40)) {
5328 // for(const auto m : c10::irange(10, 31)) {
5329 // b[n, m] = 1.f;
5330 // }
5331 // }
5332 test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kLT, "b[n, m] = 1.f;");
5333 test_case_nest_loops_fn(
5334 n,
5335 m,
5336 b,
5337 var_N + 30,
5338 var_N + 40,
5339 var_N + 10,
5340 var_N + 31,
5341 kLT,
5342 "b[n, m] = 1.f;");
5343 test_case_nest_loops_fn(
5344 n,
5345 m,
5346 b,
5347 var_N + 30,
5348 var_N + 40,
5349 var_M + 10,
5350 var_M + 31,
5351 kLT,
5352 "b[n, m] = n<m ? 0.f : 1.f;");
5353
5354 // Before:
5355 // for (const auto n : c10::irange(30, 40)) {
5356 // for(const auto m : c10::irange(10, 20)) {
5357 // b[n, m] = (n > m) ? 0.f : 1.f;
5358 // }
5359 // }
5360 // After:
5361 // for (const auto n : c10::irange(30, 40)) {
5362 // for(const auto m : c10::irange(10, 20)) {
5363 // b[n, m] = 0.f;
5364 // }
5365 // }
5366 test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;");
5367 test_case_nest_loops_fn(
5368 n,
5369 m,
5370 b,
5371 var_N + 30,
5372 var_N + 40,
5373 var_N + 10,
5374 var_N + 20,
5375 kGT,
5376 "b[n, m] = 0.f;");
5377 test_case_nest_loops_fn(
5378 n,
5379 m,
5380 b,
5381 var_N + 30,
5382 var_N + 40,
5383 var_M + 10,
5384 var_M + 20,
5385 kGT,
5386 "b[n, m] = n>m ? 0.f : 1.f;");
5387
5388 // Before:
5389 // for (const auto n : c10::irange(10, 31)) {
5390 // for(const auto m : c10::irange(30, 40)) {
5391 // b[n, m] = (n > m) ? 0.f : 1.f;
5392 // }
5393 // }
5394 // After:
5395 // for (const auto n : c10::irange(10, 31)) {
5396 // for(const auto m : c10::irange(30, 40)) {
5397 // b[n, m] = 1.f;
5398 // }
5399 // }
5400 test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;");
5401 test_case_nest_loops_fn(
5402 n,
5403 m,
5404 b,
5405 var_N + 10,
5406 var_N + 31,
5407 var_N + 30,
5408 var_N + 40,
5409 kGT,
5410 "b[n, m] = 1.f;");
5411 test_case_nest_loops_fn(
5412 n,
5413 m,
5414 b,
5415 var_N + 10,
5416 var_N + 31,
5417 var_M + 30,
5418 var_M + 40,
5419 kGT,
5420 "b[n, m] = n>m ? 0.f : 1.f;");
5421
5422 // Before:
5423 // for (const auto n : c10::irange(30, 40)) {
5424 // for(const auto m : c10::irange(10, 31)) {
5425 // b[n, m] = (n >= m) ? 0.f : 1.f;
5426 // }
5427 // }
5428 // After:
5429 // for (const auto n : c10::irange(30, 40)) {
5430 // for(const auto m : c10::irange(10, 31)) {
5431 // b[n, m] = 0.f;
5432 // }
5433 // }
5434 test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;");
5435 test_case_nest_loops_fn(
5436 n,
5437 m,
5438 b,
5439 var_N + 30,
5440 var_N + 40,
5441 var_N + 10,
5442 var_N + 31,
5443 kGE,
5444 "b[n, m] = 0.f;");
5445 test_case_nest_loops_fn(
5446 n,
5447 m,
5448 b,
5449 var_N + 30,
5450 var_N + 40,
5451 var_M + 10,
5452 var_M + 31,
5453 kGE,
5454 "b[n, m] = n>=m ? 0.f : 1.f;");
5455
5456 // Before:
5457 // for (const auto n : c10::irange(10, 20)) {
5458 // for(const auto m : c10::irange(30, 40)) {
5459 // b[n, m] = (n >= m) ? 0.f : 1.f;
5460 // }
5461 // }
5462 // After:
5463 // for (const auto n : c10::irange(10, 20)) {
5464 // for(const auto m : c10::irange(30, 40)) {
5465 // b[n, m] = 1.f;
5466 // }
5467 // }
5468 test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;");
5469 test_case_nest_loops_fn(
5470 n,
5471 m,
5472 b,
5473 var_N + 10,
5474 var_N + 20,
5475 var_N + 30,
5476 var_N + 40,
5477 kGE,
5478 "b[n, m] = 1.f;");
5479 test_case_nest_loops_fn(
5480 n,
5481 m,
5482 b,
5483 var_N + 10,
5484 var_N + 20,
5485 var_M + 30,
5486 var_M + 40,
5487 kGE,
5488 "b[n, m] = n>=m ? 0.f : 1.f;");
5489
5490 // Before:
5491 // for (const auto n : c10::irange(10, 31)) {
5492 // for(const auto m : c10::irange(30, 40)) {
5493 // b[n, m] = (n <= m) ? 0.f : 1.f;
5494 // }
5495 // }
5496 // After:
5497 // for (const auto n : c10::irange(10, 31)) {
5498 // for(const auto m : c10::irange(30, 40)) {
5499 // b[n, m] = 0.f;
5500 // }
5501 // }
5502 test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;");
5503 test_case_nest_loops_fn(
5504 n,
5505 m,
5506 b,
5507 var_N + 10,
5508 var_N + 31,
5509 var_N + 30,
5510 var_N + 40,
5511 kLE,
5512 "b[n, m] = 0.f;");
5513 test_case_nest_loops_fn(
5514 n,
5515 m,
5516 b,
5517 var_N + 10,
5518 var_N + 31,
5519 var_M + 30,
5520 var_M + 40,
5521 kLE,
5522 "b[n, m] = n<=m ? 0.f : 1.f;");
5523
5524 // Before:
5525 // for (const auto n : c10::irange(30, 40)) {
5526 // for(const auto m : c10::irange(10, 20)) {
5527 // b[n, m] = (n <= m) ? 0.f : 1.f;
5528 // }
5529 // }
5530 // After:
5531 // for (const auto n : c10::irange(30, 40)) {
5532 // for(const auto m : c10::irange(10, 20)) {
5533 // b[n, m] = 0.f;
5534 // }
5535 // }
5536 test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;");
5537 test_case_nest_loops_fn(
5538 n,
5539 m,
5540 b,
5541 var_N + 30,
5542 var_N + 40,
5543 var_N + 10,
5544 var_N + 20,
5545 kLE,
5546 "b[n, m] = 1.f;");
5547 test_case_nest_loops_fn(
5548 n,
5549 m,
5550 b,
5551 var_N + 30,
5552 var_N + 40,
5553 var_M + 10,
5554 var_M + 20,
5555 kLE,
5556 "b[n, m] = n<=m ? 0.f : 1.f;");
5557}
5558
5559TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) {
5560 // Before:
5561 // for (const auto n : c10::irange(1, N)) {
5562 // b[n] = n < 1 ? 0.f : 1.f;
5563 // }
5564 // After:
5565 // for (const auto n : c10::irange(1, N)) {
5566 // b[n] = 1.f;
5567 // }
5568 constexpr int N = 8;
5569 BufHandle b("b", {N}, kFloat);
5570 VarHandle n("n", kInt);
5571 StmtPtr s = For::make(
5572 n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT)));
5573 s = IRSimplifier::simplify(s);
5574 std::ostringstream oss;
5575 oss << *s;
5576 torch::jit::testing::FileCheck().run(
5577 R"IR(
5578# CHECK: b[n] = 1.f;
5579)IR",
5580 oss.str());
5581}
5582
5583TEST(Simplify, IfThenCondAlwaysInLoopBounds) {
5584 // Before:
5585 // for (const auto n : c10::irange(1, N)) {
5586 // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f);
5587 // }
5588 // After:
5589 // for (const auto n : c10::irange(1, N)) {
5590 // b[n] = 1.f;
5591 // }
5592 constexpr int N = 8;
5593 BufHandle b("b", {N}, kFloat);
5594 VarHandle n("n", kInt);
5595 StmtPtr s =
5596 For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f)));
5597 s = IRSimplifier::simplify(s);
5598 std::ostringstream oss;
5599 oss << *s;
5600 torch::jit::testing::FileCheck().run(
5601 R"IR(
5602# CHECK: b[n] = 1.f;
5603)IR",
5604 oss.str());
5605}
5606
5607TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) {
5608 // This test mimics the unpadded region of a conv2d. We want to remove any
5609 // conditional that is provably satisfied (or unsatisfied) by the entire loop
5610 // range.
5611 // Before:
5612 // for (const auto i : c10::irange(1, 7)) {
5613 // for (const auto j : c10::irange(1, 7)) {
5614 // b[i, j] = IfThenElse(
5615 // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f);
5616 // After:
5617 // for (const auto i : c10::irange(1, 7)) {
5618 // for (const auto j : c10::irange(1, 7)) {
5619 // b[i, j] = 1.f;
5620 constexpr int N = 8;
5621 BufHandle b("b", {N, N}, kFloat);
5622 VarHandle i("i", kInt);
5623 VarHandle j("j", kInt);
5624 auto csel = CompareSelect::make(i, 1, kLT);
5625 csel = CompareSelect::make(j, 1, 1, csel, kLT);
5626 csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
5627 csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
5628 StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
5629 s = For::make(j, 1, N - 1, s);
5630 s = For::make(i, 1, N - 1, s);
5631 s = IRSimplifier::simplify(s);
5632 std::ostringstream oss;
5633 oss << *s;
5634 torch::jit::testing::FileCheck().run(
5635 R"IR(
5636# CHECK: b[i, j] = 1.f;
5637)IR",
5638 oss.str());
5639}
5640
5641TEST(Simplify, DISABLED_SimplifyLoopBounds) {
5642 // This test mimics the padded region of a conv2d. We want to adjust the
5643 // loop bounds such that the condition will be always met. Note that this
5644 // could be solved by peeling, and applying the range-based conditional
5645 // simplification in the previous tests.
5646 // Before:
5647 // for (const auto i : c10::irange(3)) {
5648 // for (const auto j : c10::irange(3)) {
5649 // b[i, j] = (b[i, j]) + (IfThenElse(
5650 // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j]));
5651 // After:
5652 // for (const auto i : c10::irange(1, 3)) {
5653 // for (const auto j : c10::irange(1, 3)) {
5654 // b[i, j] = (b[i, j]) + 1.f;
5655 constexpr int N = 8;
5656 constexpr int K = 3;
5657 BufHandle a("a", {N, N}, kFloat);
5658 BufHandle b("b", {N, N}, kFloat);
5659 VarHandle i("i", kInt);
5660 VarHandle j("j", kInt);
5661 auto csel = CompareSelect::make(i, 1, kLT);
5662 csel = CompareSelect::make(j, 1, 1, csel, kLT);
5663 csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
5664 csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
5665 StmtPtr s = b.store(
5666 {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j})));
5667 s = For::make(j, 0, K, s);
5668 s = For::make(i, 0, K, s);
5669 s = IRSimplifier::simplify(s);
5670 std::ostringstream oss;
5671 oss << *s;
5672 torch::jit::testing::FileCheck().run(
5673 R"IR(
5674# CHECK: for (const auto i : c10::irange(1, 3)) {
5675# CHECK: for (const auto j : c10::irange(1, 3)) {
5676# CHECK-NOT: IfThenElse
5677)IR",
5678 oss.str());
5679}
5680
5681} // namespace jit
5682} // namespace torch
5683