1 | #include <gtest/gtest.h> |
2 | #include "test/cpp/tensorexpr/test_base.h" |
3 | |
4 | #include "test/cpp/tensorexpr/test_utils.h" |
5 | #include "torch/csrc/jit/tensorexpr/ir_simplifier.h" |
6 | #include "torch/csrc/jit/tensorexpr/registerizer.h" |
7 | |
8 | #include <iostream> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | using namespace torch::jit::tensorexpr; |
13 | |
14 | // Can replace a simple scalar access with a local variable. |
15 | TEST(Registerizer, RegisterizerSimple) { |
16 | BufHandle a("A" , {1}, kInt); |
17 | VarHandle x("x" , kInt); |
18 | StmtPtr stmt = Block::make( |
19 | {Store::make(a, {0}, 0), |
20 | For::make( |
21 | x, |
22 | 0, |
23 | 10, |
24 | Block::make( |
25 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); |
26 | |
27 | /* |
28 | * A[0] = 0; |
29 | * for (int x = 0; x < 10; x++) { |
30 | * A[0] = (A[0]) + x; |
31 | * } |
32 | */ |
33 | |
34 | stmt = registerize(stmt); |
35 | |
36 | /* |
37 | * int A_1 = 0; |
38 | * for (int x = 0; x < 10; x++) { |
39 | * A_1 = x + A_1; |
40 | * } |
41 | * A[0] = A_1; |
42 | */ |
43 | |
44 | std::ostringstream oss; |
45 | oss << *stmt; |
46 | |
47 | const std::string& verification_pattern = |
48 | R"IR( |
49 | # CHECK: int A_1 = 0; |
50 | # CHECK: for (int x = 0; x < 10; x++) |
51 | # CHECK-NOT: A[ |
52 | # CHECK: A_1 = |
53 | # CHECK: A[0] = A_1;)IR" ; |
54 | |
55 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
56 | } |
57 | |
58 | // Won't do replacement of a loop access. |
59 | TEST(Registerizer, RegisterizerLoop) { |
60 | BufHandle a("A" , {10}, kInt); |
61 | VarHandle x("x" , kInt); |
62 | StmtPtr stmt = Block::make( |
63 | {Store::make(a, {0}, 0), |
64 | For::make( |
65 | x, |
66 | 0, |
67 | 10, |
68 | Block::make( |
69 | {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); |
70 | |
71 | /* |
72 | * A[0] = 0; |
73 | * for (int x = 0; x < 10; x++) { |
74 | * A[x] = (A[x]) + x; |
75 | * } |
76 | */ |
77 | |
78 | // No change. |
79 | stmt = registerize(stmt); |
80 | |
81 | /* |
82 | * A[0] = 0; |
83 | * for (int x = 0; x < 10; x++) { |
84 | * A[x] = (A[x]) + x; |
85 | * } |
86 | */ |
87 | |
88 | std::ostringstream oss; |
89 | oss << *stmt; |
90 | |
91 | const std::string& verification_pattern = |
92 | R"IR( |
93 | # CHECK-NOT: int |
94 | # CHECK: A[0] = 0; |
95 | # CHECK: for (int x = 0; x < 10; x++) |
96 | # CHECK-NOT: A_ |
97 | # CHECK: A[x] = |
98 | # CHECK-NOT: A[0] = A_1;)IR" ; |
99 | |
100 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
101 | } |
102 | |
103 | // Won't replace even if the load is a fixed scalar, since the store could |
104 | // invalidate it. |
105 | TEST(Registerizer, RegisterizerLoopFixedLoad) { |
106 | BufHandle a("A" , {1}, kInt); |
107 | VarHandle x("x" , kInt); |
108 | StmtPtr stmt = Block::make( |
109 | {Store::make(a, {0}, 0), |
110 | For::make( |
111 | x, |
112 | 0, |
113 | 10, |
114 | Block::make( |
115 | {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))}); |
116 | |
117 | /* |
118 | * A[0] = 0; |
119 | * for (int x = 0; x < 10; x++) { |
120 | * A[x] = (A[0]) + x; |
121 | * } |
122 | */ |
123 | |
124 | // No change. |
125 | stmt = registerize(stmt); |
126 | |
127 | /* |
128 | * A[0] = 0; |
129 | * for (int x = 0; x < 10; x++) { |
130 | * A[x] = (A[0]) + x; |
131 | * } |
132 | */ |
133 | |
134 | std::ostringstream oss; |
135 | oss << *stmt; |
136 | |
137 | const std::string& verification_pattern = |
138 | R"IR( |
139 | # CHECK-NOT: int |
140 | # CHECK: A[0] = 0; |
141 | # CHECK: for (int x = 0; x < 10; x++) |
142 | # CHECK-NOT: A_ |
143 | # CHECK: A[x] = |
144 | # CHECK-NOT: A[0] = A_1;)IR" ; |
145 | |
146 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
147 | } |
148 | |
149 | // We can registerize accesses that occur entirely within inner scopes, even if |
150 | // they depend on the loop var. |
151 | TEST(Registerizer, RegisterizerLoopInternal) { |
152 | BufHandle a("A" , {1}, kInt); |
153 | VarHandle x("x" , kInt); |
154 | StmtPtr stmt = Block::make({For::make( |
155 | x, |
156 | 0, |
157 | 10, |
158 | Block::make( |
159 | {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)), |
160 | Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); |
161 | |
162 | /* |
163 | * for (int x = 0; x < 10; x++) { |
164 | * A[x] = (A[x]) + x; |
165 | * A[x] = (A[x]) + x; |
166 | * } |
167 | */ |
168 | |
169 | stmt = registerize(stmt); |
170 | |
171 | // TODO: the order of terms in addition changes and in general depends on |
172 | // some hash value. This results in unpredictable swaps of the operands from |
173 | // random changes, which is not great. Ideally, we should ensure some |
174 | // specific order (ideally, the original one). |
175 | /* |
176 | * for (int x = 0; x < 10; x++) { |
177 | * int A_1 = A[x]; |
178 | * A_1 = x + A_1; |
179 | * A_1 = x + A_1; |
180 | * A[x] = A_1; |
181 | * } |
182 | */ |
183 | |
184 | std::ostringstream oss; |
185 | oss << *stmt; |
186 | |
187 | const std::string& verification_pattern = |
188 | R"IR( |
189 | # CHECK: for (int x = 0; x < 10; x++) |
190 | # CHECK: int A_1 = A[x]; |
191 | # CHECK: A_1 = A_1 + x; |
192 | # CHECK: A_1 = A_1 + x; |
193 | # CHECK: A[x] = A_1; |
194 | # CHECK: })IR" ; |
195 | |
196 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
197 | } |
198 | |
199 | // An access can be overlapped by another read in the same Expr. In this case |
200 | // B[z] and B[y] overlap and prevent registerization of both accesses. |
201 | TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { |
202 | BufHandle a("A" , {10}, kInt); |
203 | BufHandle b("B" , {10}, kInt); |
204 | VarHandle x("x" , kInt); |
205 | VarHandle y("y" , kInt); |
206 | VarHandle z("z" , kInt); |
207 | StmtPtr stmt = Block::make({For::make( |
208 | x, |
209 | 0, |
210 | 10, |
211 | Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))}); |
212 | stmt = IRSimplifier::simplify(stmt); |
213 | |
214 | /* |
215 | * for (int x = 0; x < 10; x++) { |
216 | * A[x] = (B[y]) + (B[z]); |
217 | * } |
218 | */ |
219 | |
220 | std::ostringstream before; |
221 | before << *stmt; |
222 | |
223 | // No change. |
224 | stmt = registerize(stmt); |
225 | |
226 | std::ostringstream after; |
227 | after << *stmt; |
228 | |
229 | ASSERT_EQ(before.str(), after.str()); |
230 | } |
231 | |
232 | TEST(Registerizer, RegisterizerLoopInternalRepeated) { |
233 | BufHandle a("A" , {1}, kInt); |
234 | VarHandle x("x" , kInt); |
235 | StmtPtr stmt = Block::make( |
236 | {For::make( |
237 | x, |
238 | 0, |
239 | 10, |
240 | Block::make( |
241 | {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), |
242 | Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})), |
243 | For::make( |
244 | x, |
245 | 0, |
246 | 10, |
247 | Block::make( |
248 | {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), |
249 | Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})) |
250 | |
251 | }); |
252 | |
253 | /* |
254 | * for (int x = 0; x < 10; x++) { |
255 | * A[0] = x + (A[1]); |
256 | * A[0] = x + (A[1]); |
257 | * } |
258 | * for (int x = 0; x < 10; x++) { |
259 | * A[0] = x + (A[1]); |
260 | * A[0] = x + (A[1]); |
261 | * } |
262 | */ |
263 | |
264 | stmt = registerize(stmt); |
265 | |
266 | /* |
267 | * int A_1 = A[1]; |
268 | * int A_2 = A[0]; |
269 | * for (int x = 0; x < 10; x++) { |
270 | * A_2 = A_1 + x; |
271 | * A_2 = A_1 + x; |
272 | * } |
273 | * for (int x = 0; x < 10; x++) { |
274 | * A_2 = A_1 + x; |
275 | * A_2 = A_1 + x; |
276 | * } |
277 | * A[0] = A_2; |
278 | */ |
279 | |
280 | std::ostringstream oss; |
281 | oss << *stmt; |
282 | |
283 | const std::string& verification_pattern = |
284 | R"IR( |
285 | # CHECK: int A_1 = A[1]; |
286 | # CHECK: int A_2 = A[0]; |
287 | # CHECK: for (int x = 0; x < 10; x++) |
288 | # CHECK: A_2 = A_1 + x; |
289 | # CHECK: A_2 = A_1 + x; |
290 | # CHECK: } |
291 | # CHECK: for (int x = 0; x < 10; x++) |
292 | # CHECK: A_2 = A_1 + x; |
293 | # CHECK: A_2 = A_1 + x; |
294 | # CHECK: } |
295 | # CHECK-NOT: A[1] |
296 | # CHECK: A[0] = A_2; |
297 | # CHECK-NOT: A[1] |
298 | # CHECK: })IR" ; |
299 | |
300 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
301 | } |
302 | |
303 | TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { |
304 | BufHandle a("A" , {1}, kInt); |
305 | VarHandle x("x" , kInt); |
306 | StmtPtr stmt = Block::make( |
307 | {For::make( |
308 | x, |
309 | 0, |
310 | 10, |
311 | Block::make( |
312 | {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), |
313 | Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})), |
314 | For::make( |
315 | x, |
316 | 0, |
317 | 10, |
318 | Block::make( |
319 | {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), |
320 | Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})) |
321 | |
322 | }); |
323 | stmt = IRSimplifier::simplify(stmt); |
324 | |
325 | /* |
326 | * for (int x = 0; x < 10; x++) { |
327 | * A[0] = (A[x]) + x; |
328 | * A[0] = (A[x]) + x; |
329 | * } |
330 | * for (int x = 0; x < 10; x++) { |
331 | * A[0] = (A[x]) + x; |
332 | * A[0] = (A[x]) + x; |
333 | * } |
334 | */ |
335 | |
336 | std::ostringstream before; |
337 | before << *stmt; |
338 | |
339 | // No change. |
340 | stmt = registerize(stmt); |
341 | |
342 | std::ostringstream after; |
343 | after << *stmt; |
344 | |
345 | ASSERT_EQ(before.str(), after.str()); |
346 | } |
347 | |
348 | TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { |
349 | BufHandle a("A" , {1}, kInt); |
350 | VarHandle x("x" , kInt); |
351 | VarHandle y("y" , kInt); |
352 | StmtPtr stmt = IRSimplifier::simplify(Block::make( |
353 | {For::make( |
354 | x, |
355 | 0, |
356 | 10, |
357 | Block::make( |
358 | {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), |
359 | Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})), |
360 | For::make( |
361 | x, |
362 | 0, |
363 | 10, |
364 | Block::make( |
365 | {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), |
366 | Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) |
367 | |
368 | })); |
369 | |
370 | /* |
371 | * for (int x = 0; x < 10; x++) { |
372 | * A[0] = (A[x]) + x; |
373 | * A[0] = (A[x]) + x; |
374 | * } |
375 | * for (int x = 0; x < 10; x++) { |
376 | * A[0] = (A[x]) + x; |
377 | * A[0] = (A[x]) + x; |
378 | * } |
379 | */ |
380 | |
381 | std::ostringstream before; |
382 | before << *stmt; |
383 | |
384 | // No change. |
385 | stmt = registerize(stmt); |
386 | |
387 | std::ostringstream after; |
388 | after << *stmt; |
389 | |
390 | ASSERT_EQ(before.str(), after.str()); |
391 | } |
392 | |
393 | // Will registerize multiple accesses of different items of the same buffer. |
394 | TEST(Registerizer, RegisterizerMultiVar) { |
395 | BufHandle a("A" , {2}, kInt); |
396 | VarHandle x("x" , kInt); |
397 | StmtPtr stmt = Block::make({ |
398 | Store::make(a, {0}, 0), |
399 | Store::make(a, {1}, 0), |
400 | For::make( |
401 | x, |
402 | 0, |
403 | 10, |
404 | Block::make( |
405 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), |
406 | Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), |
407 | }); |
408 | |
409 | /* |
410 | * A[0] = 0; |
411 | * A[1] = 0; |
412 | * for (int x = 0; x < 10; x++) { |
413 | * A[0] = (A[0]) + x; |
414 | * A[1] = (A[1]) - x; |
415 | * } |
416 | */ |
417 | |
418 | stmt = registerize(stmt); |
419 | |
420 | /* |
421 | * int A_1 = 0; |
422 | * int A_2 = 0; |
423 | * for (int x = 0; x < 10; x++) { |
424 | * A_2 = x + A_2; |
425 | * A_1 = A_1 - x; |
426 | * } |
427 | * A[1] = A_2; |
428 | * A[0] = A_1; |
429 | */ |
430 | |
431 | std::ostringstream oss; |
432 | oss << *stmt; |
433 | |
434 | const std::string& verification_pattern = |
435 | R"IR( |
436 | # CHECK: int A_1 = 0; |
437 | # CHECK: int A_2 = 0; |
438 | # CHECK: for (int x = 0; x < 10; x++) |
439 | # CHECK-NOT: A[ |
440 | # CHECK: A_1 = |
441 | # CHECK: A_2 = |
442 | # CHECK: A[1] = A_2 |
443 | # CHECK: A[0] = A_1;)IR" ; |
444 | |
445 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
446 | } |
447 | |
448 | // Will registerize the valid accesses while skipping invalid replacements. |
449 | TEST(Registerizer, RegisterizerVariableLoad) { |
450 | BufHandle a("A" , {1}, kInt); |
451 | BufHandle b("B" , {10}, kInt); |
452 | VarHandle x("x" , kInt); |
453 | VarHandle x2("x" , kInt); |
454 | StmtPtr stmt = Block::make( |
455 | {Store::make(a, {0}, 0), |
456 | For::make(x, 0, 10, Store::make(b, {x}, x)), |
457 | For::make( |
458 | x2, |
459 | 0, |
460 | 10, |
461 | Block::make({Store::make( |
462 | a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))}); |
463 | |
464 | /* |
465 | * A[0] = 0; |
466 | * for (int x = 0; x < 10; x++) { |
467 | * B[x] = x; |
468 | * } |
469 | * for (int x_1 = 0; x_1 < 10; x_1++) { |
470 | * A[0] = (A[0]) + (B[x_1]); |
471 | * } |
472 | */ |
473 | |
474 | stmt = registerize(stmt); |
475 | |
476 | /* |
477 | * int A_1 = 0; |
478 | * for (int x = 0; x < 10; x++) { |
479 | * B[x] = x; |
480 | * } |
481 | * for (int x_1 = 0; x_1 < 10; x_1++) { |
482 | * A_1 = A_1 + (B[x_1]); |
483 | * } |
484 | * A[0] = A_1; |
485 | */ |
486 | |
487 | std::ostringstream oss; |
488 | oss << *stmt; |
489 | |
490 | const std::string& verification_pattern = |
491 | R"IR( |
492 | # CHECK: int A_1 = 0; |
493 | # CHECK: for (int x = 0; x < 10; x++) |
494 | # CHECK: B[x] = x |
495 | # CHECK: for (int x_1 = 0; x_1 < 10; x_1++) |
496 | # CHECK-NOT: A[ |
497 | # CHECK: A_1 = |
498 | # CHECK: A[0] = A_1;)IR" ; |
499 | |
500 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
501 | } |
502 | |
503 | // Can registerize variable accesses so long as the variable does not change. |
504 | TEST(Registerizer, RegisterizerSymbolicIndices) { |
505 | VarHandle i("i" , kInt); |
506 | VarHandle N("N" , kInt); |
507 | BufHandle a("A" , {N}, kInt); |
508 | VarHandle x("x" , kInt); |
509 | StmtPtr stmt = Block::make( |
510 | {Store::make(a, {i}, 0), |
511 | For::make( |
512 | x, |
513 | 0, |
514 | 10, |
515 | Block::make( |
516 | {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))}); |
517 | |
518 | /* |
519 | * A[i] = 0; |
520 | * for (int x = 0; x < 10; x++) { |
521 | * A[i] = (A[i]) + x; |
522 | * } |
523 | */ |
524 | |
525 | stmt = registerize(stmt); |
526 | |
527 | /* |
528 | * int A_1 = 0; |
529 | * for (int x = 0; x < 10; x++) { |
530 | * A_1 = x + A_1; |
531 | * } |
532 | * A[i] = A_1; |
533 | */ |
534 | |
535 | std::ostringstream oss; |
536 | oss << *stmt; |
537 | |
538 | const std::string& verification_pattern = |
539 | R"IR( |
540 | # CHECK: int A_1 = 0; |
541 | # CHECK: for (int x = 0; x < 10; x++) |
542 | # CHECK-NOT: A[ |
543 | # CHECK: A_1 = |
544 | # CHECK: A[i] = A_1;)IR" ; |
545 | |
546 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
547 | } |
548 | |
549 | // Can registerize accesses dependent on multiple loop vars. |
550 | TEST(Registerizer, RegisterizerMultiLoop) { |
551 | BufHandle a("A" , {1}, kInt); |
552 | VarHandle x("x" , kInt); |
553 | VarHandle y("y" , kInt); |
554 | StmtPtr stmt = Block::make( |
555 | {Store::make(a, {0}, 0), |
556 | For::make( |
557 | x, |
558 | 0, |
559 | 10, |
560 | For::make( |
561 | y, |
562 | 0, |
563 | 10, |
564 | Block::make({Store::make( |
565 | a, |
566 | {0}, |
567 | Mul::make(Add::make(Load::make(a, {0}), x), y))})))}); |
568 | |
569 | /* |
570 | * A[0] = 0; |
571 | * for (int x = 0; x < 10; x++) { |
572 | * for (int y = 0; y < 10; y++) { |
573 | * A[0] = x * y + (A[0]) * y; |
574 | * } |
575 | * } |
576 | */ |
577 | |
578 | stmt = registerize(stmt); |
579 | |
580 | /* |
581 | * int A_1 = 0; |
582 | * for (int x = 0; x < 10; x++) { |
583 | * for (int y = 0; y < 10; y++) { |
584 | * A_1 = x * y + y * A_1; |
585 | * } |
586 | * } |
587 | * A[0] = A_1; |
588 | */ |
589 | |
590 | std::ostringstream oss; |
591 | oss << *stmt; |
592 | |
593 | const std::string& verification_pattern = |
594 | R"IR( |
595 | # CHECK: int A_1 = 0; |
596 | # CHECK: for (int x = 0; x < 10; x++) |
597 | # CHECK: for (int y = 0; y < 10; y++) |
598 | # CHECK-NOT: A[ |
599 | # CHECK: A_1 = |
600 | # CHECK: A[0] = A_1;)IR" ; |
601 | |
602 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
603 | } |
604 | |
605 | // Can registerize correctly if scalars already exist in the program. |
606 | TEST(Registerizer, RegisterizerRepeated) { |
607 | BufHandle a("A" , {2}, kInt); |
608 | VarHandle x("x" , kInt); |
609 | StmtPtr stmt = Block::make({ |
610 | Store::make(a, {0}, 0), |
611 | Store::make(a, {1}, 0), |
612 | For::make( |
613 | x, |
614 | 0, |
615 | 10, |
616 | Block::make( |
617 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), |
618 | Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), |
619 | }); |
620 | |
621 | // Registerize manually to make sure we only replace a single target. |
622 | { |
623 | registerizer::RegisterizerAnalysis analysis; |
624 | stmt->accept(&analysis); |
625 | auto candidates = analysis.getCandidates(); |
626 | ASSERT_EQ(candidates.size(), 2); |
627 | |
628 | candidates.pop_back(); |
629 | registerizer::RegisterizerReplacer replacer(candidates); |
630 | stmt = stmt->accept_mutator(&replacer); |
631 | } |
632 | |
633 | // Re-analyze and replace the second target. |
634 | { |
635 | registerizer::RegisterizerAnalysis analysis; |
636 | stmt->accept(&analysis); |
637 | auto candidates = analysis.getCandidates(); |
638 | ASSERT_EQ(candidates.size(), 1); |
639 | |
640 | registerizer::RegisterizerReplacer replacer(candidates); |
641 | stmt = stmt->accept_mutator(&replacer); |
642 | } |
643 | |
644 | std::ostringstream oss; |
645 | oss << *stmt; |
646 | |
647 | const std::string& verification_pattern = |
648 | R"IR( |
649 | # CHECK: int A_1 = 0; |
650 | # CHECK: int A_1_1 = 0; |
651 | # CHECK: for (int x = 0; x < 10; x++) |
652 | # CHECK-NOT: A[ |
653 | # CHECK: A_1 = |
654 | # CHECK: A_1_1 = |
655 | # CHECK: A[1] = A_1_1; |
656 | # CHECK: A[0] = A_1;)IR" ; |
657 | |
658 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
659 | } |
660 | |
661 | // Can registerize the load of A. |
662 | TEST(Registerizer, RegisterizerNoLoads) { |
663 | BufHandle a("A" , {1}, kInt); |
664 | VarHandle x("x" , kInt); |
665 | StmtPtr stmt = Block::make( |
666 | {Store::make(a, {0}, 0), |
667 | For::make( |
668 | x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); |
669 | |
670 | /* |
671 | * A[0] = 0; |
672 | * for (int x = 0; x < 10; x++) { |
673 | * A[0] = x + 1; |
674 | * } |
675 | */ |
676 | |
677 | stmt = registerize(stmt); |
678 | |
679 | /* |
680 | * int A_1 = 0; |
681 | * for (int x = 0; x < 10; x++) { |
682 | * A_1 = x + 1; |
683 | * } |
684 | * A[0] = A_1; |
685 | */ |
686 | |
687 | std::ostringstream oss; |
688 | oss << *stmt; |
689 | |
690 | const std::string& verification_pattern = |
691 | R"IR( |
692 | # CHECK: int A_1 = 0; |
693 | # CHECK: for (int x = 0; x < 10; x++) |
694 | # CHECK-NOT: A[ |
695 | # CHECK: A_1 = |
696 | # CHECK: A[0] = A_1;)IR" ; |
697 | |
698 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
699 | } |
700 | |
701 | // Can registerize the load of A but not the store of B. |
702 | TEST(Registerizer, RegisterizerNoRepeatedStores) { |
703 | BufHandle a("A" , {1}, kInt); |
704 | BufHandle b("B" , {10}, kInt); |
705 | VarHandle x("x" , kInt); |
706 | StmtPtr stmt = Block::make( |
707 | {Store::make(a, {0}, 0), |
708 | For::make( |
709 | x, |
710 | 0, |
711 | 10, |
712 | Block::make( |
713 | {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))}); |
714 | |
715 | /* |
716 | * A[0] = 0; |
717 | * for (int x = 0; x < 10; x++) { |
718 | * B[x] = (A[0]) + x; |
719 | * } |
720 | */ |
721 | |
722 | stmt = registerize(stmt); |
723 | |
724 | // TODO: its unnecessary to reorder the initializer of A[0], but it's not |
725 | // actually worse so lets not worry for now. |
726 | |
727 | /* |
728 | * int A_1 = 0; |
729 | * for (int x = 0; x < 10; x++) { |
730 | * B[x] = x + A_1; |
731 | * } |
732 | * A[0] = A_1; |
733 | */ |
734 | |
735 | std::ostringstream oss; |
736 | oss << *stmt; |
737 | |
738 | const std::string& verification_pattern = |
739 | R"IR( |
740 | # CHECK: int A_1 = 0; |
741 | # CHECK: for (int x = 0; x < 10; x++) |
742 | # CHECK-NOT: A_ |
743 | # CHECK: B[x] = |
744 | # CHECK: A[0] = A_1;)IR" ; |
745 | |
746 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
747 | } |
748 | |
749 | // Won't registerize if there are multiple accesses which may overlap. |
750 | TEST(Registerizer, RegisterizerMultiVarOverlap) { |
751 | BufHandle a("A" , {2}, kInt); |
752 | VarHandle x("x" , kInt); |
753 | StmtPtr stmt = Block::make({ |
754 | Store::make(a, {0}, 0), |
755 | Store::make(a, {1}, 0), |
756 | For::make( |
757 | x, |
758 | 0, |
759 | 10, |
760 | Block::make( |
761 | {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)), |
762 | Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})), |
763 | }); |
764 | stmt = IRSimplifier::simplify(stmt); |
765 | |
766 | std::ostringstream before; |
767 | before << *stmt; |
768 | |
769 | // No change. |
770 | stmt = registerize(stmt); |
771 | |
772 | std::ostringstream after; |
773 | after << *stmt; |
774 | |
775 | ASSERT_EQ(before.str(), after.str()); |
776 | } |
777 | |
778 | TEST(Registerizer, RegisterizerAllocs) { |
779 | BufHandle a("A" , {2}, kInt); |
780 | BufHandle c("C" , {1}, kInt); |
781 | VarHandle x("x" , kInt); |
782 | |
783 | BufHandle b("B" , {Load::make(c, {0})}, kInt); |
784 | |
785 | StmtPtr stmt = Block::make( |
786 | {Allocate::make(b), |
787 | Store::make(a, {0}, Load::make(c, {0})), |
788 | Store::make(b, {0}, 0), |
789 | For::make( |
790 | x, |
791 | 0, |
792 | 10, |
793 | Block::make( |
794 | {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)), |
795 | Store::make(a, {0}, Load::make(c, {0}))})), |
796 | Free::make(b)}); |
797 | |
798 | /* |
799 | * Allocate(B, int, {C[0]}); |
800 | * A[0] = C[0]; |
801 | * B[0] = 0; |
802 | * for (int x = 0; x < 10; x++) { |
803 | * B[0] = (B[0]) + x; |
804 | * A[0] = C[0]; |
805 | * } |
806 | * Free(B); |
807 | */ |
808 | |
809 | stmt = registerize(stmt); |
810 | |
811 | /* |
812 | * int C_1 = C[0]; |
813 | * Allocate(B, int, {C_}); |
814 | * int A_1 = C_1; |
815 | * int B_1 = 0; |
816 | * for (int x = 0; x < 10; x++) { |
817 | * B_1 = B_1 + x; |
818 | * A_1 = C_1; |
819 | * } |
820 | * B[0] = B_1; |
821 | * A[0] = A_1; |
822 | * Free(B); |
823 | */ |
824 | |
825 | std::ostringstream oss; |
826 | oss << *stmt; |
827 | |
828 | const std::string& verification_pattern = |
829 | R"IR( |
830 | # CHECK: int C_1 = C[0]; |
831 | # CHECK: Allocate(B |
832 | # CHECK: int A_1 = C_1; |
833 | # CHECK: int B_1 = 0; |
834 | # CHECK: for (int x = 0; x < 10; x++) |
835 | # CHECK: B_1 = |
836 | # CHECK: A_1 = C_ |
837 | # CHECK: B[0] = B_1; |
838 | # CHECK: A[0] = A_1; |
839 | # CHECK: Free(B)IR" ; |
840 | |
841 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
842 | } |
843 | |
844 | TEST(Registerizer, RegisterizerNoInitializer) { |
845 | BufHandle a("A" , {1}, kInt); |
846 | VarHandle x("x" , kInt); |
847 | StmtPtr stmt = Block::make({For::make( |
848 | x, |
849 | 0, |
850 | 10, |
851 | Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); |
852 | |
853 | /* |
854 | * for (int x = 0; x < 10; x++) { |
855 | * A[0] = (A[0]) + x; |
856 | * } |
857 | */ |
858 | |
859 | stmt = registerize(stmt); |
860 | |
861 | /* |
862 | * int A_1 = A[0]; |
863 | * for (int x = 0; x < 10; x++) { |
864 | * A_1 = x + A_1; |
865 | * } |
866 | * A[0] = A_1; |
867 | */ |
868 | |
869 | std::ostringstream oss; |
870 | oss << *stmt; |
871 | |
872 | const std::string& verification_pattern = |
873 | R"IR( |
874 | # CHECK: int A_1 = A[0]; |
875 | # CHECK: for (int x = 0; x < 10; x++) |
876 | # CHECK-NOT: A[ |
877 | # CHECK: A_1 = |
878 | # CHECK: A[0] = A_1;)IR" ; |
879 | |
880 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
881 | } |
882 | |
883 | TEST(Registerizer, RegisterizerNoInitializerLoopVar) { |
884 | BufHandle a("A" , {1}, kInt); |
885 | VarHandle x("x" , kInt); |
886 | StmtPtr stmt = Block::make({For::make( |
887 | x, |
888 | 0, |
889 | 10, |
890 | Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); |
891 | stmt = IRSimplifier::simplify(stmt); |
892 | |
893 | /* |
894 | * for (int x = 0; x < 10; x++) { |
895 | * A[x] = (A[x]) + x; |
896 | * } |
897 | */ |
898 | |
899 | std::ostringstream before; |
900 | before << *stmt; |
901 | |
902 | // No change. |
903 | stmt = registerize(stmt); |
904 | |
905 | std::ostringstream after; |
906 | after << *stmt; |
907 | |
908 | ASSERT_EQ(before.str(), after.str()); |
909 | } |
910 | |
911 | TEST(Registerizer, RegisterizerLoadThenStore) { |
912 | BufHandle a("A" , {1}, kInt); |
913 | BufHandle b("B" , {1}, kInt); |
914 | VarHandle x("x" , kInt); |
915 | StmtPtr stmt = Block::make({For::make( |
916 | x, |
917 | 0, |
918 | 10, |
919 | Block::make( |
920 | {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)), |
921 | Store::make(a, {0}, Load::make(b, {0}))}))}); |
922 | |
923 | /* |
924 | * for (int x = 0; x < 10; x++) { |
925 | * B[0] = (A[0]) + x; |
926 | * A[0] = B[0]; |
927 | * } |
928 | */ |
929 | |
930 | stmt = registerize(stmt); |
931 | |
932 | /* |
933 | * int A_1 = A[0]; |
934 | * int B_1 = B[0]; |
935 | * for (int x = 0; x < 10; x++) { |
936 | * B_1 = x + A_1; |
937 | * A_1 = B_1; |
938 | * } |
939 | * B[0] = B_1; |
940 | * A[0] = A_1; |
941 | */ |
942 | |
943 | std::ostringstream oss; |
944 | oss << *stmt; |
945 | |
946 | const std::string& verification_pattern = |
947 | R"IR( |
948 | # CHECK: int A_1 = A[0]; |
949 | # CHECK: int B_1 = B[0]; |
950 | # CHECK: for (int x = 0; x < 10; x++) |
951 | # CHECK-NOT: B[ |
952 | # CHECK: B_1 = |
953 | # CHECK-NOT: A[ |
954 | # CHECK: A_1 = B_ |
955 | # CHECK: B[0] = B_ |
956 | # CHECK: A[0] = A_1;)IR" ; |
957 | |
958 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
959 | } |
960 | |
961 | TEST(Registerizer, RegisterizerParallelized) { |
962 | BufHandle a("A" , {1}, kInt); |
963 | VarHandle x("x" , kInt); |
964 | LoopOptions loopOpts; |
965 | loopOpts.set_gpu_block_index(0); |
966 | StmtPtr stmt = Block::make( |
967 | {Store::make(a, {0}, 0), |
968 | For::make( |
969 | x, |
970 | 0, |
971 | 10, |
972 | Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}), |
973 | loopOpts)}); |
974 | |
975 | /* |
976 | * A[0] = 0; |
977 | * for (int x = 0; x < 10; x++) { |
978 | * A[0] = (A[0]) + x; |
979 | * } |
980 | */ |
981 | |
982 | ASSERT_THROWS_WITH( |
983 | registerize(stmt), |
984 | "Registerization must occur after parallelism flattening" ); |
985 | } |
986 | |
987 | // Should be able to registerize this since the scalar would exist before the |
988 | // branch. |
989 | TEST(Registerizer, RegisterizerConditionAfter) { |
990 | BufHandle a("A" , {5}, kInt); |
991 | BufHandle b("B" , {5}, kInt); |
992 | BufHandle c("C" , {5}, kInt); |
993 | VarHandle x("x" , kInt); |
994 | |
995 | StmtPtr stmt = Block::make( |
996 | {Store::make(a, {x}, Load::make(b, {x})), |
997 | Store::make(c, {x}, Load::make(a, {x})), |
998 | Cond::make( |
999 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1000 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1001 | nullptr)}); |
1002 | |
1003 | /* |
1004 | * A[x] = B[x]; |
1005 | * C[x] = A[x]; |
1006 | * if (x<5 ? 1 : 0) { |
1007 | * A[x] = (A[x]) + 1; |
1008 | * } |
1009 | */ |
1010 | |
1011 | stmt = registerize(stmt); |
1012 | |
1013 | /* |
1014 | * int A_1 = B[x]; |
1015 | * C[x] = A_1; |
1016 | * if (x<5 ? 1 : 0) { |
1017 | * A_1 = A_1 + 1; |
1018 | * } |
1019 | * A[x] = A_1; |
1020 | */ |
1021 | |
1022 | std::ostringstream oss; |
1023 | oss << *stmt; |
1024 | |
1025 | const std::string& verification_pattern = |
1026 | R"IR( |
1027 | # CHECK: int A_1 = B[x]; |
1028 | # CHECK: C[x] = A_1; |
1029 | # CHECK: if ( |
1030 | # CHECK: A_1 = A_1 + 1; |
1031 | # CHECK: A[x] = A_1;)IR" ; |
1032 | |
1033 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1034 | } |
1035 | |
1036 | // Should be able to registerize this since the scalar exists in the same form |
1037 | // after the branch and there is no overlap. |
1038 | TEST(Registerizer, RegisterizerConditionBefore) { |
1039 | BufHandle a("A" , {5}, kInt); |
1040 | BufHandle b("B" , {5}, kInt); |
1041 | BufHandle c("C" , {5}, kInt); |
1042 | VarHandle x("x" , kInt); |
1043 | |
1044 | StmtPtr stmt = Block::make( |
1045 | {Cond::make( |
1046 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1047 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1048 | nullptr), |
1049 | Store::make(a, {x}, Load::make(b, {x})), |
1050 | Store::make(c, {x}, Load::make(a, {x}))}); |
1051 | |
1052 | /* |
1053 | * if (x<5 ? 1 : 0) { |
1054 | * A[x] = (A[x]) + 1; |
1055 | * } |
1056 | * A[x] = B[x]; |
1057 | * C[x] = A[x]; |
1058 | */ |
1059 | |
1060 | stmt = registerize(stmt); |
1061 | |
1062 | /* |
1063 | * int A_ 1 = A[x]; |
1064 | * if (x<5 ? 1 : 0) { |
1065 | * A_1 = A_1 + 1; |
1066 | * } |
1067 | * A_1 = B[x]; |
1068 | * C[x] = A_1; |
1069 | * A[x] = A_1; |
1070 | */ |
1071 | |
1072 | std::ostringstream oss; |
1073 | oss << *stmt; |
1074 | |
1075 | const std::string& verification_pattern = |
1076 | R"IR( |
1077 | # CHECK: int A_1 = A[x]; |
1078 | # CHECK: if ( |
1079 | # CHECK: A_1 = A_1 + 1; |
1080 | # CHECK: } |
1081 | # CHECK: A_1 = B[x]; |
1082 | # CHECK: C[x] = A_1; |
1083 | # CHECK: A[x] = A_1;)IR" ; |
1084 | |
1085 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1086 | } |
1087 | |
1088 | // Should be able to registerize this as the combination of the two above rules. |
1089 | TEST(Registerizer, RegisterizerConditionInside) { |
1090 | BufHandle a("A" , {5}, kInt); |
1091 | BufHandle b("B" , {5}, kInt); |
1092 | BufHandle c("C" , {5}, kInt); |
1093 | VarHandle x("x" , kInt); |
1094 | |
1095 | StmtPtr stmt = Block::make( |
1096 | {Store::make(a, {x}, Load::make(b, {x})), |
1097 | Store::make(c, {x}, Load::make(a, {x})), |
1098 | Cond::make( |
1099 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1100 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1101 | nullptr), |
1102 | Store::make(b, {x}, Load::make(a, {x})), |
1103 | Store::make(a, {x}, Load::make(c, {x}))}); |
1104 | |
1105 | /* |
1106 | * A[x] = B[x]; |
1107 | * C[x] = A[x]; |
1108 | * if (x<5 ? 1 : 0) { |
1109 | * A[x] = (A[x]) + 1; |
1110 | * } |
1111 | * B[x] = A[x]; |
1112 | * A[x] = C[x]; |
1113 | */ |
1114 | |
1115 | stmt = registerize(stmt); |
1116 | |
1117 | /* |
1118 | * int A_1 = B[x]; |
1119 | * C[x] = A_1; |
1120 | * if (x<5 ? 1 : 0) { |
1121 | * A_1 = A_1 + 1; |
1122 | * } |
1123 | * B[x] = A_1; |
1124 | * A_1 = C[x]; |
1125 | * A[x] = A_1; |
1126 | */ |
1127 | |
1128 | std::ostringstream oss; |
1129 | oss << *stmt; |
1130 | |
1131 | const std::string& verification_pattern = |
1132 | R"IR( |
1133 | # CHECK: int A_1 = B[x]; |
1134 | # CHECK: C[x] = A_1; |
1135 | # CHECK: if ( |
1136 | # CHECK: A_1 = A_1 + 1; |
1137 | # CHECK: } |
1138 | # CHECK: B[x] = A_1; |
1139 | # CHECK: A_1 = C[x]; |
1140 | # CHECK: A[x] = A_1;)IR" ; |
1141 | |
1142 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1143 | } |
1144 | |
1145 | // An example where an access is cut by an overlapping access inside a |
1146 | // condition, and both sides are large enough to be registerized but cannot be |
1147 | // because there is no safe place to put the initializer or finalizer. |
1148 | TEST(Registerizer, RegisterizerConditionInsideOverlap1) { |
1149 | BufHandle a("A" , {5}, kInt); |
1150 | BufHandle b("B" , {5}, kInt); |
1151 | BufHandle c("C" , {5}, kInt); |
1152 | VarHandle x("x" , kInt); |
1153 | VarHandle y("y" , kInt); |
1154 | |
1155 | StmtPtr stmt = Block::make( |
1156 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
1157 | {Store::make(a, {x}, Load::make(b, {x})), |
1158 | Store::make(c, {x}, Load::make(a, {x})), |
1159 | Cond::make( |
1160 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1161 | Block::make({ |
1162 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1163 | Store::make(a, {0}, 3), |
1164 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1165 | }), |
1166 | nullptr), |
1167 | Store::make(b, {x}, Load::make(a, {x})), |
1168 | Store::make(a, {x}, Load::make(c, {x}))}); |
1169 | |
1170 | /* |
1171 | * A[x] = B[x]; |
1172 | * C[x] = A[x]; |
1173 | * if (x<5 ? 1 : 0) { |
1174 | * A[x] = (A[x]) + 1; |
1175 | * A[0] = 3; |
1176 | * A[x] = (A[x]) + 1; |
1177 | * } |
1178 | * B[x] = A[x]; |
1179 | * A[x] = C[x]; |
1180 | */ |
1181 | |
1182 | // The A[0] store overlaps, A[x] cutting the region that can be registerized |
1183 | // into two groups. |
1184 | // Each group has 2 loads and 2 stores however, so we could registerize it, |
1185 | // but the first group would need to be finalized inside the condition block, |
1186 | // the second would need to be initialized inside the condition block. There's |
1187 | // no safe place to put these that's visible to the other uses in the group |
1188 | // and so neither registerization is possible. |
1189 | |
1190 | std::ostringstream before; |
1191 | before << *stmt; |
1192 | |
1193 | // No change. |
1194 | stmt = registerize(stmt); |
1195 | |
1196 | std::ostringstream after; |
1197 | after << *stmt; |
1198 | |
1199 | ASSERT_EQ(before.str(), after.str()); |
1200 | } |
1201 | |
1202 | // Same as the above, but the access group before the condition (and after the |
1203 | // condition) are large enough to be registerized without needing the access |
1204 | // from the loop. Registerization occurs but does not include any accesses in |
1205 | // the condition, and the first group must be finalized before the Cond, the |
1206 | // second initialized after it. |
1207 | TEST(Registerizer, RegisterizerConditionInsideOverlap2) { |
1208 | BufHandle a("A" , {5}, kInt); |
1209 | BufHandle b("B" , {5}, kInt); |
1210 | BufHandle c("C" , {5}, kInt); |
1211 | VarHandle x("x" , kInt); |
1212 | VarHandle y("y" , kInt); |
1213 | |
1214 | StmtPtr stmt = Block::make( |
1215 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
1216 | {Store::make(a, {x}, Load::make(b, {x})), |
1217 | Store::make(a, {x}, Load::make(b, {x + 1})), |
1218 | Store::make(c, {x}, Load::make(a, {x})), |
1219 | Cond::make( |
1220 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1221 | Block::make({ |
1222 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1223 | Store::make(a, {0}, 3), |
1224 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1225 | }), |
1226 | nullptr), |
1227 | Store::make(b, {x}, Load::make(a, {x})), |
1228 | Store::make(b, {x + 1}, Load::make(a, {x})), |
1229 | Store::make(a, {x}, Load::make(c, {x}))}); |
1230 | |
1231 | /* |
1232 | * A[x] = B[x]; |
1233 | * A[x] = B[x + 1]; |
1234 | * C[x] = A[x]; |
1235 | * if (x<5 ? 1 : 0) { |
1236 | * A[x] = (A[x]) + 1; |
1237 | * A[0] = 3; |
1238 | * A[x] = (A[x]) + 1; |
1239 | * } |
1240 | * B[x] = A[x]; |
1241 | * B[x + 1] = A[x]; |
1242 | * A[x] = C[x]; |
1243 | */ |
1244 | |
1245 | stmt = registerize(stmt); |
1246 | |
1247 | /* |
1248 | * int A_1 = B[x]; // A_1 initializer |
1249 | * A_1 = B[x + 1]; // |
1250 | * C[x] = A_1; // |
1251 | * A[x] = A_1; // A_1 finalizer |
1252 | * if (x<5 ? 1 : 0) { |
1253 | * A[x] = (A[x]) + 1; |
1254 | * A[0] = 3; |
1255 | * A[x] = (A[x]) + 1; |
1256 | * } |
1257 | * int A_2 = A[x]; // A_2 initialier |
1258 | * B[x] = A_2; // |
1259 | * B[x + 1] = A_2; // |
1260 | * A_2 = C[x]; // |
1261 | * A[x] = A_2; // A_2 finalizer |
1262 | */ |
1263 | |
1264 | std::ostringstream oss; |
1265 | oss << *stmt; |
1266 | |
1267 | const std::string& verification_pattern = |
1268 | R"IR( |
1269 | # CHECK: int A_1 = B[x]; |
1270 | # CHECK: A_1 = B[x + 1]; |
1271 | # CHECK: C[x] = A_1; |
1272 | # CHECK: A[x] = A_1; |
1273 | # CHECK: if ( |
1274 | # CHECK-NOT: A_1 = A_1 + 1; |
1275 | # CHECK: A[x] = (A[x] |
1276 | # CHECK: A[0] = |
1277 | # CHECK: A[x] = (A[x] |
1278 | # CHECK: } |
1279 | # CHECK: int A_2 = A[x]; |
1280 | # CHECK: B[x] = A_2; |
1281 | # CHECK: B[x + 1] = A_2; |
1282 | # CHECK: A_2 = C[x]; |
1283 | # CHECK: A[x] = A_2;)IR" ; |
1284 | |
1285 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1286 | } |
1287 | |
1288 | // When accesses are within conditional blocks they are not visible to the wider |
1289 | // program, because we don't know if the branch would be taken and if it isn't |
1290 | // the accesses in it don't need to be valid (think size checks on the index). |
1291 | // In this case the accesses cannot be registerized. |
1292 | TEST(Registerizer, RegisterizerConditionHidden) { |
1293 | BufHandle a("A" , {5}, kInt); |
1294 | BufHandle b("B" , {5}, kInt); |
1295 | BufHandle c("C" , {5}, kInt); |
1296 | VarHandle x("x" , kInt); |
1297 | |
1298 | StmtPtr stmt = Block::make( |
1299 | {Cond::make( |
1300 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1301 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1302 | nullptr), |
1303 | Cond::make( |
1304 | CompareSelect::make(x, 5, CompareSelectOperation::kGT), |
1305 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1306 | nullptr)}); |
1307 | |
1308 | /* |
1309 | * if (x<5 ? 1 : 0) { |
1310 | * A[x] = (A[x]) + 1; |
1311 | * } |
1312 | * if (x>5 ? 1 : 0) { |
1313 | * A[x] = (A[x]) + 1; |
1314 | * } |
1315 | */ |
1316 | |
1317 | std::ostringstream before; |
1318 | before << *stmt; |
1319 | |
1320 | // No change. |
1321 | stmt = registerize(stmt); |
1322 | |
1323 | std::ostringstream after; |
1324 | after << *stmt; |
1325 | |
1326 | ASSERT_EQ(before.str(), after.str()); |
1327 | } |
1328 | |
1329 | // But... if the same access is found in a non conditional scope, that means |
1330 | // that that access is valid in the higher scope (or at least if its not it's |
1331 | // the user's fault). It "unhides" the conditional accesses, allowing |
1332 | // registerization to occur. |
1333 | TEST(Registerizer, RegisterizerConditionUnhidden) { |
1334 | BufHandle a("A" , {5}, kInt); |
1335 | BufHandle b("B" , {5}, kInt); |
1336 | BufHandle c("C" , {5}, kInt); |
1337 | VarHandle x("x" , kInt); |
1338 | |
1339 | StmtPtr stmt = Block::make( |
1340 | {Cond::make( |
1341 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1342 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1343 | nullptr), |
1344 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1345 | Cond::make( |
1346 | CompareSelect::make(x, 5, CompareSelectOperation::kGT), |
1347 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1348 | nullptr)}); |
1349 | |
1350 | /* |
1351 | * if (x<5 ? 1 : 0) { |
1352 | * A[x] = (A[x]) + 1; |
1353 | * } |
1354 | * A[x] = (A[x]) + 1; <-- this is doing the unhiding. |
1355 | * if (x>5 ? 1 : 0) { |
1356 | * A[x] = (A[x]) + 1; |
1357 | * } |
1358 | */ |
1359 | |
1360 | stmt = registerize(stmt); |
1361 | |
1362 | /* |
1363 | * int A_1 = A[x]; |
1364 | * if (x<5 ? 1 : 0) { |
1365 | * A_1 = A_1 + 1; |
1366 | * } |
1367 | * A_1 = A_1 + 1; |
1368 | * if (x>5 ? 1 : 0) { |
1369 | * A_1 = A_1 + 1; |
1370 | * } |
1371 | * A[x] = A_1; |
1372 | */ |
1373 | |
1374 | std::ostringstream oss; |
1375 | oss << *stmt; |
1376 | |
1377 | const std::string& verification_pattern = |
1378 | R"IR( |
1379 | # CHECK: int A_1 = A[x]; |
1380 | # CHECK: if (x<5 |
1381 | # CHECK: A_1 = A_1 + 1; |
1382 | # CHECK: } |
1383 | # CHECK: A_1 = A_1 + 1; |
1384 | # CHECK: if (x>5 |
1385 | # CHECK: A_1 = A_1 + 1; |
1386 | # CHECK: } |
1387 | # CHECK: A[x] = A_1;)IR" ; |
1388 | |
1389 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1390 | } |
1391 | |
1392 | // Can registerize a load that occurs in the condition of a Cond. |
1393 | TEST(Registerizer, RegisterizerCondCondition) { |
1394 | BufHandle a("A" , {5}, kInt); |
1395 | BufHandle b("B" , {5}, kInt); |
1396 | BufHandle c("C" , {5}, kInt); |
1397 | VarHandle x("x" , kInt); |
1398 | |
1399 | StmtPtr stmt = Block::make( |
1400 | {Store::make(a, {x}, Load::make(b, {x})), |
1401 | Store::make(c, {x}, Load::make(a, {x})), |
1402 | Cond::make( |
1403 | CompareSelect::make( |
1404 | Load::make(a, {x}), 5, CompareSelectOperation::kLT), |
1405 | Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), |
1406 | nullptr)}); |
1407 | |
1408 | /* |
1409 | * A[x] = B[x]; |
1410 | * C[x] = A[x]; |
1411 | * if ((A[x])<5 ? 1 : 0) { |
1412 | * C[x] = (C[x]) + 1; |
1413 | * } |
1414 | */ |
1415 | |
1416 | stmt = registerize(stmt); |
1417 | |
1418 | /* |
1419 | * int A_1 = B[x]; |
1420 | * int C_1 = A_1; |
1421 | * if (A_1<5 ? 1 : 0) { |
1422 | * C_1 = C_1 + 1; |
1423 | * } |
1424 | * C[x] = C_1; |
1425 | */ |
1426 | |
1427 | std::ostringstream oss; |
1428 | oss << *stmt; |
1429 | |
1430 | const std::string& verification_pattern = |
1431 | R"IR( |
1432 | # CHECK: int A_1 = B[x]; |
1433 | # CHECK: int C_1 = A_1; |
1434 | # CHECK: if (A_1<5 |
1435 | # CHECK: C_1 = C_1 + 1; |
1436 | # CHECK: C[x] = C_1;)IR" ; |
1437 | |
1438 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1439 | } |
1440 | |
1441 | // Appearing in the condition of a Cond makes it visible to the enclosing scope, |
1442 | // and so we can registerize internal usages. |
1443 | TEST(Registerizer, RegisterizerCondConditionUnhidden) { |
1444 | BufHandle a("A" , {5}, kInt); |
1445 | BufHandle b("B" , {5}, kInt); |
1446 | BufHandle c("C" , {5}, kInt); |
1447 | VarHandle x("x" , kInt); |
1448 | |
1449 | StmtPtr stmt = Block::make({Cond::make( |
1450 | CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), |
1451 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), |
1452 | Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); |
1453 | |
1454 | /* |
1455 | * if ((A[x])<5 ? 1 : 0) { |
1456 | * A[x] = (A[x]) + 1; |
1457 | * } else { |
1458 | * A[x] = (A[x]) + 10; |
1459 | * } |
1460 | */ |
1461 | |
1462 | stmt = registerize(stmt); |
1463 | |
1464 | /* |
1465 | * int A_1 = A[x]; |
1466 | * if (A_1<5 ? 1 : 0) { |
1467 | * A_1 = A_1 + 1; |
1468 | * } else { |
1469 | * A_1 = A_1 + 10; |
1470 | * } |
1471 | * A[x] = A_1; |
1472 | */ |
1473 | |
1474 | std::ostringstream oss; |
1475 | oss << *stmt; |
1476 | |
1477 | const std::string& verification_pattern = |
1478 | R"IR( |
1479 | # CHECK: int A_1 = A[x]; |
1480 | # CHECK: if (A_1<5 |
1481 | # CHECK: A_1 = A_1 + 1; |
1482 | # CHECK: } else { |
1483 | # CHECK: A_1 = A_1 + 10; |
1484 | # CHECK: } |
1485 | # CHECK: A[x] = A_1;)IR" ; |
1486 | |
1487 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1488 | } |
1489 | |
1490 | // Conditional hiding also works for IfThenElse exprs. |
1491 | TEST(Registerizer, RegisterizerIfThenElseHidden) { |
1492 | BufHandle a("A" , {5}, kInt); |
1493 | BufHandle b("B" , {5}, kInt); |
1494 | BufHandle c("C" , {5}, kInt); |
1495 | VarHandle x("x" , kInt); |
1496 | VarHandle y("y" , kInt); |
1497 | |
1498 | StmtPtr stmt = Block::make( |
1499 | {Store::make( |
1500 | b, |
1501 | {y}, |
1502 | IfThenElse::make( |
1503 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1504 | Add::make(Load::make(a, {x}), 1), |
1505 | Add::make(Load::make(a, {x + 1}), 2))), |
1506 | Store::make( |
1507 | b, |
1508 | {y + 1}, |
1509 | IfThenElse::make( |
1510 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1511 | Add::make(Load::make(a, {x}), 1), |
1512 | Add::make(Load::make(a, {x + 1}), 2)))}); |
1513 | |
1514 | /* |
1515 | * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); |
1516 | * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); |
1517 | */ |
1518 | |
1519 | std::ostringstream before; |
1520 | before << *stmt; |
1521 | |
1522 | // No change. |
1523 | stmt = registerize(stmt); |
1524 | |
1525 | std::ostringstream after; |
1526 | after << *stmt; |
1527 | |
1528 | ASSERT_EQ(before.str(), after.str()); |
1529 | } |
1530 | |
1531 | // Conditional unhiding also works for IfThenElse exprs. |
1532 | TEST(Registerizer, RegisterizerIfThenElseUnhidden) { |
1533 | BufHandle a("A" , {5}, kInt); |
1534 | BufHandle b("B" , {5}, kInt); |
1535 | BufHandle c("C" , {5}, kInt); |
1536 | VarHandle x("x" , kInt); |
1537 | VarHandle y("y" , kInt); |
1538 | |
1539 | StmtPtr stmt = Block::make({ |
1540 | Store::make(a, {x}, 0), |
1541 | Store::make( |
1542 | b, |
1543 | {y}, |
1544 | IfThenElse::make( |
1545 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1546 | Add::make(Load::make(a, {x}), 1), |
1547 | Add::make(Load::make(a, {x + 1}), 2))), |
1548 | Store::make( |
1549 | b, |
1550 | {y + 1}, |
1551 | IfThenElse::make( |
1552 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1553 | Add::make(Load::make(a, {x}), 1), |
1554 | Add::make(Load::make(a, {x + 1}), 2))), |
1555 | }); |
1556 | |
1557 | /* |
1558 | * A[x] = 0; |
1559 | * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); |
1560 | * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); |
1561 | */ |
1562 | |
1563 | stmt = registerize(stmt); |
1564 | |
1565 | /* |
1566 | * int A_1 = 0; |
1567 | * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); |
1568 | * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); |
1569 | * A[x] = A_1; |
1570 | */ |
1571 | |
1572 | std::ostringstream oss; |
1573 | oss << *stmt; |
1574 | |
1575 | const std::string& verification_pattern = |
1576 | R"IR( |
1577 | # CHECK: int A_1 = 0; |
1578 | # CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); |
1579 | # CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); |
1580 | # CHECK: A[x] = A_1;)IR" ; |
1581 | |
1582 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1583 | } |
1584 | |
1585 | // Nested IfThenElse exprs can't promote to higher level scopes. |
1586 | TEST(Registerizer, RegisterizerIfThenElseNested) { |
1587 | BufHandle a("A" , {5}, kInt); |
1588 | BufHandle b("B" , {5}, kInt); |
1589 | BufHandle c("C" , {5}, kInt); |
1590 | BufHandle d("D" , {5}, kInt); |
1591 | VarHandle x("x" , kInt); |
1592 | |
1593 | StmtPtr stmt = Block::make({Store::make( |
1594 | a, |
1595 | {x}, |
1596 | IfThenElse::make( |
1597 | CompareSelect::make(x, 3, CompareSelectOperation::kLT), |
1598 | IfThenElse::make( |
1599 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
1600 | Load::make(d, {x}), |
1601 | Load::make(b, {x})), |
1602 | IfThenElse::make( |
1603 | CompareSelect::make(x, 5, CompareSelectOperation::kEQ), |
1604 | Load::make(c, {x}), |
1605 | Load::make(d, {x}))))}); |
1606 | |
1607 | /* |
1608 | * A[x] = IfThenElse(x<3 ? 1 : 0, |
1609 | * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), |
1610 | * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); |
1611 | */ |
1612 | |
1613 | std::ostringstream before; |
1614 | before << *stmt; |
1615 | |
1616 | // No change. |
1617 | stmt = registerize(stmt); |
1618 | |
1619 | std::ostringstream after; |
1620 | after << *stmt; |
1621 | |
1622 | ASSERT_EQ(before.str(), after.str()); |
1623 | } |
1624 | |
1625 | // Cannot registerize an access completely contained within an IfThenElse |
1626 | // branch, since it is not a Stmt and cannot hold variable definitions. We need |
1627 | // to check that we don't promote the initializer/finalizer to the enclosing |
1628 | // Block. |
1629 | TEST(Registerizer, RegisterizerIfThenElseInternal) { |
1630 | // Making these floats so they don't get simplified to a single access. |
1631 | BufHandle a("A" , {5}, kFloat); |
1632 | BufHandle b("B" , {5}, kFloat); |
1633 | VarHandle x("x" , kInt); |
1634 | |
1635 | StmtPtr stmt = Block::make({Store::make( |
1636 | a, |
1637 | {x}, |
1638 | IfThenElse::make( |
1639 | CompareSelect::make(x, 3, CompareSelectOperation::kLT), |
1640 | Add::make(Load::make(b, {x}), Load::make(b, {x})), |
1641 | Load::make(b, {x})))}); |
1642 | |
1643 | /* |
1644 | * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); |
1645 | */ |
1646 | |
1647 | std::ostringstream before; |
1648 | before << *stmt; |
1649 | |
1650 | // No change. |
1651 | stmt = registerize(stmt); |
1652 | |
1653 | std::ostringstream after; |
1654 | after << *stmt; |
1655 | |
1656 | ASSERT_EQ(before.str(), after.str()); |
1657 | |
1658 | // If this was a Cond instead of an IfThenElse then we could registerize the |
1659 | // two accesses to B[x] in the True branch. |
1660 | |
1661 | // Actually lets verify that. |
1662 | |
1663 | stmt = Block::make({Cond::make( |
1664 | CompareSelect::make(x, 3, CompareSelectOperation::kLT), |
1665 | Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))), |
1666 | Store::make(a, {x}, Load::make(b, {x})))}); |
1667 | |
1668 | /* |
1669 | * if (x<3 ? 1 : 0) { |
1670 | * A[x] = (B[x]) + (B[x]); |
1671 | * } else { |
1672 | * A[x] = B[x]; |
1673 | * } |
1674 | */ |
1675 | |
1676 | stmt = registerize(stmt); |
1677 | |
1678 | /* |
1679 | * if (x<3 ? 1 : 0) { |
1680 | * float B_1 = B[x]; |
1681 | * A[x] = B_1 + B_1; |
1682 | * } else { |
1683 | * A[x] = B[x]; |
1684 | * } |
1685 | */ |
1686 | |
1687 | std::ostringstream oss; |
1688 | oss << *stmt; |
1689 | |
1690 | const std::string& verification_pattern = |
1691 | R"IR( |
1692 | # CHECK-NOT: int |
1693 | # CHECK-NOT: float |
1694 | # CHECK: if (x<3 |
1695 | # CHECK: float B_1 = |
1696 | # CHECK: A[x] = B_1 + B_1 |
1697 | # CHECK: } else { |
1698 | # CHECK: A[x] = B[x] |
1699 | # CHECK: } |
1700 | # CHECK-NOT: A[x] |
1701 | # CHECK-NOT: B[x])IR" ; |
1702 | |
1703 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1704 | } |
1705 | |
1706 | // Can registerize a load that occurs in the condition of an IfThenElse; |
1707 | TEST(Registerizer, RegisterizerIfThenElseCondition) { |
1708 | BufHandle a("A" , {5}, kInt); |
1709 | BufHandle b("B" , {5}, kInt); |
1710 | BufHandle c("C" , {5}, kInt); |
1711 | VarHandle x("x" , kInt); |
1712 | |
1713 | StmtPtr stmt = Block::make( |
1714 | {Store::make(a, {x}, Load::make(a, {x})), |
1715 | Store::make( |
1716 | a, |
1717 | {x}, |
1718 | IfThenElse::make( |
1719 | CompareSelect::make( |
1720 | Load::make(a, {x}), 5, CompareSelectOperation::kLT), |
1721 | Load::make(b, {0}), |
1722 | Load::make(c, {0})))}); |
1723 | |
1724 | /* |
1725 | * A[x] = A[x]; <---- just here so there are enough accesses to combine. |
1726 | * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); |
1727 | */ |
1728 | |
1729 | stmt = registerize(stmt); |
1730 | |
1731 | /* |
1732 | * int A_1 = A[x]; |
1733 | * A_1 = A_1; |
1734 | * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); |
1735 | * A[x] = A_1; |
1736 | */ |
1737 | |
1738 | std::ostringstream oss; |
1739 | oss << *stmt; |
1740 | |
1741 | const std::string& verification_pattern = |
1742 | R"IR( |
1743 | # CHECK: int A_1 = A[x]; |
1744 | # CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); |
1745 | # CHECK: A[x] = A_1;)IR" ; |
1746 | |
1747 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1748 | } |
1749 | |
1750 | // Appearing in the condition of a Cond makes it visible to the enclosing scope, |
1751 | // and so we can registerize internal usages. |
1752 | TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { |
1753 | BufHandle a("A" , {5}, kInt); |
1754 | BufHandle b("B" , {5}, kInt); |
1755 | BufHandle c("C" , {5}, kInt); |
1756 | VarHandle x("x" , kInt); |
1757 | |
1758 | StmtPtr stmt = Block::make({Store::make( |
1759 | b, |
1760 | {x}, |
1761 | IfThenElse::make( |
1762 | CompareSelect::make( |
1763 | Load::make(a, {x}), 5, CompareSelectOperation::kLT), |
1764 | Add::make(Load::make(a, {x}), 1), |
1765 | Add::make(Load::make(a, {x}), 10)))}); |
1766 | |
1767 | /* |
1768 | * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); |
1769 | */ |
1770 | |
1771 | stmt = registerize(stmt); |
1772 | |
1773 | /* |
1774 | * int A_1 = A[x]; |
1775 | * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); |
1776 | */ |
1777 | |
1778 | std::ostringstream oss; |
1779 | oss << *stmt; |
1780 | |
1781 | const std::string& verification_pattern = |
1782 | R"IR( |
1783 | # CHECK: int A_1 = A[x]; |
1784 | # CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR" ; |
1785 | |
1786 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1787 | } |
1788 | |
1789 | // Cannot promote accesses internal to IfThenElse branches even if the enclosing |
1790 | // scope if conditional. |
1791 | TEST(Registerizer, RegisterizerConditionBranchOnly) { |
1792 | BufHandle a("A" , {5}, kInt); |
1793 | VarHandle x("x" , kInt); |
1794 | StmtPtr stmt = Block::make({For::make( |
1795 | x, |
1796 | 0, |
1797 | 10, |
1798 | Block::make({ |
1799 | Cond::make( |
1800 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1801 | Store::make( |
1802 | a, |
1803 | {x}, |
1804 | IfThenElse::make( |
1805 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1806 | Add::make(Load::make(a, {x}), x), |
1807 | Add::make(Load::make(a, {x - 5}), x))), |
1808 | Store::make( |
1809 | a, |
1810 | {x - 5}, |
1811 | IfThenElse::make( |
1812 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
1813 | Add::make(Load::make(a, {x}), x), |
1814 | Add::make(Load::make(a, {x - 5}), x)))), |
1815 | }))}); |
1816 | stmt = IRSimplifier::simplify(stmt); |
1817 | |
1818 | std::ostringstream before; |
1819 | before << *stmt; |
1820 | |
1821 | /* for (int x = 0; x < 10; x++) { |
1822 | * if (x<5 ? 1 : 0) { |
1823 | * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); |
1824 | * } else { |
1825 | * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); |
1826 | * } |
1827 | * } |
1828 | */ |
1829 | |
1830 | // No change. |
1831 | stmt = registerize(stmt); |
1832 | |
1833 | std::ostringstream after; |
1834 | after << *stmt; |
1835 | |
1836 | ASSERT_EQ(before.str(), after.str()); |
1837 | } |
1838 | |
1839 | // We can registerize an IfThenElse that appears in the condition branch of a |
1840 | // Cond. This is a weird but valid thing to do. |
1841 | TEST(Registerizer, RegisterizerCondIfThenElse) { |
1842 | BufHandle a("A" , {5}, kInt); |
1843 | BufHandle b("B" , {5}, kInt); |
1844 | BufHandle c("C" , {5}, kInt); |
1845 | VarHandle x("x" , kInt); |
1846 | |
1847 | StmtPtr stmt = Block::make({Cond::make( |
1848 | CompareSelect::make( |
1849 | IfThenElse::make( |
1850 | CompareSelect::make( |
1851 | Load::make(a, {x}), 5, CompareSelectOperation::kLT), |
1852 | Load::make(a, {x}), |
1853 | Load::make(b, {x})), |
1854 | x, |
1855 | CompareSelectOperation::kEQ), |
1856 | Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), |
1857 | nullptr)}); |
1858 | |
1859 | /* |
1860 | * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { |
1861 | * C[x] = (C[x]) + 1; |
1862 | * } |
1863 | */ |
1864 | |
1865 | stmt = registerize(stmt); |
1866 | |
1867 | // access to A can be registerized, but not B or C |
1868 | |
1869 | /* |
1870 | * int A_1 = A[x]; |
1871 | * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { |
1872 | * C[x] = (C[x]) + 1; |
1873 | * } |
1874 | */ |
1875 | |
1876 | std::ostringstream oss; |
1877 | oss << *stmt; |
1878 | |
1879 | const std::string& verification_pattern = |
1880 | R"IR( |
1881 | # CHECK: int A_1 = A[x]; |
1882 | # CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] |
1883 | # CHECK: C[x] = (C[x]) + 1;)IR" ; |
1884 | |
1885 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1886 | } |
1887 | |
1888 | // Can registerize a conditional access in the RHS of a store unhidden by it's |
1889 | // LHS, and hoist it out of a loop. |
1890 | TEST(Registerizer, RegisterizerIfThenElseLoop) { |
1891 | BufHandle a("A" , {5}, kInt); |
1892 | BufHandle b("B" , {5}, kInt); |
1893 | VarHandle x("x" , kInt); |
1894 | VarHandle y("y" , kInt); |
1895 | |
1896 | StmtPtr stmt = For::make( |
1897 | y, |
1898 | 0, |
1899 | 10, |
1900 | Store::make( |
1901 | a, |
1902 | {x}, |
1903 | IfThenElse::make( |
1904 | CompareSelect::make(x, 3, CompareSelectOperation::kLT), |
1905 | Load::make(a, {x}), |
1906 | Load::make(b, {y})))); |
1907 | |
1908 | /* |
1909 | * for (int y = 0; y < 10; y++) { |
1910 | * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); |
1911 | * } |
1912 | */ |
1913 | |
1914 | stmt = registerize(stmt); |
1915 | |
1916 | /* |
1917 | * int A_1 = A[x]; |
1918 | * for (int y = 0; y < 10; y++) { |
1919 | * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); |
1920 | * } |
1921 | * A[x] = A_1; |
1922 | */ |
1923 | |
1924 | std::ostringstream oss; |
1925 | oss << *stmt; |
1926 | |
1927 | const std::string& verification_pattern = |
1928 | R"IR( |
1929 | # CHECK: int A_1 = A[x]; |
1930 | # CHECK: for ( |
1931 | # CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); |
1932 | # CHECK: } |
1933 | # CHECK: A[x] = A_1;)IR" ; |
1934 | |
1935 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1936 | } |
1937 | |
1938 | // Cannot registerize if the RHS overlaps the access creating visibility. |
1939 | TEST(Registerizer, RegisterizerIfThenElseLoopCut) { |
1940 | BufHandle a("A" , {5}, kInt); |
1941 | BufHandle b("B" , {5}, kInt); |
1942 | VarHandle x("x" , kInt); |
1943 | VarHandle y("y" , kInt); |
1944 | |
1945 | StmtPtr stmt = Block::make({For::make( |
1946 | y, |
1947 | 0, |
1948 | 10, |
1949 | Store::make( |
1950 | a, |
1951 | {x}, |
1952 | IfThenElse::make( |
1953 | CompareSelect::make(x, 3, CompareSelectOperation::kLT), |
1954 | Load::make(a, {x}), |
1955 | Load::make(a, {y}))))}); |
1956 | |
1957 | /* |
1958 | * for (int y = 0; y < 10; y++) { |
1959 | * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); |
1960 | * } |
1961 | */ |
1962 | |
1963 | std::ostringstream before; |
1964 | before << *stmt; |
1965 | |
1966 | // No change. |
1967 | stmt = registerize(stmt); |
1968 | |
1969 | std::ostringstream after; |
1970 | after << *stmt; |
1971 | |
1972 | ASSERT_EQ(before.str(), after.str()); |
1973 | } |
1974 | |
1975 | // Simple case where an access is cut by an overlapping access later in the |
1976 | // program, we can registerize up until the overlap. |
1977 | TEST(Registerizer, RegisterizerPartialAfter) { |
1978 | BufHandle a("A" , {1}, kInt); |
1979 | VarHandle x("x" , kInt); |
1980 | StmtPtr stmt = Block::make( |
1981 | {Store::make(a, {0}, 0), |
1982 | For::make( |
1983 | x, |
1984 | 0, |
1985 | 10, |
1986 | Block::make( |
1987 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})), |
1988 | For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))}); |
1989 | |
1990 | /* |
1991 | * A[0] = 0; |
1992 | * for (int x = 0; x < 10; x++) { |
1993 | * A[0] = (A[0]) + x; |
1994 | * } |
1995 | * for (int x = 1; x < 10; x++) { |
1996 | * A[x] = A[x - 1]; |
1997 | * } |
1998 | */ |
1999 | |
2000 | stmt = registerize(stmt); |
2001 | |
2002 | /* |
2003 | * int A_1 = 0; |
2004 | * for (int x = 0; x < 10; x++) { |
2005 | * A_1 = A_1 + x; |
2006 | * } |
2007 | * A[0] = A_1; |
2008 | * for (int x = 1; x < 10; x++) { |
2009 | * A[x] = A[x - 1]; |
2010 | * } |
2011 | */ |
2012 | |
2013 | std::ostringstream oss; |
2014 | oss << *stmt; |
2015 | |
2016 | const std::string& verification_pattern = |
2017 | R"IR( |
2018 | # CHECK: int A_1 = 0; |
2019 | # CHECK: for ( |
2020 | # CHECK: A_1 = A_1 + x; |
2021 | # CHECK: } |
2022 | # CHECK: A[0] = A_1; |
2023 | # CHECK: for ( |
2024 | # CHECK: A[x] = A[x - 1]; |
2025 | # CHECK: } |
2026 | # CHECK-NOT: A)IR" ; |
2027 | |
2028 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2029 | } |
2030 | |
2031 | // We can registerize an access which overlaps a previous access, the |
2032 | // initializer must be inserted after the previous access. |
2033 | TEST(Registerizer, RegisterizerPartialBefore) { |
2034 | BufHandle a("A" , {1}, kInt); |
2035 | VarHandle x("x" , kInt); |
2036 | StmtPtr stmt = Block::make( |
2037 | {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), |
2038 | Store::make(a, {0}, 0), |
2039 | For::make( |
2040 | x, |
2041 | 0, |
2042 | 10, |
2043 | Block::make( |
2044 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); |
2045 | |
2046 | /* |
2047 | * for (int x = 1; x < 10; x++) { |
2048 | * A[x] = A[x - 1]; |
2049 | * } |
2050 | * A[0] = 0; |
2051 | * for (int x = 0; x < 10; x++) { |
2052 | * A[0] = (A[0]) + x; |
2053 | * } |
2054 | */ |
2055 | |
2056 | stmt = registerize(stmt); |
2057 | |
2058 | /* |
2059 | * for (int x = 1; x < 10; x++) { |
2060 | * A[x] = A[x - 1]; |
2061 | * } |
2062 | * int A_1 = 0; |
2063 | * for (int x = 0; x < 10; x++) { |
2064 | * A_1 = A_1 + x; |
2065 | * } |
2066 | * A[0] = A_1; |
2067 | */ |
2068 | |
2069 | std::ostringstream oss; |
2070 | oss << *stmt; |
2071 | |
2072 | const std::string& verification_pattern = |
2073 | R"IR( |
2074 | # CHECK-NOT: int |
2075 | # CHECK: for ( |
2076 | # CHECK: A[x] = A[x - 1]; |
2077 | # CHECK: } |
2078 | # CHECK: int A_1 = 0; |
2079 | # CHECK: for ( |
2080 | # CHECK: A_1 = A_1 + x; |
2081 | # CHECK: } |
2082 | # CHECK: A[0] = A_1;)IR" ; |
2083 | |
2084 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2085 | } |
2086 | |
2087 | // The combination of the previous two tests, an access is cut by an overlapping |
2088 | // access in both directions. |
2089 | TEST(Registerizer, RegisterizerPartialInside) { |
2090 | BufHandle a("A" , {1}, kInt); |
2091 | VarHandle x1("x1" , kInt); |
2092 | VarHandle x2("x2" , kInt); |
2093 | VarHandle x3("x3" , kInt); |
2094 | StmtPtr stmt = Block::make( |
2095 | {Store::make(a, {0}, 2), |
2096 | For::make( |
2097 | x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), |
2098 | For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))), |
2099 | For::make( |
2100 | x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))}); |
2101 | |
2102 | /* |
2103 | * A[0] = 2; |
2104 | * for (int x1 = 0; x1 < 10; x1++) { |
2105 | * A[0] = (A[0]) + x1; |
2106 | * } |
2107 | * for (int x2 = 1; x2 < 10; x2++) { |
2108 | * A[x2] = A[x2 - 1]; |
2109 | * } |
2110 | * for (int x3 = 0; x3 < 10; x3++) { |
2111 | * A[0] = (A[0]) + x3; |
2112 | * } |
2113 | */ |
2114 | |
2115 | stmt = registerize(stmt); |
2116 | |
2117 | /* |
2118 | * int A_1 = 2; |
2119 | * for (int x1 = 0; x1 < 10; x1++) { |
2120 | * A_1 = A_1 + x1; |
2121 | * } |
2122 | * A[0] = A_1; |
2123 | * for (int x2 = 1; x2 < 10; x2++) { |
2124 | * A[x2] = A[x2 - 1]; |
2125 | * } |
2126 | * int A_2 = A[0]; |
2127 | * for (int x3 = 0; x3 < 10; x3++) { |
2128 | * A_2 = A_2 + x3; |
2129 | * } |
2130 | * A[0] = A_2; |
2131 | */ |
2132 | |
2133 | std::ostringstream oss; |
2134 | oss << *stmt; |
2135 | |
2136 | const std::string& verification_pattern = |
2137 | R"IR( |
2138 | # CHECK: int A_1 = 2; |
2139 | # CHECK: for ( |
2140 | # CHECK: A_1 = A_1 + x1; |
2141 | # CHECK: } |
2142 | # CHECK: A[0] = A_1; |
2143 | # CHECK: for ( |
2144 | # CHECK: A[x2] = |
2145 | # CHECK: } |
2146 | # CHECK: int A_2 = A[0]; |
2147 | # CHECK: for ( |
2148 | # CHECK: A_2 = A_2 + x3; |
2149 | # CHECK: } |
2150 | # CHECK: A[0] = A_2;)IR" ; |
2151 | |
2152 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2153 | } |
2154 | |
2155 | // An element could be registerized program wide but is cut by a conditional |
2156 | // access, we should break this into two scalars and write back to the buffer |
2157 | // before the condition. |
2158 | TEST(Registerizer, RegisterizerPartialCondition) { |
2159 | BufHandle a("A" , {1}, kInt); |
2160 | VarHandle x("x" , kInt); |
2161 | StmtPtr stmt = Block::make( |
2162 | {Store::make(a, {0}, 2), |
2163 | For::make( |
2164 | x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), |
2165 | Cond::make( |
2166 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2167 | Store::make(a, {x}, Load::make(a, {x - 1})), |
2168 | nullptr), |
2169 | For::make( |
2170 | x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))}); |
2171 | |
2172 | /* |
2173 | * A[0] = 2; |
2174 | * for (int x = 0; x < 10; x++) { |
2175 | * A[0] = (A[0]) + x; |
2176 | * } |
2177 | * if (x<5 ? 1 : 0) { |
2178 | * A[x] = A[x - 1]; |
2179 | * } |
2180 | * for (int x = 0; x < 10; x++) { |
2181 | * A[0] = (A[0]) + x; |
2182 | * } |
2183 | */ |
2184 | |
2185 | stmt = registerize(stmt); |
2186 | |
2187 | /* |
2188 | * int A_1 = 2; |
2189 | * for (int x = 0; x < 10; x++) { |
2190 | * A_1 = A_1 + x; |
2191 | * } |
2192 | * A[0] = A_1; |
2193 | * if (x<5 ? 1 : 0) { |
2194 | * A[x] = A[x - 1]; |
2195 | * } |
2196 | * int A_2 = A[0]; |
2197 | * for (int x = 0; x < 10; x++) { |
2198 | * A_2 = A_2 + x; |
2199 | * } |
2200 | * A[0] = A_2; |
2201 | */ |
2202 | |
2203 | std::ostringstream oss; |
2204 | oss << *stmt; |
2205 | |
2206 | const std::string& verification_pattern = |
2207 | R"IR( |
2208 | # CHECK: int A_1 = 2; |
2209 | # CHECK: for ( |
2210 | # CHECK: A_1 = A_1 + x; |
2211 | # CHECK: } |
2212 | # CHECK: A[0] = A_1; |
2213 | # CHECK: if ( |
2214 | # CHECK: A[x] = |
2215 | # CHECK: } |
2216 | # CHECK: int A_2 = A[0]; |
2217 | # CHECK: for ( |
2218 | # CHECK: A_2 = A_2 + x; |
2219 | # CHECK: } |
2220 | # CHECK: A[0] = A_2;)IR" ; |
2221 | |
2222 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2223 | } |
2224 | |
2225 | // Tests case where an access is cut by an internal conditional access which |
2226 | // itself is registerized. |
2227 | TEST(Registerizer, RegisterizerPartialConditionInternalCut) { |
2228 | BufHandle a("A" , {1}, kInt); |
2229 | VarHandle x("x" , kInt); |
2230 | StmtPtr stmt = Block::make( |
2231 | {Store::make(a, {0}, 1), |
2232 | Store::make(a, {0}, 3), |
2233 | Cond::make( |
2234 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2235 | Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), |
2236 | nullptr), |
2237 | Store::make(a, {0}, 4), |
2238 | Store::make(a, {0}, 6)}); |
2239 | |
2240 | /* |
2241 | * A[0] = 1; |
2242 | * A[0] = 3; |
2243 | * if (x<5 ? 1 : 0) { |
2244 | * A[x] = 1; |
2245 | * A[x] = 3; |
2246 | * } |
2247 | * A[0] = 4; |
2248 | * A[0] = 6; |
2249 | */ |
2250 | |
2251 | stmt = registerize(stmt); |
2252 | |
2253 | /* |
2254 | * int A_1 = 1; |
2255 | * A_1 = 3; |
2256 | * A[0] = A_1; |
2257 | * if (x<5 ? 1 : 0) { |
2258 | * int A_2 = 1; |
2259 | * A_2 = 3; |
2260 | * A[x] = A_2; |
2261 | * } |
2262 | * int A_3 = 4; |
2263 | * A_3 = 6; |
2264 | * A[0] = A_3; |
2265 | */ |
2266 | |
2267 | std::ostringstream oss; |
2268 | oss << *stmt; |
2269 | |
2270 | const std::string& verification_pattern = |
2271 | R"IR( |
2272 | # CHECK: int A_1 = 1; |
2273 | # CHECK: A_1 = 3 |
2274 | # CHECK: A[0] = A_1; |
2275 | # CHECK: if ( |
2276 | # CHECK: int A_2 = 1; |
2277 | # CHECK: A_2 = 3; |
2278 | # CHECK: A[x] = A_2; |
2279 | # CHECK: } |
2280 | # CHECK: int A_3 = 4; |
2281 | # CHECK: A_3 = 6; |
2282 | # CHECK: A[0] = A_3;)IR" ; |
2283 | |
2284 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2285 | } |
2286 | |
2287 | // First statment in condition closes outer access, but can be registerized with |
2288 | // later statements. |
2289 | TEST(Registerizer, RegisterizerPartialConditionInternalStart) { |
2290 | BufHandle a("A" , {1}, kInt); |
2291 | VarHandle x("x" , kInt); |
2292 | StmtPtr stmt = Block::make( |
2293 | {Store::make(a, {0}, 1), |
2294 | Store::make(a, {0}, 3), |
2295 | Cond::make( |
2296 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2297 | Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), |
2298 | nullptr), |
2299 | Store::make(a, {x}, 4), |
2300 | Store::make(a, {x}, 6)}); |
2301 | |
2302 | /* |
2303 | * A[0] = 1; |
2304 | * A[0] = 3; |
2305 | * if (x<5 ? 1 : 0) { |
2306 | * A[x] = 1; |
2307 | * A[x] = 3; |
2308 | * } |
2309 | * A[x] = 4; |
2310 | * A[x] = 6; |
2311 | */ |
2312 | |
2313 | stmt = registerize(stmt); |
2314 | |
2315 | /* |
2316 | * int A_1 = 1; |
2317 | * A_1 = 3; |
2318 | * A[0] = A_1; |
2319 | * int A_2 = A[x]; <--- must read from the input here. |
2320 | * if (x<5 ? 1 : 0) { |
2321 | * A_2 = 1; |
2322 | * A_2 = 3; |
2323 | * } |
2324 | * A_2 = 4; |
2325 | * A_2 = 6; |
2326 | * A[x] = A_2; |
2327 | */ |
2328 | |
2329 | // TODO: I suppose we could refactor with a conditional initializier? |
2330 | |
2331 | std::ostringstream oss; |
2332 | oss << *stmt; |
2333 | |
2334 | const std::string& verification_pattern = |
2335 | R"IR( |
2336 | # CHECK: int A_1 = 1; |
2337 | # CHECK: A_1 = 3 |
2338 | # CHECK: A[0] = A_1; |
2339 | # CHECK: int A_2 = A[x]; |
2340 | # CHECK: if ( |
2341 | # CHECK: A_2 = 1; |
2342 | # CHECK: A_2 = 3; |
2343 | # CHECK: } |
2344 | # CHECK: A_2 = 4; |
2345 | # CHECK: A_2 = 6; |
2346 | # CHECK: A[x] = A_2;)IR" ; |
2347 | |
2348 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2349 | } |
2350 | |
2351 | // An access cuts two open overlaps and creates four scalar variables. |
2352 | TEST(Registerizer, RegisterizerPartialOverlapsTwo) { |
2353 | BufHandle a("A" , {1}, kInt); |
2354 | VarHandle x("x" , kInt); |
2355 | StmtPtr stmt = Block::make( |
2356 | {Store::make(a, {1}, Load::make(a, {0})), |
2357 | Store::make(a, {0}, Load::make(a, {1})), |
2358 | Store::make(a, {0}, Load::make(a, {1})), |
2359 | For::make(x, 1, 10, Store::make(a, {x}, x)), |
2360 | Store::make(a, {1}, Load::make(a, {0})), |
2361 | Store::make(a, {0}, Load::make(a, {1})), |
2362 | Store::make(a, {0}, Load::make(a, {1}))}); |
2363 | |
2364 | /* |
2365 | * A[1] = A[0]; |
2366 | * A[0] = A[1]; |
2367 | * A[0] = A[1]; |
2368 | * for (int x = 1; x < 10; x++) { |
2369 | * A[x] = x; |
2370 | * } |
2371 | * A[1] = A[0]; |
2372 | * A[0] = A[1]; |
2373 | * A[0] = A[1]; |
2374 | */ |
2375 | |
2376 | stmt = registerize(stmt); |
2377 | |
2378 | /* |
2379 | * int A_1 = A[0]; |
2380 | * int A_2 = A_1; |
2381 | * A_1 = A_2; |
2382 | * A_1 = A_2; |
2383 | * A[1] = A_2; |
2384 | * A[0] = A_1; |
2385 | * for (int x = 1; x < 10; x++) { |
2386 | * A[x] = x; |
2387 | * } |
2388 | * int A_3 = A[0]; |
2389 | * int A_4 = A_3; |
2390 | * A_3 = A_4; |
2391 | * A_3 = A_4; |
2392 | * A[1] = A_4; |
2393 | * A[0] = A_3; |
2394 | */ |
2395 | |
2396 | std::ostringstream oss; |
2397 | oss << *stmt; |
2398 | |
2399 | const std::string& verification_pattern = |
2400 | R"IR( |
2401 | # CHECK: int A_1 = A[0]; |
2402 | # CHECK: int A_2 = A_1; |
2403 | # CHECK: A_1 = A_2; |
2404 | # CHECK: A_1 = A_2; |
2405 | # CHECK: A[1] = A_2; |
2406 | # CHECK: A[0] = A_1; |
2407 | # CHECK: for ( |
2408 | # CHECK: A[x] = x; |
2409 | # CHECK: } |
2410 | # CHECK: int A_3 = A[0]; |
2411 | # CHECK: int A_4 = A_3; |
2412 | # CHECK: A_3 = A_4; |
2413 | # CHECK: A_3 = A_4; |
2414 | # CHECK: A[1] = A_4; |
2415 | # CHECK: A[0] = A_3;)IR" ; |
2416 | |
2417 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2418 | } |
2419 | |
2420 | // Nested blocks will automatically be flattened and do not provent |
2421 | // registerization of enclosed accesses. |
2422 | TEST(Registerizer, RegisterizerNestedBlocks) { |
2423 | BufHandle a("A" , {1}, kInt); |
2424 | VarHandle x("x" , kInt); |
2425 | StmtPtr stmt = Block::make( |
2426 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
2427 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2428 | Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), |
2429 | Block::make( |
2430 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)), |
2431 | Block::make( |
2432 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})}); |
2433 | |
2434 | /* |
2435 | * A[0] = (A[0]) + 1; |
2436 | * { |
2437 | * A[0] = (A[0]) + 2; |
2438 | * } |
2439 | * { |
2440 | * A[0] = (A[0]) + 3; |
2441 | * { |
2442 | * A[0] = (A[0]) + 4; |
2443 | * } |
2444 | * } |
2445 | */ |
2446 | |
2447 | stmt = registerize(stmt); |
2448 | |
2449 | /* |
2450 | * int A_1 = A[0]; |
2451 | * A_1 = A_1 + 1; |
2452 | * A_1 = A_1 + 2; |
2453 | * A_1 = A_1 + 3; |
2454 | * A_1 = A_1 + 4; |
2455 | * A[0] = A_1; |
2456 | */ |
2457 | |
2458 | std::ostringstream oss; |
2459 | oss << *stmt; |
2460 | |
2461 | const std::string& verification_pattern = |
2462 | R"IR( |
2463 | # CHECK: int A_1 = A[0]; |
2464 | # CHECK: A_1 = A_1 + 1; |
2465 | # CHECK: A_1 = A_1 + 2; |
2466 | # CHECK: A_1 = A_1 + 3; |
2467 | # CHECK: A_1 = A_1 + 4; |
2468 | # CHECK: A[0] = A_1;)IR" ; |
2469 | |
2470 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2471 | } |
2472 | |
2473 | // The access can be registerized internally to a condition, but must ensure |
2474 | // that both initializer and finalizer are within the same condition. |
2475 | TEST(Registerizer, RegisterizerNestedConditions) { |
2476 | BufHandle a("A" , {1}, kInt); |
2477 | VarHandle x("x" , kInt); |
2478 | StmtPtr stmt = Block::make({Cond::make( |
2479 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2480 | Block::make( |
2481 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2482 | Cond::make( |
2483 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2484 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2485 | nullptr)}), |
2486 | nullptr)}); |
2487 | |
2488 | /* |
2489 | * if (x<5 ? 1 : 0) { |
2490 | * A[0] = (A[0]) + 1; |
2491 | * if (x==2 ? 1 : 0) { |
2492 | * |
2493 | * A[0] = (A[0]) + 1; |
2494 | * } |
2495 | * } |
2496 | */ |
2497 | |
2498 | stmt = registerize(stmt); |
2499 | |
2500 | /* |
2501 | * if (x<5 ? 1 : 0) { |
2502 | * int A_1 = A[0]; |
2503 | * A_1 = A_1 + 1; |
2504 | * if (x==2 ? 1 : 0) { |
2505 | * A_1 = A_1 + 1; |
2506 | * } |
2507 | * A[0] = A_1; |
2508 | * } |
2509 | */ |
2510 | |
2511 | std::ostringstream oss; |
2512 | oss << *stmt; |
2513 | |
2514 | const std::string& verification_pattern = |
2515 | R"IR( |
2516 | # CHECK: if (x<5 |
2517 | # CHECK: int A_1 = A[0]; |
2518 | # CHECK: A_1 = A_1 + 1; |
2519 | # CHECK: if (x==2 |
2520 | # CHECK: A_1 = A_1 + 1; |
2521 | # CHECK: } |
2522 | # CHECK: A[0] = A_1; |
2523 | # CHECK: })IR" ; |
2524 | |
2525 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2526 | } |
2527 | |
2528 | // If an access exists outside the scope of the condition then we can lift |
2529 | // nested conditional usages into the same scalar. |
2530 | TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { |
2531 | BufHandle a("A" , {1}, kInt); |
2532 | VarHandle x("x" , kInt); |
2533 | StmtPtr stmt = Block::make( |
2534 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2535 | Cond::make( |
2536 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2537 | Block::make( |
2538 | {Store::make(a, {1}, 1), |
2539 | Cond::make( |
2540 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2541 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2542 | nullptr)}), |
2543 | nullptr)}); |
2544 | |
2545 | /* |
2546 | * A[0] = (A[0]) + 1; |
2547 | * if (x<5 ? 1 : 0) { |
2548 | * A[1] = 1; |
2549 | * if (x==2 ? 1 : 0) { |
2550 | * A[0] = (A[0]) + 1; |
2551 | * } |
2552 | * } |
2553 | */ |
2554 | |
2555 | stmt = registerize(stmt); |
2556 | |
2557 | /* |
2558 | * int A_1 = A[0]; |
2559 | * A_1 = A_1 + 1; |
2560 | * if (x<5 ? 1 : 0) { |
2561 | * A[1] = 1; |
2562 | * if (x==2 ? 1 : 0) { |
2563 | * A_1 = A_1 + 1; |
2564 | * } |
2565 | * } |
2566 | * A[0] = A_1; |
2567 | */ |
2568 | |
2569 | std::ostringstream oss; |
2570 | oss << *stmt; |
2571 | |
2572 | const std::string& verification_pattern = |
2573 | R"IR( |
2574 | # CHECK: int A_1 = A[0]; |
2575 | # CHECK: A_1 = A_1 + 1; |
2576 | # CHECK: if (x<5 |
2577 | # CHECK: A[1] = 1; |
2578 | # CHECK: if (x==2 |
2579 | # CHECK: A_1 = A_1 + 1; |
2580 | # CHECK: A[0] = A_1;)IR" ; |
2581 | |
2582 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2583 | } |
2584 | |
2585 | TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { |
2586 | BufHandle a("A" , {1}, kInt); |
2587 | VarHandle x("x" , kInt); |
2588 | StmtPtr stmt = Block::make( |
2589 | {Cond::make( |
2590 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2591 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2592 | nullptr), |
2593 | Cond::make( |
2594 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2595 | Block::make({Cond::make( |
2596 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2597 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2598 | nullptr)}), |
2599 | nullptr)}); |
2600 | |
2601 | /* |
2602 | * if (x==2 ? 1 : 0) { |
2603 | * A[0] = (A[0]) + 1; |
2604 | * } |
2605 | * if (x<5 ? 1 : 0) { |
2606 | * if (x==2 ? 1 : 0) { |
2607 | * A[0] = (A[0]) + 1; |
2608 | * } |
2609 | * } |
2610 | */ |
2611 | |
2612 | std::ostringstream before; |
2613 | before << *stmt; |
2614 | |
2615 | // No change. |
2616 | stmt = registerize(stmt); |
2617 | |
2618 | std::ostringstream after; |
2619 | after << *stmt; |
2620 | |
2621 | ASSERT_EQ(before.str(), after.str()); |
2622 | |
2623 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
2624 | stmt = registerize(stmt); |
2625 | } |
2626 | |
2627 | TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { |
2628 | BufHandle a("A" , {1}, kInt); |
2629 | VarHandle x("x" , kInt); |
2630 | StmtPtr stmt = Block::make( |
2631 | {Cond::make( |
2632 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2633 | Block::make({Cond::make( |
2634 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2635 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2636 | nullptr)}), |
2637 | nullptr), |
2638 | Cond::make( |
2639 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2640 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2641 | nullptr)}); |
2642 | |
2643 | /* |
2644 | * if (x<5 ? 1 : 0) { |
2645 | * if (x==2 ? 1 : 0) { |
2646 | * A[0] = (A[0]) + 1; |
2647 | * } |
2648 | * } |
2649 | * if (x==2 ? 1 : 0) { |
2650 | * A[0] = (A[0]) + 1; |
2651 | * } |
2652 | */ |
2653 | |
2654 | std::ostringstream before; |
2655 | before << *stmt; |
2656 | |
2657 | // No change. |
2658 | stmt = registerize(stmt); |
2659 | |
2660 | std::ostringstream after; |
2661 | after << *stmt; |
2662 | |
2663 | ASSERT_EQ(before.str(), after.str()); |
2664 | |
2665 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
2666 | stmt = registerize(stmt); |
2667 | } |
2668 | |
2669 | // If an access is cut by another access internal to a condition block, it still |
2670 | // cuts the access. |
2671 | TEST(Registerizer, RegisterizerNestedConditionsCut) { |
2672 | BufHandle a("A" , {1}, kInt); |
2673 | VarHandle x("x" , kInt); |
2674 | StmtPtr stmt = Block::make( |
2675 | {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2676 | Cond::make( |
2677 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
2678 | Block::make( |
2679 | {Store::make(a, {x}, 1), |
2680 | Cond::make( |
2681 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2682 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2683 | nullptr)}), |
2684 | nullptr)}); |
2685 | |
2686 | /* |
2687 | * A[0] = (A[0]) + 1; |
2688 | * if (x<5 ? 1 : 0) { |
2689 | * A[x] = 1; |
2690 | * if (x==2 ? 1 : 0) { |
2691 | * |
2692 | * A[0] = (A[0]) + 1; |
2693 | * } |
2694 | * } |
2695 | */ |
2696 | |
2697 | std::ostringstream before; |
2698 | before << *stmt; |
2699 | |
2700 | // No change. |
2701 | stmt = registerize(stmt); |
2702 | |
2703 | std::ostringstream after; |
2704 | after << *stmt; |
2705 | |
2706 | ASSERT_EQ(before.str(), after.str()); |
2707 | } |
2708 | |
2709 | TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { |
2710 | BufHandle a("A" , {10}, kInt); |
2711 | BufHandle b("B" , {10}, kInt); |
2712 | VarHandle x("x" , kInt); |
2713 | StmtPtr stmt = Block::make( |
2714 | {Cond::make( |
2715 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2716 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2717 | nullptr), |
2718 | For::make( |
2719 | x, |
2720 | 0, |
2721 | 10, |
2722 | Block::make( |
2723 | {Store::make(b, {x}, 0), |
2724 | Cond::make( |
2725 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2726 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), |
2727 | nullptr)}))}); |
2728 | |
2729 | /* |
2730 | * if (x==2 ? 1 : 0) { |
2731 | * A[0] = (A[0]) + 1; |
2732 | * } |
2733 | * for (int x = 0; x < 10; x++) { |
2734 | * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. |
2735 | * if (x==2 ? 1 : 0) { |
2736 | * A[0] = (A[0]) + 1; |
2737 | * } |
2738 | * } |
2739 | */ |
2740 | |
2741 | std::ostringstream before; |
2742 | before << *stmt; |
2743 | |
2744 | // No change. |
2745 | stmt = registerize(stmt); |
2746 | |
2747 | std::ostringstream after; |
2748 | after << *stmt; |
2749 | |
2750 | ASSERT_EQ(before.str(), after.str()); |
2751 | } |
2752 | |
2753 | // Three loops and four element regions, three of which should be registerized |
2754 | // at different levels of the IR. |
2755 | TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { |
2756 | BufHandle a("A" , {10}, kInt); |
2757 | BufHandle b("B" , {10}, kInt); |
2758 | VarHandle x("x" , kInt); |
2759 | StmtPtr stmt = Block::make( |
2760 | {Store::make(a, {4}, 0), |
2761 | Cond::make( |
2762 | CompareSelect::make(x, 2, CompareSelectOperation::kGT), |
2763 | Cond::make( |
2764 | CompareSelect::make(x, 3, CompareSelectOperation::kGT), |
2765 | Block::make({ |
2766 | Cond::make( |
2767 | CompareSelect::make(x, 4, CompareSelectOperation::kGT), |
2768 | Block::make({ |
2769 | Store::make( |
2770 | a, {1}, Add::make(Load::make(a, {1}), 1)), |
2771 | Store::make( |
2772 | a, {2}, Add::make(Load::make(a, {2}), 1)), |
2773 | Store::make( |
2774 | a, {3}, Add::make(Load::make(a, {3}), 1)), |
2775 | Store::make( |
2776 | a, {4}, Add::make(Load::make(a, {4}), 1)), |
2777 | Store::make( |
2778 | a, {1}, Add::make(Load::make(a, {1}), 1)), |
2779 | }), |
2780 | nullptr), |
2781 | Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)), |
2782 | }), |
2783 | nullptr), |
2784 | nullptr)}); |
2785 | |
2786 | /* |
2787 | * A[4] = 0; |
2788 | * if (x>2 ? 1 : 0) { |
2789 | * if (x>3 ? 1 : 0) { |
2790 | * if (x>4 ? 1 : 0) { |
2791 | * A[1] = (A[1]) + 1; |
2792 | * A[2] = (A[2]) + 1; |
2793 | * A[3] = (A[3]) + 1; |
2794 | * A[4] = (A[4]) + 1; |
2795 | * A[1] = (A[1]) + 1; |
2796 | * } |
2797 | * A[2] = (A[2]) + 1; |
2798 | * } |
2799 | * } |
2800 | */ |
2801 | |
2802 | stmt = registerize(stmt); |
2803 | |
2804 | /* |
2805 | * int A_1 = 0; |
2806 | * if (x>2 ? 1 : 0) { |
2807 | * if (x>3 ? 1 : 0) { |
2808 | * int A_3 = A[2]; |
2809 | * if (x>4 ? 1 : 0) { |
2810 | * int A_2 = A[1]; |
2811 | * A_2 = A_2 + 1; |
2812 | * A_3 = A_3 + 1; |
2813 | * A[3] = (A[3]) + 1; |
2814 | * A_1 = A_1 + 1; |
2815 | * A_2 = A_2 + 1; |
2816 | * A[1] = A_2; |
2817 | * } |
2818 | * A_3 = A_3 + 1; |
2819 | * A[2] = A_3; |
2820 | * } |
2821 | * } |
2822 | * A[4] = A_1; |
2823 | */ |
2824 | |
2825 | std::ostringstream oss; |
2826 | oss << *stmt; |
2827 | |
2828 | const std::string& verification_pattern = |
2829 | R"IR( |
2830 | # CHECK: int A_1 = 0; |
2831 | # CHECK: if (x>2 ? 1 : 0) { |
2832 | # CHECK: if (x>3 ? 1 : 0) { |
2833 | # CHECK: int A_3 = A[2]; |
2834 | # CHECK: if (x>4 ? 1 : 0) { |
2835 | # CHECK: int A_2 = A[1]; |
2836 | # CHECK: A_2 = A_2 + 1; |
2837 | # CHECK: A_3 = A_3 + 1; |
2838 | # CHECK: A[3] = (A[3]) + 1; |
2839 | # CHECK: A_1 = A_1 + 1; |
2840 | # CHECK: A_2 = A_2 + 1; |
2841 | # CHECK: A[1] = A_2; |
2842 | # CHECK: } |
2843 | # CHECK: A_3 = A_3 + 1; |
2844 | # CHECK: A[2] = A_3; |
2845 | # CHECK: } |
2846 | # CHECK: } |
2847 | # CHECK: A[4] = A_1;)IR" ; |
2848 | |
2849 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2850 | } |
2851 | |
2852 | // Can replace a simple scalar access with a local variable even when that |
2853 | // variable is an outer loop var. |
2854 | TEST(Registerizer, RegisterizerNestedLoopSimple) { |
2855 | BufHandle a("A" , {1}, kInt); |
2856 | VarHandle x("x" , kInt); |
2857 | VarHandle y("y" , kInt); |
2858 | StmtPtr stmt = Block::make({For::make( |
2859 | y, |
2860 | 0, |
2861 | 10, |
2862 | For::make( |
2863 | x, |
2864 | 0, |
2865 | 10, |
2866 | Block::make( |
2867 | {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))}); |
2868 | |
2869 | /* |
2870 | * for (int y = 0; y < 10; y++) { |
2871 | * for (int x = 0; x < 10; x++) { |
2872 | * A[y] = (A[y]) + x; |
2873 | * } |
2874 | * } |
2875 | */ |
2876 | |
2877 | stmt = registerize(stmt); |
2878 | |
2879 | /* |
2880 | * for (int y = 0; y < 10; y++) { |
2881 | * int A_1 = A[y]; |
2882 | * for (int x = 0; x < 10; x++) { |
2883 | * A_1 = A_1 + x; |
2884 | * } |
2885 | * A[y] = A_1; |
2886 | * } |
2887 | */ |
2888 | |
2889 | std::ostringstream oss; |
2890 | oss << *stmt; |
2891 | |
2892 | const std::string& verification_pattern = |
2893 | R"IR( |
2894 | # CHECK: for (int y |
2895 | # CHECK: int A_1 = A[y]; |
2896 | # CHECK: for (int x |
2897 | # CHECK: A_1 = A_1 + x; |
2898 | # CHECK: } |
2899 | # CHECK: A[y] = A_1; |
2900 | # CHECK: })IR" ; |
2901 | |
2902 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2903 | } |
2904 | |
2905 | // Test the positive case of the hiddenAccess split, where an internal |
2906 | // conditional access can be hoisted up through a loop to match an existing |
2907 | // access in a higher scope and the two can be registerized. |
2908 | TEST(Registerizer, RegisterizerHiddenAccessYes) { |
2909 | BufHandle a("A" , {10}, kInt); |
2910 | BufHandle b("B" , {10}, kInt); |
2911 | VarHandle x("x" , kInt); |
2912 | VarHandle y("y" , kInt); |
2913 | StmtPtr stmt = Block::make({Cond::make( |
2914 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2915 | Block::make( |
2916 | {Store::make(a, {0}, 0), |
2917 | For::make( |
2918 | x, |
2919 | 0, |
2920 | 10, |
2921 | Block::make( |
2922 | {Store::make(b, {x}, 0), |
2923 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
2924 | Cond::make( |
2925 | CompareSelect::make(x, 3, CompareSelectOperation::kEQ), |
2926 | For::make( |
2927 | y, |
2928 | 0, |
2929 | 10, |
2930 | Store::make( |
2931 | a, {0}, Add::make(Load::make(a, {0}), 1))), |
2932 | nullptr)}))}), |
2933 | nullptr)}); |
2934 | |
2935 | /* |
2936 | * if (x==2 ? 1 : 0) { |
2937 | * A[0] = 0; |
2938 | * for (int x = 0; x < 10; x++) { |
2939 | * B[x] = 0; |
2940 | * if (x==3 ? 1 : 0) { |
2941 | * for (int y = 0; y < 10; y++) { |
2942 | * A[0] = (A[0]) + 1; |
2943 | * } |
2944 | * } |
2945 | * } |
2946 | * } |
2947 | */ |
2948 | |
2949 | stmt = registerize(stmt); |
2950 | |
2951 | /* |
2952 | * if (x==2 ? 1 : 0) { |
2953 | * int A_1 = 0; |
2954 | * for (int x = 0; x < 10; x++) { |
2955 | * B[x] = 0; |
2956 | * if (x==3 ? 1 : 0) { |
2957 | * for (int y = 0; y < 10; y++) { |
2958 | * A_1 = A_1 + 1; |
2959 | * } |
2960 | * } |
2961 | * } |
2962 | * A[0] = A_1; |
2963 | * } |
2964 | */ |
2965 | |
2966 | std::ostringstream oss; |
2967 | oss << *stmt; |
2968 | |
2969 | const std::string& verification_pattern = |
2970 | R"IR( |
2971 | # CHECK: if (x==2 |
2972 | # CHECK: int A_1 = 0; |
2973 | # CHECK: for (int x |
2974 | # CHECK: B[x] = 0; |
2975 | # CHECK: if (x==3 |
2976 | # CHECK: for (int y |
2977 | # CHECK: A_1 = A_1 + 1; |
2978 | # CHECK: } |
2979 | # CHECK: } |
2980 | # CHECK: } |
2981 | # CHECK: A[0] = A_1; |
2982 | # CHECK: })IR" ; |
2983 | |
2984 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2985 | } |
2986 | |
2987 | // Test the negative case of the hiddenAccess split, where the hoisted access is |
2988 | // never unhidden at a higher scope and registerization occurs at the lower |
2989 | // scope. |
2990 | TEST(Registerizer, RegisterizerHiddenAccessNo) { |
2991 | BufHandle a("A" , {10}, kInt); |
2992 | BufHandle b("B" , {10}, kInt); |
2993 | VarHandle x("x" , kInt); |
2994 | VarHandle y("y" , kInt); |
2995 | StmtPtr stmt = Block::make({Cond::make( |
2996 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
2997 | Block::make({For::make( |
2998 | x, |
2999 | 0, |
3000 | 10, |
3001 | Block::make( |
3002 | {Store::make(b, {x}, 0), |
3003 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
3004 | Cond::make( |
3005 | CompareSelect::make(x, 3, CompareSelectOperation::kEQ), |
3006 | For::make( |
3007 | y, |
3008 | 0, |
3009 | 10, |
3010 | Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), |
3011 | nullptr)}))}), |
3012 | nullptr)}); |
3013 | |
3014 | /* |
3015 | * if (x==2 ? 1 : 0) { |
3016 | * A[0] = 0; |
3017 | * for (int x = 0; x < 10; x++) { |
3018 | * B[x] = 0; |
3019 | * if (x==3 ? 1 : 0) { |
3020 | * for (int y = 0; y < 10; y++) { |
3021 | * A[0] = (A[0]) + 1; |
3022 | * } |
3023 | * } |
3024 | * } |
3025 | * } |
3026 | */ |
3027 | |
3028 | stmt = registerize(stmt); |
3029 | |
3030 | /* |
3031 | * if (x==2 ? 1 : 0) { |
3032 | * for (int x = 0; x < 10; x++) { |
3033 | * B[x] = 0; |
3034 | * if (x==3 ? 1 : 0) { |
3035 | * int A_1 = A[0]; |
3036 | * for (int y = 0; y < 10; y++) { |
3037 | * A_1 = A_1 + 1; |
3038 | * } |
3039 | * A[0] = A_1; |
3040 | * } |
3041 | * } |
3042 | * } |
3043 | */ |
3044 | |
3045 | std::ostringstream oss; |
3046 | oss << *stmt; |
3047 | |
3048 | const std::string& verification_pattern = |
3049 | R"IR( |
3050 | # CHECK: if (x==2 |
3051 | # CHECK: for (int x |
3052 | # CHECK: B[x] = 0; |
3053 | # CHECK: if (x==3 |
3054 | # CHECK: int A_1 = A[0]; |
3055 | # CHECK: for (int y |
3056 | # CHECK: A_1 = A_1 + 1; |
3057 | # CHECK: } |
3058 | # CHECK: A[0] = A_1; |
3059 | # CHECK: } |
3060 | # CHECK: } |
3061 | # CHECK: })IR" ; |
3062 | |
3063 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3064 | } |
3065 | |
3066 | // In this case the conditional access must be hoisted by two loops, there are |
3067 | // two accesses here one is unhidden and the other isnt. A[0] can be |
3068 | // registerized but B[0] cannot. |
3069 | TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { |
3070 | BufHandle a("A" , {10}, kInt); |
3071 | BufHandle b("B" , {10}, kInt); |
3072 | VarHandle x("x" , kInt); |
3073 | VarHandle y("y" , kInt); |
3074 | StmtPtr stmt = Block::make({Cond::make( |
3075 | CompareSelect::make(x, 2, CompareSelectOperation::kEQ), |
3076 | Block::make( |
3077 | {Store::make(a, {0}, 0), |
3078 | For::make( |
3079 | x, |
3080 | 0, |
3081 | 10, |
3082 | For::make( |
3083 | y, |
3084 | 0, |
3085 | 10, |
3086 | Block::make({Cond::make( |
3087 | CompareSelect::make(y, 3, CompareSelectOperation::kEQ), |
3088 | Block::make( |
3089 | {Store::make( |
3090 | a, {0}, Add::make(Load::make(a, {0}), 1)), |
3091 | Store::make( |
3092 | b, {0}, Add::make(Load::make(b, {0}), 1))}), |
3093 | nullptr)})))}), |
3094 | nullptr)}); |
3095 | |
3096 | /* |
3097 | * if (x==2 ? 1 : 0) { |
3098 | * A[0] = 0; |
3099 | * for (int x = 0; x < 10; x++) { |
3100 | * for (int y = 0; y < 10; y++) { |
3101 | * if (y==3 ? 1 : 0) { |
3102 | * A[0] = (A[0]) + 1; |
3103 | * B[0] = (B[0]) + 1; |
3104 | * } |
3105 | * } |
3106 | * } |
3107 | * } |
3108 | */ |
3109 | |
3110 | stmt = registerize(stmt); |
3111 | |
3112 | /* |
3113 | * if (x==2 ? 1 : 0) { |
3114 | * int A_1 = 0; |
3115 | * for (int x = 0; x < 10; x++) { |
3116 | * for (int y = 0; y < 10; y++) { |
3117 | * if (y==3 ? 1 : 0) { |
3118 | * A_1 = A_1 + 1; |
3119 | * B[0] = (B[0]) + 1; |
3120 | * } |
3121 | * } |
3122 | * } |
3123 | * A[0] = A_1; |
3124 | * } |
3125 | */ |
3126 | |
3127 | std::ostringstream oss; |
3128 | oss << *stmt; |
3129 | |
3130 | const std::string& verification_pattern = |
3131 | R"IR( |
3132 | # CHECK: if (x==2 |
3133 | # CHECK: int A_1 = 0; |
3134 | # CHECK: for (int x |
3135 | # CHECK: for (int y |
3136 | # CHECK: if (y==3 |
3137 | # CHECK: A_1 = A_1 + 1; |
3138 | # CHECK: B[0] = (B[0]) + 1; |
3139 | # CHECK: } |
3140 | # CHECK: } |
3141 | # CHECK: } |
3142 | # CHECK: A[0] = A_1; |
3143 | # CHECK: })IR" ; |
3144 | |
3145 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3146 | } |
3147 | |
3148 | // Accesses are registerized inside two conditions, but the immeidate parent is |
3149 | // not a condition. |
3150 | TEST(Registerizer, RegisterizerTwoConditionalLoops) { |
3151 | BufHandle a("A" , {1}, kInt); |
3152 | VarHandle x("x" , kInt); |
3153 | StmtPtr stmt = Block::make( |
3154 | {Cond::make( |
3155 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
3156 | For::make( |
3157 | x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), |
3158 | nullptr), |
3159 | Cond::make( |
3160 | CompareSelect::make(x, 5, CompareSelectOperation::kGT), |
3161 | For::make( |
3162 | x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), |
3163 | nullptr)}); |
3164 | |
3165 | /* |
3166 | * if (x<5 ? 1 : 0) { |
3167 | * for (int x = 0; x < 10; x++) { |
3168 | * A[0] = (A[0]) + 1; |
3169 | * } |
3170 | * } |
3171 | * if (x>5 ? 1 : 0) { |
3172 | * for (int x = 0; x < 10; x++) { |
3173 | * A[0] = (A[0]) + 1; |
3174 | * } |
3175 | * } |
3176 | */ |
3177 | |
3178 | stmt = registerize(stmt); |
3179 | |
3180 | /* |
3181 | * if (x<5 ? 1 : 0) { |
3182 | * int A_1 = A[0]; |
3183 | * for (int x = 0; x < 10; x++) { |
3184 | * A_1 = A_1 + 1; |
3185 | * } |
3186 | * A[0] = A_1; |
3187 | * } |
3188 | * if (x>5 ? 1 : 0) { |
3189 | * int A_2 = A[0]; |
3190 | * for (int x = 0; x < 10; x++) { |
3191 | * A_2 = A_2 + 1; |
3192 | * } |
3193 | * A[0] = A_2; |
3194 | * } |
3195 | */ |
3196 | |
3197 | std::ostringstream oss; |
3198 | oss << *stmt; |
3199 | |
3200 | const std::string& verification_pattern = |
3201 | R"IR( |
3202 | # CHECK: if (x<5 |
3203 | # CHECK: int A_1 = A[0]; |
3204 | # CHECK: for (int x |
3205 | # CHECK: A_1 = A_1 + 1; |
3206 | # CHECK: } |
3207 | # CHECK: A[0] = A_1; |
3208 | # CHECK: } |
3209 | # CHECK: if (x>5 |
3210 | # CHECK: int A_2 = A[0]; |
3211 | # CHECK: for (int x |
3212 | # CHECK: A_2 = A_2 + 1; |
3213 | # CHECK: } |
3214 | # CHECK: A[0] = A_2; |
3215 | # CHECK: })IR" ; |
3216 | |
3217 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3218 | } |
3219 | |
3220 | // Accesses are registerized inside two conditions, cut in the middle. |
3221 | TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { |
3222 | BufHandle a("A" , {1}, kInt); |
3223 | VarHandle x("x" , kInt); |
3224 | StmtPtr stmt = Block::make( |
3225 | {Cond::make( |
3226 | CompareSelect::make(x, 5, CompareSelectOperation::kLT), |
3227 | For::make( |
3228 | x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), |
3229 | nullptr), |
3230 | For::make(x, 0, 10, Store::make(a, {x}, 1)), |
3231 | Cond::make( |
3232 | CompareSelect::make(x, 5, CompareSelectOperation::kGT), |
3233 | For::make( |
3234 | x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), |
3235 | nullptr)}); |
3236 | |
3237 | /* |
3238 | * if (x<5 ? 1 : 0) { |
3239 | * for (int x = 0; x < 10; x++) { |
3240 | * A[0] = (A[0]) + 1; |
3241 | * } |
3242 | * } |
3243 | * for (int x = 0; x < 10; x++) { |
3244 | * A[x] = 1; |
3245 | * } |
3246 | * if (x>5 ? 1 : 0) { |
3247 | * for (int x = 0; x < 10; x++) { |
3248 | * A[0] = (A[0]) + 1; |
3249 | * } |
3250 | * } |
3251 | */ |
3252 | |
3253 | stmt = registerize(stmt); |
3254 | |
3255 | /* |
3256 | * if (x<5 ? 1 : 0) { |
3257 | * int A_1 = A[0]; |
3258 | * for (int x = 0; x < 10; x++) { |
3259 | * A_1 = A_1 + 1; |
3260 | * } |
3261 | * A[0] = A_1; |
3262 | * } |
3263 | * for (int x = 0; x < 10; x++) { |
3264 | * A[x] = 1; |
3265 | * } |
3266 | * if (x>5 ? 1 : 0) { |
3267 | * int A_2 = A[0]; |
3268 | * for (int x = 0; x < 10; x++) { |
3269 | * A_2 = A_2 + 1; |
3270 | * } |
3271 | * A[0] = A_2; |
3272 | * } |
3273 | */ |
3274 | |
3275 | std::ostringstream oss; |
3276 | oss << *stmt; |
3277 | |
3278 | const std::string& verification_pattern = |
3279 | R"IR( |
3280 | # CHECK: if (x<5 |
3281 | # CHECK: int A_1 = A[0]; |
3282 | # CHECK: for (int x |
3283 | # CHECK: A_1 = A_1 + 1; |
3284 | # CHECK: } |
3285 | # CHECK: A[0] = A_1; |
3286 | # CHECK: } |
3287 | # CHECK: for (int x |
3288 | # CHECK: A[x] = 1; |
3289 | # CHECK: if (x>5 |
3290 | # CHECK: int A_2 = A[0]; |
3291 | # CHECK: for (int x |
3292 | # CHECK: A_2 = A_2 + 1; |
3293 | # CHECK: } |
3294 | # CHECK: A[0] = A_2; |
3295 | # CHECK: })IR" ; |
3296 | |
3297 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3298 | } |
3299 | |
3300 | // references a Let var in a local scope which cannot be hoisted out of the |
3301 | // loop. |
3302 | TEST(Registerizer, RegisterizerLoopLetVar) { |
3303 | BufHandle a("A" , {10}, kInt); |
3304 | VarHandle x("x" , kInt); |
3305 | VarHandle y("y" , kInt); |
3306 | StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( |
3307 | x, |
3308 | 0, |
3309 | 10, |
3310 | Block::make( |
3311 | {Let::make(y, 30), |
3312 | Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); |
3313 | |
3314 | /* |
3315 | * for (int x = 0; x < 10; x++) { |
3316 | * int y = 30; |
3317 | * A[y] = x + (A[y]); |
3318 | * } |
3319 | */ |
3320 | |
3321 | std::ostringstream before; |
3322 | before << *stmt; |
3323 | |
3324 | // No change. |
3325 | stmt = registerize(stmt); |
3326 | |
3327 | std::ostringstream after; |
3328 | after << *stmt; |
3329 | |
3330 | ASSERT_EQ(before.str(), after.str()); |
3331 | } |
3332 | |
3333 | // references a Let var in an outer scope that does not prevent hoisting the |
3334 | // initializer. |
3335 | TEST(Registerizer, RegisterizerLoopLetVarOuter) { |
3336 | BufHandle a("A" , {10}, kInt); |
3337 | VarHandle x("x" , kInt); |
3338 | VarHandle y("y" , kInt); |
3339 | StmtPtr stmt = Block::make( |
3340 | {Let::make(y, 30), |
3341 | For::make( |
3342 | x, |
3343 | 0, |
3344 | 10, |
3345 | Block::make( |
3346 | {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); |
3347 | |
3348 | /* |
3349 | * int y = 30; |
3350 | * for (int x = 0; x < 10; x++) { |
3351 | * A[y] = x + (A[y]); |
3352 | * } |
3353 | */ |
3354 | |
3355 | stmt = registerize(stmt); |
3356 | |
3357 | /* |
3358 | * int y = 30; |
3359 | * int A_1 = A[y]; |
3360 | * for (int x = 0; x < 10; x++) { |
3361 | * A_1 = A_1 + x; |
3362 | * } |
3363 | * A[y] = A_1; |
3364 | */ |
3365 | |
3366 | std::ostringstream oss; |
3367 | oss << *stmt; |
3368 | |
3369 | const std::string& verification_pattern = |
3370 | R"IR( |
3371 | # CHECK: int y = 30; |
3372 | # CHECK: int A_1 = A[y]; |
3373 | # CHECK: for (int x |
3374 | # CHECK: A_1 = A_1 + x; |
3375 | # CHECK: A[y] = A_1;)IR" ; |
3376 | |
3377 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3378 | } |
3379 | |
3380 | // Okay so the registerizer generally goes after index flattening, but just in |
3381 | // case. Test multi index registerization. |
3382 | TEST(Registerizer, RegisterizerMultiDim) { |
3383 | BufHandle a("A" , {3, 4, 5}, kInt); |
3384 | VarHandle x("x" , kInt); |
3385 | StmtPtr stmt = Block::make( |
3386 | {Store::make(a, {0, 1, 2}, 0), |
3387 | For::make( |
3388 | x, |
3389 | 0, |
3390 | 10, |
3391 | Block::make({Store::make( |
3392 | a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))}); |
3393 | |
3394 | /* |
3395 | * A[0, 1, 2] = 0; |
3396 | * for (int x = 0; x < 10; x++) { |
3397 | * A[0, 1, 2] = (A[0, 1, 2]) + x; |
3398 | * } |
3399 | */ |
3400 | |
3401 | stmt = registerize(stmt); |
3402 | |
3403 | /* |
3404 | * int A_1 = 0; |
3405 | * for (int x = 0; x < 10; x++) { |
3406 | * A_1 = x + A_1; |
3407 | * } |
3408 | * A[0, 1, 2] = A_1; |
3409 | */ |
3410 | |
3411 | std::ostringstream oss; |
3412 | oss << *stmt; |
3413 | |
3414 | const std::string& verification_pattern = |
3415 | R"IR( |
3416 | # CHECK: int A_1 = 0; |
3417 | # CHECK: for (int x = 0; x < 10; x++) |
3418 | # CHECK-NOT: A[ |
3419 | # CHECK: A_1 = |
3420 | # CHECK: A[0, 1, 2] = A_1;)IR" ; |
3421 | |
3422 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3423 | } |
3424 | |
3425 | // Wont registerize if only some dims match, but will still registerize distinct |
3426 | // elements. |
3427 | TEST(Registerizer, RegisterizerMultiDimPartial) { |
3428 | BufHandle a("A" , {3, 4, 5}, kInt); |
3429 | VarHandle x("x" , kInt); |
3430 | StmtPtr stmt = Block::make( |
3431 | {Store::make(a, {0, 1, 2}, 0), |
3432 | For::make( |
3433 | x, |
3434 | 0, |
3435 | 10, |
3436 | Block::make({Store::make( |
3437 | a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))}); |
3438 | |
3439 | /* |
3440 | * A[0, 1, 2] = 0; |
3441 | * for (int x = 0; x < 10; x++) { |
3442 | * A[0, 2, 2] = (A[0, 1, 4]) + x; |
3443 | * } |
3444 | */ |
3445 | |
3446 | stmt = registerize(stmt); |
3447 | |
3448 | /* |
3449 | * A[0, 1, 2] = 0; |
3450 | * int A_1 = A[0, 1, 4]; |
3451 | * int A_2 = A[0, 2, 2]; |
3452 | * for (int x = 0; x < 10; x++) { |
3453 | * A_2 = A_1 + x; |
3454 | * } |
3455 | * A[0, 2, 2] = A_2; |
3456 | */ |
3457 | |
3458 | std::ostringstream oss; |
3459 | oss << *stmt; |
3460 | |
3461 | const std::string& verification_pattern = |
3462 | R"IR( |
3463 | # CHECK: A[0, 1, 2] = 0; |
3464 | # CHECK: int A_1 = A[0, 1, 4]; |
3465 | # CHECK: int A_2 = A[0, 2, 2]; |
3466 | # CHECK: for ( |
3467 | # CHECK: A_2 = A_1 + x; |
3468 | # CHECK: A[0, 2, 2] = A_2;)IR" ; |
3469 | |
3470 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3471 | } |
3472 | |
3473 | // If they could overlap across all dimensions we cannot registerize. |
3474 | TEST(Registerizer, RegisterizerMultiDimOverlap) { |
3475 | BufHandle a("A" , {3, 4, 5}, kInt); |
3476 | VarHandle x("x" , kInt); |
3477 | VarHandle y("y" , kInt); |
3478 | StmtPtr stmt = Block::make( |
3479 | {Store::make(a, {0, 1, 2}, 0), |
3480 | For::make( |
3481 | x, |
3482 | 0, |
3483 | 10, |
3484 | Block::make({Store::make( |
3485 | a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))}); |
3486 | stmt = IRSimplifier::simplify(stmt); |
3487 | |
3488 | /* |
3489 | * A[0, 1, 2] = 0; |
3490 | * for (int x = 0; x < 10; x++) { |
3491 | * A[0, x, 2] = (A[y, 2, 2]) + x; |
3492 | * } |
3493 | */ |
3494 | |
3495 | std::ostringstream before; |
3496 | before << *stmt; |
3497 | |
3498 | // No change. |
3499 | stmt = registerize(stmt); |
3500 | |
3501 | std::ostringstream after; |
3502 | after << *stmt; |
3503 | |
3504 | ASSERT_EQ(before.str(), after.str()); |
3505 | } |
3506 | |
3507 | // But, if one dimension is known to be distinct they do not overlap. |
3508 | TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { |
3509 | BufHandle a("A" , {3, 4, 5}, kInt); |
3510 | VarHandle x("x" , kInt); |
3511 | VarHandle y("y" , kInt); |
3512 | StmtPtr stmt = Block::make( |
3513 | {Store::make(a, {0, 1, 2}, 0), |
3514 | For::make( |
3515 | x, |
3516 | 0, |
3517 | 10, |
3518 | Block::make({Store::make( |
3519 | a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))}); |
3520 | |
3521 | /* |
3522 | * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. |
3523 | * for (int x = 0; x < 10; x++) { |
3524 | * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. |
3525 | * } |
3526 | */ |
3527 | |
3528 | stmt = registerize(stmt); |
3529 | |
3530 | /* |
3531 | * A[0, 1, 2] = 0; |
3532 | * int A_1 = A[y, 2, 4]; |
3533 | * for (int x = 0; x < 10; x++) { |
3534 | * A[0, x, 2] = A_1 + x; |
3535 | * } |
3536 | */ |
3537 | |
3538 | std::ostringstream oss; |
3539 | oss << *stmt; |
3540 | |
3541 | const std::string& verification_pattern = |
3542 | R"IR( |
3543 | # CHECK: A[0, 1, 2] = 0; |
3544 | # CHECK: int A_1 = A[y, 2, 4]; |
3545 | # CHECK: for ( |
3546 | # CHECK: A[0, x, 2] = A_1 + x; |
3547 | # CHECK: })IR" ; |
3548 | |
3549 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3550 | } |
3551 | |
3552 | // A 3D reduction with different input dimensionality. |
3553 | TEST(Registerizer, RegisterizerMultiDim3DReduction1) { |
3554 | BufHandle a("A" , {10}, kInt); |
3555 | BufHandle b("B" , {10, 10}, kInt); |
3556 | BufHandle c("C" , {10, 10, 10}, kInt); |
3557 | VarHandle x("x" , kInt); |
3558 | VarHandle y("y" , kInt); |
3559 | VarHandle z("z" , kInt); |
3560 | StmtPtr stmt = For::make( |
3561 | x, |
3562 | 0, |
3563 | 10, |
3564 | For::make( |
3565 | y, |
3566 | 0, |
3567 | 10, |
3568 | For::make( |
3569 | z, |
3570 | 0, |
3571 | 10, |
3572 | Store::make( |
3573 | c, |
3574 | {x, y, z}, |
3575 | Add::make( |
3576 | Load::make(c, {x, y, z}), |
3577 | Mul::make(Load::make(b, {x, y}), Load::make(a, {x}))))))); |
3578 | |
3579 | /* |
3580 | * for (int x = 0; x < 10; x++) { |
3581 | * for (int y = 0; y < 10; y++) { |
3582 | * for (int z = 0; z < 10; z++) { |
3583 | * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); |
3584 | * } |
3585 | * } |
3586 | * } |
3587 | */ |
3588 | |
3589 | // We can registerize the A and B access since they can be hoisted before |
3590 | // hitting a dependent loop var. |
3591 | |
3592 | stmt = registerize(stmt); |
3593 | |
3594 | /* |
3595 | * for (int x = 0; x < 10; x++) { |
3596 | * int A_1 = A[x]; |
3597 | * for (int y = 0; y < 10; y++) { |
3598 | * int B_1 = B[x, y]; |
3599 | * for (int z = 0; z < 10; z++) { |
3600 | * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); |
3601 | * } |
3602 | * } |
3603 | * } |
3604 | */ |
3605 | |
3606 | std::ostringstream oss; |
3607 | oss << *stmt; |
3608 | |
3609 | const std::string& verification_pattern = |
3610 | R"IR( |
3611 | # CHECK: for (int x |
3612 | # CHECK: int A_1 = A[x]; |
3613 | # CHECK: for (int y |
3614 | # CHECK: int B_1 = B[x, y]; |
3615 | # CHECK: for (int z |
3616 | # CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); |
3617 | # CHECK: })IR" ; |
3618 | |
3619 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3620 | } |
3621 | |
3622 | // A 3D reduction with the same smaller dimensionality using different loop |
3623 | // vars. |
3624 | TEST(Registerizer, RegisterizerMultiDim3DReduction2) { |
3625 | BufHandle a("A" , {10}, kInt); |
3626 | BufHandle b("B" , {10}, kInt); |
3627 | BufHandle c("C" , {10}, kInt); |
3628 | VarHandle x("x" , kInt); |
3629 | VarHandle y("y" , kInt); |
3630 | VarHandle z("z" , kInt); |
3631 | StmtPtr stmt = For::make( |
3632 | x, |
3633 | 0, |
3634 | 10, |
3635 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
3636 | For::make( |
3637 | y, |
3638 | 0, |
3639 | 10, |
3640 | For::make( |
3641 | z, |
3642 | 0, |
3643 | 10, |
3644 | Store::make( |
3645 | c, |
3646 | {x}, |
3647 | Add::make( |
3648 | Load::make(c, {x}), |
3649 | Mul::make(Load::make(b, {y}), Load::make(a, {x}))))))); |
3650 | |
3651 | /* |
3652 | * for (int x = 0; x < 10; x++) { |
3653 | * for (int y = 0; y < 10; y++) { |
3654 | * for (int z = 0; z < 10; z++) { |
3655 | * C[x] = (C[x]) + (B[y]) * (A[x]); |
3656 | * } |
3657 | * } |
3658 | * } |
3659 | */ |
3660 | |
3661 | // We can registerize all accesses, the A and C access can be hoisted to the |
3662 | // outer loop since they depend only on it's loop var while the B can only be |
3663 | // raised to the loop of y. |
3664 | |
3665 | stmt = registerize(stmt); |
3666 | |
3667 | /* |
3668 | * for (int x = 0; x < 10; x++) { |
3669 | * int A_1 = A[x]; |
3670 | * int C_1 = C[x]; |
3671 | * for (int y = 0; y < 10; y++) { |
3672 | * int B_1 = B[y]; |
3673 | * for (int z = 0; z < 10; z++) { |
3674 | * C_1 = A_1 * B_1 + C_1; |
3675 | * } |
3676 | * } |
3677 | * C[x] = C_1; |
3678 | * } |
3679 | */ |
3680 | |
3681 | std::ostringstream oss; |
3682 | oss << *stmt; |
3683 | |
3684 | const std::string& verification_pattern = |
3685 | R"IR( |
3686 | # CHECK: for (int x |
3687 | # CHECK: int A_1 = A[x]; |
3688 | # CHECK: int C_1 = C[x]; |
3689 | # CHECK: for (int y |
3690 | # CHECK: int B_1 = B[y]; |
3691 | # CHECK: for (int z |
3692 | # CHECK: C_1 = A_1 * B_1 + C_1; |
3693 | # CHECK: } |
3694 | # CHECK: } |
3695 | # CHECK: C[x] = C_1; |
3696 | # CHECK: })IR" ; |
3697 | |
3698 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3699 | } |
3700 | |
3701 | } // namespace jit |
3702 | } // namespace torch |
3703 | |