1 | #include <gtest/gtest.h> |
2 | #include <test/cpp/tensorexpr/test_base.h> |
3 | |
4 | #include <c10/util/irange.h> |
5 | #include <test/cpp/tensorexpr/test_utils.h> |
6 | #include <torch/csrc/jit/tensorexpr/hash_provider.h> |
7 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
8 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
9 | |
10 | #include <cmath> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | using namespace torch::jit::tensorexpr; |
15 | using SimpleIRExprEval = ExprEval<SimpleIREvaluator>; |
16 | |
17 | TEST(Simplify, ConstantFoldSimple) { |
18 | ExprHandle a(2.0f); |
19 | ExprHandle b(3.0f); |
20 | ExprHandle f = (a + b); |
21 | |
22 | ExprHandle newF = IRSimplifier::simplify(f); |
23 | ASSERT_NE(newF.AsNode<FloatImm>(), nullptr); |
24 | ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 5); |
25 | |
26 | SimpleIRExprEval eval(newF); |
27 | ASSERT_EQ(eval.value<float>(), 5.f); |
28 | } |
29 | |
30 | TEST(Simplify, ConstantFoldTwoLayer) { |
31 | ExprHandle a(2.0f); |
32 | ExprHandle b(3.0f); |
33 | ExprHandle c(4.0f); |
34 | ExprHandle d(5.0f); |
35 | ExprHandle f = (a + b) - (c + d); |
36 | |
37 | ExprHandle newF = IRSimplifier::simplify(f); |
38 | ASSERT_NE(newF.AsNode<FloatImm>(), nullptr); |
39 | ASSERT_EQ(newF.AsNode<FloatImm>()->value(), -4); |
40 | |
41 | SimpleIRExprEval eval(newF); |
42 | ASSERT_EQ(eval.value<float>(), -4.f); |
43 | } |
44 | |
45 | TEST(Simplify, ConstantFoldShifts) { |
46 | ExprHandle a(7); |
47 | ExprHandle b(2); |
48 | ExprHandle c(3); |
49 | ExprHandle f = ((a << b) << b) >> c; |
50 | |
51 | ExprHandle newF = IRSimplifier::simplify(f); |
52 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
53 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 14); |
54 | |
55 | SimpleIRExprEval eval(newF); |
56 | ASSERT_EQ(eval.value<int>(), 7 << (4 - 3)); |
57 | } |
58 | |
59 | TEST(Simplify, ConstantFoldBitwise) { |
60 | ExprHandle a(59); |
61 | ExprHandle b(22); |
62 | ExprHandle c(101); |
63 | ExprHandle f = (a ^ b) & c; |
64 | |
65 | ExprHandle newF = IRSimplifier::simplify(f); |
66 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
67 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 37); |
68 | |
69 | SimpleIRExprEval eval(newF); |
70 | ASSERT_EQ(eval.value<int>(), (59 ^ 22) & 101); |
71 | } |
72 | |
73 | TEST(Simplify, ConstantFoldMultiOp) { |
74 | ExprHandle a(2.0f); |
75 | ExprHandle b(3.0f); |
76 | ExprHandle c(4.0f); |
77 | ExprHandle d(5.0f); |
78 | ExprHandle e(6.0f); |
79 | ExprHandle f(7.0f); |
80 | ExprHandle fn = ((a / e) - (c + d)) * (f / b); |
81 | |
82 | ExprHandle newF = IRSimplifier::simplify(fn); |
83 | ASSERT_NE(newF.AsNode<FloatImm>(), nullptr); |
84 | |
85 | SimpleIRExprEval eval(newF); |
86 | SimpleIRExprEval ref(fn); |
87 | |
88 | ASSERT_EQ(eval.value<float>(), ref.value<float>()); |
89 | } |
90 | |
91 | TEST(Simplify, ConstantFoldMinMax) { |
92 | ExprHandle a(12.0f); |
93 | ExprHandle b(15.0f); |
94 | ExprHandle c(17.0f); |
95 | |
96 | // x = max(12, min(15, 17)). |
97 | ExprHandle minHandle = Min::make(b, c, true); |
98 | ExprHandle fn = Max::make(a, minHandle, false); |
99 | |
100 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
101 | ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); |
102 | |
103 | ExprHandle newF = IRSimplifier::simplify(fn); |
104 | ASSERT_NE(newF.AsNode<FloatImm>(), nullptr); |
105 | |
106 | SimpleIRExprEval eval(newF); |
107 | ASSERT_EQ(eval.value<float>(), 15.f); |
108 | } |
109 | |
110 | TEST(Simplify, ConstantFoldIntrinsics) { |
111 | ExprHandle a(2.0f); |
112 | ExprHandle b(3.0f); |
113 | ExprHandle c(4.0f); |
114 | ExprHandle powHandle = Intrinsics::make(kPow, a, b); |
115 | ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); |
116 | ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); |
117 | ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); |
118 | ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); |
119 | ExprHandle fn = Intrinsics::make(kAbs, rndHandle); |
120 | |
121 | ExprHandle newF = IRSimplifier::simplify(fn); |
122 | ASSERT_NE(newF.AsNode<FloatImm>(), nullptr); |
123 | ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 1); |
124 | |
125 | SimpleIRExprEval eval(newF); |
126 | SimpleIRExprEval ref(fn); |
127 | |
128 | ASSERT_EQ(eval.value<float>(), ref.value<float>()); |
129 | } |
130 | |
131 | TEST(Simplify, ConstantFoldCastToBool) { |
132 | ExprHandle f = Cast::make(kBool, IntImm::make(0)); |
133 | ExprHandle newF = IRSimplifier::simplify(f); |
134 | SimpleIRExprEval eval(newF); |
135 | ASSERT_EQ(eval.value<bool>(), false); |
136 | } |
137 | |
138 | TEST(Simplify, ConstantFoldWithVar) { |
139 | { |
140 | VarHandle x("x" , kInt); |
141 | ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); |
142 | |
143 | ExprHandle newF = IRSimplifier::simplify(body); |
144 | MulPtr root = newF.AsNode<Mul>(); |
145 | ASSERT_NE(root, nullptr); |
146 | ASSERT_NE(to<IntImm>(root->lhs()), nullptr); |
147 | |
148 | SimpleIRExprEval eval(newF); |
149 | eval.bindVar(x, ExprHandle(3)); |
150 | ASSERT_EQ(eval.value<int>(), 3 * (2 + 4)); |
151 | } |
152 | |
153 | { |
154 | VarHandle x("x" , kFloat); |
155 | ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); |
156 | |
157 | ExprHandle newF = IRSimplifier::simplify(body); |
158 | MulPtr root = newF.AsNode<Mul>(); |
159 | ASSERT_NE(root, nullptr); |
160 | ASSERT_NE(to<FloatImm>(root->rhs()), nullptr); |
161 | |
162 | SimpleIRExprEval eval(newF); |
163 | eval.bindVar(x, ExprHandle(3.f)); |
164 | ASSERT_EQ(eval.value<float>(), 3 * (2 + 4)); |
165 | } |
166 | } |
167 | |
168 | TEST(Simplify, ConditionalSelectFoldSimple) { |
169 | ExprHandle a(3.0f); |
170 | ExprHandle b(4.0f); |
171 | ExprHandle c(3.0f); |
172 | { |
173 | ExprHandle f = (a > b); |
174 | |
175 | ExprHandle newF = IRSimplifier::simplify(f); |
176 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
177 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0); |
178 | |
179 | SimpleIRExprEval eval(newF); |
180 | ASSERT_EQ(eval.value<int>(), 0); |
181 | } |
182 | { |
183 | ExprHandle f = (a < b); |
184 | |
185 | ExprHandle newF = IRSimplifier::simplify(f); |
186 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
187 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1); |
188 | |
189 | SimpleIRExprEval eval(newF); |
190 | ASSERT_EQ(eval.value<int>(), 1); |
191 | } |
192 | { |
193 | ExprHandle f = (a == c); |
194 | |
195 | ExprHandle newF = IRSimplifier::simplify(f); |
196 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
197 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1); |
198 | |
199 | SimpleIRExprEval eval(newF); |
200 | ASSERT_EQ(eval.value<int>(), 1); |
201 | } |
202 | { |
203 | ExprHandle f = (a != c); |
204 | |
205 | ExprHandle newF = IRSimplifier::simplify(f); |
206 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
207 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0); |
208 | |
209 | SimpleIRExprEval eval(newF); |
210 | ASSERT_EQ(eval.value<int>(), 0); |
211 | } |
212 | } |
213 | |
214 | TEST(Simplify, ConditionalSelectFoldTwoLayer) { |
215 | ExprHandle a(3.0f); |
216 | ExprHandle b(2.0f); |
217 | ExprHandle c(2.0f); |
218 | ExprHandle d(1.0f); |
219 | { |
220 | ExprHandle f = (a + b < c + d); |
221 | |
222 | ExprHandle newF = IRSimplifier::simplify(f); |
223 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
224 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0); |
225 | |
226 | SimpleIRExprEval eval(newF); |
227 | ASSERT_EQ(eval.value<int>(), 0); |
228 | } |
229 | { |
230 | ExprHandle f = (a + b > c + d); |
231 | |
232 | ExprHandle newF = IRSimplifier::simplify(f); |
233 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
234 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1); |
235 | |
236 | SimpleIRExprEval eval(newF); |
237 | ASSERT_EQ(eval.value<int>(), 1); |
238 | } |
239 | { |
240 | ExprHandle f = (a + d == b + c); |
241 | |
242 | ExprHandle newF = IRSimplifier::simplify(f); |
243 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
244 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 1); |
245 | |
246 | SimpleIRExprEval eval(newF); |
247 | ASSERT_EQ(eval.value<int>(), 1); |
248 | } |
249 | { |
250 | ExprHandle f = (a + d != b + c); |
251 | |
252 | ExprHandle newF = IRSimplifier::simplify(f); |
253 | ASSERT_NE(newF.AsNode<IntImm>(), nullptr); |
254 | ASSERT_EQ(newF.AsNode<IntImm>()->value(), 0); |
255 | |
256 | SimpleIRExprEval eval(newF); |
257 | ASSERT_EQ(eval.value<int>(), 0); |
258 | } |
259 | } |
260 | |
261 | TEST(Simplify, ConditionalSelectFoldWithVar) { |
262 | VarHandle x("x" , kFloat); |
263 | ExprHandle f = x < 4.f; |
264 | |
265 | ExprHandle newF = IRSimplifier::simplify(f); |
266 | IntImmPtr folded = newF.AsNode<IntImm>(); |
267 | ASSERT_EQ(folded, nullptr); |
268 | |
269 | { |
270 | SimpleIRExprEval eval(newF); |
271 | eval.bindVar(x, ExprHandle(3.f)); |
272 | ASSERT_EQ(eval.value<int>(), 1); |
273 | } |
274 | { |
275 | SimpleIRExprEval eval(newF); |
276 | eval.bindVar(x, ExprHandle(5.f)); |
277 | ASSERT_EQ(eval.value<int>(), 0); |
278 | } |
279 | } |
280 | |
281 | TEST(Simplify, UnFoldableExpr) { |
282 | VarHandle x("x" , kFloat); |
283 | VarHandle y("y" , kFloat); |
284 | ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); |
285 | |
286 | ExprHandle newF = IRSimplifier::simplify(body); |
287 | AddPtr root = newF.AsNode<Add>(); |
288 | ASSERT_NE(root, nullptr); |
289 | ASSERT_EQ(to<FloatImm>(root->lhs()), nullptr); |
290 | ASSERT_EQ(to<FloatImm>(root->rhs()), nullptr); |
291 | |
292 | SimpleIRExprEval eval(newF); |
293 | eval.bindVar(x, ExprHandle(3.f)); |
294 | eval.bindVar(y, ExprHandle(2.f)); |
295 | ASSERT_EQ(eval.value<float>(), 9 + 10); |
296 | } |
297 | |
298 | TEST(Simplify, HashSimple) { |
299 | VarHandle x("x" , kFloat); |
300 | ExprHandle a(2.0f); |
301 | ExprHandle b(3.0f); |
302 | ExprHandle f = a + b * x; |
303 | |
304 | HashProvider hasher; |
305 | |
306 | auto hash_x = hasher.hash(x.node()); |
307 | auto hash_a = hasher.hash(a.node()); |
308 | auto hash_f = hasher.hash(f.node()); |
309 | |
310 | ASSERT_NE(hash_x, (size_t)0); |
311 | ASSERT_NE(hash_a, (size_t)0); |
312 | ASSERT_NE(hash_f, (size_t)0); |
313 | ASSERT_NE(hash_x, hash_a); |
314 | ASSERT_NE(hash_x, hash_f); |
315 | ASSERT_NE(hash_a, hash_f); |
316 | } |
317 | |
318 | TEST(Simplify, HashEquivalence) { |
319 | VarHandle x("x" , kFloat); |
320 | VarHandle y("y" , kFloat); |
321 | ExprHandle f = (x * y) + (x * y); |
322 | |
323 | AddPtr root = f.AsNode<Add>(); |
324 | ASSERT_NE(root, nullptr); |
325 | |
326 | HashProvider hasher; |
327 | auto hash_f = hasher.hash(f.node()); |
328 | auto hash_l = hasher.hash(root->lhs()); |
329 | auto hash_r = hasher.hash(root->rhs()); |
330 | |
331 | // Root not equal to either branch. |
332 | ASSERT_NE(hash_f, hash_l); |
333 | ASSERT_NE(hash_f, hash_r); |
334 | // but branches are equal. |
335 | ASSERT_EQ(hash_l, hash_r); |
336 | |
337 | // Still equivalent if separate. |
338 | ExprHandle a(2); |
339 | ExprHandle f2 = x + a / y; |
340 | ExprHandle b(2); |
341 | ExprHandle f3 = x + b / y; |
342 | ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); |
343 | |
344 | // Not equivalent if different vars (even with same name). |
345 | VarHandle z("x" , kFloat); |
346 | ExprHandle f4 = z + b / y; |
347 | ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); |
348 | |
349 | // Intrinsics sanity check. |
350 | ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); |
351 | ASSERT_NE(hasher.hash(f5.node()), (size_t)0); |
352 | } |
353 | |
354 | TEST(Simplify, HashEquivalenceRand) { |
355 | ExprHandle f = |
356 | Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); |
357 | |
358 | AddPtr root = f.AsNode<Add>(); |
359 | ASSERT_NE(root, nullptr); |
360 | |
361 | HashProvider hasher; |
362 | auto hash_f = hasher.hash(f.node()); |
363 | auto hash_l = hasher.hash(root->lhs()); |
364 | auto hash_r = hasher.hash(root->rhs()); |
365 | |
366 | // Root not equal to either branch. |
367 | ASSERT_NE(hash_f, hash_l); |
368 | ASSERT_NE(hash_f, hash_r); |
369 | // and branches are NOT equal. |
370 | ASSERT_NE(hash_l, hash_r); |
371 | } |
372 | |
373 | TEST(Simplify, HashEquivalenceAfterFolding) { |
374 | VarHandle x("x" , kFloat); |
375 | ExprHandle a(2.0f); |
376 | ExprHandle b(3.0f); |
377 | ExprHandle c(5.0f); |
378 | |
379 | ExprHandle f1 = ((a + b) * x); |
380 | ExprHandle f2 = (c * x); |
381 | |
382 | HashProvider hasher; |
383 | auto hash_l = hasher.hash(f1.node()); |
384 | auto hash_r = hasher.hash(f2.node()); |
385 | |
386 | // Root not equal to either branch, and branches not equal. |
387 | ASSERT_NE(hash_l, hash_r); |
388 | |
389 | ExprHandle ff1 = IRSimplifier::simplify(f1); |
390 | ExprHandle ff2 = IRSimplifier::simplify(f2); |
391 | |
392 | auto hash_l_n = hasher.hash(ff1.node()); |
393 | auto hash_r_n = hasher.hash(ff2.node()); |
394 | // but branches are now equal. |
395 | ASSERT_EQ(hash_l_n, hash_r_n); |
396 | } |
397 | |
398 | TEST(Simplify, HashDifferenceTypes) { |
399 | HashProvider hasher; |
400 | std::vector<ExprPtr> immediates; |
401 | |
402 | immediates.push_back(alloc<DoubleImm>(1)); |
403 | immediates.push_back(alloc<FloatImm>(1)); |
404 | immediates.push_back(alloc<HalfImm>(1)); |
405 | // NOLINTNEXTLINE(modernize-use-bool-literals) |
406 | immediates.push_back(alloc<BoolImm>(1)); |
407 | immediates.push_back(alloc<CharImm>(1)); |
408 | immediates.push_back(alloc<ByteImm>(1)); |
409 | immediates.push_back(alloc<ShortImm>(1)); |
410 | immediates.push_back(alloc<IntImm>(1)); |
411 | immediates.push_back(alloc<LongImm>(1)); |
412 | |
413 | // Immediates of different types are not equal. |
414 | for (unsigned int i = 0; i < immediates.size(); ++i) { |
415 | for (unsigned int j = i + 1; j < immediates.size(); ++j) { |
416 | ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); |
417 | } |
418 | } |
419 | |
420 | // But coerced immediates are if they are the same type: |
421 | ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); |
422 | ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); |
423 | |
424 | ExprHandle ff1 = IRSimplifier::simplify(f1); |
425 | ExprHandle ff2 = IRSimplifier::simplify(f2); |
426 | |
427 | ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); |
428 | } |
429 | |
430 | TEST(Simplify, HashLargeExpression) { |
431 | constexpr int N = 1024; |
432 | BufHandle a("A" , {N}, kInt); |
433 | BufHandle b("B" , {N}, kInt); |
434 | BufHandle c("C" , {N}, kInt); |
435 | VarHandle i("i" , kInt); |
436 | auto memcpy_stmt = For::make( |
437 | i, |
438 | 0, |
439 | N, |
440 | Store::make( |
441 | c, |
442 | {i}, |
443 | CompareSelect::make( |
444 | Load::make(a, {i}), |
445 | Load::make(b, {i}), |
446 | CompareSelectOperation::kEQ))); |
447 | |
448 | BufHandle d("D" , {1}, kInt); |
449 | BufHandle e("E" , {1}, kInt); |
450 | auto store_ramp_stmt = Store::make( |
451 | e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); |
452 | |
453 | auto if_stmt = Cond::make( |
454 | CompareSelect::make( |
455 | Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), |
456 | memcpy_stmt, |
457 | store_ramp_stmt); |
458 | |
459 | HashProvider hasher; |
460 | auto hash_r = hasher.hash(if_stmt); |
461 | // We should not have to do any more work. |
462 | ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); |
463 | auto hash_t = hasher.hash(memcpy_stmt); |
464 | ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); |
465 | auto hash_f = hasher.hash(store_ramp_stmt); |
466 | |
467 | // Root not equal to either branch, and branches not equal. |
468 | ASSERT_NE(hash_r, hash_t); |
469 | ASSERT_NE(hash_r, hash_f); |
470 | ASSERT_NE(hash_t, hash_f); |
471 | } |
472 | |
473 | TEST(Simplify, HashForLoopOptions) { |
474 | constexpr int N = 1024; |
475 | BufHandle a("A" , {N}, kInt); |
476 | BufHandle b("B" , {N}, kInt); |
477 | BufHandle c("C" , {N}, kInt); |
478 | VarHandle i("i" , kInt); |
479 | auto for_stmt = For::make( |
480 | i, |
481 | 0, |
482 | N, |
483 | Store::make( |
484 | c, |
485 | {i}, |
486 | CompareSelect::make( |
487 | Load::make(a, {i}), |
488 | Load::make(b, {i}), |
489 | CompareSelectOperation::kEQ))); |
490 | |
491 | HashProvider hasher; |
492 | auto hash_before = hasher.hash(for_stmt); |
493 | hasher.clearCache(); |
494 | |
495 | for_stmt->set_gpu_block_index(LoopOptions::IDX_X); |
496 | auto hash_block_idx = hasher.hash(for_stmt); |
497 | hasher.clearCache(); |
498 | |
499 | ASSERT_NE(hash_before, hash_block_idx); |
500 | |
501 | for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); |
502 | auto hash_reset = hasher.hash(for_stmt); |
503 | hasher.clearCache(); |
504 | |
505 | ASSERT_EQ(hash_before, hash_reset); |
506 | for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); |
507 | auto hash_thread_idx = hasher.hash(for_stmt); |
508 | |
509 | ASSERT_NE(hash_before, hash_thread_idx); |
510 | ASSERT_NE(hash_block_idx, hash_thread_idx); |
511 | } |
512 | |
513 | /// (2 + x) + 4 => x + 6 |
514 | TEST(Simplify, SimplifyAdd) { |
515 | VarHandle x("x" , kInt); |
516 | VarHandle y("y" , kInt); |
517 | |
518 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
519 | VarHandle m("m" , kInt); |
520 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
521 | VarHandle n("n" , kInt); |
522 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
523 | VarHandle n_1("n_1" , kInt); |
524 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
525 | ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); |
526 | |
527 | ExprHandle simplified = IRSimplifier::simplify(body); |
528 | AddPtr root = simplified.AsNode<Add>(); |
529 | ASSERT_NE(root, nullptr); |
530 | VarPtr lhs = to<Var>(root->lhs()); |
531 | ASSERT_NE(lhs, nullptr); |
532 | ASSERT_EQ(lhs->name_hint(), "x" ); |
533 | IntImmPtr rhs = to<IntImm>(root->rhs()); |
534 | ASSERT_NE(rhs, nullptr); |
535 | ASSERT_EQ(rhs->value(), 6.f); |
536 | } |
537 | |
538 | /// (2 - x) - 4 => -2 - x |
539 | TEST(Simplify, SimplifySub) { |
540 | VarHandle x("x" , kInt); |
541 | ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); |
542 | |
543 | ExprHandle simplified = IRSimplifier::simplify(body); |
544 | SubPtr root = simplified.AsNode<Sub>(); |
545 | ASSERT_NE(root, nullptr); |
546 | IntImmPtr lhs = to<IntImm>(root->lhs()); |
547 | ASSERT_NE(lhs, nullptr); |
548 | ASSERT_EQ(lhs->value(), -2.f); |
549 | VarPtr rhs = to<Var>(root->rhs()); |
550 | ASSERT_NE(rhs, nullptr); |
551 | ASSERT_EQ(rhs->name_hint(), "x" ); |
552 | } |
553 | |
554 | /// 2 * (1 - x) - 4 => 2 * (-3 - x) |
555 | TEST(Simplify, SimplifyMultiLayer) { |
556 | VarHandle x("x" , kInt); |
557 | ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); |
558 | ExprHandle simplified = IRSimplifier::simplify(body); |
559 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
560 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
561 | IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); |
562 | IS_IMM_WITH_VAL(Int, sub->lhs(), -3); |
563 | IS_VAR_WITH_NAME(sub->rhs(), "x" ); |
564 | } |
565 | |
566 | /// 2 * (3 * x) - (x * 4) => 2 * x |
567 | TEST(Simplify, SimplifyMultiTerm) { |
568 | VarHandle x("x" , kInt); |
569 | ExprHandle body = |
570 | (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); |
571 | |
572 | ExprHandle simplified = IRSimplifier::simplify(body); |
573 | MulPtr root = simplified.AsNode<Mul>(); |
574 | ASSERT_NE(root, nullptr); |
575 | IntImmPtr lhs = to<IntImm>(root->lhs()); |
576 | ASSERT_NE(lhs, nullptr); |
577 | ASSERT_EQ(lhs->value(), 2); |
578 | VarPtr rhs = to<Var>(root->rhs()); |
579 | ASSERT_NE(rhs, nullptr); |
580 | ASSERT_EQ(rhs->name_hint(), "x" ); |
581 | } |
582 | |
583 | /// 2 * (3 * (long)x) - (x * 4) => 2 * x |
584 | TEST(Simplify, SimplifyCasts) { |
585 | VarHandle x("x" , kLong); |
586 | ExprHandle body = |
587 | (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); |
588 | |
589 | ExprHandle simplified = IRSimplifier::simplify(body); |
590 | MulPtr root = simplified.AsNode<Mul>(); |
591 | ASSERT_NE(root, nullptr); |
592 | LongImmPtr lhs = to<LongImm>(root->lhs()); |
593 | ASSERT_NE(lhs, nullptr); |
594 | ASSERT_EQ(lhs->value(), 2); |
595 | VarPtr rhs = to<Var>(root->rhs()); |
596 | ASSERT_NE(rhs, nullptr); |
597 | ASSERT_EQ(rhs->name_hint(), "x" ); |
598 | } |
599 | |
600 | /// (x + 0) * 1 => x |
601 | TEST(Simplify, SimplifyEliminatesNoOps) { |
602 | VarHandle x("x" , kInt); |
603 | ExprHandle body = (x + ExprHandle(0)) * 1; |
604 | |
605 | ExprHandle simplified = IRSimplifier::simplify(body); |
606 | VarPtr root = simplified.AsNode<Var>(); |
607 | ASSERT_NE(root, nullptr); |
608 | ASSERT_EQ(root->name_hint(), "x" ); |
609 | } |
610 | |
611 | /// Cannot simplify this. |
612 | TEST(Simplify, SimplifyMultiVar) { |
613 | VarHandle x("x" , kInt); |
614 | VarHandle y("y" , kInt); |
615 | ExprHandle body = x * 24 + y * 34; |
616 | |
617 | ExprHandle simplified = IRSimplifier::simplify(body); |
618 | |
619 | AddPtr root = simplified.AsNode<Add>(); |
620 | ASSERT_NE(root, nullptr); |
621 | MulPtr lhs = to<Mul>(root->lhs()); |
622 | ASSERT_NE(lhs, nullptr); |
623 | VarPtr varX = to<Var>(lhs->rhs()); |
624 | ASSERT_NE(varX, nullptr); |
625 | ASSERT_EQ(varX->name_hint(), "x" ); |
626 | MulPtr rhs = to<Mul>(root->rhs()); |
627 | ASSERT_NE(rhs, nullptr); |
628 | VarPtr varY = to<Var>(rhs->rhs()); |
629 | ASSERT_NE(varY, nullptr); |
630 | ASSERT_EQ(varY->name_hint(), "y" ); |
631 | } |
632 | |
633 | // x + 2 + y => x + y + 2 |
634 | TEST(Simplify, DISABLED_SimplifyReorderings) { |
635 | VarHandle x("x" , kInt); |
636 | VarHandle y("y" , kInt); |
637 | ExprHandle body = x + 2 + y; |
638 | ExprHandle simplified = IRSimplifier::simplify(body); |
639 | |
640 | AddPtr root = simplified.AsNode<Add>(); |
641 | ASSERT_NE(root, nullptr); |
642 | |
643 | IS_NODE_WITH_NAME(Add, root->lhs(), rhs); |
644 | IS_VAR_WITH_NAME(rhs->lhs(), "x" ); |
645 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
646 | IS_IMM_WITH_VAL(Int, root->rhs(), 2); |
647 | } |
648 | |
649 | /// y + x * 0 => y |
650 | TEST(Simplify, SimplifyEliminatesVar) { |
651 | VarHandle x("x" , kInt); |
652 | VarHandle y("y" , kInt); |
653 | ExprHandle body = y + x * ExprHandle(0); |
654 | |
655 | ExprHandle simplified = IRSimplifier::simplify(body); |
656 | IS_VAR_WITH_NAME(simplified.node(), "y" ); |
657 | } |
658 | |
659 | TEST(Simplify, SimplifyAdds) { |
660 | VarHandle x("x" , kInt); |
661 | VarHandle y("y" , kInt); |
662 | |
663 | { |
664 | // (x + y) + (x + y) => 2 * (x + y) |
665 | ExprHandle body = (x + y) + (x + y); |
666 | ExprHandle simplified = IRSimplifier::simplify(body); |
667 | |
668 | IS_NODE_WITH_NAME(Mul, simplified.node(), root); |
669 | IS_IMM_WITH_VAL(Int, root->lhs(), 2); |
670 | IS_NODE_WITH_NAME(Add, root->rhs(), add); |
671 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
672 | IS_VAR_WITH_NAME(add->rhs(), "y" ); |
673 | } |
674 | |
675 | { |
676 | // (x * y) + (x * y) => 2 * (x * y) |
677 | ExprHandle body = (x * y) + (x * y); |
678 | ExprHandle simplified = IRSimplifier::simplify(body); |
679 | |
680 | IS_NODE_WITH_NAME(Mul, simplified.node(), root); |
681 | IS_IMM_WITH_VAL(Int, root->lhs(), 2); |
682 | IS_NODE_WITH_NAME(Mul, root->rhs(), mul); |
683 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
684 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
685 | } |
686 | |
687 | { |
688 | // (x - y) + (x - y) => 2 * (x - y) |
689 | ExprHandle body = (x - y) + (x - y); |
690 | ExprHandle simplified = IRSimplifier::simplify(body); |
691 | |
692 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
693 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
694 | |
695 | IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); |
696 | IS_VAR_WITH_NAME(rhs->lhs(), "x" ); |
697 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
698 | } |
699 | |
700 | { |
701 | // (x + x + x + x) => 4 * x |
702 | ExprHandle body = (x + x + x + x); |
703 | ExprHandle simplified = IRSimplifier::simplify(body); |
704 | |
705 | IS_NODE_WITH_NAME(Mul, simplified.node(), root); |
706 | IS_IMM_WITH_VAL(Int, root->lhs(), 4); |
707 | IS_VAR_WITH_NAME(root->rhs(), "x" ); |
708 | } |
709 | |
710 | { |
711 | // (x + 0) => x. |
712 | ExprHandle body = x + 0; |
713 | ExprHandle simplified = IRSimplifier::simplify(body); |
714 | |
715 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
716 | } |
717 | |
718 | { |
719 | // (x + 0.f) => float(x). |
720 | ExprHandle body = x + 0.f; |
721 | ExprHandle simplified = IRSimplifier::simplify(body); |
722 | |
723 | IS_NODE_WITH_NAME(Cast, simplified.node(), cast); |
724 | ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); |
725 | IS_VAR_WITH_NAME(cast->src_value(), "x" ); |
726 | } |
727 | } |
728 | |
729 | TEST(Simplify, SimplifyMuls) { |
730 | VarHandle x("x" , kInt); |
731 | VarHandle y("y" , kInt); |
732 | |
733 | { |
734 | // (x + y) * (x + y) => (x + y) * (x + y) |
735 | // We don't attempt to simplify mulitplication of polynomials since the |
736 | // result is only very rarely more efficient. |
737 | ExprHandle body = (x + y) * (x + y); |
738 | ExprHandle simplified = IRSimplifier::simplify(body); |
739 | |
740 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
741 | IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); |
742 | IS_VAR_WITH_NAME(lhs->lhs(), "x" ); |
743 | IS_VAR_WITH_NAME(lhs->rhs(), "y" ); |
744 | IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); |
745 | IS_VAR_WITH_NAME(rhs->lhs(), "x" ); |
746 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
747 | } |
748 | |
749 | { |
750 | // x * y * x * y => x * x * y * y |
751 | // These get reordered only. |
752 | ExprHandle body = x * y * x * y; |
753 | ExprHandle simplified = IRSimplifier::simplify(body); |
754 | |
755 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); |
756 | IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); |
757 | IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); |
758 | IS_VAR_WITH_NAME(mul1->rhs(), "y" ); |
759 | IS_VAR_WITH_NAME(mul2->rhs(), "y" ); |
760 | IS_VAR_WITH_NAME(mul3->lhs(), "x" ); |
761 | IS_VAR_WITH_NAME(mul3->rhs(), "x" ); |
762 | } |
763 | |
764 | { |
765 | // 1 * (x * 1) => x |
766 | // Ones cancel cleanly. |
767 | ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); |
768 | ExprHandle simplified = IRSimplifier::simplify(body); |
769 | |
770 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
771 | } |
772 | |
773 | { |
774 | // 1.f * (x * 1.f) => x |
775 | // Even float ones cancel cleanly, but carry their type. |
776 | ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); |
777 | ExprHandle simplified = IRSimplifier::simplify(body); |
778 | |
779 | IS_NODE_WITH_NAME(Cast, simplified.node(), cast); |
780 | ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); |
781 | IS_VAR_WITH_NAME(cast->src_value(), "x" ); |
782 | } |
783 | |
784 | { |
785 | // 1 * (x * 1.f) => x |
786 | // One float is enough to cast the expr. |
787 | ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); |
788 | ExprHandle simplified = IRSimplifier::simplify(body); |
789 | |
790 | IS_NODE_WITH_NAME(Cast, simplified.node(), cast); |
791 | ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); |
792 | IS_VAR_WITH_NAME(cast->src_value(), "x" ); |
793 | } |
794 | |
795 | { |
796 | // 1 * (x * 0) => 0 |
797 | // Zeroes are eliminated. |
798 | ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); |
799 | ExprHandle simplified = IRSimplifier::simplify(body); |
800 | |
801 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
802 | } |
803 | |
804 | { |
805 | // 1 * (x * 0) => 0 |
806 | // But not for Float since nan * 0 = nan. |
807 | ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); |
808 | ExprHandle simplified = IRSimplifier::simplify(body); |
809 | |
810 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
811 | IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); |
812 | ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); |
813 | IS_VAR_WITH_NAME(cast->src_value(), "x" ); |
814 | IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); |
815 | } |
816 | |
817 | { |
818 | // (x - y) * (x - y) => (x - y) * (x - y) |
819 | // As with Add we don't attempt simplification of this. |
820 | ExprHandle body = (x - y) * (x - y); |
821 | ExprHandle simplified = IRSimplifier::simplify(body); |
822 | |
823 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
824 | IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); |
825 | IS_VAR_WITH_NAME(lhs->lhs(), "x" ); |
826 | IS_VAR_WITH_NAME(lhs->rhs(), "y" ); |
827 | IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); |
828 | IS_VAR_WITH_NAME(rhs->lhs(), "x" ); |
829 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
830 | } |
831 | |
832 | { |
833 | // (x + y) * (x - y) => (x + y) * (x - y) |
834 | // Don't simplify with different ops on each side. |
835 | ExprHandle body = (x + y) * (x - y); |
836 | ExprHandle simplified = IRSimplifier::simplify(body); |
837 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
838 | IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); |
839 | IS_VAR_WITH_NAME(lhs->lhs(), "x" ); |
840 | IS_VAR_WITH_NAME(lhs->rhs(), "y" ); |
841 | IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); |
842 | IS_VAR_WITH_NAME(rhs->lhs(), "x" ); |
843 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
844 | } |
845 | |
846 | { |
847 | // Multiply a polynomial by a term. |
848 | // - term with no scalar, poly with non-identity scalar. |
849 | // x * (y + 1) => x + x * y |
850 | ExprHandle body = x * (y + ExprHandle(1)); |
851 | ExprHandle simplified = IRSimplifier::simplify(body); |
852 | |
853 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
854 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
855 | IS_NODE_WITH_NAME(Mul, add->rhs(), mul); |
856 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
857 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
858 | } |
859 | |
860 | { |
861 | // Multiply a polynomial by a term. |
862 | // - term with identity scalar, poly with non-identity scalar. |
863 | // (x * 1) * (y + 1) => x + x * y |
864 | ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); |
865 | ExprHandle simplified = IRSimplifier::simplify(body); |
866 | |
867 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
868 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
869 | IS_NODE_WITH_NAME(Mul, add->rhs(), mul); |
870 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
871 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
872 | } |
873 | |
874 | { |
875 | // Multiply a polynomial by a term. |
876 | // - term with non-identity scalar, poly with non-identity scalar. |
877 | // (x * 2) * (y + 1) => 2 * (x + x * y) |
878 | ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); |
879 | ExprHandle simplified = IRSimplifier::simplify(body); |
880 | |
881 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
882 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
883 | IS_NODE_WITH_NAME(Add, mul->rhs(), add); |
884 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
885 | IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); |
886 | IS_VAR_WITH_NAME(mul2->lhs(), "x" ); |
887 | IS_VAR_WITH_NAME(mul2->rhs(), "y" ); |
888 | } |
889 | |
890 | { |
891 | // Multiply a polynomial by a term. |
892 | // - term with non-identity scalar, poly with identity scalar. |
893 | // (x * 2) * (y + 0) => 2 * (x * y) |
894 | ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); |
895 | ExprHandle simplified = IRSimplifier::simplify(body); |
896 | |
897 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
898 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
899 | IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); |
900 | IS_VAR_WITH_NAME(mul2->lhs(), "x" ); |
901 | IS_VAR_WITH_NAME(mul2->rhs(), "y" ); |
902 | } |
903 | |
904 | { |
905 | // Multiply a polynomial by a term. |
906 | // - term with identity scalar, poly with identity scalar. |
907 | // (x * 1) * (y + 0) => x * y |
908 | ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); |
909 | ExprHandle simplified = IRSimplifier::simplify(body); |
910 | |
911 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
912 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
913 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
914 | } |
915 | |
916 | { |
917 | // Multiply a polynomial by a term. |
918 | // - term with no scalar, poly with identity scalar. |
919 | // x * (y + 0) => x * y |
920 | ExprHandle body = x * (y + ExprHandle(0)); |
921 | ExprHandle simplified = IRSimplifier::simplify(body); |
922 | |
923 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
924 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
925 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
926 | } |
927 | } |
928 | |
929 | // Sub an expr from itself will result in zero. |
930 | TEST(Simplify, SimplifySubs) { |
931 | VarHandle x("x" , kInt); |
932 | VarHandle y("y" , kInt); |
933 | |
934 | { |
935 | // (x + y) - (x + y) => 0 |
936 | ExprHandle body = (x + y) - (x + y); |
937 | ExprHandle simplified = IRSimplifier::simplify(body); |
938 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
939 | } |
940 | |
941 | { |
942 | // (x * y) - (x * y) => 0 |
943 | ExprHandle body = (x * y) - (x * y); |
944 | ExprHandle simplified = IRSimplifier::simplify(body); |
945 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
946 | } |
947 | |
948 | { |
949 | // (x - y) - (x - y) => 0 |
950 | ExprHandle body = (x - y) - (x - y); |
951 | ExprHandle simplified = IRSimplifier::simplify(body); |
952 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
953 | } |
954 | |
955 | { |
956 | // (x + y) - 2 * (x + y) => -1 * x - y |
957 | ExprHandle body = (x + y) - ExprHandle(2) * (x + y); |
958 | ExprHandle simplified = IRSimplifier::simplify(body); |
959 | |
960 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
961 | IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); |
962 | IS_IMM_WITH_VAL(Int, mul->lhs(), -1); |
963 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
964 | IS_VAR_WITH_NAME(sub->rhs(), "y" ); |
965 | } |
966 | |
967 | { |
968 | // (x + y) - y => x |
969 | ExprHandle body = (x + y) - y; |
970 | ExprHandle simplified = IRSimplifier::simplify(body); |
971 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
972 | } |
973 | |
974 | { |
975 | // (x - 0) => x. |
976 | ExprHandle body = x - 0; |
977 | ExprHandle simplified = IRSimplifier::simplify(body); |
978 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
979 | } |
980 | |
981 | { |
982 | // (x - 0.f) => x. |
983 | // Simple enough to cancel in float. |
984 | ExprHandle body = x - ExprHandle(0.f); |
985 | ExprHandle simplified = IRSimplifier::simplify(body); |
986 | IS_NODE_WITH_NAME(Cast, simplified.node(), cast); |
987 | ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); |
988 | IS_VAR_WITH_NAME(cast->src_value(), "x" ); |
989 | } |
990 | |
991 | { |
992 | // (x - (float)(y - y)) => x. |
993 | ExprHandle body = x - Cast::make(kFloat, y - y); |
994 | ExprHandle simplified = IRSimplifier::simplify(body); |
995 | IS_NODE_WITH_NAME(Cast, simplified.node(), cast); |
996 | ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); |
997 | IS_VAR_WITH_NAME(cast->src_value(), "x" ); |
998 | } |
999 | |
1000 | { |
1001 | // (x - y) - y => x - 2 * y |
1002 | ExprHandle body = (x - y) - y; |
1003 | ExprHandle simplified = IRSimplifier::simplify(body); |
1004 | |
1005 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
1006 | IS_VAR_WITH_NAME(sub->lhs(), "x" ); |
1007 | IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); |
1008 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
1009 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
1010 | } |
1011 | |
1012 | { |
1013 | // 2 * x - x => x |
1014 | ExprHandle body = (ExprHandle(2) * x) - x; |
1015 | ExprHandle simplified = IRSimplifier::simplify(body); |
1016 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
1017 | } |
1018 | |
1019 | { |
1020 | // x - 2 * x = -1 * x |
1021 | // We don't have a unary negate, but this could be 0 -x I guess? |
1022 | ExprHandle body = x - (ExprHandle(2) * x); |
1023 | ExprHandle simplified = IRSimplifier::simplify(body); |
1024 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1025 | |
1026 | IS_IMM_WITH_VAL(Int, mul->lhs(), -1); |
1027 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
1028 | } |
1029 | |
1030 | { |
1031 | // (x + y + 5) * (x - x) => 0 |
1032 | // Cancelling out one side of Mul cancels both. |
1033 | ExprHandle body = (x + y + 5) * (x - x); |
1034 | ExprHandle simplified = IRSimplifier::simplify(body); |
1035 | |
1036 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1037 | } |
1038 | |
1039 | { |
1040 | // Cancel out opaque modulus. |
1041 | ExprHandle body = (x % y + 2) - (x % y); |
1042 | ExprHandle simplified = IRSimplifier::simplify(body); |
1043 | IS_IMM_WITH_VAL(Int, simplified.node(), 2); |
1044 | } |
1045 | |
1046 | { |
1047 | // Cancel out opaque modulus with a bit more going on. |
1048 | ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); |
1049 | ExprHandle simplified = IRSimplifier::simplify(body); |
1050 | IS_IMM_WITH_VAL(Int, simplified.node(), 2); |
1051 | } |
1052 | |
1053 | { |
1054 | // Sub where result is negative. |
1055 | ExprHandle body = x - (x + 1); |
1056 | ExprHandle simplified = IRSimplifier::simplify(body); |
1057 | IS_IMM_WITH_VAL(Int, simplified.node(), -1); |
1058 | } |
1059 | |
1060 | { |
1061 | // Sub where result is positive due to negative scalar on RHS. |
1062 | ExprHandle body = x - (x - 1); |
1063 | ExprHandle simplified = IRSimplifier::simplify(body); |
1064 | IS_IMM_WITH_VAL(Int, simplified.node(), 1); |
1065 | } |
1066 | |
1067 | { |
1068 | // Term - Polynomial sub where RHS must be negated. |
1069 | ExprHandle body = (x * 2) - (x * 2 + 1); |
1070 | ExprHandle simplified = IRSimplifier::simplify(body); |
1071 | IS_IMM_WITH_VAL(Int, simplified.node(), -1); |
1072 | } |
1073 | |
1074 | { |
1075 | // Term - Polynomial sub where the result is a Term. |
1076 | ExprHandle body = (y * x * 2) - (x * y); |
1077 | ExprHandle simplified = IRSimplifier::simplify(body); |
1078 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1079 | |
1080 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
1081 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
1082 | } |
1083 | |
1084 | { |
1085 | // Term - Polynomial sub where the result is a Polynomial. |
1086 | ExprHandle body = (x * 2) - (x + 1); |
1087 | ExprHandle simplified = IRSimplifier::simplify(body); |
1088 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
1089 | |
1090 | IS_VAR_WITH_NAME(sub->lhs(), "x" ); |
1091 | IS_IMM_WITH_VAL(Int, sub->rhs(), 1); |
1092 | } |
1093 | } |
1094 | |
1095 | TEST(Simplify, SimplifyDiv) { |
1096 | VarHandle x("x" , kInt); |
1097 | |
1098 | { |
1099 | ExprHandle body = ExprHandle(0) / x; |
1100 | ExprHandle simplified = IRSimplifier::simplify(body); |
1101 | |
1102 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1103 | } |
1104 | |
1105 | { |
1106 | ExprHandle body = x / 1; |
1107 | ExprHandle simplified = IRSimplifier::simplify(body); |
1108 | |
1109 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
1110 | } |
1111 | } |
1112 | |
1113 | TEST(Simplify, SimplifyDivWithLoopContext0) { |
1114 | // Stmt to simplify: |
1115 | // for (int i = 0; i < 100; i++) { |
1116 | // A[i] = i / 100; |
1117 | //} |
1118 | VarHandle i("i" , kInt); |
1119 | BufHandle a_buf("A" , {100}, kInt); |
1120 | auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); |
1121 | |
1122 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1123 | |
1124 | std::ostringstream oss; |
1125 | oss << *(simplified); |
1126 | const std::string& verification_pattern = |
1127 | R"IR( |
1128 | # CHECK: for (int i |
1129 | # CHECK-NEXT: A[i] = 0; |
1130 | )IR" ; |
1131 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1132 | } |
1133 | |
1134 | TEST(Simplify, SimplifyDivWithLoopContext1) { |
1135 | // Stmt to simplify: |
1136 | // for (const auto i : c10::irange(6)) { |
1137 | // A[i] = (i + 24) / 6; |
1138 | //} |
1139 | VarHandle i("i" , kInt); |
1140 | BufHandle a_buf("A" , {6}, kInt); |
1141 | auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); |
1142 | |
1143 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1144 | |
1145 | std::ostringstream oss; |
1146 | oss << *(simplified); |
1147 | const std::string& verification_pattern = |
1148 | R"IR( |
1149 | # CHECK: for (int i |
1150 | # CHECK-NEXT: A[i] = 4; |
1151 | )IR" ; |
1152 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1153 | } |
1154 | |
1155 | TEST(Simplify, SimplifyDivWithLoopContext2) { |
1156 | // Stmt to simplify: |
1157 | // for (const auto i : c10::irange(5)) { |
1158 | // A[i] = (i + 25) / 6; |
1159 | //} |
1160 | VarHandle i("i" , kInt); |
1161 | BufHandle a_buf("A" , {5}, kInt); |
1162 | auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); |
1163 | |
1164 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1165 | |
1166 | std::ostringstream oss; |
1167 | oss << *(simplified); |
1168 | const std::string& verification_pattern = |
1169 | R"IR( |
1170 | # CHECK: for (int i |
1171 | # CHECK-NEXT: A[i] = 4; |
1172 | )IR" ; |
1173 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1174 | } |
1175 | |
1176 | TEST(Simplify, SimplifyDivWithLoopContext3) { |
1177 | // Stmt to simplify: |
1178 | // for (const auto i : c10::irange(6)) { |
1179 | // A[i] = (i + 24) / (-6); |
1180 | //} |
1181 | VarHandle i("i" , kInt); |
1182 | BufHandle a_buf("A" , {6}, kInt); |
1183 | auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); |
1184 | |
1185 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1186 | |
1187 | std::ostringstream oss; |
1188 | oss << *(simplified); |
1189 | const std::string& verification_pattern = |
1190 | R"IR( |
1191 | # CHECK: for (int i |
1192 | # CHECK-NOT: A[i] = -4; |
1193 | )IR" ; |
1194 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1195 | } |
1196 | |
1197 | TEST(Simplify, SimplifyDivWithLoopContext4) { |
1198 | // Stmt to simplify: |
1199 | // for (const auto i : c10::irange(5)) { |
1200 | // A[i] = (i - 5) / 6; |
1201 | //} |
1202 | VarHandle i("i" , kInt); |
1203 | BufHandle a_buf("A" , {5}, kInt); |
1204 | auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); |
1205 | |
1206 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1207 | |
1208 | std::ostringstream oss; |
1209 | oss << *(simplified); |
1210 | const std::string& verification_pattern = |
1211 | R"IR( |
1212 | # CHECK: for (int i |
1213 | # CHECK-NOT: A[i] = 0; |
1214 | )IR" ; |
1215 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1216 | } |
1217 | |
1218 | TEST(Simplify, SimplifyDivWithLoopContext5) { |
1219 | // Stmt to simplify: |
1220 | // for (const auto i : c10::irange(6)) { |
1221 | // for (const auto j : c10::irange(10)) { |
1222 | // A[i, j] = (i + 6*j) / 6; |
1223 | // } |
1224 | //} |
1225 | VarHandle i("i" , kInt); |
1226 | VarHandle j("j" , kInt); |
1227 | BufHandle a_buf("A" , {6, 10}, kInt); |
1228 | auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); |
1229 | auto for_i = For::make(i, 0, 6, for_j); |
1230 | |
1231 | const StmtPtr simplified = IRSimplifier::simplify(for_i); |
1232 | |
1233 | std::ostringstream oss; |
1234 | oss << *(simplified); |
1235 | const std::string& verification_pattern = |
1236 | R"IR( |
1237 | # CHECK: for (int i |
1238 | # CHECK: for (int j |
1239 | # CHECK-NEXT: A[i, j] = j; |
1240 | )IR" ; |
1241 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1242 | } |
1243 | |
1244 | TEST(Simplify, SimplifyDivWithLoopContext6) { |
1245 | // Stmt to simplify: |
1246 | // for (const auto i : c10::irange(6)) { |
1247 | // for (int j = -1; j < 9; j++) { |
1248 | // A[i, j+1] = (i + 6*j) / 6; |
1249 | // } |
1250 | //} |
1251 | VarHandle i("i" , kInt); |
1252 | VarHandle j("j" , kInt); |
1253 | BufHandle a_buf("A" , {6, 10}, kInt); |
1254 | auto for_j = |
1255 | For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); |
1256 | auto for_i = For::make(i, 0, 6, for_j); |
1257 | |
1258 | const StmtPtr simplified = IRSimplifier::simplify(for_i); |
1259 | |
1260 | std::ostringstream oss; |
1261 | oss << *(simplified); |
1262 | const std::string& verification_pattern = |
1263 | R"IR( |
1264 | # CHECK: for (int i |
1265 | # CHECK: for (int j |
1266 | # CHECK-NOT: A[i, j] = j; |
1267 | )IR" ; |
1268 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1269 | } |
1270 | |
1271 | TEST(Simplify, SimplifyDivWithLoopContext7) { |
1272 | // Stmt to simplify: |
1273 | // for (const auto i : c10::irange(6)) { |
1274 | // for (const auto j : c10::irange(10)) { |
1275 | // A[i, j] = (i + 6*j) / (-6); |
1276 | // } |
1277 | //} |
1278 | VarHandle i("i" , kInt); |
1279 | VarHandle j("j" , kInt); |
1280 | BufHandle a_buf("A" , {6, 10}, kInt); |
1281 | auto for_j = |
1282 | For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); |
1283 | auto for_i = For::make(i, 0, 6, for_j); |
1284 | |
1285 | const StmtPtr simplified = IRSimplifier::simplify(for_i); |
1286 | |
1287 | std::ostringstream oss; |
1288 | oss << *(simplified); |
1289 | const std::string& verification_pattern = |
1290 | R"IR( |
1291 | # CHECK: for (int i |
1292 | # CHECK: for (int j |
1293 | # CHECK-NOT: A[i, j] = -j; |
1294 | )IR" ; |
1295 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1296 | } |
1297 | |
1298 | TEST(Simplify, SimplifyModWithLoopContext0) { |
1299 | // Stmt to simplify: |
1300 | // for (const auto i : c10::irange(100)) { |
1301 | // A[i] = i % 100; |
1302 | //} |
1303 | VarHandle i("i" , kInt); |
1304 | BufHandle a_buf("A" , {100}, kInt); |
1305 | auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); |
1306 | |
1307 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1308 | |
1309 | std::ostringstream oss; |
1310 | oss << *(simplified); |
1311 | const std::string& verification_pattern = |
1312 | R"IR( |
1313 | # CHECK: for (int i |
1314 | # CHECK-NEXT: A[i] = i; |
1315 | )IR" ; |
1316 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1317 | } |
1318 | |
1319 | TEST(Simplify, SimplifyModWithLoopContext1) { |
1320 | // Stmt to simplify: |
1321 | // for (const auto i : c10::irange(6)) { |
1322 | // A[i] = (i + 24) % 6; |
1323 | //} |
1324 | VarHandle i("i" , kInt); |
1325 | BufHandle a_buf("A" , {6}, kInt); |
1326 | auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); |
1327 | |
1328 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1329 | |
1330 | std::ostringstream oss; |
1331 | oss << *(simplified); |
1332 | const std::string& verification_pattern = |
1333 | R"IR( |
1334 | # CHECK: for (int i |
1335 | # CHECK-NEXT: A[i] = i; |
1336 | )IR" ; |
1337 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1338 | } |
1339 | |
1340 | TEST(Simplify, SimplifyModWithLoopContext2) { |
1341 | // Stmt to simplify: |
1342 | // for (const auto i : c10::irange(5)) { |
1343 | // A[i] = (i + 25) % 6; |
1344 | //} |
1345 | VarHandle i("i" , kInt); |
1346 | BufHandle a_buf("A" , {5}, kInt); |
1347 | auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); |
1348 | |
1349 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1350 | |
1351 | std::ostringstream oss; |
1352 | oss << *(simplified); |
1353 | const std::string& verification_pattern = |
1354 | R"IR( |
1355 | # CHECK: for (int i |
1356 | # CHECK-NEXT: A[i] = i + 1; |
1357 | )IR" ; |
1358 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1359 | } |
1360 | |
1361 | TEST(Simplify, SimplifyModWithLoopContext3) { |
1362 | // Stmt to simplify: |
1363 | // for (const auto i : c10::irange(6)) { |
1364 | // A[i] = (i + 24) % (-6); |
1365 | //} |
1366 | VarHandle i("i" , kInt); |
1367 | BufHandle a_buf("A" , {6}, kInt); |
1368 | auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); |
1369 | |
1370 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1371 | |
1372 | std::ostringstream oss; |
1373 | oss << *(simplified); |
1374 | const std::string& verification_pattern = |
1375 | R"IR( |
1376 | # CHECK: for (int i |
1377 | # CHECK-NOT: A[i] = i; |
1378 | )IR" ; |
1379 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1380 | } |
1381 | |
1382 | TEST(Simplify, SimplifyModWithLoopContext4) { |
1383 | // Stmt to simplify: |
1384 | // for (const auto i : c10::irange(5)) { |
1385 | // A[i] = (i - 5) % 6; |
1386 | //} |
1387 | VarHandle i("i" , kInt); |
1388 | BufHandle a_buf("A" , {5}, kInt); |
1389 | auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); |
1390 | |
1391 | const StmtPtr simplified = IRSimplifier::simplify(for_stmt); |
1392 | |
1393 | std::ostringstream oss; |
1394 | oss << *(simplified); |
1395 | const std::string& verification_pattern = |
1396 | R"IR( |
1397 | # CHECK: for (int i |
1398 | # CHECK-NOT: A[i] = i - 5; |
1399 | )IR" ; |
1400 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1401 | } |
1402 | |
1403 | TEST(Simplify, SimplifyModWithLoopContext5) { |
1404 | // Stmt to simplify: |
1405 | // for (const auto i : c10::irange(6)) { |
1406 | // for (const auto j : c10::irange(10)) { |
1407 | // A[i, j] = (i + 6*j) % 6; |
1408 | // } |
1409 | //} |
1410 | VarHandle i("i" , kInt); |
1411 | VarHandle j("j" , kInt); |
1412 | BufHandle a_buf("A" , {6, 10}, kInt); |
1413 | auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); |
1414 | auto for_i = For::make(i, 0, 6, for_j); |
1415 | |
1416 | const StmtPtr simplified = IRSimplifier::simplify(for_i); |
1417 | |
1418 | std::ostringstream oss; |
1419 | oss << *(simplified); |
1420 | const std::string& verification_pattern = |
1421 | R"IR( |
1422 | # CHECK: for (int i |
1423 | # CHECK: for (int j |
1424 | # CHECK-NEXT: A[i, j] = i; |
1425 | )IR" ; |
1426 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1427 | } |
1428 | |
1429 | TEST(Simplify, SimplifyModWithLoopContext6) { |
1430 | // Stmt to simplify: |
1431 | // for (const auto i : c10::irange(6)) { |
1432 | // for (int j = -1; j < 9; j++) { |
1433 | // A[i, j+1] = (i + 6*j) % 6; |
1434 | // } |
1435 | //} |
1436 | VarHandle i("i" , kInt); |
1437 | VarHandle j("j" , kInt); |
1438 | BufHandle a_buf("A" , {6, 10}, kInt); |
1439 | auto for_j = |
1440 | For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); |
1441 | auto for_i = For::make(i, 0, 6, for_j); |
1442 | |
1443 | const StmtPtr simplified = IRSimplifier::simplify(for_i); |
1444 | |
1445 | std::ostringstream oss; |
1446 | oss << *(simplified); |
1447 | const std::string& verification_pattern = |
1448 | R"IR( |
1449 | # CHECK: for (int i |
1450 | # CHECK: for (int j |
1451 | # CHECK-NOT: A[i, j] = i; |
1452 | )IR" ; |
1453 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1454 | } |
1455 | |
1456 | TEST(Simplify, SimplifyModWithLoopContext7) { |
1457 | // Stmt to simplify: |
1458 | // for (const auto i : c10::irange(6)) { |
1459 | // for (const auto j : c10::irange(10)) { |
1460 | // A[i, j] = (i + 6*j) % (-6); |
1461 | // } |
1462 | //} |
1463 | VarHandle i("i" , kInt); |
1464 | VarHandle j("j" , kInt); |
1465 | BufHandle a_buf("A" , {6, 10}, kInt); |
1466 | auto for_j = |
1467 | For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); |
1468 | auto for_i = For::make(i, 0, 6, for_j); |
1469 | |
1470 | const StmtPtr simplified = IRSimplifier::simplify(for_i); |
1471 | |
1472 | std::ostringstream oss; |
1473 | oss << *(simplified); |
1474 | const std::string& verification_pattern = |
1475 | R"IR( |
1476 | # CHECK: for (int i |
1477 | # CHECK: for (int j |
1478 | # CHECK-NOT: A[i, j] = i; |
1479 | )IR" ; |
1480 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1481 | } |
1482 | |
1483 | TEST(Simplify, SimplifyMod) { |
1484 | VarHandle x("x" , kInt); |
1485 | VarHandle y("y" , kInt); |
1486 | VarHandle z("z" , kInt); |
1487 | |
1488 | { |
1489 | // Constant folding works. |
1490 | ExprHandle body = ExprHandle(10) % 8; |
1491 | ExprHandle simplified = IRSimplifier::simplify(body); |
1492 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
1493 | IS_IMM_WITH_VAL(Int, simplified.node(), 2); |
1494 | } |
1495 | |
1496 | { |
1497 | // x % x => 0 |
1498 | ExprHandle body = x % x; |
1499 | ExprHandle simplified = IRSimplifier::simplify(body); |
1500 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1501 | } |
1502 | |
1503 | { |
1504 | // 0 % x => 0 |
1505 | ExprHandle body = ExprHandle(0) % x; |
1506 | ExprHandle simplified = IRSimplifier::simplify(body); |
1507 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1508 | } |
1509 | |
1510 | { |
1511 | // x % 1 => 0 |
1512 | ExprHandle body = x % 1; |
1513 | ExprHandle simplified = IRSimplifier::simplify(body); |
1514 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1515 | } |
1516 | |
1517 | { |
1518 | // Doesn't change unknown mods. |
1519 | // x % y => x % y |
1520 | ExprHandle body = x % y; |
1521 | ExprHandle simplified = IRSimplifier::simplify(body); |
1522 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
1523 | IS_VAR_WITH_NAME(mod->lhs(), "x" ); |
1524 | IS_VAR_WITH_NAME(mod->rhs(), "y" ); |
1525 | } |
1526 | |
1527 | { |
1528 | // don't touch if RHS is unknown. |
1529 | // 4 % x => 4 % x |
1530 | ExprHandle body = ExprHandle(4) % x; |
1531 | ExprHandle simplified = IRSimplifier::simplify(body); |
1532 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
1533 | IS_IMM_WITH_VAL(Int, mod->lhs(), 4); |
1534 | IS_VAR_WITH_NAME(mod->rhs(), "x" ); |
1535 | } |
1536 | |
1537 | { |
1538 | // don't touch if LHS is unknown. |
1539 | // x % 4 => x % 4 |
1540 | ExprHandle body = x % 4; |
1541 | ExprHandle simplified = IRSimplifier::simplify(body); |
1542 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
1543 | IS_VAR_WITH_NAME(mod->lhs(), "x" ); |
1544 | IS_IMM_WITH_VAL(Int, mod->rhs(), 4); |
1545 | } |
1546 | |
1547 | { |
1548 | // if LHS is a multiple of RHS, mod is zero. |
1549 | // 2 * x % x => 0 |
1550 | ExprHandle body = (x * 2) % x; |
1551 | ExprHandle simplified = IRSimplifier::simplify(body); |
1552 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1553 | } |
1554 | |
1555 | { |
1556 | // true even if the multiple is not constant. |
1557 | // x * y % x => 0 |
1558 | ExprHandle body = (x * y) % x; |
1559 | ExprHandle simplified = IRSimplifier::simplify(body); |
1560 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1561 | } |
1562 | |
1563 | { |
1564 | // true with multiple unknown values in LHS. |
1565 | // x * y * z % x => 0 |
1566 | ExprHandle body = (x * y * z) % x; |
1567 | ExprHandle simplified = IRSimplifier::simplify(body); |
1568 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1569 | } |
1570 | |
1571 | { |
1572 | // true if the denom is compound. |
1573 | // x * y * z % y * z => 0 |
1574 | ExprHandle body = (x * y * z) % (y * z); |
1575 | ExprHandle simplified = IRSimplifier::simplify(body); |
1576 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1577 | } |
1578 | |
1579 | { |
1580 | // Sanity check true with scalars that are multiples. |
1581 | // 12 * x % 4 => 0 |
1582 | ExprHandle body = (x * 12) % 4; |
1583 | ExprHandle simplified = IRSimplifier::simplify(body); |
1584 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1585 | } |
1586 | |
1587 | { |
1588 | // Sanity check not true if the smaller scalar is on LHS. |
1589 | // 4 * x % 12 => 4 * x % 12 |
1590 | ExprHandle body = (x * 4) % 12; |
1591 | ExprHandle simplified = IRSimplifier::simplify(body); |
1592 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
1593 | IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); |
1594 | IS_IMM_WITH_VAL(Int, mul->lhs(), 4); |
1595 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
1596 | IS_IMM_WITH_VAL(Int, mod->rhs(), 12); |
1597 | } |
1598 | |
1599 | { |
1600 | // Both scalar and symbolic in multiple. |
1601 | // (6 * x * y) % (3 * x * y) => 0 |
1602 | ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); |
1603 | ExprHandle simplified = IRSimplifier::simplify(body); |
1604 | IS_IMM_WITH_VAL(Int, simplified.node(), 0); |
1605 | } |
1606 | } |
1607 | |
1608 | // Test that mixing ops together simplifies as expected. |
1609 | TEST(Simplify, SimplifyMultiOp) { |
1610 | VarHandle x("x" , kInt); |
1611 | VarHandle y("y" , kInt); |
1612 | |
1613 | { |
1614 | // (x * y) + (x - y) => (x + x * y) - y |
1615 | ExprHandle body = (x * y) + (x - y); |
1616 | ExprHandle simplified = IRSimplifier::simplify(body); |
1617 | |
1618 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
1619 | IS_NODE_WITH_NAME(Add, sub->lhs(), add); |
1620 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1621 | IS_NODE_WITH_NAME(Mul, add->rhs(), mul); |
1622 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
1623 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
1624 | IS_VAR_WITH_NAME(sub->rhs(), "y" ); |
1625 | } |
1626 | |
1627 | { |
1628 | // (x + y) - x * y => (x + y) - x * y |
1629 | ExprHandle body = (x + y) - x * y; |
1630 | ExprHandle simplified = IRSimplifier::simplify(body); |
1631 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
1632 | IS_NODE_WITH_NAME(Add, sub->lhs(), add); |
1633 | IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); |
1634 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1635 | IS_VAR_WITH_NAME(add->rhs(), "y" ); |
1636 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
1637 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
1638 | } |
1639 | |
1640 | { |
1641 | // (x - y) - (x + y) => -2 * y |
1642 | ExprHandle body = (x - y) - (x + y); |
1643 | ExprHandle simplified = IRSimplifier::simplify(body); |
1644 | |
1645 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1646 | IS_IMM_WITH_VAL(Int, mul->lhs(), -2); |
1647 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
1648 | } |
1649 | |
1650 | { |
1651 | // (x - 0) + (x * 1) - (x + 0) => x |
1652 | ExprHandle body = (x - 0) + (x * 1) - (x + 0); |
1653 | ExprHandle simplified = IRSimplifier::simplify(body); |
1654 | |
1655 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
1656 | } |
1657 | |
1658 | { |
1659 | // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) |
1660 | // Even in Float simple terms cancel out, but the variable ones cannot. |
1661 | ExprHandle body = |
1662 | (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); |
1663 | ExprHandle simplified = IRSimplifier::simplify(body); |
1664 | |
1665 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
1666 | IS_NODE_WITH_NAME(Add, sub->lhs(), add); |
1667 | IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); |
1668 | IS_VAR_WITH_NAME(cast1->src_value(), "x" ); |
1669 | IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); |
1670 | IS_VAR_WITH_NAME(cast2->src_value(), "x" ); |
1671 | IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); |
1672 | IS_VAR_WITH_NAME(cast3->src_value(), "x" ); |
1673 | } |
1674 | } |
1675 | |
1676 | // Test that chaining many ops together works as expected. |
1677 | TEST(Simplify, SimplifyManyOps) { |
1678 | VarHandle x("x" , kInt); |
1679 | VarHandle y("y" , kInt); |
1680 | |
1681 | { |
1682 | // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x |
1683 | ExprHandle body = x + y + x + x + y + y + x + y + x; |
1684 | ExprHandle simplified = IRSimplifier::simplify(body); |
1685 | |
1686 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
1687 | |
1688 | IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); |
1689 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); |
1690 | IS_VAR_WITH_NAME(lhs->rhs(), "y" ); |
1691 | |
1692 | IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); |
1693 | IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); |
1694 | IS_VAR_WITH_NAME(rhs->rhs(), "x" ); |
1695 | } |
1696 | |
1697 | { |
1698 | // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y |
1699 | ExprHandle body = x - y + x + x - y - y + x - y + x; |
1700 | ExprHandle simplified = IRSimplifier::simplify(body); |
1701 | |
1702 | IS_NODE_WITH_NAME(Sub, simplified.node(), add); |
1703 | |
1704 | IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); |
1705 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); |
1706 | IS_VAR_WITH_NAME(lhs->rhs(), "x" ); |
1707 | |
1708 | IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); |
1709 | IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); |
1710 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
1711 | } |
1712 | |
1713 | { |
1714 | // x + y + x - x - y - y + x + y + x = 3 * x |
1715 | ExprHandle body = x + y + x - x - y - y + x + y + x; |
1716 | ExprHandle simplified = IRSimplifier::simplify(body); |
1717 | |
1718 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1719 | IS_IMM_WITH_VAL(Int, mul->lhs(), 3); |
1720 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
1721 | } |
1722 | } |
1723 | |
1724 | TEST(Simplify, SimplifyFactorization) { |
1725 | VarHandle x("x" , kInt); |
1726 | VarHandle y("y" , kInt); |
1727 | |
1728 | { |
1729 | // (2 * x) + (2 * y) => 2 * (x + y) |
1730 | ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); |
1731 | ExprHandle simplified = IRSimplifier::simplify(body); |
1732 | |
1733 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1734 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
1735 | |
1736 | IS_NODE_WITH_NAME(Add, mul->rhs(), add); |
1737 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1738 | IS_VAR_WITH_NAME(add->rhs(), "y" ); |
1739 | } |
1740 | |
1741 | { |
1742 | // Factorization when scalars have common divider. |
1743 | // (2 * x) + (4 * y) => 2 * (2 * y + x) |
1744 | ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); |
1745 | ExprHandle simplified = IRSimplifier::simplify(body); |
1746 | |
1747 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1748 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
1749 | |
1750 | IS_NODE_WITH_NAME(Add, mul->rhs(), add); |
1751 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1752 | IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); |
1753 | IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); |
1754 | IS_VAR_WITH_NAME(mul2->rhs(), "y" ); |
1755 | } |
1756 | |
1757 | { |
1758 | // Factorization attempt without a common divider. |
1759 | // (2 * x) + (5 * y) => (5 * y) + (2 * x) |
1760 | ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); |
1761 | ExprHandle simplified = IRSimplifier::simplify(body); |
1762 | |
1763 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
1764 | |
1765 | IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); |
1766 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); |
1767 | IS_VAR_WITH_NAME(lhs->rhs(), "x" ); |
1768 | |
1769 | IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); |
1770 | IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); |
1771 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
1772 | } |
1773 | |
1774 | { |
1775 | // Factorization after merging. |
1776 | // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) |
1777 | ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + |
1778 | (ExprHandle(8) * x + ExprHandle(6) * y); |
1779 | ExprHandle simplified = IRSimplifier::simplify(body); |
1780 | |
1781 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1782 | IS_IMM_WITH_VAL(Int, mul->lhs(), 10); |
1783 | |
1784 | IS_NODE_WITH_NAME(Add, mul->rhs(), add); |
1785 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1786 | IS_VAR_WITH_NAME(add->rhs(), "y" ); |
1787 | } |
1788 | |
1789 | { |
1790 | // Factorization with common divider but different signs. |
1791 | // (2 * x) + (-4 * y) => 2 * (x - 2 * y) |
1792 | ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); |
1793 | ExprHandle simplified = IRSimplifier::simplify(body); |
1794 | |
1795 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1796 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
1797 | |
1798 | IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); |
1799 | IS_VAR_WITH_NAME(sub->lhs(), "x" ); |
1800 | IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); |
1801 | IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); |
1802 | IS_VAR_WITH_NAME(mul2->rhs(), "y" ); |
1803 | } |
1804 | |
1805 | { |
1806 | // Factorization with all negative numbers. |
1807 | // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) |
1808 | ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; |
1809 | ExprHandle simplified = IRSimplifier::simplify(body); |
1810 | |
1811 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1812 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
1813 | |
1814 | IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); |
1815 | IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); |
1816 | IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); |
1817 | IS_VAR_WITH_NAME(mul2->rhs(), "x" ); |
1818 | IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); |
1819 | IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); |
1820 | IS_VAR_WITH_NAME(mul3->rhs(), "y" ); |
1821 | } |
1822 | |
1823 | { |
1824 | // The following test ensures that there in no infinite recursion during |
1825 | // factorization when negative numbers are involved. |
1826 | VarHandle a("a" , kInt); |
1827 | VarHandle b("b" , kInt); |
1828 | VarHandle c("c" , kInt); |
1829 | VarHandle d("d" , kInt); |
1830 | VarHandle e("e" , kInt); |
1831 | VarHandle f("f" , kInt); |
1832 | VarHandle g("g" , kInt); |
1833 | VarHandle h("h" , kInt); |
1834 | |
1835 | ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + |
1836 | f * 32 + g * (-1024) + h * (-32); |
1837 | ExprHandle simplified = IRSimplifier::simplify(body); |
1838 | checkExprIR( |
1839 | simplified, |
1840 | "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h" ); |
1841 | } |
1842 | } |
1843 | |
1844 | // (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) |
1845 | TEST(Simplify, SimplifyFactorizeUneven) { |
1846 | VarHandle x("x" , kInt); |
1847 | VarHandle y("y" , kInt); |
1848 | VarHandle z("z" , kInt); |
1849 | ExprHandle body = |
1850 | (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); |
1851 | ExprHandle simplified = IRSimplifier::simplify(body); |
1852 | |
1853 | IS_NODE_WITH_NAME(Mul, simplified.node(), root); |
1854 | IS_IMM_WITH_VAL(Int, root->lhs(), 2); |
1855 | IS_NODE_WITH_NAME(Add, root->rhs(), add1); |
1856 | IS_NODE_WITH_NAME(Add, add1->lhs(), add2); |
1857 | |
1858 | IS_VAR_WITH_NAME(add2->lhs(), "y" ); |
1859 | IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); |
1860 | IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); |
1861 | |
1862 | IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); |
1863 | IS_VAR_WITH_NAME(xmul->rhs(), "x" ); |
1864 | |
1865 | IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); |
1866 | IS_VAR_WITH_NAME(zmul->rhs(), "z" ); |
1867 | } |
1868 | |
1869 | // (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) |
1870 | // This is kind of a placeholder test for variable factorization. |
1871 | TEST(Simplify, SimplifyDeeperTerms) { |
1872 | VarHandle x("x" , kInt); |
1873 | VarHandle y("y" , kInt); |
1874 | ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); |
1875 | ExprHandle simplified = IRSimplifier::simplify(body); |
1876 | |
1877 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
1878 | |
1879 | IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); |
1880 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); |
1881 | IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); |
1882 | IS_VAR_WITH_NAME(xxTerm->lhs(), "x" ); |
1883 | IS_VAR_WITH_NAME(xxTerm->rhs(), "x" ); |
1884 | |
1885 | IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); |
1886 | IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); |
1887 | IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); |
1888 | IS_VAR_WITH_NAME(xyTerm->lhs(), "x" ); |
1889 | IS_VAR_WITH_NAME(xyTerm->rhs(), "y" ); |
1890 | } |
1891 | |
1892 | // Tests the difference between two less trivial expressions. |
1893 | // (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 |
1894 | TEST(Simplify, SimplifyDeeperDifference) { |
1895 | VarHandle n("n" , kInt); |
1896 | VarHandle n_1("n_1" , kInt); |
1897 | VarHandle m("m" , kInt); |
1898 | ExprHandle body = |
1899 | (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); |
1900 | ExprHandle simplified = IRSimplifier::simplify(body); |
1901 | |
1902 | IS_IMM_WITH_VAL(Int, simplified.node(), 1); |
1903 | } |
1904 | |
1905 | // Test constant folding into the difference between expressions. |
1906 | // 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 |
1907 | TEST(Simplify, SimplifyFoldComplexDifference) { |
1908 | VarHandle n("n" , kInt); |
1909 | VarHandle n_1("n_1" , kInt); |
1910 | VarHandle m("m" , kInt); |
1911 | ExprHandle body = |
1912 | (IntImm::make(2) + |
1913 | (Cast::make( |
1914 | kChar, |
1915 | (m * (ExprHandle(1) * n_1) + (n + 1)) - |
1916 | (m * (ExprHandle(1) * n_1) + n)))); |
1917 | ExprHandle simplified = IRSimplifier::simplify(body); |
1918 | IS_IMM_WITH_VAL(Int, simplified.node(), 3); |
1919 | } |
1920 | |
1921 | TEST(Simplify, SimplifyIfComponents) { |
1922 | VarHandle x("x" , kInt); |
1923 | VarHandle y("y" , kInt); |
1924 | ExprHandle body = IfThenElse::make( |
1925 | ((ExprHandle(5) - ExprHandle(4)) * x) > y, |
1926 | ExprHandle(2) * x - x, |
1927 | ExprHandle(2) * y - y); |
1928 | |
1929 | ExprHandle simplified = IRSimplifier::simplify(body); |
1930 | |
1931 | IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); |
1932 | |
1933 | IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); |
1934 | ASSERT_EQ(cmp->compare_select_op(), kGT); |
1935 | IS_VAR_WITH_NAME(cmp->lhs(), "x" ); |
1936 | IS_VAR_WITH_NAME(cmp->rhs(), "y" ); |
1937 | |
1938 | IS_VAR_WITH_NAME(ifexpr->true_value(), "x" ); |
1939 | IS_VAR_WITH_NAME(ifexpr->false_value(), "y" ); |
1940 | } |
1941 | |
1942 | TEST(Simplify, SimplifyOpaqueTerms) { |
1943 | VarHandle x("x" , kInt); |
1944 | VarHandle y("y" , kInt); |
1945 | |
1946 | { |
1947 | // 2 * x/y * y - x/y * y => x/y * y |
1948 | ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); |
1949 | ExprHandle simplified = IRSimplifier::simplify(body); |
1950 | |
1951 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
1952 | IS_NODE_WITH_NAME(Div, mul->lhs(), div); |
1953 | IS_VAR_WITH_NAME(div->lhs(), "x" ); |
1954 | IS_VAR_WITH_NAME(div->rhs(), "y" ); |
1955 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
1956 | } |
1957 | |
1958 | { |
1959 | // x%y - (x%y - 1) => 1 |
1960 | ExprHandle body = (x % y) - ((x % y) - 1); |
1961 | ExprHandle simplified = IRSimplifier::simplify(body); |
1962 | |
1963 | IS_IMM_WITH_VAL(Int, simplified.node(), 1); |
1964 | } |
1965 | } |
1966 | |
1967 | TEST(Simplify, SimplifySymbolicMinMax) { |
1968 | { |
1969 | // Minimum with constant difference between terms. |
1970 | VarHandle x("x" , kInt); |
1971 | ExprHandle body = Min::make(x + 3, x + 7, true); |
1972 | ExprHandle simplified = IRSimplifier::simplify(body); |
1973 | |
1974 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
1975 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1976 | IS_IMM_WITH_VAL(Int, add->rhs(), 3); |
1977 | } |
1978 | |
1979 | { |
1980 | // Maximum with constant difference between terms. |
1981 | VarHandle x("x" , kInt); |
1982 | ExprHandle body = Max::make(x + 3, x + 7, true); |
1983 | ExprHandle simplified = IRSimplifier::simplify(body); |
1984 | |
1985 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
1986 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
1987 | IS_IMM_WITH_VAL(Int, add->rhs(), 7); |
1988 | } |
1989 | |
1990 | { |
1991 | // Can't simplify multiples because of signedness of variable component. |
1992 | // TODO: maybe we could for unsigned types? |
1993 | VarHandle x("x" , kInt); |
1994 | ExprHandle body = Max::make(x * 3, x * 7, true); |
1995 | ExprHandle simplified = IRSimplifier::simplify(body); |
1996 | |
1997 | IS_NODE(Max, simplified.node()); |
1998 | } |
1999 | } |
2000 | |
2001 | TEST(Simplify, SimplifyNestedMax) { |
2002 | VarHandle x("x" , kInt); |
2003 | VarHandle y("y" , kInt); |
2004 | VarHandle z("z" , kInt); |
2005 | |
2006 | { |
2007 | // Max(x + y, x + y) => x + y |
2008 | ExprHandle body = Max::make(x + y, x + y, true); |
2009 | ExprHandle simplified = IRSimplifier::simplify(body); |
2010 | |
2011 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
2012 | IS_BINOP_W_VARS(Add, simplified.node(), add, "x" , "y" ); |
2013 | } |
2014 | |
2015 | { |
2016 | // Max(x + y, Max(x + y, z)) => Max(x + y, z) |
2017 | ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); |
2018 | ExprHandle simplified = IRSimplifier::simplify(body); |
2019 | |
2020 | IS_NODE_WITH_NAME(Max, simplified.node(), max); |
2021 | IS_BINOP_W_VARS(Add, max->lhs(), add, "x" , "y" ); |
2022 | IS_VAR_WITH_NAME(max->rhs(), "z" ); |
2023 | } |
2024 | |
2025 | { |
2026 | // Max(x + y, Max(z, x + y)) => Max(x + y, z) |
2027 | ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); |
2028 | ExprHandle simplified = IRSimplifier::simplify(body); |
2029 | |
2030 | IS_NODE_WITH_NAME(Max, simplified.node(), max); |
2031 | IS_BINOP_W_VARS(Add, max->lhs(), add, "x" , "y" ); |
2032 | IS_VAR_WITH_NAME(max->rhs(), "z" ); |
2033 | } |
2034 | |
2035 | { |
2036 | // Max(Max(x + y, z), x + y) => Max(x + y, z) |
2037 | ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); |
2038 | ExprHandle simplified = IRSimplifier::simplify(body); |
2039 | |
2040 | IS_NODE_WITH_NAME(Max, simplified.node(), max); |
2041 | IS_BINOP_W_VARS(Add, max->lhs(), add, "x" , "y" ); |
2042 | IS_VAR_WITH_NAME(max->rhs(), "z" ); |
2043 | } |
2044 | |
2045 | { |
2046 | // Max(Max(z, x + y), x + y) => Max(x + y, z) |
2047 | ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); |
2048 | ExprHandle simplified = IRSimplifier::simplify(body); |
2049 | |
2050 | IS_NODE_WITH_NAME(Max, simplified.node(), max); |
2051 | IS_BINOP_W_VARS(Add, max->lhs(), add, "x" , "y" ); |
2052 | IS_VAR_WITH_NAME(max->rhs(), "z" ); |
2053 | } |
2054 | |
2055 | { |
2056 | // Max(Max(x, y), x) => Max(Max(x, y), x) |
2057 | // Nested Max ops with different propagate_nans should not be simplified. |
2058 | ExprHandle body = Max::make(Max::make(x, y, true), x, false); |
2059 | ExprHandle simplified = IRSimplifier::simplify(body); |
2060 | |
2061 | IS_NODE_WITH_NAME(Max, simplified.node(), max); |
2062 | IS_BINOP_W_VARS(Max, max->lhs(), max1, "x" , "y" ); |
2063 | ASSERT_TRUE(max1->propagate_nans()); |
2064 | IS_VAR_WITH_NAME(max->rhs(), "x" ); |
2065 | ASSERT_FALSE(max->propagate_nans()); |
2066 | } |
2067 | |
2068 | { |
2069 | // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) |
2070 | ExprHandle body = |
2071 | Max::make(Min::make(x, y, true), Min::make(x, z, true), true); |
2072 | ExprHandle simplified = IRSimplifier::simplify(body); |
2073 | checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)" ); |
2074 | } |
2075 | |
2076 | { |
2077 | // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) |
2078 | ExprHandle body = |
2079 | Max::make(Min::make(x, y, true), Min::make(z, x, true), true); |
2080 | ExprHandle simplified = IRSimplifier::simplify(body); |
2081 | checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)" ); |
2082 | } |
2083 | |
2084 | { |
2085 | // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) |
2086 | ExprHandle body = |
2087 | Max::make(Min::make(y, x, true), Min::make(x, z, true), true); |
2088 | ExprHandle simplified = IRSimplifier::simplify(body); |
2089 | checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)" ); |
2090 | } |
2091 | |
2092 | { |
2093 | // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) |
2094 | ExprHandle body = |
2095 | Max::make(Min::make(y, x, true), Min::make(z, x, true), true); |
2096 | ExprHandle simplified = IRSimplifier::simplify(body); |
2097 | checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)" ); |
2098 | } |
2099 | |
2100 | { |
2101 | // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) |
2102 | // When all the ops in the pattern do not have the same propagate_nans, |
2103 | // it should not be simplified. |
2104 | ExprHandle body = |
2105 | Max::make(Min::make(y, x, true), Min::make(z, x, false), true); |
2106 | ExprHandle simplified = IRSimplifier::simplify(body); |
2107 | |
2108 | IS_NODE_WITH_NAME(Max, simplified.node(), max); |
2109 | IS_BINOP_W_VARS(Min, max->lhs(), min1, "x" , "y" ); |
2110 | ASSERT_TRUE(min1->propagate_nans()); |
2111 | IS_BINOP_W_VARS(Min, max->rhs(), min2, "x" , "z" ); |
2112 | ASSERT_FALSE(min2->propagate_nans()); |
2113 | ASSERT_TRUE(max->propagate_nans()); |
2114 | } |
2115 | |
2116 | { |
2117 | // Max(5, Max(x, 8)) => Max(x, 8) |
2118 | ExprHandle body = Max::make(5, Max::make(x, 8, true), true); |
2119 | ExprHandle simplified = IRSimplifier::simplify(body); |
2120 | |
2121 | IS_BINOP_W_CONST(Max, simplified.node(), max, "x" , 8); |
2122 | ASSERT_TRUE(max->propagate_nans()); |
2123 | } |
2124 | |
2125 | { |
2126 | // Max(8, Max(x, 5)) => Max(x, 8) |
2127 | ExprHandle body = Max::make(8, Max::make(x, 5, true), true); |
2128 | ExprHandle simplified = IRSimplifier::simplify(body); |
2129 | |
2130 | IS_BINOP_W_CONST(Max, simplified.node(), max, "x" , 8); |
2131 | ASSERT_TRUE(max->propagate_nans()); |
2132 | } |
2133 | |
2134 | { |
2135 | // Max(Max(x, 8), 5) => Max(x, 8) |
2136 | ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); |
2137 | ExprHandle simplified = IRSimplifier::simplify(body); |
2138 | |
2139 | IS_BINOP_W_CONST(Max, simplified.node(), max, "x" , 8); |
2140 | ASSERT_TRUE(max->propagate_nans()); |
2141 | } |
2142 | |
2143 | { |
2144 | // Max(Max(x, 5), 8) => Max(x, 8) |
2145 | ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); |
2146 | ExprHandle simplified = IRSimplifier::simplify(body); |
2147 | |
2148 | IS_BINOP_W_CONST(Max, simplified.node(), max, "x" , 8); |
2149 | ASSERT_TRUE(max->propagate_nans()); |
2150 | } |
2151 | |
2152 | { |
2153 | // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) |
2154 | ExprHandle body = Max::make( |
2155 | 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); |
2156 | ExprHandle simplified = IRSimplifier::simplify(body); |
2157 | |
2158 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2159 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2160 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2161 | ASSERT_TRUE(max3->propagate_nans()); |
2162 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2163 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2164 | } |
2165 | |
2166 | { |
2167 | // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) |
2168 | ExprHandle body = Max::make( |
2169 | 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); |
2170 | ExprHandle simplified = IRSimplifier::simplify(body); |
2171 | |
2172 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2173 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2174 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2175 | ASSERT_TRUE(max3->propagate_nans()); |
2176 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2177 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2178 | } |
2179 | |
2180 | { |
2181 | // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) |
2182 | ExprHandle body = Max::make( |
2183 | 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); |
2184 | ExprHandle simplified = IRSimplifier::simplify(body); |
2185 | |
2186 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2187 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2188 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2189 | ASSERT_TRUE(max3->propagate_nans()); |
2190 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2191 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2192 | } |
2193 | |
2194 | { |
2195 | // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) |
2196 | ExprHandle body = Max::make( |
2197 | Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); |
2198 | ExprHandle simplified = IRSimplifier::simplify(body); |
2199 | |
2200 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2201 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2202 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2203 | ASSERT_TRUE(max3->propagate_nans()); |
2204 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2205 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2206 | } |
2207 | |
2208 | { |
2209 | // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) |
2210 | ExprHandle body = Max::make( |
2211 | Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); |
2212 | ExprHandle simplified = IRSimplifier::simplify(body); |
2213 | |
2214 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2215 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2216 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2217 | ASSERT_TRUE(max3->propagate_nans()); |
2218 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2219 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2220 | } |
2221 | |
2222 | { |
2223 | // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) |
2224 | ExprHandle body = Max::make( |
2225 | Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); |
2226 | ExprHandle simplified = IRSimplifier::simplify(body); |
2227 | |
2228 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2229 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2230 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2231 | ASSERT_TRUE(max3->propagate_nans()); |
2232 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2233 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2234 | } |
2235 | |
2236 | { |
2237 | // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) |
2238 | // Do not simplify when all the Max ops do not have the same |
2239 | // propagate_nans. |
2240 | ExprHandle body = Max::make( |
2241 | Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), |
2242 | 8, |
2243 | false); |
2244 | ExprHandle simplified = IRSimplifier::simplify(body); |
2245 | checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)" ); |
2246 | } |
2247 | |
2248 | { |
2249 | // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) |
2250 | ExprHandle body = Max::make( |
2251 | 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); |
2252 | ExprHandle simplified = IRSimplifier::simplify(body); |
2253 | |
2254 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2255 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2256 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2257 | ASSERT_TRUE(max3->propagate_nans()); |
2258 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2259 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2260 | } |
2261 | |
2262 | { |
2263 | // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) |
2264 | ExprHandle body = Max::make( |
2265 | Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); |
2266 | ExprHandle simplified = IRSimplifier::simplify(body); |
2267 | |
2268 | IS_NODE_WITH_NAME(Max, simplified.node(), max1); |
2269 | IS_NODE_WITH_NAME(Max, max1->lhs(), max2); |
2270 | IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x" , 8); |
2271 | ASSERT_TRUE(max3->propagate_nans()); |
2272 | IS_VAR_WITH_NAME(max2->rhs(), "y" ); |
2273 | IS_VAR_WITH_NAME(max1->rhs(), "z" ); |
2274 | } |
2275 | } |
2276 | |
2277 | TEST(Simplify, SimplifyNestedMin) { |
2278 | VarHandle x("x" , kInt); |
2279 | VarHandle y("y" , kInt); |
2280 | VarHandle z("z" , kInt); |
2281 | |
2282 | { |
2283 | // Min(x + y, x + y) => x + y |
2284 | ExprHandle body = Min::make(x + y, x + y, true); |
2285 | ExprHandle simplified = IRSimplifier::simplify(body); |
2286 | |
2287 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
2288 | IS_BINOP_W_VARS(Add, simplified.node(), add, "x" , "y" ); |
2289 | } |
2290 | |
2291 | { |
2292 | // Min(x + y, Min(x + y, z)) => Min(x + y, z) |
2293 | ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); |
2294 | ExprHandle simplified = IRSimplifier::simplify(body); |
2295 | |
2296 | IS_NODE_WITH_NAME(Min, simplified.node(), min); |
2297 | IS_BINOP_W_VARS(Add, min->lhs(), add, "x" , "y" ); |
2298 | IS_VAR_WITH_NAME(min->rhs(), "z" ); |
2299 | } |
2300 | |
2301 | { |
2302 | // Min(x + y, Min(z, x + y)) => Min(x + y, z) |
2303 | ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); |
2304 | ExprHandle simplified = IRSimplifier::simplify(body); |
2305 | |
2306 | IS_NODE_WITH_NAME(Min, simplified.node(), min); |
2307 | IS_BINOP_W_VARS(Add, min->lhs(), add, "x" , "y" ); |
2308 | IS_VAR_WITH_NAME(min->rhs(), "z" ); |
2309 | } |
2310 | |
2311 | { |
2312 | // Min(Min(x + y, z), x + y) => Min(x + y, z) |
2313 | ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); |
2314 | ExprHandle simplified = IRSimplifier::simplify(body); |
2315 | |
2316 | IS_NODE_WITH_NAME(Min, simplified.node(), min); |
2317 | IS_BINOP_W_VARS(Add, min->lhs(), add, "x" , "y" ); |
2318 | IS_VAR_WITH_NAME(min->rhs(), "z" ); |
2319 | } |
2320 | |
2321 | { |
2322 | // Min(Min(z, x + y), x + y) => Min(x + y, z) |
2323 | ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); |
2324 | ExprHandle simplified = IRSimplifier::simplify(body); |
2325 | |
2326 | IS_NODE_WITH_NAME(Min, simplified.node(), min); |
2327 | IS_BINOP_W_VARS(Add, min->lhs(), add, "x" , "y" ); |
2328 | IS_VAR_WITH_NAME(min->rhs(), "z" ); |
2329 | } |
2330 | |
2331 | { |
2332 | // Min(Min(x, y), x) => Min(Min(x, y), x) |
2333 | // Nested Min ops with different propagate_nans should not be simplified. |
2334 | ExprHandle body = Min::make(Min::make(x, y, true), x, false); |
2335 | ExprHandle simplified = IRSimplifier::simplify(body); |
2336 | |
2337 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2338 | IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x" , "y" ); |
2339 | ASSERT_TRUE(min2->propagate_nans()); |
2340 | IS_VAR_WITH_NAME(min1->rhs(), "x" ); |
2341 | ASSERT_FALSE(min1->propagate_nans()); |
2342 | } |
2343 | |
2344 | { |
2345 | // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) |
2346 | ExprHandle body = |
2347 | Min::make(Max::make(x, y, true), Max::make(x, z, true), true); |
2348 | ExprHandle simplified = IRSimplifier::simplify(body); |
2349 | checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)" ); |
2350 | } |
2351 | |
2352 | { |
2353 | // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) |
2354 | ExprHandle body = |
2355 | Min::make(Max::make(x, y, true), Max::make(z, x, true), true); |
2356 | ExprHandle simplified = IRSimplifier::simplify(body); |
2357 | checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)" ); |
2358 | } |
2359 | |
2360 | { |
2361 | // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) |
2362 | ExprHandle body = |
2363 | Min::make(Max::make(y, x, true), Max::make(x, z, true), true); |
2364 | ExprHandle simplified = IRSimplifier::simplify(body); |
2365 | checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)" ); |
2366 | } |
2367 | |
2368 | { |
2369 | // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) |
2370 | ExprHandle body = |
2371 | Min::make(Max::make(y, x, true), Max::make(z, x, true), true); |
2372 | ExprHandle simplified = IRSimplifier::simplify(body); |
2373 | checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)" ); |
2374 | } |
2375 | |
2376 | { |
2377 | // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) |
2378 | // When all the ops in the pattern do not have the same propagate_nans, |
2379 | // it should not be simplified. |
2380 | ExprHandle body = |
2381 | Min::make(Max::make(y, x, true), Max::make(z, x, false), true); |
2382 | ExprHandle simplified = IRSimplifier::simplify(body); |
2383 | |
2384 | IS_NODE_WITH_NAME(Min, simplified.node(), min); |
2385 | IS_BINOP_W_VARS(Max, min->lhs(), max1, "x" , "y" ); |
2386 | ASSERT_TRUE(max1->propagate_nans()); |
2387 | IS_BINOP_W_VARS(Max, min->rhs(), max2, "x" , "z" ); |
2388 | ASSERT_FALSE(max2->propagate_nans()); |
2389 | ASSERT_TRUE(min->propagate_nans()); |
2390 | } |
2391 | |
2392 | { |
2393 | // Min(5, Min(x, 8)) => Min(x, 8) |
2394 | ExprHandle body = Min::make(5, Min::make(x, 8, true), true); |
2395 | ExprHandle simplified = IRSimplifier::simplify(body); |
2396 | |
2397 | IS_BINOP_W_CONST(Min, simplified.node(), min, "x" , 5); |
2398 | ASSERT_TRUE(min->propagate_nans()); |
2399 | } |
2400 | |
2401 | { |
2402 | // Min(8, Min(x, 5)) => Min(x, 8) |
2403 | ExprHandle body = Min::make(8, Min::make(x, 5, true), true); |
2404 | ExprHandle simplified = IRSimplifier::simplify(body); |
2405 | |
2406 | IS_BINOP_W_CONST(Min, simplified.node(), min, "x" , 5); |
2407 | ASSERT_TRUE(min->propagate_nans()); |
2408 | } |
2409 | |
2410 | { |
2411 | // Min(Min(x, 8), 5) => Min(x, 8) |
2412 | ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); |
2413 | ExprHandle simplified = IRSimplifier::simplify(body); |
2414 | |
2415 | IS_BINOP_W_CONST(Min, simplified.node(), min, "x" , 5); |
2416 | ASSERT_TRUE(min->propagate_nans()); |
2417 | } |
2418 | |
2419 | { |
2420 | // Min(Min(x, 5), 8) => Min(x, 8) |
2421 | ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); |
2422 | ExprHandle simplified = IRSimplifier::simplify(body); |
2423 | |
2424 | IS_BINOP_W_CONST(Min, simplified.node(), min, "x" , 5); |
2425 | ASSERT_TRUE(min->propagate_nans()); |
2426 | } |
2427 | |
2428 | { |
2429 | // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) |
2430 | ExprHandle body = Min::make( |
2431 | 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); |
2432 | ExprHandle simplified = IRSimplifier::simplify(body); |
2433 | |
2434 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2435 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2436 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2437 | ASSERT_TRUE(min3->propagate_nans()); |
2438 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2439 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2440 | } |
2441 | |
2442 | { |
2443 | // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) |
2444 | ExprHandle body = Min::make( |
2445 | 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); |
2446 | ExprHandle simplified = IRSimplifier::simplify(body); |
2447 | |
2448 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2449 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2450 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2451 | ASSERT_TRUE(min3->propagate_nans()); |
2452 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2453 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2454 | } |
2455 | |
2456 | { |
2457 | // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) |
2458 | ExprHandle body = Min::make( |
2459 | 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); |
2460 | ExprHandle simplified = IRSimplifier::simplify(body); |
2461 | |
2462 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2463 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2464 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2465 | ASSERT_TRUE(min3->propagate_nans()); |
2466 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2467 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2468 | } |
2469 | |
2470 | { |
2471 | // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) |
2472 | ExprHandle body = Min::make( |
2473 | Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); |
2474 | ExprHandle simplified = IRSimplifier::simplify(body); |
2475 | |
2476 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2477 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2478 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2479 | ASSERT_TRUE(min3->propagate_nans()); |
2480 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2481 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2482 | } |
2483 | |
2484 | { |
2485 | // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) |
2486 | ExprHandle body = Min::make( |
2487 | Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); |
2488 | ExprHandle simplified = IRSimplifier::simplify(body); |
2489 | |
2490 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2491 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2492 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2493 | ASSERT_TRUE(min3->propagate_nans()); |
2494 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2495 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2496 | } |
2497 | |
2498 | { |
2499 | // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) |
2500 | ExprHandle body = Min::make( |
2501 | Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); |
2502 | ExprHandle simplified = IRSimplifier::simplify(body); |
2503 | |
2504 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2505 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2506 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2507 | ASSERT_TRUE(min3->propagate_nans()); |
2508 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2509 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2510 | } |
2511 | |
2512 | { |
2513 | // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) |
2514 | // Do not simplify when all the Min ops do not have the same |
2515 | // propagate_nans. |
2516 | ExprHandle body = Min::make( |
2517 | Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), |
2518 | 8, |
2519 | false); |
2520 | ExprHandle simplified = IRSimplifier::simplify(body); |
2521 | checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)" ); |
2522 | } |
2523 | |
2524 | { |
2525 | // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) |
2526 | ExprHandle body = Min::make( |
2527 | 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); |
2528 | ExprHandle simplified = IRSimplifier::simplify(body); |
2529 | |
2530 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2531 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2532 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2533 | ASSERT_TRUE(min3->propagate_nans()); |
2534 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2535 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2536 | } |
2537 | |
2538 | { |
2539 | // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) |
2540 | ExprHandle body = Min::make( |
2541 | Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); |
2542 | ExprHandle simplified = IRSimplifier::simplify(body); |
2543 | |
2544 | IS_NODE_WITH_NAME(Min, simplified.node(), min1); |
2545 | IS_NODE_WITH_NAME(Min, min1->lhs(), min2); |
2546 | IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x" , 5); |
2547 | ASSERT_TRUE(min3->propagate_nans()); |
2548 | IS_VAR_WITH_NAME(min2->rhs(), "y" ); |
2549 | IS_VAR_WITH_NAME(min1->rhs(), "z" ); |
2550 | } |
2551 | } |
2552 | |
2553 | TEST(Simplify, SimplifyWontReorderFloat) { |
2554 | { |
2555 | // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) |
2556 | // This is an expression we can simplify. |
2557 | VarHandle x("x" , kInt); |
2558 | VarHandle y("y" , kInt); |
2559 | |
2560 | ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - |
2561 | ExprHandle(3) * (ExprHandle(3) * y); |
2562 | ExprHandle simplified = IRSimplifier::simplify(body); |
2563 | |
2564 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2565 | IS_IMM_WITH_VAL(Int, mul->lhs(), 9); |
2566 | IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); |
2567 | IS_VAR_WITH_NAME(sub->lhs(), "x" ); |
2568 | IS_VAR_WITH_NAME(sub->rhs(), "y" ); |
2569 | } |
2570 | |
2571 | { |
2572 | // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). |
2573 | // If the vars are floating point, ops are not associative and we can't |
2574 | // reorder. |
2575 | VarHandle x("x" , kFloat); |
2576 | VarHandle y("y" , kFloat); |
2577 | |
2578 | ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - |
2579 | ExprHandle(3) * (ExprHandle(3) * y); |
2580 | ExprHandle simplified = IRSimplifier::simplify(body); |
2581 | |
2582 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
2583 | IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); |
2584 | IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); |
2585 | IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); |
2586 | IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); |
2587 | IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x" ); |
2588 | |
2589 | IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); |
2590 | IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); |
2591 | IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); |
2592 | IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); |
2593 | IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y" ); |
2594 | } |
2595 | |
2596 | { |
2597 | // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). |
2598 | // We will simplify subexprs if they dont reorder floating point ops. |
2599 | VarHandle x("x" , kDouble); |
2600 | VarHandle y("y" , kInt); |
2601 | |
2602 | ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - |
2603 | ExprHandle(3) * (ExprHandle(3) * y); |
2604 | ExprHandle simplified = IRSimplifier::simplify(body); |
2605 | |
2606 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
2607 | IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); |
2608 | IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); |
2609 | IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); |
2610 | IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); |
2611 | IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x" ); |
2612 | |
2613 | IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); |
2614 | IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); |
2615 | IS_VAR_WITH_NAME(rhsMul->rhs(), "y" ); |
2616 | } |
2617 | |
2618 | { |
2619 | // Prevent reordering if FP propagated from dtypes. |
2620 | VarHandle x("x" , kInt); |
2621 | VarHandle y("y" , kInt); |
2622 | |
2623 | ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - |
2624 | ExprHandle(3) * (ExprHandle(3.f) * y); |
2625 | ExprHandle simplified = IRSimplifier::simplify(body); |
2626 | |
2627 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
2628 | IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); |
2629 | IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); |
2630 | IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); |
2631 | IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); |
2632 | IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x" ); |
2633 | |
2634 | IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); |
2635 | IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); |
2636 | IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); |
2637 | IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); |
2638 | IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); |
2639 | IS_VAR_WITH_NAME(yCast->src_value(), "y" ); |
2640 | } |
2641 | |
2642 | { |
2643 | VarHandle x("x" , kFloat); |
2644 | VarHandle y("y" , kFloat); |
2645 | // x%y - (x%y - 1) => x%y - (x%y - 1). |
2646 | // We wont reorder opaque ops if they are FP. |
2647 | ExprHandle body = (x % y) - ((x % y) - 1); |
2648 | ExprHandle simplified = IRSimplifier::simplify(body); |
2649 | |
2650 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
2651 | IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); |
2652 | IS_VAR_WITH_NAME(lhsMod->lhs(), "x" ); |
2653 | IS_VAR_WITH_NAME(lhsMod->rhs(), "y" ); |
2654 | |
2655 | IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); |
2656 | IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); |
2657 | IS_VAR_WITH_NAME(rhsMod->lhs(), "x" ); |
2658 | IS_VAR_WITH_NAME(rhsMod->rhs(), "y" ); |
2659 | IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); |
2660 | } |
2661 | } |
2662 | |
2663 | TEST(Simplify, SimplifyRoundModPattern) { |
2664 | { |
2665 | // (x/y)*y + x%y => x. |
2666 | VarHandle x("x" , kInt); |
2667 | VarHandle y("y" , kInt); |
2668 | ExprHandle body = ((x / y) * y) + (x % y); |
2669 | ExprHandle simplified = IRSimplifier::simplify(body); |
2670 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
2671 | } |
2672 | |
2673 | { |
2674 | // Reverse order. |
2675 | // x%y + (x/y)*y => x. |
2676 | VarHandle x("x" , kInt); |
2677 | VarHandle y("y" , kInt); |
2678 | ExprHandle body = (x % y) + ((x / y) * y); |
2679 | ExprHandle simplified = IRSimplifier::simplify(body); |
2680 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
2681 | } |
2682 | |
2683 | { |
2684 | // Non opaque denominator. |
2685 | // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. |
2686 | VarHandle x("x" , kInt); |
2687 | VarHandle y("y" , kInt); |
2688 | ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + |
2689 | (x % (y + ExprHandle(4))); |
2690 | ExprHandle simplified = IRSimplifier::simplify(body); |
2691 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
2692 | } |
2693 | |
2694 | { |
2695 | // Reverse order. |
2696 | // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. |
2697 | VarHandle x("x" , kInt); |
2698 | VarHandle y("y" , kInt); |
2699 | ExprHandle body = (x % (y + ExprHandle(4))) + |
2700 | ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); |
2701 | ExprHandle simplified = IRSimplifier::simplify(body); |
2702 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
2703 | } |
2704 | |
2705 | { |
2706 | // Opaque denominator. |
2707 | // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. |
2708 | VarHandle x("x" , kInt); |
2709 | VarHandle y("y" , kInt); |
2710 | ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + |
2711 | (x % (ExprHandle(2) / y)); |
2712 | ExprHandle simplified = IRSimplifier::simplify(body); |
2713 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
2714 | } |
2715 | |
2716 | { |
2717 | // Non opaque numerator |
2718 | // ((2*x)/y * y) + ((2*x) % y) => 2 * x. |
2719 | VarHandle x("x" , kInt); |
2720 | VarHandle y("y" , kInt); |
2721 | ExprHandle body = |
2722 | (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); |
2723 | ExprHandle simplified = IRSimplifier::simplify(body); |
2724 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2725 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
2726 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2727 | } |
2728 | |
2729 | { |
2730 | // Opaque numerator. |
2731 | // ((x/2) / y * y) + (x/2 % y) => x / 2. |
2732 | VarHandle x("x" , kInt); |
2733 | VarHandle y("y" , kInt); |
2734 | ExprHandle body = |
2735 | (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); |
2736 | ExprHandle simplified = IRSimplifier::simplify(body); |
2737 | |
2738 | IS_NODE_WITH_NAME(Div, simplified.node(), div); |
2739 | IS_VAR_WITH_NAME(div->lhs(), "x" ); |
2740 | IS_IMM_WITH_VAL(Int, div->rhs(), 2); |
2741 | } |
2742 | |
2743 | { |
2744 | // Numerator and denominator. |
2745 | // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. |
2746 | VarHandle x("x" , kInt); |
2747 | VarHandle y("y" , kInt); |
2748 | ExprHandle body = |
2749 | (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + |
2750 | ((ExprHandle(2) * x) % (ExprHandle(2) * y)); |
2751 | ExprHandle simplified = IRSimplifier::simplify(body); |
2752 | |
2753 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2754 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
2755 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2756 | } |
2757 | |
2758 | { |
2759 | // Reverse order. |
2760 | // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. |
2761 | VarHandle x("x" , kInt); |
2762 | VarHandle y("y" , kInt); |
2763 | ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + |
2764 | (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); |
2765 | ExprHandle simplified = IRSimplifier::simplify(body); |
2766 | |
2767 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2768 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
2769 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2770 | } |
2771 | |
2772 | { |
2773 | // Negated Subtraction of Round Mod. |
2774 | // (x/y) * y - (0 - x%y) => x. |
2775 | VarHandle x("x" , kInt); |
2776 | VarHandle y("y" , kInt); |
2777 | ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); |
2778 | ExprHandle simplified = IRSimplifier::simplify(body); |
2779 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
2780 | } |
2781 | |
2782 | { |
2783 | // Other terms are preserved. |
2784 | // (x/y)*y + x%y + (y * x) => x + (y * x). |
2785 | VarHandle x("x" , kInt); |
2786 | VarHandle y("y" , kInt); |
2787 | ExprHandle body = ((x / y) * y) + (x % y) + (y * x); |
2788 | ExprHandle simplified = IRSimplifier::simplify(body); |
2789 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
2790 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
2791 | IS_NODE_WITH_NAME(Mul, add->rhs(), mul); |
2792 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
2793 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
2794 | } |
2795 | |
2796 | { |
2797 | // Sanity checking we wont do the optimization on floats. |
2798 | VarHandle x("x" , kFloat); |
2799 | VarHandle y("y" , kFloat); |
2800 | ExprHandle body = ((x / y) * y) + (x % y); |
2801 | ExprHandle simplified = IRSimplifier::simplify(body); |
2802 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
2803 | IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); |
2804 | IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); |
2805 | IS_VAR_WITH_NAME(roundDiv->lhs(), "x" ); |
2806 | IS_VAR_WITH_NAME(roundDiv->rhs(), "y" ); |
2807 | IS_VAR_WITH_NAME(roundMul->rhs(), "y" ); |
2808 | IS_NODE_WITH_NAME(Mod, add->rhs(), mod); |
2809 | IS_VAR_WITH_NAME(mod->lhs(), "x" ); |
2810 | IS_VAR_WITH_NAME(mod->rhs(), "y" ); |
2811 | } |
2812 | |
2813 | { |
2814 | // Sanity check we wont do it if the mod term doesn't match. |
2815 | VarHandle x("x" , kInt); |
2816 | VarHandle y("y" , kInt); |
2817 | VarHandle z("z" , kInt); |
2818 | ExprHandle body = ((x / y) * y) + (x % z); |
2819 | ExprHandle simplified = IRSimplifier::simplify(body); |
2820 | checkExprIR(simplified, "(x / y) * y + x % z" ); |
2821 | } |
2822 | |
2823 | { |
2824 | // Sanity check we wont do it if the div term doesn't match. |
2825 | VarHandle x("x" , kInt); |
2826 | VarHandle y("y" , kInt); |
2827 | VarHandle z("z" , kInt); |
2828 | ExprHandle body = (y * (x / z)) + (x % y); |
2829 | ExprHandle simplified = IRSimplifier::simplify(body); |
2830 | checkExprIR(simplified, "x % y + (x / z) * y" ); |
2831 | } |
2832 | |
2833 | { |
2834 | // Sanity check we wont do it if the mul term doesn't match. |
2835 | VarHandle x("x" , kInt); |
2836 | VarHandle y("y" , kInt); |
2837 | VarHandle z("z" , kInt); |
2838 | ExprHandle body = ((x / y) * z) + (x % y); |
2839 | ExprHandle simplified = IRSimplifier::simplify(body); |
2840 | checkExprIR(simplified, "x % y + (x / y) * z" ); |
2841 | } |
2842 | } |
2843 | |
2844 | TEST(Simplify, SimplifyRoundModPatternFactorization) { |
2845 | { |
2846 | // Full factorization. |
2847 | // 2 * (x/y * y) + 2 * (x%y) => 2 * x. |
2848 | VarHandle x("x" , kInt); |
2849 | VarHandle y("y" , kInt); |
2850 | ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); |
2851 | ExprHandle simplified = IRSimplifier::simplify(body); |
2852 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2853 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
2854 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2855 | } |
2856 | |
2857 | { |
2858 | // Partial Factorization. |
2859 | // 32 * (x/8) + 4 * (x % 8) => 4 * x. |
2860 | VarHandle x("x" , kInt); |
2861 | VarHandle y("y" , kInt); |
2862 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
2863 | ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); |
2864 | ExprHandle simplified = IRSimplifier::simplify(body); |
2865 | |
2866 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2867 | IS_IMM_WITH_VAL(Int, mul->lhs(), 4); |
2868 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2869 | } |
2870 | |
2871 | { |
2872 | // Factorization requiring constant folding. |
2873 | // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. |
2874 | VarHandle x("x" , kInt); |
2875 | ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + |
2876 | (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); |
2877 | ExprHandle simplified = IRSimplifier::simplify(body); |
2878 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2879 | IS_IMM_WITH_VAL(Int, mul->lhs(), 5); |
2880 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2881 | } |
2882 | |
2883 | { |
2884 | VarHandle x("x" , kInt); |
2885 | ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); |
2886 | ExprHandle simplified = IRSimplifier::simplify(body); |
2887 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
2888 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
2889 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
2890 | } |
2891 | |
2892 | { |
2893 | VarHandle x("x" , kInt); |
2894 | ExprHandle body = (x / 10) * 0 + x % 5; |
2895 | ExprHandle simplified = IRSimplifier::simplify(body); |
2896 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
2897 | IS_VAR_WITH_NAME(mod->lhs(), "x" ); |
2898 | IS_IMM_WITH_VAL(Int, mod->rhs(), 5); |
2899 | } |
2900 | } |
2901 | |
2902 | TEST(Simplify, SimplifyRoundModPatternMultivar) { |
2903 | { |
2904 | // Multivar. |
2905 | // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. |
2906 | VarHandle x("x" , kInt); |
2907 | VarHandle y("y" , kInt); |
2908 | ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + |
2909 | (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); |
2910 | ExprHandle simplified = IRSimplifier::simplify(body); |
2911 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
2912 | IS_VAR_WITH_NAME(add->lhs(), "x" ); |
2913 | IS_VAR_WITH_NAME(add->rhs(), "y" ); |
2914 | } |
2915 | |
2916 | { |
2917 | // Find the right var. |
2918 | // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 |
2919 | VarHandle x("x" , kInt); |
2920 | VarHandle y("y" , kInt); |
2921 | VarHandle z("z" , kInt); |
2922 | ExprHandle body = |
2923 | (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); |
2924 | ExprHandle simplified = IRSimplifier::simplify(body); |
2925 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
2926 | IS_NODE_WITH_NAME(Add, add->lhs(), add2); |
2927 | IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); |
2928 | IS_VAR_WITH_NAME(xMod->lhs(), "x" ); |
2929 | IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); |
2930 | IS_VAR_WITH_NAME(add2->rhs(), "y" ); |
2931 | IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); |
2932 | IS_VAR_WITH_NAME(zMod->lhs(), "z" ); |
2933 | IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); |
2934 | } |
2935 | |
2936 | { |
2937 | // Compound. |
2938 | // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) |
2939 | // => (z + 512 * y) + x |
2940 | VarHandle x("x" , kInt); |
2941 | VarHandle y("y" , kInt); |
2942 | VarHandle z("z" , kInt); |
2943 | |
2944 | ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); |
2945 | ExprHandle simplified = IRSimplifier::simplify(body); |
2946 | checkExprIR(simplified, "x + (z + 512 * y)" ); |
2947 | } |
2948 | } |
2949 | |
2950 | TEST(Simplify, SimplifyModRoundModPattern) { |
2951 | { |
2952 | // t/7 % 9 * 7 + t % 7 => t%63 |
2953 | VarHandle t("t" , kInt); |
2954 | ExprHandle body = (t / 7 % 9) * 7 + t % 7; |
2955 | ExprHandle simplified = IRSimplifier::simplify(body); |
2956 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
2957 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
2958 | IS_IMM_WITH_VAL(Int, mod->rhs(), 63); |
2959 | } |
2960 | |
2961 | { |
2962 | // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 |
2963 | VarHandle t("t" , kInt); |
2964 | ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; |
2965 | ExprHandle simplified = IRSimplifier::simplify(body); |
2966 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
2967 | IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); |
2968 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
2969 | IS_VAR_WITH_NAME(mul->rhs(), "t" ); |
2970 | IS_IMM_WITH_VAL(Int, mod->rhs(), 63); |
2971 | } |
2972 | |
2973 | { |
2974 | // t/x % y * x + t % x => t%(x*y) |
2975 | VarHandle t("t" , kInt); |
2976 | VarHandle x("x" , kInt); |
2977 | VarHandle y("y" , kInt); |
2978 | ExprHandle body = (t / x % y) * x + t % x; |
2979 | ExprHandle simplified = IRSimplifier::simplify(body); |
2980 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
2981 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
2982 | IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); |
2983 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
2984 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
2985 | } |
2986 | |
2987 | { |
2988 | // k*t/x % y * x + k*t % x => k*t%(x*y) |
2989 | VarHandle t("t" , kInt); |
2990 | VarHandle x("x" , kInt); |
2991 | VarHandle y("y" , kInt); |
2992 | VarHandle k("k" , kInt); |
2993 | ExprHandle body = (k * t / x % y) * x + k * t % x; |
2994 | ExprHandle simplified = IRSimplifier::simplify(body); |
2995 | checkExprIR(simplified, "(k * t) % (x * y)" ); |
2996 | } |
2997 | |
2998 | { |
2999 | // t/k/x % y * x + t/k % x => t/k%(x*y) |
3000 | VarHandle t("t" , kInt); |
3001 | VarHandle x("x" , kInt); |
3002 | VarHandle y("y" , kInt); |
3003 | VarHandle k("k" , kInt); |
3004 | ExprHandle body = (t / k / x % y) * x + t / k % x; |
3005 | ExprHandle simplified = IRSimplifier::simplify(body); |
3006 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
3007 | IS_NODE_WITH_NAME(Div, mod->lhs(), div); |
3008 | IS_VAR_WITH_NAME(div->lhs(), "t" ); |
3009 | IS_VAR_WITH_NAME(div->rhs(), "k" ); |
3010 | IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); |
3011 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
3012 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
3013 | } |
3014 | |
3015 | { |
3016 | // Sanity checking we wont do the optimization on floats. |
3017 | VarHandle x("x" , kFloat); |
3018 | VarHandle y("y" , kFloat); |
3019 | VarHandle z("z" , kFloat); |
3020 | ExprHandle body = ((x / y % z) * y) + (x % y); |
3021 | ExprHandle simplified = IRSimplifier::simplify(body); |
3022 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
3023 | IS_NODE_WITH_NAME(Mul, add->lhs(), mul); |
3024 | IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); |
3025 | IS_NODE_WITH_NAME(Div, mod->lhs(), div); |
3026 | IS_VAR_WITH_NAME(div->lhs(), "x" ); |
3027 | IS_VAR_WITH_NAME(div->rhs(), "y" ); |
3028 | IS_VAR_WITH_NAME(mod->rhs(), "z" ); |
3029 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
3030 | IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); |
3031 | IS_VAR_WITH_NAME(mod2->lhs(), "x" ); |
3032 | IS_VAR_WITH_NAME(mod2->rhs(), "y" ); |
3033 | } |
3034 | } |
3035 | |
3036 | TEST(Simplify, SimplifyModRoundModPatternFactorization) { |
3037 | { |
3038 | // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) |
3039 | VarHandle t("t" , kInt); |
3040 | ExprHandle body = |
3041 | ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); |
3042 | ExprHandle simplified = IRSimplifier::simplify(body); |
3043 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
3044 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
3045 | IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); |
3046 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
3047 | IS_IMM_WITH_VAL(Int, mod->rhs(), 63); |
3048 | } |
3049 | |
3050 | { |
3051 | // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) |
3052 | VarHandle t("t" , kInt); |
3053 | ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); |
3054 | ExprHandle simplified = IRSimplifier::simplify(body); |
3055 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
3056 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
3057 | IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); |
3058 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
3059 | IS_IMM_WITH_VAL(Int, mod->rhs(), 63); |
3060 | } |
3061 | |
3062 | { |
3063 | // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 |
3064 | VarHandle t("t" , kInt); |
3065 | ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; |
3066 | ExprHandle simplified = IRSimplifier::simplify(body); |
3067 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
3068 | IS_NODE_WITH_NAME(Div, mod->lhs(), div); |
3069 | IS_VAR_WITH_NAME(div->lhs(), "t" ); |
3070 | IS_IMM_WITH_VAL(Int, div->rhs(), 2); |
3071 | IS_IMM_WITH_VAL(Int, mod->rhs(), 63); |
3072 | } |
3073 | |
3074 | { |
3075 | // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 |
3076 | VarHandle t("t" , kInt); |
3077 | ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + |
3078 | t % (ExprHandle(7) * ExprHandle(3)); |
3079 | ExprHandle simplified = IRSimplifier::simplify(body); |
3080 | IS_NODE_WITH_NAME(Mod, simplified.node(), mod); |
3081 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
3082 | IS_IMM_WITH_VAL(Int, mod->rhs(), 189); |
3083 | } |
3084 | |
3085 | { |
3086 | // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) |
3087 | VarHandle t("t" , kInt); |
3088 | VarHandle x("x" , kInt); |
3089 | VarHandle y("y" , kInt); |
3090 | ExprHandle body = |
3091 | ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); |
3092 | ExprHandle simplified = IRSimplifier::simplify(body); |
3093 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
3094 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
3095 | IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); |
3096 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
3097 | IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); |
3098 | IS_VAR_WITH_NAME(mul2->lhs(), "x" ); |
3099 | IS_VAR_WITH_NAME(mul2->rhs(), "y" ); |
3100 | } |
3101 | } |
3102 | |
3103 | TEST(Simplify, SimplifyModRoundModPatternMultivar) { |
3104 | { |
3105 | // t/7 % 9 * 7 + t % 7 + t => t % 63 + t |
3106 | VarHandle t("t" , kInt); |
3107 | ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; |
3108 | ExprHandle simplified = IRSimplifier::simplify(body); |
3109 | checkExprIR(simplified, "t % 63 + t" ); |
3110 | } |
3111 | |
3112 | { |
3113 | // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 |
3114 | VarHandle t("t" , kInt); |
3115 | ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; |
3116 | ExprHandle simplified = IRSimplifier::simplify(body); |
3117 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
3118 | IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); |
3119 | IS_VAR_WITH_NAME(mod1->lhs(), "t" ); |
3120 | IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); |
3121 | IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); |
3122 | IS_VAR_WITH_NAME(mod2->lhs(), "t" ); |
3123 | IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); |
3124 | } |
3125 | |
3126 | { |
3127 | // k + t/x % y * x + t % x => k + t%(x*y) |
3128 | VarHandle t("t" , kInt); |
3129 | VarHandle x("x" , kInt); |
3130 | VarHandle y("y" , kInt); |
3131 | VarHandle k("k" , kInt); |
3132 | ExprHandle body = k + (t / x % y) * x + t % x; |
3133 | ExprHandle simplified = IRSimplifier::simplify(body); |
3134 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
3135 | IS_VAR_WITH_NAME(add->lhs(), "k" ); |
3136 | IS_NODE_WITH_NAME(Mod, add->rhs(), mod); |
3137 | IS_VAR_WITH_NAME(mod->lhs(), "t" ); |
3138 | IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); |
3139 | IS_VAR_WITH_NAME(mul->lhs(), "x" ); |
3140 | IS_VAR_WITH_NAME(mul->rhs(), "y" ); |
3141 | } |
3142 | |
3143 | { |
3144 | // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x |
3145 | // => t%(x*y) + t/k % (x*y) |
3146 | VarHandle t("t" , kInt); |
3147 | VarHandle x("x" , kInt); |
3148 | VarHandle y("y" , kInt); |
3149 | VarHandle k("k" , kInt); |
3150 | ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; |
3151 | ExprHandle simplified = IRSimplifier::simplify(body); |
3152 | checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)" ); |
3153 | } |
3154 | |
3155 | { |
3156 | // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) |
3157 | // => io_flat |
3158 | VarHandle t("io_flat" , kInt); |
3159 | ExprHandle body = |
3160 | ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); |
3161 | ExprHandle simplified = IRSimplifier::simplify(body); |
3162 | IS_VAR_WITH_NAME(simplified.node(), "io_flat" ); |
3163 | } |
3164 | |
3165 | { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + |
3166 | // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + |
3167 | // (i0_flat / (9 * 7) % 10) * 7 * 9 + |
3168 | // (i0_flat / 7 % 9) * 7 + |
3169 | // i0_flat % 7 => io_flat |
3170 | VarHandle t("io_flat" , kInt); |
3171 | ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + |
3172 | (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + |
3173 | (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; |
3174 | ExprHandle simplified = IRSimplifier::simplify(body); |
3175 | IS_VAR_WITH_NAME(simplified.node(), "io_flat" ); |
3176 | } |
3177 | |
3178 | { |
3179 | // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * |
3180 | // (i0_flat / (m * n)) => io_flat |
3181 | VarHandle t("io_flat" , kInt); |
3182 | VarHandle m("m" , kInt); |
3183 | VarHandle n("n" , kInt); |
3184 | ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); |
3185 | ExprHandle simplified = IRSimplifier::simplify(body); |
3186 | IS_VAR_WITH_NAME(simplified.node(), "io_flat" ); |
3187 | } |
3188 | |
3189 | { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + |
3190 | // (i0_flat / (l * n * m) % k) * m * n * l + |
3191 | // (i0_flat / (n * m) % l) * m * n + |
3192 | // (i0_flat / m % n) * m + |
3193 | // i0_flat % m => io_flat |
3194 | VarHandle t("io_flat" , kInt); |
3195 | VarHandle m("m" , kInt); |
3196 | VarHandle n("n" , kInt); |
3197 | VarHandle l("l" , kInt); |
3198 | VarHandle k("k" , kInt); |
3199 | ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + |
3200 | (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + |
3201 | (t / m % n) * m + t % m; |
3202 | ExprHandle simplified = IRSimplifier::simplify(body); |
3203 | IS_VAR_WITH_NAME(simplified.node(), "io_flat" ); |
3204 | } |
3205 | } |
3206 | |
3207 | TEST(Simplify, SimplifyDivisionScalarFactorization) { |
3208 | { |
3209 | // Simple factorization of numerator and denominator. |
3210 | // 8x / 4y => 2x / y. |
3211 | VarHandle x("x" , kInt); |
3212 | VarHandle y("y" , kInt); |
3213 | ExprHandle body = (x * 8) / (y * 4); |
3214 | ExprHandle simplified = IRSimplifier::simplify(body); |
3215 | IS_NODE_WITH_NAME(Div, simplified.node(), div); |
3216 | IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); |
3217 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); |
3218 | IS_VAR_WITH_NAME(lhs->rhs(), "x" ); |
3219 | IS_VAR_WITH_NAME(div->rhs(), "y" ); |
3220 | } |
3221 | |
3222 | { |
3223 | // Don't change anything if we can't factorize. |
3224 | VarHandle x("x" , kInt); |
3225 | VarHandle y("y" , kInt); |
3226 | ExprHandle body = (x * 7) / (y * 4); |
3227 | ExprHandle simplified = IRSimplifier::simplify(body); |
3228 | IS_NODE_WITH_NAME(Div, simplified.node(), div); |
3229 | IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); |
3230 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); |
3231 | IS_VAR_WITH_NAME(lhs->rhs(), "x" ); |
3232 | IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); |
3233 | IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); |
3234 | IS_VAR_WITH_NAME(rhs->rhs(), "y" ); |
3235 | } |
3236 | |
3237 | { |
3238 | // Don't reorder floats. |
3239 | VarHandle x("x" , kFloat); |
3240 | VarHandle y("y" , kFloat); |
3241 | ExprHandle body = (x * 8) / (y * 4); |
3242 | ExprHandle simplified = IRSimplifier::simplify(body); |
3243 | IS_NODE_WITH_NAME(Div, simplified.node(), div); |
3244 | IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); |
3245 | IS_VAR_WITH_NAME(lhs->lhs(), "x" ); |
3246 | IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); |
3247 | IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); |
3248 | IS_VAR_WITH_NAME(rhs->lhs(), "y" ); |
3249 | IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); |
3250 | } |
3251 | |
3252 | { |
3253 | // Sanity check we do nothing if there are only scalar parts. |
3254 | VarHandle x("x" , kInt); |
3255 | VarHandle y("y" , kInt); |
3256 | ExprHandle body = (x * 1) / (y * 1); |
3257 | ExprHandle simplified = IRSimplifier::simplify(body); |
3258 | IS_NODE_WITH_NAME(Div, simplified.node(), div); |
3259 | IS_VAR_WITH_NAME(div->lhs(), "x" ); |
3260 | IS_VAR_WITH_NAME(div->rhs(), "y" ); |
3261 | } |
3262 | |
3263 | { |
3264 | // Can factorize amounts of variables. |
3265 | VarHandle x("x" , kInt); |
3266 | VarHandle y("y" , kInt); |
3267 | ExprHandle body = (x + x + x + x) / (y + y); |
3268 | ExprHandle simplified = IRSimplifier::simplify(body); |
3269 | IS_NODE_WITH_NAME(Div, simplified.node(), div); |
3270 | IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); |
3271 | IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); |
3272 | IS_VAR_WITH_NAME(lhs->rhs(), "x" ); |
3273 | IS_VAR_WITH_NAME(div->rhs(), "y" ); |
3274 | } |
3275 | } |
3276 | |
3277 | TEST(Simplify, SimplifyConstantBranches) { |
3278 | { |
3279 | // If the condition is constant true then take the true_value. |
3280 | // 1 ? x : y => x |
3281 | VarHandle x("x" , kInt); |
3282 | VarHandle y("y" , kInt); |
3283 | ExprHandle t(1); |
3284 | ExprHandle body = IfThenElse::make(t, x, y); |
3285 | ExprHandle simplified = IRSimplifier::simplify(body); |
3286 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
3287 | } |
3288 | |
3289 | { |
3290 | // If the condition is constant false then take the false_value. |
3291 | // 0 ? x : y => y |
3292 | VarHandle x("x" , kInt); |
3293 | VarHandle y("y" , kInt); |
3294 | ExprHandle t(0); |
3295 | ExprHandle body = IfThenElse::make(t, x, y); |
3296 | ExprHandle simplified = IRSimplifier::simplify(body); |
3297 | IS_VAR_WITH_NAME(simplified.node(), "y" ); |
3298 | } |
3299 | |
3300 | { |
3301 | // condition is simplified before checking. |
3302 | // (x-x) ? x : y => y |
3303 | VarHandle x("x" , kInt); |
3304 | VarHandle y("y" , kInt); |
3305 | ExprHandle body = IfThenElse::make(x - x, x, y); |
3306 | ExprHandle simplified = IRSimplifier::simplify(body); |
3307 | IS_VAR_WITH_NAME(simplified.node(), "y" ); |
3308 | } |
3309 | |
3310 | { |
3311 | // If both branches are the same then don't do the condition. |
3312 | // y ? x : x => x |
3313 | VarHandle x("x" , kInt); |
3314 | VarHandle y("y" , kInt); |
3315 | ExprHandle body = IfThenElse::make(y, x, x); |
3316 | ExprHandle simplified = IRSimplifier::simplify(body); |
3317 | IS_VAR_WITH_NAME(simplified.node(), "x" ); |
3318 | } |
3319 | |
3320 | { |
3321 | // If both branches simplify to the same thing it still works. |
3322 | // y ? (x + x) : (2 * x) => x |
3323 | VarHandle x("x" , kInt); |
3324 | VarHandle y("y" , kInt); |
3325 | ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); |
3326 | ExprHandle simplified = IRSimplifier::simplify(body); |
3327 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
3328 | IS_IMM_WITH_VAL(Int, mul->lhs(), 2); |
3329 | IS_VAR_WITH_NAME(mul->rhs(), "x" ); |
3330 | } |
3331 | } |
3332 | |
3333 | TEST(Simplify, SimplifyConstantCond) { |
3334 | { |
3335 | // If the condition is constant true then take the true_value. |
3336 | // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 |
3337 | BufHandle a("A" , {1}, kInt); |
3338 | BufHandle b("B" , {1}, kInt); |
3339 | ExprHandle condition(1); |
3340 | StmtPtr true_val = Store::make(a, {0}, 1); |
3341 | StmtPtr false_val = Store::make(b, {0}, 1); |
3342 | |
3343 | CondPtr body = alloc<Cond>(condition.node(), true_val, false_val); |
3344 | StmtPtr simplified = IRSimplifier::simplify(body); |
3345 | BlockPtr block = to<Block>(simplified); |
3346 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3347 | IS_VAR_WITH_NAME(store->base_handle(), "A" ); |
3348 | } |
3349 | |
3350 | { |
3351 | // If the condition is constant false then take the false_value. |
3352 | // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 |
3353 | BufHandle a("A" , {1}, kInt); |
3354 | BufHandle b("B" , {1}, kInt); |
3355 | ExprHandle condition(0); |
3356 | StmtPtr true_val = Store::make(a, {0}, 1); |
3357 | StmtPtr false_val = Store::make(b, {0}, 1); |
3358 | |
3359 | StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val); |
3360 | StmtPtr simplified = IRSimplifier::simplify(body); |
3361 | BlockPtr block = to<Block>(simplified); |
3362 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3363 | IS_VAR_WITH_NAME(store->base_handle(), "B" ); |
3364 | } |
3365 | |
3366 | { |
3367 | // condition is simplified before checking. |
3368 | // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 |
3369 | VarHandle x("x" , kInt); |
3370 | BufHandle a("A" , {1}, kInt); |
3371 | BufHandle b("B" , {1}, kInt); |
3372 | ExprHandle condition(x - x); |
3373 | StmtPtr true_val = Store::make(a, {0}, 1); |
3374 | StmtPtr false_val = Store::make(b, {0}, 1); |
3375 | |
3376 | StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val); |
3377 | StmtPtr simplified = IRSimplifier::simplify(body); |
3378 | BlockPtr block = to<Block>(simplified); |
3379 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3380 | IS_VAR_WITH_NAME(store->base_handle(), "B" ); |
3381 | } |
3382 | |
3383 | { |
3384 | // If both branches are the same then don't do the condition. |
3385 | // x ? A[0] = x : A[0] = x => A[0] = x |
3386 | VarHandle x("x" , kInt); |
3387 | BufHandle a("A" , {1}, kInt); |
3388 | ExprHandle condition(x - x); |
3389 | StmtPtr true_val = Store::make(a, {0}, x); |
3390 | StmtPtr false_val = Store::make(a, {0}, x); |
3391 | |
3392 | StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val); |
3393 | StmtPtr simplified = IRSimplifier::simplify(body); |
3394 | BlockPtr block = to<Block>(simplified); |
3395 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3396 | IS_VAR_WITH_NAME(store->base_handle(), "A" ); |
3397 | } |
3398 | |
3399 | { |
3400 | // If both branches simplify to the same thing it still works. |
3401 | // x ? (x + x) : (2 * x) => x |
3402 | VarHandle x("x" , kInt); |
3403 | BufHandle a("A" , {1}, kInt); |
3404 | ExprHandle condition(x - x); |
3405 | StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); |
3406 | StmtPtr false_val = Store::make(a, {0}, x + x); |
3407 | |
3408 | StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val); |
3409 | StmtPtr simplified = IRSimplifier::simplify(body); |
3410 | BlockPtr block = to<Block>(simplified); |
3411 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3412 | IS_VAR_WITH_NAME(store->base_handle(), "A" ); |
3413 | } |
3414 | |
3415 | { |
3416 | // But not if they dont |
3417 | // x ? x : (2 * x) => x ? x : (2 * x) |
3418 | VarHandle x("x" , kInt); |
3419 | BufHandle a("A" , {1}, kInt); |
3420 | ExprHandle condition(x); |
3421 | StmtPtr true_val = Store::make(a, {0}, x); |
3422 | StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); |
3423 | |
3424 | StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val); |
3425 | StmtPtr simplified = IRSimplifier::simplify(body); |
3426 | BlockPtr block = to<Block>(simplified); |
3427 | ASSERT_EQ(block, nullptr); |
3428 | } |
3429 | |
3430 | { |
3431 | StmtPtr cond = alloc<Cond>( |
3432 | ExprHandle(false).node(), |
3433 | alloc<Block>(std::vector<StmtPtr>({})), |
3434 | nullptr); |
3435 | StmtPtr simplified = IRSimplifier::simplify(cond); |
3436 | ASSERT_EQ(simplified, nullptr); |
3437 | } |
3438 | |
3439 | { |
3440 | StmtPtr cond = alloc<Cond>( |
3441 | ExprHandle(true).node(), |
3442 | nullptr, |
3443 | alloc<Block>(std::vector<StmtPtr>({}))); |
3444 | StmtPtr simplified = IRSimplifier::simplify(cond); |
3445 | ASSERT_EQ(simplified, nullptr); |
3446 | } |
3447 | } |
3448 | |
3449 | TEST(Simplify, SimplifyEliminateEmptyCond) { |
3450 | // If the branches are empty in different ways, eliminate. |
3451 | { |
3452 | VarHandle x("x" , kInt); |
3453 | ExprHandle condition(x); |
3454 | StmtPtr true_val = alloc<Block>(std::vector<StmtPtr>({})); |
3455 | |
3456 | StmtPtr body = alloc<Cond>(condition.node(), true_val, nullptr); |
3457 | StmtPtr simplified = IRSimplifier::simplify(body); |
3458 | BlockPtr block = to<Block>(simplified); |
3459 | ASSERT_NE(block, nullptr); |
3460 | ASSERT_EQ(block->nstmts(), 0); |
3461 | } |
3462 | |
3463 | { |
3464 | VarHandle x("x" , kInt); |
3465 | ExprHandle condition(x); |
3466 | StmtPtr false_val = alloc<Block>(std::vector<StmtPtr>({})); |
3467 | |
3468 | StmtPtr body = alloc<Cond>(condition.node(), nullptr, false_val); |
3469 | StmtPtr simplified = IRSimplifier::simplify(body); |
3470 | BlockPtr block = to<Block>(simplified); |
3471 | ASSERT_NE(block, nullptr); |
3472 | ASSERT_EQ(block->nstmts(), 0); |
3473 | } |
3474 | } |
3475 | |
3476 | TEST(Simplify, SimplifyConstantComparisons) { |
3477 | auto ComparisonTest = |
3478 | [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { |
3479 | ExprHandle body = CompareSelect::make(a, b, op); |
3480 | ExprHandle simplified = IRSimplifier::simplify(body); |
3481 | IS_IMM_WITH_VAL(Int, simplified.node(), result); |
3482 | }; |
3483 | |
3484 | // Equals. |
3485 | ComparisonTest(2, 2, kEQ, 1); |
3486 | ComparisonTest(1, 2, kEQ, 0); |
3487 | ComparisonTest(2, 1, kEQ, 0); |
3488 | |
3489 | // Greater than. |
3490 | ComparisonTest(2, 2, kGT, 0); |
3491 | ComparisonTest(1, 2, kGT, 0); |
3492 | ComparisonTest(2, 1, kGT, 1); |
3493 | |
3494 | // Greater or Equal. |
3495 | ComparisonTest(2, 2, kGE, 1); |
3496 | ComparisonTest(1, 2, kGE, 0); |
3497 | ComparisonTest(2, 1, kGE, 1); |
3498 | |
3499 | // Less Than. |
3500 | ComparisonTest(2, 2, kLT, 0); |
3501 | ComparisonTest(1, 2, kLT, 1); |
3502 | ComparisonTest(2, 1, kLT, 0); |
3503 | |
3504 | // Less or Equal. |
3505 | ComparisonTest(2, 2, kLE, 1); |
3506 | ComparisonTest(1, 2, kLE, 1); |
3507 | ComparisonTest(2, 1, kLE, 0); |
3508 | |
3509 | // Not equal. |
3510 | ComparisonTest(2, 2, kNE, 0); |
3511 | ComparisonTest(1, 2, kNE, 1); |
3512 | ComparisonTest(2, 1, kNE, 1); |
3513 | |
3514 | // With specified results: |
3515 | ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); |
3516 | ExprHandle simplified = IRSimplifier::simplify(body); |
3517 | IS_IMM_WITH_VAL(Int, simplified.node(), 42); |
3518 | } |
3519 | |
3520 | TEST(Simplify, SimplifySymbolicComparisons) { |
3521 | VarHandle x("x" , kInt); |
3522 | VarHandle y("y" , kInt); |
3523 | |
3524 | auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; |
3525 | auto TookFalseBranch = [](ExprHandle a) { |
3526 | IS_IMM_WITH_VAL(Int, a.node(), 0); |
3527 | }; |
3528 | |
3529 | // EQ |
3530 | |
3531 | // x == x => 1 |
3532 | ExprHandle body = CompareSelect::make(x, x, kEQ); |
3533 | TookTrueBranch(IRSimplifier::simplify(body)); |
3534 | |
3535 | // x == x+1 => 0 |
3536 | body = CompareSelect::make(x, x + 1, kEQ); |
3537 | TookFalseBranch(IRSimplifier::simplify(body)); |
3538 | |
3539 | // x == x * 2 cannot simplify since we don't know x is nonzero. |
3540 | body = CompareSelect::make(x, x * 2, kEQ); |
3541 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
3542 | IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); |
3543 | |
3544 | // x == x * 1 => 1 |
3545 | body = CompareSelect::make(x, x * 1, kEQ); |
3546 | TookTrueBranch(IRSimplifier::simplify(body)); |
3547 | |
3548 | { |
3549 | // x == y => x == y |
3550 | body = CompareSelect::make(x, y, kEQ); |
3551 | ExprHandle simplified = IRSimplifier::simplify(body); |
3552 | IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); |
3553 | ASSERT_EQ(cmp->compare_select_op(), kEQ); |
3554 | IS_VAR_WITH_NAME(cmp->lhs(), "x" ); |
3555 | IS_VAR_WITH_NAME(cmp->rhs(), "y" ); |
3556 | } |
3557 | |
3558 | { |
3559 | // x == 5 => x == 5 |
3560 | body = CompareSelect::make(x, 5, kEQ); |
3561 | ExprHandle simplified = IRSimplifier::simplify(body); |
3562 | IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); |
3563 | ASSERT_EQ(cmp->compare_select_op(), kEQ); |
3564 | IS_VAR_WITH_NAME(cmp->lhs(), "x" ); |
3565 | IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); |
3566 | } |
3567 | |
3568 | // GT |
3569 | |
3570 | // x+1 > x => 1 |
3571 | body = CompareSelect::make(x + 1, x, kGT); |
3572 | TookTrueBranch(IRSimplifier::simplify(body)); |
3573 | |
3574 | // x > x + 1 => 0 |
3575 | body = CompareSelect::make(x, x + 1, kGT); |
3576 | TookFalseBranch(IRSimplifier::simplify(body)); |
3577 | |
3578 | // x > x - 1 => 1 |
3579 | body = CompareSelect::make(x, x - 1, kGT); |
3580 | TookTrueBranch(IRSimplifier::simplify(body)); |
3581 | |
3582 | // x - 1 > x => 0 |
3583 | body = CompareSelect::make(x - 1, x, kGT); |
3584 | TookFalseBranch(IRSimplifier::simplify(body)); |
3585 | |
3586 | // x > x => 0 |
3587 | body = CompareSelect::make(x, x, kGT); |
3588 | TookFalseBranch(IRSimplifier::simplify(body)); |
3589 | |
3590 | // x * 2 > x => x * 2 > x |
3591 | // since we don't know the sign of x. |
3592 | body = CompareSelect::make(x * 2, x, kGT); |
3593 | IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); |
3594 | |
3595 | // GE |
3596 | |
3597 | // x+1 >= x => 1 |
3598 | body = CompareSelect::make(x + 1, x, kGE); |
3599 | TookTrueBranch(IRSimplifier::simplify(body)); |
3600 | |
3601 | // x >= x + 1 => 0 |
3602 | body = CompareSelect::make(x, x + 1, kGE); |
3603 | TookFalseBranch(IRSimplifier::simplify(body)); |
3604 | |
3605 | // x >= x => 1 |
3606 | body = CompareSelect::make(x, x, kGE); |
3607 | TookTrueBranch(IRSimplifier::simplify(body)); |
3608 | |
3609 | // x * 2 >= x => x * 2 >= x |
3610 | // since we don't know the sign of x. |
3611 | body = CompareSelect::make(x * 2, x, kGE); |
3612 | IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); |
3613 | |
3614 | // LT |
3615 | |
3616 | // x+1 < x => 0 |
3617 | body = CompareSelect::make(x + 1, x, kLT); |
3618 | TookFalseBranch(IRSimplifier::simplify(body)); |
3619 | |
3620 | // x < x + 1 => 1 |
3621 | body = CompareSelect::make(x, x + 1, kLT); |
3622 | TookTrueBranch(IRSimplifier::simplify(body)); |
3623 | |
3624 | // x < x => 0 |
3625 | body = CompareSelect::make(x, x, kLT); |
3626 | TookFalseBranch(IRSimplifier::simplify(body)); |
3627 | |
3628 | // LE |
3629 | |
3630 | // x+1 <= x => 0 |
3631 | body = CompareSelect::make(x + 1, x, kLE); |
3632 | TookFalseBranch(IRSimplifier::simplify(body)); |
3633 | |
3634 | // x <= x + 1 => 1 |
3635 | body = CompareSelect::make(x, x + 1, kLE); |
3636 | TookTrueBranch(IRSimplifier::simplify(body)); |
3637 | |
3638 | // x <= x => 1 |
3639 | body = CompareSelect::make(x, x, kLE); |
3640 | TookTrueBranch(IRSimplifier::simplify(body)); |
3641 | |
3642 | // NE |
3643 | |
3644 | // x+1 != x => 1 |
3645 | body = CompareSelect::make(x + 1, x, kNE); |
3646 | TookTrueBranch(IRSimplifier::simplify(body)); |
3647 | |
3648 | // x != x + 1 => 1 |
3649 | body = CompareSelect::make(x, x + 1, kNE); |
3650 | TookTrueBranch(IRSimplifier::simplify(body)); |
3651 | |
3652 | // x != x => 0 |
3653 | body = CompareSelect::make(x, x, kNE); |
3654 | TookFalseBranch(IRSimplifier::simplify(body)); |
3655 | } |
3656 | |
3657 | TEST(Simplify, SimplifyEliminateZeroLengthFor) { |
3658 | { |
3659 | // Will eliminate zero loop For. |
3660 | BufHandle a("A" , {4}, kInt); |
3661 | BufHandle c("C" , {4}, kInt); |
3662 | VarHandle i("i" , kInt); |
3663 | auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); |
3664 | StmtPtr simplified = IRSimplifier::simplify(body); |
3665 | BlockPtr block = to<Block>(simplified); |
3666 | ASSERT_EQ(block->nstmts(), 0); |
3667 | } |
3668 | |
3669 | { |
3670 | // still works if start is not zero. |
3671 | BufHandle a("A" , {4}, kInt); |
3672 | BufHandle c("C" , {4}, kInt); |
3673 | VarHandle i("i" , kInt); |
3674 | auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); |
3675 | StmtPtr simplified = IRSimplifier::simplify(body); |
3676 | BlockPtr block = to<Block>(simplified); |
3677 | ASSERT_EQ(block->nstmts(), 0); |
3678 | } |
3679 | |
3680 | { |
3681 | // works if both terms are variable. |
3682 | VarHandle x("x" , kInt); |
3683 | BufHandle a("A" , {4}, kInt); |
3684 | BufHandle c("C" , {4}, kInt); |
3685 | VarHandle i("i" , kInt); |
3686 | auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); |
3687 | StmtPtr simplified = IRSimplifier::simplify(body); |
3688 | BlockPtr block = to<Block>(simplified); |
3689 | ASSERT_EQ(block->nstmts(), 0); |
3690 | } |
3691 | |
3692 | { |
3693 | // works if one term simplifies down. |
3694 | VarHandle x("x" , kInt); |
3695 | BufHandle a("A" , {4}, kInt); |
3696 | BufHandle c("C" , {4}, kInt); |
3697 | VarHandle i("i" , kInt); |
3698 | auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); |
3699 | StmtPtr simplified = IRSimplifier::simplify(body); |
3700 | BlockPtr block = to<Block>(simplified); |
3701 | ASSERT_EQ(block->nstmts(), 0); |
3702 | } |
3703 | |
3704 | { |
3705 | // Sanity check does nothing if the condition is not met. |
3706 | BufHandle a("A" , {4}, kInt); |
3707 | BufHandle c("C" , {4}, kInt); |
3708 | VarHandle i("i" , kInt); |
3709 | auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); |
3710 | StmtPtr simplified = IRSimplifier::simplify(body); |
3711 | IS_NODE(For, simplified); |
3712 | } |
3713 | } |
3714 | |
3715 | TEST(Simplify, SimplifyOneLoopFor) { |
3716 | { |
3717 | // Will remove the loop if the body is run once. |
3718 | BufHandle a("A" , {4}, kInt); |
3719 | BufHandle c("C" , {4}, kInt); |
3720 | VarHandle i("i" , kInt); |
3721 | auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); |
3722 | StmtPtr simplified = IRSimplifier::simplify(body); |
3723 | BlockPtr block = to<Block>(simplified); |
3724 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3725 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3726 | IS_IMM_WITH_VAL(Int, store->flat_index(), 0); |
3727 | } |
3728 | |
3729 | { |
3730 | // still works if start is not zero. |
3731 | BufHandle a("A" , {4}, kInt); |
3732 | BufHandle c("C" , {4}, kInt); |
3733 | VarHandle i("i" , kInt); |
3734 | auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); |
3735 | StmtPtr simplified = IRSimplifier::simplify(body); |
3736 | BlockPtr block = to<Block>(simplified); |
3737 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3738 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3739 | IS_IMM_WITH_VAL(Int, store->flat_index(), 2); |
3740 | } |
3741 | |
3742 | { |
3743 | // works if both terms are variable. |
3744 | VarHandle x("x" , kInt); |
3745 | BufHandle a("A" , {4}, kInt); |
3746 | BufHandle c("C" , {4}, kInt); |
3747 | VarHandle i("i" , kInt); |
3748 | auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); |
3749 | StmtPtr simplified = IRSimplifier::simplify(body); |
3750 | BlockPtr block = to<Block>(simplified); |
3751 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3752 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3753 | IS_VAR_WITH_NAME(store->flat_index(), "x" ); |
3754 | } |
3755 | |
3756 | { |
3757 | // works if one term simplifies down. |
3758 | VarHandle x("x" , kInt); |
3759 | BufHandle a("A" , {4}, kInt); |
3760 | BufHandle c("C" , {4}, kInt); |
3761 | VarHandle i("i" , kInt); |
3762 | auto body = |
3763 | For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); |
3764 | StmtPtr simplified = IRSimplifier::simplify(body); |
3765 | BlockPtr block = to<Block>(simplified); |
3766 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3767 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3768 | IS_IMM_WITH_VAL(Int, store->flat_index(), 0); |
3769 | } |
3770 | |
3771 | { |
3772 | // Sanity check does nothing if the condition is not met. |
3773 | BufHandle a("A" , {4}, kInt); |
3774 | BufHandle c("C" , {4}, kInt); |
3775 | VarHandle i("i" , kInt); |
3776 | auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); |
3777 | StmtPtr simplified = IRSimplifier::simplify(body); |
3778 | IS_NODE(For, simplified); |
3779 | } |
3780 | } |
3781 | |
3782 | TEST(Simplify, SimplifyForWontLoseLoopOptions) { |
3783 | { |
3784 | // Sanity check does nothing if the condition is not met. |
3785 | BufHandle a("A" , {4}, kInt); |
3786 | BufHandle c("C" , {4}, kInt); |
3787 | VarHandle i("i" , kInt); |
3788 | LoopOptions options; |
3789 | options.set_gpu_block_index(LoopOptions::IDX_W); |
3790 | auto body = |
3791 | For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); |
3792 | StmtPtr simplified = IRSimplifier::simplify(body); |
3793 | IS_NODE_WITH_NAME(For, simplified, for_); |
3794 | LoopOptions options2 = for_->loop_options(); |
3795 | ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); |
3796 | } |
3797 | } |
3798 | |
3799 | TEST(Simplify, SimplifyMultilevelFor) { |
3800 | { |
3801 | // Multiple layers of For will be simplified out. |
3802 | BufHandle a("A" , {4}, kInt); |
3803 | BufHandle c("C" , {4}, kInt); |
3804 | VarHandle i("i" , kInt); |
3805 | VarHandle j("j" , kInt); |
3806 | auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); |
3807 | auto outer = For::make(j, 0, 1, body); |
3808 | StmtPtr simplified = IRSimplifier::simplify(outer); |
3809 | BlockPtr block = to<Block>(simplified); |
3810 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3811 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3812 | IS_IMM_WITH_VAL(Int, store->flat_index(), 0); |
3813 | } |
3814 | |
3815 | { |
3816 | // Will maintain an outer loop if the inner loop is eliminated. |
3817 | BufHandle a("A" , {4}, kInt); |
3818 | BufHandle c("C" , {4}, kInt); |
3819 | VarHandle i("i" , kInt); |
3820 | VarHandle j("j" , kInt); |
3821 | auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); |
3822 | auto outer = For::make(j, 0, 2, body); |
3823 | StmtPtr simplified = IRSimplifier::simplify(outer); |
3824 | ForPtr for__ = static_to<For>(simplified); |
3825 | IS_NODE_WITH_NAME(For, for__, for_); |
3826 | IS_VAR_WITH_NAME(for_->var(), "j" ); |
3827 | IS_IMM_WITH_VAL(Int, for_->start(), 0); |
3828 | IS_IMM_WITH_VAL(Int, for_->stop(), 2); |
3829 | BlockPtr block = to<Block>(for_->body()); |
3830 | ASSERT_NE(block, nullptr); |
3831 | IS_NODE_WITH_NAME(Store, block->front(), store); |
3832 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3833 | IS_IMM_WITH_VAL(Int, store->flat_index(), 0); |
3834 | } |
3835 | |
3836 | { |
3837 | // Will maintain inner loop if outer loops is eliminated. |
3838 | BufHandle a("A" , {4}, kInt); |
3839 | BufHandle c("C" , {4}, kInt); |
3840 | VarHandle i("i" , kInt); |
3841 | VarHandle j("j" , kInt); |
3842 | auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); |
3843 | auto outer = For::make(j, 0, 1, body); |
3844 | StmtPtr simplified = IRSimplifier::simplify(outer); |
3845 | BlockPtr block = to<Block>(simplified); |
3846 | IS_NODE_WITH_NAME(For, block->front(), for_); |
3847 | IS_VAR_WITH_NAME(for_->var(), "i" ); |
3848 | IS_IMM_WITH_VAL(Int, for_->start(), 0); |
3849 | IS_IMM_WITH_VAL(Int, for_->stop(), 2); |
3850 | IS_NODE_WITH_NAME(Store, for_->body()->front(), store); |
3851 | IS_VAR_WITH_NAME(store->base_handle(), "C" ); |
3852 | IS_VAR_WITH_NAME(store->flat_index(), "i" ); |
3853 | } |
3854 | } |
3855 | |
3856 | TEST(Simplify, SimplifyForCleansUp) { |
3857 | { |
3858 | BufHandle a("a" , {1, 12, 1}, kFloat); |
3859 | VarHandle x("x" , kInt); |
3860 | Tensor b = Compute( |
3861 | "x" , |
3862 | {1, 12, 1}, |
3863 | [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { |
3864 | return i + m + n; |
3865 | }); |
3866 | LoopNest l({b}); |
3867 | l.prepareForCodegen(); |
3868 | |
3869 | StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); |
3870 | StmtPtr simplified = IRSimplifier::simplify(body); |
3871 | |
3872 | BlockPtr block = to<Block>(simplified); |
3873 | IS_NODE_WITH_NAME(For, block->front(), for_); |
3874 | // for is over "m". |
3875 | IS_VAR_WITH_NAME(for_->var(), "j" ); |
3876 | // x[m] = m; |
3877 | IS_NODE_WITH_NAME(Store, for_->body()->front(), store); |
3878 | IS_VAR_WITH_NAME(store->flat_index(), "j" ); |
3879 | IS_VAR_WITH_NAME(store->value(), "j" ); |
3880 | } |
3881 | } |
3882 | |
3883 | TEST(Simplify, SimplifyEliminateEmptyFor) { |
3884 | { |
3885 | // Flatten many layers around an empty block to an empty block. |
3886 | StmtPtr last = alloc<Block>(std::vector<StmtPtr>({})); |
3887 | for (const auto i : c10::irange(11)) { |
3888 | (void)i; // Suppress unused variable warning |
3889 | VarHandle loopVar("loopVar" , kInt); |
3890 | last = For::make(loopVar, 0, 10, last); |
3891 | } |
3892 | |
3893 | StmtPtr simplified = IRSimplifier::simplify(last); |
3894 | IS_NODE_WITH_NAME(Block, simplified, block); |
3895 | ASSERT_EQ(block->nstmts(), 0); |
3896 | } |
3897 | } |
3898 | |
3899 | TEST(Simplify, SimplifyFlattenBlock) { |
3900 | { |
3901 | // Flatten multiple blocks down to one. |
3902 | // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } |
3903 | BufHandle a("A" , {1}, kInt); |
3904 | StorePtr store1 = Store::make(a, {0}, 1); |
3905 | StorePtr store2 = Store::make(a, {0}, 0); |
3906 | |
3907 | BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1, store2})); |
3908 | BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1})); |
3909 | |
3910 | BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block2})); |
3911 | StmtPtr simplified = IRSimplifier::simplify(enclosing); |
3912 | |
3913 | IS_NODE_WITH_NAME(Block, simplified, block); |
3914 | ASSERT_EQ(block->nstmts(), 2); |
3915 | |
3916 | IS_NODE_WITH_NAME(Store, block->front(), store1_); |
3917 | IS_NODE_WITH_NAME(Store, block->back(), store2_); |
3918 | |
3919 | ASSERT_EQ(store1->value(), store1_->value()); |
3920 | ASSERT_EQ(store2->value(), store2_->value()); |
3921 | } |
3922 | |
3923 | { |
3924 | // Flatten multiple sub blocks containing statements. |
3925 | // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } |
3926 | BufHandle a("A" , {1}, kInt); |
3927 | StorePtr store1 = Store::make(a, {0}, 1); |
3928 | StorePtr store2 = Store::make(a, {0}, 0); |
3929 | |
3930 | BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1})); |
3931 | BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({store2})); |
3932 | |
3933 | BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block1, block2})); |
3934 | StmtPtr simplified = IRSimplifier::simplify(enclosing); |
3935 | |
3936 | IS_NODE_WITH_NAME(Block, simplified, block); |
3937 | ASSERT_EQ(block->nstmts(), 2); |
3938 | |
3939 | IS_NODE_WITH_NAME(Store, block->front(), store1_); |
3940 | IS_NODE_WITH_NAME(Store, block->back(), store2_); |
3941 | |
3942 | ASSERT_EQ(store1->value(), store1_->value()); |
3943 | ASSERT_EQ(store2->value(), store2_->value()); |
3944 | } |
3945 | |
3946 | { |
3947 | // Flatten sub blocks with different depths. |
3948 | // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } |
3949 | BufHandle a("A" , {1}, kInt); |
3950 | StorePtr store1 = Store::make(a, {0}, 1); |
3951 | StorePtr store2 = Store::make(a, {0}, 0); |
3952 | |
3953 | BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store2})); |
3954 | BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1})); |
3955 | |
3956 | BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({store1, block2})); |
3957 | StmtPtr simplified = IRSimplifier::simplify(enclosing); |
3958 | |
3959 | IS_NODE_WITH_NAME(Block, simplified, block); |
3960 | ASSERT_EQ(block->nstmts(), 2); |
3961 | |
3962 | IS_NODE_WITH_NAME(Store, block->front(), store1_); |
3963 | IS_NODE_WITH_NAME(Store, block->back(), store2_); |
3964 | |
3965 | ASSERT_EQ(store1->value(), store1_->value()); |
3966 | ASSERT_EQ(store2->value(), store2_->value()); |
3967 | } |
3968 | |
3969 | { |
3970 | // Flatten many layers around an empty block to an empty block. |
3971 | StmtPtr last = alloc<Block>(std::vector<StmtPtr>({})); |
3972 | for (const auto i : c10::irange(11)) { |
3973 | (void)i; // Suppress unused variable warning |
3974 | last = alloc<Block>(std::vector<StmtPtr>({last})); |
3975 | } |
3976 | |
3977 | StmtPtr simplified = IRSimplifier::simplify(last); |
3978 | IS_NODE_WITH_NAME(Block, simplified, block); |
3979 | ASSERT_EQ(block->nstmts(), 0); |
3980 | } |
3981 | } |
3982 | |
3983 | TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { |
3984 | { |
3985 | // Simple positive case. |
3986 | BufHandle b("x" , {0}, kInt); |
3987 | |
3988 | AllocatePtr alloc_ = Allocate::make(b); |
3989 | FreePtr free_ = Free::make(b); |
3990 | |
3991 | BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_})); |
3992 | ASSERT_EQ(block1->nstmts(), 2); |
3993 | |
3994 | StmtPtr simplified = IRSimplifier::simplify(block1); |
3995 | IS_NODE_WITH_NAME(Block, simplified, block2); |
3996 | ASSERT_EQ(block2->nstmts(), 0); |
3997 | } |
3998 | |
3999 | { |
4000 | // Simple negative case. |
4001 | BufHandle b("x" , {2}, kInt); |
4002 | |
4003 | AllocatePtr alloc_ = Allocate::make(b); |
4004 | FreePtr free_ = Free::make(b); |
4005 | |
4006 | BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_})); |
4007 | ASSERT_EQ(block1->nstmts(), 2); |
4008 | |
4009 | StmtPtr simplified = IRSimplifier::simplify(block1); |
4010 | IS_NODE_WITH_NAME(Block, simplified, block2); |
4011 | ASSERT_EQ(block2->nstmts(), 2); |
4012 | } |
4013 | |
4014 | { |
4015 | // Finds right Alloc/Free. |
4016 | BufHandle b1("x" , {0}, kInt); |
4017 | BufHandle b2("y" , {2}, kInt); |
4018 | |
4019 | AllocatePtr alloc1 = Allocate::make(b1); |
4020 | AllocatePtr alloc2 = Allocate::make(b2); |
4021 | FreePtr free2_ = Free::make(b2); |
4022 | FreePtr free1_ = Free::make(b1); |
4023 | |
4024 | BlockPtr block1 = |
4025 | alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_})); |
4026 | ASSERT_EQ(block1->nstmts(), 4); |
4027 | |
4028 | StmtPtr simplified = IRSimplifier::simplify(block1); |
4029 | IS_NODE_WITH_NAME(Block, simplified, block2); |
4030 | ASSERT_EQ(block2->nstmts(), 2); |
4031 | IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); |
4032 | IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y" ); |
4033 | IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); |
4034 | ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); |
4035 | } |
4036 | |
4037 | { |
4038 | // Dynamic shape. |
4039 | VarHandle z("z" , kInt); |
4040 | BufHandle b1("x" , {0}, kInt); |
4041 | BufHandle b2("y" , {z}, kInt); |
4042 | |
4043 | AllocatePtr alloc1 = Allocate::make(b1); |
4044 | AllocatePtr alloc2 = Allocate::make(b2); |
4045 | FreePtr free2_ = Free::make(b2); |
4046 | FreePtr free1_ = Free::make(b1); |
4047 | |
4048 | BlockPtr block1 = |
4049 | alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_})); |
4050 | ASSERT_EQ(block1->nstmts(), 4); |
4051 | StmtPtr simplified = IRSimplifier::simplify(block1); |
4052 | IS_NODE_WITH_NAME(Block, simplified, block2); |
4053 | ASSERT_EQ(block2->nstmts(), 2); |
4054 | } |
4055 | } |
4056 | |
4057 | TEST(Simplify, DontSimplifyRand) { |
4058 | { |
4059 | // rand() + rand() = rand() + rand() NOT 2 * rand(). |
4060 | ExprHandle body = |
4061 | Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); |
4062 | ExprHandle simplified = IRSimplifier::simplify(body); |
4063 | IS_NODE_WITH_NAME(Add, simplified.node(), add); |
4064 | IS_RAND(add->lhs()); |
4065 | IS_RAND(add->rhs()); |
4066 | } |
4067 | |
4068 | { |
4069 | // rand() - rand() = rand() - rand() NOT 0. |
4070 | ExprHandle body = |
4071 | Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); |
4072 | ExprHandle simplified = IRSimplifier::simplify(body); |
4073 | IS_NODE_WITH_NAME(Sub, simplified.node(), sub); |
4074 | IS_RAND(sub->lhs()); |
4075 | IS_RAND(sub->rhs()); |
4076 | } |
4077 | |
4078 | { |
4079 | // rand() * rand() = rand() * rand(). |
4080 | ExprHandle body = |
4081 | Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); |
4082 | ExprHandle simplified = IRSimplifier::simplify(body); |
4083 | IS_NODE_WITH_NAME(Mul, simplified.node(), mul); |
4084 | IS_RAND(mul->lhs()); |
4085 | IS_RAND(mul->rhs()); |
4086 | } |
4087 | } |
4088 | |
4089 | TEST(Simplify, SimplifyReorderForCond) { |
4090 | BufHandle a("A" , {4}, kInt); |
4091 | BufHandle b("B" , {1}, kInt); |
4092 | BufHandle c("C" , {4}, kInt); |
4093 | VarHandle i("i" , kInt); |
4094 | VarHandle j("j" , kInt); |
4095 | |
4096 | { |
4097 | // for ( if ( ... ) ) => if ( for ( ... ) ). |
4098 | auto body = For::make( |
4099 | i, |
4100 | 0, |
4101 | 4, |
4102 | Cond::make( |
4103 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4104 | Store::make(c, {i}, Load::make(a, {i})), |
4105 | nullptr)); |
4106 | |
4107 | StmtPtr simplified = IRSimplifier::simplify(body); |
4108 | IS_NODE_WITH_NAME(Cond, simplified, cond); |
4109 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); |
4110 | IS_NODE_WITH_NAME(For, true_block->front(), loop); |
4111 | } |
4112 | |
4113 | { |
4114 | // Can't reorder if condition is dependent on the loop var. |
4115 | auto body = For::make( |
4116 | i, |
4117 | 0, |
4118 | 4, |
4119 | Cond::make( |
4120 | CompareSelect::make(i, 2, CompareSelectOperation::kEQ), |
4121 | Store::make(c, {i}, Load::make(a, {i})), |
4122 | nullptr)); |
4123 | |
4124 | StmtPtr simplified = IRSimplifier::simplify(body); |
4125 | IS_NODE_WITH_NAME(For, simplified, loop); |
4126 | IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); |
4127 | } |
4128 | |
4129 | { |
4130 | // Can't reorder if condition is dependent on a var that is modified inside |
4131 | // the loop. |
4132 | auto body = For::make( |
4133 | i, |
4134 | 0, |
4135 | 4, |
4136 | Cond::make( |
4137 | CompareSelect::make( |
4138 | Load::make(c, {0}), 10, CompareSelectOperation::kLT), |
4139 | Store::make(c, {0}, Load::make(a, {i})), |
4140 | nullptr)); |
4141 | |
4142 | StmtPtr simplified = IRSimplifier::simplify(body); |
4143 | IS_NODE_WITH_NAME(For, simplified, loop); |
4144 | IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); |
4145 | } |
4146 | |
4147 | { |
4148 | // Condition based on buffer not referenced in body. Can reorder here. |
4149 | auto body = For::make( |
4150 | i, |
4151 | 0, |
4152 | 4, |
4153 | Cond::make( |
4154 | CompareSelect::make( |
4155 | Load::make(b, {0}), 10, CompareSelectOperation::kLT), |
4156 | Store::make(c, {0}, Load::make(a, {i})), |
4157 | nullptr)); |
4158 | |
4159 | StmtPtr simplified = IRSimplifier::simplify(body); |
4160 | IS_NODE_WITH_NAME(Cond, simplified, cond); |
4161 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); |
4162 | IS_NODE_WITH_NAME(For, true_block->front(), loop); |
4163 | } |
4164 | |
4165 | { |
4166 | // Condition based on buffer read only in body. Can reorder here. |
4167 | auto body = For::make( |
4168 | i, |
4169 | 0, |
4170 | 4, |
4171 | Cond::make( |
4172 | CompareSelect::make( |
4173 | Load::make(a, {0}), 10, CompareSelectOperation::kLT), |
4174 | Store::make(c, {0}, Load::make(a, {i})), |
4175 | nullptr)); |
4176 | |
4177 | StmtPtr simplified = IRSimplifier::simplify(body); |
4178 | IS_NODE_WITH_NAME(Cond, simplified, cond); |
4179 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); |
4180 | IS_NODE_WITH_NAME(For, true_block->front(), loop); |
4181 | } |
4182 | |
4183 | { |
4184 | // Condition depends on Let in the loop. Cannot reorder. |
4185 | auto body = For::make( |
4186 | i, |
4187 | 0, |
4188 | 4, |
4189 | Block::make( |
4190 | {Let::make(j, 3), |
4191 | Cond::make( |
4192 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4193 | Store::make(c, {0}, Load::make(a, {i})), |
4194 | nullptr)})); |
4195 | |
4196 | StmtPtr simplified = IRSimplifier::simplify(body); |
4197 | IS_NODE_WITH_NAME(For, simplified, loop); |
4198 | IS_NODE_WITH_NAME(Let, loop->body()->front(), let); |
4199 | IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); |
4200 | } |
4201 | |
4202 | { |
4203 | // Multi level Ifs where all conditions are distinct. Move BOTH Cond |
4204 | // statements outside the loop. |
4205 | auto body = For::make( |
4206 | i, |
4207 | 0, |
4208 | 4, |
4209 | Cond::make( |
4210 | CompareSelect::make( |
4211 | Load::make(a, {0}), 10, CompareSelectOperation::kLT), |
4212 | Cond::make( |
4213 | CompareSelect::make(j, 10, CompareSelectOperation::kEQ), |
4214 | Store::make(c, {0}, Load::make(a, {i})), |
4215 | nullptr), |
4216 | nullptr)); |
4217 | |
4218 | StmtPtr simplified = IRSimplifier::simplify(body); |
4219 | IS_NODE_WITH_NAME(Cond, simplified, cond); |
4220 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); |
4221 | IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); |
4222 | IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); |
4223 | IS_NODE_WITH_NAME(For, true_block2->front(), loop); |
4224 | } |
4225 | |
4226 | { |
4227 | // Multi level Ifs where the inner condition does depend on a loop var, |
4228 | // reorder only the first Cond. |
4229 | auto body = For::make( |
4230 | i, |
4231 | 0, |
4232 | 4, |
4233 | Cond::make( |
4234 | CompareSelect::make( |
4235 | Load::make(a, {0}), 10, CompareSelectOperation::kLT), |
4236 | Cond::make( |
4237 | CompareSelect::make(i, 3, CompareSelectOperation::kEQ), |
4238 | Store::make(c, {0}, Load::make(a, {i})), |
4239 | nullptr), |
4240 | nullptr)); |
4241 | |
4242 | StmtPtr simplified = IRSimplifier::simplify(body); |
4243 | IS_NODE_WITH_NAME(Cond, simplified, cond); |
4244 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); |
4245 | IS_NODE_WITH_NAME(For, true_block->front(), loop); |
4246 | IS_NODE_WITH_NAME(Block, loop->body(), loop_body); |
4247 | IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); |
4248 | } |
4249 | |
4250 | { |
4251 | // Don't reorder if there's an else block of the Cond. |
4252 | // We could, but is it much better? |
4253 | auto body = For::make( |
4254 | i, |
4255 | 0, |
4256 | 4, |
4257 | Cond::make( |
4258 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4259 | Store::make(c, {0}, Load::make(a, {i})), |
4260 | Store::make(c, {0}, 0))); |
4261 | |
4262 | StmtPtr simplified = IRSimplifier::simplify(body); |
4263 | IS_NODE_WITH_NAME(For, simplified, loop); |
4264 | IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); |
4265 | } |
4266 | |
4267 | { |
4268 | // Condition uses distinct region of Tensor. |
4269 | // We could reorder here wih better analysis, but we don't. Included for |
4270 | // completeness. |
4271 | auto body = For::make( |
4272 | i, |
4273 | 0, |
4274 | 4, |
4275 | Cond::make( |
4276 | CompareSelect::make( |
4277 | Load::make(c, {0}), 10, CompareSelectOperation::kLT), |
4278 | Store::make(c, {1}, Load::make(a, {i})), |
4279 | nullptr)); |
4280 | |
4281 | StmtPtr simplified = IRSimplifier::simplify(body); |
4282 | IS_NODE_WITH_NAME(For, simplified, loop); |
4283 | IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); |
4284 | } |
4285 | } |
4286 | |
4287 | TEST(Simplify, SimplifyFuseConditions) { |
4288 | BufHandle a("A" , {2}, kInt); |
4289 | BufHandle b("B" , {2}, kInt); |
4290 | VarHandle i("i" , kInt); |
4291 | VarHandle j("j" , kInt); |
4292 | |
4293 | { |
4294 | // Can fuse since the conditions are identical. |
4295 | // if (A) { X }; if (A) { Y }; => if (A) { X; Y } |
4296 | auto body = Block::make( |
4297 | {Cond::make( |
4298 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4299 | Store::make(a, {0}, i), |
4300 | nullptr), |
4301 | Cond::make( |
4302 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4303 | Store::make(a, {1}, i), |
4304 | nullptr)}); |
4305 | |
4306 | StmtPtr simplified = IRSimplifier::simplify(body); |
4307 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
4308 | IS_NODE_WITH_NAME(Block, simplified, block); |
4309 | ASSERT_EQ(block->nstmts(), 1); |
4310 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4311 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4312 | ASSERT_EQ(true_stmt->nstmts(), 2); |
4313 | ASSERT_EQ(cond->false_stmt(), nullptr); |
4314 | } |
4315 | |
4316 | { |
4317 | // Can't fuse, conditions are not identical in lhs (i != j). |
4318 | auto body = Block::make( |
4319 | {Cond::make( |
4320 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4321 | Store::make(a, {0}, i), |
4322 | nullptr), |
4323 | Cond::make( |
4324 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4325 | Store::make(a, {1}, i), |
4326 | nullptr)}); |
4327 | |
4328 | StmtPtr simplified = IRSimplifier::simplify(body); |
4329 | IS_NODE_WITH_NAME(Block, simplified, block); |
4330 | ASSERT_EQ(block->nstmts(), 2); |
4331 | IS_NODE_WITH_NAME(Cond, block->front(), cond1); |
4332 | IS_NODE_WITH_NAME(Cond, block->back(), cond2); |
4333 | |
4334 | IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); |
4335 | IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); |
4336 | ASSERT_EQ(true_stmt1->nstmts(), 1); |
4337 | ASSERT_EQ(true_stmt2->nstmts(), 1); |
4338 | |
4339 | ASSERT_EQ(cond1->false_stmt(), nullptr); |
4340 | ASSERT_EQ(cond2->false_stmt(), nullptr); |
4341 | } |
4342 | { |
4343 | // Can't fuse, conditions are not identical in rhs (10 != 11). |
4344 | auto body = Block::make( |
4345 | {Cond::make( |
4346 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4347 | Store::make(a, {0}, i), |
4348 | nullptr), |
4349 | Cond::make( |
4350 | CompareSelect::make(i, 11, CompareSelectOperation::kLT), |
4351 | Store::make(a, {1}, i), |
4352 | nullptr)}); |
4353 | |
4354 | StmtPtr simplified = IRSimplifier::simplify(body); |
4355 | IS_NODE_WITH_NAME(Block, simplified, block); |
4356 | ASSERT_EQ(block->nstmts(), 2); |
4357 | IS_NODE_WITH_NAME(Cond, block->front(), cond1); |
4358 | IS_NODE_WITH_NAME(Cond, block->back(), cond2); |
4359 | |
4360 | IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); |
4361 | IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); |
4362 | ASSERT_EQ(true_stmt1->nstmts(), 1); |
4363 | ASSERT_EQ(true_stmt2->nstmts(), 1); |
4364 | |
4365 | ASSERT_EQ(cond1->false_stmt(), nullptr); |
4366 | ASSERT_EQ(cond2->false_stmt(), nullptr); |
4367 | } |
4368 | |
4369 | { |
4370 | // Can't fuse, conditions are not identical in operation (LT vs GT). |
4371 | auto body = Block::make( |
4372 | {Cond::make( |
4373 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4374 | Store::make(a, {0}, i), |
4375 | nullptr), |
4376 | Cond::make( |
4377 | CompareSelect::make(i, 10, CompareSelectOperation::kGT), |
4378 | Store::make(a, {1}, i), |
4379 | nullptr)}); |
4380 | |
4381 | StmtPtr simplified = IRSimplifier::simplify(body); |
4382 | IS_NODE_WITH_NAME(Block, simplified, block); |
4383 | ASSERT_EQ(block->nstmts(), 2); |
4384 | IS_NODE_WITH_NAME(Cond, block->front(), cond1); |
4385 | IS_NODE_WITH_NAME(Cond, block->back(), cond2); |
4386 | |
4387 | IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); |
4388 | IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); |
4389 | ASSERT_EQ(true_stmt1->nstmts(), 1); |
4390 | ASSERT_EQ(true_stmt2->nstmts(), 1); |
4391 | |
4392 | ASSERT_EQ(cond1->false_stmt(), nullptr); |
4393 | ASSERT_EQ(cond2->false_stmt(), nullptr); |
4394 | } |
4395 | |
4396 | { |
4397 | // Can't fuse, CompareSelect results are different. |
4398 | // Actually we totally could if we normalized CompareSelect results, but |
4399 | // TODO for later. |
4400 | auto body = Block::make( |
4401 | {Cond::make( |
4402 | CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), |
4403 | Store::make(a, {0}, i), |
4404 | nullptr), |
4405 | Cond::make( |
4406 | CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), |
4407 | Store::make(a, {1}, i), |
4408 | nullptr)}); |
4409 | |
4410 | StmtPtr simplified = IRSimplifier::simplify(body); |
4411 | IS_NODE_WITH_NAME(Block, simplified, block); |
4412 | ASSERT_EQ(block->nstmts(), 2); |
4413 | IS_NODE_WITH_NAME(Cond, block->front(), cond1); |
4414 | IS_NODE_WITH_NAME(Cond, block->back(), cond2); |
4415 | |
4416 | IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); |
4417 | IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); |
4418 | ASSERT_EQ(true_stmt1->nstmts(), 1); |
4419 | ASSERT_EQ(true_stmt2->nstmts(), 1); |
4420 | |
4421 | ASSERT_EQ(cond1->false_stmt(), nullptr); |
4422 | ASSERT_EQ(cond2->false_stmt(), nullptr); |
4423 | } |
4424 | |
4425 | { |
4426 | // Can fuse with false stmt only. |
4427 | auto body = Block::make( |
4428 | {Cond::make( |
4429 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4430 | nullptr, |
4431 | Store::make(a, {0}, i)), |
4432 | Cond::make( |
4433 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4434 | nullptr, |
4435 | Store::make(a, {1}, i))}); |
4436 | |
4437 | StmtPtr simplified = IRSimplifier::simplify(body); |
4438 | IS_NODE_WITH_NAME(Block, simplified, block); |
4439 | ASSERT_EQ(block->nstmts(), 1); |
4440 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4441 | IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); |
4442 | ASSERT_EQ(false_stmt->nstmts(), 2); |
4443 | ASSERT_EQ(cond->true_stmt(), nullptr); |
4444 | } |
4445 | |
4446 | { |
4447 | // Can fuse with both true and false stmt. |
4448 | auto body = Block::make( |
4449 | {Cond::make( |
4450 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4451 | Store::make(a, {0}, i), |
4452 | Store::make(b, {0}, i)), |
4453 | Cond::make( |
4454 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4455 | Store::make(a, {1}, i), |
4456 | Store::make(b, {1}, i))}); |
4457 | |
4458 | StmtPtr simplified = IRSimplifier::simplify(body); |
4459 | IS_NODE_WITH_NAME(Block, simplified, block); |
4460 | ASSERT_EQ(block->nstmts(), 1); |
4461 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4462 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4463 | ASSERT_EQ(true_stmt->nstmts(), 2); |
4464 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); |
4465 | ASSERT_EQ(false_stmt->nstmts(), 2); |
4466 | } |
4467 | |
4468 | { |
4469 | // Can fuse with mismatched true / false stmt existing |
4470 | auto body = Block::make( |
4471 | {Cond::make( |
4472 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4473 | Store::make(a, {0}, i), |
4474 | nullptr), |
4475 | Cond::make( |
4476 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4477 | nullptr, |
4478 | Store::make(b, {1}, i))}); |
4479 | |
4480 | StmtPtr simplified = IRSimplifier::simplify(body); |
4481 | IS_NODE_WITH_NAME(Block, simplified, block); |
4482 | ASSERT_EQ(block->nstmts(), 1); |
4483 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4484 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4485 | ASSERT_EQ(true_stmt->nstmts(), 1); |
4486 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); |
4487 | ASSERT_EQ(false_stmt->nstmts(), 1); |
4488 | } |
4489 | |
4490 | { |
4491 | // Can fuse partial block contents, ie when there are non fused stmts before |
4492 | // and after. |
4493 | // before: |
4494 | // if (j < 10) { A[0] = j; } |
4495 | // if (i < 10) { A[0] = i; } |
4496 | // if (i < 10) { A[1] = i; } |
4497 | // if (i < 11) { A[1] = j; } |
4498 | // |
4499 | // after: |
4500 | // |
4501 | // if (j < 10) { A[0] = j; } |
4502 | // if (i < 10) { |
4503 | // A[0] = i; |
4504 | // A[1] = i; |
4505 | // } |
4506 | // if (i < 11) { A[1] = j; } |
4507 | |
4508 | auto body = Block::make({ |
4509 | Cond::make( |
4510 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4511 | Store::make(a, {0}, j), |
4512 | nullptr), |
4513 | Cond::make( |
4514 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4515 | Store::make(a, {0}, i), |
4516 | nullptr), |
4517 | Cond::make( |
4518 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4519 | Store::make(a, {1}, i), |
4520 | nullptr), |
4521 | Cond::make( |
4522 | CompareSelect::make(i, 11, CompareSelectOperation::kLT), |
4523 | Store::make(a, {1}, j), |
4524 | nullptr), |
4525 | }); |
4526 | StmtPtr simplified = IRSimplifier::simplify(body); |
4527 | IS_NODE_WITH_NAME(Block, simplified, block); |
4528 | ASSERT_EQ(block->nstmts(), 3); |
4529 | auto it = block->begin(); |
4530 | it++; |
4531 | IS_NODE_WITH_NAME(Cond, *it, cond); |
4532 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4533 | ASSERT_EQ(true_stmt->nstmts(), 2); |
4534 | ASSERT_EQ(cond->false_stmt(), nullptr); |
4535 | } |
4536 | |
4537 | { |
4538 | // Can fuse longer sequences of identical conditions. |
4539 | auto body = Block::make({ |
4540 | Cond::make( |
4541 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4542 | Store::make(a, {0}, j), |
4543 | nullptr), |
4544 | Cond::make( |
4545 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4546 | Store::make(a, {0}, i), |
4547 | nullptr), |
4548 | Cond::make( |
4549 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4550 | Store::make(a, {1}, i), |
4551 | nullptr), |
4552 | Cond::make( |
4553 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4554 | Store::make(a, {1}, j), |
4555 | nullptr), |
4556 | }); |
4557 | StmtPtr simplified = IRSimplifier::simplify(body); |
4558 | IS_NODE_WITH_NAME(Block, simplified, block); |
4559 | ASSERT_EQ(block->nstmts(), 1); |
4560 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4561 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4562 | ASSERT_EQ(true_stmt->nstmts(), 4); |
4563 | ASSERT_EQ(cond->false_stmt(), nullptr); |
4564 | } |
4565 | |
4566 | { |
4567 | // Can't fuse through a non condition. |
4568 | auto body = Block::make({ |
4569 | Cond::make( |
4570 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4571 | Store::make(a, {0}, j), |
4572 | nullptr), |
4573 | Cond::make( |
4574 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4575 | Store::make(a, {0}, i), |
4576 | nullptr), |
4577 | Store::make(b, {1}, i + j), |
4578 | Cond::make( |
4579 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4580 | Store::make(a, {1}, i), |
4581 | nullptr), |
4582 | Cond::make( |
4583 | CompareSelect::make(i, 10, CompareSelectOperation::kLT), |
4584 | Store::make(a, {1}, j), |
4585 | nullptr), |
4586 | }); |
4587 | StmtPtr simplified = IRSimplifier::simplify(body); |
4588 | IS_NODE_WITH_NAME(Block, simplified, block); |
4589 | ASSERT_EQ(block->nstmts(), 3); |
4590 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4591 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4592 | ASSERT_EQ(true_stmt->nstmts(), 2); |
4593 | ASSERT_EQ(cond->false_stmt(), nullptr); |
4594 | |
4595 | IS_NODE_WITH_NAME(Cond, block->back(), cond2); |
4596 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); |
4597 | ASSERT_EQ(true_stmt2->nstmts(), 2); |
4598 | ASSERT_EQ(cond2->false_stmt(), nullptr); |
4599 | |
4600 | auto it = block->begin(); |
4601 | it++; |
4602 | IS_NODE_WITH_NAME(Store, *it, middle); |
4603 | } |
4604 | |
4605 | { |
4606 | // Can fuse if the conditions simplify to the same thing. |
4607 | auto body = Block::make( |
4608 | {Cond::make( |
4609 | CompareSelect::make( |
4610 | i * 2, |
4611 | ExprHandle(87) % ExprHandle(11), |
4612 | CompareSelectOperation::kLT), |
4613 | Store::make(a, {0}, i), |
4614 | nullptr), |
4615 | Cond::make( |
4616 | CompareSelect::make( |
4617 | i * 2, |
4618 | ExprHandle(300) / ExprHandle(30), |
4619 | CompareSelectOperation::kLT), |
4620 | Store::make(a, {1}, i), |
4621 | nullptr)}); |
4622 | StmtPtr simplified = IRSimplifier::simplify(body); |
4623 | IS_NODE_WITH_NAME(Block, simplified, block); |
4624 | ASSERT_EQ(block->nstmts(), 1); |
4625 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4626 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4627 | ASSERT_EQ(true_stmt->nstmts(), 2); |
4628 | ASSERT_EQ(cond->false_stmt(), nullptr); |
4629 | } |
4630 | |
4631 | { |
4632 | // Can fuse non-CompareSelects. |
4633 | // if (i) { X } if (i) { Y } => if (i) { X; Y } |
4634 | auto body = Block::make( |
4635 | {Cond::make(i, Store::make(a, {0}, i), nullptr), |
4636 | Cond::make(i, Store::make(a, {1}, i), nullptr)}); |
4637 | |
4638 | StmtPtr simplified = IRSimplifier::simplify(body); |
4639 | IS_NODE_WITH_NAME(Block, simplified, block); |
4640 | ASSERT_EQ(block->nstmts(), 1); |
4641 | IS_NODE_WITH_NAME(Cond, block->front(), cond); |
4642 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); |
4643 | ASSERT_EQ(true_stmt->nstmts(), 2); |
4644 | ASSERT_EQ(cond->false_stmt(), nullptr); |
4645 | } |
4646 | |
4647 | { |
4648 | // Sanity check wont fuse different non-CompareSelects. |
4649 | auto body = Block::make( |
4650 | {Cond::make(i, Store::make(a, {0}, i), nullptr), |
4651 | Cond::make(j, Store::make(a, {1}, i), nullptr)}); |
4652 | |
4653 | StmtPtr simplified = IRSimplifier::simplify(body); |
4654 | IS_NODE_WITH_NAME(Block, simplified, block); |
4655 | ASSERT_EQ(block->nstmts(), 2); |
4656 | IS_NODE_WITH_NAME(Cond, block->front(), cond1); |
4657 | IS_NODE_WITH_NAME(Cond, block->back(), cond2); |
4658 | } |
4659 | |
4660 | { |
4661 | // Sanity check constant condition elimination still occurs when merging is |
4662 | // possible. |
4663 | auto body = Block::make( |
4664 | {Cond::make(1, Store::make(a, {0}, i), nullptr), |
4665 | Cond::make(1, Store::make(a, {1}, i), nullptr)}); |
4666 | StmtPtr simplified = IRSimplifier::simplify(body); |
4667 | IS_NODE_WITH_NAME(Block, simplified, block); |
4668 | ASSERT_EQ(block->nstmts(), 2); |
4669 | IS_NODE_WITH_NAME(Store, block->front(), store1); |
4670 | IS_NODE_WITH_NAME(Store, block->back(), store2); |
4671 | } |
4672 | |
4673 | { |
4674 | // Sanity check for-cond reordering occurs after fusing. |
4675 | auto body = For::make( |
4676 | i, |
4677 | 0, |
4678 | 4, |
4679 | Block::make( |
4680 | {Cond::make( |
4681 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4682 | Store::make(a, {1}, Load::make(b, {0})), |
4683 | nullptr), |
4684 | Cond::make( |
4685 | CompareSelect::make(j, 10, CompareSelectOperation::kLT), |
4686 | Store::make(a, {2}, Load::make(b, {0})), |
4687 | nullptr)})); |
4688 | |
4689 | StmtPtr simplified = IRSimplifier::simplify(body); |
4690 | IS_NODE_WITH_NAME(Cond, simplified, cond); |
4691 | IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); |
4692 | IS_NODE_WITH_NAME(For, true_block->front(), loop); |
4693 | } |
4694 | } |
4695 | |
4696 | TEST(Simplify, SimplifySyncThreads) { |
4697 | BufHandle a("A" , {4}, kInt); |
4698 | VarHandle i("i" , kInt); |
4699 | |
4700 | { |
4701 | // Merge two inner SyncThreads. |
4702 | auto body = Block::make( |
4703 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
4704 | {Store::make(a, {0}, 1), |
4705 | alloc<SyncThreads>(), |
4706 | alloc<SyncThreads>(), |
4707 | Store::make(a, {1}, 0)}); |
4708 | StmtPtr simplified = IRSimplifier::simplify(body); |
4709 | IS_NODE_WITH_NAME(Block, simplified, block); |
4710 | ASSERT_EQ(block->nstmts(), 3); |
4711 | auto it = block->begin(); |
4712 | IS_NODE(Store, *it++); |
4713 | IS_NODE(SyncThreads, *it++); |
4714 | IS_NODE(Store, *it++); |
4715 | } |
4716 | |
4717 | { |
4718 | // Eliminate outer SyncThreads. |
4719 | auto body = Block::make( |
4720 | {alloc<SyncThreads>(), Store::make(a, {1}, 0), alloc<SyncThreads>()}); |
4721 | |
4722 | StmtPtr simplified = IRSimplifier::simplify(body); |
4723 | IS_NODE_WITH_NAME(Block, simplified, block); |
4724 | ASSERT_EQ(block->nstmts(), 1); |
4725 | auto it = block->begin(); |
4726 | IS_NODE(Store, *it); |
4727 | } |
4728 | |
4729 | { |
4730 | // Merge many inner SyncThreads. |
4731 | auto body = Block::make( |
4732 | {Store::make(a, {0}, 1), |
4733 | alloc<SyncThreads>(), |
4734 | alloc<SyncThreads>(), |
4735 | alloc<SyncThreads>(), |
4736 | alloc<SyncThreads>(), |
4737 | alloc<SyncThreads>(), |
4738 | Store::make(a, {1}, 0)}); |
4739 | |
4740 | StmtPtr simplified = IRSimplifier::simplify(body); |
4741 | IS_NODE_WITH_NAME(Block, simplified, block); |
4742 | ASSERT_EQ(block->nstmts(), 3); |
4743 | auto it = block->begin(); |
4744 | IS_NODE(Store, *it++); |
4745 | IS_NODE(SyncThreads, *it++); |
4746 | IS_NODE(Store, *it++); |
4747 | } |
4748 | |
4749 | { |
4750 | // Merge multiple outer SyncThreads. |
4751 | auto body = Block::make( |
4752 | {alloc<SyncThreads>(), |
4753 | alloc<SyncThreads>(), |
4754 | Store::make(a, {1}, 0), |
4755 | alloc<SyncThreads>(), |
4756 | alloc<SyncThreads>(), |
4757 | alloc<SyncThreads>(), |
4758 | alloc<SyncThreads>()}); |
4759 | |
4760 | StmtPtr simplified = IRSimplifier::simplify(body); |
4761 | IS_NODE_WITH_NAME(Block, simplified, block); |
4762 | ASSERT_EQ(block->nstmts(), 1); |
4763 | auto it = block->begin(); |
4764 | IS_NODE(Store, *it); |
4765 | } |
4766 | |
4767 | { |
4768 | // Merge multiple sections; |
4769 | auto body = Block::make( |
4770 | {Store::make(a, {0}, 1), |
4771 | alloc<SyncThreads>(), |
4772 | alloc<SyncThreads>(), |
4773 | Store::make(a, {1}, 0), |
4774 | Store::make(a, {2}, 0), |
4775 | alloc<SyncThreads>(), |
4776 | alloc<SyncThreads>(), |
4777 | alloc<SyncThreads>(), |
4778 | Store::make(a, {3}, 0)}); |
4779 | |
4780 | StmtPtr simplified = IRSimplifier::simplify(body); |
4781 | IS_NODE_WITH_NAME(Block, simplified, block); |
4782 | ASSERT_EQ(block->nstmts(), 6); |
4783 | auto it = block->begin(); |
4784 | IS_NODE(Store, *it++); |
4785 | IS_NODE(SyncThreads, *it++); |
4786 | IS_NODE(Store, *it++); |
4787 | IS_NODE(Store, *it++); |
4788 | IS_NODE(SyncThreads, *it++); |
4789 | IS_NODE(Store, *it++); |
4790 | } |
4791 | } |
4792 | |
4793 | TEST(Simplify, SimplifyRampSubBroadcast) { |
4794 | int num_lanes = 4; |
4795 | ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); |
4796 | ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); |
4797 | ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); |
4798 | RampPtr newRamp = simplified.AsNode<Ramp>(); |
4799 | IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); |
4800 | ASSERT_EQ(base->value(), 5); |
4801 | IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); |
4802 | ASSERT_EQ(stride->value(), 6); |
4803 | ASSERT_EQ(newRamp->lanes(), num_lanes); |
4804 | } |
4805 | |
4806 | TEST(Simplify, SimplifyBroadcastTermExpander) { |
4807 | int num_lanes = 8; |
4808 | ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); |
4809 | ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); |
4810 | ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); |
4811 | // NB: We need a term in the middle which isn't simplified to trigger the |
4812 | // relevant path in TermExpander::mutate. The two bc1 terms are brought |
4813 | // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. |
4814 | ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); |
4815 | BufHandle buf("buf" , {num_lanes}, kInt); |
4816 | // The result isn't fully simplified currently and thus would be brittle to |
4817 | // match. Observe its value instead. |
4818 | auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); |
4819 | SimpleIREvaluator eval(store, {buf}); |
4820 | std::vector<int> output(num_lanes); |
4821 | eval(output); |
4822 | for (const auto i : c10::irange(num_lanes)) { |
4823 | ASSERT_EQ(output[i], 2); |
4824 | } |
4825 | } |
4826 | |
4827 | TEST(Simplify, CompareSelectLoopBounds) { |
4828 | constexpr int N = 8; |
4829 | BufHandle b("b" , {N}, kFloat); |
4830 | VarHandle n("n" , kInt); |
4831 | VarHandle m("m" , kInt); |
4832 | VarHandle var_N("var_N" , kInt); |
4833 | VarHandle var_M("var_M" , kInt); |
4834 | |
4835 | auto test_case_fn = [](const VarHandle& n, |
4836 | const BufHandle& b, |
4837 | const ExprHandle& start, |
4838 | const ExprHandle& stop, |
4839 | const int& cmp_val, |
4840 | const CompareSelectOperation& cmp_op, |
4841 | const std::string& check_string) { |
4842 | StmtPtr s = For::make( |
4843 | n, |
4844 | start, |
4845 | stop, |
4846 | b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); |
4847 | s = IRSimplifier::simplify(s); |
4848 | std::ostringstream oss; |
4849 | oss << *s; |
4850 | std::string target_string = "# CHECK: " ; |
4851 | target_string += check_string; |
4852 | torch::jit::testing::FileCheck().run(target_string, oss.str()); |
4853 | }; |
4854 | |
4855 | auto test_case_nest_loops_fn = [](const VarHandle& n, |
4856 | const VarHandle& m, |
4857 | const BufHandle& b, |
4858 | const ExprHandle& n_start, |
4859 | const ExprHandle& n_stop, |
4860 | const ExprHandle& m_start, |
4861 | const ExprHandle& m_stop, |
4862 | const CompareSelectOperation& cmp_op, |
4863 | const std::string& check_string) { |
4864 | StmtPtr s = For::make( |
4865 | m, |
4866 | m_start, |
4867 | m_stop, |
4868 | b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); |
4869 | StmtPtr root_s = For::make(n, n_start, n_stop, s); |
4870 | root_s = IRSimplifier::simplify(root_s); |
4871 | std::ostringstream oss; |
4872 | oss << *root_s; |
4873 | std::string target_string = "# CHECK: " ; |
4874 | target_string += check_string; |
4875 | torch::jit::testing::FileCheck().run(target_string, oss.str()); |
4876 | }; |
4877 | |
4878 | // Before: |
4879 | // for (const auto n : c10::irange(1, N)) { |
4880 | // b[n] = n < 1 ? 0.f : 1.f; |
4881 | // } |
4882 | // After: |
4883 | // for (const auto n : c10::irange(1, N)) { |
4884 | // b[n] = 1.f; |
4885 | // } |
4886 | test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;" ); |
4887 | |
4888 | // Before: |
4889 | // for (const auto n : c10::irange(1, N)) { |
4890 | // b[n] = n <= 1 ? 0.f : 1.f; |
4891 | // } |
4892 | // After: |
4893 | // for (const auto n : c10::irange(1, N)) { |
4894 | // b[n] = n <= 1 ? 0.f : 1.f; |
4895 | // } |
4896 | test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;" ); |
4897 | |
4898 | // Before: |
4899 | // for (const auto n : c10::irange(1, N)) { |
4900 | // b[n] = n <= 0 ? 0.f : 1.f; |
4901 | // } |
4902 | // After: |
4903 | // for (const auto n : c10::irange(1, N)) { |
4904 | // b[n] = 1.f; |
4905 | // } |
4906 | test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;" ); |
4907 | |
4908 | // Before: |
4909 | // for (const auto n : c10::irange(1, N)) { |
4910 | // b[n] = n < 0 ? 0.f : 1.f; |
4911 | // } |
4912 | // After: |
4913 | // for (const auto n : c10::irange(1, N)) { |
4914 | // b[n] = 1.f; |
4915 | // } |
4916 | test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;" ); |
4917 | |
4918 | // Before: |
4919 | // for (const auto n : c10::irange(1, N)) { |
4920 | // b[n] = n < 8 ? 0.f : 1.f; |
4921 | // } |
4922 | // After: |
4923 | // for (const auto n : c10::irange(1, N)) { |
4924 | // b[n] = 0.f; |
4925 | // } |
4926 | test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;" ); |
4927 | |
4928 | // Before: |
4929 | // for (const auto n : c10::irange(1, N)) { |
4930 | // b[n] = n <= 7 ? 0.f : 1.f; |
4931 | // } |
4932 | // After: |
4933 | // for (const auto n : c10::irange(1, N)) { |
4934 | // b[n] = 0.f; |
4935 | // } |
4936 | test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;" ); |
4937 | |
4938 | // Before: |
4939 | // for (const auto n : c10::irange(1, N)) { |
4940 | // b[n] = n <= 8 ? 0.f : 1.f; |
4941 | // } |
4942 | // After: |
4943 | // for (const auto n : c10::irange(1, N)) { |
4944 | // b[n] = 0.f; |
4945 | // } |
4946 | test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;" ); |
4947 | |
4948 | // Before: |
4949 | // for (const auto n : c10::irange(1, N)) { |
4950 | // b[n] = n < 7 ? 0.f : 1.f; |
4951 | // } |
4952 | // After: |
4953 | // for (const auto n : c10::irange(1, N)) { |
4954 | // b[n] = n < 7 ? 0.f : 1.f; |
4955 | // } |
4956 | test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;" ); |
4957 | |
4958 | // Before: |
4959 | // for (const auto n : c10::irange(1, N)) { |
4960 | // b[n] = n > 0 ? 0.f : 1.f; |
4961 | // } |
4962 | // After: |
4963 | // for (const auto n : c10::irange(1, N)) { |
4964 | // b[n] = 0.f; |
4965 | // } |
4966 | test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;" ); |
4967 | |
4968 | // Before: |
4969 | // for (const auto n : c10::irange(1, N)) { |
4970 | // b[n] = n > 1 ? 0.f : 1.f; |
4971 | // } |
4972 | // After: |
4973 | // for (const auto n : c10::irange(1, N)) { |
4974 | // b[n] = n > 1 ? 0.f : 1.f; |
4975 | // } |
4976 | test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;" ); |
4977 | |
4978 | // Before: |
4979 | // for (const auto n : c10::irange(1, N)) { |
4980 | // b[n] = n >= 1 ? 0.f : 1.f; |
4981 | // } |
4982 | // After: |
4983 | // for (const auto n : c10::irange(1, N)) { |
4984 | // b[n] = 0.f; |
4985 | // } |
4986 | test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;" ); |
4987 | |
4988 | // Before: |
4989 | // for (const auto n : c10::irange(1, N)) { |
4990 | // b[n] = n > 7 ? 0.f : 1.f; |
4991 | // } |
4992 | // After: |
4993 | // for (const auto n : c10::irange(1, N)) { |
4994 | // b[n] = 1.f; |
4995 | // } |
4996 | test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;" ); |
4997 | |
4998 | // Before: |
4999 | // for (const auto n : c10::irange(1, N)) { |
5000 | // b[n] = n >= 7 ? 0.f : 1.f; |
5001 | // } |
5002 | // After: |
5003 | // for (const auto n : c10::irange(1, N)) { |
5004 | // b[n] = n >= 7 ? 0.f : 1.f; |
5005 | // } |
5006 | test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;" ); |
5007 | |
5008 | // Before: |
5009 | // for (const auto n : c10::irange(1, N)) { |
5010 | // b[n] = n > 5 ? 0.f : 1.f; |
5011 | // } |
5012 | // After: |
5013 | // for (const auto n : c10::irange(1, N)) { |
5014 | // b[n] = n > 5 ? 0.f : 1.f; |
5015 | // } |
5016 | test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;" ); |
5017 | |
5018 | // Before: |
5019 | // for (const auto n : c10::irange(1, N)) { |
5020 | // b[n] = n >= 5 ? 0.f : 1.f; |
5021 | // } |
5022 | // After: |
5023 | // for (const auto n : c10::irange(1, N)) { |
5024 | // b[n] = n >= 5 ? 0.f : 1.f; |
5025 | // } |
5026 | test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;" ); |
5027 | |
5028 | // Before: |
5029 | // for (const auto n : c10::irange(1, N)) { |
5030 | // b[n] = n > 8 ? 0.f : 1.f; |
5031 | // } |
5032 | // After: |
5033 | // for (const auto n : c10::irange(1, N)) { |
5034 | // b[n] = 1.f; |
5035 | // } |
5036 | test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;" ); |
5037 | |
5038 | // Before: |
5039 | // for (const auto n : c10::irange(1, N)) { |
5040 | // b[n] = n >= 8 ? 0.f : 1.f; |
5041 | // } |
5042 | // After: |
5043 | // for (const auto n : c10::irange(1, N)) { |
5044 | // b[n] = 1.f; |
5045 | // } |
5046 | test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;" ); |
5047 | |
5048 | // Before: |
5049 | // for (const auto n : c10::irange(1, 2)) { |
5050 | // b[n] = n == 1 ? 0.f : 1.f; |
5051 | // } |
5052 | // After: |
5053 | // for (const auto n : c10::irange(1, 2)) { |
5054 | // b[1] = 0.f; |
5055 | // } |
5056 | test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;" ); |
5057 | |
5058 | // Before: |
5059 | // for (const auto n : c10::irange(1, N)) { |
5060 | // b[n] = n == 1 ? 0.f : 1.f; |
5061 | // } |
5062 | // After: |
5063 | // for (const auto n : c10::irange(1, N)) { |
5064 | // b[n] = n == 1 ? 0.f : 1.f; |
5065 | // } |
5066 | test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;" ); |
5067 | |
5068 | // Before: |
5069 | // for (const auto n : c10::irange(1, N)) { |
5070 | // b[n] = n == 0 ? 0.f : 1.f; |
5071 | // } |
5072 | // After: |
5073 | // for (const auto n : c10::irange(1, N)) { |
5074 | // b[n] = 1.f; |
5075 | // } |
5076 | test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;" ); |
5077 | |
5078 | // Before: |
5079 | // for (const auto n : c10::irange(1, N)) { |
5080 | // b[n] = n == 7 ? 0.f : 1.f; |
5081 | // } |
5082 | // After: |
5083 | // for (const auto n : c10::irange(1, N)) { |
5084 | // b[n] = n == 7 ? 0.f : 1.f; |
5085 | // } |
5086 | test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;" ); |
5087 | |
5088 | // Before: |
5089 | // for (const auto n : c10::irange(1, N)) { |
5090 | // b[n] = n == 8 ? 0.f : 1.f; |
5091 | // } |
5092 | // After: |
5093 | // for (const auto n : c10::irange(1, N)) { |
5094 | // b[n] = 1.f; |
5095 | // } |
5096 | test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;" ); |
5097 | |
5098 | // Before: |
5099 | // for (const auto n : c10::irange(1, N)) { |
5100 | // b[n] = n != 1 ? 0.f : 1.f; |
5101 | // } |
5102 | // After: |
5103 | // for (const auto n : c10::irange(1, N)) { |
5104 | // b[n] = n != 1 ? 0.f : 1.f; |
5105 | // } |
5106 | test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;" ); |
5107 | |
5108 | // Before: |
5109 | // for (const auto n : c10::irange(1, N)) { |
5110 | // b[n] = n != 7 ? 0.f : 1.f; |
5111 | // } |
5112 | // After: |
5113 | // for (const auto n : c10::irange(1, N)) { |
5114 | // b[n] = n != 7 ? 0.f : 1.f; |
5115 | // } |
5116 | test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;" ); |
5117 | |
5118 | // Before: |
5119 | // for (const auto n : c10::irange(1, N)) { |
5120 | // b[n] = n != 5 ? 0.f : 1.f; |
5121 | // } |
5122 | // After: |
5123 | // for (const auto n : c10::irange(1, N)) { |
5124 | // b[n] = n != 5 ? 0.f : 1.f; |
5125 | // } |
5126 | test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;" ); |
5127 | |
5128 | // Before: |
5129 | // for (const auto n : c10::irange(1, N)) { |
5130 | // b[n] = n != 0 ? 0.f : 1.f; |
5131 | // } |
5132 | // After: |
5133 | // for (const auto n : c10::irange(1, N)) { |
5134 | // b[n] = 0.f; |
5135 | // } |
5136 | test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;" ); |
5137 | |
5138 | // Before: |
5139 | // for (const auto n : c10::irange(1, N)) { |
5140 | // b[n] = n != 8 ? 0.f : 1.f; |
5141 | // } |
5142 | // After: |
5143 | // for (const auto n : c10::irange(1, N)) { |
5144 | // b[n] = 0.f; |
5145 | // } |
5146 | test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;" ); |
5147 | |
5148 | // Before: |
5149 | // for (const auto n : c10::irange(10, 20)) { |
5150 | // for(const auto m : c10::irange(30, 40)) { |
5151 | // b[n, m] = (n != m) ? 0.f : 1.f; |
5152 | // } |
5153 | // } |
5154 | // After: |
5155 | // for (const auto n : c10::irange(10, 20)) { |
5156 | // for(const auto m : c10::irange(30, 40)) { |
5157 | // b[n, m] = 0.f; |
5158 | // } |
5159 | // } |
5160 | test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;" ); |
5161 | test_case_nest_loops_fn( |
5162 | n, |
5163 | m, |
5164 | b, |
5165 | var_N + 10, |
5166 | var_N + 20, |
5167 | var_N + 30, |
5168 | var_N + 40, |
5169 | kNE, |
5170 | "b[n, m] = 0.f;" ); |
5171 | test_case_nest_loops_fn( |
5172 | n, |
5173 | m, |
5174 | b, |
5175 | var_N + 10, |
5176 | var_N + 20, |
5177 | var_M + 30, |
5178 | var_M + 40, |
5179 | kNE, |
5180 | "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5181 | |
5182 | // Before: |
5183 | // for (const auto n : c10::irange(30, 40)) { |
5184 | // for(const auto m : c10::irange(10, 20)) { |
5185 | // b[n, m] = (n != m) ? 0.f : 1.f; |
5186 | // } |
5187 | // } |
5188 | // After: |
5189 | // for (const auto n : c10::irange(30, 40)) { |
5190 | // for(const auto m : c10::irange(10, 20)) { |
5191 | // b[n, m] = 0.f; |
5192 | // } |
5193 | // } |
5194 | test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;" ); |
5195 | test_case_nest_loops_fn( |
5196 | n, |
5197 | m, |
5198 | b, |
5199 | var_N + 30, |
5200 | var_N + 40, |
5201 | var_N + 10, |
5202 | var_N + 20, |
5203 | kNE, |
5204 | "b[n, m] = 0.f;" ); |
5205 | test_case_nest_loops_fn( |
5206 | n, |
5207 | m, |
5208 | b, |
5209 | var_N + 30, |
5210 | var_N + 40, |
5211 | var_M + 10, |
5212 | var_M + 20, |
5213 | kNE, |
5214 | "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5215 | |
5216 | // Before: |
5217 | // for (const auto n : c10::irange(30, 40)) { |
5218 | // for(const auto m : c10::irange(10, 31)) { |
5219 | // b[n, m] = (n != m) ? 0.f : 1.f; |
5220 | // } |
5221 | // } |
5222 | // After: |
5223 | // for (const auto n : c10::irange(30, 40)) { |
5224 | // for(const auto m : c10::irange(10, 31)) { |
5225 | // b[n, m] = (n != m) ? 0.f : 1.f; |
5226 | // } |
5227 | // } |
5228 | test_case_nest_loops_fn( |
5229 | n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5230 | test_case_nest_loops_fn( |
5231 | n, |
5232 | m, |
5233 | b, |
5234 | var_N + 30, |
5235 | var_N + 40, |
5236 | var_N + 10, |
5237 | var_N + 31, |
5238 | kNE, |
5239 | "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5240 | test_case_nest_loops_fn( |
5241 | n, |
5242 | m, |
5243 | b, |
5244 | var_N + 30, |
5245 | var_N + 40, |
5246 | var_M + 10, |
5247 | var_M + 31, |
5248 | kNE, |
5249 | "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5250 | |
5251 | // Before: |
5252 | // for (const auto n : c10::irange(10, 31)) { |
5253 | // for(const auto m : c10::irange(30, 40)) { |
5254 | // b[n, m] = (n != m) ? 0.f : 1.f; |
5255 | // } |
5256 | // } |
5257 | // After: |
5258 | // for (const auto n : c10::irange(10, 31)) { |
5259 | // for(const auto m : c10::irange(30, 40)) { |
5260 | // b[n, m] = (n != m) ? 0.f : 1.f; |
5261 | // } |
5262 | // } |
5263 | test_case_nest_loops_fn( |
5264 | n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5265 | test_case_nest_loops_fn( |
5266 | n, |
5267 | m, |
5268 | b, |
5269 | var_N + 10, |
5270 | var_N + 31, |
5271 | var_N + 30, |
5272 | var_N + 40, |
5273 | kNE, |
5274 | "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5275 | test_case_nest_loops_fn( |
5276 | n, |
5277 | m, |
5278 | b, |
5279 | var_N + 10, |
5280 | var_N + 31, |
5281 | var_M + 30, |
5282 | var_M + 40, |
5283 | kNE, |
5284 | "b[n, m] = n!=m ? 0.f : 1.f;" ); |
5285 | |
5286 | // Before: |
5287 | // for (const auto n : c10::irange(10, 20)) { |
5288 | // for(const auto m : c10::irange(30, 40)) { |
5289 | // b[n, m] = (n < m) ? 0.f : 1.f; |
5290 | // } |
5291 | // } |
5292 | // After: |
5293 | // for (const auto n : c10::irange(10, 20)) { |
5294 | // for(const auto m : c10::irange(30, 40)) { |
5295 | // b[n, m] = 0.f; |
5296 | // } |
5297 | // } |
5298 | test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;" ); |
5299 | test_case_nest_loops_fn( |
5300 | n, |
5301 | m, |
5302 | b, |
5303 | var_N + 10, |
5304 | var_N + 20, |
5305 | var_N + 30, |
5306 | var_N + 40, |
5307 | kLT, |
5308 | "b[n, m] = 0.f;" ); |
5309 | test_case_nest_loops_fn( |
5310 | n, |
5311 | m, |
5312 | b, |
5313 | var_N + 10, |
5314 | var_N + 20, |
5315 | var_M + 30, |
5316 | var_M + 40, |
5317 | kLT, |
5318 | "b[n, m] = n<m ? 0.f : 1.f;" ); |
5319 | |
5320 | // Before: |
5321 | // for (const auto n : c10::irange(30, 40)) { |
5322 | // for(const auto m : c10::irange(10, 31)) { |
5323 | // b[n, m] = (n < m) ? 0.f : 1.f; |
5324 | // } |
5325 | // } |
5326 | // After: |
5327 | // for (const auto n : c10::irange(30, 40)) { |
5328 | // for(const auto m : c10::irange(10, 31)) { |
5329 | // b[n, m] = 1.f; |
5330 | // } |
5331 | // } |
5332 | test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kLT, "b[n, m] = 1.f;" ); |
5333 | test_case_nest_loops_fn( |
5334 | n, |
5335 | m, |
5336 | b, |
5337 | var_N + 30, |
5338 | var_N + 40, |
5339 | var_N + 10, |
5340 | var_N + 31, |
5341 | kLT, |
5342 | "b[n, m] = 1.f;" ); |
5343 | test_case_nest_loops_fn( |
5344 | n, |
5345 | m, |
5346 | b, |
5347 | var_N + 30, |
5348 | var_N + 40, |
5349 | var_M + 10, |
5350 | var_M + 31, |
5351 | kLT, |
5352 | "b[n, m] = n<m ? 0.f : 1.f;" ); |
5353 | |
5354 | // Before: |
5355 | // for (const auto n : c10::irange(30, 40)) { |
5356 | // for(const auto m : c10::irange(10, 20)) { |
5357 | // b[n, m] = (n > m) ? 0.f : 1.f; |
5358 | // } |
5359 | // } |
5360 | // After: |
5361 | // for (const auto n : c10::irange(30, 40)) { |
5362 | // for(const auto m : c10::irange(10, 20)) { |
5363 | // b[n, m] = 0.f; |
5364 | // } |
5365 | // } |
5366 | test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;" ); |
5367 | test_case_nest_loops_fn( |
5368 | n, |
5369 | m, |
5370 | b, |
5371 | var_N + 30, |
5372 | var_N + 40, |
5373 | var_N + 10, |
5374 | var_N + 20, |
5375 | kGT, |
5376 | "b[n, m] = 0.f;" ); |
5377 | test_case_nest_loops_fn( |
5378 | n, |
5379 | m, |
5380 | b, |
5381 | var_N + 30, |
5382 | var_N + 40, |
5383 | var_M + 10, |
5384 | var_M + 20, |
5385 | kGT, |
5386 | "b[n, m] = n>m ? 0.f : 1.f;" ); |
5387 | |
5388 | // Before: |
5389 | // for (const auto n : c10::irange(10, 31)) { |
5390 | // for(const auto m : c10::irange(30, 40)) { |
5391 | // b[n, m] = (n > m) ? 0.f : 1.f; |
5392 | // } |
5393 | // } |
5394 | // After: |
5395 | // for (const auto n : c10::irange(10, 31)) { |
5396 | // for(const auto m : c10::irange(30, 40)) { |
5397 | // b[n, m] = 1.f; |
5398 | // } |
5399 | // } |
5400 | test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;" ); |
5401 | test_case_nest_loops_fn( |
5402 | n, |
5403 | m, |
5404 | b, |
5405 | var_N + 10, |
5406 | var_N + 31, |
5407 | var_N + 30, |
5408 | var_N + 40, |
5409 | kGT, |
5410 | "b[n, m] = 1.f;" ); |
5411 | test_case_nest_loops_fn( |
5412 | n, |
5413 | m, |
5414 | b, |
5415 | var_N + 10, |
5416 | var_N + 31, |
5417 | var_M + 30, |
5418 | var_M + 40, |
5419 | kGT, |
5420 | "b[n, m] = n>m ? 0.f : 1.f;" ); |
5421 | |
5422 | // Before: |
5423 | // for (const auto n : c10::irange(30, 40)) { |
5424 | // for(const auto m : c10::irange(10, 31)) { |
5425 | // b[n, m] = (n >= m) ? 0.f : 1.f; |
5426 | // } |
5427 | // } |
5428 | // After: |
5429 | // for (const auto n : c10::irange(30, 40)) { |
5430 | // for(const auto m : c10::irange(10, 31)) { |
5431 | // b[n, m] = 0.f; |
5432 | // } |
5433 | // } |
5434 | test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;" ); |
5435 | test_case_nest_loops_fn( |
5436 | n, |
5437 | m, |
5438 | b, |
5439 | var_N + 30, |
5440 | var_N + 40, |
5441 | var_N + 10, |
5442 | var_N + 31, |
5443 | kGE, |
5444 | "b[n, m] = 0.f;" ); |
5445 | test_case_nest_loops_fn( |
5446 | n, |
5447 | m, |
5448 | b, |
5449 | var_N + 30, |
5450 | var_N + 40, |
5451 | var_M + 10, |
5452 | var_M + 31, |
5453 | kGE, |
5454 | "b[n, m] = n>=m ? 0.f : 1.f;" ); |
5455 | |
5456 | // Before: |
5457 | // for (const auto n : c10::irange(10, 20)) { |
5458 | // for(const auto m : c10::irange(30, 40)) { |
5459 | // b[n, m] = (n >= m) ? 0.f : 1.f; |
5460 | // } |
5461 | // } |
5462 | // After: |
5463 | // for (const auto n : c10::irange(10, 20)) { |
5464 | // for(const auto m : c10::irange(30, 40)) { |
5465 | // b[n, m] = 1.f; |
5466 | // } |
5467 | // } |
5468 | test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;" ); |
5469 | test_case_nest_loops_fn( |
5470 | n, |
5471 | m, |
5472 | b, |
5473 | var_N + 10, |
5474 | var_N + 20, |
5475 | var_N + 30, |
5476 | var_N + 40, |
5477 | kGE, |
5478 | "b[n, m] = 1.f;" ); |
5479 | test_case_nest_loops_fn( |
5480 | n, |
5481 | m, |
5482 | b, |
5483 | var_N + 10, |
5484 | var_N + 20, |
5485 | var_M + 30, |
5486 | var_M + 40, |
5487 | kGE, |
5488 | "b[n, m] = n>=m ? 0.f : 1.f;" ); |
5489 | |
5490 | // Before: |
5491 | // for (const auto n : c10::irange(10, 31)) { |
5492 | // for(const auto m : c10::irange(30, 40)) { |
5493 | // b[n, m] = (n <= m) ? 0.f : 1.f; |
5494 | // } |
5495 | // } |
5496 | // After: |
5497 | // for (const auto n : c10::irange(10, 31)) { |
5498 | // for(const auto m : c10::irange(30, 40)) { |
5499 | // b[n, m] = 0.f; |
5500 | // } |
5501 | // } |
5502 | test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;" ); |
5503 | test_case_nest_loops_fn( |
5504 | n, |
5505 | m, |
5506 | b, |
5507 | var_N + 10, |
5508 | var_N + 31, |
5509 | var_N + 30, |
5510 | var_N + 40, |
5511 | kLE, |
5512 | "b[n, m] = 0.f;" ); |
5513 | test_case_nest_loops_fn( |
5514 | n, |
5515 | m, |
5516 | b, |
5517 | var_N + 10, |
5518 | var_N + 31, |
5519 | var_M + 30, |
5520 | var_M + 40, |
5521 | kLE, |
5522 | "b[n, m] = n<=m ? 0.f : 1.f;" ); |
5523 | |
5524 | // Before: |
5525 | // for (const auto n : c10::irange(30, 40)) { |
5526 | // for(const auto m : c10::irange(10, 20)) { |
5527 | // b[n, m] = (n <= m) ? 0.f : 1.f; |
5528 | // } |
5529 | // } |
5530 | // After: |
5531 | // for (const auto n : c10::irange(30, 40)) { |
5532 | // for(const auto m : c10::irange(10, 20)) { |
5533 | // b[n, m] = 0.f; |
5534 | // } |
5535 | // } |
5536 | test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;" ); |
5537 | test_case_nest_loops_fn( |
5538 | n, |
5539 | m, |
5540 | b, |
5541 | var_N + 30, |
5542 | var_N + 40, |
5543 | var_N + 10, |
5544 | var_N + 20, |
5545 | kLE, |
5546 | "b[n, m] = 1.f;" ); |
5547 | test_case_nest_loops_fn( |
5548 | n, |
5549 | m, |
5550 | b, |
5551 | var_N + 30, |
5552 | var_N + 40, |
5553 | var_M + 10, |
5554 | var_M + 20, |
5555 | kLE, |
5556 | "b[n, m] = n<=m ? 0.f : 1.f;" ); |
5557 | } |
5558 | |
5559 | TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { |
5560 | // Before: |
5561 | // for (const auto n : c10::irange(1, N)) { |
5562 | // b[n] = n < 1 ? 0.f : 1.f; |
5563 | // } |
5564 | // After: |
5565 | // for (const auto n : c10::irange(1, N)) { |
5566 | // b[n] = 1.f; |
5567 | // } |
5568 | constexpr int N = 8; |
5569 | BufHandle b("b" , {N}, kFloat); |
5570 | VarHandle n("n" , kInt); |
5571 | StmtPtr s = For::make( |
5572 | n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); |
5573 | s = IRSimplifier::simplify(s); |
5574 | std::ostringstream oss; |
5575 | oss << *s; |
5576 | torch::jit::testing::FileCheck().run( |
5577 | R"IR( |
5578 | # CHECK: b[n] = 1.f; |
5579 | )IR" , |
5580 | oss.str()); |
5581 | } |
5582 | |
5583 | TEST(Simplify, IfThenCondAlwaysInLoopBounds) { |
5584 | // Before: |
5585 | // for (const auto n : c10::irange(1, N)) { |
5586 | // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); |
5587 | // } |
5588 | // After: |
5589 | // for (const auto n : c10::irange(1, N)) { |
5590 | // b[n] = 1.f; |
5591 | // } |
5592 | constexpr int N = 8; |
5593 | BufHandle b("b" , {N}, kFloat); |
5594 | VarHandle n("n" , kInt); |
5595 | StmtPtr s = |
5596 | For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); |
5597 | s = IRSimplifier::simplify(s); |
5598 | std::ostringstream oss; |
5599 | oss << *s; |
5600 | torch::jit::testing::FileCheck().run( |
5601 | R"IR( |
5602 | # CHECK: b[n] = 1.f; |
5603 | )IR" , |
5604 | oss.str()); |
5605 | } |
5606 | |
5607 | TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { |
5608 | // This test mimics the unpadded region of a conv2d. We want to remove any |
5609 | // conditional that is provably satisfied (or unsatisfied) by the entire loop |
5610 | // range. |
5611 | // Before: |
5612 | // for (const auto i : c10::irange(1, 7)) { |
5613 | // for (const auto j : c10::irange(1, 7)) { |
5614 | // b[i, j] = IfThenElse( |
5615 | // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); |
5616 | // After: |
5617 | // for (const auto i : c10::irange(1, 7)) { |
5618 | // for (const auto j : c10::irange(1, 7)) { |
5619 | // b[i, j] = 1.f; |
5620 | constexpr int N = 8; |
5621 | BufHandle b("b" , {N, N}, kFloat); |
5622 | VarHandle i("i" , kInt); |
5623 | VarHandle j("j" , kInt); |
5624 | auto csel = CompareSelect::make(i, 1, kLT); |
5625 | csel = CompareSelect::make(j, 1, 1, csel, kLT); |
5626 | csel = CompareSelect::make(i, N - 1, 1, csel, kGE); |
5627 | csel = CompareSelect::make(j, N - 1, 1, csel, kGE); |
5628 | StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); |
5629 | s = For::make(j, 1, N - 1, s); |
5630 | s = For::make(i, 1, N - 1, s); |
5631 | s = IRSimplifier::simplify(s); |
5632 | std::ostringstream oss; |
5633 | oss << *s; |
5634 | torch::jit::testing::FileCheck().run( |
5635 | R"IR( |
5636 | # CHECK: b[i, j] = 1.f; |
5637 | )IR" , |
5638 | oss.str()); |
5639 | } |
5640 | |
5641 | TEST(Simplify, DISABLED_SimplifyLoopBounds) { |
5642 | // This test mimics the padded region of a conv2d. We want to adjust the |
5643 | // loop bounds such that the condition will be always met. Note that this |
5644 | // could be solved by peeling, and applying the range-based conditional |
5645 | // simplification in the previous tests. |
5646 | // Before: |
5647 | // for (const auto i : c10::irange(3)) { |
5648 | // for (const auto j : c10::irange(3)) { |
5649 | // b[i, j] = (b[i, j]) + (IfThenElse( |
5650 | // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); |
5651 | // After: |
5652 | // for (const auto i : c10::irange(1, 3)) { |
5653 | // for (const auto j : c10::irange(1, 3)) { |
5654 | // b[i, j] = (b[i, j]) + 1.f; |
5655 | constexpr int N = 8; |
5656 | constexpr int K = 3; |
5657 | BufHandle a("a" , {N, N}, kFloat); |
5658 | BufHandle b("b" , {N, N}, kFloat); |
5659 | VarHandle i("i" , kInt); |
5660 | VarHandle j("j" , kInt); |
5661 | auto csel = CompareSelect::make(i, 1, kLT); |
5662 | csel = CompareSelect::make(j, 1, 1, csel, kLT); |
5663 | csel = CompareSelect::make(i, N - 1, 1, csel, kGE); |
5664 | csel = CompareSelect::make(j, N - 1, 1, csel, kGE); |
5665 | StmtPtr s = b.store( |
5666 | {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); |
5667 | s = For::make(j, 0, K, s); |
5668 | s = For::make(i, 0, K, s); |
5669 | s = IRSimplifier::simplify(s); |
5670 | std::ostringstream oss; |
5671 | oss << *s; |
5672 | torch::jit::testing::FileCheck().run( |
5673 | R"IR( |
5674 | # CHECK: for (const auto i : c10::irange(1, 3)) { |
5675 | # CHECK: for (const auto j : c10::irange(1, 3)) { |
5676 | # CHECK-NOT: IfThenElse |
5677 | )IR" , |
5678 | oss.str()); |
5679 | } |
5680 | |
5681 | } // namespace jit |
5682 | } // namespace torch |
5683 | |