1#include <gtest/gtest.h>
2#include "test/cpp/tensorexpr/test_base.h"
3
4#include "test/cpp/tensorexpr/test_utils.h"
5#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
6#include "torch/csrc/jit/tensorexpr/registerizer.h"
7
8#include <iostream>
9
10namespace torch {
11namespace jit {
12using namespace torch::jit::tensorexpr;
13
14// Can replace a simple scalar access with a local variable.
15TEST(Registerizer, RegisterizerSimple) {
16 BufHandle a("A", {1}, kInt);
17 VarHandle x("x", kInt);
18 StmtPtr stmt = Block::make(
19 {Store::make(a, {0}, 0),
20 For::make(
21 x,
22 0,
23 10,
24 Block::make(
25 {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
26
27 /*
28 * A[0] = 0;
29 * for (int x = 0; x < 10; x++) {
30 * A[0] = (A[0]) + x;
31 * }
32 */
33
34 stmt = registerize(stmt);
35
36 /*
37 * int A_1 = 0;
38 * for (int x = 0; x < 10; x++) {
39 * A_1 = x + A_1;
40 * }
41 * A[0] = A_1;
42 */
43
44 std::ostringstream oss;
45 oss << *stmt;
46
47 const std::string& verification_pattern =
48 R"IR(
49# CHECK: int A_1 = 0;
50# CHECK: for (int x = 0; x < 10; x++)
51# CHECK-NOT: A[
52# CHECK: A_1 =
53# CHECK: A[0] = A_1;)IR";
54
55 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
56}
57
58// Won't do replacement of a loop access.
59TEST(Registerizer, RegisterizerLoop) {
60 BufHandle a("A", {10}, kInt);
61 VarHandle x("x", kInt);
62 StmtPtr stmt = Block::make(
63 {Store::make(a, {0}, 0),
64 For::make(
65 x,
66 0,
67 10,
68 Block::make(
69 {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
70
71 /*
72 * A[0] = 0;
73 * for (int x = 0; x < 10; x++) {
74 * A[x] = (A[x]) + x;
75 * }
76 */
77
78 // No change.
79 stmt = registerize(stmt);
80
81 /*
82 * A[0] = 0;
83 * for (int x = 0; x < 10; x++) {
84 * A[x] = (A[x]) + x;
85 * }
86 */
87
88 std::ostringstream oss;
89 oss << *stmt;
90
91 const std::string& verification_pattern =
92 R"IR(
93# CHECK-NOT: int
94# CHECK: A[0] = 0;
95# CHECK: for (int x = 0; x < 10; x++)
96# CHECK-NOT: A_
97# CHECK: A[x] =
98# CHECK-NOT: A[0] = A_1;)IR";
99
100 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
101}
102
103// Won't replace even if the load is a fixed scalar, since the store could
104// invalidate it.
105TEST(Registerizer, RegisterizerLoopFixedLoad) {
106 BufHandle a("A", {1}, kInt);
107 VarHandle x("x", kInt);
108 StmtPtr stmt = Block::make(
109 {Store::make(a, {0}, 0),
110 For::make(
111 x,
112 0,
113 10,
114 Block::make(
115 {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});
116
117 /*
118 * A[0] = 0;
119 * for (int x = 0; x < 10; x++) {
120 * A[x] = (A[0]) + x;
121 * }
122 */
123
124 // No change.
125 stmt = registerize(stmt);
126
127 /*
128 * A[0] = 0;
129 * for (int x = 0; x < 10; x++) {
130 * A[x] = (A[0]) + x;
131 * }
132 */
133
134 std::ostringstream oss;
135 oss << *stmt;
136
137 const std::string& verification_pattern =
138 R"IR(
139# CHECK-NOT: int
140# CHECK: A[0] = 0;
141# CHECK: for (int x = 0; x < 10; x++)
142# CHECK-NOT: A_
143# CHECK: A[x] =
144# CHECK-NOT: A[0] = A_1;)IR";
145
146 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
147}
148
149// We can registerize accesses that occur entirely within inner scopes, even if
150// they depend on the loop var.
151TEST(Registerizer, RegisterizerLoopInternal) {
152 BufHandle a("A", {1}, kInt);
153 VarHandle x("x", kInt);
154 StmtPtr stmt = Block::make({For::make(
155 x,
156 0,
157 10,
158 Block::make(
159 {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),
160 Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
161
162 /*
163 * for (int x = 0; x < 10; x++) {
164 * A[x] = (A[x]) + x;
165 * A[x] = (A[x]) + x;
166 * }
167 */
168
169 stmt = registerize(stmt);
170
171 // TODO: the order of terms in addition changes and in general depends on
172 // some hash value. This results in unpredictable swaps of the operands from
173 // random changes, which is not great. Ideally, we should ensure some
174 // specific order (ideally, the original one).
175 /*
176 * for (int x = 0; x < 10; x++) {
177 * int A_1 = A[x];
178 * A_1 = x + A_1;
179 * A_1 = x + A_1;
180 * A[x] = A_1;
181 * }
182 */
183
184 std::ostringstream oss;
185 oss << *stmt;
186
187 const std::string& verification_pattern =
188 R"IR(
189# CHECK: for (int x = 0; x < 10; x++)
190# CHECK: int A_1 = A[x];
191# CHECK: A_1 = A_1 + x;
192# CHECK: A_1 = A_1 + x;
193# CHECK: A[x] = A_1;
194# CHECK: })IR";
195
196 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
197}
198
199// An access can be overlapped by another read in the same Expr. In this case
200// B[z] and B[y] overlap and prevent registerization of both accesses.
201TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {
202 BufHandle a("A", {10}, kInt);
203 BufHandle b("B", {10}, kInt);
204 VarHandle x("x", kInt);
205 VarHandle y("y", kInt);
206 VarHandle z("z", kInt);
207 StmtPtr stmt = Block::make({For::make(
208 x,
209 0,
210 10,
211 Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});
212 stmt = IRSimplifier::simplify(stmt);
213
214 /*
215 * for (int x = 0; x < 10; x++) {
216 * A[x] = (B[y]) + (B[z]);
217 * }
218 */
219
220 std::ostringstream before;
221 before << *stmt;
222
223 // No change.
224 stmt = registerize(stmt);
225
226 std::ostringstream after;
227 after << *stmt;
228
229 ASSERT_EQ(before.str(), after.str());
230}
231
232TEST(Registerizer, RegisterizerLoopInternalRepeated) {
233 BufHandle a("A", {1}, kInt);
234 VarHandle x("x", kInt);
235 StmtPtr stmt = Block::make(
236 {For::make(
237 x,
238 0,
239 10,
240 Block::make(
241 {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
242 Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),
243 For::make(
244 x,
245 0,
246 10,
247 Block::make(
248 {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
249 Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))
250
251 });
252
253 /*
254 * for (int x = 0; x < 10; x++) {
255 * A[0] = x + (A[1]);
256 * A[0] = x + (A[1]);
257 * }
258 * for (int x = 0; x < 10; x++) {
259 * A[0] = x + (A[1]);
260 * A[0] = x + (A[1]);
261 * }
262 */
263
264 stmt = registerize(stmt);
265
266 /*
267 * int A_1 = A[1];
268 * int A_2 = A[0];
269 * for (int x = 0; x < 10; x++) {
270 * A_2 = A_1 + x;
271 * A_2 = A_1 + x;
272 * }
273 * for (int x = 0; x < 10; x++) {
274 * A_2 = A_1 + x;
275 * A_2 = A_1 + x;
276 * }
277 * A[0] = A_2;
278 */
279
280 std::ostringstream oss;
281 oss << *stmt;
282
283 const std::string& verification_pattern =
284 R"IR(
285# CHECK: int A_1 = A[1];
286# CHECK: int A_2 = A[0];
287# CHECK: for (int x = 0; x < 10; x++)
288# CHECK: A_2 = A_1 + x;
289# CHECK: A_2 = A_1 + x;
290# CHECK: }
291# CHECK: for (int x = 0; x < 10; x++)
292# CHECK: A_2 = A_1 + x;
293# CHECK: A_2 = A_1 + x;
294# CHECK: }
295# CHECK-NOT: A[1]
296# CHECK: A[0] = A_2;
297# CHECK-NOT: A[1]
298# CHECK: })IR";
299
300 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
301}
302
303TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {
304 BufHandle a("A", {1}, kInt);
305 VarHandle x("x", kInt);
306 StmtPtr stmt = Block::make(
307 {For::make(
308 x,
309 0,
310 10,
311 Block::make(
312 {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
313 Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),
314 For::make(
315 x,
316 0,
317 10,
318 Block::make(
319 {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
320 Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))
321
322 });
323 stmt = IRSimplifier::simplify(stmt);
324
325 /*
326 * for (int x = 0; x < 10; x++) {
327 * A[0] = (A[x]) + x;
328 * A[0] = (A[x]) + x;
329 * }
330 * for (int x = 0; x < 10; x++) {
331 * A[0] = (A[x]) + x;
332 * A[0] = (A[x]) + x;
333 * }
334 */
335
336 std::ostringstream before;
337 before << *stmt;
338
339 // No change.
340 stmt = registerize(stmt);
341
342 std::ostringstream after;
343 after << *stmt;
344
345 ASSERT_EQ(before.str(), after.str());
346}
347
348TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
349 BufHandle a("A", {1}, kInt);
350 VarHandle x("x", kInt);
351 VarHandle y("y", kInt);
352 StmtPtr stmt = IRSimplifier::simplify(Block::make(
353 {For::make(
354 x,
355 0,
356 10,
357 Block::make(
358 {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
359 Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),
360 For::make(
361 x,
362 0,
363 10,
364 Block::make(
365 {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
366 Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))
367
368 }));
369
370 /*
371 * for (int x = 0; x < 10; x++) {
372 * A[0] = (A[x]) + x;
373 * A[0] = (A[x]) + x;
374 * }
375 * for (int x = 0; x < 10; x++) {
376 * A[0] = (A[x]) + x;
377 * A[0] = (A[x]) + x;
378 * }
379 */
380
381 std::ostringstream before;
382 before << *stmt;
383
384 // No change.
385 stmt = registerize(stmt);
386
387 std::ostringstream after;
388 after << *stmt;
389
390 ASSERT_EQ(before.str(), after.str());
391}
392
393// Will registerize multiple accesses of different items of the same buffer.
394TEST(Registerizer, RegisterizerMultiVar) {
395 BufHandle a("A", {2}, kInt);
396 VarHandle x("x", kInt);
397 StmtPtr stmt = Block::make({
398 Store::make(a, {0}, 0),
399 Store::make(a, {1}, 0),
400 For::make(
401 x,
402 0,
403 10,
404 Block::make(
405 {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
406 Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
407 });
408
409 /*
410 * A[0] = 0;
411 * A[1] = 0;
412 * for (int x = 0; x < 10; x++) {
413 * A[0] = (A[0]) + x;
414 * A[1] = (A[1]) - x;
415 * }
416 */
417
418 stmt = registerize(stmt);
419
420 /*
421 * int A_1 = 0;
422 * int A_2 = 0;
423 * for (int x = 0; x < 10; x++) {
424 * A_2 = x + A_2;
425 * A_1 = A_1 - x;
426 * }
427 * A[1] = A_2;
428 * A[0] = A_1;
429 */
430
431 std::ostringstream oss;
432 oss << *stmt;
433
434 const std::string& verification_pattern =
435 R"IR(
436# CHECK: int A_1 = 0;
437# CHECK: int A_2 = 0;
438# CHECK: for (int x = 0; x < 10; x++)
439# CHECK-NOT: A[
440# CHECK: A_1 =
441# CHECK: A_2 =
442# CHECK: A[1] = A_2
443# CHECK: A[0] = A_1;)IR";
444
445 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
446}
447
448// Will registerize the valid accesses while skipping invalid replacements.
449TEST(Registerizer, RegisterizerVariableLoad) {
450 BufHandle a("A", {1}, kInt);
451 BufHandle b("B", {10}, kInt);
452 VarHandle x("x", kInt);
453 VarHandle x2("x", kInt);
454 StmtPtr stmt = Block::make(
455 {Store::make(a, {0}, 0),
456 For::make(x, 0, 10, Store::make(b, {x}, x)),
457 For::make(
458 x2,
459 0,
460 10,
461 Block::make({Store::make(
462 a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});
463
464 /*
465 * A[0] = 0;
466 * for (int x = 0; x < 10; x++) {
467 * B[x] = x;
468 * }
469 * for (int x_1 = 0; x_1 < 10; x_1++) {
470 * A[0] = (A[0]) + (B[x_1]);
471 * }
472 */
473
474 stmt = registerize(stmt);
475
476 /*
477 * int A_1 = 0;
478 * for (int x = 0; x < 10; x++) {
479 * B[x] = x;
480 * }
481 * for (int x_1 = 0; x_1 < 10; x_1++) {
482 * A_1 = A_1 + (B[x_1]);
483 * }
484 * A[0] = A_1;
485 */
486
487 std::ostringstream oss;
488 oss << *stmt;
489
490 const std::string& verification_pattern =
491 R"IR(
492# CHECK: int A_1 = 0;
493# CHECK: for (int x = 0; x < 10; x++)
494# CHECK: B[x] = x
495# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
496# CHECK-NOT: A[
497# CHECK: A_1 =
498# CHECK: A[0] = A_1;)IR";
499
500 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
501}
502
503// Can registerize variable accesses so long as the variable does not change.
504TEST(Registerizer, RegisterizerSymbolicIndices) {
505 VarHandle i("i", kInt);
506 VarHandle N("N", kInt);
507 BufHandle a("A", {N}, kInt);
508 VarHandle x("x", kInt);
509 StmtPtr stmt = Block::make(
510 {Store::make(a, {i}, 0),
511 For::make(
512 x,
513 0,
514 10,
515 Block::make(
516 {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});
517
518 /*
519 * A[i] = 0;
520 * for (int x = 0; x < 10; x++) {
521 * A[i] = (A[i]) + x;
522 * }
523 */
524
525 stmt = registerize(stmt);
526
527 /*
528 * int A_1 = 0;
529 * for (int x = 0; x < 10; x++) {
530 * A_1 = x + A_1;
531 * }
532 * A[i] = A_1;
533 */
534
535 std::ostringstream oss;
536 oss << *stmt;
537
538 const std::string& verification_pattern =
539 R"IR(
540# CHECK: int A_1 = 0;
541# CHECK: for (int x = 0; x < 10; x++)
542# CHECK-NOT: A[
543# CHECK: A_1 =
544# CHECK: A[i] = A_1;)IR";
545
546 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
547}
548
549// Can registerize accesses dependent on multiple loop vars.
550TEST(Registerizer, RegisterizerMultiLoop) {
551 BufHandle a("A", {1}, kInt);
552 VarHandle x("x", kInt);
553 VarHandle y("y", kInt);
554 StmtPtr stmt = Block::make(
555 {Store::make(a, {0}, 0),
556 For::make(
557 x,
558 0,
559 10,
560 For::make(
561 y,
562 0,
563 10,
564 Block::make({Store::make(
565 a,
566 {0},
567 Mul::make(Add::make(Load::make(a, {0}), x), y))})))});
568
569 /*
570 * A[0] = 0;
571 * for (int x = 0; x < 10; x++) {
572 * for (int y = 0; y < 10; y++) {
573 * A[0] = x * y + (A[0]) * y;
574 * }
575 * }
576 */
577
578 stmt = registerize(stmt);
579
580 /*
581 * int A_1 = 0;
582 * for (int x = 0; x < 10; x++) {
583 * for (int y = 0; y < 10; y++) {
584 * A_1 = x * y + y * A_1;
585 * }
586 * }
587 * A[0] = A_1;
588 */
589
590 std::ostringstream oss;
591 oss << *stmt;
592
593 const std::string& verification_pattern =
594 R"IR(
595# CHECK: int A_1 = 0;
596# CHECK: for (int x = 0; x < 10; x++)
597# CHECK: for (int y = 0; y < 10; y++)
598# CHECK-NOT: A[
599# CHECK: A_1 =
600# CHECK: A[0] = A_1;)IR";
601
602 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
603}
604
605// Can registerize correctly if scalars already exist in the program.
606TEST(Registerizer, RegisterizerRepeated) {
607 BufHandle a("A", {2}, kInt);
608 VarHandle x("x", kInt);
609 StmtPtr stmt = Block::make({
610 Store::make(a, {0}, 0),
611 Store::make(a, {1}, 0),
612 For::make(
613 x,
614 0,
615 10,
616 Block::make(
617 {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
618 Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
619 });
620
621 // Registerize manually to make sure we only replace a single target.
622 {
623 registerizer::RegisterizerAnalysis analysis;
624 stmt->accept(&analysis);
625 auto candidates = analysis.getCandidates();
626 ASSERT_EQ(candidates.size(), 2);
627
628 candidates.pop_back();
629 registerizer::RegisterizerReplacer replacer(candidates);
630 stmt = stmt->accept_mutator(&replacer);
631 }
632
633 // Re-analyze and replace the second target.
634 {
635 registerizer::RegisterizerAnalysis analysis;
636 stmt->accept(&analysis);
637 auto candidates = analysis.getCandidates();
638 ASSERT_EQ(candidates.size(), 1);
639
640 registerizer::RegisterizerReplacer replacer(candidates);
641 stmt = stmt->accept_mutator(&replacer);
642 }
643
644 std::ostringstream oss;
645 oss << *stmt;
646
647 const std::string& verification_pattern =
648 R"IR(
649# CHECK: int A_1 = 0;
650# CHECK: int A_1_1 = 0;
651# CHECK: for (int x = 0; x < 10; x++)
652# CHECK-NOT: A[
653# CHECK: A_1 =
654# CHECK: A_1_1 =
655# CHECK: A[1] = A_1_1;
656# CHECK: A[0] = A_1;)IR";
657
658 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
659}
660
661// Can registerize the load of A.
662TEST(Registerizer, RegisterizerNoLoads) {
663 BufHandle a("A", {1}, kInt);
664 VarHandle x("x", kInt);
665 StmtPtr stmt = Block::make(
666 {Store::make(a, {0}, 0),
667 For::make(
668 x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});
669
670 /*
671 * A[0] = 0;
672 * for (int x = 0; x < 10; x++) {
673 * A[0] = x + 1;
674 * }
675 */
676
677 stmt = registerize(stmt);
678
679 /*
680 * int A_1 = 0;
681 * for (int x = 0; x < 10; x++) {
682 * A_1 = x + 1;
683 * }
684 * A[0] = A_1;
685 */
686
687 std::ostringstream oss;
688 oss << *stmt;
689
690 const std::string& verification_pattern =
691 R"IR(
692# CHECK: int A_1 = 0;
693# CHECK: for (int x = 0; x < 10; x++)
694# CHECK-NOT: A[
695# CHECK: A_1 =
696# CHECK: A[0] = A_1;)IR";
697
698 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
699}
700
701// Can registerize the load of A but not the store of B.
702TEST(Registerizer, RegisterizerNoRepeatedStores) {
703 BufHandle a("A", {1}, kInt);
704 BufHandle b("B", {10}, kInt);
705 VarHandle x("x", kInt);
706 StmtPtr stmt = Block::make(
707 {Store::make(a, {0}, 0),
708 For::make(
709 x,
710 0,
711 10,
712 Block::make(
713 {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});
714
715 /*
716 * A[0] = 0;
717 * for (int x = 0; x < 10; x++) {
718 * B[x] = (A[0]) + x;
719 * }
720 */
721
722 stmt = registerize(stmt);
723
724 // TODO: its unnecessary to reorder the initializer of A[0], but it's not
725 // actually worse so lets not worry for now.
726
727 /*
728 * int A_1 = 0;
729 * for (int x = 0; x < 10; x++) {
730 * B[x] = x + A_1;
731 * }
732 * A[0] = A_1;
733 */
734
735 std::ostringstream oss;
736 oss << *stmt;
737
738 const std::string& verification_pattern =
739 R"IR(
740# CHECK: int A_1 = 0;
741# CHECK: for (int x = 0; x < 10; x++)
742# CHECK-NOT: A_
743# CHECK: B[x] =
744# CHECK: A[0] = A_1;)IR";
745
746 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
747}
748
749// Won't registerize if there are multiple accesses which may overlap.
750TEST(Registerizer, RegisterizerMultiVarOverlap) {
751 BufHandle a("A", {2}, kInt);
752 VarHandle x("x", kInt);
753 StmtPtr stmt = Block::make({
754 Store::make(a, {0}, 0),
755 Store::make(a, {1}, 0),
756 For::make(
757 x,
758 0,
759 10,
760 Block::make(
761 {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),
762 Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),
763 });
764 stmt = IRSimplifier::simplify(stmt);
765
766 std::ostringstream before;
767 before << *stmt;
768
769 // No change.
770 stmt = registerize(stmt);
771
772 std::ostringstream after;
773 after << *stmt;
774
775 ASSERT_EQ(before.str(), after.str());
776}
777
778TEST(Registerizer, RegisterizerAllocs) {
779 BufHandle a("A", {2}, kInt);
780 BufHandle c("C", {1}, kInt);
781 VarHandle x("x", kInt);
782
783 BufHandle b("B", {Load::make(c, {0})}, kInt);
784
785 StmtPtr stmt = Block::make(
786 {Allocate::make(b),
787 Store::make(a, {0}, Load::make(c, {0})),
788 Store::make(b, {0}, 0),
789 For::make(
790 x,
791 0,
792 10,
793 Block::make(
794 {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),
795 Store::make(a, {0}, Load::make(c, {0}))})),
796 Free::make(b)});
797
798 /*
799 * Allocate(B, int, {C[0]});
800 * A[0] = C[0];
801 * B[0] = 0;
802 * for (int x = 0; x < 10; x++) {
803 * B[0] = (B[0]) + x;
804 * A[0] = C[0];
805 * }
806 * Free(B);
807 */
808
809 stmt = registerize(stmt);
810
811 /*
812 * int C_1 = C[0];
813 * Allocate(B, int, {C_});
814 * int A_1 = C_1;
815 * int B_1 = 0;
816 * for (int x = 0; x < 10; x++) {
817 * B_1 = B_1 + x;
818 * A_1 = C_1;
819 * }
820 * B[0] = B_1;
821 * A[0] = A_1;
822 * Free(B);
823 */
824
825 std::ostringstream oss;
826 oss << *stmt;
827
828 const std::string& verification_pattern =
829 R"IR(
830# CHECK: int C_1 = C[0];
831# CHECK: Allocate(B
832# CHECK: int A_1 = C_1;
833# CHECK: int B_1 = 0;
834# CHECK: for (int x = 0; x < 10; x++)
835# CHECK: B_1 =
836# CHECK: A_1 = C_
837# CHECK: B[0] = B_1;
838# CHECK: A[0] = A_1;
839# CHECK: Free(B)IR";
840
841 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
842}
843
844TEST(Registerizer, RegisterizerNoInitializer) {
845 BufHandle a("A", {1}, kInt);
846 VarHandle x("x", kInt);
847 StmtPtr stmt = Block::make({For::make(
848 x,
849 0,
850 10,
851 Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
852
853 /*
854 * for (int x = 0; x < 10; x++) {
855 * A[0] = (A[0]) + x;
856 * }
857 */
858
859 stmt = registerize(stmt);
860
861 /*
862 * int A_1 = A[0];
863 * for (int x = 0; x < 10; x++) {
864 * A_1 = x + A_1;
865 * }
866 * A[0] = A_1;
867 */
868
869 std::ostringstream oss;
870 oss << *stmt;
871
872 const std::string& verification_pattern =
873 R"IR(
874# CHECK: int A_1 = A[0];
875# CHECK: for (int x = 0; x < 10; x++)
876# CHECK-NOT: A[
877# CHECK: A_1 =
878# CHECK: A[0] = A_1;)IR";
879
880 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
881}
882
883TEST(Registerizer, RegisterizerNoInitializerLoopVar) {
884 BufHandle a("A", {1}, kInt);
885 VarHandle x("x", kInt);
886 StmtPtr stmt = Block::make({For::make(
887 x,
888 0,
889 10,
890 Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
891 stmt = IRSimplifier::simplify(stmt);
892
893 /*
894 * for (int x = 0; x < 10; x++) {
895 * A[x] = (A[x]) + x;
896 * }
897 */
898
899 std::ostringstream before;
900 before << *stmt;
901
902 // No change.
903 stmt = registerize(stmt);
904
905 std::ostringstream after;
906 after << *stmt;
907
908 ASSERT_EQ(before.str(), after.str());
909}
910
911TEST(Registerizer, RegisterizerLoadThenStore) {
912 BufHandle a("A", {1}, kInt);
913 BufHandle b("B", {1}, kInt);
914 VarHandle x("x", kInt);
915 StmtPtr stmt = Block::make({For::make(
916 x,
917 0,
918 10,
919 Block::make(
920 {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),
921 Store::make(a, {0}, Load::make(b, {0}))}))});
922
923 /*
924 * for (int x = 0; x < 10; x++) {
925 * B[0] = (A[0]) + x;
926 * A[0] = B[0];
927 * }
928 */
929
930 stmt = registerize(stmt);
931
932 /*
933 * int A_1 = A[0];
934 * int B_1 = B[0];
935 * for (int x = 0; x < 10; x++) {
936 * B_1 = x + A_1;
937 * A_1 = B_1;
938 * }
939 * B[0] = B_1;
940 * A[0] = A_1;
941 */
942
943 std::ostringstream oss;
944 oss << *stmt;
945
946 const std::string& verification_pattern =
947 R"IR(
948# CHECK: int A_1 = A[0];
949# CHECK: int B_1 = B[0];
950# CHECK: for (int x = 0; x < 10; x++)
951# CHECK-NOT: B[
952# CHECK: B_1 =
953# CHECK-NOT: A[
954# CHECK: A_1 = B_
955# CHECK: B[0] = B_
956# CHECK: A[0] = A_1;)IR";
957
958 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
959}
960
961TEST(Registerizer, RegisterizerParallelized) {
962 BufHandle a("A", {1}, kInt);
963 VarHandle x("x", kInt);
964 LoopOptions loopOpts;
965 loopOpts.set_gpu_block_index(0);
966 StmtPtr stmt = Block::make(
967 {Store::make(a, {0}, 0),
968 For::make(
969 x,
970 0,
971 10,
972 Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),
973 loopOpts)});
974
975 /*
976 * A[0] = 0;
977 * for (int x = 0; x < 10; x++) {
978 * A[0] = (A[0]) + x;
979 * }
980 */
981
982 ASSERT_THROWS_WITH(
983 registerize(stmt),
984 "Registerization must occur after parallelism flattening");
985}
986
987// Should be able to registerize this since the scalar would exist before the
988// branch.
989TEST(Registerizer, RegisterizerConditionAfter) {
990 BufHandle a("A", {5}, kInt);
991 BufHandle b("B", {5}, kInt);
992 BufHandle c("C", {5}, kInt);
993 VarHandle x("x", kInt);
994
995 StmtPtr stmt = Block::make(
996 {Store::make(a, {x}, Load::make(b, {x})),
997 Store::make(c, {x}, Load::make(a, {x})),
998 Cond::make(
999 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1000 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1001 nullptr)});
1002
1003 /*
1004 * A[x] = B[x];
1005 * C[x] = A[x];
1006 * if (x<5 ? 1 : 0) {
1007 * A[x] = (A[x]) + 1;
1008 * }
1009 */
1010
1011 stmt = registerize(stmt);
1012
1013 /*
1014 * int A_1 = B[x];
1015 * C[x] = A_1;
1016 * if (x<5 ? 1 : 0) {
1017 * A_1 = A_1 + 1;
1018 * }
1019 * A[x] = A_1;
1020 */
1021
1022 std::ostringstream oss;
1023 oss << *stmt;
1024
1025 const std::string& verification_pattern =
1026 R"IR(
1027# CHECK: int A_1 = B[x];
1028# CHECK: C[x] = A_1;
1029# CHECK: if (
1030# CHECK: A_1 = A_1 + 1;
1031# CHECK: A[x] = A_1;)IR";
1032
1033 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1034}
1035
1036// Should be able to registerize this since the scalar exists in the same form
1037// after the branch and there is no overlap.
1038TEST(Registerizer, RegisterizerConditionBefore) {
1039 BufHandle a("A", {5}, kInt);
1040 BufHandle b("B", {5}, kInt);
1041 BufHandle c("C", {5}, kInt);
1042 VarHandle x("x", kInt);
1043
1044 StmtPtr stmt = Block::make(
1045 {Cond::make(
1046 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1047 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1048 nullptr),
1049 Store::make(a, {x}, Load::make(b, {x})),
1050 Store::make(c, {x}, Load::make(a, {x}))});
1051
1052 /*
1053 * if (x<5 ? 1 : 0) {
1054 * A[x] = (A[x]) + 1;
1055 * }
1056 * A[x] = B[x];
1057 * C[x] = A[x];
1058 */
1059
1060 stmt = registerize(stmt);
1061
1062 /*
1063 * int A_ 1 = A[x];
1064 * if (x<5 ? 1 : 0) {
1065 * A_1 = A_1 + 1;
1066 * }
1067 * A_1 = B[x];
1068 * C[x] = A_1;
1069 * A[x] = A_1;
1070 */
1071
1072 std::ostringstream oss;
1073 oss << *stmt;
1074
1075 const std::string& verification_pattern =
1076 R"IR(
1077# CHECK: int A_1 = A[x];
1078# CHECK: if (
1079# CHECK: A_1 = A_1 + 1;
1080# CHECK: }
1081# CHECK: A_1 = B[x];
1082# CHECK: C[x] = A_1;
1083# CHECK: A[x] = A_1;)IR";
1084
1085 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1086}
1087
1088// Should be able to registerize this as the combination of the two above rules.
1089TEST(Registerizer, RegisterizerConditionInside) {
1090 BufHandle a("A", {5}, kInt);
1091 BufHandle b("B", {5}, kInt);
1092 BufHandle c("C", {5}, kInt);
1093 VarHandle x("x", kInt);
1094
1095 StmtPtr stmt = Block::make(
1096 {Store::make(a, {x}, Load::make(b, {x})),
1097 Store::make(c, {x}, Load::make(a, {x})),
1098 Cond::make(
1099 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1100 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1101 nullptr),
1102 Store::make(b, {x}, Load::make(a, {x})),
1103 Store::make(a, {x}, Load::make(c, {x}))});
1104
1105 /*
1106 * A[x] = B[x];
1107 * C[x] = A[x];
1108 * if (x<5 ? 1 : 0) {
1109 * A[x] = (A[x]) + 1;
1110 * }
1111 * B[x] = A[x];
1112 * A[x] = C[x];
1113 */
1114
1115 stmt = registerize(stmt);
1116
1117 /*
1118 * int A_1 = B[x];
1119 * C[x] = A_1;
1120 * if (x<5 ? 1 : 0) {
1121 * A_1 = A_1 + 1;
1122 * }
1123 * B[x] = A_1;
1124 * A_1 = C[x];
1125 * A[x] = A_1;
1126 */
1127
1128 std::ostringstream oss;
1129 oss << *stmt;
1130
1131 const std::string& verification_pattern =
1132 R"IR(
1133# CHECK: int A_1 = B[x];
1134# CHECK: C[x] = A_1;
1135# CHECK: if (
1136# CHECK: A_1 = A_1 + 1;
1137# CHECK: }
1138# CHECK: B[x] = A_1;
1139# CHECK: A_1 = C[x];
1140# CHECK: A[x] = A_1;)IR";
1141
1142 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1143}
1144
1145// An example where an access is cut by an overlapping access inside a
1146// condition, and both sides are large enough to be registerized but cannot be
1147// because there is no safe place to put the initializer or finalizer.
1148TEST(Registerizer, RegisterizerConditionInsideOverlap1) {
1149 BufHandle a("A", {5}, kInt);
1150 BufHandle b("B", {5}, kInt);
1151 BufHandle c("C", {5}, kInt);
1152 VarHandle x("x", kInt);
1153 VarHandle y("y", kInt);
1154
1155 StmtPtr stmt = Block::make(
1156 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1157 {Store::make(a, {x}, Load::make(b, {x})),
1158 Store::make(c, {x}, Load::make(a, {x})),
1159 Cond::make(
1160 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1161 Block::make({
1162 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1163 Store::make(a, {0}, 3),
1164 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1165 }),
1166 nullptr),
1167 Store::make(b, {x}, Load::make(a, {x})),
1168 Store::make(a, {x}, Load::make(c, {x}))});
1169
1170 /*
1171 * A[x] = B[x];
1172 * C[x] = A[x];
1173 * if (x<5 ? 1 : 0) {
1174 * A[x] = (A[x]) + 1;
1175 * A[0] = 3;
1176 * A[x] = (A[x]) + 1;
1177 * }
1178 * B[x] = A[x];
1179 * A[x] = C[x];
1180 */
1181
1182 // The A[0] store overlaps, A[x] cutting the region that can be registerized
1183 // into two groups.
1184 // Each group has 2 loads and 2 stores however, so we could registerize it,
1185 // but the first group would need to be finalized inside the condition block,
1186 // the second would need to be initialized inside the condition block. There's
1187 // no safe place to put these that's visible to the other uses in the group
1188 // and so neither registerization is possible.
1189
1190 std::ostringstream before;
1191 before << *stmt;
1192
1193 // No change.
1194 stmt = registerize(stmt);
1195
1196 std::ostringstream after;
1197 after << *stmt;
1198
1199 ASSERT_EQ(before.str(), after.str());
1200}
1201
1202// Same as the above, but the access group before the condition (and after the
1203// condition) are large enough to be registerized without needing the access
1204// from the loop. Registerization occurs but does not include any accesses in
1205// the condition, and the first group must be finalized before the Cond, the
1206// second initialized after it.
1207TEST(Registerizer, RegisterizerConditionInsideOverlap2) {
1208 BufHandle a("A", {5}, kInt);
1209 BufHandle b("B", {5}, kInt);
1210 BufHandle c("C", {5}, kInt);
1211 VarHandle x("x", kInt);
1212 VarHandle y("y", kInt);
1213
1214 StmtPtr stmt = Block::make(
1215 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1216 {Store::make(a, {x}, Load::make(b, {x})),
1217 Store::make(a, {x}, Load::make(b, {x + 1})),
1218 Store::make(c, {x}, Load::make(a, {x})),
1219 Cond::make(
1220 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1221 Block::make({
1222 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1223 Store::make(a, {0}, 3),
1224 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1225 }),
1226 nullptr),
1227 Store::make(b, {x}, Load::make(a, {x})),
1228 Store::make(b, {x + 1}, Load::make(a, {x})),
1229 Store::make(a, {x}, Load::make(c, {x}))});
1230
1231 /*
1232 * A[x] = B[x];
1233 * A[x] = B[x + 1];
1234 * C[x] = A[x];
1235 * if (x<5 ? 1 : 0) {
1236 * A[x] = (A[x]) + 1;
1237 * A[0] = 3;
1238 * A[x] = (A[x]) + 1;
1239 * }
1240 * B[x] = A[x];
1241 * B[x + 1] = A[x];
1242 * A[x] = C[x];
1243 */
1244
1245 stmt = registerize(stmt);
1246
1247 /*
1248 * int A_1 = B[x]; // A_1 initializer
1249 * A_1 = B[x + 1]; //
1250 * C[x] = A_1; //
1251 * A[x] = A_1; // A_1 finalizer
1252 * if (x<5 ? 1 : 0) {
1253 * A[x] = (A[x]) + 1;
1254 * A[0] = 3;
1255 * A[x] = (A[x]) + 1;
1256 * }
1257 * int A_2 = A[x]; // A_2 initialier
1258 * B[x] = A_2; //
1259 * B[x + 1] = A_2; //
1260 * A_2 = C[x]; //
1261 * A[x] = A_2; // A_2 finalizer
1262 */
1263
1264 std::ostringstream oss;
1265 oss << *stmt;
1266
1267 const std::string& verification_pattern =
1268 R"IR(
1269# CHECK: int A_1 = B[x];
1270# CHECK: A_1 = B[x + 1];
1271# CHECK: C[x] = A_1;
1272# CHECK: A[x] = A_1;
1273# CHECK: if (
1274# CHECK-NOT: A_1 = A_1 + 1;
1275# CHECK: A[x] = (A[x]
1276# CHECK: A[0] =
1277# CHECK: A[x] = (A[x]
1278# CHECK: }
1279# CHECK: int A_2 = A[x];
1280# CHECK: B[x] = A_2;
1281# CHECK: B[x + 1] = A_2;
1282# CHECK: A_2 = C[x];
1283# CHECK: A[x] = A_2;)IR";
1284
1285 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1286}
1287
1288// When accesses are within conditional blocks they are not visible to the wider
1289// program, because we don't know if the branch would be taken and if it isn't
1290// the accesses in it don't need to be valid (think size checks on the index).
1291// In this case the accesses cannot be registerized.
1292TEST(Registerizer, RegisterizerConditionHidden) {
1293 BufHandle a("A", {5}, kInt);
1294 BufHandle b("B", {5}, kInt);
1295 BufHandle c("C", {5}, kInt);
1296 VarHandle x("x", kInt);
1297
1298 StmtPtr stmt = Block::make(
1299 {Cond::make(
1300 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1301 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1302 nullptr),
1303 Cond::make(
1304 CompareSelect::make(x, 5, CompareSelectOperation::kGT),
1305 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1306 nullptr)});
1307
1308 /*
1309 * if (x<5 ? 1 : 0) {
1310 * A[x] = (A[x]) + 1;
1311 * }
1312 * if (x>5 ? 1 : 0) {
1313 * A[x] = (A[x]) + 1;
1314 * }
1315 */
1316
1317 std::ostringstream before;
1318 before << *stmt;
1319
1320 // No change.
1321 stmt = registerize(stmt);
1322
1323 std::ostringstream after;
1324 after << *stmt;
1325
1326 ASSERT_EQ(before.str(), after.str());
1327}
1328
1329// But... if the same access is found in a non conditional scope, that means
1330// that that access is valid in the higher scope (or at least if its not it's
1331// the user's fault). It "unhides" the conditional accesses, allowing
1332// registerization to occur.
1333TEST(Registerizer, RegisterizerConditionUnhidden) {
1334 BufHandle a("A", {5}, kInt);
1335 BufHandle b("B", {5}, kInt);
1336 BufHandle c("C", {5}, kInt);
1337 VarHandle x("x", kInt);
1338
1339 StmtPtr stmt = Block::make(
1340 {Cond::make(
1341 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1342 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1343 nullptr),
1344 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1345 Cond::make(
1346 CompareSelect::make(x, 5, CompareSelectOperation::kGT),
1347 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1348 nullptr)});
1349
1350 /*
1351 * if (x<5 ? 1 : 0) {
1352 * A[x] = (A[x]) + 1;
1353 * }
1354 * A[x] = (A[x]) + 1; <-- this is doing the unhiding.
1355 * if (x>5 ? 1 : 0) {
1356 * A[x] = (A[x]) + 1;
1357 * }
1358 */
1359
1360 stmt = registerize(stmt);
1361
1362 /*
1363 * int A_1 = A[x];
1364 * if (x<5 ? 1 : 0) {
1365 * A_1 = A_1 + 1;
1366 * }
1367 * A_1 = A_1 + 1;
1368 * if (x>5 ? 1 : 0) {
1369 * A_1 = A_1 + 1;
1370 * }
1371 * A[x] = A_1;
1372 */
1373
1374 std::ostringstream oss;
1375 oss << *stmt;
1376
1377 const std::string& verification_pattern =
1378 R"IR(
1379# CHECK: int A_1 = A[x];
1380# CHECK: if (x<5
1381# CHECK: A_1 = A_1 + 1;
1382# CHECK: }
1383# CHECK: A_1 = A_1 + 1;
1384# CHECK: if (x>5
1385# CHECK: A_1 = A_1 + 1;
1386# CHECK: }
1387# CHECK: A[x] = A_1;)IR";
1388
1389 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1390}
1391
1392// Can registerize a load that occurs in the condition of a Cond.
1393TEST(Registerizer, RegisterizerCondCondition) {
1394 BufHandle a("A", {5}, kInt);
1395 BufHandle b("B", {5}, kInt);
1396 BufHandle c("C", {5}, kInt);
1397 VarHandle x("x", kInt);
1398
1399 StmtPtr stmt = Block::make(
1400 {Store::make(a, {x}, Load::make(b, {x})),
1401 Store::make(c, {x}, Load::make(a, {x})),
1402 Cond::make(
1403 CompareSelect::make(
1404 Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1405 Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
1406 nullptr)});
1407
1408 /*
1409 * A[x] = B[x];
1410 * C[x] = A[x];
1411 * if ((A[x])<5 ? 1 : 0) {
1412 * C[x] = (C[x]) + 1;
1413 * }
1414 */
1415
1416 stmt = registerize(stmt);
1417
1418 /*
1419 * int A_1 = B[x];
1420 * int C_1 = A_1;
1421 * if (A_1<5 ? 1 : 0) {
1422 * C_1 = C_1 + 1;
1423 * }
1424 * C[x] = C_1;
1425 */
1426
1427 std::ostringstream oss;
1428 oss << *stmt;
1429
1430 const std::string& verification_pattern =
1431 R"IR(
1432# CHECK: int A_1 = B[x];
1433# CHECK: int C_1 = A_1;
1434# CHECK: if (A_1<5
1435# CHECK: C_1 = C_1 + 1;
1436# CHECK: C[x] = C_1;)IR";
1437
1438 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1439}
1440
1441// Appearing in the condition of a Cond makes it visible to the enclosing scope,
1442// and so we can registerize internal usages.
1443TEST(Registerizer, RegisterizerCondConditionUnhidden) {
1444 BufHandle a("A", {5}, kInt);
1445 BufHandle b("B", {5}, kInt);
1446 BufHandle c("C", {5}, kInt);
1447 VarHandle x("x", kInt);
1448
1449 StmtPtr stmt = Block::make({Cond::make(
1450 CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1451 Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1452 Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});
1453
1454 /*
1455 * if ((A[x])<5 ? 1 : 0) {
1456 * A[x] = (A[x]) + 1;
1457 * } else {
1458 * A[x] = (A[x]) + 10;
1459 * }
1460 */
1461
1462 stmt = registerize(stmt);
1463
1464 /*
1465 * int A_1 = A[x];
1466 * if (A_1<5 ? 1 : 0) {
1467 * A_1 = A_1 + 1;
1468 * } else {
1469 * A_1 = A_1 + 10;
1470 * }
1471 * A[x] = A_1;
1472 */
1473
1474 std::ostringstream oss;
1475 oss << *stmt;
1476
1477 const std::string& verification_pattern =
1478 R"IR(
1479# CHECK: int A_1 = A[x];
1480# CHECK: if (A_1<5
1481# CHECK: A_1 = A_1 + 1;
1482# CHECK: } else {
1483# CHECK: A_1 = A_1 + 10;
1484# CHECK: }
1485# CHECK: A[x] = A_1;)IR";
1486
1487 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1488}
1489
1490// Conditional hiding also works for IfThenElse exprs.
1491TEST(Registerizer, RegisterizerIfThenElseHidden) {
1492 BufHandle a("A", {5}, kInt);
1493 BufHandle b("B", {5}, kInt);
1494 BufHandle c("C", {5}, kInt);
1495 VarHandle x("x", kInt);
1496 VarHandle y("y", kInt);
1497
1498 StmtPtr stmt = Block::make(
1499 {Store::make(
1500 b,
1501 {y},
1502 IfThenElse::make(
1503 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1504 Add::make(Load::make(a, {x}), 1),
1505 Add::make(Load::make(a, {x + 1}), 2))),
1506 Store::make(
1507 b,
1508 {y + 1},
1509 IfThenElse::make(
1510 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1511 Add::make(Load::make(a, {x}), 1),
1512 Add::make(Load::make(a, {x + 1}), 2)))});
1513
1514 /*
1515 * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1516 * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1517 */
1518
1519 std::ostringstream before;
1520 before << *stmt;
1521
1522 // No change.
1523 stmt = registerize(stmt);
1524
1525 std::ostringstream after;
1526 after << *stmt;
1527
1528 ASSERT_EQ(before.str(), after.str());
1529}
1530
1531// Conditional unhiding also works for IfThenElse exprs.
1532TEST(Registerizer, RegisterizerIfThenElseUnhidden) {
1533 BufHandle a("A", {5}, kInt);
1534 BufHandle b("B", {5}, kInt);
1535 BufHandle c("C", {5}, kInt);
1536 VarHandle x("x", kInt);
1537 VarHandle y("y", kInt);
1538
1539 StmtPtr stmt = Block::make({
1540 Store::make(a, {x}, 0),
1541 Store::make(
1542 b,
1543 {y},
1544 IfThenElse::make(
1545 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1546 Add::make(Load::make(a, {x}), 1),
1547 Add::make(Load::make(a, {x + 1}), 2))),
1548 Store::make(
1549 b,
1550 {y + 1},
1551 IfThenElse::make(
1552 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1553 Add::make(Load::make(a, {x}), 1),
1554 Add::make(Load::make(a, {x + 1}), 2))),
1555 });
1556
1557 /*
1558 * A[x] = 0;
1559 * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1560 * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1561 */
1562
1563 stmt = registerize(stmt);
1564
1565 /*
1566 * int A_1 = 0;
1567 * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1568 * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1569 * A[x] = A_1;
1570 */
1571
1572 std::ostringstream oss;
1573 oss << *stmt;
1574
1575 const std::string& verification_pattern =
1576 R"IR(
1577# CHECK: int A_1 = 0;
1578# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1579# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1580# CHECK: A[x] = A_1;)IR";
1581
1582 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1583}
1584
1585// Nested IfThenElse exprs can't promote to higher level scopes.
1586TEST(Registerizer, RegisterizerIfThenElseNested) {
1587 BufHandle a("A", {5}, kInt);
1588 BufHandle b("B", {5}, kInt);
1589 BufHandle c("C", {5}, kInt);
1590 BufHandle d("D", {5}, kInt);
1591 VarHandle x("x", kInt);
1592
1593 StmtPtr stmt = Block::make({Store::make(
1594 a,
1595 {x},
1596 IfThenElse::make(
1597 CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1598 IfThenElse::make(
1599 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
1600 Load::make(d, {x}),
1601 Load::make(b, {x})),
1602 IfThenElse::make(
1603 CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
1604 Load::make(c, {x}),
1605 Load::make(d, {x}))))});
1606
1607 /*
1608 * A[x] = IfThenElse(x<3 ? 1 : 0,
1609 * IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
1610 * IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
1611 */
1612
1613 std::ostringstream before;
1614 before << *stmt;
1615
1616 // No change.
1617 stmt = registerize(stmt);
1618
1619 std::ostringstream after;
1620 after << *stmt;
1621
1622 ASSERT_EQ(before.str(), after.str());
1623}
1624
1625// Cannot registerize an access completely contained within an IfThenElse
1626// branch, since it is not a Stmt and cannot hold variable definitions. We need
1627// to check that we don't promote the initializer/finalizer to the enclosing
1628// Block.
1629TEST(Registerizer, RegisterizerIfThenElseInternal) {
1630 // Making these floats so they don't get simplified to a single access.
1631 BufHandle a("A", {5}, kFloat);
1632 BufHandle b("B", {5}, kFloat);
1633 VarHandle x("x", kInt);
1634
1635 StmtPtr stmt = Block::make({Store::make(
1636 a,
1637 {x},
1638 IfThenElse::make(
1639 CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1640 Add::make(Load::make(b, {x}), Load::make(b, {x})),
1641 Load::make(b, {x})))});
1642
1643 /*
1644 * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
1645 */
1646
1647 std::ostringstream before;
1648 before << *stmt;
1649
1650 // No change.
1651 stmt = registerize(stmt);
1652
1653 std::ostringstream after;
1654 after << *stmt;
1655
1656 ASSERT_EQ(before.str(), after.str());
1657
1658 // If this was a Cond instead of an IfThenElse then we could registerize the
1659 // two accesses to B[x] in the True branch.
1660
1661 // Actually lets verify that.
1662
1663 stmt = Block::make({Cond::make(
1664 CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1665 Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),
1666 Store::make(a, {x}, Load::make(b, {x})))});
1667
1668 /*
1669 * if (x<3 ? 1 : 0) {
1670 * A[x] = (B[x]) + (B[x]);
1671 * } else {
1672 * A[x] = B[x];
1673 * }
1674 */
1675
1676 stmt = registerize(stmt);
1677
1678 /*
1679 * if (x<3 ? 1 : 0) {
1680 * float B_1 = B[x];
1681 * A[x] = B_1 + B_1;
1682 * } else {
1683 * A[x] = B[x];
1684 * }
1685 */
1686
1687 std::ostringstream oss;
1688 oss << *stmt;
1689
1690 const std::string& verification_pattern =
1691 R"IR(
1692# CHECK-NOT: int
1693# CHECK-NOT: float
1694# CHECK: if (x<3
1695# CHECK: float B_1 =
1696# CHECK: A[x] = B_1 + B_1
1697# CHECK: } else {
1698# CHECK: A[x] = B[x]
1699# CHECK: }
1700# CHECK-NOT: A[x]
1701# CHECK-NOT: B[x])IR";
1702
1703 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1704}
1705
1706// Can registerize a load that occurs in the condition of an IfThenElse;
1707TEST(Registerizer, RegisterizerIfThenElseCondition) {
1708 BufHandle a("A", {5}, kInt);
1709 BufHandle b("B", {5}, kInt);
1710 BufHandle c("C", {5}, kInt);
1711 VarHandle x("x", kInt);
1712
1713 StmtPtr stmt = Block::make(
1714 {Store::make(a, {x}, Load::make(a, {x})),
1715 Store::make(
1716 a,
1717 {x},
1718 IfThenElse::make(
1719 CompareSelect::make(
1720 Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1721 Load::make(b, {0}),
1722 Load::make(c, {0})))});
1723
1724 /*
1725 * A[x] = A[x]; <---- just here so there are enough accesses to combine.
1726 * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
1727 */
1728
1729 stmt = registerize(stmt);
1730
1731 /*
1732 * int A_1 = A[x];
1733 * A_1 = A_1;
1734 * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1735 * A[x] = A_1;
1736 */
1737
1738 std::ostringstream oss;
1739 oss << *stmt;
1740
1741 const std::string& verification_pattern =
1742 R"IR(
1743# CHECK: int A_1 = A[x];
1744# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1745# CHECK: A[x] = A_1;)IR";
1746
1747 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1748}
1749
1750// Appearing in the condition of a Cond makes it visible to the enclosing scope,
1751// and so we can registerize internal usages.
1752TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {
1753 BufHandle a("A", {5}, kInt);
1754 BufHandle b("B", {5}, kInt);
1755 BufHandle c("C", {5}, kInt);
1756 VarHandle x("x", kInt);
1757
1758 StmtPtr stmt = Block::make({Store::make(
1759 b,
1760 {x},
1761 IfThenElse::make(
1762 CompareSelect::make(
1763 Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1764 Add::make(Load::make(a, {x}), 1),
1765 Add::make(Load::make(a, {x}), 10)))});
1766
1767 /*
1768 * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
1769 */
1770
1771 stmt = registerize(stmt);
1772
1773 /*
1774 * int A_1 = A[x];
1775 * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
1776 */
1777
1778 std::ostringstream oss;
1779 oss << *stmt;
1780
1781 const std::string& verification_pattern =
1782 R"IR(
1783# CHECK: int A_1 = A[x];
1784# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";
1785
1786 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1787}
1788
1789// Cannot promote accesses internal to IfThenElse branches even if the enclosing
1790// scope if conditional.
1791TEST(Registerizer, RegisterizerConditionBranchOnly) {
1792 BufHandle a("A", {5}, kInt);
1793 VarHandle x("x", kInt);
1794 StmtPtr stmt = Block::make({For::make(
1795 x,
1796 0,
1797 10,
1798 Block::make({
1799 Cond::make(
1800 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1801 Store::make(
1802 a,
1803 {x},
1804 IfThenElse::make(
1805 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1806 Add::make(Load::make(a, {x}), x),
1807 Add::make(Load::make(a, {x - 5}), x))),
1808 Store::make(
1809 a,
1810 {x - 5},
1811 IfThenElse::make(
1812 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1813 Add::make(Load::make(a, {x}), x),
1814 Add::make(Load::make(a, {x - 5}), x)))),
1815 }))});
1816 stmt = IRSimplifier::simplify(stmt);
1817
1818 std::ostringstream before;
1819 before << *stmt;
1820
1821 /* for (int x = 0; x < 10; x++) {
1822 * if (x<5 ? 1 : 0) {
1823 * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1824 * } else {
1825 * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1826 * }
1827 * }
1828 */
1829
1830 // No change.
1831 stmt = registerize(stmt);
1832
1833 std::ostringstream after;
1834 after << *stmt;
1835
1836 ASSERT_EQ(before.str(), after.str());
1837}
1838
1839// We can registerize an IfThenElse that appears in the condition branch of a
1840// Cond. This is a weird but valid thing to do.
1841TEST(Registerizer, RegisterizerCondIfThenElse) {
1842 BufHandle a("A", {5}, kInt);
1843 BufHandle b("B", {5}, kInt);
1844 BufHandle c("C", {5}, kInt);
1845 VarHandle x("x", kInt);
1846
1847 StmtPtr stmt = Block::make({Cond::make(
1848 CompareSelect::make(
1849 IfThenElse::make(
1850 CompareSelect::make(
1851 Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1852 Load::make(a, {x}),
1853 Load::make(b, {x})),
1854 x,
1855 CompareSelectOperation::kEQ),
1856 Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
1857 nullptr)});
1858
1859 /*
1860 * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
1861 * C[x] = (C[x]) + 1;
1862 * }
1863 */
1864
1865 stmt = registerize(stmt);
1866
1867 // access to A can be registerized, but not B or C
1868
1869 /*
1870 * int A_1 = A[x];
1871 * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
1872 * C[x] = (C[x]) + 1;
1873 * }
1874 */
1875
1876 std::ostringstream oss;
1877 oss << *stmt;
1878
1879 const std::string& verification_pattern =
1880 R"IR(
1881# CHECK: int A_1 = A[x];
1882# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
1883# CHECK: C[x] = (C[x]) + 1;)IR";
1884
1885 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1886}
1887
1888// Can registerize a conditional access in the RHS of a store unhidden by it's
1889// LHS, and hoist it out of a loop.
1890TEST(Registerizer, RegisterizerIfThenElseLoop) {
1891 BufHandle a("A", {5}, kInt);
1892 BufHandle b("B", {5}, kInt);
1893 VarHandle x("x", kInt);
1894 VarHandle y("y", kInt);
1895
1896 StmtPtr stmt = For::make(
1897 y,
1898 0,
1899 10,
1900 Store::make(
1901 a,
1902 {x},
1903 IfThenElse::make(
1904 CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1905 Load::make(a, {x}),
1906 Load::make(b, {y}))));
1907
1908 /*
1909 * for (int y = 0; y < 10; y++) {
1910 * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
1911 * }
1912 */
1913
1914 stmt = registerize(stmt);
1915
1916 /*
1917 * int A_1 = A[x];
1918 * for (int y = 0; y < 10; y++) {
1919 * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1920 * }
1921 * A[x] = A_1;
1922 */
1923
1924 std::ostringstream oss;
1925 oss << *stmt;
1926
1927 const std::string& verification_pattern =
1928 R"IR(
1929# CHECK: int A_1 = A[x];
1930# CHECK: for (
1931# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1932# CHECK: }
1933# CHECK: A[x] = A_1;)IR";
1934
1935 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1936}
1937
1938// Cannot registerize if the RHS overlaps the access creating visibility.
1939TEST(Registerizer, RegisterizerIfThenElseLoopCut) {
1940 BufHandle a("A", {5}, kInt);
1941 BufHandle b("B", {5}, kInt);
1942 VarHandle x("x", kInt);
1943 VarHandle y("y", kInt);
1944
1945 StmtPtr stmt = Block::make({For::make(
1946 y,
1947 0,
1948 10,
1949 Store::make(
1950 a,
1951 {x},
1952 IfThenElse::make(
1953 CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1954 Load::make(a, {x}),
1955 Load::make(a, {y}))))});
1956
1957 /*
1958 * for (int y = 0; y < 10; y++) {
1959 * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
1960 * }
1961 */
1962
1963 std::ostringstream before;
1964 before << *stmt;
1965
1966 // No change.
1967 stmt = registerize(stmt);
1968
1969 std::ostringstream after;
1970 after << *stmt;
1971
1972 ASSERT_EQ(before.str(), after.str());
1973}
1974
1975// Simple case where an access is cut by an overlapping access later in the
1976// program, we can registerize up until the overlap.
1977TEST(Registerizer, RegisterizerPartialAfter) {
1978 BufHandle a("A", {1}, kInt);
1979 VarHandle x("x", kInt);
1980 StmtPtr stmt = Block::make(
1981 {Store::make(a, {0}, 0),
1982 For::make(
1983 x,
1984 0,
1985 10,
1986 Block::make(
1987 {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),
1988 For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});
1989
1990 /*
1991 * A[0] = 0;
1992 * for (int x = 0; x < 10; x++) {
1993 * A[0] = (A[0]) + x;
1994 * }
1995 * for (int x = 1; x < 10; x++) {
1996 * A[x] = A[x - 1];
1997 * }
1998 */
1999
2000 stmt = registerize(stmt);
2001
2002 /*
2003 * int A_1 = 0;
2004 * for (int x = 0; x < 10; x++) {
2005 * A_1 = A_1 + x;
2006 * }
2007 * A[0] = A_1;
2008 * for (int x = 1; x < 10; x++) {
2009 * A[x] = A[x - 1];
2010 * }
2011 */
2012
2013 std::ostringstream oss;
2014 oss << *stmt;
2015
2016 const std::string& verification_pattern =
2017 R"IR(
2018# CHECK: int A_1 = 0;
2019# CHECK: for (
2020# CHECK: A_1 = A_1 + x;
2021# CHECK: }
2022# CHECK: A[0] = A_1;
2023# CHECK: for (
2024# CHECK: A[x] = A[x - 1];
2025# CHECK: }
2026# CHECK-NOT: A)IR";
2027
2028 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2029}
2030
2031// We can registerize an access which overlaps a previous access, the
2032// initializer must be inserted after the previous access.
2033TEST(Registerizer, RegisterizerPartialBefore) {
2034 BufHandle a("A", {1}, kInt);
2035 VarHandle x("x", kInt);
2036 StmtPtr stmt = Block::make(
2037 {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
2038 Store::make(a, {0}, 0),
2039 For::make(
2040 x,
2041 0,
2042 10,
2043 Block::make(
2044 {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
2045
2046 /*
2047 * for (int x = 1; x < 10; x++) {
2048 * A[x] = A[x - 1];
2049 * }
2050 * A[0] = 0;
2051 * for (int x = 0; x < 10; x++) {
2052 * A[0] = (A[0]) + x;
2053 * }
2054 */
2055
2056 stmt = registerize(stmt);
2057
2058 /*
2059 * for (int x = 1; x < 10; x++) {
2060 * A[x] = A[x - 1];
2061 * }
2062 * int A_1 = 0;
2063 * for (int x = 0; x < 10; x++) {
2064 * A_1 = A_1 + x;
2065 * }
2066 * A[0] = A_1;
2067 */
2068
2069 std::ostringstream oss;
2070 oss << *stmt;
2071
2072 const std::string& verification_pattern =
2073 R"IR(
2074# CHECK-NOT: int
2075# CHECK: for (
2076# CHECK: A[x] = A[x - 1];
2077# CHECK: }
2078# CHECK: int A_1 = 0;
2079# CHECK: for (
2080# CHECK: A_1 = A_1 + x;
2081# CHECK: }
2082# CHECK: A[0] = A_1;)IR";
2083
2084 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2085}
2086
2087// The combination of the previous two tests, an access is cut by an overlapping
2088// access in both directions.
2089TEST(Registerizer, RegisterizerPartialInside) {
2090 BufHandle a("A", {1}, kInt);
2091 VarHandle x1("x1", kInt);
2092 VarHandle x2("x2", kInt);
2093 VarHandle x3("x3", kInt);
2094 StmtPtr stmt = Block::make(
2095 {Store::make(a, {0}, 2),
2096 For::make(
2097 x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
2098 For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),
2099 For::make(
2100 x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});
2101
2102 /*
2103 * A[0] = 2;
2104 * for (int x1 = 0; x1 < 10; x1++) {
2105 * A[0] = (A[0]) + x1;
2106 * }
2107 * for (int x2 = 1; x2 < 10; x2++) {
2108 * A[x2] = A[x2 - 1];
2109 * }
2110 * for (int x3 = 0; x3 < 10; x3++) {
2111 * A[0] = (A[0]) + x3;
2112 * }
2113 */
2114
2115 stmt = registerize(stmt);
2116
2117 /*
2118 * int A_1 = 2;
2119 * for (int x1 = 0; x1 < 10; x1++) {
2120 * A_1 = A_1 + x1;
2121 * }
2122 * A[0] = A_1;
2123 * for (int x2 = 1; x2 < 10; x2++) {
2124 * A[x2] = A[x2 - 1];
2125 * }
2126 * int A_2 = A[0];
2127 * for (int x3 = 0; x3 < 10; x3++) {
2128 * A_2 = A_2 + x3;
2129 * }
2130 * A[0] = A_2;
2131 */
2132
2133 std::ostringstream oss;
2134 oss << *stmt;
2135
2136 const std::string& verification_pattern =
2137 R"IR(
2138# CHECK: int A_1 = 2;
2139# CHECK: for (
2140# CHECK: A_1 = A_1 + x1;
2141# CHECK: }
2142# CHECK: A[0] = A_1;
2143# CHECK: for (
2144# CHECK: A[x2] =
2145# CHECK: }
2146# CHECK: int A_2 = A[0];
2147# CHECK: for (
2148# CHECK: A_2 = A_2 + x3;
2149# CHECK: }
2150# CHECK: A[0] = A_2;)IR";
2151
2152 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2153}
2154
2155// An element could be registerized program wide but is cut by a conditional
2156// access, we should break this into two scalars and write back to the buffer
2157// before the condition.
2158TEST(Registerizer, RegisterizerPartialCondition) {
2159 BufHandle a("A", {1}, kInt);
2160 VarHandle x("x", kInt);
2161 StmtPtr stmt = Block::make(
2162 {Store::make(a, {0}, 2),
2163 For::make(
2164 x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
2165 Cond::make(
2166 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2167 Store::make(a, {x}, Load::make(a, {x - 1})),
2168 nullptr),
2169 For::make(
2170 x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});
2171
2172 /*
2173 * A[0] = 2;
2174 * for (int x = 0; x < 10; x++) {
2175 * A[0] = (A[0]) + x;
2176 * }
2177 * if (x<5 ? 1 : 0) {
2178 * A[x] = A[x - 1];
2179 * }
2180 * for (int x = 0; x < 10; x++) {
2181 * A[0] = (A[0]) + x;
2182 * }
2183 */
2184
2185 stmt = registerize(stmt);
2186
2187 /*
2188 * int A_1 = 2;
2189 * for (int x = 0; x < 10; x++) {
2190 * A_1 = A_1 + x;
2191 * }
2192 * A[0] = A_1;
2193 * if (x<5 ? 1 : 0) {
2194 * A[x] = A[x - 1];
2195 * }
2196 * int A_2 = A[0];
2197 * for (int x = 0; x < 10; x++) {
2198 * A_2 = A_2 + x;
2199 * }
2200 * A[0] = A_2;
2201 */
2202
2203 std::ostringstream oss;
2204 oss << *stmt;
2205
2206 const std::string& verification_pattern =
2207 R"IR(
2208# CHECK: int A_1 = 2;
2209# CHECK: for (
2210# CHECK: A_1 = A_1 + x;
2211# CHECK: }
2212# CHECK: A[0] = A_1;
2213# CHECK: if (
2214# CHECK: A[x] =
2215# CHECK: }
2216# CHECK: int A_2 = A[0];
2217# CHECK: for (
2218# CHECK: A_2 = A_2 + x;
2219# CHECK: }
2220# CHECK: A[0] = A_2;)IR";
2221
2222 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2223}
2224
2225// Tests case where an access is cut by an internal conditional access which
2226// itself is registerized.
2227TEST(Registerizer, RegisterizerPartialConditionInternalCut) {
2228 BufHandle a("A", {1}, kInt);
2229 VarHandle x("x", kInt);
2230 StmtPtr stmt = Block::make(
2231 {Store::make(a, {0}, 1),
2232 Store::make(a, {0}, 3),
2233 Cond::make(
2234 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2235 Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
2236 nullptr),
2237 Store::make(a, {0}, 4),
2238 Store::make(a, {0}, 6)});
2239
2240 /*
2241 * A[0] = 1;
2242 * A[0] = 3;
2243 * if (x<5 ? 1 : 0) {
2244 * A[x] = 1;
2245 * A[x] = 3;
2246 * }
2247 * A[0] = 4;
2248 * A[0] = 6;
2249 */
2250
2251 stmt = registerize(stmt);
2252
2253 /*
2254 * int A_1 = 1;
2255 * A_1 = 3;
2256 * A[0] = A_1;
2257 * if (x<5 ? 1 : 0) {
2258 * int A_2 = 1;
2259 * A_2 = 3;
2260 * A[x] = A_2;
2261 * }
2262 * int A_3 = 4;
2263 * A_3 = 6;
2264 * A[0] = A_3;
2265 */
2266
2267 std::ostringstream oss;
2268 oss << *stmt;
2269
2270 const std::string& verification_pattern =
2271 R"IR(
2272# CHECK: int A_1 = 1;
2273# CHECK: A_1 = 3
2274# CHECK: A[0] = A_1;
2275# CHECK: if (
2276# CHECK: int A_2 = 1;
2277# CHECK: A_2 = 3;
2278# CHECK: A[x] = A_2;
2279# CHECK: }
2280# CHECK: int A_3 = 4;
2281# CHECK: A_3 = 6;
2282# CHECK: A[0] = A_3;)IR";
2283
2284 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2285}
2286
2287// First statment in condition closes outer access, but can be registerized with
2288// later statements.
2289TEST(Registerizer, RegisterizerPartialConditionInternalStart) {
2290 BufHandle a("A", {1}, kInt);
2291 VarHandle x("x", kInt);
2292 StmtPtr stmt = Block::make(
2293 {Store::make(a, {0}, 1),
2294 Store::make(a, {0}, 3),
2295 Cond::make(
2296 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2297 Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
2298 nullptr),
2299 Store::make(a, {x}, 4),
2300 Store::make(a, {x}, 6)});
2301
2302 /*
2303 * A[0] = 1;
2304 * A[0] = 3;
2305 * if (x<5 ? 1 : 0) {
2306 * A[x] = 1;
2307 * A[x] = 3;
2308 * }
2309 * A[x] = 4;
2310 * A[x] = 6;
2311 */
2312
2313 stmt = registerize(stmt);
2314
2315 /*
2316 * int A_1 = 1;
2317 * A_1 = 3;
2318 * A[0] = A_1;
2319 * int A_2 = A[x]; <--- must read from the input here.
2320 * if (x<5 ? 1 : 0) {
2321 * A_2 = 1;
2322 * A_2 = 3;
2323 * }
2324 * A_2 = 4;
2325 * A_2 = 6;
2326 * A[x] = A_2;
2327 */
2328
2329 // TODO: I suppose we could refactor with a conditional initializier?
2330
2331 std::ostringstream oss;
2332 oss << *stmt;
2333
2334 const std::string& verification_pattern =
2335 R"IR(
2336# CHECK: int A_1 = 1;
2337# CHECK: A_1 = 3
2338# CHECK: A[0] = A_1;
2339# CHECK: int A_2 = A[x];
2340# CHECK: if (
2341# CHECK: A_2 = 1;
2342# CHECK: A_2 = 3;
2343# CHECK: }
2344# CHECK: A_2 = 4;
2345# CHECK: A_2 = 6;
2346# CHECK: A[x] = A_2;)IR";
2347
2348 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2349}
2350
2351// An access cuts two open overlaps and creates four scalar variables.
2352TEST(Registerizer, RegisterizerPartialOverlapsTwo) {
2353 BufHandle a("A", {1}, kInt);
2354 VarHandle x("x", kInt);
2355 StmtPtr stmt = Block::make(
2356 {Store::make(a, {1}, Load::make(a, {0})),
2357 Store::make(a, {0}, Load::make(a, {1})),
2358 Store::make(a, {0}, Load::make(a, {1})),
2359 For::make(x, 1, 10, Store::make(a, {x}, x)),
2360 Store::make(a, {1}, Load::make(a, {0})),
2361 Store::make(a, {0}, Load::make(a, {1})),
2362 Store::make(a, {0}, Load::make(a, {1}))});
2363
2364 /*
2365 * A[1] = A[0];
2366 * A[0] = A[1];
2367 * A[0] = A[1];
2368 * for (int x = 1; x < 10; x++) {
2369 * A[x] = x;
2370 * }
2371 * A[1] = A[0];
2372 * A[0] = A[1];
2373 * A[0] = A[1];
2374 */
2375
2376 stmt = registerize(stmt);
2377
2378 /*
2379 * int A_1 = A[0];
2380 * int A_2 = A_1;
2381 * A_1 = A_2;
2382 * A_1 = A_2;
2383 * A[1] = A_2;
2384 * A[0] = A_1;
2385 * for (int x = 1; x < 10; x++) {
2386 * A[x] = x;
2387 * }
2388 * int A_3 = A[0];
2389 * int A_4 = A_3;
2390 * A_3 = A_4;
2391 * A_3 = A_4;
2392 * A[1] = A_4;
2393 * A[0] = A_3;
2394 */
2395
2396 std::ostringstream oss;
2397 oss << *stmt;
2398
2399 const std::string& verification_pattern =
2400 R"IR(
2401# CHECK: int A_1 = A[0];
2402# CHECK: int A_2 = A_1;
2403# CHECK: A_1 = A_2;
2404# CHECK: A_1 = A_2;
2405# CHECK: A[1] = A_2;
2406# CHECK: A[0] = A_1;
2407# CHECK: for (
2408# CHECK: A[x] = x;
2409# CHECK: }
2410# CHECK: int A_3 = A[0];
2411# CHECK: int A_4 = A_3;
2412# CHECK: A_3 = A_4;
2413# CHECK: A_3 = A_4;
2414# CHECK: A[1] = A_4;
2415# CHECK: A[0] = A_3;)IR";
2416
2417 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2418}
2419
2420// Nested blocks will automatically be flattened and do not provent
2421// registerization of enclosed accesses.
2422TEST(Registerizer, RegisterizerNestedBlocks) {
2423 BufHandle a("A", {1}, kInt);
2424 VarHandle x("x", kInt);
2425 StmtPtr stmt = Block::make(
2426 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2427 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2428 Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
2429 Block::make(
2430 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),
2431 Block::make(
2432 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});
2433
2434 /*
2435 * A[0] = (A[0]) + 1;
2436 * {
2437 * A[0] = (A[0]) + 2;
2438 * }
2439 * {
2440 * A[0] = (A[0]) + 3;
2441 * {
2442 * A[0] = (A[0]) + 4;
2443 * }
2444 * }
2445 */
2446
2447 stmt = registerize(stmt);
2448
2449 /*
2450 * int A_1 = A[0];
2451 * A_1 = A_1 + 1;
2452 * A_1 = A_1 + 2;
2453 * A_1 = A_1 + 3;
2454 * A_1 = A_1 + 4;
2455 * A[0] = A_1;
2456 */
2457
2458 std::ostringstream oss;
2459 oss << *stmt;
2460
2461 const std::string& verification_pattern =
2462 R"IR(
2463# CHECK: int A_1 = A[0];
2464# CHECK: A_1 = A_1 + 1;
2465# CHECK: A_1 = A_1 + 2;
2466# CHECK: A_1 = A_1 + 3;
2467# CHECK: A_1 = A_1 + 4;
2468# CHECK: A[0] = A_1;)IR";
2469
2470 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2471}
2472
2473// The access can be registerized internally to a condition, but must ensure
2474// that both initializer and finalizer are within the same condition.
2475TEST(Registerizer, RegisterizerNestedConditions) {
2476 BufHandle a("A", {1}, kInt);
2477 VarHandle x("x", kInt);
2478 StmtPtr stmt = Block::make({Cond::make(
2479 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2480 Block::make(
2481 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2482 Cond::make(
2483 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2484 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2485 nullptr)}),
2486 nullptr)});
2487
2488 /*
2489 * if (x<5 ? 1 : 0) {
2490 * A[0] = (A[0]) + 1;
2491 * if (x==2 ? 1 : 0) {
2492 *
2493 * A[0] = (A[0]) + 1;
2494 * }
2495 * }
2496 */
2497
2498 stmt = registerize(stmt);
2499
2500 /*
2501 * if (x<5 ? 1 : 0) {
2502 * int A_1 = A[0];
2503 * A_1 = A_1 + 1;
2504 * if (x==2 ? 1 : 0) {
2505 * A_1 = A_1 + 1;
2506 * }
2507 * A[0] = A_1;
2508 * }
2509 */
2510
2511 std::ostringstream oss;
2512 oss << *stmt;
2513
2514 const std::string& verification_pattern =
2515 R"IR(
2516# CHECK: if (x<5
2517# CHECK: int A_1 = A[0];
2518# CHECK: A_1 = A_1 + 1;
2519# CHECK: if (x==2
2520# CHECK: A_1 = A_1 + 1;
2521# CHECK: }
2522# CHECK: A[0] = A_1;
2523# CHECK: })IR";
2524
2525 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2526}
2527
2528// If an access exists outside the scope of the condition then we can lift
2529// nested conditional usages into the same scalar.
2530TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {
2531 BufHandle a("A", {1}, kInt);
2532 VarHandle x("x", kInt);
2533 StmtPtr stmt = Block::make(
2534 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2535 Cond::make(
2536 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2537 Block::make(
2538 {Store::make(a, {1}, 1),
2539 Cond::make(
2540 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2541 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2542 nullptr)}),
2543 nullptr)});
2544
2545 /*
2546 * A[0] = (A[0]) + 1;
2547 * if (x<5 ? 1 : 0) {
2548 * A[1] = 1;
2549 * if (x==2 ? 1 : 0) {
2550 * A[0] = (A[0]) + 1;
2551 * }
2552 * }
2553 */
2554
2555 stmt = registerize(stmt);
2556
2557 /*
2558 * int A_1 = A[0];
2559 * A_1 = A_1 + 1;
2560 * if (x<5 ? 1 : 0) {
2561 * A[1] = 1;
2562 * if (x==2 ? 1 : 0) {
2563 * A_1 = A_1 + 1;
2564 * }
2565 * }
2566 * A[0] = A_1;
2567 */
2568
2569 std::ostringstream oss;
2570 oss << *stmt;
2571
2572 const std::string& verification_pattern =
2573 R"IR(
2574# CHECK: int A_1 = A[0];
2575# CHECK: A_1 = A_1 + 1;
2576# CHECK: if (x<5
2577# CHECK: A[1] = 1;
2578# CHECK: if (x==2
2579# CHECK: A_1 = A_1 + 1;
2580# CHECK: A[0] = A_1;)IR";
2581
2582 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2583}
2584
2585TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {
2586 BufHandle a("A", {1}, kInt);
2587 VarHandle x("x", kInt);
2588 StmtPtr stmt = Block::make(
2589 {Cond::make(
2590 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2591 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2592 nullptr),
2593 Cond::make(
2594 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2595 Block::make({Cond::make(
2596 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2597 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2598 nullptr)}),
2599 nullptr)});
2600
2601 /*
2602 * if (x==2 ? 1 : 0) {
2603 * A[0] = (A[0]) + 1;
2604 * }
2605 * if (x<5 ? 1 : 0) {
2606 * if (x==2 ? 1 : 0) {
2607 * A[0] = (A[0]) + 1;
2608 * }
2609 * }
2610 */
2611
2612 std::ostringstream before;
2613 before << *stmt;
2614
2615 // No change.
2616 stmt = registerize(stmt);
2617
2618 std::ostringstream after;
2619 after << *stmt;
2620
2621 ASSERT_EQ(before.str(), after.str());
2622
2623 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2624 stmt = registerize(stmt);
2625}
2626
2627TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {
2628 BufHandle a("A", {1}, kInt);
2629 VarHandle x("x", kInt);
2630 StmtPtr stmt = Block::make(
2631 {Cond::make(
2632 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2633 Block::make({Cond::make(
2634 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2635 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2636 nullptr)}),
2637 nullptr),
2638 Cond::make(
2639 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2640 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2641 nullptr)});
2642
2643 /*
2644 * if (x<5 ? 1 : 0) {
2645 * if (x==2 ? 1 : 0) {
2646 * A[0] = (A[0]) + 1;
2647 * }
2648 * }
2649 * if (x==2 ? 1 : 0) {
2650 * A[0] = (A[0]) + 1;
2651 * }
2652 */
2653
2654 std::ostringstream before;
2655 before << *stmt;
2656
2657 // No change.
2658 stmt = registerize(stmt);
2659
2660 std::ostringstream after;
2661 after << *stmt;
2662
2663 ASSERT_EQ(before.str(), after.str());
2664
2665 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2666 stmt = registerize(stmt);
2667}
2668
2669// If an access is cut by another access internal to a condition block, it still
2670// cuts the access.
2671TEST(Registerizer, RegisterizerNestedConditionsCut) {
2672 BufHandle a("A", {1}, kInt);
2673 VarHandle x("x", kInt);
2674 StmtPtr stmt = Block::make(
2675 {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2676 Cond::make(
2677 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2678 Block::make(
2679 {Store::make(a, {x}, 1),
2680 Cond::make(
2681 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2682 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2683 nullptr)}),
2684 nullptr)});
2685
2686 /*
2687 * A[0] = (A[0]) + 1;
2688 * if (x<5 ? 1 : 0) {
2689 * A[x] = 1;
2690 * if (x==2 ? 1 : 0) {
2691 *
2692 * A[0] = (A[0]) + 1;
2693 * }
2694 * }
2695 */
2696
2697 std::ostringstream before;
2698 before << *stmt;
2699
2700 // No change.
2701 stmt = registerize(stmt);
2702
2703 std::ostringstream after;
2704 after << *stmt;
2705
2706 ASSERT_EQ(before.str(), after.str());
2707}
2708
2709TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {
2710 BufHandle a("A", {10}, kInt);
2711 BufHandle b("B", {10}, kInt);
2712 VarHandle x("x", kInt);
2713 StmtPtr stmt = Block::make(
2714 {Cond::make(
2715 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2716 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2717 nullptr),
2718 For::make(
2719 x,
2720 0,
2721 10,
2722 Block::make(
2723 {Store::make(b, {x}, 0),
2724 Cond::make(
2725 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2726 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2727 nullptr)}))});
2728
2729 /*
2730 * if (x==2 ? 1 : 0) {
2731 * A[0] = (A[0]) + 1;
2732 * }
2733 * for (int x = 0; x < 10; x++) {
2734 * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering.
2735 * if (x==2 ? 1 : 0) {
2736 * A[0] = (A[0]) + 1;
2737 * }
2738 * }
2739 */
2740
2741 std::ostringstream before;
2742 before << *stmt;
2743
2744 // No change.
2745 stmt = registerize(stmt);
2746
2747 std::ostringstream after;
2748 after << *stmt;
2749
2750 ASSERT_EQ(before.str(), after.str());
2751}
2752
2753// Three loops and four element regions, three of which should be registerized
2754// at different levels of the IR.
2755TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {
2756 BufHandle a("A", {10}, kInt);
2757 BufHandle b("B", {10}, kInt);
2758 VarHandle x("x", kInt);
2759 StmtPtr stmt = Block::make(
2760 {Store::make(a, {4}, 0),
2761 Cond::make(
2762 CompareSelect::make(x, 2, CompareSelectOperation::kGT),
2763 Cond::make(
2764 CompareSelect::make(x, 3, CompareSelectOperation::kGT),
2765 Block::make({
2766 Cond::make(
2767 CompareSelect::make(x, 4, CompareSelectOperation::kGT),
2768 Block::make({
2769 Store::make(
2770 a, {1}, Add::make(Load::make(a, {1}), 1)),
2771 Store::make(
2772 a, {2}, Add::make(Load::make(a, {2}), 1)),
2773 Store::make(
2774 a, {3}, Add::make(Load::make(a, {3}), 1)),
2775 Store::make(
2776 a, {4}, Add::make(Load::make(a, {4}), 1)),
2777 Store::make(
2778 a, {1}, Add::make(Load::make(a, {1}), 1)),
2779 }),
2780 nullptr),
2781 Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),
2782 }),
2783 nullptr),
2784 nullptr)});
2785
2786 /*
2787 * A[4] = 0;
2788 * if (x>2 ? 1 : 0) {
2789 * if (x>3 ? 1 : 0) {
2790 * if (x>4 ? 1 : 0) {
2791 * A[1] = (A[1]) + 1;
2792 * A[2] = (A[2]) + 1;
2793 * A[3] = (A[3]) + 1;
2794 * A[4] = (A[4]) + 1;
2795 * A[1] = (A[1]) + 1;
2796 * }
2797 * A[2] = (A[2]) + 1;
2798 * }
2799 * }
2800 */
2801
2802 stmt = registerize(stmt);
2803
2804 /*
2805 * int A_1 = 0;
2806 * if (x>2 ? 1 : 0) {
2807 * if (x>3 ? 1 : 0) {
2808 * int A_3 = A[2];
2809 * if (x>4 ? 1 : 0) {
2810 * int A_2 = A[1];
2811 * A_2 = A_2 + 1;
2812 * A_3 = A_3 + 1;
2813 * A[3] = (A[3]) + 1;
2814 * A_1 = A_1 + 1;
2815 * A_2 = A_2 + 1;
2816 * A[1] = A_2;
2817 * }
2818 * A_3 = A_3 + 1;
2819 * A[2] = A_3;
2820 * }
2821 * }
2822 * A[4] = A_1;
2823 */
2824
2825 std::ostringstream oss;
2826 oss << *stmt;
2827
2828 const std::string& verification_pattern =
2829 R"IR(
2830# CHECK: int A_1 = 0;
2831# CHECK: if (x>2 ? 1 : 0) {
2832# CHECK: if (x>3 ? 1 : 0) {
2833# CHECK: int A_3 = A[2];
2834# CHECK: if (x>4 ? 1 : 0) {
2835# CHECK: int A_2 = A[1];
2836# CHECK: A_2 = A_2 + 1;
2837# CHECK: A_3 = A_3 + 1;
2838# CHECK: A[3] = (A[3]) + 1;
2839# CHECK: A_1 = A_1 + 1;
2840# CHECK: A_2 = A_2 + 1;
2841# CHECK: A[1] = A_2;
2842# CHECK: }
2843# CHECK: A_3 = A_3 + 1;
2844# CHECK: A[2] = A_3;
2845# CHECK: }
2846# CHECK: }
2847# CHECK: A[4] = A_1;)IR";
2848
2849 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2850}
2851
2852// Can replace a simple scalar access with a local variable even when that
2853// variable is an outer loop var.
2854TEST(Registerizer, RegisterizerNestedLoopSimple) {
2855 BufHandle a("A", {1}, kInt);
2856 VarHandle x("x", kInt);
2857 VarHandle y("y", kInt);
2858 StmtPtr stmt = Block::make({For::make(
2859 y,
2860 0,
2861 10,
2862 For::make(
2863 x,
2864 0,
2865 10,
2866 Block::make(
2867 {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});
2868
2869 /*
2870 * for (int y = 0; y < 10; y++) {
2871 * for (int x = 0; x < 10; x++) {
2872 * A[y] = (A[y]) + x;
2873 * }
2874 * }
2875 */
2876
2877 stmt = registerize(stmt);
2878
2879 /*
2880 * for (int y = 0; y < 10; y++) {
2881 * int A_1 = A[y];
2882 * for (int x = 0; x < 10; x++) {
2883 * A_1 = A_1 + x;
2884 * }
2885 * A[y] = A_1;
2886 * }
2887 */
2888
2889 std::ostringstream oss;
2890 oss << *stmt;
2891
2892 const std::string& verification_pattern =
2893 R"IR(
2894# CHECK: for (int y
2895# CHECK: int A_1 = A[y];
2896# CHECK: for (int x
2897# CHECK: A_1 = A_1 + x;
2898# CHECK: }
2899# CHECK: A[y] = A_1;
2900# CHECK: })IR";
2901
2902 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2903}
2904
2905// Test the positive case of the hiddenAccess split, where an internal
2906// conditional access can be hoisted up through a loop to match an existing
2907// access in a higher scope and the two can be registerized.
2908TEST(Registerizer, RegisterizerHiddenAccessYes) {
2909 BufHandle a("A", {10}, kInt);
2910 BufHandle b("B", {10}, kInt);
2911 VarHandle x("x", kInt);
2912 VarHandle y("y", kInt);
2913 StmtPtr stmt = Block::make({Cond::make(
2914 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2915 Block::make(
2916 {Store::make(a, {0}, 0),
2917 For::make(
2918 x,
2919 0,
2920 10,
2921 Block::make(
2922 {Store::make(b, {x}, 0),
2923 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2924 Cond::make(
2925 CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
2926 For::make(
2927 y,
2928 0,
2929 10,
2930 Store::make(
2931 a, {0}, Add::make(Load::make(a, {0}), 1))),
2932 nullptr)}))}),
2933 nullptr)});
2934
2935 /*
2936 * if (x==2 ? 1 : 0) {
2937 * A[0] = 0;
2938 * for (int x = 0; x < 10; x++) {
2939 * B[x] = 0;
2940 * if (x==3 ? 1 : 0) {
2941 * for (int y = 0; y < 10; y++) {
2942 * A[0] = (A[0]) + 1;
2943 * }
2944 * }
2945 * }
2946 * }
2947 */
2948
2949 stmt = registerize(stmt);
2950
2951 /*
2952 * if (x==2 ? 1 : 0) {
2953 * int A_1 = 0;
2954 * for (int x = 0; x < 10; x++) {
2955 * B[x] = 0;
2956 * if (x==3 ? 1 : 0) {
2957 * for (int y = 0; y < 10; y++) {
2958 * A_1 = A_1 + 1;
2959 * }
2960 * }
2961 * }
2962 * A[0] = A_1;
2963 * }
2964 */
2965
2966 std::ostringstream oss;
2967 oss << *stmt;
2968
2969 const std::string& verification_pattern =
2970 R"IR(
2971# CHECK: if (x==2
2972# CHECK: int A_1 = 0;
2973# CHECK: for (int x
2974# CHECK: B[x] = 0;
2975# CHECK: if (x==3
2976# CHECK: for (int y
2977# CHECK: A_1 = A_1 + 1;
2978# CHECK: }
2979# CHECK: }
2980# CHECK: }
2981# CHECK: A[0] = A_1;
2982# CHECK: })IR";
2983
2984 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2985}
2986
2987// Test the negative case of the hiddenAccess split, where the hoisted access is
2988// never unhidden at a higher scope and registerization occurs at the lower
2989// scope.
2990TEST(Registerizer, RegisterizerHiddenAccessNo) {
2991 BufHandle a("A", {10}, kInt);
2992 BufHandle b("B", {10}, kInt);
2993 VarHandle x("x", kInt);
2994 VarHandle y("y", kInt);
2995 StmtPtr stmt = Block::make({Cond::make(
2996 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2997 Block::make({For::make(
2998 x,
2999 0,
3000 10,
3001 Block::make(
3002 {Store::make(b, {x}, 0),
3003 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3004 Cond::make(
3005 CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
3006 For::make(
3007 y,
3008 0,
3009 10,
3010 Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3011 nullptr)}))}),
3012 nullptr)});
3013
3014 /*
3015 * if (x==2 ? 1 : 0) {
3016 * A[0] = 0;
3017 * for (int x = 0; x < 10; x++) {
3018 * B[x] = 0;
3019 * if (x==3 ? 1 : 0) {
3020 * for (int y = 0; y < 10; y++) {
3021 * A[0] = (A[0]) + 1;
3022 * }
3023 * }
3024 * }
3025 * }
3026 */
3027
3028 stmt = registerize(stmt);
3029
3030 /*
3031 * if (x==2 ? 1 : 0) {
3032 * for (int x = 0; x < 10; x++) {
3033 * B[x] = 0;
3034 * if (x==3 ? 1 : 0) {
3035 * int A_1 = A[0];
3036 * for (int y = 0; y < 10; y++) {
3037 * A_1 = A_1 + 1;
3038 * }
3039 * A[0] = A_1;
3040 * }
3041 * }
3042 * }
3043 */
3044
3045 std::ostringstream oss;
3046 oss << *stmt;
3047
3048 const std::string& verification_pattern =
3049 R"IR(
3050# CHECK: if (x==2
3051# CHECK: for (int x
3052# CHECK: B[x] = 0;
3053# CHECK: if (x==3
3054# CHECK: int A_1 = A[0];
3055# CHECK: for (int y
3056# CHECK: A_1 = A_1 + 1;
3057# CHECK: }
3058# CHECK: A[0] = A_1;
3059# CHECK: }
3060# CHECK: }
3061# CHECK: })IR";
3062
3063 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3064}
3065
3066// In this case the conditional access must be hoisted by two loops, there are
3067// two accesses here one is unhidden and the other isnt. A[0] can be
3068// registerized but B[0] cannot.
3069TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {
3070 BufHandle a("A", {10}, kInt);
3071 BufHandle b("B", {10}, kInt);
3072 VarHandle x("x", kInt);
3073 VarHandle y("y", kInt);
3074 StmtPtr stmt = Block::make({Cond::make(
3075 CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
3076 Block::make(
3077 {Store::make(a, {0}, 0),
3078 For::make(
3079 x,
3080 0,
3081 10,
3082 For::make(
3083 y,
3084 0,
3085 10,
3086 Block::make({Cond::make(
3087 CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
3088 Block::make(
3089 {Store::make(
3090 a, {0}, Add::make(Load::make(a, {0}), 1)),
3091 Store::make(
3092 b, {0}, Add::make(Load::make(b, {0}), 1))}),
3093 nullptr)})))}),
3094 nullptr)});
3095
3096 /*
3097 * if (x==2 ? 1 : 0) {
3098 * A[0] = 0;
3099 * for (int x = 0; x < 10; x++) {
3100 * for (int y = 0; y < 10; y++) {
3101 * if (y==3 ? 1 : 0) {
3102 * A[0] = (A[0]) + 1;
3103 * B[0] = (B[0]) + 1;
3104 * }
3105 * }
3106 * }
3107 * }
3108 */
3109
3110 stmt = registerize(stmt);
3111
3112 /*
3113 * if (x==2 ? 1 : 0) {
3114 * int A_1 = 0;
3115 * for (int x = 0; x < 10; x++) {
3116 * for (int y = 0; y < 10; y++) {
3117 * if (y==3 ? 1 : 0) {
3118 * A_1 = A_1 + 1;
3119 * B[0] = (B[0]) + 1;
3120 * }
3121 * }
3122 * }
3123 * A[0] = A_1;
3124 * }
3125 */
3126
3127 std::ostringstream oss;
3128 oss << *stmt;
3129
3130 const std::string& verification_pattern =
3131 R"IR(
3132# CHECK: if (x==2
3133# CHECK: int A_1 = 0;
3134# CHECK: for (int x
3135# CHECK: for (int y
3136# CHECK: if (y==3
3137# CHECK: A_1 = A_1 + 1;
3138# CHECK: B[0] = (B[0]) + 1;
3139# CHECK: }
3140# CHECK: }
3141# CHECK: }
3142# CHECK: A[0] = A_1;
3143# CHECK: })IR";
3144
3145 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3146}
3147
3148// Accesses are registerized inside two conditions, but the immeidate parent is
3149// not a condition.
3150TEST(Registerizer, RegisterizerTwoConditionalLoops) {
3151 BufHandle a("A", {1}, kInt);
3152 VarHandle x("x", kInt);
3153 StmtPtr stmt = Block::make(
3154 {Cond::make(
3155 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
3156 For::make(
3157 x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3158 nullptr),
3159 Cond::make(
3160 CompareSelect::make(x, 5, CompareSelectOperation::kGT),
3161 For::make(
3162 x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3163 nullptr)});
3164
3165 /*
3166 * if (x<5 ? 1 : 0) {
3167 * for (int x = 0; x < 10; x++) {
3168 * A[0] = (A[0]) + 1;
3169 * }
3170 * }
3171 * if (x>5 ? 1 : 0) {
3172 * for (int x = 0; x < 10; x++) {
3173 * A[0] = (A[0]) + 1;
3174 * }
3175 * }
3176 */
3177
3178 stmt = registerize(stmt);
3179
3180 /*
3181 * if (x<5 ? 1 : 0) {
3182 * int A_1 = A[0];
3183 * for (int x = 0; x < 10; x++) {
3184 * A_1 = A_1 + 1;
3185 * }
3186 * A[0] = A_1;
3187 * }
3188 * if (x>5 ? 1 : 0) {
3189 * int A_2 = A[0];
3190 * for (int x = 0; x < 10; x++) {
3191 * A_2 = A_2 + 1;
3192 * }
3193 * A[0] = A_2;
3194 * }
3195 */
3196
3197 std::ostringstream oss;
3198 oss << *stmt;
3199
3200 const std::string& verification_pattern =
3201 R"IR(
3202# CHECK: if (x<5
3203# CHECK: int A_1 = A[0];
3204# CHECK: for (int x
3205# CHECK: A_1 = A_1 + 1;
3206# CHECK: }
3207# CHECK: A[0] = A_1;
3208# CHECK: }
3209# CHECK: if (x>5
3210# CHECK: int A_2 = A[0];
3211# CHECK: for (int x
3212# CHECK: A_2 = A_2 + 1;
3213# CHECK: }
3214# CHECK: A[0] = A_2;
3215# CHECK: })IR";
3216
3217 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3218}
3219
3220// Accesses are registerized inside two conditions, cut in the middle.
3221TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {
3222 BufHandle a("A", {1}, kInt);
3223 VarHandle x("x", kInt);
3224 StmtPtr stmt = Block::make(
3225 {Cond::make(
3226 CompareSelect::make(x, 5, CompareSelectOperation::kLT),
3227 For::make(
3228 x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3229 nullptr),
3230 For::make(x, 0, 10, Store::make(a, {x}, 1)),
3231 Cond::make(
3232 CompareSelect::make(x, 5, CompareSelectOperation::kGT),
3233 For::make(
3234 x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3235 nullptr)});
3236
3237 /*
3238 * if (x<5 ? 1 : 0) {
3239 * for (int x = 0; x < 10; x++) {
3240 * A[0] = (A[0]) + 1;
3241 * }
3242 * }
3243 * for (int x = 0; x < 10; x++) {
3244 * A[x] = 1;
3245 * }
3246 * if (x>5 ? 1 : 0) {
3247 * for (int x = 0; x < 10; x++) {
3248 * A[0] = (A[0]) + 1;
3249 * }
3250 * }
3251 */
3252
3253 stmt = registerize(stmt);
3254
3255 /*
3256 * if (x<5 ? 1 : 0) {
3257 * int A_1 = A[0];
3258 * for (int x = 0; x < 10; x++) {
3259 * A_1 = A_1 + 1;
3260 * }
3261 * A[0] = A_1;
3262 * }
3263 * for (int x = 0; x < 10; x++) {
3264 * A[x] = 1;
3265 * }
3266 * if (x>5 ? 1 : 0) {
3267 * int A_2 = A[0];
3268 * for (int x = 0; x < 10; x++) {
3269 * A_2 = A_2 + 1;
3270 * }
3271 * A[0] = A_2;
3272 * }
3273 */
3274
3275 std::ostringstream oss;
3276 oss << *stmt;
3277
3278 const std::string& verification_pattern =
3279 R"IR(
3280# CHECK: if (x<5
3281# CHECK: int A_1 = A[0];
3282# CHECK: for (int x
3283# CHECK: A_1 = A_1 + 1;
3284# CHECK: }
3285# CHECK: A[0] = A_1;
3286# CHECK: }
3287# CHECK: for (int x
3288# CHECK: A[x] = 1;
3289# CHECK: if (x>5
3290# CHECK: int A_2 = A[0];
3291# CHECK: for (int x
3292# CHECK: A_2 = A_2 + 1;
3293# CHECK: }
3294# CHECK: A[0] = A_2;
3295# CHECK: })IR";
3296
3297 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3298}
3299
3300// references a Let var in a local scope which cannot be hoisted out of the
3301// loop.
3302TEST(Registerizer, RegisterizerLoopLetVar) {
3303 BufHandle a("A", {10}, kInt);
3304 VarHandle x("x", kInt);
3305 VarHandle y("y", kInt);
3306 StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
3307 x,
3308 0,
3309 10,
3310 Block::make(
3311 {Let::make(y, 30),
3312 Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));
3313
3314 /*
3315 * for (int x = 0; x < 10; x++) {
3316 * int y = 30;
3317 * A[y] = x + (A[y]);
3318 * }
3319 */
3320
3321 std::ostringstream before;
3322 before << *stmt;
3323
3324 // No change.
3325 stmt = registerize(stmt);
3326
3327 std::ostringstream after;
3328 after << *stmt;
3329
3330 ASSERT_EQ(before.str(), after.str());
3331}
3332
3333// references a Let var in an outer scope that does not prevent hoisting the
3334// initializer.
3335TEST(Registerizer, RegisterizerLoopLetVarOuter) {
3336 BufHandle a("A", {10}, kInt);
3337 VarHandle x("x", kInt);
3338 VarHandle y("y", kInt);
3339 StmtPtr stmt = Block::make(
3340 {Let::make(y, 30),
3341 For::make(
3342 x,
3343 0,
3344 10,
3345 Block::make(
3346 {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});
3347
3348 /*
3349 * int y = 30;
3350 * for (int x = 0; x < 10; x++) {
3351 * A[y] = x + (A[y]);
3352 * }
3353 */
3354
3355 stmt = registerize(stmt);
3356
3357 /*
3358 * int y = 30;
3359 * int A_1 = A[y];
3360 * for (int x = 0; x < 10; x++) {
3361 * A_1 = A_1 + x;
3362 * }
3363 * A[y] = A_1;
3364 */
3365
3366 std::ostringstream oss;
3367 oss << *stmt;
3368
3369 const std::string& verification_pattern =
3370 R"IR(
3371# CHECK: int y = 30;
3372# CHECK: int A_1 = A[y];
3373# CHECK: for (int x
3374# CHECK: A_1 = A_1 + x;
3375# CHECK: A[y] = A_1;)IR";
3376
3377 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3378}
3379
3380// Okay so the registerizer generally goes after index flattening, but just in
3381// case. Test multi index registerization.
3382TEST(Registerizer, RegisterizerMultiDim) {
3383 BufHandle a("A", {3, 4, 5}, kInt);
3384 VarHandle x("x", kInt);
3385 StmtPtr stmt = Block::make(
3386 {Store::make(a, {0, 1, 2}, 0),
3387 For::make(
3388 x,
3389 0,
3390 10,
3391 Block::make({Store::make(
3392 a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});
3393
3394 /*
3395 * A[0, 1, 2] = 0;
3396 * for (int x = 0; x < 10; x++) {
3397 * A[0, 1, 2] = (A[0, 1, 2]) + x;
3398 * }
3399 */
3400
3401 stmt = registerize(stmt);
3402
3403 /*
3404 * int A_1 = 0;
3405 * for (int x = 0; x < 10; x++) {
3406 * A_1 = x + A_1;
3407 * }
3408 * A[0, 1, 2] = A_1;
3409 */
3410
3411 std::ostringstream oss;
3412 oss << *stmt;
3413
3414 const std::string& verification_pattern =
3415 R"IR(
3416# CHECK: int A_1 = 0;
3417# CHECK: for (int x = 0; x < 10; x++)
3418# CHECK-NOT: A[
3419# CHECK: A_1 =
3420# CHECK: A[0, 1, 2] = A_1;)IR";
3421
3422 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3423}
3424
3425// Wont registerize if only some dims match, but will still registerize distinct
3426// elements.
3427TEST(Registerizer, RegisterizerMultiDimPartial) {
3428 BufHandle a("A", {3, 4, 5}, kInt);
3429 VarHandle x("x", kInt);
3430 StmtPtr stmt = Block::make(
3431 {Store::make(a, {0, 1, 2}, 0),
3432 For::make(
3433 x,
3434 0,
3435 10,
3436 Block::make({Store::make(
3437 a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});
3438
3439 /*
3440 * A[0, 1, 2] = 0;
3441 * for (int x = 0; x < 10; x++) {
3442 * A[0, 2, 2] = (A[0, 1, 4]) + x;
3443 * }
3444 */
3445
3446 stmt = registerize(stmt);
3447
3448 /*
3449 * A[0, 1, 2] = 0;
3450 * int A_1 = A[0, 1, 4];
3451 * int A_2 = A[0, 2, 2];
3452 * for (int x = 0; x < 10; x++) {
3453 * A_2 = A_1 + x;
3454 * }
3455 * A[0, 2, 2] = A_2;
3456 */
3457
3458 std::ostringstream oss;
3459 oss << *stmt;
3460
3461 const std::string& verification_pattern =
3462 R"IR(
3463# CHECK: A[0, 1, 2] = 0;
3464# CHECK: int A_1 = A[0, 1, 4];
3465# CHECK: int A_2 = A[0, 2, 2];
3466# CHECK: for (
3467# CHECK: A_2 = A_1 + x;
3468# CHECK: A[0, 2, 2] = A_2;)IR";
3469
3470 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3471}
3472
3473// If they could overlap across all dimensions we cannot registerize.
3474TEST(Registerizer, RegisterizerMultiDimOverlap) {
3475 BufHandle a("A", {3, 4, 5}, kInt);
3476 VarHandle x("x", kInt);
3477 VarHandle y("y", kInt);
3478 StmtPtr stmt = Block::make(
3479 {Store::make(a, {0, 1, 2}, 0),
3480 For::make(
3481 x,
3482 0,
3483 10,
3484 Block::make({Store::make(
3485 a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});
3486 stmt = IRSimplifier::simplify(stmt);
3487
3488 /*
3489 * A[0, 1, 2] = 0;
3490 * for (int x = 0; x < 10; x++) {
3491 * A[0, x, 2] = (A[y, 2, 2]) + x;
3492 * }
3493 */
3494
3495 std::ostringstream before;
3496 before << *stmt;
3497
3498 // No change.
3499 stmt = registerize(stmt);
3500
3501 std::ostringstream after;
3502 after << *stmt;
3503
3504 ASSERT_EQ(before.str(), after.str());
3505}
3506
3507// But, if one dimension is known to be distinct they do not overlap.
3508TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
3509 BufHandle a("A", {3, 4, 5}, kInt);
3510 VarHandle x("x", kInt);
3511 VarHandle y("y", kInt);
3512 StmtPtr stmt = Block::make(
3513 {Store::make(a, {0, 1, 2}, 0),
3514 For::make(
3515 x,
3516 0,
3517 10,
3518 Block::make({Store::make(
3519 a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});
3520
3521 /*
3522 * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store.
3523 * for (int x = 0; x < 10; x++) {
3524 * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff.
3525 * }
3526 */
3527
3528 stmt = registerize(stmt);
3529
3530 /*
3531 * A[0, 1, 2] = 0;
3532 * int A_1 = A[y, 2, 4];
3533 * for (int x = 0; x < 10; x++) {
3534 * A[0, x, 2] = A_1 + x;
3535 * }
3536 */
3537
3538 std::ostringstream oss;
3539 oss << *stmt;
3540
3541 const std::string& verification_pattern =
3542 R"IR(
3543# CHECK: A[0, 1, 2] = 0;
3544# CHECK: int A_1 = A[y, 2, 4];
3545# CHECK: for (
3546# CHECK: A[0, x, 2] = A_1 + x;
3547# CHECK: })IR";
3548
3549 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3550}
3551
3552// A 3D reduction with different input dimensionality.
3553TEST(Registerizer, RegisterizerMultiDim3DReduction1) {
3554 BufHandle a("A", {10}, kInt);
3555 BufHandle b("B", {10, 10}, kInt);
3556 BufHandle c("C", {10, 10, 10}, kInt);
3557 VarHandle x("x", kInt);
3558 VarHandle y("y", kInt);
3559 VarHandle z("z", kInt);
3560 StmtPtr stmt = For::make(
3561 x,
3562 0,
3563 10,
3564 For::make(
3565 y,
3566 0,
3567 10,
3568 For::make(
3569 z,
3570 0,
3571 10,
3572 Store::make(
3573 c,
3574 {x, y, z},
3575 Add::make(
3576 Load::make(c, {x, y, z}),
3577 Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));
3578
3579 /*
3580 * for (int x = 0; x < 10; x++) {
3581 * for (int y = 0; y < 10; y++) {
3582 * for (int z = 0; z < 10; z++) {
3583 * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
3584 * }
3585 * }
3586 * }
3587 */
3588
3589 // We can registerize the A and B access since they can be hoisted before
3590 // hitting a dependent loop var.
3591
3592 stmt = registerize(stmt);
3593
3594 /*
3595 * for (int x = 0; x < 10; x++) {
3596 * int A_1 = A[x];
3597 * for (int y = 0; y < 10; y++) {
3598 * int B_1 = B[x, y];
3599 * for (int z = 0; z < 10; z++) {
3600 * C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3601 * }
3602 * }
3603 * }
3604 */
3605
3606 std::ostringstream oss;
3607 oss << *stmt;
3608
3609 const std::string& verification_pattern =
3610 R"IR(
3611# CHECK: for (int x
3612# CHECK: int A_1 = A[x];
3613# CHECK: for (int y
3614# CHECK: int B_1 = B[x, y];
3615# CHECK: for (int z
3616# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3617# CHECK: })IR";
3618
3619 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3620}
3621
3622// A 3D reduction with the same smaller dimensionality using different loop
3623// vars.
3624TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
3625 BufHandle a("A", {10}, kInt);
3626 BufHandle b("B", {10}, kInt);
3627 BufHandle c("C", {10}, kInt);
3628 VarHandle x("x", kInt);
3629 VarHandle y("y", kInt);
3630 VarHandle z("z", kInt);
3631 StmtPtr stmt = For::make(
3632 x,
3633 0,
3634 10,
3635 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3636 For::make(
3637 y,
3638 0,
3639 10,
3640 For::make(
3641 z,
3642 0,
3643 10,
3644 Store::make(
3645 c,
3646 {x},
3647 Add::make(
3648 Load::make(c, {x}),
3649 Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));
3650
3651 /*
3652 * for (int x = 0; x < 10; x++) {
3653 * for (int y = 0; y < 10; y++) {
3654 * for (int z = 0; z < 10; z++) {
3655 * C[x] = (C[x]) + (B[y]) * (A[x]);
3656 * }
3657 * }
3658 * }
3659 */
3660
3661 // We can registerize all accesses, the A and C access can be hoisted to the
3662 // outer loop since they depend only on it's loop var while the B can only be
3663 // raised to the loop of y.
3664
3665 stmt = registerize(stmt);
3666
3667 /*
3668 * for (int x = 0; x < 10; x++) {
3669 * int A_1 = A[x];
3670 * int C_1 = C[x];
3671 * for (int y = 0; y < 10; y++) {
3672 * int B_1 = B[y];
3673 * for (int z = 0; z < 10; z++) {
3674 * C_1 = A_1 * B_1 + C_1;
3675 * }
3676 * }
3677 * C[x] = C_1;
3678 * }
3679 */
3680
3681 std::ostringstream oss;
3682 oss << *stmt;
3683
3684 const std::string& verification_pattern =
3685 R"IR(
3686# CHECK: for (int x
3687# CHECK: int A_1 = A[x];
3688# CHECK: int C_1 = C[x];
3689# CHECK: for (int y
3690# CHECK: int B_1 = B[y];
3691# CHECK: for (int z
3692# CHECK: C_1 = A_1 * B_1 + C_1;
3693# CHECK: }
3694# CHECK: }
3695# CHECK: C[x] = C_1;
3696# CHECK: })IR";
3697
3698 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3699}
3700
3701} // namespace jit
3702} // namespace torch
3703