1#include <gtest/gtest.h>
2
3#include <test/cpp/tensorexpr/test_base.h>
4#include <memory>
5#include <sstream>
6#include <stdexcept>
7#include <unordered_map>
8
9#include <test/cpp/tensorexpr/padded_buffer.h>
10#include <test/cpp/tensorexpr/test_utils.h>
11#include <torch/csrc/jit/tensorexpr/analysis.h>
12#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
13#include <torch/csrc/jit/tensorexpr/eval.h>
14#include <torch/csrc/jit/tensorexpr/ir.h>
15#include <torch/csrc/jit/tensorexpr/ir_printer.h>
16#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
17#include <torch/csrc/jit/tensorexpr/loopnest.h>
18#include <torch/csrc/jit/tensorexpr/tensor.h>
19#include <torch/csrc/jit/testing/file_check.h>
20
21namespace torch {
22namespace jit {
23
24using namespace torch::jit::tensorexpr;
25
26void checkIR(StmtPtr s, const std::string& pattern) {
27 std::ostringstream oss;
28 oss << *s;
29 torch::jit::testing::FileCheck().run(pattern, oss.str());
30}
31
32void checkExprIR(ExprPtr e, const std::string& pattern) {
33 std::string prefixed_pattern = "# CHECK: " + pattern + "\n";
34 std::ostringstream oss;
35 oss << *e << "\n";
36 torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str());
37}
38
39void checkExprIR(const ExprHandle& e, const std::string& pattern) {
40 checkExprIR(e.node(), pattern);
41}
42
43TEST(LoopNest, ExprSimple01) {
44 Tensor tensor =
45 Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) {
46 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
47 });
48 LoopNest l({tensor});
49 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
50
51 LoopNest::splitWithTail(loops[0], 2);
52 LoopNest::splitWithTail(loops[0], 2);
53}
54
55TEST(LoopNest, ExprLower01) {
56 Tensor tensor =
57 Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) {
58 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
59 });
60 LoopNest l({tensor});
61 StmtPtr stmt = l.root_stmt();
62 std::ostringstream oss;
63 oss << *stmt;
64 ASSERT_GT(oss.str().size(), 20);
65 ASSERT_LT(oss.str().size(), 200);
66}
67
68TEST(LoopNest, ExprSimple02) {
69 auto func = [](const ExprHandle& x, const ExprHandle& y) {
70 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
71 };
72 Tensor tensor = Compute("f", {26, 5}, func);
73 LoopNest l({tensor});
74 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
75
76 LoopNest::splitWithTail(loops[0], 4);
77
78 StmtPtr stmt = l.root_stmt();
79 std::ostringstream oss;
80 oss << *stmt;
81 ASSERT_GT(oss.str().size(), 200);
82 ASSERT_LT(oss.str().size(), 600);
83
84 {
85 // Compare to a reference loop structure structure.
86 VarHandle x_outer("i_outer", kInt);
87 VarHandle x_inner("i_inner", kInt);
88 VarHandle y("i", kInt);
89 VarHandle x_tail("i_tail", kInt);
90 BufHandle f("f", {26, 5}, kFloat);
91 ExprHandle x_1 = x_outer * 4 + x_inner;
92 ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4;
93 ForPtr stmt1 = For::make(
94 x_outer,
95 0,
96 x_outer_end,
97 For::make(
98 x_inner,
99 0,
100 4,
101 For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))));
102 ExprHandle x_2 = x_tail + x_outer_end * 4;
103 ForPtr stmt2 = For::make(
104 x_tail,
105 0,
106 (ExprHandle(26) - 0) % 4,
107 For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y))));
108 StmtPtr stmt = Block::make({stmt1, stmt2});
109
110 std::ostringstream oss_ref;
111 oss_ref << *stmt;
112 ASSERT_EQ(oss.str(), oss_ref.str());
113 }
114
115 {
116 PaddedBuffer<float> f_v(26, 5, "f_v");
117 PaddedBuffer<float> f_ref(26, 5, "f_res");
118
119 stmt = FlattenIndexes(stmt);
120 SimpleIREvaluator ir_eval(stmt, {tensor});
121 ir_eval(f_v);
122
123 for (int x = 0; x < 26; x++) {
124 for (int y = 0; y < 5; y++) {
125 f_ref(x, y) = 1 + x * x + y * y;
126 }
127 }
128
129 ExpectAllNear(f_v, f_ref, 1e-5);
130 }
131}
132
133BlockPtr getSimplifiedBody(const LoopNest& l) {
134 StmtPtr stmt = l.root_stmt();
135 StmtPtr simplified = IRSimplifier::simplify(stmt);
136 return to<Block>(simplified);
137}
138
139void assertForRange(ForPtr f, int expected_start, int expected_stop) {
140 ASSERT_NE(f, nullptr);
141 IntImmPtr start = to<IntImm>(f->start());
142 ASSERT_NE(start, nullptr);
143 ASSERT_EQ(start->value(), expected_start);
144 IntImmPtr stop = to<IntImm>(f->stop());
145 ASSERT_NE(stop, nullptr);
146 ASSERT_EQ(stop->value(), expected_stop);
147}
148
149void assertForRanges(
150 BlockPtr body,
151 const std::vector<std::pair<int, int>>& start_stops) {
152 ASSERT_EQ(body->nstmts(), start_stops.size());
153
154 auto it = body->begin();
155 for (size_t i = 0; i < start_stops.size(); i++, it++) {
156 ForPtr loop = to<For>(*it);
157 assertForRange(loop, start_stops[i].first, start_stops[i].second);
158 }
159}
160
161TEST(LoopNest, ExprSliceHeadWithLoopOptions) {
162 auto func = [](const ExprHandle& x) {
163 return ExprHandle(1.0f) + cast<float>(x);
164 };
165 Tensor tensor = Compute("f", {10}, func);
166 LoopNest l({tensor});
167 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
168 ForPtr head;
169 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
170 ForPtr tail;
171 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
172 loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
173 LoopNest::sliceHead(loops[0], 2, &head, &tail);
174
175 BlockPtr body = getSimplifiedBody(l);
176 assertForRanges(body, {{0, 2}, {0, 8}});
177
178 ASSERT_TRUE(tail->loop_options().is_gpu_block_index());
179 ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
180
181 ASSERT_TRUE(head->loop_options().isDefault());
182}
183
184TEST(LoopNest, ExprSliceTailWithLoopOptions) {
185 auto func = [](const ExprHandle& x) {
186 return ExprHandle(1.0f) + cast<float>(x);
187 };
188 Tensor tensor = Compute("f", {10}, func);
189 LoopNest l({tensor});
190 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
191 ForPtr head;
192 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
193 ForPtr tail;
194 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
195 LoopNest::sliceTail(loops[0], 4, &head, &tail);
196
197 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
198 ForPtr tail_head;
199 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
200 ForPtr tail_tail;
201 tail->set_gpu_block_index(LoopOptions::IDX_Y);
202 LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail);
203
204 BlockPtr body = getSimplifiedBody(l);
205 assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}});
206
207 ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index());
208 ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
209
210 ASSERT_TRUE(head->loop_options().isDefault());
211 ASSERT_TRUE(tail_tail->loop_options().isDefault());
212}
213
214TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) {
215 // When factor equals the For loop's original size, keep using the original
216 // For loop.
217 auto func = [](const ExprHandle& x) {
218 return ExprHandle(1.0f) + cast<float>(x);
219 };
220 Tensor tensor = Compute("f", {10}, func);
221 LoopNest l({tensor});
222 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
223 ForPtr head;
224 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
225 ForPtr tail;
226 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
227 LoopNest::sliceHead(loops[0], 10, &head, &tail);
228
229 ASSERT_EQ(head, loops[0]);
230 ASSERT_EQ(tail, nullptr);
231
232 BlockPtr body = getSimplifiedBody(l);
233 assertForRanges(body, {{0, 10}});
234}
235
236TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) {
237 auto func = [](const ExprHandle& x) {
238 return ExprHandle(1.0f) + cast<float>(x);
239 };
240 Tensor tensor = Compute("f", {10}, func);
241 LoopNest l({tensor});
242 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
243 ForPtr head;
244 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
245 ForPtr tail;
246 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
247 LoopNest::sliceHead(loops[0], 100, &head, &tail);
248
249 ASSERT_EQ(head, loops[0]);
250 ASSERT_EQ(tail, nullptr);
251
252 BlockPtr body = getSimplifiedBody(l);
253 assertForRanges(body, {{0, 10}});
254}
255
256TEST(LoopNest, ExprSliceHead) {
257 auto func = [](const ExprHandle& x) {
258 return ExprHandle(1.0f) + cast<float>(x);
259 };
260 Tensor tensor = Compute("f", {10}, func);
261 LoopNest l({tensor});
262 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
263 ForPtr head;
264 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
265 ForPtr tail;
266 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
267 LoopNest::sliceHead(loops[0], 4, &head, &tail);
268
269 ASSERT_NE(head, nullptr);
270 ASSERT_NE(head, loops[0]);
271 ASSERT_NE(tail, nullptr);
272 ASSERT_EQ(tail, loops[0]);
273
274 BlockPtr body = getSimplifiedBody(l);
275 assertForRanges(body, {{0, 4}, {4, 10}});
276}
277
278TEST(LoopNest, ExprSliceHeadWithNonZeroStart) {
279 auto func = [](const ExprHandle& x) {
280 return ExprHandle(1.0f) + cast<float>(x);
281 };
282 Tensor tensor = Compute("f", {10}, func);
283 LoopNest l({tensor});
284 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
285
286 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
287 ForPtr head;
288 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
289 ForPtr tail;
290 LoopNest::sliceTail(loops[0], 4, &head, &tail);
291 // head: [0, 6)
292 // tail: [6, 10)
293
294 LoopNest::sliceHead(tail, 2);
295 // tail_head: [6, 8)
296 // tail_tail: [8, 10)
297
298 BlockPtr body = getSimplifiedBody(l);
299 assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}});
300}
301
302TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) {
303 // When factor equals the For loop's original size, keep using the original
304 // For loop.
305 auto func = [](const ExprHandle& x) {
306 return ExprHandle(1.0f) + cast<float>(x);
307 };
308 Tensor tensor = Compute("f", {10}, func);
309 LoopNest l({tensor});
310 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
311 ForPtr head;
312 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
313 ForPtr tail;
314 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
315 LoopNest::sliceTail(loops[0], 10, &head, &tail);
316
317 ASSERT_EQ(head, nullptr);
318 ASSERT_EQ(tail, loops[0]);
319
320 BlockPtr body = getSimplifiedBody(l);
321 assertForRanges(body, {{0, 10}});
322}
323
324TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) {
325 // When factor equals the For loop's original size, keep using the original
326 // For loop.
327 auto func = [](const ExprHandle& x) {
328 return ExprHandle(1.0f) + cast<float>(x);
329 };
330 Tensor tensor = Compute("f", {10}, func);
331 LoopNest l({tensor});
332 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
333 ForPtr head;
334 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
335 ForPtr tail;
336 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
337 LoopNest::sliceTail(loops[0], 100, &head, &tail);
338
339 ASSERT_EQ(head, nullptr);
340 ASSERT_EQ(tail, loops[0]);
341
342 BlockPtr body = getSimplifiedBody(l);
343 assertForRanges(body, {{0, 10}});
344}
345
346TEST(LoopNest, ExprSliceTail) {
347 auto func = [](const ExprHandle& x) {
348 return ExprHandle(1.0f) + cast<float>(x);
349 };
350 Tensor tensor = Compute("f", {10}, func);
351 LoopNest l({tensor});
352 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
353 ForPtr head;
354 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
355 ForPtr tail;
356 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
357 LoopNest::sliceTail(loops[0], 4, &head, &tail);
358
359 ASSERT_NE(head, nullptr);
360 ASSERT_EQ(head, loops[0]);
361 ASSERT_NE(tail, nullptr);
362 ASSERT_NE(tail, loops[0]);
363
364 BlockPtr body = getSimplifiedBody(l);
365 assertForRanges(body, {{0, 6}, {6, 10}});
366}
367
368TEST(LoopNest, ExprSplitAndSlice) {
369 // 0: splitWithTail
370 // 1: sliceTail on inner loop
371 // 2: sliceHead on outer loop
372 auto func = [](const ExprHandle& x) {
373 return ExprHandle(1.0f) + cast<float>(x);
374 };
375 Tensor tensor = Compute("f", {100}, func);
376 LoopNest l({tensor});
377
378 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
379 ForPtr inner;
380 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
381 ForPtr tail;
382 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
383 // outer: [0, 4)
384 // inner: [0, 21)
385 // tail: [84, 100)
386 LoopNest::splitWithTail(loops[0], 21, &inner, &tail);
387 LoopNest::sliceTail(inner, 2);
388 LoopNest::sliceHead(loops[0], 2);
389
390 // for (int x_outer = 0; x_outer < 2; x_outer++) {
391 // for (int x_inner = 0; x_inner < 19; x_inner++) {
392 // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
393 // }
394 // for (int x_inner = 19; x_inner < 21; x_inner++) {
395 // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
396 // }
397 // }
398 // for (int x_outer = 2; x_outer < 4; x_outer++) {
399 // for (int x_inner = 0; x_inner < 19; x_inner++) {
400 // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
401 // }
402 // for (int x_inner = 19; x_inner < 21; x_inner++) {
403 // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner);
404 // }
405 // }
406 // for (int x_tail = 0; x_tail < 16; x_tail++) {
407 // f[x_tail + 84] = 1.f + float(x_tail + 84);
408 // }
409 BlockPtr body = getSimplifiedBody(l);
410 assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}});
411
412 auto biter = body->begin();
413
414 ForPtr loop = to<For>(*biter++);
415 assertForRanges(loop->body(), {{0, 19}, {19, 21}});
416
417 loop = to<For>(*biter);
418 assertForRanges(loop->body(), {{0, 19}, {19, 21}});
419}
420
421TEST(LoopNest, ExprSliceAndNormalize) {
422 // 0: sliceHead
423 // 1: normalize tail
424 auto func = [](const ExprHandle& x) {
425 return ExprHandle(1.0f) + cast<float>(x);
426 };
427 Tensor tensor = Compute("f", {10}, func);
428 LoopNest l({tensor});
429 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
430
431 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
432 ForPtr head;
433 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
434 ForPtr tail;
435 LoopNest::sliceHead(loops[0], 2, &head, &tail);
436 // head: [0, 2)
437 // tail: [2, 10)
438
439 LoopNest::normalize(tail);
440 // normalized_tail: [0, 8)
441
442 BlockPtr body = getSimplifiedBody(l);
443 assertForRanges(body, {{0, 2}, {0, 8}});
444}
445
446template <typename T>
447T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) {
448 ExprEval<SimpleIREvaluator> eval(expr, {var});
449 return eval.value<T>(value);
450}
451
452TEST(LoopNest, ExprSliceWithVariableDimension) {
453 auto testWithDimension =
454 [](int dimension,
455 const std::vector<std::pair<int, int>>& expected_for_ranges) {
456 VarHandle dim("dim", kInt);
457 Tensor tensor =
458 Compute("f", {dim}, [](const ExprHandle& x) { return x; });
459 LoopNest l({tensor});
460 std::vector<ForPtr> loops =
461 l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
462
463 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
464 ForPtr head;
465 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
466 ForPtr tail;
467 LoopNest::sliceHead(loops[0], 2, &head, &tail);
468
469 LoopNest::sliceTail(tail, 2);
470
471 BlockPtr body = getSimplifiedBody(l);
472 ASSERT_EQ(expected_for_ranges.size(), 3);
473 auto it = body->begin();
474 for (auto& start_stop : expected_for_ranges) {
475 ForPtr loop = to<For>(*it++);
476 int start = evalExpr<int>(ExprHandle(loop->start()), dim, dimension);
477 int stop = evalExpr<int>(ExprHandle(loop->stop()), dim, dimension);
478 ASSERT_EQ(start, start_stop.first);
479 ASSERT_EQ(stop, start_stop.second);
480 }
481 };
482
483 testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}});
484 testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}});
485 testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}});
486 testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}});
487 testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}});
488 testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}});
489}
490
491TEST(LoopNest, ExprSplitWithTail) {
492 auto func = [](const ExprHandle& x) {
493 return ExprHandle(1.0f) + cast<float>(x);
494 };
495 Tensor tensor = Compute("f", {199}, func);
496 LoopNest l({tensor});
497 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
498 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
499 LoopNest::splitWithTail(loops[0], 17);
500 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
501 LoopNest::splitWithTail(loops[0], 7);
502
503 StmtPtr stmt = l.root_stmt();
504 StmtPtr simplified = IRSimplifier::simplify(stmt);
505 BlockPtr body = to<Block>(simplified);
506 ASSERT_EQ(body->nstmts(), 3);
507 auto biter = body->begin();
508
509 // Verify that the split loops are ordered correctly.
510 ForPtr loop = to<For>(*biter++);
511 assertForRange(loop, 0, 7);
512
513 loop = to<For>(*biter++);
514 assertForRange(loop, 0, 4);
515
516 loop = to<For>(*biter);
517 assertForRange(loop, 0, 12);
518}
519
520TEST(LoopNest, ExprSplitWithTailNone) {
521 auto func = [](const ExprHandle& x, const ExprHandle& y) {
522 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
523 };
524 Tensor tensor = Compute("f", {24, 5}, func);
525 LoopNest l({tensor});
526 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
527 LoopNest::splitWithTail(loops[0], 4);
528
529 StmtPtr stmt = l.root_stmt();
530 std::ostringstream oss;
531 oss << *stmt;
532 ASSERT_GT(oss.str().size(), 200);
533 ASSERT_LT(oss.str().size(), 600);
534
535 {
536 // Compare to a reference loop structure structure.
537 VarHandle x_outer("i_outer", kInt);
538 VarHandle x_inner("i_inner", kInt);
539 VarHandle y("i", kInt);
540 VarHandle x_tail("i_tail", kInt);
541 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
542 BufHandle f("f", {24, 5}, kFloat);
543 ExprHandle x_1 = x_outer * 4 + x_inner;
544 ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4;
545 StmtPtr stmt = alloc<Block>(std::vector<StmtPtr>({For::make(
546 x_outer,
547 0,
548 x_outer_end,
549 For::make(
550 x_inner,
551 0,
552 4,
553 For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))}));
554
555 std::ostringstream oss_ref;
556 oss_ref << *stmt;
557 ASSERT_EQ(oss.str(), oss_ref.str());
558 }
559
560 {
561 PaddedBuffer<float> f_v(24, 5, "f_v");
562 PaddedBuffer<float> f_ref(24, 5, "f_res");
563
564 SimpleIREvaluator ir_eval(stmt, {tensor});
565 ir_eval(f_v);
566
567 for (int x = 0; x < 24; x++) {
568 for (int y = 0; y < 5; y++) {
569 f_ref(x, y) = 1 + x * x + y * y;
570 }
571 }
572
573 ExpectAllNear(f_v, f_ref, 1e-5);
574 }
575}
576
577TEST(LoopNest, ExprSplitWithMask01) {
578 const int M = 26;
579 const int N = 5;
580 BufHandle a_buf("a", {M, N}, kFloat);
581 BufHandle b_buf("b", {M, N}, kFloat);
582 Tensor tensor =
583 Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
584 return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f;
585 });
586
587 LoopNest l({tensor});
588 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
589 LoopNest::splitWithMask(loops[1], 4);
590
591 StmtPtr stmt = l.root_stmt();
592
593 PaddedBuffer<float> a_v(M, N, "a");
594 PaddedBuffer<float> b_v(M, N, "b");
595 PaddedBuffer<float> c_v(M, N, "c");
596 PaddedBuffer<float> c_ref(M, N, "c_ref");
597 for (int m = 0; m < M; m++) {
598 for (int n = 0; n < N; n++) {
599 a_v(m, n) = 2 * m;
600 b_v(m, n) = 3 * n;
601 c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
602 }
603 }
604
605 SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
606
607 ExpectAllNear(c_v, c_ref, 1e-5);
608}
609
610// Tests the case where we split a loop cleanly multiple times, we should not
611// insert any masks.
612TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) {
613 const int M = 64;
614 BufHandle a_buf("a", {M}, kFloat);
615 BufHandle b_buf("b", {M}, kFloat);
616 Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
617 return a_buf.load(m) + b_buf.load(m) + 1.0f;
618 });
619
620 LoopNest l({tensor});
621 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
622 LoopNest::splitWithMask(loops[0], 4);
623 LoopNest::splitWithMask(loops[0], 4);
624
625 StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt());
626
627 // Two splits mean 3 loops, but should need no masks in this case.
628 checkIR(stmt1, R"IR(
629# CHECK: for (
630# CHECK-NOT: if (
631# CHECK: for (
632# CHECK-NOT: if (
633# CHECK: for (
634# CHECK-NOT: if (
635# CHECK: f[)IR");
636}
637
638TEST(LoopNest, getLoopAt) {
639 // Input IR:
640 // for (int i = 0; i < 100; i++) {
641 // for (int j = 0; j < 100; j++) {
642 // A[i, j] = sin(i * j);
643 // for (int k1 = 0; k1 < 200; k1++) {
644 // B[i, j, k1] = (A[i, j]) / (k1 + 1);
645 // }
646 // for (int k2 = 0; k2 < 300; k2++) {
647 // C[i, j, k2] = (A[i, j]) * (k2 + 1);
648 // }
649 // }
650 // }
651 BufPtr A = alloc<Buf>(
652 "A",
653 std::vector<ExprPtr>({alloc<IntImm>(100), alloc<IntImm>(100)}),
654 kInt);
655 BufPtr B = alloc<Buf>(
656 "B",
657 std::vector<ExprPtr>(
658 {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(200)}),
659 kInt);
660 BufPtr C = alloc<Buf>(
661 "C",
662 std::vector<ExprPtr>(
663 {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(300)}),
664 kInt);
665 BufHandle a_buf(A);
666 BufHandle b_buf(B);
667 BufHandle c_buf(C);
668 VarHandle i("i", kInt);
669 VarHandle j("j", kInt);
670 VarHandle k1("k1", kInt);
671 VarHandle k2("k2", kInt);
672 auto store1 = Store::make(a_buf, {i, j}, sin(i * j));
673 auto store2 = Store::make(
674 b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1)));
675 auto store3 = Store::make(
676 c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1)));
677 auto for_k2 = For::make(k2, 0, 300, Block::make({store3}));
678 auto for_k1 = For::make(k1, 0, 200, Block::make({store2}));
679 auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2}));
680 auto for_i = For::make(i, 0, 100, for_j);
681 LoopNest l(Block::make({for_i}), {B, C});
682 auto ret_k2 = l.getLoopAt(for_i, {0, 2});
683 TORCH_CHECK(ret_k2 == for_k2);
684
685 std::ostringstream oss;
686 oss << *ret_k2;
687 const std::string& verification_pattern =
688 R"IR(
689# CHECK: for (int k2
690# CHECK-NEXT: C[i, j, k2] =
691 )IR";
692 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
693}
694
695TEST(LoopNest, TileSimple) {
696 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
697 const int M = 64, N = 64;
698 BufHandle a_buf("a", {M, N}, kFloat);
699 BufHandle b_buf("b", {M, N}, kFloat);
700 Tensor tensor =
701 Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
702 return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f;
703 });
704
705 LoopNest l({tensor});
706 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
707 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
708 l.tile(loops[0], loops[1], 4, 8);
709
710 // IR check
711 StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
712 checkIR(stmt, R"IR(
713# CHECK: for (int i_outer
714# CHECK: for (int i_outer_1
715# CHECK: for (int i_inner
716# CHECK: for (int i_inner_1
717# CHECK: f[
718# CHECK-NOT: for (int i_tail
719# CHECK-NOT: for (int i_tail)IR");
720
721 // Correctness check
722 PaddedBuffer<float> a_v(M, N, "a");
723 PaddedBuffer<float> b_v(M, N, "b");
724 PaddedBuffer<float> c_v(M, N, "c");
725 PaddedBuffer<float> c_ref(M, N, "c_ref");
726 for (int m = 0; m < M; m++) {
727 for (int n = 0; n < N; n++) {
728 a_v(m, n) = 2 * m;
729 b_v(m, n) = 3 * n;
730 c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
731 }
732 }
733
734 SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
735
736 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
737 ExpectAllNear(c_v, c_ref, 1e-5);
738}
739
740TEST(LoopNest, TileWithTails) {
741 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
742 const int M = 64, N = 64;
743 BufHandle a_buf("a", {M, N}, kFloat);
744 BufHandle b_buf("b", {M, N}, kFloat);
745 Tensor tensor =
746 Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
747 return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f;
748 });
749
750 LoopNest l({tensor});
751 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
752 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
753 l.tile(loops[0], loops[1], 5, 9);
754
755 // IR check
756 StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
757 checkIR(stmt, R"IR(
758# CHECK: for (int i_outer
759# CHECK: for (int i_outer_1
760# CHECK: for (int i_inner
761# CHECK: for (int i_inner_1
762# CHECK: f[
763# CHECK: for (int i_inner
764# CHECK: f[
765# CHECK: for (int i_tail)IR");
766
767 // Correctness check
768 PaddedBuffer<float> a_v(M, N, "a");
769 PaddedBuffer<float> b_v(M, N, "b");
770 PaddedBuffer<float> c_v(M, N, "c");
771 PaddedBuffer<float> c_ref(M, N, "c_ref");
772 for (int m = 0; m < M; m++) {
773 for (int n = 0; n < N; n++) {
774 a_v(m, n) = 2 * m;
775 b_v(m, n) = 3 * n;
776 c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f;
777 }
778 }
779
780 SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
781
782 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
783 ExpectAllNear(c_v, c_ref, 1e-5);
784}
785
786TEST(LoopNest, TileInMiddle) {
787 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
788 const int M = 8, N = 8, L = 8, K = 8;
789 BufHandle a_buf("a", {M, N, L, K}, kFloat);
790 BufHandle b_buf("b", {M, N, L, K}, kFloat);
791 Tensor tensor = Compute(
792 "f",
793 {M, N, L, K},
794 [&](const ExprHandle& m,
795 const ExprHandle& n,
796 const ExprHandle& l,
797 const ExprHandle& k) {
798 return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f;
799 });
800
801 LoopNest nest({tensor});
802 std::vector<ForPtr> loops =
803 nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
804 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
805 nest.tile(loops[1], loops[2], 3, 3);
806
807 // IR check
808 StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt());
809 checkIR(stmt, R"IR(
810# CHECK: for (int i
811# CHECK: for (int i_outer
812# CHECK: for (int i_outer_1
813# CHECK: for (int i_inner
814# CHECK: for (int i_inner_1
815# CHECK: for (int i_1
816# CHECK: f[
817# CHECK: for (int i_tail_1
818# CHECK: for (int i_inner_1
819# CHECK: for (int i_1
820# CHECK: f[
821# CHECK: for (int i_tail)IR");
822
823 // Correctness check
824 PaddedBuffer<float> a_v(M, N, L, K, "a");
825 PaddedBuffer<float> b_v(M, N, L, K, "b");
826 PaddedBuffer<float> c_v(M, N, L, K, "c");
827 PaddedBuffer<float> c_ref(M, N, L, K, "c_ref");
828 for (int m = 0; m < M; m++) {
829 for (int n = 0; n < N; n++) {
830 for (int l = 0; l < L; l++) {
831 for (int k = 0; k < K; k++) {
832 a_v(m, n, l, k) = 2 * (m + l);
833 b_v(m, n, l, k) = 3 * (n + k);
834 c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f;
835 }
836 }
837 }
838 }
839
840 SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v);
841
842 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
843 ExpectAllNear(c_v, c_ref, 1e-5);
844}
845
846TEST(LoopNest, SplitWithTailWithLoopOptions) {
847 const int M = 21;
848 BufHandle a_buf("a", {M}, kFloat);
849 BufHandle b_buf("b", {M}, kFloat);
850 Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
851 return a_buf.load(m) + b_buf.load(m) + 1.0f;
852 });
853 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
854 ForPtr inner, tail;
855
856 LoopNest l({tensor});
857 auto loops = NodeFinder<For>::find(l.root_stmt());
858 ASSERT_GT(loops.size(), 0);
859 loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
860 LoopNest::splitWithTail(loops[0], 4, &inner, &tail);
861 ASSERT_NE(inner, nullptr);
862 ASSERT_NE(tail, nullptr);
863 ForPtr outer = loops[0];
864
865 // Outer loop carries loop axis bindings.
866 ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
867 ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
868
869 // Inner loop has none.
870 ASSERT_TRUE(inner->loop_options().isDefault());
871
872 // Tail loop has none.
873 ASSERT_TRUE(tail->loop_options().isDefault());
874}
875
876TEST(LoopNest, SplitWithMaskWithLoopOptions) {
877 const int M = 21;
878 BufHandle a_buf("a", {M}, kFloat);
879 BufHandle b_buf("b", {M}, kFloat);
880 Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) {
881 return a_buf.load(m) + b_buf.load(m) + 1.0f;
882 });
883 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
884 ForPtr inner;
885
886 LoopNest l({tensor});
887 auto loops = NodeFinder<For>::find(l.root_stmt());
888 loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
889 LoopNest::splitWithMask(loops[0], 4, &inner);
890 ForPtr outer = loops[0];
891
892 // Outer loop carries loop axis bindings.
893 ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
894 ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y);
895
896 // Inner loop has none.
897 ASSERT_TRUE(inner->loop_options().isDefault());
898}
899
900TEST(LoopNest, ScheduleBroadcastAddBuffer) {
901 const int M = 4;
902 const int N = 5;
903 const int K = 6;
904 BufHandle a_buf("a", {M, N}, kFloat);
905 BufHandle b_buf("b", {N, K}, kFloat);
906 Tensor c = Compute(
907 "broadcast_add",
908 {M, N, K},
909 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
910 return a_buf.load(m, n) + b_buf.load(n, k);
911 });
912 LoopNest l({c});
913 StmtPtr stmt = l.root_stmt();
914
915 PaddedBuffer<float> a_v(M, N, "a_v");
916 for (int m = 0; m < M; m++) {
917 for (int n = 0; n < N; n++) {
918 a_v(m, n) = 7 * m * n;
919 }
920 }
921 a_v.Backup();
922
923 PaddedBuffer<float> b_v(N, K, "b_v");
924 for (int n = 0; n < N; n++) {
925 for (int k = 0; k < K; k++) {
926 b_v(n, k) = 11 * n * k;
927 }
928 }
929 b_v.Backup();
930
931 PaddedBuffer<float> c_v(M, N, K, "c_buf");
932 SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c});
933 ir_eval(a_v, b_v, c_v);
934
935 a_v.CheckBackup();
936 b_v.CheckBackup();
937 PaddedBuffer<float> c_ref(M, N, K, "c_ref");
938 for (int m = 0; m < M; m++) {
939 for (int n = 0; n < N; n++) {
940 for (int k = 0; k < K; k++) {
941 c_ref(m, n, k) = 7 * m * n + 11 * n * k;
942 }
943 }
944 }
945 ExpectAllNear(c_v, c_ref, 1e-5);
946}
947
948TEST(LoopNest, ScheduleFunctionCall01) {
949 const int M = 4;
950 const int N = 5;
951 const int K = 6;
952 BufHandle a_buf("a", {M, N}, kFloat);
953 BufHandle b_buf("b", {N, K}, kFloat);
954 Tensor c = Compute(
955 "broadcast_add",
956 {M, N, K},
957 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
958 return a_buf.load(m, n) + b_buf.load(n, k);
959 });
960 Tensor d = Compute(
961 "d",
962 {M, N, K},
963 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
964 return c.load(m, n, k) + 1;
965 });
966
967 LoopNest l({d}, {c, d});
968 l.prepareForCodegen();
969 StmtPtr stmt = l.root_stmt();
970 std::ostringstream oss;
971 oss << *stmt;
972 ASSERT_GT(oss.str().size(), 100);
973
974 PaddedBuffer<float> a_v(M, N);
975 PaddedBuffer<float> b_v(N, K);
976 PaddedBuffer<float> c_v(M, N, K);
977 PaddedBuffer<float> d_v(M, N, K);
978 PaddedBuffer<float> d_ref(M, N, K);
979
980 for (int i = 0; i < M; i++) {
981 for (int j = 0; j < N; j++) {
982 a_v(i, j) = i * i;
983 }
984 }
985 for (int i = 0; i < N; i++) {
986 for (int j = 0; j < K; j++) {
987 b_v(i, j) = j * j;
988 }
989 }
990 for (int i = 0; i < M; i++) {
991 for (int j = 0; j < N; j++) {
992 for (int k = 0; k < K; k++) {
993 d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1;
994 }
995 }
996 }
997
998 SimpleIREvaluator eval(stmt, {a_buf, b_buf, d});
999 eval(a_v, b_v, d_v);
1000
1001 ExpectAllNear(d_v, d_ref, 1e-5);
1002}
1003
1004TEST(LoopNest, ScheduleInlineSimple) {
1005 const int M = 4;
1006 const int N = 5;
1007 const int K = 6;
1008 BufHandle a_buf("a", {M, N}, kFloat);
1009 BufHandle b_buf("b", {N, K}, kFloat);
1010 BufHandle c_buf("c", {M, N}, kFloat);
1011 BufHandle d_buf("d", {M, K}, kFloat);
1012
1013 Tensor x = Compute(
1014 "x",
1015 {M, N, K},
1016 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1017 return a_buf.load(m, n) * b_buf.load(n, k);
1018 });
1019 Tensor y = Compute(
1020 "y",
1021 {M, N, K},
1022 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1023 return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
1024 });
1025
1026 LoopNest l1({y}, {x, y});
1027 LoopNest l2(l1);
1028 l2.computeInline(x.buf());
1029
1030 l1.prepareForCodegen();
1031 l2.prepareForCodegen();
1032
1033 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1034 StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1035
1036 SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y});
1037 SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y});
1038
1039 PaddedBuffer<float> a_v(M, N);
1040 PaddedBuffer<float> b_v(N, K);
1041 PaddedBuffer<float> c_v(M, N);
1042 PaddedBuffer<float> d_v(M, K);
1043
1044 for (int i = 0; i < M; i++) {
1045 for (int j = 0; j < N; j++) {
1046 a_v(i, j) = i * i;
1047 }
1048 }
1049 for (int i = 0; i < N; i++) {
1050 for (int j = 0; j < K; j++) {
1051 b_v(i, j) = j * j;
1052 }
1053 }
1054 for (int i = 0; i < M; i++) {
1055 for (int j = 0; j < N; j++) {
1056 c_v(i, j) = i + j;
1057 }
1058 }
1059 for (int i = 0; i < M; i++) {
1060 for (int j = 0; j < K; j++) {
1061 d_v(i, j) = i * j;
1062 }
1063 }
1064
1065 PaddedBuffer<float> y_1(M, N, K);
1066 PaddedBuffer<float> y_2(M, N, K);
1067
1068 eval1(a_v, b_v, c_v, d_v, y_1);
1069 eval2(a_v, b_v, c_v, d_v, y_2);
1070 ExpectAllNear(y_1, y_2, 1e-5);
1071 std::ostringstream oss1, oss2;
1072 oss1 << *stmt1;
1073 oss2 << *stmt2;
1074 ASSERT_GT(oss1.str().size(), oss2.str().size());
1075}
1076
1077static std::string remove_space(const std::string& str) {
1078 std::string str_new = str;
1079 str_new.erase(
1080 remove_if(str_new.begin(), str_new.end(), isspace), str_new.end());
1081 return str_new;
1082}
1083
1084void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
1085 const int M = 4;
1086 const int N = 5;
1087 const int K = 6;
1088 BufHandle a_buf("a", {M, N}, kFloat);
1089 BufHandle b_buf("b", {N, K}, kFloat);
1090 BufHandle c_buf("c", {M, N}, kFloat);
1091 BufHandle d_buf("d", {M, K}, kFloat);
1092
1093 Tensor x = Compute(
1094 "x",
1095 {M, N, K},
1096 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1097 return a_buf.load(m, n) * b_buf.load(n, k);
1098 });
1099 Tensor y = Compute(
1100 "y",
1101 {M, N, K},
1102 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1103 return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
1104 });
1105 Tensor z = Compute(
1106 "z",
1107 {M, N, K},
1108 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1109 return x.load(m, n, k) + y.load(m, n, k);
1110 });
1111
1112 LoopNest l({z}, {x, y, z});
1113 for (const std::string& order : inline_order) {
1114 if (order == "x") {
1115 l.computeInline(x.buf());
1116 } else if (order == "y") {
1117 l.computeInline(y.buf());
1118 } else {
1119 throw std::runtime_error("Invalid order: " + order);
1120 }
1121 }
1122 l.prepareForCodegen();
1123 StmtPtr stmt = l.root_stmt();
1124
1125 std::ostringstream oss;
1126 oss << *stmt;
1127 std::string str1 = remove_space(oss.str());
1128
1129 {
1130 PaddedBuffer<float> a_v(M, N);
1131 PaddedBuffer<float> b_v(N, K);
1132 PaddedBuffer<float> c_v(M, N);
1133 PaddedBuffer<float> d_v(M, K);
1134
1135 for (int i = 0; i < M; i++) {
1136 for (int j = 0; j < N; j++) {
1137 a_v(i, j) = i * i;
1138 }
1139 }
1140 for (int i = 0; i < N; i++) {
1141 for (int j = 0; j < K; j++) {
1142 b_v(i, j) = j * j;
1143 }
1144 }
1145 for (int i = 0; i < M; i++) {
1146 for (int j = 0; j < N; j++) {
1147 c_v(i, j) = i + j;
1148 }
1149 }
1150 for (int i = 0; i < M; i++) {
1151 for (int j = 0; j < K; j++) {
1152 d_v(i, j) = i * j;
1153 }
1154 }
1155
1156 PaddedBuffer<float> z_v(M, N, K);
1157 PaddedBuffer<float> z_ref(M, N, K);
1158 for (int m = 0; m < M; m++) {
1159 for (int n = 0; n < N; n++) {
1160 for (int k = 0; k < K; k++) {
1161 z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
1162 }
1163 }
1164 }
1165
1166 SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z});
1167 eval(a_v, b_v, c_v, d_v, z_v);
1168 ExpectAllNear(z_v, z_ref, 1e-5);
1169 }
1170
1171 if (inline_order.size() == 2) {
1172 Tensor z2 = Compute(
1173 "z",
1174 {M, N, K},
1175 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1176 return a_buf.load(m, n) * b_buf.load(n, k) +
1177 (c_buf.load(m, n) * d_buf.load(m, k) +
1178 a_buf.load(m, n) * b_buf.load(n, k));
1179 });
1180 LoopNest l2({z2});
1181 l2.prepareForCodegen();
1182 StmtPtr stmt2 = l2.root_stmt();
1183
1184 std::ostringstream oss2;
1185 oss2 << *stmt2;
1186 std::string str2 = remove_space(oss2.str());
1187
1188 ASSERT_EQ(str1, str2);
1189 ASSERT_GT(str1.size(), 100);
1190 }
1191}
1192
1193TEST(LoopNest, ScheduleInlineFunc01) {
1194 InlineFunc01Helper({"x", "y"});
1195 InlineFunc01Helper({"y", "x"});
1196 InlineFunc01Helper({"x"});
1197 InlineFunc01Helper({"y"});
1198 InlineFunc01Helper({});
1199}
1200
1201// Make sure we cache random vars if we should.
1202TEST(LoopNest, ScheduleInlineRandom) {
1203 const int M = 4;
1204 const int N = 5;
1205 const int K = 6;
1206
1207 Tensor x = Compute(
1208 "x",
1209 {M, N, K},
1210 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1211 return Mod::make(Intrinsics::make(kRand, kInt), 5);
1212 });
1213 Tensor y = Compute(
1214 "y",
1215 {M, N, K},
1216 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1217 return x.load(m, n, k) + x.load(m, n, k);
1218 });
1219
1220 LoopNest l1({y}, {x, y});
1221 l1.computeInline(x.buf());
1222
1223 // would normally compare results but Rand isn't implemented in the
1224 // SimpleIREvaluator, even if we could seed it.
1225 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1226
1227 // Check the IR we produced
1228 checkIR(stmt1, R"IR(
1229# CHECK: for (int i = 0; i < 4; i++)
1230# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
1231# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
1232# CHECK: int x = rand();
1233# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR");
1234}
1235
1236// Make sure we don't cache random vars that are not being inlined.
1237TEST(LoopNest, ScheduleInlineRandomUnrelated) {
1238 const int M = 4;
1239 const int N = 5;
1240 const int K = 6;
1241
1242 Tensor x = Compute(
1243 "x",
1244 {M, N, K},
1245 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1246 return m * n * k;
1247 });
1248 Tensor y = Compute(
1249 "y",
1250 {M, N, K},
1251 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1252 return x.load(m, n, k) + Intrinsics::make(kRand, kInt) +
1253 Intrinsics::make(kRand, kInt);
1254 });
1255
1256 LoopNest l1({y}, {x, y});
1257 l1.computeInline(x.buf());
1258
1259 // would normally compare results but Rand isn't implemented in the
1260 // SimpleIREvaluator, even if we could seed it.
1261 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1262
1263 // Check the IR we produced
1264 checkIR(stmt1, R"IR(
1265# CHECK: for (int i = 0; i < 4; i++)
1266# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
1267# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
1268# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR");
1269}
1270
1271// Make sure we generate the right number of random values == the dimensionality
1272// of the production tensor.
1273TEST(LoopNest, ScheduleInlineRandomLowerDimensions) {
1274 const int M = 4;
1275 const int N = 5;
1276 const int K = 6;
1277
1278 Tensor x = Compute("x", {M}, [&](const VarHandle& m) {
1279 return Mod::make(Intrinsics::make(kRand, kInt), 5);
1280 });
1281 Tensor y = Compute(
1282 "y",
1283 {M, N, K},
1284 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1285 return x.load(m) + x.load(m);
1286 });
1287
1288 LoopNest l1({y}, {x, y});
1289 l1.computeInline(x.buf());
1290
1291 // would normally compare results but Rand isn't implemented in the
1292 // SimpleIREvaluator, even if we could seed it.
1293 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1294
1295 // Check the IR we produced
1296 checkIR(stmt1, R"IR(
1297# CHECK: for (int i = 0; i < 4; i++)
1298# CHECK: int x = rand();
1299# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
1300# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
1301# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR");
1302}
1303
1304// Make sure we don't screw up intrinsics thinking they're rand.
1305TEST(LoopNest, ScheduleInlineIntrinsics) {
1306 const int M = 4;
1307 const int N = 5;
1308 const int K = 6;
1309 BufHandle a_buf("a", {M, N}, kFloat);
1310 BufHandle b_buf("b", {N, K}, kFloat);
1311
1312 Tensor x = Compute(
1313 "x",
1314 {M, N, K},
1315 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1316 return a_buf.load(m, n) * b_buf.load(n, k);
1317 });
1318 Tensor y = Compute(
1319 "y",
1320 {M, N, K},
1321 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1322 return Intrinsics::make(kSqrt, x.load(m, n, k));
1323 });
1324
1325 PaddedBuffer<float> a_v(M, N);
1326 PaddedBuffer<float> b_v(N, K);
1327
1328 for (int i = 0; i < M; i++) {
1329 for (int j = 0; j < N; j++) {
1330 a_v(i, j) = i * i;
1331 }
1332 }
1333 for (int i = 0; i < N; i++) {
1334 for (int j = 0; j < K; j++) {
1335 b_v(i, j) = j * j;
1336 }
1337 }
1338
1339 LoopNest l1({y}, {x, y});
1340 LoopNest l2(l1);
1341 l2.computeInline(x.buf());
1342
1343 l1.prepareForCodegen();
1344 l2.prepareForCodegen();
1345
1346 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1347 StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1348
1349 SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1350 SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1351
1352 PaddedBuffer<float> y_1(M, N, K);
1353 PaddedBuffer<float> y_2(M, N, K);
1354
1355 eval1(a_v, b_v, y_1);
1356 eval2(a_v, b_v, y_2);
1357 ExpectAllNear(y_1, y_2, 1e-5);
1358 std::ostringstream oss1, oss2;
1359 oss1 << *stmt1;
1360 oss2 << *stmt2;
1361 ASSERT_GT(oss1.str().size(), oss2.str().size());
1362}
1363
1364// Make sure we can handle rand and non-rand intrinsics.
1365TEST(LoopNest, ScheduleInlineRandWithIntrinsics) {
1366 const int M = 4;
1367 const int N = 5;
1368 const int K = 6;
1369
1370 Tensor x = Compute(
1371 "x",
1372 {M, N, K},
1373 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1374 return Intrinsics::make(kRand, kFloat);
1375 });
1376 Tensor y = Compute(
1377 "y",
1378 {M, N, K},
1379 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1380 return Intrinsics::make(kSqrt, x.load(m, n, k));
1381 });
1382
1383 LoopNest l1({y}, {x, y});
1384 l1.computeInline(x.buf());
1385
1386 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1387
1388 // Check the IR we produced
1389 checkIR(stmt1, R"IR(
1390# CHECK: for (int i = 0; i < 4; i++)
1391# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
1392# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
1393# CHECK: float x = rand();
1394# CHECK: y[i, i_1, i_2] = sqrt(x);)IR");
1395}
1396
1397// Split a Compute then inline it into another compute.
1398TEST(LoopNest, ScheduleSplitAThenInline) {
1399 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1400 Tensor b = Compute(
1401 "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1402
1403 LoopNest l({b}, {a, b});
1404 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1405 LoopNest::splitWithMask(loops[0], 4);
1406 ASSERT_FALSE(l.computeInline(a.buf()));
1407}
1408
1409// Split a Compute then inline another Compute into it.
1410TEST(LoopNest, ScheduleSplitBThenInline) {
1411 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1412 Tensor b = Compute(
1413 "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1414
1415 LoopNest l({b}, {a, b});
1416 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0);
1417 LoopNest::splitWithMask(loops[0], 3);
1418 l.computeInline(a.buf());
1419 l.prepareForCodegen();
1420 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1421
1422 std::vector<int> output(6, 0);
1423 SimpleIREvaluator eval(s, {b});
1424 eval(output);
1425
1426 for (int i = 0; i < 6; ++i) {
1427 ASSERT_EQ(output[i], (i + 8) * (i + 8));
1428 }
1429}
1430
1431// Split a Compute twice then inline it.
1432TEST(LoopNest, ScheduleSplitTwiceThenInline) {
1433 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1434 Tensor b = Compute(
1435 "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1436 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1437 ForPtr i_inner;
1438
1439 LoopNest l({b}, {a, b});
1440 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1441 LoopNest::splitWithMask(loops[0], 4, &i_inner);
1442 LoopNest::splitWithMask(i_inner, 2);
1443 ASSERT_FALSE(l.computeInline(a.buf()));
1444}
1445
1446// Inline a Compute, then split.
1447TEST(LoopNest, ScheduleInlineThenSplit) {
1448 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1449 Tensor b = Compute(
1450 "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1451
1452 LoopNest l({b}, {a, b});
1453 l.computeInline(a.buf());
1454
1455 std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
1456 LoopNest::splitWithMask(loops.back(), 3);
1457 l.prepareForCodegen();
1458 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1459 std::vector<int> output(6, 0);
1460 SimpleIREvaluator eval(s, {b});
1461 eval(output);
1462
1463 for (int i = 0; i < 6; ++i) {
1464 ASSERT_EQ(output[i], (i + 8) * (i + 8));
1465 }
1466}
1467
1468// Split a Compute, inline it, then split the result.
1469TEST(LoopNest, ScheduleSplitInlineThenSplit) {
1470 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1471 Tensor b = Compute(
1472 "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1473
1474 LoopNest l({b}, {a, b});
1475 auto loops = NodeFinder<For>::find(l.root_stmt());
1476 LoopNest::splitWithMask(loops.back(), 2);
1477 l.computeInline(a.buf());
1478
1479 loops = NodeFinder<For>::find(l.root_stmt());
1480 LoopNest::splitWithMask(loops.front(), 2);
1481 l.prepareForCodegen();
1482 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1483 std::vector<int> output(16, 0);
1484 SimpleIREvaluator eval(s, {b});
1485 eval(output);
1486
1487 for (int i = 0; i < 16; ++i) {
1488 ASSERT_EQ(output[i], (i + 8) * (i + 8));
1489 }
1490}
1491
1492// Oversplit a loop that is simplified out after inlining.
1493TEST(LoopNest, ScheduleSplitInlineSimplify) {
1494 Tensor a = Compute("a", {18}, [&](const VarHandle& i) {
1495 return ExprHandle(4) * i - ExprHandle(2) * i;
1496 });
1497 Tensor b = Compute(
1498 "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); });
1499
1500 LoopNest l({b}, {a, b});
1501 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1502 LoopNest::splitWithMask(loops[0], 4);
1503 ASSERT_FALSE(l.computeInline(a.buf()));
1504}
1505
1506// Inline a Compute with two consumers.
1507TEST(LoopNest, ScheduleInlineThreeMixedOnce) {
1508 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1509 Tensor b = Compute(
1510 "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1511 Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1512 return a.load(k) * b.load(l);
1513 });
1514
1515 LoopNest l({c}, {a, b, c});
1516 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1517 l.computeInline(a.buf());
1518 l.prepareForCodegen();
1519
1520 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1521 std::vector<int> output(4 * 3, 0);
1522 SimpleIREvaluator eval(s, {c});
1523 eval(output);
1524
1525 for (int k = 0; k < 4; ++k) {
1526 for (int l = 0; l < 3; ++l) {
1527 ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
1528 }
1529 }
1530}
1531
1532// Inline Compute A into B, then inline B into C.
1533TEST(LoopNest, ScheduleInlineThreeMixedTwice) {
1534 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1535 Tensor b = Compute(
1536 "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1537 Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1538 return a.load(k) * b.load(l);
1539 });
1540
1541 LoopNest l({c}, {a, b, c});
1542 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1543 l.computeInline(a.buf());
1544 l.computeInline(b.buf());
1545 l.prepareForCodegen();
1546
1547 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1548 std::vector<int> output(4 * 3, 0);
1549 SimpleIREvaluator eval(s, {c});
1550 eval(output);
1551
1552 for (int k = 0; k < 4; ++k) {
1553 for (int l = 0; l < 3; ++l) {
1554 ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
1555 }
1556 }
1557}
1558
1559// Inline a Compute that is both a producer and consumer.
1560TEST(LoopNest, ScheduleInlineThreeMixedInner) {
1561 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1562 Tensor b = Compute(
1563 "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1564 Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1565 return a.load(k) * b.load(l);
1566 });
1567
1568 LoopNest l({c}, {a, b, c});
1569 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1570 l.computeInline(b.buf());
1571 l.prepareForCodegen();
1572
1573 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
1574 std::vector<int> output(4 * 3, 0);
1575 SimpleIREvaluator eval(s, {c});
1576 eval(output);
1577
1578 for (int k = 0; k < 4; ++k) {
1579 for (int l = 0; l < 3; ++l) {
1580 ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8));
1581 }
1582 }
1583}
1584
1585// Split 3 Computes, then inline the first two into the last.
1586TEST(LoopNest, ScheduleInlineThreeMixedSplit) {
1587 Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; });
1588 Tensor b = Compute(
1589 "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); });
1590 Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) {
1591 return a.load(k) * b.load(l);
1592 });
1593
1594 LoopNest l({c}, {a, b, c});
1595 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0);
1596 LoopNest::splitWithMask(loops[0], 4);
1597 loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0);
1598 LoopNest::splitWithMask(loops[0], 3);
1599 loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
1600 LoopNest::splitWithMask(loops[0], 2);
1601
1602 ASSERT_FALSE(l.computeInline(a.buf()));
1603}
1604
1605// Check that inlining works for output tensors too
1606TEST(LoopNest, ScheduleInlineOutputTensors) {
1607 const int M = 4;
1608 const int N = 5;
1609 const int K = 6;
1610
1611 Tensor x = Compute(
1612 "x",
1613 {M, N, K},
1614 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1615 return m * n * k;
1616 });
1617 Tensor y = Compute(
1618 "y",
1619 {M, N, K},
1620 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1621 return x.load(m, n, k) + m;
1622 });
1623
1624 LoopNest l1({x, y});
1625 l1.computeInline(x.buf());
1626
1627 // would normally compare results but Rand isn't implemented in the
1628 // SimpleIREvaluator, even if we could seed it.
1629 StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1630
1631 // Check the IR we produced
1632 checkIR(stmt1, R"IR(
1633# CHECK: for (int i = 0; i < 4; i++)
1634# CHECK: for (int i_1 = 0; i_1 < 5; i_1++)
1635# CHECK: for (int i_2 = 0; i_2 < 6; i_2++)
1636# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2;
1637# CHECK: for (int i_3 = 0; i_3 < 4; i_3++)
1638# CHECK: for (int i_4 = 0; i_4 < 5; i_4++)
1639# CHECK: for (int i_5 = 0; i_5 < 6; i_5++)
1640# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR");
1641}
1642
1643TEST(LoopNest, ScheduleInlineWithCompoundIndices) {
1644 // Input IR:
1645 // for (int64_t i = 0; i < 100; i++) {
1646 // A[i*2,i] = i * 500ll;
1647 // }
1648 // for (int64_t j = 0; j < 100; j++) {
1649 // B[0ll,j] = A[0, j] + j * 100ll;
1650 // }
1651 BufHandle a_buf("A", {20, 100}, kLong);
1652 BufHandle b_buf("B", {20, 100}, kLong);
1653 VarHandle i("i", kLong);
1654 VarHandle j("j", kLong);
1655 auto forI = For::make(
1656 i,
1657 0,
1658 100,
1659 Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast<int64_t>(500))));
1660 auto forJ = For::make(
1661 j,
1662 0,
1663 100,
1664 Store::make(
1665 b_buf,
1666 {static_cast<int64_t>(0), j},
1667 Add::make(
1668 Load::make(a_buf, {static_cast<int64_t>(0), j}),
1669 Mul::make(j, static_cast<int64_t>(100)))));
1670 auto par = Block::make({forI, forJ});
1671
1672 LoopNest l(par, {b_buf.node()});
1673 // Inlining should fail since the producer has compound expr as index.
1674 ASSERT_FALSE(l.computeInline(a_buf.node()));
1675
1676 // The input statement must remain as is.
1677 checkIR(l.root_stmt(), R"IR(
1678 # CHECK: for (int64_t i = 0;
1679 # CHECK-NEXT: A[
1680 # CHECK: for (int64_t j = 0;
1681 # CHECK-NEXT: B[)IR");
1682}
1683
1684TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) {
1685 // Input IR:
1686 // for (int64_t i = 0; i < 100; i++) {
1687 // A[0ll,i] = i * 500ll;
1688 // }
1689 // for (int64_t j = 0; j < 100; j++) {
1690 // B[0ll,j] = A[(int64_t)0, j] + j * 100ll;
1691 // }
1692 BufHandle a_buf("A", {20, 100}, kLong);
1693 BufHandle b_buf("B", {20, 100}, kLong);
1694 VarHandle i("i", kLong);
1695 VarHandle j("j", kLong);
1696 auto forI = For::make(
1697 i,
1698 0,
1699 100,
1700 Store::make(
1701 a_buf,
1702 {static_cast<int64_t>(0), i},
1703 Mul::make(i, static_cast<int64_t>(500))));
1704 auto forJ = For::make(
1705 j,
1706 0,
1707 100,
1708 Store::make(
1709 b_buf,
1710 {static_cast<int64_t>(0), j},
1711 Add::make(
1712 Load::make(a_buf, {0, j}),
1713 Mul::make(j, static_cast<int64_t>(100)))));
1714 auto par = Block::make({forI, forJ});
1715
1716 LoopNest l(par, {b_buf.node()});
1717 ASSERT_TRUE(l.computeInline(a_buf.node()));
1718
1719 checkIR(l.root_stmt(), R"IR(
1720 # CHECK: for (int64_t j = 0; j < 100; j++) {
1721 # CHECK: B[0ll, j] = j * 500ll + j * 100ll;
1722 # CHECK: })IR");
1723}
1724
1725TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) {
1726 // Input IR:
1727 // for (int64_t i = 0; i < 100; i++) {
1728 // A[(int64_t)0,i] = i * 500ll;
1729 // }
1730 // for (int64_t j = 0; j < 100; j++) {
1731 // B[0ll,j] = A[0ll, j] + j * 100ll;
1732 // }
1733 BufHandle a_buf("A", {20, 100}, kLong);
1734 BufHandle b_buf("B", {20, 100}, kLong);
1735 VarHandle i("i", kLong);
1736 VarHandle j("j", kLong);
1737 auto forI = For::make(
1738 i,
1739 0,
1740 100,
1741 Store::make(a_buf, {0, i}, Mul::make(i, static_cast<int64_t>(500))));
1742 auto forJ = For::make(
1743 j,
1744 0,
1745 100,
1746 Store::make(
1747 b_buf,
1748 {static_cast<int64_t>(0), j},
1749 Add::make(
1750 Load::make(a_buf, {static_cast<int64_t>(0), j}),
1751 Mul::make(j, static_cast<int64_t>(100)))));
1752 auto par = Block::make({forI, forJ});
1753
1754 LoopNest l(par, {b_buf.node()});
1755 ASSERT_TRUE(l.computeInline(a_buf.node()));
1756
1757 checkIR(l.root_stmt(), R"IR(
1758 # CHECK: for (int64_t j = 0; j < 100; j++) {
1759 # CHECK: B[0ll, j] = j * 500ll + j * 100ll;
1760 # CHECK: })IR");
1761}
1762
1763TEST(LoopNest, ScheduleFuserStyle) {
1764 const int kVectorSize = 8;
1765 const int kVectorCount = 128;
1766 const int kTotalSize = kVectorSize * kVectorCount;
1767
1768 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
1769
1770 Tensor b =
1771 Compute("f", {kTotalSize}, [&](const std::vector<VarHandle>& axes) {
1772 return a_buf.load(axes[0]) + 11.0f;
1773 });
1774
1775 Tensor c =
1776 Compute("g", {kTotalSize}, [&](const std::vector<VarHandle>& axes) {
1777 return b.load(axes[0]) + 1.0f;
1778 });
1779
1780 LoopNest l({b, c});
1781 l.prepareForCodegen();
1782 StmtPtr s = l.root_stmt();
1783
1784 std::vector<float> a_data(kTotalSize, 7.0f);
1785 std::vector<float> b_data(kTotalSize, 0.0f);
1786 std::vector<float> c_data(kTotalSize, 0.0f);
1787 SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data);
1788
1789 for (int i = 0; i < kTotalSize; i++) {
1790 ASSERT_EQ(b_data[i], 18.0f);
1791 ASSERT_EQ(c_data[i], 19.0f);
1792 }
1793}
1794
1795TEST(LoopNest, ScheduleFuserThreeArg) {
1796 const int kVectorSize = 8;
1797 const int kVectorCount = 128;
1798 const int kTotalSize = kVectorSize * kVectorCount;
1799
1800 BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat);
1801 BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat);
1802 BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat);
1803 BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat);
1804
1805 Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) {
1806 return a.load(i) + b.load(i);
1807 });
1808 Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) {
1809 return e.load(i) + c.load(i);
1810 });
1811 Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) {
1812 return f.load(i) + d.load(i);
1813 });
1814
1815 LoopNest l({g}, {e, f, g});
1816 l.computeInline(l.getLoopBodyFor(e));
1817 l.computeInline(l.getLoopBodyFor(f));
1818 l.prepareForCodegen();
1819 StmtPtr s = l.root_stmt();
1820
1821 std::vector<float> a_data(kTotalSize, 1.0f);
1822 std::vector<float> b_data(kTotalSize, 2.0f);
1823 std::vector<float> c_data(kTotalSize, 3.0f);
1824 std::vector<float> d_data(kTotalSize, 4.0f);
1825 std::vector<float> g_data(kTotalSize, 0.0f);
1826 SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data);
1827
1828 for (int i = 0; i < kTotalSize; i++) {
1829 ASSERT_EQ(g_data[i], 10.0f);
1830 }
1831}
1832
1833TEST(LoopNest, ScheduleDynamicShape2D) {
1834 auto testWithSize = [](int32_t M, int32_t N) {
1835 VarHandle m("m", kInt);
1836 VarHandle n("n", kInt);
1837 BufHandle a("a", {m, n}, kFloat);
1838 BufHandle b("b", {m, n}, kFloat);
1839 Tensor c =
1840 Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
1841 return a.load(i, j) + b.load(i, j);
1842 });
1843 LoopNest l({c});
1844 StmtPtr s = l.root_stmt();
1845 SimpleIREvaluator cg(s, {a, b, c, m, n});
1846 std::vector<float> aData(M * N, 1.0f);
1847 std::vector<float> bData(M * N, 2.0f);
1848 std::vector<float> cData(M * N, 0.0f);
1849 cg.call({aData, bData, cData, M, N});
1850 ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
1851 };
1852 testWithSize(1, 8);
1853 testWithSize(16, 32);
1854 testWithSize(37, 11);
1855}
1856
1857TEST(LoopNest, LoopNestComputeAt_1) {
1858 // Verify that compute_at works on the following example:
1859 //
1860 // for (int i_a = 0; i_a < N; i_a++) {
1861 // A[i_a] = i_a * i_a
1862 // }
1863 // for (int i_b = 0; i_b < N; i_b++) {
1864 // B[i_b] = A[i_b]
1865 // }
1866 //
1867 // After the transformation the i_b loop should have an allocation for a temp
1868 // buffer and that buffer should be used in computation of B. No use of A
1869 // should be in that loop after the transformation. Also, computation of A
1870 // should not be inlined into B. Instead, it should be computed into the temp,
1871 // and the temp should be used in B.
1872 VarHandle N("N", kInt);
1873 Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; });
1874 Tensor B =
1875 Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); });
1876 LoopNest l({B}, {A, B});
1877 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0);
1878 LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
1879 l.prepareForCodegen();
1880 SimpleIREvaluator cg(l.root_stmt(), {B, N});
1881 StmtPtr s = cg.stmt();
1882
1883 checkIR(s, R"IR(
1884# CHECK: Allocate(temp); // dtype=int, dims=[1]
1885# CHECK: for (int i = 0; i < N; i++)
1886# CHECK: temp[
1887# CHECK-NOT: A[
1888# CHECK: B[i_1] = temp[0]
1889# CHECK: Free(temp))IR");
1890
1891 // Now check that the loop still produces the correct result.
1892 std::vector<int> b_data(100, 0);
1893 cg.call({b_data, 100});
1894
1895 std::vector<int> b_ref(100, 0);
1896 for (int i = 0; i < 100; i++) {
1897 b_ref[i] = i * i;
1898 }
1899 assertAllEqual(b_data, b_ref);
1900}
1901
1902TEST(LoopNest, LoopNestComputeAt_2) {
1903 // Verify that compute_at works on the following example:
1904 //
1905 // for (int py = 0; py < H+1; py++) {
1906 // for (int px = 0; px < W+1; px++) {
1907 // p[py, px] = py*px
1908 // }
1909 // }
1910 // for (int cy = 0; cy < H; cy++) {
1911 // for (int cx = 0; cx < W; cx++) {
1912 // c[py, px] = p[cy,cx] + p[cy+1,cx] +
1913 // p[cy,cx+1] + p[cy+1,cx+1]
1914 // }
1915 // }
1916
1917 const int kW = 16, kH = 16;
1918 VarHandle W("W", kInt);
1919 VarHandle H("H", kInt);
1920 Tensor p = Compute(
1921 "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) {
1922 return px * py;
1923 });
1924 Tensor c =
1925 Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
1926 return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) +
1927 p.load(y + 1, x + 1);
1928 });
1929
1930 std::vector<int> c_ref(kW * kH, 0);
1931 for (int y = 0; y < kH; y++) {
1932 for (int x = 0; x < kW; x++) {
1933 c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
1934 }
1935 }
1936 LoopNest orig_loopnest({c}, {p, c});
1937
1938 {
1939 // First let's try to compute P at axis cy (the outer loop)
1940 LoopNest l(orig_loopnest);
1941 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
1942 LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
1943 l.prepareForCodegen();
1944 SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
1945 StmtPtr s = cg.stmt();
1946
1947 // Check the IR we produced
1948 checkIR(s, R"IR(
1949# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
1950# CHECK: for (int i_2 = 0; i_2 < H; i_2++)
1951# CHECK: for
1952# CHECK: for
1953# CHECK: for (int i_3 = 0; i_3 < W; i_3++)
1954# CHECK-NOT: prod[
1955# CHECK: cons[
1956# CHECK: Free(temp))IR");
1957
1958 // Now check that the loop still produces the correct result.
1959 std::vector<int> c_data(kW * kH, 0);
1960 cg.call({c_data, kW, kH});
1961
1962 assertAllEqual(c_data, c_ref);
1963 }
1964 {
1965 // Now let's try to compute P at axis cx (the inner loop)
1966 LoopNest l(orig_loopnest);
1967 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
1968 LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
1969 l.prepareForCodegen();
1970 SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
1971 StmtPtr s = cg.stmt();
1972
1973 // Check the IR we produced
1974 checkIR(s, R"IR(
1975# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
1976# CHECK: for (int i_2 = 0; i_2 < H; i_2++)
1977# CHECK: for (int i_3 = 0; i_3 < W; i_3++)
1978# CHECK: for
1979# CHECK: for
1980# CHECK-NOT: prod[
1981# CHECK: cons[
1982# CHECK: Free(temp))IR");
1983
1984 // Now check that the loop still produces the correct result.
1985 std::vector<int> c_data(kW * kH, 0);
1986 cg.call({c_data, kW, kH});
1987
1988 assertAllEqual(c_data, c_ref);
1989 }
1990}
1991
1992TEST(LoopNest, LoopNestComputeAt_3) {
1993 // Verify that compute_at works on the following example:
1994 //
1995 // A(x,y) = x*y
1996 // B(x,y) = A(x, y)
1997 // C(x,y) = B(x+1, y)
1998 // D(x,y) = A(x, y+1) + C(x, y)
1999 //
2000 // i.e. when 'A' comes to 'D' directly and indirectly through 'C'.
2001
2002 const int kW = 16, kH = 16;
2003 VarHandle W("W", kInt);
2004 VarHandle H("H", kInt);
2005 Tensor A = Compute(
2006 "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) {
2007 return ax * ay;
2008 });
2009 Tensor B = Compute(
2010 "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) {
2011 return A.load(by, bx);
2012 });
2013 Tensor C =
2014 Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) {
2015 return B.load(cy, cx + 1);
2016 });
2017 Tensor D =
2018 Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) {
2019 return A.load(dy + 1, dx) + C.load(dy, dx);
2020 });
2021
2022 std::vector<int> c_ref(kW * kH, 0);
2023 for (int y = 0; y < kH; y++) {
2024 for (int x = 0; x < kW; x++) {
2025 c_ref[y * kW + x] = (y + 1) * x + y * (x + 1);
2026 }
2027 }
2028
2029 LoopNest orig_loopnest({D}, {A, B, C, D});
2030 {
2031 // First let's try to compute A at axis dy (the outer loop)
2032 LoopNest l(orig_loopnest);
2033 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0);
2034 LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
2035 l.prepareForCodegen();
2036 SimpleIREvaluator cg(l.root_stmt(), {D, W, H});
2037 StmtPtr s = cg.stmt();
2038
2039 // Check the IR we produced
2040 checkIR(s, R"IR(
2041# CHECK: Allocate(temp); // dtype=int, dims=[1, W]
2042# CHECK: for (int i = 0; i < H + 1; i++)
2043# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++)
2044# CHECK: A[
2045# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++)
2046# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++)
2047# CHECK: B[
2048# CHECK: for (int i_4 = 0; i_4 < H; i_4++)
2049# CHECK: for (int i_5 = 0; i_5 < W; i_5++)
2050# CHECK: C[
2051# CHECK: for (int i_6 = 0; i_6 < H; i_6++)
2052# CHECK: for (int i_7 = 0; i_7 < W; i_7++)
2053# CHECK-NOT: A[)IR");
2054
2055 // Now check that the loop still produces the correct result.
2056 std::vector<int> c_data(kW * kH, 0);
2057 cg.call({c_data, kW, kH});
2058
2059 assertAllEqual(c_data, c_ref);
2060 }
2061 {
2062 // Now let's try to compute A at axis dx (the inner loop)
2063 LoopNest l(orig_loopnest);
2064 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0);
2065 LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]);
2066 l.prepareForCodegen();
2067 SimpleIREvaluator cg(l.root_stmt(), {D, W, H});
2068 StmtPtr s = cg.stmt();
2069
2070 // Check the IR we produced
2071 checkIR(s, R"IR(
2072# CHECK: Allocate(temp); // dtype=int, dims=[1, 1]
2073# CHECK: for (int i = 0; i < H + 1; i++)
2074# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++)
2075# CHECK: A[
2076# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++)
2077# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++)
2078# CHECK: B[
2079# CHECK: for (int i_4 = 0; i_4 < H; i_4++)
2080# CHECK: for (int i_5 = 0; i_5 < W; i_5++)
2081# CHECK: C[
2082# CHECK: for (int i_6 = 0; i_6 < H; i_6++)
2083# CHECK: for (int i_7 = 0; i_7 < W; i_7++)
2084# CHECK-NOT: A[)IR");
2085
2086 // Now check that the loop still produces the correct result.
2087 std::vector<int> c_data(kW * kH, 0);
2088 cg.call({c_data, kW, kH});
2089
2090 assertAllEqual(c_data, c_ref);
2091 }
2092}
2093
2094using Axis = const VarHandle&;
2095
2096TEST(LoopNest, Reduce2dComputeAt) {
2097 const int kW = 16, kH = 16;
2098 VarHandle W("W", kInt);
2099 VarHandle H("H", kInt);
2100
2101 Tensor p = Compute(
2102 "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; });
2103 Tensor c = Reduce(
2104 "cons",
2105 {H, W},
2106 Sum(),
2107 [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); },
2108 {2, 2});
2109
2110 std::vector<int> c_ref(kW * kH, 0);
2111 for (int y = 0; y < kH; y++) {
2112 for (int x = 0; x < kW; x++) {
2113 c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
2114 }
2115 }
2116 LoopNest orig_loopnest({c}, {p, c});
2117 checkIR(orig_loopnest.root_stmt(), R"IR(
2118# CHECK: for (int i = 0; i < H + 1; i++) {
2119# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) {
2120# CHECK: prod[i, i_1] = i_1 * i;
2121# CHECK: }
2122# CHECK: }
2123# CHECK: for (int i_2 = 0; i_2 < H; i_2++) {
2124# CHECK: for (int i_3 = 0; i_3 < W; i_3++) {
2125# CHECK: cons[i_2, i_3] = int(0);
2126# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) {
2127# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) {
2128# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5});
2129# CHECK: }
2130# CHECK: }
2131# CHECK: }
2132# CHECK: }
2133)IR");
2134
2135 {
2136 // First let's try to compute P at axis cy (the outer loop)
2137 LoopNest l(orig_loopnest);
2138 auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2139 LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
2140 // FIXME: Calling simplify here breaks the IR:
2141 // MALFORMED INPUT: could not find base node in Load - temp[...]
2142 // l.simplify();
2143 l.eliminateDeadStores();
2144 l.prepareForCodegen();
2145 SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
2146 checkIR(cg.stmt(), R"IR(
2147# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1]
2148# CHECK: for (int i = 0; i < H; i++) {
2149# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
2150# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) {
2151# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0);
2152# CHECK: }
2153# CHECK: }
2154# CHECK: for (int i_1 = 0; i_1 < W; i_1++) {
2155# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0);
2156# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) {
2157# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) {
2158# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]);
2159# CHECK: }
2160# CHECK: }
2161# CHECK: }
2162# CHECK: }
2163# CHECK: Free(temp);
2164)IR");
2165
2166 // Now check that the loop still produces the correct result.
2167 std::vector<int> c_data(kW * kH, 0);
2168 cg.call({c_data, kW, kH});
2169 assertAllEqual(c_data, c_ref);
2170 }
2171 {
2172 // Now let's try to compute P at axis cx (the inner loop)
2173 LoopNest l(orig_loopnest);
2174 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2175 LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
2176 l.simplify();
2177 l.eliminateDeadStores();
2178 l.prepareForCodegen();
2179 SimpleIREvaluator cg(l.root_stmt(), {c, W, H});
2180 checkIR(cg.stmt(), R"IR(
2181# CHECK: Allocate(temp); // dtype=int, dims=[2, 2]
2182# CHECK: for (int i = 0; i < H; i++) {
2183# CHECK: for (int i_1 = 0; i_1 < W; i_1++) {
2184# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) {
2185# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) {
2186# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1);
2187# CHECK: }
2188# CHECK: }
2189# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0;
2190# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) {
2191# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) {
2192# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]);
2193# CHECK: }
2194# CHECK: }
2195# CHECK: }
2196# CHECK: }
2197# CHECK: Free(temp);
2198)IR");
2199
2200 // Now check that the loop still produces the correct result.
2201 std::vector<int> c_data(kW * kH, 0);
2202 cg.call({c_data, kW, kH});
2203 assertAllEqual(c_data, c_ref);
2204 }
2205}
2206
2207TEST(LoopNest, DISABLED_Conv1d_NH) {
2208 // Lots of stuff is broken here. The computeAt swaps the axes for some odd
2209 // reason. Even without that, the index flattener fails due to "dimensions
2210 // mismatch in flatten index".
2211
2212 int N = 4;
2213 int H = 256;
2214 int R = 3;
2215 int Pad = 1;
2216 BufHandle IP("input", {H}, kFloat);
2217
2218 Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) {
2219 auto cond = CompareSelect::make(h, Pad, 1, 0, kLT);
2220 cond = CompareSelect::make(h, H + Pad, 1, cond, kGE);
2221 return ifThenElse(cond, 0.f, IP.load(n, h - Pad));
2222 });
2223 Tensor B = Reduce(
2224 "B",
2225 {N, H},
2226 Sum(),
2227 [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); },
2228 {R});
2229 LoopNest l({B});
2230 checkIR(l.root_stmt(), R"IR(
2231# CHECK: for (int np = 0; np < 4; np++) {
2232# CHECK: for (int hp = 0; hp < 258; hp++) {
2233# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]);
2234# CHECK: }
2235# CHECK: }
2236# CHECK: for (int n = 0; n < 4; n++) {
2237# CHECK: for (int h = 0; h < 256; h++) {
2238# CHECK: B[n, h] = float(0);
2239# CHECK: for (int r = 0; r < 3; r++) {
2240# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r});
2241# CHECK: }
2242# CHECK: }
2243# CHECK: }
2244)IR");
2245 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0);
2246 LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
2247 // FIXME: The current IR is totally broken. The body of the inlined loop is:
2248
2249 // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0),
2250 // 0.f, input[idx1 + 0, (idx0 + n) - 1]);
2251
2252 // Which seems to mix up the axes. The CHECK below is my best guess at what
2253 // the input "should" look like
2254
2255 checkIR(l.root_stmt(), R"IR(
2256# CHECK: for (int n = 0; n < 4; n++) {
2257# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) {
2258# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) {
2259 temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]);
2260# CHECK: }
2261# CHECK: }
2262# CHECK: for (int h = 0; h < 256; h++) {
2263# CHECK: B[n, h] = float(0);
2264# CHECK: for (int r = 0; r < 3; r++) {
2265# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r});
2266# CHECK: }
2267# CHECK: }
2268# CHECK: }
2269)IR");
2270
2271 l.simplify();
2272 l.prepareForCodegen();
2273 StmtPtr s = l.root_stmt();
2274
2275 SimpleIREvaluator cg(s, {IP, B});
2276 // auto At = at::ones({N, H}, at::kFloat);
2277 auto At = at::arange(N * H, at::kFloat).reshape({N, H});
2278 auto Rt = at::conv1d(
2279 At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3);
2280 auto Bt = at::empty_like(Rt);
2281 cg.call({At.data_ptr<float>(), Bt.data_ptr<float>()});
2282 ASSERT_TRUE(at::allclose(Rt, Bt));
2283}
2284
2285class LoopOrderHelper : public IRVisitor {
2286 std::stringstream ordering;
2287
2288 public:
2289 std::string getOrder(StmtPtr s) {
2290 ordering.str("");
2291 s->accept(this);
2292 return ordering.str();
2293 }
2294
2295 void visit(ForPtr v) final {
2296 ordering << v->var()->name_hint() << ",";
2297 IRVisitor::visit(v);
2298 }
2299};
2300
2301TEST(LoopNest, LoopNestReorderAxis1) {
2302 Tensor tensor =
2303 Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) {
2304 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2305 });
2306 LoopNest l({tensor});
2307 StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2308
2309 std::vector<int> stmt1_output(6, 0);
2310 SimpleIREvaluator cg(stmt1, {tensor});
2311 cg.call({stmt1_output});
2312
2313 auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2314 LoopNest::reorderAxis(loops[0], loops[1]);
2315 StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2316
2317 ASSERT_NE(stmt1, stmt2);
2318 LoopOrderHelper loopOrderHelper;
2319 std::string order1 = loopOrderHelper.getOrder(stmt1);
2320 std::string order2 = loopOrderHelper.getOrder(stmt2);
2321
2322 ASSERT_EQ(order1, "j,i,");
2323 ASSERT_EQ(order2, "i,j,");
2324
2325 std::vector<int> stmt2_output(6, 0);
2326 SimpleIREvaluator cg2(stmt2, {tensor});
2327 cg.call({stmt2_output});
2328
2329 for (int i = 0; i < 6; ++i) {
2330 ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2331 }
2332
2333 // Reorder them back.
2334 loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2335 LoopNest::reorderAxis(loops[0], loops[1]);
2336 StmtPtr stmt3 = l.root_stmt();
2337
2338 std::string order3 = loopOrderHelper.getOrder(stmt3);
2339 ASSERT_EQ(order3, order1);
2340
2341 std::ostringstream oss1, oss2;
2342 oss1 << *stmt1;
2343 oss2 << *stmt3;
2344
2345 // Should be identical to the unreordered statement.
2346 ASSERT_EQ(oss1.str(), oss2.str());
2347}
2348
2349TEST(LoopNest, LoopNestReorderPartialAxes) {
2350 Tensor tensor = Compute(
2351 "f",
2352 {2, 3, 4},
2353 [](const VarHandle& x, const VarHandle& y, const VarHandle& z) {
2354 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y +
2355 cast<float>(z) * z;
2356 });
2357 LoopNest l({tensor});
2358
2359 LoopOrderHelper loopOrderHelper;
2360 StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2361 ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,");
2362
2363 std::vector<int> stmt1_output(24, 0);
2364 SimpleIREvaluator cg(stmt1, {tensor});
2365 cg.call({stmt1_output});
2366
2367 auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2368 LoopNest::reorderAxis(loops[0], loops[1]);
2369 ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,");
2370
2371 StmtPtr stmt2 = Stmt::clone(l.root_stmt());
2372
2373 std::vector<int> stmt2_output(24, 0);
2374 SimpleIREvaluator cg2(stmt2, {tensor});
2375 cg2.call({stmt2_output});
2376
2377 for (int i = 0; i < 24; ++i) {
2378 ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2379 }
2380
2381 loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2382 LoopNest::reorderAxis(loops[1], loops[2]);
2383 ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,");
2384
2385 StmtPtr stmt3 = Stmt::clone(l.root_stmt());
2386
2387 std::vector<int> stmt3_output(24, 0);
2388 SimpleIREvaluator cg3(stmt3, {tensor});
2389 cg3.call({stmt3_output});
2390
2391 for (int i = 0; i < 24; ++i) {
2392 ASSERT_EQ(stmt1_output[i], stmt3_output[i]);
2393 }
2394}
2395
2396TEST(LoopNest, LoopNestReorderInternalAxis) {
2397 Tensor tensor = Compute(
2398 "f",
2399 {1, 2, 3, 4},
2400 [](const VarHandle& w,
2401 const VarHandle& x,
2402 const VarHandle& y,
2403 const VarHandle& z) {
2404 return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y +
2405 cast<float>(z) * z;
2406 });
2407 LoopNest l({tensor});
2408
2409 LoopOrderHelper loopOrderHelper;
2410 StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2411 ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,");
2412
2413 std::vector<int> stmt1_output(24, 0);
2414 SimpleIREvaluator cg(stmt1, {tensor});
2415 cg.call({stmt1_output});
2416
2417 auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2418 LoopNest::reorderAxis(loops[2], loops[1]);
2419 ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,");
2420
2421 StmtPtr stmt2 = l.root_stmt();
2422
2423 std::vector<int> stmt2_output(24, 0);
2424 SimpleIREvaluator cg2(stmt2, {tensor});
2425 cg2.call({stmt2_output});
2426
2427 for (int i = 0; i < 24; ++i) {
2428 ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2429 }
2430}
2431
2432TEST(LoopNest, LoopNestReorderEnclosingAxis) {
2433 Tensor tensor = Compute(
2434 "f",
2435 {1, 2, 3, 4},
2436 [](const VarHandle& w,
2437 const VarHandle& x,
2438 const VarHandle& y,
2439 const VarHandle& z) {
2440 return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y +
2441 cast<float>(z) * z;
2442 });
2443 LoopNest l({tensor});
2444
2445 LoopOrderHelper loopOrderHelper;
2446 StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2447
2448 std::vector<int> stmt1_output(24, 0);
2449 SimpleIREvaluator cg(stmt1, {tensor});
2450 cg.call({stmt1_output});
2451
2452 auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2453 LoopNest::reorderAxis(loops[0], loops[3]);
2454 ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,");
2455
2456 StmtPtr stmt2 = l.root_stmt();
2457
2458 std::vector<int> stmt2_output(24, 0);
2459 SimpleIREvaluator cg2(stmt2, {tensor});
2460 cg2.call({stmt2_output});
2461
2462 for (int i = 0; i < 24; ++i) {
2463 ASSERT_EQ(stmt1_output[i], stmt2_output[i]);
2464 }
2465}
2466
2467TEST(LoopNest, LoopNestReorderSameAxis) {
2468 Tensor tensor =
2469 Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) {
2470 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2471 });
2472 LoopNest l({tensor});
2473 StmtPtr stmt1 = Stmt::clone(l.root_stmt());
2474
2475 auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2476 LoopNest::reorderAxis(loops[1], loops[1]);
2477 StmtPtr stmt2 = Stmt::clone(l.root_stmt());
2478
2479 std::ostringstream oss, oss2;
2480 oss << *stmt1;
2481 oss2 << *stmt2;
2482 ASSERT_EQ(oss.str(), oss2.str());
2483}
2484
2485TEST(LoopNest, LoopNestReorderExtraStatements) {
2486 /* We're going for a structure like this:
2487 * for i in ...
2488 * Stmt 1
2489 * for j in ...
2490 * Stmt 2
2491 * for k in ...
2492 * Stmt 3
2493 * Stmt 4
2494 */
2495
2496 Tensor tensor = Compute(
2497 "f",
2498 {2, 3, 4},
2499 [](const VarHandle& x, const VarHandle& y, const VarHandle& z) {
2500 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y +
2501 cast<float>(z) * z;
2502 });
2503 LoopNest l({tensor});
2504
2505 BufHandle extra("res", {6, 3}, kFloat);
2506
2507 auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2508
2509 VarHandle i = VarHandle(loops[0]->var());
2510
2511 StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f);
2512 StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f);
2513 // stmt 3 is the Function body.
2514 StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f);
2515
2516 loops[0]->body()->prepend_stmt(store_1);
2517 loops[1]->body()->prepend_stmt(store_2);
2518 loops[1]->body()->append_stmt(store_3);
2519 StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2520
2521 std::vector<int> extra1(6, 0);
2522 std::vector<int> res1(24, 0);
2523 SimpleIREvaluator cg(stmt1, {tensor, extra});
2524 cg.call({res1, extra1});
2525
2526 /* Then we reorder loop y and z, we want it to look like:
2527 *
2528 * for i in ...
2529 * Stmt 1
2530 * for j in ...
2531 * Stmt 2
2532 * for j_1 in ...
2533 * for k in ...
2534 * Stmt 3
2535 * for j_2 in ...
2536 * Stmt 4
2537 *
2538 * We need extra loops because we don't have dependency info about stmt 3
2539 * and 4.
2540 *
2541 */
2542
2543 LoopNest::reorderAxis(loops[1], loops[2]);
2544 StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2545
2546 // Check the IR we produced
2547 checkIR(stmt2, R"IR(
2548# CHECK: for
2549# CHECK: res[i, 0] = 1
2550# CHECK: for
2551# CHECK: res[i, 1] = 2
2552# CHECK: for
2553# CHECK: for
2554# CHECK: f[
2555# CHECK: for
2556# CHECK: res[i, 2] = 4
2557)IR");
2558
2559 std::vector<int> extra2(6, 0);
2560 std::vector<int> res2(24, 0);
2561 SimpleIREvaluator cg2(stmt2, {tensor, extra});
2562 cg2.call({res2, extra2});
2563
2564 for (int i = 0; i < 24; ++i) {
2565 ASSERT_EQ(res1[i], res2[i]);
2566 }
2567 for (int i = 0; i < 6; ++i) {
2568 ASSERT_EQ(extra1[i], extra2[i]);
2569 }
2570
2571 /* Now reorder x and the y above stmt 3:
2572 *
2573 *
2574 * for x in ...
2575 * Stmt 1
2576 * for y in ...
2577 * Stmt 2
2578 *
2579 * for y in ...
2580 * for z in ...
2581 * for x in ...
2582 * Stmt 3
2583 *
2584 * for x in ...
2585 * for y in ...
2586 * Stmt 4
2587 *
2588 *
2589 */
2590 loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0);
2591 LoopNest::reorderAxis(loops[0], loops[2]);
2592 StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt()));
2593
2594 // Check the IR we produced
2595 checkIR(stmt3, R"IR(
2596# CHECK: for
2597# CHECK: res[i, 0] = 1
2598# CHECK: for
2599# CHECK: res[i, 1] = 2
2600# CHECK: for
2601# CHECK: for
2602# CHECK: for
2603# CHECK: f[
2604# CHECK: for
2605# CHECK: for
2606# CHECK: res[i_2, 2] = 4
2607)IR");
2608
2609 std::vector<int> extra3(6, 0);
2610 std::vector<int> res3(24, 0);
2611 SimpleIREvaluator cg3(stmt3, {tensor, extra});
2612 cg3.call({res3, extra3});
2613
2614 for (int i = 0; i < 24; ++i) {
2615 ASSERT_EQ(res1[i], res3[i]);
2616 }
2617 for (int i = 0; i < 6; ++i) {
2618 ASSERT_EQ(extra1[i], extra3[i]);
2619 }
2620}
2621
2622void LoopNestReorderTestHelper(
2623 bool prepend,
2624 bool append,
2625 int index1,
2626 int index2) {
2627 Tensor c = Compute(
2628 "5d", {2, 3, 2, 3, 2}, [](const std::vector<VarHandle>&) { return -1; });
2629 LoopNest l({c});
2630
2631 BufHandle extra("extra", {5}, kInt);
2632
2633 auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2634 int j = 0;
2635 for (auto l : loops) {
2636 // Add an increment at each layer of the loop which counts the number of
2637 // times the loop executes.
2638 LoadPtr load =
2639 alloc<Load>(extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}));
2640 AddPtr add = alloc<Add>(load, alloc<IntImm>(1));
2641 StmtPtr store = alloc<Store>(
2642 extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}), add);
2643 if (prepend) {
2644 l->body()->prepend_stmt(store);
2645 }
2646 if (append) {
2647 l->body()->append_stmt(Stmt::clone(store));
2648 }
2649
2650 j++;
2651 }
2652
2653 StmtPtr stmt1 = Stmt::clone(l.root_stmt());
2654
2655 std::vector<int> extra1(5, 0);
2656 std::vector<int> res1(2 * 3 * 2 * 3 * 2, 0);
2657 SimpleIREvaluator cg(stmt1, {c, extra});
2658 cg.call({res1, extra1});
2659
2660 std::vector<int> loopExtents = {2, 3, 2, 3, 2};
2661
2662 int expected_loops = 0;
2663 if (prepend) {
2664 expected_loops++;
2665 }
2666 if (append) {
2667 expected_loops++;
2668 }
2669 for (int i = 0; i < 5; ++i) {
2670 expected_loops *= loopExtents[i];
2671 ASSERT_EQ(extra1[i], expected_loops);
2672 }
2673
2674 loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0);
2675 LoopNest::reorderAxis(loops[index1], loops[index2]);
2676 StmtPtr stmt2 = Stmt::clone(l.root_stmt());
2677
2678 std::ostringstream oss, oss2;
2679 oss << *stmt1;
2680 oss2 << *stmt2;
2681 ASSERT_NE(oss.str(), oss2.str());
2682
2683 std::vector<int> extra2(5, 0);
2684 std::vector<int> res2(2 * 3 * 2 * 3 * 2, 0);
2685 SimpleIREvaluator cg2(stmt2, {c, extra});
2686 cg2.call({res2, extra2});
2687
2688 expected_loops = 0;
2689 if (prepend) {
2690 expected_loops++;
2691 }
2692 if (append) {
2693 expected_loops++;
2694 }
2695
2696 for (int i = 0; i < 5; ++i) {
2697 expected_loops *= loopExtents[i];
2698 ASSERT_EQ(extra2[i], expected_loops);
2699 }
2700
2701 for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) {
2702 ASSERT_EQ(res2[i], res1[i]);
2703 }
2704}
2705
2706TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) {
2707 for (int i = 0; i < 5; ++i) {
2708 for (int j = 0; j < 5; ++j) {
2709 // skip noops, since we check the loop isn't the same after reordering.
2710 if (i != j) {
2711 LoopNestReorderTestHelper(true, false, i, j);
2712 }
2713 }
2714 }
2715}
2716
2717TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) {
2718 for (int i = 0; i < 5; ++i) {
2719 for (int j = 0; j < 5; ++j) {
2720 // skip noops, since we check the loop isn't the same after reordering.
2721 if (i != j) {
2722 LoopNestReorderTestHelper(false, true, i, j);
2723 }
2724 }
2725 }
2726}
2727
2728TEST(LoopNest, LoopNestReorderLongStringFull) {
2729 for (int i = 0; i < 5; ++i) {
2730 for (int j = 0; j < 5; ++j) {
2731 // skip noops, since we check the loop isn't the same after reordering.
2732 if (i != j) {
2733 LoopNestReorderTestHelper(true, true, i, j);
2734 }
2735 }
2736 }
2737}
2738
2739TEST(LoopNest, LoopNestReorderInternalLoopNest) {
2740 const int M = 4;
2741 const int N = 5;
2742 const int K = 6;
2743 BufHandle a_buf("a", {M, N}, kFloat);
2744 BufHandle b_buf("b", {N, K}, kFloat);
2745 BufHandle c_buf("c", {M, N}, kFloat);
2746 BufHandle d_buf("d", {M, K}, kFloat);
2747
2748 Tensor x = Compute(
2749 "x",
2750 {M, N, K},
2751 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2752 return a_buf.load(m, n) * b_buf.load(n, k);
2753 });
2754 Tensor y = Compute(
2755 "y",
2756 {M, N, K},
2757 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2758 return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k);
2759 });
2760 Tensor z = Compute(
2761 "z",
2762 {M, N, K},
2763 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2764 return x.load(m, n, k) + y.load(m, n, k);
2765 });
2766
2767 LoopNest l({z}, {x, y, z});
2768 ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2];
2769 ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0];
2770 LoopNest::reorderAxis(a, b);
2771
2772 l.prepareForCodegen();
2773 StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2774
2775 // Check the IR we produced has the 3 nests in the right order, but k and m
2776 // swapped in the middle.
2777 checkIR(stmt, R"IR(
2778# CHECK: < 4
2779# CHECK: < 5
2780# CHECK: < 6
2781# CHECK: < 6
2782# CHECK: < 5
2783# CHECK: < 4
2784# CHECK: < 4
2785# CHECK: < 5
2786# CHECK: < 6)IR");
2787
2788 {
2789 PaddedBuffer<float> a_v(M, N);
2790 PaddedBuffer<float> b_v(N, K);
2791 PaddedBuffer<float> c_v(M, N);
2792 PaddedBuffer<float> d_v(M, K);
2793
2794 for (int i = 0; i < M; i++) {
2795 for (int j = 0; j < N; j++) {
2796 a_v(i, j) = i * i;
2797 }
2798 }
2799 for (int i = 0; i < N; i++) {
2800 for (int j = 0; j < K; j++) {
2801 b_v(i, j) = j * j;
2802 }
2803 }
2804 for (int i = 0; i < M; i++) {
2805 for (int j = 0; j < N; j++) {
2806 c_v(i, j) = i + j;
2807 }
2808 }
2809 for (int i = 0; i < M; i++) {
2810 for (int j = 0; j < K; j++) {
2811 d_v(i, j) = i * j;
2812 }
2813 }
2814
2815 PaddedBuffer<float> z_v(M, N, K);
2816 PaddedBuffer<float> z_ref(M, N, K);
2817 for (int m = 0; m < M; m++) {
2818 for (int n = 0; n < N; n++) {
2819 for (int k = 0; k < K; k++) {
2820 z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
2821 }
2822 }
2823 }
2824
2825 SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z});
2826 eval(a_v, b_v, c_v, d_v, z_v);
2827 ExpectAllNear(z_v, z_ref, 1e-5);
2828 }
2829}
2830
2831TEST(LoopNest, OuterLoopVectorization) {
2832 Tensor tensor =
2833 Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) {
2834 return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2835 });
2836 LoopNest l({tensor});
2837
2838 ASSERT_TRUE(
2839 LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0]));
2840
2841 StmtPtr root_stmt = l.root_stmt();
2842 BlockPtr outer_block = to<Block>(root_stmt);
2843 ASSERT_NE(outer_block, nullptr);
2844 while (BlockPtr inner_block = to<Block>(outer_block->front())) {
2845 outer_block = inner_block;
2846 }
2847
2848 // Verify that we have only a single loop level remaining after
2849 // vectorization.
2850 ASSERT_EQ(outer_block->nstmts(), 1);
2851 ForPtr for_loop = to<For>(outer_block->front());
2852 ASSERT_NE(for_loop, nullptr);
2853 BlockPtr for_body = for_loop->body();
2854 ASSERT_EQ(for_body->nstmts(), 1);
2855 ASSERT_EQ(to<For>(for_body->front()), nullptr);
2856}
2857
2858TEST(LoopNest, VectorizeLoopNotNormalized) {
2859 // Input IR:
2860 // for (int i = 0; i < 10; i++) {
2861 // for (int j = 1; j < 5; j++) {
2862 // A[i,j] = i * j;
2863 // }
2864 // }
2865 BufHandle a_buf("A", {10, 5}, kInt);
2866 VarHandle i("i", kInt);
2867 VarHandle j("j", kInt);
2868 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
2869 auto inner_for = For::make(j, 1, 5, for_body);
2870 auto outer_for = For::make(i, 0, 10, inner_for);
2871 auto block = Block::make({outer_for});
2872 LoopNest l(block, {a_buf.node()});
2873
2874 ASSERT_TRUE(LoopNest::vectorize(inner_for));
2875 ASSERT_EQ(outer_for->body()->nstmts(), 1);
2876 ASSERT_EQ(to<For>(outer_for->body()->front()), nullptr);
2877}
2878
2879namespace {
2880
2881std::string constantUpperBoundLoopIR(int upper_bound_val) {
2882 ExprHandle upper_bound(upper_bound_val);
2883 Tensor A =
2884 Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; });
2885 LoopNest l({A});
2886 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
2887 StmtPtr unrolled = nullptr;
2888 LoopNest::fullUnroll(loops[0], &unrolled);
2889 std::ostringstream oss;
2890 oss << *unrolled;
2891 return oss.str();
2892}
2893
2894} // namespace
2895
2896TEST(LoopNest, Unroll) {
2897 const std::string actual = constantUpperBoundLoopIR(3);
2898 const std::string& verification_pattern =
2899 R"IR(
2900# CHECK: A[0] = 0;
2901# CHECK: A[1] = 2;
2902# CHECK: A[2] = 4)IR";
2903
2904 torch::jit::testing::FileCheck().run(verification_pattern, actual);
2905}
2906
2907TEST(LoopNest, UnrollOuter) {
2908 ExprHandle outer_bound(3);
2909 ExprHandle inner_bound(4);
2910 Tensor A = Compute(
2911 "A",
2912 {outer_bound, inner_bound},
2913 [&](const VarHandle& x, const VarHandle& y) { return x + y; });
2914 LoopNest l({A});
2915 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
2916 StmtPtr unrolled = nullptr;
2917 LoopNest::fullUnroll(loops[0], &unrolled);
2918 checkIR(unrolled, R"IR(
2919# CHECK: for (int i = 0; i < 4; i++) {
2920# CHECK: A[0, i] = i;
2921# CHECK: }
2922# CHECK: for (int i = 0; i < 4; i++) {
2923# CHECK: A[1, i] = i + 1;
2924# CHECK: }
2925# CHECK: for (int i = 0; i < 4; i++) {
2926# CHECK: A[2, i] = i + 2;
2927# CHECK: })IR");
2928}
2929
2930TEST(LoopNest, UnrollInner) {
2931 ExprHandle outer_bound(3);
2932 ExprHandle inner_bound(4);
2933 Tensor A = Compute(
2934 "A",
2935 {outer_bound, inner_bound},
2936 [&](const VarHandle& x, const VarHandle& y) { return x + y; });
2937 LoopNest l({A});
2938 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
2939 StmtPtr unrolled = nullptr;
2940 LoopNest::fullUnroll(
2941 static_to<For>(loops[0]->body()->stmts().front()), &unrolled);
2942 checkIR(loops[0], R"IR(
2943# CHECK: for (int i = 0; i < 3; i++) {
2944# CHECK: A[i, 0] = i;
2945# CHECK: A[i, 1] = i + 1;
2946# CHECK: A[i, 2] = i + 2;
2947# CHECK: A[i, 3] = i + 3;
2948# CHECK: })IR");
2949}
2950
2951TEST(LoopNest, UnrollMultipleStatements) {
2952 const int kTotalSize = 3;
2953 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
2954 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
2955
2956 VarHandle x("x", kInt);
2957 auto f = For::make(
2958 x,
2959 0,
2960 kTotalSize,
2961 Block::make(
2962 {Store::make(a_buf, {x}, x * 2),
2963 Store::make(b_buf, {x}, Load::make(a_buf, {x}))}));
2964 auto parent_block = Block::make({f});
2965 StmtPtr unrolled = nullptr;
2966 LoopNest::fullUnroll(f, &unrolled);
2967 checkIR(unrolled, R"IR(
2968# CHECK: A[0] = 0;
2969# CHECK: B[0] = A[0];
2970# CHECK: A[1] = 2;
2971# CHECK: B[1] = A[1];
2972# CHECK: A[2] = 4
2973# CHECK: B[2] = A[2];)IR");
2974}
2975
2976TEST(LoopNest, UnrollNonLiteralConstantBounds) {
2977 // Input IR:
2978 // for (int i = 2 - 1; i < 12 / 3; i++) {
2979 // for (int j = 0; j < 4; j++) {
2980 // A[i,j] = i * j;
2981 // }
2982 // }
2983 BufHandle a_buf("A", {3, 4}, kInt);
2984 VarHandle i("i", kInt);
2985 VarHandle j("j", kInt);
2986 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
2987 auto inner_for = For::make(j, 0, 4, for_body);
2988 auto outer_for = For::make(
2989 i,
2990 IntImm::make(2) - IntImm::make(1),
2991 IntImm::make(12) / IntImm::make(3),
2992 inner_for);
2993 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2994 auto b = Block::make({outer_for});
2995
2996 std::vector<ForPtr> loops = {outer_for, inner_for};
2997 StmtPtr unrolled = nullptr;
2998 LoopNest::fullUnroll(loops[0], &unrolled);
2999 checkIR(unrolled, R"IR(
3000# CHECK: for (int j = 0; j < 4; j++) {
3001# CHECK: A[1, j] = j;
3002# CHECK: }
3003# CHECK: for (int j = 0; j < 4; j++) {
3004# CHECK: A[2, j] = 2 * j;
3005# CHECK: }
3006# CHECK: for (int j = 0; j < 4; j++) {
3007# CHECK: A[3, j] = 3 * j;
3008# CHECK: })IR");
3009}
3010
3011TEST(LoopNest, UnrollNonConstantBounds) {
3012 // Input IR:
3013 // for (int i = 0; i < M; i++) {
3014 // for (int j = 0; j < N; j++) {
3015 // A[i, j] = i * j;
3016 // }
3017 // }
3018 VarHandle M("M", kInt);
3019 VarHandle N("N", kInt);
3020 BufHandle a_buf("A", {M, N}, kInt);
3021 VarHandle i("i", kInt);
3022 VarHandle j("j", kInt);
3023 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3024 auto inner_for = For::make(j, 0, N, for_body);
3025 auto outer_for = For::make(i, 0, M, inner_for);
3026 auto block = Block::make({outer_for});
3027 LoopNest l(block, {a_buf.node()});
3028
3029 LoopNest::unroll(inner_for, 8);
3030 l.simplify();
3031 checkIR(l.root_stmt(), R"IR(
3032 # CHECK: for (int i = 0; i < M; i++) {
3033 # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) {
3034 # CHECK: A[i, 8 * j_outer] =
3035 # CHECK: A[i, 8 * j_outer + 1] =
3036 # CHECK: A[i, 2 * (4 * j_outer + 1)] =
3037 # CHECK: A[i, 8 * j_outer + 3] =
3038 # CHECK: A[i, 4 * (2 * j_outer + 1)] =
3039 # CHECK: A[i, 8 * j_outer + 5] =
3040 # CHECK: A[i, 8 * j_outer + 6] =
3041 # CHECK: A[i, 8 * j_outer + 7] =
3042 # CHECK: }
3043 # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) {
3044 # CHECK: A[i, 8 * (N / 8) + j_tail] =
3045 # CHECK: }
3046 # CHECK: }
3047 )IR");
3048}
3049
3050TEST(LoopNest, UnrollByFactorsLessThan2) {
3051 // Input IR:
3052 // for (int i = 0; i < M; i++) {
3053 // for (int j = 0; j < N; j++) {
3054 // A[i, j] = i * j;
3055 // }
3056 // }
3057 VarHandle M("M", kInt);
3058 VarHandle N("N", kInt);
3059 BufHandle a_buf("A", {M, N}, kInt);
3060 VarHandle i("i", kInt);
3061 VarHandle j("j", kInt);
3062 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3063 auto inner_for = For::make(j, 0, N, for_body);
3064 auto outer_for = For::make(i, 0, M, inner_for);
3065 auto block = Block::make({outer_for});
3066 LoopNest l(block, {a_buf.node()});
3067
3068 // Unrolling by factor = 1 should do nothing.
3069 LoopNest::unroll(inner_for, 1);
3070 checkIR(l.root_stmt(), R"IR(
3071 # CHECK: for (int i = 0; i < M; i++) {
3072 # CHECK: for (int j = 0; j < N; j++) {
3073 # CHECK: A[i, j] =
3074 # CHECK: }
3075 # CHECK: }
3076 )IR");
3077
3078 // Unrolling by factor = 0 should do nothing.
3079 LoopNest::unroll(inner_for, 0);
3080 checkIR(l.root_stmt(), R"IR(
3081 # CHECK: for (int i = 0; i < M; i++) {
3082 # CHECK: for (int j = 0; j < N; j++) {
3083 # CHECK: A[i, j] =
3084 # CHECK: }
3085 # CHECK: }
3086 )IR");
3087
3088 // Unrolling by negative factor should do nothing.
3089 LoopNest::unroll(inner_for, -2);
3090 checkIR(l.root_stmt(), R"IR(
3091 # CHECK: for (int i = 0; i < M; i++) {
3092 # CHECK: for (int j = 0; j < N; j++) {
3093 # CHECK: A[i, j] =
3094 # CHECK: }
3095 # CHECK: }
3096 )IR");
3097}
3098
3099TEST(LoopNest, UnrollByFactorEqualToIters) {
3100 // Input IR:
3101 // for (int i = 0; i < 5; i++) {
3102 // A[i] = i * i;
3103 // }
3104 BufHandle a_buf("A", {5}, kInt);
3105 VarHandle i("i", kInt);
3106 auto for_body = Block::make({Store::make(a_buf, {i}, i * i)});
3107 auto for_loop = For::make(i, 0, 5, for_body);
3108 auto block = Block::make({for_loop});
3109 LoopNest l(block, {a_buf.node()});
3110
3111 LoopNest::unroll(for_loop, 5);
3112 checkIR(l.root_stmt(), R"IR(
3113 # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++)
3114 # CHECK: A[5 * i_outer]
3115 # CHECK: A[5 * i_outer + 1]
3116 # CHECK: A[5 * i_outer + 2]
3117 # CHECK: A[5 * i_outer + 3]
3118 # CHECK: A[5 * i_outer + 4]
3119 )IR");
3120}
3121
3122TEST(LoopNest, UnrollEmpty) {
3123 const std::string actual = constantUpperBoundLoopIR(0);
3124 const std::string& verification_pattern = R"IR(
3125# CHECK-NOT: A[
3126 )IR";
3127
3128 torch::jit::testing::FileCheck().run(verification_pattern, actual);
3129}
3130
3131TEST(LoopNest, NoUnroll) {
3132 VarHandle upper_bound("N", kInt);
3133 Tensor A =
3134 Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; });
3135 LoopNest l({A});
3136 std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0];
3137 StmtPtr unrolled = nullptr;
3138 ASSERT_THROWS_WITH(
3139 LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop");
3140}
3141
3142TEST(LoopNest, UnrollWithLet) {
3143 const int kTotalSize = 3;
3144 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3145 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3146
3147 VarHandle e("e", kInt);
3148 VarHandle x("x", kInt);
3149 auto f = For::make(
3150 x,
3151 0,
3152 kTotalSize,
3153 Block::make(
3154 {Let::make(e, 7),
3155 Store::make(a_buf, {x}, e),
3156 Store::make(b_buf, {x}, e + 1)}));
3157 auto parent_block = Block::make({f});
3158 StmtPtr unrolled = nullptr;
3159 LoopNest::fullUnroll(f, &unrolled);
3160 std::ostringstream oss;
3161 oss << *unrolled;
3162 const std::string& verification_pattern =
3163 R"IR(
3164# CHECK: int e = 7;
3165# CHECK: A[0] = e;
3166# CHECK: B[0] = e + 1;
3167# CHECK: A[1] = e;
3168# CHECK: B[1] = e + 1;
3169# CHECK: A[2] = e;
3170# CHECK: B[2] = e + 1;)IR";
3171
3172 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3173
3174 std::vector<int> a_v(kTotalSize, 0);
3175 std::vector<int> b_v(kTotalSize, 0);
3176 SimpleIREvaluator eval(unrolled, {a_buf, b_buf});
3177 eval(a_v, b_v);
3178 for (int i = 0; i < kTotalSize; ++i) {
3179 ASSERT_EQ(a_v[i], 7);
3180 ASSERT_EQ(b_v[i], 8);
3181 }
3182}
3183
3184TEST(LoopNest, IsNormalized) {
3185 // Input IR:
3186 // for (int i = 50; i < 100; i++) {
3187 // A[i] = B[i];
3188 // }
3189 BufHandle a_buf("A", {ExprHandle(100)}, kInt);
3190 BufHandle b_buf("B", {ExprHandle(100)}, kInt);
3191 VarHandle i("i", kInt);
3192 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
3193 auto for_stmt =
3194 For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i})));
3195 Block::make({for_stmt});
3196 ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
3197
3198 for_stmt->set_start(alloc<IntImm>(0));
3199 ASSERT_TRUE(LoopNest::isNormalized(for_stmt));
3200
3201 VarHandle N("N", kInt);
3202 for_stmt->set_start(N.node());
3203 ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
3204}
3205
3206TEST(LoopNest, NormalizeStartPositive) {
3207 // Input IR:
3208 // for (int x = 50; x < 100; x++) {
3209 // A[x] = B[x];
3210 // B[x] = x * 2;
3211 // }
3212 const int kTotalSize = 50;
3213 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3214 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3215 VarHandle x("x", kInt);
3216 auto for_body = Block::make(
3217 {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
3218 Store::make(b_buf, {x}, x * 2)});
3219 auto for_stmt = For::make(x, 50, 100, for_body);
3220 Block::make({for_stmt});
3221
3222 LoopNest::normalize(for_stmt);
3223
3224 auto result = IRSimplifier::simplify(for_stmt);
3225 std::ostringstream oss;
3226 oss << *result;
3227 const std::string& expected_ir =
3228 R"IR(
3229 # CHECK: for (int x = 0; x < 50; x++) {
3230 # CHECK: A[x + 50] = B[x + 50];
3231 # CHECK: B[x + 50] = 2 * (x + 50);
3232 )IR";
3233 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3234}
3235
3236TEST(LoopNest, NormalizeStartNegative) {
3237 // Input IR:
3238 // for (int x = -50; x < 100; x++) {
3239 // A[x + 50] = B[x + 50];
3240 // B[x + 50] = x * 2;
3241 // }
3242 const int kTotalSize = 150;
3243 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3244 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3245 VarHandle x("x", kInt);
3246 auto for_body = Block::make(
3247 {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})),
3248 Store::make(b_buf, {x + 50}, x * 2)});
3249 auto for_stmt = For::make(x, -50, 100, for_body);
3250 Block::make({for_stmt});
3251
3252 LoopNest::normalize(for_stmt);
3253
3254 auto result = IRSimplifier::simplify(for_stmt);
3255 std::ostringstream oss;
3256 oss << *result;
3257 const std::string& expected_ir =
3258 R"IR(
3259 # CHECK: for (int x = 0; x < 150; x++) {
3260 # CHECK: A[x] = B[x];
3261 # CHECK: B[x] = 2 * (x - 50);
3262 )IR";
3263 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3264}
3265
3266TEST(LoopNest, NormalizeStartZero) {
3267 // Input IR:
3268 // for (int x = 0; x < 100; x++) {
3269 // A[x] = B[x];
3270 // B[x] = x * 2;
3271 // }
3272 // Should not be modified.
3273
3274 const int kTotalSize = 100;
3275 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3276 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3277 VarHandle x("x", kInt);
3278 auto for_body = Block::make(
3279 {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
3280 Store::make(b_buf, {x}, x * 2)});
3281 auto for_stmt = For::make(x, 0, 100, for_body);
3282 Block::make({for_stmt});
3283
3284 LoopNest::normalize(for_stmt);
3285
3286 auto result = IRSimplifier::simplify(for_stmt);
3287 std::ostringstream oss;
3288 oss << *result;
3289 const std::string& expected_ir =
3290 R"IR(
3291 # CHECK: for (int x = 0; x < 100; x++) {
3292 # CHECK: A[x] = B[x];
3293 # CHECK: B[x] = 2 * x;
3294 )IR";
3295 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3296}
3297
3298TEST(LoopNest, NormalizeStartVariable) {
3299 // Input IR:
3300 // for (int x = y; x < 100; x++) {
3301 // A[x] = B[x];
3302 // B[x] = x * 2;
3303 // }
3304
3305 const int kTotalSize = 100;
3306 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3307 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);
3308 VarHandle x("x", kInt);
3309 VarHandle y("y", kInt);
3310 auto for_body = Block::make(
3311 {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
3312 Store::make(b_buf, {x}, x * 2)});
3313 auto for_stmt = For::make(x, y, 100, for_body);
3314 auto parent_block = Block::make({for_stmt});
3315
3316 LoopNest::normalize(for_stmt);
3317
3318 auto result = IRSimplifier::simplify(for_stmt);
3319 std::ostringstream oss;
3320 oss << *result;
3321 const std::string& expected_ir =
3322 R"IR(
3323 # CHECK: for (int x = 0; x < 100 - y; x++) {
3324 # CHECK: A[x + y] = B[x + y];
3325 # CHECK: B[x + y] = 2 * (x + y);
3326 )IR";
3327 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3328}
3329
3330TEST(LoopNest, NormalizeOnNestedOuterLoop) {
3331 // Input IR:
3332 // for (int x = 50; x < 100; x++) {
3333 // for (int y = 10; y < 100; y++) {
3334 // A[x] = A[x] + B[y] + y * 2;
3335 // }
3336 // }
3337
3338 BufHandle a_buf("A", {ExprHandle(50)}, kInt);
3339 BufHandle b_buf("B", {ExprHandle(100)}, kInt);
3340 VarHandle x("x", kInt);
3341 VarHandle y("y", kInt);
3342 auto inner_for_body = Store::make(
3343 a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2);
3344 auto inner_for = For::make(y, 10, 100, inner_for_body);
3345 auto for_stmt = For::make(x, 50, 100, inner_for);
3346 Block::make({for_stmt});
3347
3348 LoopNest::normalize(for_stmt);
3349
3350 auto result = IRSimplifier::simplify(for_stmt);
3351 std::ostringstream oss;
3352 oss << *result;
3353 const std::string& expected_ir =
3354 R"IR(
3355 # CHECK: for (int x = 0; x < 50; x++) {
3356 # CHECK: for (int y = 10; y < 100; y++) {
3357 # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y;
3358 )IR";
3359 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3360}
3361
3362TEST(LoopNest, NormalizeOnNestedInnerLoop) {
3363 // Input IR:
3364 // for (int x = 50; x < 100; x++) {
3365 // for (int y = 10; y < 100; y++) {
3366 // A[x] = A[x] + B[y] + y * 2;
3367 // }
3368 // }
3369
3370 BufHandle a_buf("A", {ExprHandle(50)}, kInt);
3371 BufHandle b_buf("B", {ExprHandle(100)}, kInt);
3372 VarHandle x("x", kInt);
3373 VarHandle y("y", kInt);
3374 auto inner_for_body = Store::make(
3375 a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2);
3376 auto inner_for = For::make(y, 10, 100, inner_for_body);
3377 auto for_stmt = For::make(x, 50, 100, inner_for);
3378 Block::make({for_stmt});
3379
3380 LoopNest::normalize(inner_for);
3381
3382 auto result = IRSimplifier::simplify(for_stmt);
3383 std::ostringstream oss;
3384 oss << *result;
3385 const std::string& expected_ir =
3386 R"IR(
3387 # CHECK: for (int x = 50; x < 100; x++) {
3388 # CHECK: for (int y = 0; y < 90; y++) {
3389 # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20;
3390 )IR";
3391 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3392}
3393
3394TEST(LoopNest, NormalizeAndSplitWithTail) {
3395 // Create a dummy tensor to construct LoopNest.
3396 ExprHandle n(100);
3397 BufHandle a("a", {n}, kFloat);
3398 Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
3399 LoopNest l({b});
3400
3401 // Input IR:
3402 // for (int x = 5; x < 10; x++) {
3403 // A[x] = x * 2;
3404 // }
3405 const int kTotalSize = 5;
3406 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
3407 VarHandle x("x", kInt);
3408 auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2));
3409 auto parent_block = Block::make({for_stmt});
3410
3411 LoopNest::normalize(for_stmt);
3412
3413 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3414 ForPtr x_inner;
3415 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3416 ForPtr x_tail;
3417 LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail);
3418
3419 auto x_outer_result = IRSimplifier::simplify(for_stmt);
3420 std::ostringstream oss_outer;
3421 oss_outer << *x_outer_result;
3422 const std::string& expected_outer_ir =
3423 R"IR(
3424 # CHECK: {
3425 # CHECK: }
3426 )IR";
3427 torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str());
3428
3429 auto x_tail_result = IRSimplifier::simplify(x_tail);
3430 std::ostringstream oss_tail;
3431 oss_tail << *x_tail_result;
3432 const std::string& expected_tail_ir =
3433 R"IR(
3434 # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) {
3435 # CHECK: A[x_tail + 5] = 2 * (x_tail + 5);
3436 )IR";
3437 torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str());
3438}
3439
3440TEST(LoopNest, NotNormalizeAndSplitWithTail) {
3441 // Create a dummy tensor to construct LoopNest.
3442 ExprHandle n(100);
3443 BufHandle a("a", {n}, kFloat);
3444 Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
3445 LoopNest l({b});
3446
3447 // Input IR:
3448 // for (int x = 5; x < 15; x++) {
3449 // A[x] = x * 2;
3450 // }
3451 const int kTotalSize = 10;
3452 BufHandle a_buf("A", {kTotalSize}, kInt);
3453 VarHandle x("x", kInt);
3454 auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2));
3455 auto parent_block = Block::make({for_stmt});
3456
3457 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3458 ForPtr x_inner;
3459 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3460 ForPtr x_tail;
3461 LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail);
3462
3463 auto x_outer_result = IRSimplifier::simplify(for_stmt);
3464 std::ostringstream oss_outer;
3465 oss_outer << *x_outer_result;
3466 const std::string& expected_outer_ir =
3467 R"IR(
3468 # CHECK: {
3469 # CHECK: }
3470 )IR";
3471 torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str());
3472
3473 auto x_tail_result = IRSimplifier::simplify(x_tail);
3474 std::ostringstream oss_tail;
3475 oss_tail << *x_tail_result;
3476 const std::string& expected_tail_ir =
3477 R"IR(
3478 # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) {
3479 # CHECK: A[x_tail + 13] = 2 * (x_tail + 13);
3480 )IR";
3481 torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str());
3482}
3483
3484TEST(LoopNest, FlattenSimpleLoopNest2D) {
3485 // Input IR:
3486 // for (int i = 0; i < 10; i++) {
3487 // for (int j = 0; j < 5; j++) {
3488 // A[i,j] = i * j;
3489 // }
3490 // }
3491 BufHandle a_buf("A", {10, 5}, kInt);
3492 VarHandle i("i", kInt);
3493 VarHandle j("j", kInt);
3494 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3495 auto inner_for = For::make(j, 0, 5, for_body);
3496 auto outer_for = For::make(i, 0, 10, inner_for);
3497 auto parent_block = Block::make({outer_for});
3498
3499 std::vector<ForPtr> loops = {outer_for, inner_for};
3500 ForPtr flattened = nullptr;
3501 ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3502 ASSERT_EQ(flattened, loops.front());
3503
3504 auto result = IRSimplifier::simplify(flattened);
3505 std::ostringstream oss;
3506 oss << *result;
3507 const std::string& expected_ir =
3508 R"IR(
3509 # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
3510 # CHECK: A[i_flat / 5, i_flat % 5] =
3511 )IR";
3512 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3513
3514 {
3515 SimpleIREvaluator eval1(loops[0], {a_buf});
3516 PaddedBuffer<int> inp1(10, 5);
3517 eval1(inp1);
3518 SimpleIREvaluator eval2(flattened, {a_buf});
3519 PaddedBuffer<int> inp2(10, 5);
3520 eval2(inp2);
3521 ExpectAllNear(inp1, inp2, 1e-5);
3522 }
3523}
3524
3525TEST(LoopNest, FlattenSimpleLoopNest3D) {
3526 // Input IR:
3527 // for (int i = 0; i < 10; i++) {
3528 // for (int j = 0; j < 5; j++) {
3529 // for (int k = 0; k < 7; k++) {
3530 // A[i,j,k] = i + j * k;
3531 // }
3532 // }
3533 // }
3534 BufHandle a_buf("A", {10, 5, 7}, kInt);
3535 VarHandle i("i", kInt);
3536 VarHandle j("j", kInt);
3537 VarHandle k("k", kInt);
3538 auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)});
3539 auto for1 = For::make(k, 0, 7, for_body);
3540 auto for2 = For::make(j, 0, 5, for1);
3541 auto for3 = For::make(i, 0, 10, for2);
3542 auto parent_block = Block::make({for3});
3543
3544 std::vector<ForPtr> loops = {for3, for2, for1};
3545 ForPtr flattened = nullptr;
3546 ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3547 ASSERT_EQ(flattened, loops.front());
3548
3549 auto result = IRSimplifier::simplify(flattened);
3550 std::ostringstream oss;
3551 oss << *result;
3552 const std::string& expected_ir =
3553 R"IR(
3554 # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) {
3555 # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] =
3556 )IR";
3557 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3558
3559 {
3560 SimpleIREvaluator eval1(loops[0], {a_buf});
3561 PaddedBuffer<int> inp1(10, 5, 7);
3562 eval1(inp1);
3563 SimpleIREvaluator eval2(flattened, {a_buf});
3564 PaddedBuffer<int> inp2(10, 5, 7);
3565 eval2(inp2);
3566 ExpectAllNear(inp1, inp2, 1e-5);
3567 }
3568}
3569
3570TEST(LoopNest, FlattenLoopNestAfterNormalize) {
3571 // Input IR:
3572 // for (int i = 2; i < 10; i++) {
3573 // for (int j = 3; j < 15; j++) {
3574 // A[i - 2,j - 3] = i * j;
3575 // }
3576 // }
3577 BufHandle a_buf("A", {8, 12}, kInt);
3578 VarHandle i("i", kInt);
3579 VarHandle j("j", kInt);
3580 auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)});
3581 auto inner_for = For::make(j, 3, 15, for_body);
3582 auto outer_for = For::make(i, 2, 10, inner_for);
3583 auto parent_block = Block::make({outer_for});
3584
3585 std::vector<ForPtr> loops = {outer_for, inner_for};
3586 ForPtr flattened = nullptr;
3587 ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3588 ASSERT_EQ(flattened, loops.front());
3589
3590 auto result = IRSimplifier::simplify(flattened);
3591 std::ostringstream oss;
3592 oss << *result;
3593 const std::string& expected_ir =
3594 R"IR(
3595 # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) {
3596 # CHECK: A[i_flat / 12, i_flat % 12] =
3597 )IR";
3598 torch::jit::testing::FileCheck().run(expected_ir, oss.str());
3599
3600 {
3601 SimpleIREvaluator eval1(loops[0], {a_buf});
3602 PaddedBuffer<int> inp1(8, 12);
3603 eval1(inp1);
3604 SimpleIREvaluator eval2(flattened, {a_buf});
3605 PaddedBuffer<int> inp2(8, 12);
3606 eval2(inp2);
3607 ExpectAllNear(inp1, inp2, 1e-5);
3608 }
3609}
3610
3611TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) {
3612 // Input IR:
3613 // for (int i = 0; i < 15-5; i++) {
3614 // for (int j = 0; j < 20/4; j++) {
3615 // A[i,j] = i * j;
3616 // }
3617 // }
3618 BufHandle a_buf("A", {10, 5}, kInt);
3619 VarHandle i("i", kInt);
3620 VarHandle j("j", kInt);
3621 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3622 auto inner_for =
3623 For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body);
3624 auto outer_for =
3625 For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for);
3626 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
3627 auto b = Block::make({outer_for});
3628
3629 std::vector<ForPtr> loops = {outer_for, inner_for};
3630 ForPtr flattened = nullptr;
3631 ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
3632 ASSERT_EQ(flattened, loops.front());
3633
3634 auto result = IRSimplifier::simplify(flattened);
3635 checkIR(result, R"IR(
3636 # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) {
3637 # CHECK: A[i_flat / 5, i_flat % 5] =
3638 )IR");
3639
3640 {
3641 SimpleIREvaluator eval1(loops[0], {a_buf});
3642 PaddedBuffer<int> inp1(10, 5);
3643 eval1(inp1);
3644 SimpleIREvaluator eval2(flattened, {a_buf});
3645 PaddedBuffer<int> inp2(10, 5);
3646 eval2(inp2);
3647 ExpectAllNear(inp1, inp2, 1e-5);
3648 }
3649}
3650
3651TEST(LoopNest, FlattenImperfectLoopNest) {
3652 // Input IR:
3653 // for (int i = 0; i < 10; i++) {
3654 // A[i, i] = 0;
3655 // for (int j = 0; j < 15; j++) {
3656 // A[i,j] = i * j;
3657 // }
3658 // }
3659 // Do not flatten.
3660
3661 BufHandle a_buf("A", {10, 15}, kInt);
3662 VarHandle i("i", kInt);
3663 VarHandle j("j", kInt);
3664 auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
3665 auto inner_for = For::make(j, 0, 15, for_body);
3666 auto outer_for = For::make(
3667 i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for}));
3668 auto par = Block::make({outer_for});
3669 HashProvider hasher;
3670 auto hash_before = hasher.hash(par);
3671
3672 std::vector<ForPtr> loops = {outer_for, inner_for};
3673 ForPtr flattened = nullptr;
3674 ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3675 ASSERT_EQ(flattened, nullptr);
3676 auto hash_after = hasher.hash(par);
3677 ASSERT_EQ(hash_before, hash_after);
3678}
3679
3680TEST(LoopNest, FlattenReductionLoopNest) {
3681 // Input IR:
3682 // for (int i = 0; i < 10; i++) {
3683 // S[i] = 0;
3684 // for (int j = 0; j < 15; j++) {
3685 // S[i] = S[i] + A[i,j];
3686 // }
3687 // }
3688 // Do not flatten.
3689
3690 BufHandle a_buf("A", {10, 15}, kInt);
3691 BufHandle s_buf("S", {10}, kInt);
3692 VarHandle i("i", kInt);
3693 VarHandle j("j", kInt);
3694 auto for_body = Block::make({Store::make(
3695 s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))});
3696 auto inner_for = For::make(j, 0, 15, for_body);
3697 auto outer_for =
3698 For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for}));
3699 auto par = Block::make({outer_for});
3700 HashProvider hasher;
3701 auto hash_before = hasher.hash(par);
3702
3703 std::vector<ForPtr> loops = {outer_for, inner_for};
3704 ForPtr flattened = nullptr;
3705 ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3706 ASSERT_EQ(flattened, nullptr);
3707 auto hash_after = hasher.hash(par);
3708 ASSERT_EQ(hash_before, hash_after);
3709}
3710
3711TEST(LoopNest, FlattenReductionLoopNestFromTensor) {
3712 const int M = 3;
3713 const int N = 7;
3714 VarHandle m("m", kInt);
3715 VarHandle n("n", kInt);
3716 BufHandle b("b", {m, n}, kFloat);
3717 Tensor c = Reduce("sum", {M}, Sum(), b, {N});
3718 LoopNest loop({c});
3719 HashProvider hasher;
3720 auto hash_before = hasher.hash(loop.root_stmt());
3721
3722 auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1];
3723 ForPtr flattened = nullptr;
3724 ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3725 ASSERT_EQ(flattened, nullptr);
3726 auto hash_after = hasher.hash(loop.root_stmt());
3727 ASSERT_EQ(hash_before, hash_after);
3728}
3729
3730TEST(LoopNest, FlattenIncorrectLoopsAsInput) {
3731 // Input IR:
3732 // for (int i = 0; i < 10; i++) {
3733 // for (int j = 0; j < 5; j++) {
3734 // A[i,j] = i * j;
3735 // }
3736 // }
3737 // for (int x = 0; x < 10; x++) {
3738 // for (int y = 0; y < 5; y++) {
3739 // A[x,y] = A[x,y] + x + y;
3740 // }
3741 // }
3742 // Flatten({For_i, For_y}) => should not succeed
3743
3744 BufHandle a_buf("A", {10, 5}, kInt);
3745 VarHandle i("i", kInt);
3746 VarHandle j("j", kInt);
3747 VarHandle x("x", kInt);
3748 VarHandle y("y", kInt);
3749 auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
3750 auto inner_for1 = For::make(j, 0, 5, for_body1);
3751 auto outer_for1 = For::make(i, 0, 10, inner_for1);
3752 auto for_body2 = Block::make(
3753 {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
3754 auto inner_for2 = For::make(y, 0, 5, for_body2);
3755 auto outer_for2 = For::make(x, 0, 10, inner_for2);
3756 auto par = Block::make({outer_for1, outer_for2});
3757 HashProvider hasher;
3758 auto hash_before = hasher.hash(par);
3759
3760 std::vector<ForPtr> loops = {outer_for1, inner_for2};
3761 ForPtr flattened = nullptr;
3762 ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
3763 ASSERT_EQ(flattened, nullptr);
3764 auto hash_after = hasher.hash(par);
3765 ASSERT_EQ(hash_before, hash_after);
3766}
3767
3768TEST(LoopNest, DetectInlineRankMismatch) {
3769 const int kTotalSize = 8;
3770
3771 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
3772 Tensor a = Compute(
3773 "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); });
3774 Tensor reshape = Compute(
3775 "reshape",
3776 {kTotalSize / 2, 2},
3777 [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); });
3778 LoopNest l({reshape}, {a, reshape});
3779 ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a)));
3780}
3781
3782TEST(LoopNest, CacheReadsSimple) {
3783 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3784 return i * j;
3785 });
3786 Tensor B =
3787 Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3788 return A.load(i + 30, j + 3);
3789 });
3790 Tensor C =
3791 Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3792 return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3793 });
3794
3795 LoopNest l({B, C}, {A, B, C});
3796 StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1];
3797 LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
3798
3799 l.prepareForCodegen();
3800 StmtPtr result =
3801 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3802 SimpleIREvaluator cg(result, {B, C});
3803 result = cg.stmt();
3804
3805 // just this once: verify the whole thing.
3806 checkIR(result, R"IR(
3807#CHECK: Allocate(A); // dtype=int, dims=[64, 64]
3808#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10]
3809#CHECK: for (int i
3810#CHECK: for (int j
3811#CHECK: A[
3812#CHECK: }
3813#CHECK: }
3814#CHECK: for (int i_1
3815#CHECK: for (int j_1
3816#CHECK: A_local[j_1] = A[
3817#CHECK: }
3818#CHECK: for (int j_2
3819#CHECK: B[j_2 + 10 * i_1] = A_local[j_2];
3820#CHECK: }
3821#CHECK: }
3822#CHECK: for (int i_2
3823#CHECK: for (int j_3
3824#CHECK: C[
3825#CHECK: }
3826#CHECK: }
3827#CHECK: Free(A_local);
3828#CHECK: Free(A);
3829 )IR");
3830
3831 std::vector<int> b_data(200, 0);
3832 std::vector<int> c_data(200, 0);
3833 cg.call({b_data, c_data});
3834
3835 std::vector<int> b_ref(200, 0);
3836 std::vector<int> c_ref(200, 0);
3837
3838 for (int i = 0; i < 20; ++i) {
3839 for (int j = 0; j < 10; ++j) {
3840 b_ref[i * 10 + j] = (i + 30) * (j + 3);
3841 c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3842 }
3843 }
3844
3845 assertAllEqual(b_data, b_ref);
3846 assertAllEqual(c_data, c_ref);
3847}
3848
3849TEST(LoopNest, CacheReadsOuter) {
3850 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3851 return i * j;
3852 });
3853 Tensor B =
3854 Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3855 return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
3856 });
3857 Tensor C =
3858 Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3859 return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3860 });
3861
3862 LoopNest l({B, C}, {A, B, C});
3863 StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0];
3864 LoopNest::cacheAccesses(A.buf(), "A_local", i_loop);
3865
3866 l.prepareForCodegen();
3867 StmtPtr result =
3868 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3869 SimpleIREvaluator cg(result, {B, C});
3870 result = cg.stmt();
3871
3872 checkIR(result, R"IR(
3873#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
3874#CHECK: A_local[j_1 + 11 * i_1] =
3875#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]);
3876 )IR");
3877
3878 std::vector<int> b_data(200, 0);
3879 std::vector<int> c_data(200, 0);
3880 cg.call({b_data, c_data});
3881
3882 std::vector<int> b_ref(200, 0);
3883 std::vector<int> c_ref(200, 0);
3884
3885 for (int i = 0; i < 20; ++i) {
3886 for (int j = 0; j < 10; ++j) {
3887 b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
3888 c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3889 }
3890 }
3891
3892 assertAllEqual(b_data, b_ref);
3893 assertAllEqual(c_data, c_ref);
3894}
3895
3896TEST(LoopNest, CacheReadsInternal) {
3897 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3898 return i * j;
3899 });
3900 Tensor B =
3901 Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3902 return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
3903 });
3904 Tensor C =
3905 Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3906 return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3907 });
3908
3909 LoopNest l({B, C}, {A, B, C});
3910 StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1];
3911 LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
3912 l.prepareForCodegen();
3913 StmtPtr result =
3914 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3915 SimpleIREvaluator cg(result, {B, C});
3916 result = cg.stmt();
3917
3918 checkIR(result, R"IR(
3919#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
3920#CHECK: A_local[k + 11 * j_1] =
3921#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]);
3922 )IR");
3923
3924 std::vector<int> b_data(200, 0);
3925 std::vector<int> c_data(200, 0);
3926 cg.call({b_data, c_data});
3927
3928 std::vector<int> b_ref(200, 0);
3929 std::vector<int> c_ref(200, 0);
3930
3931 for (int i = 0; i < 20; ++i) {
3932 for (int j = 0; j < 10; ++j) {
3933 b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
3934 c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3935 }
3936 }
3937
3938 assertAllEqual(b_data, b_ref);
3939 assertAllEqual(c_data, c_ref);
3940}
3941
3942TEST(LoopNest, CacheReadsInner) {
3943 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3944 return i * j;
3945 });
3946 // note im changing the offset of the first arg of the first call to A.
3947 Tensor B =
3948 Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3949 return A.load(i + 34, j + 40) + A.load(i + 30, j + 41);
3950 });
3951 Tensor C =
3952 Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3953 return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
3954 });
3955
3956 LoopNest l({B, C}, {A, B, C});
3957 StmtPtr body = l.getLoopBodyFor(B);
3958 LoopNest::cacheAccesses(A.buf(), "A_local", body);
3959 l.prepareForCodegen();
3960 StmtPtr result =
3961 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
3962 SimpleIREvaluator cg(result, {B, C});
3963 result = cg.stmt();
3964
3965 checkIR(result, R"IR(
3966#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
3967#CHECK: A_local[l + 2 * k] =
3968#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]);
3969 )IR");
3970
3971 std::vector<int> b_data(200, 0);
3972 std::vector<int> c_data(200, 0);
3973 cg.call({b_data, c_data});
3974
3975 std::vector<int> b_ref(200, 0);
3976 std::vector<int> c_ref(200, 0);
3977
3978 for (int i = 0; i < 20; ++i) {
3979 for (int j = 0; j < 10; ++j) {
3980 b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41);
3981 c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
3982 }
3983 }
3984
3985 assertAllEqual(b_data, b_ref);
3986 assertAllEqual(c_data, c_ref);
3987}
3988
3989TEST(LoopNest, CacheWritesSimple) {
3990 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
3991 return i * j;
3992 });
3993 Tensor B =
3994 Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3995 return A.load(i + 30, j + 40) + A.load(i + 31, j + 41);
3996 });
3997 Tensor C =
3998 Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
3999 return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
4000 });
4001
4002 LoopNest l({B, C}, {A, B, C});
4003 StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1];
4004 LoopNest::cacheAccesses(A.buf(), "A_local", a_loop);
4005
4006 l.prepareForCodegen();
4007 StmtPtr result =
4008 LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
4009 SimpleIREvaluator cg(result, {B, C});
4010 result = cg.stmt();
4011
4012 checkIR(result, R"IR(
4013#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64]
4014#CHECK: for (int j = 0; j < 64
4015#CHECK: A_local[j] = i * j;
4016#CHECK: for (int j_1 = 0; j_1 < 64
4017#CHECK: A[j_1 + 64 * i] = A_local[
4018#CHECK: Free(A_local);
4019#CHECK-NOT: A_local
4020 )IR");
4021
4022 std::vector<int> b_data(200, 0);
4023 std::vector<int> c_data(200, 0);
4024 cg.call({b_data, c_data});
4025
4026 std::vector<int> b_ref(200, 0);
4027 std::vector<int> c_ref(200, 0);
4028
4029 for (int i = 0; i < 20; ++i) {
4030 for (int j = 0; j < 10; ++j) {
4031 b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41);
4032 c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40);
4033 }
4034 }
4035
4036 assertAllEqual(b_data, b_ref);
4037 assertAllEqual(c_data, c_ref);
4038}
4039
4040TEST(LoopNest, DeadStoreElimination) {
4041 VarHandle y("y", kInt);
4042 VarHandle x("x_tail", kInt);
4043 BufHandle f("f", {26, 5}, kInt);
4044 BufHandle g("g", {26, 5}, kInt);
4045 ExprHandle x_outer_end = 5;
4046 ExprHandle x_2 = x + x_outer_end * 4;
4047 ForPtr stmt1 = For::make(
4048 x,
4049 0,
4050 5,
4051 For::make(
4052 y,
4053 0,
4054 5,
4055 Block::make({
4056 Store::make(f, {x_2, y}, (x_2 + y)),
4057 Store::make(g, {x_2, y}, (x_2 * y)),
4058 })));
4059 StmtPtr stmt = Block::make({stmt1});
4060
4061 // Will eliminate if not used by an output.
4062 LoopNest loop(Stmt::clone(stmt), {f.node()});
4063 loop.eliminateDeadStores();
4064
4065 checkIR(loop.root_stmt(), R"IR(
4066#CHECK: f[x_tail + 5 * 4, y]
4067#CHECK-NOT: g[x_tail + 5 * 4, y]
4068 )IR");
4069
4070 // But won't eliminate if used by different outputs.
4071 LoopNest loop2(stmt, {f.node(), g.node()});
4072 loop2.eliminateDeadStores();
4073
4074 checkIR(loop2.root_stmt(), R"IR(
4075#CHECK: f[x_tail + 5 * 4, y]
4076#CHECK: g[x_tail + 5 * 4, y]
4077 )IR");
4078}
4079
4080TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
4081 VarHandle x("x", kInt);
4082 VarHandle y("y", kInt);
4083 VarHandle z("z", kInt);
4084 BufHandle f("f", {26 * 5}, kInt);
4085 BufHandle g("g", {26 * 5}, kInt);
4086 BufHandle h("h", {26, 5}, kInt);
4087 ExprHandle x_outer_end = 5;
4088 ExprHandle x_2 = x + x_outer_end * 4;
4089 ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x));
4090 ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1));
4091 ForPtr stmt3 = For::make(
4092 x,
4093 0,
4094 5,
4095 For::make(
4096 y,
4097 0,
4098 5,
4099 Block::make({
4100 Store::make(h, {x, y}, Load::make(f, {x * y})),
4101 })));
4102 StmtPtr stmt = Block::make({stmt1, stmt2, stmt3});
4103
4104 // Will eliminate the write to g, but not f since it used by the producer of
4105 // h.
4106 LoopNest loop(Stmt::clone(stmt), {h.node()});
4107 loop.eliminateDeadStores();
4108
4109 checkIR(loop.root_stmt(), R"IR(
4110 #CHECK: f[x] = x;
4111 #CHECK-NOT: g[z] =
4112 #CHECK: h[x, y] = f[x * y];
4113 )IR");
4114
4115 // Sanity check won't eliminate if g is an output.
4116 LoopNest loop2(stmt, {h.node(), g.node()});
4117 loop2.eliminateDeadStores();
4118
4119 checkIR(loop2.root_stmt(), R"IR(
4120 #CHECK: f[x] = x;
4121 #CHECK: g[z] = z + 1;
4122 #CHECK: h[x, y] = f[x * y];
4123 )IR");
4124}
4125
4126TEST(LoopNest, CompoundTensorSimple) {
4127 BufHandle a_buf("A", {10, 5}, kInt);
4128 VarHandle i("i", kInt);
4129 VarHandle j("j", kInt);
4130 VarHandle x("x", kInt);
4131 VarHandle y("y", kInt);
4132 auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
4133 auto inner_for1 = For::make(j, 0, 5, for_body1);
4134 auto outer_for1 = For::make(i, 0, 10, inner_for1);
4135 auto for_body2 = Block::make(
4136 {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
4137 auto inner_for2 = For::make(y, 0, 5, for_body2);
4138 auto outer_for2 = For::make(x, 0, 10, inner_for2);
4139 BlockPtr body = Block::make({outer_for1, outer_for2});
4140
4141 Tensor A = Tensor(a_buf.node(), body);
4142
4143 LoopNest l({A});
4144 l.prepareForCodegen();
4145
4146 std::vector<int> a_data(50, 0);
4147
4148 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
4149 SimpleIREvaluator cg(s, {A});
4150
4151 std::vector<int> a_ref(50, 0);
4152
4153 for (int i = 0; i < 10; ++i) {
4154 for (int j = 0; j < 5; ++j) {
4155 a_ref[i * 5 + j] = (i * j) + i + j;
4156 }
4157 }
4158 cg.call({a_data});
4159
4160 assertAllEqual(a_data, a_ref);
4161}
4162
4163TEST(LoopNest, InlineConstantIndex) {
4164 const int N = 10;
4165 BufHandle x_buf("a", {1, N, 1}, kFloat);
4166 Tensor y = Compute(
4167 "f",
4168 {1, N, 1},
4169 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
4170 return x_buf.load(m, n, o);
4171 });
4172 Tensor z = Compute(
4173 "f",
4174 {1, N, 1},
4175 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
4176 return y.load(m, n, o);
4177 });
4178
4179 LoopNest l({z}, {y, z});
4180 l.simplify();
4181 ASSERT_TRUE(l.computeInline(y.buf()));
4182}
4183
4184TEST(LoopNest, CompoundTensorUsed) {
4185 BufHandle a_buf("A", {10, 5}, kInt);
4186 VarHandle i("i", kInt);
4187 VarHandle j("j", kInt);
4188 VarHandle x("x", kInt);
4189 VarHandle y("y", kInt);
4190 auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)});
4191 auto inner_for1 = For::make(j, 0, 5, for_body1);
4192 auto outer_for1 = For::make(i, 0, 10, inner_for1);
4193 auto for_body2 = Block::make(
4194 {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
4195 auto inner_for2 = For::make(y, 0, 5, for_body2);
4196 auto outer_for2 = For::make(x, 0, 10, inner_for2);
4197 BlockPtr body = Block::make({outer_for1, outer_for2});
4198
4199 Tensor A = Tensor(a_buf.node(), body);
4200 Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) {
4201 return A.load(i, j + 1) + A.load(i, j + 2);
4202 });
4203
4204 LoopNest l({B}, {A, B});
4205 ASSERT_FALSE(l.computeInline(A.buf()));
4206 l.prepareForCodegen();
4207
4208 std::vector<int> a_data(50, 0);
4209 std::vector<int> b_data(50, 0);
4210
4211 StmtPtr s = IRSimplifier::simplify(l.root_stmt());
4212 SimpleIREvaluator cg(s, {B});
4213
4214 std::vector<int> b_ref(50, 0);
4215
4216 auto AT = [](int i, int j) { return i * j + i + j; };
4217 for (int i = 0; i < 10; ++i) {
4218 for (int j = 0; j < 3; ++j) {
4219 b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2);
4220 }
4221 }
4222 cg.call({b_data});
4223
4224 assertAllEqual(b_data, b_ref);
4225}
4226
4227TEST(LoopNest, InlineFromLoad) {
4228 constexpr int N = 1024;
4229 BufHandle a("A", {N}, kInt);
4230 BufHandle b("B", {N}, kInt);
4231 VarHandle i("i", kInt);
4232 VarHandle j("j", kInt);
4233 auto store_a = For::make(i, 0, N, Store::make(a, {i}, i));
4234 auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j})));
4235 LoopNest l(Block::make({store_a, store_b}), {b.node()});
4236
4237 l.computeInline(a.node());
4238
4239 // Check that A[j] is replaced with j after inlining
4240 std::ostringstream oss;
4241 oss << *l.root_stmt();
4242 torch::jit::testing::FileCheck().run(
4243 R"IR(
4244# CHECK: for (int j
4245# CHECK-NOT: B[j] = A[j]
4246# CHECK-NEXT: B[j] = j
4247)IR",
4248 oss.str());
4249}
4250
4251TEST(LoopNest, OptimizeConditionalsSimple) {
4252 // Input IR:
4253 // for (int i = 0; i < 20; i++) {
4254 // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4255 // }
4256
4257 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4258 BufHandle a_buf("A", {20}, kInt);
4259 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4260 BufHandle b_buf("B", {5}, kInt);
4261 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4262 BufHandle c_buf("C", {15}, kInt);
4263 VarHandle i("i", kInt);
4264 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4265 auto store = Store::make(
4266 a_buf,
4267 {i},
4268 IfThenElse::make(
4269 CompareSelect::make(i, 5, kLT),
4270 Load::make(b_buf, {i}),
4271 Load::make(c_buf, {i - 5})));
4272 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4273 auto forI = For::make(i, 0, 20, store);
4274 auto par = Block::make({forI});
4275
4276 LoopNest nest(par, {a_buf.node()});
4277 nest.optimizeConditionals();
4278
4279 std::ostringstream oss;
4280 oss << *nest.root_stmt();
4281 const std::string& verification_pattern =
4282 R"IR(
4283# CHECK: for (int i = 0; i < 5
4284# CHECK-NEXT: A[i] = B[i]
4285# CHECK: for (int i = 0; i < 15
4286# CHECK-NEXT: A[i + 5] = C[i]
4287 )IR";
4288 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4289}
4290
4291TEST(LoopNest, OptimizeConditionalsNestedConditions) {
4292 // Input IR:
4293 // for (int i = 0; i < 20; i++) {
4294 // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4295 // }
4296
4297 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4298 BufHandle a_buf("A", {20}, kInt);
4299 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4300 BufHandle b_buf("B", {5}, kInt);
4301 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4302 BufHandle c_buf("C", {5}, kInt);
4303 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4304 BufHandle d_buf("D", {10}, kInt);
4305 VarHandle i("i", kInt);
4306 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4307 auto store = Store::make(
4308 a_buf,
4309 {i},
4310 IfThenElse::make(
4311 CompareSelect::make(i, 10, kLT),
4312 IfThenElse::make(
4313 CompareSelect::make(i, 5, kLT),
4314 Load::make(b_buf, {i}),
4315 Load::make(c_buf, {i - 5})),
4316 Load::make(d_buf, {i - 10})));
4317 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4318 auto forI = For::make(i, 0, 20, store);
4319 auto par = Block::make({forI});
4320
4321 LoopNest nest(par, {a_buf.node()});
4322 nest.optimizeConditionals();
4323
4324 std::ostringstream oss;
4325 oss << *nest.root_stmt();
4326 const std::string& verification_pattern =
4327 R"IR(
4328# CHECK: for (int i = 0; i < 5
4329# CHECK-NEXT: A[i] = B[i]
4330# CHECK: for (int i = 0; i < 5
4331# CHECK-NEXT: A[i + 5] = C[i]
4332# CHECK: for (int i = 0; i < 10
4333# CHECK-NEXT: A[i + 10] = D[i]
4334 )IR";
4335 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4336}
4337
4338TEST(LoopNest, OptimizeConditionalsMultipleStores) {
4339 // Input IR:
4340 // for (int i = 0; i < 20; i++) {
4341 // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4342 // }
4343 // for (int j = 0; j < 100; j++) {
4344 // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j])
4345 // }
4346
4347 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4348 BufHandle a_buf("A", {20}, kInt);
4349 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4350 BufHandle b_buf("B", {5}, kInt);
4351 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4352 BufHandle c_buf("C", {100}, kInt);
4353 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4354 BufHandle d_buf("D", {100}, kInt);
4355 VarHandle i("i", kInt);
4356 VarHandle j("j", kInt);
4357 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4358 auto storeA = Store::make(
4359 a_buf,
4360 {i},
4361 IfThenElse::make(
4362 CompareSelect::make(i, 5, kLT),
4363 Load::make(b_buf, {i}),
4364 Load::make(c_buf, {i - 5})));
4365 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4366 auto forI = For::make(i, 0, 20, storeA);
4367 auto storeB = Store::make(
4368 b_buf,
4369 {j},
4370 IfThenElse::make(
4371 CompareSelect::make(j, 30, kLT),
4372 Load::make(c_buf, {j}),
4373 Load::make(d_buf, {j})));
4374 auto forJ = For::make(j, 0, 100, storeB);
4375 auto par = Block::make({forI, forJ});
4376
4377 LoopNest nest(par, {a_buf.node()});
4378 nest.optimizeConditionals();
4379
4380 std::ostringstream oss;
4381 oss << *nest.root_stmt();
4382 const std::string& verification_pattern =
4383 R"IR(
4384# CHECK: for (int i = 0; i < 5
4385# CHECK-NEXT: A[i] = B[i]
4386# CHECK: for (int i = 0; i < 15
4387# CHECK-NEXT: A[i + 5] = C[i]
4388# CHECK: for (int j = 0; j < 30
4389# CHECK-NEXT: B[j] = C[j]
4390# CHECK: for (int j = 0; j < 70
4391# CHECK-NEXT: B[j + 30] = D[j + 30]
4392 )IR";
4393 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4394}
4395
4396TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) {
4397 // Input IR:
4398 // for (int i = 0; i < 50; i++) {
4399 // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4400 // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j])
4401 // }
4402 // Only the first conditional, in the write to A, will be optimized.
4403
4404 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4405 BufHandle a_buf("A", {100}, kInt);
4406 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4407 BufHandle b_buf("B", {100}, kInt);
4408 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4409 BufHandle c_buf("C", {100}, kInt);
4410 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4411 BufHandle d_buf("D", {100}, kInt);
4412 VarHandle i("i", kInt);
4413 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4414 auto storeA = Store::make(
4415 a_buf,
4416 {i},
4417 IfThenElse::make(
4418 CompareSelect::make(i, 5, kLT),
4419 Load::make(b_buf, {i}),
4420 Load::make(c_buf, {i - 5})));
4421 auto storeB = Store::make(
4422 b_buf,
4423 {i},
4424 IfThenElse::make(
4425 CompareSelect::make(i, 30, kLT),
4426 Load::make(c_buf, {i}),
4427 Load::make(d_buf, {i})));
4428 auto forI = For::make(i, 0, 50, Block::make({storeA, storeB}));
4429 auto par = Block::make({forI});
4430
4431 LoopNest nest(par, {a_buf.node()});
4432 nest.optimizeConditionals();
4433
4434 std::ostringstream oss;
4435 oss << *nest.root_stmt();
4436 const std::string& verification_pattern =
4437 R"IR(
4438# CHECK: for (int i = 0; i < 5
4439# CHECK-NEXT: A[i] = B[i]
4440# CHECK-NEXT: B[i] = C[i]
4441# CHECK: for (int i = 0; i < 45
4442# CHECK-NEXT: A[i + 5] = C[i]
4443# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5])
4444 )IR";
4445 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4446}
4447
4448TEST(LoopNest, OptimizeConditionalsOuterLoopVar) {
4449 // Input IR:
4450 // for (int i = 0; i < 20; i++) {
4451 // for (int j = 0; j < 100; j++) {
4452 // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4453 // }
4454 // }
4455 // Currently, this case where the condition variable `i` is not the
4456 // inner-most loop variable, is not optimized.
4457
4458 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4459 BufHandle a_buf("A", {20}, kInt);
4460 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4461 BufHandle b_buf("B", {5}, kInt);
4462 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4463 BufHandle c_buf("C", {5}, kInt);
4464 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4465 BufHandle d_buf("D", {10}, kInt);
4466 VarHandle i("i", kInt);
4467 VarHandle j("j", kInt);
4468 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4469 auto store = Store::make(
4470 a_buf,
4471 {i},
4472 IfThenElse::make(
4473 CompareSelect::make(i, 10, kLT),
4474 IfThenElse::make(
4475 CompareSelect::make(i, 5, kLT),
4476 Load::make(b_buf, {i}),
4477 Load::make(c_buf, {i - 5})),
4478 Load::make(d_buf, {i - 10})));
4479 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4480 auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store));
4481 auto par = Block::make({forI});
4482 LoopNest nest(par, {a_buf.node()});
4483
4484 HashProvider hasher;
4485 auto hash_before = hasher.hash(nest.root_stmt());
4486 nest.optimizeConditionals();
4487 auto hash_after = hasher.hash(nest.root_stmt());
4488 ASSERT_EQ(hash_before, hash_after);
4489}
4490
4491TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) {
4492 // Input IR:
4493 // for (int i = 0; i < 20; i++) {
4494 // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10])
4495 // }
4496 // No optimization should be done here because one of the conditions use '>'.
4497
4498 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4499 BufHandle a_buf("A", {20}, kInt);
4500 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4501 BufHandle b_buf("B", {5}, kInt);
4502 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4503 BufHandle c_buf("C", {5}, kInt);
4504 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4505 BufHandle d_buf("D", {10}, kInt);
4506 VarHandle i("i", kInt);
4507 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4508 auto store = Store::make(
4509 a_buf,
4510 {i},
4511 IfThenElse::make(
4512 CompareSelect::make(i, 5, kLT),
4513 IfThenElse::make(
4514 CompareSelect::make(i, 10, kLT),
4515 Load::make(b_buf, {i}),
4516 Load::make(c_buf, {i - 5})),
4517 Load::make(d_buf, {i - 10})));
4518 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4519 auto forI = For::make(i, 0, 20, store);
4520 auto par = Block::make({forI});
4521 LoopNest nest(par, {a_buf.node()});
4522
4523 HashProvider hasher;
4524 auto hash_before = hasher.hash(nest.root_stmt());
4525 nest.optimizeConditionals();
4526 auto hash_after = hasher.hash(nest.root_stmt());
4527 ASSERT_EQ(hash_before, hash_after);
4528}
4529
4530TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) {
4531 // Input IR:
4532 // for (int i = 0; i < 20; i++) {
4533 // A[i] = IfThenElse(i<N, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4534 // }
4535 // No optimization should be done here because one of the conditions use '>'.
4536
4537 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4538 BufHandle a_buf("A", {20}, kInt);
4539 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4540 BufHandle b_buf("B", {5}, kInt);
4541 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4542 BufHandle c_buf("C", {5}, kInt);
4543 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4544 BufHandle d_buf("D", {10}, kInt);
4545 VarHandle i("i", kInt);
4546 VarHandle N("N", kInt);
4547 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4548 auto store = Store::make(
4549 a_buf,
4550 {i},
4551 IfThenElse::make(
4552 CompareSelect::make(i, N, kLT),
4553 IfThenElse::make(
4554 CompareSelect::make(i, 5, kLT),
4555 Load::make(b_buf, {i}),
4556 Load::make(c_buf, {i - 5})),
4557 Load::make(d_buf, {i - 10})));
4558 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4559 auto forI = For::make(i, 0, 20, store);
4560 auto par = Block::make({forI});
4561 LoopNest nest(par, {a_buf.node()});
4562
4563 HashProvider hasher;
4564 auto hash_before = hasher.hash(nest.root_stmt());
4565 nest.optimizeConditionals();
4566 auto hash_after = hasher.hash(nest.root_stmt());
4567 ASSERT_EQ(hash_before, hash_after);
4568}
4569
4570TEST(LoopNest, OptimizeConditionalsInvalidCondition) {
4571 // Input IR:
4572 // for (int i = 0; i < 20; i++) {
4573 // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10])
4574 // }
4575 // No optimization should be done here because one of the conditions use '>'.
4576
4577 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4578 BufHandle a_buf("A", {20}, kInt);
4579 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4580 BufHandle b_buf("B", {5}, kInt);
4581 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4582 BufHandle c_buf("C", {5}, kInt);
4583 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4584 BufHandle d_buf("D", {10}, kInt);
4585 VarHandle i("i", kInt);
4586 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4587 auto store = Store::make(
4588 a_buf,
4589 {i},
4590 IfThenElse::make(
4591 CompareSelect::make(i, 10, kLT),
4592 IfThenElse::make(
4593 CompareSelect::make(i, 5, kGT),
4594 Load::make(b_buf, {i}),
4595 Load::make(c_buf, {i - 5})),
4596 Load::make(d_buf, {i - 10})));
4597 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4598 auto forI = For::make(i, 0, 20, store);
4599 auto par = Block::make({forI});
4600 LoopNest nest(par, {a_buf.node()});
4601
4602 HashProvider hasher;
4603 auto hash_before = hasher.hash(nest.root_stmt());
4604 nest.optimizeConditionals();
4605 auto hash_after = hasher.hash(nest.root_stmt());
4606 ASSERT_EQ(hash_before, hash_after);
4607}
4608
4609TEST(LoopNest, OptimizeConditionalsInvalidCondition2) {
4610 // Input IR:
4611 // for (int i = 0; i < 20; i++) {
4612 // A[i] = IfThenElse(10<i, IfThenElse(i<5, B[i], C[i-5]), D[i-10])
4613 // }
4614 // No optimization should be done here because of the invalid condition:
4615 // "10 < i".
4616
4617 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4618 BufHandle a_buf("A", {20}, kInt);
4619 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4620 BufHandle b_buf("B", {5}, kInt);
4621 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4622 BufHandle c_buf("C", {5}, kInt);
4623 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4624 BufHandle d_buf("D", {10}, kInt);
4625 VarHandle i("i", kInt);
4626 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4627 auto store = Store::make(
4628 a_buf,
4629 {i},
4630 IfThenElse::make(
4631 CompareSelect::make(10, i, kLT),
4632 IfThenElse::make(
4633 CompareSelect::make(i, 5, kLT),
4634 Load::make(b_buf, {i}),
4635 Load::make(c_buf, {i - 5})),
4636 Load::make(d_buf, {i - 10})));
4637 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4638 auto forI = For::make(i, 0, 20, store);
4639 auto par = Block::make({forI});
4640 LoopNest nest(par, {a_buf.node()});
4641
4642 HashProvider hasher;
4643 auto hash_before = hasher.hash(nest.root_stmt());
4644 nest.optimizeConditionals();
4645 auto hash_after = hasher.hash(nest.root_stmt());
4646 ASSERT_EQ(hash_before, hash_after);
4647}
4648
4649TEST(LoopNest, OptimizeConditionalsInvalidCondition3) {
4650 // Input IR:
4651 // for (int i = 0; i < 20; i++) {
4652 // A[i] = IfThenElse(i<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10])
4653 // }
4654 // No optimization should be done here because the conditions use different
4655 // variables: "i < 10" and "k < 5"
4656
4657 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4658 BufHandle a_buf("A", {20}, kInt);
4659 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4660 BufHandle b_buf("B", {5}, kInt);
4661 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4662 BufHandle c_buf("C", {5}, kInt);
4663 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4664 BufHandle d_buf("D", {10}, kInt);
4665 VarHandle i("i", kInt);
4666 VarHandle k("k", kInt);
4667 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4668 auto store = Store::make(
4669 a_buf,
4670 {i},
4671 IfThenElse::make(
4672 CompareSelect::make(i, 10, kLT),
4673 IfThenElse::make(
4674 CompareSelect::make(k, 5, kLT),
4675 Load::make(b_buf, {i}),
4676 Load::make(c_buf, {i - 5})),
4677 Load::make(d_buf, {i - 10})));
4678 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4679 auto forI = For::make(i, 0, 20, store);
4680 auto par = Block::make({forI});
4681 LoopNest nest(par, {a_buf.node()});
4682
4683 HashProvider hasher;
4684 auto hash_before = hasher.hash(nest.root_stmt());
4685 nest.optimizeConditionals();
4686 auto hash_after = hasher.hash(nest.root_stmt());
4687 ASSERT_EQ(hash_before, hash_after);
4688}
4689
4690TEST(LoopNest, OptimizeConditionalsInvalidCondition4) {
4691 // Input IR:
4692 // for (int i = 0; i < 20; i++) {
4693 // A[i] = IfThenElse(k<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10])
4694 // }
4695 // No optimization should be done here because the conditions use the
4696 // variable 'k' which is not a loop variable.
4697
4698 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4699 BufHandle a_buf("A", {20}, kInt);
4700 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4701 BufHandle b_buf("B", {5}, kInt);
4702 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4703 BufHandle c_buf("C", {5}, kInt);
4704 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4705 BufHandle d_buf("D", {10}, kInt);
4706 VarHandle i("i", kInt);
4707 VarHandle k("k", kInt);
4708 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4709 auto store = Store::make(
4710 a_buf,
4711 {i},
4712 IfThenElse::make(
4713 CompareSelect::make(k, 10, kLT),
4714 IfThenElse::make(
4715 CompareSelect::make(k, 5, kLT),
4716 Load::make(b_buf, {i}),
4717 Load::make(c_buf, {i - 5})),
4718 Load::make(d_buf, {i - 10})));
4719 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4720 auto forI = For::make(i, 0, 20, store);
4721 auto par = Block::make({forI});
4722 LoopNest nest(par, {a_buf.node()});
4723
4724 HashProvider hasher;
4725 auto hash_before = hasher.hash(nest.root_stmt());
4726 nest.optimizeConditionals();
4727 auto hash_after = hasher.hash(nest.root_stmt());
4728 ASSERT_EQ(hash_before, hash_after);
4729}
4730
4731TEST(LoopNest, OptimizeConditionalsNotNormalized) {
4732 // Input IR:
4733 // for (int i = 2; i < 20; i++) {
4734 // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
4735 // }
4736
4737 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4738 BufHandle a_buf("A", {20}, kInt);
4739 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4740 BufHandle b_buf("B", {5}, kInt);
4741 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4742 BufHandle c_buf("C", {15}, kInt);
4743 VarHandle i("i", kInt);
4744 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4745 auto store = Store::make(
4746 a_buf,
4747 {i},
4748 IfThenElse::make(
4749 CompareSelect::make(i, 5, kLT),
4750 Load::make(b_buf, {i}),
4751 Load::make(c_buf, {i - 5})));
4752 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4753 auto forI = For::make(i, 2, 20, store);
4754 auto par = Block::make({forI});
4755 LoopNest nest(par, {a_buf.node()});
4756
4757 HashProvider hasher;
4758 auto hash_before = hasher.hash(nest.root_stmt());
4759 nest.optimizeConditionals();
4760 auto hash_after = hasher.hash(nest.root_stmt());
4761 ASSERT_EQ(hash_before, hash_after);
4762}
4763
4764static std::pair<BufHandle, Tensor> colReduce(int M, int N) {
4765 BufHandle a("a", {M, N}, kFloat);
4766 Tensor t = Reduce(
4767 "b",
4768 {N},
4769 Sum(),
4770 [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); },
4771 {M});
4772 return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))};
4773}
4774
4775static StmtPtr splitTailReorder(Tensor b) {
4776 constexpr int kVectorWidth = 8;
4777 LoopNest nest({b});
4778 auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0];
4779 nest.splitWithTail(loops[0], kVectorWidth);
4780 // Now the loopnests will look like:
4781 //
4782 // for (int i_outer = 0; ...
4783 // for (int i_inner = 0; ...
4784 // b[i_outer * 8 + i_inner] = float(0);
4785 // for (int j = 0; ...
4786 // b[i_outer * 8 + i_inner] = ReduceOp(...);
4787 //
4788 // for (int i_tail = 0; ...
4789 // b[i_tail + ((100 - 0) / 8) * 8] = float(0);
4790 // for (int j = 0; ...
4791 // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...);
4792 //
4793 // Since there are 4 writes to b, we will get 4 loopnests from the
4794 // call to `getAllLoopNestsWritingToBuf` below.
4795 //
4796 // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)"
4797 // Loopnest #2: {i_outer, i_inner, j};
4798 // We will have to reorder i_inner and j.
4799 auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf());
4800 LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]);
4801 nest.prepareForCodegen();
4802 return nest.root_stmt();
4803}
4804
4805static StmtPtr splitMaskReorder(Tensor b) {
4806 constexpr int kVectorWidth = 8;
4807 LoopNest nest({b});
4808 auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1];
4809 nest.splitWithMask(loops[0], kVectorWidth);
4810 loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1];
4811 LoopNest::reorderAxis(loops[1], loops[2]);
4812 nest.prepareForCodegen();
4813 return nest.root_stmt();
4814}
4815
4816static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) {
4817 int M = immediateAs<int>(p.dim(0));
4818 int N = immediateAs<int>(p.dim(1));
4819 PaddedBuffer<float> a(M, N);
4820 PaddedBuffer<float> b(N);
4821 PaddedBuffer<float> ref(N);
4822 for (int i = 0; i < M; i++) {
4823 for (int j = 0; j < N; j++) {
4824 a(i, j) = 1.0f;
4825 }
4826 }
4827 for (int i = 0; i < N; i++) {
4828 b(i) = 0.0f;
4829 }
4830 for (int i = 0; i < N; i++) {
4831 ref(i) = 76.0f;
4832 }
4833 SimpleIREvaluator(s, {p, t}).call({a, b});
4834 ExpectAllNear(b, ref, 1e-5);
4835}
4836
4837TEST(LoopNest, ColReduceSplitTailEvenReorder) {
4838 constexpr int M = 76, N = 128;
4839 auto p = colReduce(M, N);
4840 StmtPtr s = splitTailReorder(p.second);
4841
4842 std::ostringstream oss;
4843 oss << *s;
4844 const std::string& verification_pattern =
4845 R"IR(
4846# CHECK: for (int i_outer
4847# CHECK-NEXT: for (int i_inner
4848# CHECK-NEXT: b[
4849# CHECK: for (int j
4850# CHECK-NEXT: for (int i_inner
4851# CHECK-NEXT: b[
4852# CHECK-NOT: for (
4853 )IR";
4854 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4855
4856 checkColReduce(s, p.first, p.second);
4857}
4858
4859TEST(LoopNest, ColReduceSplitTailUnevenReorder) {
4860 constexpr int M = 76, N = 100;
4861 auto p = colReduce(M, N);
4862 StmtPtr s = splitTailReorder(p.second);
4863
4864 std::ostringstream oss;
4865 oss << *s;
4866 const std::string& verification_pattern =
4867 R"IR(
4868# CHECK: for (int i_outer
4869# CHECK-NEXT: for (int i_inner
4870# CHECK-NEXT: b[
4871# CHECK: for (int j
4872# CHECK-NEXT: for (int i_inner
4873# CHECK-NEXT: b[
4874# CHECK: for (int i_tail
4875# CHECK-NEXT: b[
4876# CHECK-NEXT: for (int j
4877# CHECK-NEXT: b[
4878 )IR";
4879 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4880
4881 checkColReduce(s, p.first, p.second);
4882}
4883
4884TEST(LoopNest, ColReduceSplitMaskEvenReorder) {
4885 constexpr int M = 76, N = 128;
4886 auto p = colReduce(M, N);
4887 StmtPtr s = splitMaskReorder(p.second);
4888 checkColReduce(s, p.first, p.second);
4889}
4890
4891TEST(LoopNest, ColReduceSplitMaskUnevenReorder) {
4892 constexpr int M = 76, N = 100;
4893 auto p = colReduce(M, N);
4894 StmtPtr s = splitMaskReorder(p.second);
4895 checkColReduce(s, p.first, p.second);
4896}
4897
4898TEST(LoopNest, ReorderAxisWithMultipleConds) {
4899 // Input IR:
4900 // for (int i = 0; i < 20; i++) {
4901 // if i > 5 {
4902 // if i < 10 {
4903 // for (int j = 0; j < 100; j++) {
4904 // A[i] = i * j;
4905 // }
4906 // }
4907 // }
4908 // }
4909 BufHandle a_buf("A", {20}, kInt);
4910 VarHandle i("i", kInt);
4911 VarHandle j("j", kInt);
4912 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j)));
4913 auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr);
4914 auto outer_cond =
4915 Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr);
4916 auto forI = For::make(i, 0, 20, outer_cond);
4917 StmtPtr par = Block::make({forI});
4918 LoopNest l(par, {a_buf.node()});
4919 LoopNest::reorderAxis(forI, forJ);
4920 ASSERT_EQ(par, l.root_stmt());
4921 par = IRSimplifier::simplify(par);
4922
4923 const std::string& verification_pattern =
4924 R"IR(
4925# CHECK: for (int j
4926# CHECK-NEXT: for (int i
4927# CHECK-NEXT: if (i>5
4928# CHECK-NEXT: if (i<10
4929# CHECK-NEXT: A[i] = i * j
4930# CHECK-NOT: for (
4931 )IR";
4932 std::ostringstream oss;
4933 oss << *par;
4934 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
4935}
4936
4937TEST(LoopNest, VectorizeUse) {
4938 constexpr int N = 8;
4939 BufHandle a("a", {N}, kFloat);
4940 Tensor b =
4941 Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; });
4942 Tensor c =
4943 Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; });
4944 LoopNest nest({c}, {b, c});
4945 auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0];
4946 ASSERT_TRUE(LoopNest::vectorize(loops[0]));
4947 loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0];
4948 ASSERT_TRUE(LoopNest::vectorize(loops[0]));
4949 nest.prepareForCodegen();
4950 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
4951 StmtPtr s = nest.root_stmt();
4952 std::ostringstream oss;
4953 oss << *nest.root_stmt();
4954 torch::jit::testing::FileCheck().run(
4955 R"IR(
4956# CHECK: c[Ramp
4957)IR",
4958 oss.str());
4959}
4960
4961const char* int64Loop = R"IR(
4962# CHECK: for (int64_t i = 0ll; i < 12ll; i++) {
4963# CHECK: b[i] = (a[i]) + 1ll;
4964# CHECK: }
4965)IR";
4966
4967TEST(LoopNest, Int64Direct) {
4968 constexpr int64_t N = 12;
4969 BufHandle a("a", {N}, kLong);
4970 BufHandle b("b", {N}, kLong);
4971 VarHandle n("i", kLong);
4972 StmtPtr s = For::make(
4973 n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l)));
4974 s = IRSimplifier::simplify(s);
4975 std::ostringstream oss;
4976 oss << *s;
4977 torch::jit::testing::FileCheck().run(int64Loop, oss.str());
4978}
4979
4980TEST(LoopNest, Int64Compute) {
4981 constexpr int64_t N = 12;
4982 BufHandle a("a", {N}, kLong);
4983 Tensor b = Compute("b", {N}, [&](const VarHandle& n) {
4984 return a.load(n) + LongImm::make(1l);
4985 });
4986 LoopNest nest({b});
4987 nest.prepareForCodegen();
4988 nest.simplify();
4989 std::ostringstream oss;
4990 oss << *nest.root_stmt();
4991 torch::jit::testing::FileCheck().run(int64Loop, oss.str());
4992}
4993
4994TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) {
4995 // Input IR:
4996 // for (int i = 0; i < 20; i++) {
4997 // A[i] = 0;
4998 // for (int j = 0; j < 100; j++) {
4999 // A[i] = A[i] + i * j;
5000 // }
5001 // B[i] = A[i];
5002 // for (int k = 0; k < 50; k++) {
5003 // B[i] = B[i] + i * k;
5004 // }
5005 // }
5006 BufHandle a_buf("A", {20}, kInt);
5007 BufHandle b_buf("B", {20}, kInt);
5008 VarHandle i("i", kInt);
5009 VarHandle j("j", kInt);
5010 VarHandle k("k", kInt);
5011 auto initA = Store::make(a_buf, {i}, 0);
5012 auto forJ = For::make(
5013 j,
5014 0,
5015 100,
5016 Store::make(
5017 a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5018 auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5019 auto forK = For::make(
5020 k,
5021 0,
5022 50,
5023 Store::make(
5024 b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5025 auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5026 auto par = Block::make({forI});
5027
5028 const std::string& verification_pattern =
5029 R"IR(
5030# CHECK: for (int i
5031# CHECK-NEXT: A[i] = 0
5032# CHECK: for (int i
5033# CHECK-NEXT: for (int j
5034# CHECK-NEXT: A[i] =
5035# CHECK: for (int i
5036# CHECK-NEXT: B[i] = A[i]
5037# CHECK: for (int i
5038# CHECK-NEXT: for (int k
5039# CHECK-NEXT: B[i] =
5040# CHECK-NOT: for (
5041 )IR";
5042
5043 LoopNest nest(par, {a_buf.node(), b_buf.node()});
5044 auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB});
5045
5046 std::ostringstream oss;
5047 oss << *par;
5048 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5049
5050 // The first loop after distribution must be same as the original For.
5051 ASSERT_EQ(new_loops.front(), forI);
5052}
5053
5054TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) {
5055 // Input IR:
5056 // for (int i = 0; i < 20; i++) {
5057 // A[i] = 0;
5058 // for (int j = 0; j < 100; j++) {
5059 // A[i] = A[i] + i * j;
5060 // }
5061 // B[i] = A[i];
5062 // for (int k = 0; k < 50; k++) {
5063 // B[i] = B[i] + i * k;
5064 // }
5065 // }
5066 BufHandle a_buf("A", {20}, kInt);
5067 BufHandle b_buf("B", {20}, kInt);
5068 VarHandle i("i", kInt);
5069 VarHandle j("j", kInt);
5070 VarHandle k("k", kInt);
5071 auto initA = Store::make(a_buf, {i}, 0);
5072 auto forJ = For::make(
5073 j,
5074 0,
5075 100,
5076 Store::make(
5077 a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5078 auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5079 auto forK = For::make(
5080 k,
5081 0,
5082 50,
5083 Store::make(
5084 b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5085 auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5086 auto par = Block::make({forI});
5087
5088 LoopNest nest(par, {a_buf.node(), b_buf.node()});
5089 auto new_loops = LoopNest::distributeLoop(forI, {forJ});
5090
5091 std::ostringstream oss;
5092 oss << *par;
5093 const std::string& verification_pattern =
5094 R"IR(
5095# CHECK: for (int i
5096# CHECK-NEXT: A[i] = 0
5097# CHECK-NEXT: for (int j
5098# CHECK-NEXT: A[i] =
5099# CHECK: for (int i
5100# CHECK-NEXT: B[i] = A[i]
5101# CHECK-NEXT: for (int k
5102# CHECK-NEXT: B[i] =
5103# CHECK-NOT: for (
5104 )IR";
5105 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5106
5107 // The first loop after distribution must be same as the original For.
5108 ASSERT_EQ(new_loops.front(), forI);
5109}
5110
5111TEST(LoopNest, DistributeLoopWithoutAnyPivot) {
5112 // Input IR:
5113 // for (int i = 0; i < 20; i++) {
5114 // A[i] = 0;
5115 // for (int j = 0; j < 100; j++) {
5116 // A[i] = A[i] + i * j;
5117 // }
5118 // B[i] = A[i];
5119 // for (int k = 0; k < 50; k++) {
5120 // B[i] = B[i] + i * k;
5121 // }
5122 // }
5123 BufHandle a_buf("A", {20}, kInt);
5124 BufHandle b_buf("B", {20}, kInt);
5125 VarHandle i("i", kInt);
5126 VarHandle j("j", kInt);
5127 VarHandle k("k", kInt);
5128 auto initA = Store::make(a_buf, {i}, 0);
5129 auto forJ = For::make(
5130 j,
5131 0,
5132 100,
5133 Store::make(
5134 a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5135 auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5136 auto forK = For::make(
5137 k,
5138 0,
5139 50,
5140 Store::make(
5141 b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5142 auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5143 auto par = Block::make({forI});
5144
5145 const std::string& verification_pattern =
5146 R"IR(
5147# CHECK: for (int i
5148# CHECK-NEXT: A[i] = 0
5149# CHECK: for (int i
5150# CHECK-NEXT: for (int j
5151# CHECK-NEXT: A[i] =
5152# CHECK: for (int i
5153# CHECK-NEXT: B[i] = A[i]
5154# CHECK: for (int i
5155# CHECK-NEXT: for (int k
5156# CHECK-NEXT: B[i] =
5157# CHECK-NOT: for (
5158 )IR";
5159
5160 LoopNest nest(par, {a_buf.node(), b_buf.node()});
5161 auto new_loops = LoopNest::distributeLoop(forI);
5162
5163 std::ostringstream oss;
5164 oss << *par;
5165 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5166
5167 // The first loop after distribution must be same as the original For.
5168 ASSERT_EQ(new_loops.front(), forI);
5169}
5170
5171TEST(LoopNest, DistributeLoopOverInnerLoops) {
5172 // Input IR:
5173 // for (int i = 0; i < 20; i++) {
5174 // A[i] = 0;
5175 // for (int j = 0; j < 100; j++) {
5176 // A[i] = A[i] + i * j;
5177 // }
5178 // B[i] = A[i];
5179 // for (int k = 0; k < 50; k++) {
5180 // B[i] = B[i] + i * k;
5181 // }
5182 // }
5183 BufHandle a_buf("A", {20}, kInt);
5184 BufHandle b_buf("B", {20}, kInt);
5185 VarHandle i("i", kInt);
5186 VarHandle j("j", kInt);
5187 VarHandle k("k", kInt);
5188 auto initA = Store::make(a_buf, {i}, 0);
5189 auto forJ = For::make(
5190 j,
5191 0,
5192 100,
5193 Store::make(
5194 a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j))));
5195 auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i}));
5196 auto forK = For::make(
5197 k,
5198 0,
5199 50,
5200 Store::make(
5201 b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k))));
5202 auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5203 auto par = Block::make({forI});
5204
5205 LoopNest nest(par, {a_buf.node(), b_buf.node()});
5206 auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI);
5207
5208 std::ostringstream oss;
5209 oss << *par;
5210 const std::string& verification_pattern =
5211 R"IR(
5212# CHECK: for (int i
5213# CHECK-NEXT: A[i] = 0
5214# CHECK-NEXT: for (int j
5215# CHECK-NEXT: A[i] =
5216# CHECK: for (int i
5217# CHECK-NEXT: B[i] = A[i]
5218# CHECK-NEXT: for (int k
5219# CHECK-NEXT: B[i] =
5220# CHECK-NOT: for (
5221 )IR";
5222 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5223
5224 // The first loop after distribution must be same as the original For.
5225 ASSERT_EQ(new_loops.front(), forI);
5226}
5227
5228TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) {
5229 // Input IR:
5230 // for (int m = 0; m < 50; m++) {
5231 // for (int i = 0; i < 20; i++) {
5232 // A[m,i] = 0;
5233 // for (int j = 0; j < 100; j++) {
5234 // A[m,i] = A[m,i] + i * j;
5235 // }
5236 // B[m,i] = A[m,i];
5237 // for (int k = 0; k < 50; k++) {
5238 // B[m,i] = B[m,i] + i * k;
5239 // }
5240 // }
5241 // }
5242 BufHandle a_buf("A", {100, 100}, kInt);
5243 BufHandle b_buf("B", {100, 100}, kInt);
5244 VarHandle m("m", kInt);
5245 VarHandle i("i", kInt);
5246 VarHandle j("j", kInt);
5247 VarHandle k("k", kInt);
5248 auto initA = Store::make(a_buf, {m, i}, 0);
5249 auto forJ = For::make(
5250 j,
5251 0,
5252 100,
5253 Store::make(
5254 a_buf,
5255 {m, i},
5256 Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j))));
5257 auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i}));
5258 auto forK = For::make(
5259 k,
5260 0,
5261 50,
5262 Store::make(
5263 b_buf,
5264 {m, i},
5265 Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k))));
5266 auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK}));
5267
5268 {
5269 // Check the case of distributing loop and its parents over all the
5270 // statements in the loop.
5271 const std::string& verification_pattern =
5272 R"IR(
5273# CHECK: for (int m
5274# CHECK-NEXT: for (int i
5275# CHECK-NEXT: A[m, i] = 0
5276# CHECK: for (int m
5277# CHECK-NEXT: for (int i
5278# CHECK-NEXT: for (int j
5279# CHECK-NEXT: A[m, i] =
5280# CHECK: for (int m
5281# CHECK-NEXT: for (int i
5282# CHECK-NEXT: B[m, i] = A[m, i]
5283# CHECK: for (int m
5284# CHECK-NEXT: for (int i
5285# CHECK-NEXT: for (int k
5286# CHECK-NEXT: B[m, i] =
5287# CHECK-NOT: for (
5288 )IR";
5289
5290 auto newForI = to<For>(Stmt::clone(forI));
5291 auto forM = For::make(m, 0, 50, newForI);
5292 auto par = Block::make({forM});
5293 LoopNest nest(par, {a_buf.node(), b_buf.node()});
5294 auto newLoops = LoopNest::distributeLoopAndParents(newForI);
5295
5296 std::ostringstream oss;
5297 oss << *par;
5298 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5299
5300 // The first loop after distribution must be same as the original For.
5301 ASSERT_EQ(newLoops.front(), forM);
5302 }
5303
5304 {
5305 // Check the case of distributing loop and its parents over all the inner
5306 // loops.
5307 const std::string& verification_pattern =
5308 R"IR(
5309# CHECK: for (int m
5310# CHECK-NEXT: for (int i
5311# CHECK-NEXT: A[m, i] = 0
5312# CHECK-NEXT: for (int j
5313# CHECK-NEXT: A[m, i] =
5314# CHECK: for (int m
5315# CHECK-NEXT: for (int i
5316# CHECK-NEXT: B[m, i] = A[m, i]
5317# CHECK-NEXT: for (int k
5318# CHECK-NEXT: B[m, i] =
5319# CHECK-NOT: for (
5320 )IR";
5321
5322 auto newForI = to<For>(Stmt::clone(forI));
5323 auto forM = For::make(m, 0, 50, newForI);
5324 auto par = Block::make({forM});
5325 LoopNest nest(par, {a_buf.node(), b_buf.node()});
5326 auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI);
5327
5328 std::ostringstream oss;
5329 oss << *par;
5330 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5331
5332 // The first loop after distribution must be same as the original For.
5333 ASSERT_EQ(newLoops.front(), forM);
5334 }
5335}
5336
5337TEST(LoopNest, fuseLoopsSimple) {
5338 // Input IR:
5339 // for (int j = 0; j < 100; j++) {
5340 // A[j] = 10 * j;
5341 // }
5342 // for (int k = 0; k < 100; k++) {
5343 // B[k] = 20 * k;
5344 // }
5345 BufHandle a_buf("A", {100}, kInt);
5346 BufHandle b_buf("B", {100}, kInt);
5347 VarHandle j("j", kInt);
5348 VarHandle k("k", kInt);
5349 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5350 auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
5351 auto par = Block::make({forJ, forK});
5352 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5353 ForPtr fused_loop;
5354 ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5355
5356 std::ostringstream oss;
5357 oss << *par;
5358 const std::string& verification_pattern =
5359 R"IR(
5360# CHECK: for (int j
5361# CHECK-NEXT: A[j] =
5362# CHECK-NEXT: B[j] =
5363# CHECK-NOT: for (
5364 )IR";
5365 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5366
5367 // The fused loop must be the same as the first loop.
5368 ASSERT_EQ(fused_loop, forJ);
5369}
5370
5371TEST(LoopNest, fuseLoopsMultiple) {
5372 // Input IR:
5373 // for (int i = 0; i < 100; i++) {
5374 // A[i+100] = 20 + i;
5375 // }
5376 // for (int j = 0; j < 100; j++) {
5377 // A[j] = 10 * j;
5378 // }
5379 // for (int k = 0; k < 100; k++) {
5380 // B[k] = 20 * k;
5381 // }
5382 BufHandle a_buf("A", {200}, kInt);
5383 BufHandle b_buf("B", {100}, kInt);
5384 VarHandle i("i", kInt);
5385 VarHandle j("j", kInt);
5386 VarHandle k("k", kInt);
5387 auto forI =
5388 For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i)));
5389 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5390 auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
5391 auto par = Block::make({forI, forJ, forK});
5392 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5393 ForPtr fused_loop;
5394 ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop));
5395
5396 std::ostringstream oss;
5397 oss << *par;
5398 const std::string& verification_pattern =
5399 R"IR(
5400# CHECK: for (int i
5401# CHECK-NEXT: A[i + 100] =
5402# CHECK-NEXT: A[i] =
5403# CHECK-NEXT: B[i] =
5404# CHECK-NOT: for (
5405 )IR";
5406 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5407
5408 // The fused loop must be the same as the first loop.
5409 ASSERT_EQ(fused_loop, forI);
5410}
5411
5412TEST(LoopNest, fuseLoopsNested) {
5413 // Input IR:
5414 // for (int m = 0; m < 20; m++) {
5415 // A[m] = 0;
5416 // for (int j = 0; j < 100; j++) {
5417 // A[m] = A[m] + m * j;
5418 // }
5419 // }
5420 // for (int n = 0; n < 20; n++) {
5421 // B[n] = A[n];
5422 // for (int k = 0; k < 50; k++) {
5423 // B[n] = B[n] + n * k;
5424 // }
5425 // }
5426 BufHandle a_buf("A", {20, 100}, kInt);
5427 BufHandle b_buf("B", {20, 100}, kInt);
5428 VarHandle m("m", kInt);
5429 VarHandle n("n", kInt);
5430 VarHandle j("j", kInt);
5431 VarHandle k("k", kInt);
5432 auto initA = Store::make(a_buf, {m}, 0);
5433 auto forJ = For::make(
5434 j,
5435 0,
5436 100,
5437 Store::make(
5438 a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j))));
5439 auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n}));
5440 auto forK = For::make(
5441 k,
5442 0,
5443 50,
5444 Store::make(
5445 b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k))));
5446 auto forM = For::make(m, 0, 20, Block::make({initA, forJ}));
5447 auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
5448 auto par = Block::make({forM, forN});
5449 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5450 ForPtr fused_loop;
5451 ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
5452
5453 std::ostringstream oss;
5454 oss << *par;
5455 const std::string& verification_pattern =
5456 R"IR(
5457# CHECK: for (int m
5458# CHECK-NEXT: A[m] = 0
5459# CHECK-NEXT: for (int j
5460# CHECK-NEXT: A[m] =
5461# CHECK: B[m] = A[m]
5462# CHECK-NEXT: for (int k
5463# CHECK-NEXT: B[m] =
5464# CHECK-NOT: for (
5465 )IR";
5466 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5467
5468 // The fused loop must be the same as the first loop.
5469 ASSERT_EQ(fused_loop, forM);
5470}
5471
5472TEST(LoopNest, fuseLoopsNested2D) {
5473 // Input IR:
5474 // for (int i = 0; i < 20; i++) {
5475 // for (int j = 0; j < 100; j++) {
5476 // A[i,j] = i * j * 500;
5477 // }
5478 // }
5479 // for (int m = 0; m < 20; m++) {
5480 // for (int n = 0; n < 50; n++) {
5481 // B[m,n] = m + n * 100;
5482 // }
5483 // }
5484 BufHandle a_buf("A", {20, 100}, kInt);
5485 BufHandle b_buf("B", {20, 100}, kInt);
5486 VarHandle i("i", kInt);
5487 VarHandle j("j", kInt);
5488 VarHandle m("m", kInt);
5489 VarHandle n("n", kInt);
5490 auto forI = For::make(
5491 i,
5492 0,
5493 20,
5494 For::make(
5495 j,
5496 0,
5497 100,
5498 Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))));
5499 auto forM = For::make(
5500 m,
5501 0,
5502 20,
5503 For::make(
5504 n,
5505 0,
5506 50,
5507 Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100)))));
5508 auto par = Block::make({forI, forM});
5509 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5510 ForPtr fused_loop;
5511 ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5512
5513 std::ostringstream oss;
5514 oss << *par;
5515 const std::string& verification_pattern =
5516 R"IR(
5517# CHECK: for (int i
5518# CHECK-NEXT: for (int j
5519# CHECK-NEXT: A[i, j] =
5520# CHECK: for (int n
5521# CHECK-NEXT: B[i, n] =
5522# CHECK-NOT: for (
5523 )IR";
5524 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5525
5526 // The fused loop must be the same as the first loop.
5527 ASSERT_EQ(fused_loop, forI);
5528}
5529
5530TEST(LoopNest, fuseLoopsNested2DInner) {
5531 // Input IR:
5532 // for (int i = 0; i < 20; i++) {
5533 // for (int j = 0; j < 100; j++) {
5534 // A[i,j] = i * j * 500;
5535 // }
5536 // for (int n = 0; n < 100; n++) {
5537 // B[i,n] = m + n * 100;
5538 // }
5539 // }
5540 BufHandle a_buf("A", {20, 100}, kInt);
5541 BufHandle b_buf("B", {20, 100}, kInt);
5542 VarHandle i("i", kInt);
5543 VarHandle j("j", kInt);
5544 VarHandle n("n", kInt);
5545 auto forJ = For::make(
5546 j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)));
5547 auto forN = For::make(
5548 n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100))));
5549 auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
5550 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5551 ForPtr fused_loop;
5552 ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
5553
5554 std::ostringstream oss;
5555 oss << *forI;
5556 const std::string& verification_pattern =
5557 R"IR(
5558# CHECK: for (int i
5559# CHECK-NEXT: for (int j
5560# CHECK-NEXT: A[i, j] =
5561# CHECK-NEXT: B[i, j] =
5562# CHECK-NOT: for (
5563 )IR";
5564 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5565
5566 // The fused loop must be the same as the first loop.
5567 ASSERT_EQ(fused_loop, forJ);
5568}
5569
5570TEST(LoopNest, fuseLoopsDifferentStopBounds) {
5571 // Input IR:
5572 // for (int j = 0; j < 100; j++) {
5573 // A[j] = 10 * j;
5574 // }
5575 // for (int k = 0; k < 50; k++) {
5576 // B[k] = 20 * k;
5577 // }
5578 BufHandle a_buf("A", {100}, kInt);
5579 BufHandle b_buf("B", {100}, kInt);
5580 VarHandle j("j", kInt);
5581 VarHandle k("k", kInt);
5582 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5583 auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k)));
5584 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5585 auto par = Block::make({forJ, forK});
5586 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5587 ForPtr fused_loop;
5588 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5589}
5590
5591TEST(LoopNest, fuseLoopsDifferentStartBounds) {
5592 // Input IR:
5593 // for (int j = 0; j < 100; j++) {
5594 // A[j] = 10 * j;
5595 // }
5596 // for (int k = 50; k < 100; k++) {
5597 // B[k] = 20 * k;
5598 // }
5599 BufHandle a_buf("A", {100}, kInt);
5600 BufHandle b_buf("B", {100}, kInt);
5601 VarHandle j("j", kInt);
5602 VarHandle k("k", kInt);
5603 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5604 auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
5605 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5606 auto par = Block::make({forJ, forK});
5607 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5608 ForPtr fused_loop;
5609 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5610}
5611
5612TEST(LoopNest, fuseLoopsNotContiguous) {
5613 // Input IR:
5614 // for (int j = 0; j < 100; j++) {
5615 // A[j] = 10 * j;
5616 // }
5617 // B[0] = 0;
5618 // for (int k = 0; k < 100; k++) {
5619 // B[k] = 20 * k;
5620 // }
5621 BufHandle a_buf("A", {100}, kInt);
5622 BufHandle b_buf("B", {100}, kInt);
5623 VarHandle j("j", kInt);
5624 VarHandle k("k", kInt);
5625 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5626 auto initB = Store::make(b_buf, {0}, 0);
5627 auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
5628 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5629 auto par = Block::make({forJ, initB, forK});
5630 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5631 ForPtr fused_loop;
5632 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5633}
5634
5635TEST(LoopNest, fuseLoopsWithDifferentParents) {
5636 // Input IR:
5637 // for (int i = 0; i < 50; i++) {
5638 // for (int j = 0; j < 100; j++) {
5639 // A[i,j] = i * j;
5640 // }
5641 // }
5642 // B[0] = 0;
5643 // for (int k = 50; k < 100; k++) {
5644 // B[k] = 20 * k;
5645 // }
5646 BufHandle a_buf("A", {50, 100}, kInt);
5647 BufHandle b_buf("B", {100}, kInt);
5648 VarHandle i("i", kInt);
5649 VarHandle j("j", kInt);
5650 VarHandle k("k", kInt);
5651 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j)));
5652 auto forI = For::make(i, 0, 50, forJ);
5653 auto initB = Store::make(b_buf, {0}, 0);
5654 auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k)));
5655 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
5656 auto par = Block::make({forI, initB, forK});
5657 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5658 ForPtr fused_loop;
5659 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5660}
5661
5662TEST(LoopNest, fuseLoopsWithVariableBounds) {
5663 // Input IR:
5664 // for (int j = 0; j < N; j++) {
5665 // A[j] = 10 * j;
5666 // }
5667 // for (int k = 0; k < N; k++) {
5668 // B[k] = 20 * k;
5669 // }
5670 BufHandle a_buf("A", {20}, kInt);
5671 BufHandle b_buf("B", {20}, kInt);
5672 VarHandle j("j", kInt);
5673 VarHandle k("k", kInt);
5674 VarHandle N("N", kInt);
5675 auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j)));
5676 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
5677 auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k)));
5678 auto par = Block::make({forJ, forK});
5679 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5680 ForPtr fused_loop;
5681 ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5682
5683 std::ostringstream oss;
5684 oss << *par;
5685 const std::string& verification_pattern =
5686 R"IR(
5687# CHECK: for (int j
5688# CHECK-NEXT: A[j] =
5689# CHECK-NEXT: B[j] =
5690# CHECK-NOT: for (
5691 )IR";
5692 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5693
5694 // The fused loop must be the same as the first loop.
5695 ASSERT_EQ(fused_loop, forJ);
5696}
5697
5698TEST(LoopNest, fuseLoopsWithExprBounds) {
5699 // Input IR:
5700 // for (int j = 0; j < M + N; j++) {
5701 // A[j] = 10 * j;
5702 // }
5703 // for (int k = 0; k < M + N; k++) {
5704 // B[k] = 20 * k;
5705 // }
5706 BufHandle a_buf("A", {20}, kInt);
5707 BufHandle b_buf("B", {20}, kInt);
5708 VarHandle j("j", kInt);
5709 VarHandle k("k", kInt);
5710 VarHandle M("M", kInt);
5711 VarHandle N("N", kInt);
5712 auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j)));
5713 auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k)));
5714 auto par = Block::make({forJ, forK});
5715 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5716 ForPtr fused_loop;
5717 ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5718
5719 std::ostringstream oss;
5720 oss << *par;
5721 const std::string& verification_pattern =
5722 R"IR(
5723# CHECK: for (int j
5724# CHECK-NEXT: A[j] =
5725# CHECK-NEXT: B[j] =
5726# CHECK-NOT: for (
5727 )IR";
5728 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5729
5730 // The fused loop must be the same as the first loop.
5731 ASSERT_EQ(fused_loop, forJ);
5732}
5733
5734TEST(LoopNest, fuseLoopsWithDifferentExprBounds) {
5735 // Input IR:
5736 // for (int j = M; j < N * 2; j++) {
5737 // A[j] = 10 * j;
5738 // }
5739 // for (int k = M; k < N + N; k++) {
5740 // B[k] = 20 * k;
5741 // }
5742 BufHandle a_buf("A", {20}, kInt);
5743 BufHandle b_buf("B", {20}, kInt);
5744 VarHandle j("j", kInt);
5745 VarHandle k("k", kInt);
5746 VarHandle M("M", kInt);
5747 VarHandle N("N", kInt);
5748 auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j)));
5749 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers)
5750 auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k)));
5751 auto par = Block::make({forJ, forK});
5752 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5753 ForPtr fused_loop;
5754 ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5755
5756 std::ostringstream oss;
5757 oss << *par;
5758 const std::string& verification_pattern =
5759 R"IR(
5760# CHECK: for (int j
5761# CHECK-NEXT: A[j] =
5762# CHECK-NEXT: B[j] =
5763# CHECK-NOT: for (
5764 )IR";
5765 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5766
5767 // The fused loop must be the same as the first loop.
5768 ASSERT_EQ(fused_loop, forJ);
5769}
5770
5771TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) {
5772 // Input IR:
5773 // for (int j = 10; j < 100; j++) {
5774 // A[j] = 10 * j;
5775 // }
5776 // for (int k = 10; k < 100; k++) {
5777 // A[k+100] = 30 * k
5778 // }
5779 BufHandle a_buf("A", {200}, kInt);
5780 VarHandle j("j", kInt);
5781 VarHandle k("k", kInt);
5782 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
5783 auto forK =
5784 For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k)));
5785 auto par = Block::make({forJ, forK});
5786
5787 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5788 ForPtr fused_loop;
5789 ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
5790
5791 std::ostringstream oss;
5792 oss << *par;
5793 const std::string& verification_pattern =
5794 R"IR(
5795# CHECK: for (int j
5796# CHECK-NEXT: A[j] =
5797# CHECK-NEXT: A[j + 100] =
5798# CHECK-NOT: for (
5799 )IR";
5800 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5801
5802 // The fused loop must be the same as the first loop.
5803 ASSERT_EQ(fused_loop, forJ);
5804}
5805
5806TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) {
5807 // Input IR:
5808 // for (int i = 0; i < 20; i++) {
5809 // for (int j = 0; j < 100; j++) {
5810 // A[i,j] = i * j * 500;
5811 // }
5812 // }
5813 // for (int m = 0; m < 20; m++) {
5814 // for (int n = 0; n < 50; n++) {
5815 // A[m+20,n+100] = m + n * 100;
5816 // }
5817 // }
5818 BufHandle a_buf("A", {20, 100}, kInt);
5819 BufHandle b_buf("B", {20, 50}, kInt);
5820 VarHandle i("i", kInt);
5821 VarHandle j("j", kInt);
5822 VarHandle m("m", kInt);
5823 VarHandle n("n", kInt);
5824 auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
5825 auto forJ = For::make(j, 0, 100, storeA1);
5826 auto forI = For::make(i, 0, 20, forJ);
5827 auto storeA2 =
5828 Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100)));
5829 auto forN = For::make(n, 0, 50, storeA2);
5830 auto forM = For::make(m, 0, 20, forN);
5831 auto par = Block::make({forI, forM});
5832
5833 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5834 ForPtr fused_loop;
5835 ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5836
5837 std::ostringstream oss;
5838 oss << *par;
5839 const std::string& verification_pattern =
5840 R"IR(
5841# CHECK: for (int i
5842# CHECK-NEXT: for (int j
5843# CHECK-NEXT: A[i, j] =
5844# CHECK: for (int n
5845# CHECK-NEXT: A[i + 20, n + 100] =
5846# CHECK-NOT: for (
5847 )IR";
5848 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5849
5850 // The fused loop must be the same as the first loop.
5851 ASSERT_EQ(fused_loop, forI);
5852}
5853
5854TEST(LoopNest, fuseLoopsWithReductions) {
5855 // Input IR:
5856 // for (int i = 0; i < 20; i++) {
5857 // A[i] = 0
5858 // for (int j = 0; j < 100; j++) {
5859 // A[i] = A[i] + B[i,j];
5860 // }
5861 // }
5862 // for (int m = 0; m < 20; m++) {
5863 // C[m] = A[m];
5864 // }
5865 BufHandle a_buf("A", {20}, kInt);
5866 BufHandle b_buf("B", {20, 100}, kInt);
5867 BufHandle c_buf("C", {20}, kInt);
5868 VarHandle i("i", kInt);
5869 VarHandle j("j", kInt);
5870 VarHandle m("m", kInt);
5871 auto initA = Store::make(a_buf, {i}, 0);
5872 auto sumA = Store::make(
5873 a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j})));
5874 auto forJ = For::make(j, 0, 100, sumA);
5875 auto forI = For::make(i, 0, 20, Block::make({initA, forJ}));
5876 auto forM =
5877 For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m})));
5878 auto par = Block::make({forI, forM});
5879 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5880 ForPtr fused_loop;
5881 ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5882
5883 std::ostringstream oss;
5884 oss << *par;
5885 const std::string& verification_pattern =
5886 R"IR(
5887# CHECK: for (int i
5888# CHECK-NEXT: A[i] =
5889# CHECK-NEXT: for (int j
5890# CHECK-NEXT: A[i] = (A[i]) +
5891# CHECK-NOT: for (
5892# CHECK: C[i] = A[i]
5893 )IR";
5894 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5895
5896 // The fused loop must be the same as the first loop.
5897 ASSERT_EQ(fused_loop, forI);
5898}
5899
5900TEST(LoopNest, fuseLoopsWith2DReductions) {
5901 // Input IR:
5902 // for (int i = 0; i < 20; i++) {
5903 // for (int j = 0; j < 50; j++) {
5904 // A[i,j] = 0
5905 // for (int k = 0; k < 100; k++) {
5906 // A[i,j] = A[i,j] + B[i,j,k];
5907 // }
5908 // }
5909 // }
5910 // for (int m = 0; m < 20; m++) {
5911 // for (int n = 0; n < 40; n++) {
5912 // C[m,n] = A[m,n];
5913 // }
5914 // }
5915 BufHandle a_buf("A", {20, 50}, kInt);
5916 BufHandle b_buf("B", {20, 50, 100}, kInt);
5917 BufHandle c_buf("C", {20, 40}, kInt);
5918 VarHandle i("i", kInt);
5919 VarHandle j("j", kInt);
5920 VarHandle k("k", kInt);
5921 VarHandle m("m", kInt);
5922 VarHandle n("n", kInt);
5923 auto initA = Store::make(a_buf, {i, j}, 0);
5924 auto sumA = Store::make(
5925 a_buf,
5926 {i, j},
5927 Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k})));
5928 auto forK = For::make(k, 0, 100, sumA);
5929 auto forJ = For::make(j, 0, 50, Block::make({initA, forK}));
5930 auto forI = For::make(i, 0, 20, forJ);
5931 auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n}));
5932 auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC));
5933 auto par = Block::make({forI, forM});
5934
5935 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5936 ForPtr fused_loop;
5937 ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5938
5939 std::ostringstream oss;
5940 oss << *par;
5941 const std::string& verification_pattern =
5942 R"IR(
5943# CHECK: for (int i
5944# CHECK-NEXT: for (int j
5945# CHECK-NEXT: A[i, j] =
5946# CHECK-NEXT: for (int k
5947# CHECK-NEXT: A[i, j] = (A[i, j]) +
5948# CHECK: for (int n
5949# CHECK-NEXT: C[i, n] = A[i, n]
5950# CHECK-NOT: for (
5951 )IR";
5952 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5953
5954 // The fused loop must be the same as the first loop.
5955 ASSERT_EQ(fused_loop, forI);
5956}
5957
5958TEST(LoopNest, fuseLoopsWithComplexIndices) {
5959 // Input IR:
5960 // for (int i = 0; i < 20; i++) {
5961 // for (int j = 0; j < 20; j++) {
5962 // A[i,j*20+j+2] = i + j;
5963 // }
5964 // }
5965 // for (int m = 0; m < 20; m++) {
5966 // for (int n = 0; n < 20; n++) {
5967 // B[m,n] = A[m,n*20+n+2];
5968 // }
5969 // }
5970 BufHandle a_buf("A", {20, 400}, kInt);
5971 BufHandle b_buf("B", {20, 400}, kInt);
5972 VarHandle i("i", kInt);
5973 VarHandle j("j", kInt);
5974 VarHandle m("m", kInt);
5975 VarHandle n("n", kInt);
5976 auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j);
5977 auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
5978 auto storeB =
5979 Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2}));
5980 auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
5981 auto par = Block::make({forI, forM});
5982
5983 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5984 ForPtr fused_loop;
5985 ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
5986
5987 std::ostringstream oss;
5988 oss << *par;
5989 const std::string& verification_pattern =
5990 R"IR(
5991# CHECK: for (int i
5992# CHECK-NEXT: for (int j
5993# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j
5994# CHECK: for (int n
5995# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2]
5996# CHECK-NOT: for (
5997 )IR";
5998 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
5999
6000 // The fused loop must be the same as the first loop.
6001 ASSERT_EQ(fused_loop, forI);
6002}
6003
6004TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) {
6005 // Input IR:
6006 // for (int i = 0; i < 20; i++) {
6007 // for (int j = 0; j < 20; j++) {
6008 // A[i,i*20+j] = i + j;
6009 // }
6010 // }
6011 // for (int m = 0; m < 20; m++) {
6012 // for (int n = 0; n < 20; n++) {
6013 // B[m,n] = A[m,m*20+n]; // Both indices of A use m
6014 // }
6015 // }
6016 BufHandle a_buf("A", {20, 500}, kInt);
6017 BufHandle b_buf("B", {20, 500}, kInt);
6018 VarHandle i("i", kInt);
6019 VarHandle j("j", kInt);
6020 VarHandle m("m", kInt);
6021 VarHandle n("n", kInt);
6022 auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j);
6023 auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
6024 auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n}));
6025 auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
6026 auto par = Block::make({forI, forM});
6027
6028 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6029 ForPtr fused_loop;
6030 ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
6031}
6032
6033TEST(LoopNest, fuseLoopsWithTranspose) {
6034 // Input IR:
6035 // for (int i = 0; i < 20; i++) {
6036 // for (int j = 0; j < 20; j++) {
6037 // A[i,j] = i + j;
6038 // }
6039 // }
6040 // for (int m = 0; m < 20; m++) {
6041 // for (int n = 0; n < 20; n++) {
6042 // B[m,n] = A[n,m]; // Transpose
6043 // }
6044 // }
6045 BufHandle a_buf("A", {20, 20}, kInt);
6046 BufHandle b_buf("B", {20, 20}, kInt);
6047 VarHandle i("i", kInt);
6048 VarHandle j("j", kInt);
6049 VarHandle m("m", kInt);
6050 VarHandle n("n", kInt);
6051 auto writeA = Store::make(a_buf, {i, j}, i + j);
6052 auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA));
6053 auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m}));
6054 auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB));
6055 auto par = Block::make({forI, forM});
6056
6057 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6058 ForPtr fused_loop;
6059 ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
6060}
6061
6062TEST(LoopNest, fuseLoopsThatViolateDependencies1) {
6063 // Input IR:
6064 // for (int j = 10; j < 100; j++) {
6065 // A[j] = 10 * j;
6066 // }
6067 // for (int k = 10; k < 100; k++) {
6068 // A[k-1] = 20 * k;
6069 // }
6070 BufHandle a_buf("A", {100}, kInt);
6071 VarHandle j("j", kInt);
6072 VarHandle k("k", kInt);
6073 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6074 auto forK =
6075 For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k)));
6076 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6077 auto par = Block::make({forJ, forK});
6078 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6079 ForPtr fused_loop;
6080 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
6081}
6082
6083TEST(LoopNest, fuseLoopsThatViolateDependencies2) {
6084 // Input IR:
6085 // for (int j = 10; j < 100; j++) {
6086 // A[j] = 10 * j;
6087 // }
6088 // for (int k = 10; k < 100; k++) {
6089 // A[k+50] = 20 * k;
6090 // }
6091 BufHandle a_buf("A", {150}, kInt);
6092 VarHandle j("j", kInt);
6093 VarHandle k("k", kInt);
6094 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6095 auto forK =
6096 For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k)));
6097 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6098 auto par = Block::make({forJ, forK});
6099 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6100 ForPtr fused_loop;
6101 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
6102}
6103
6104TEST(LoopNest, fuseLoopsThatViolateDependencies3) {
6105 // Input IR:
6106 // for (int m = 0; m < 20; m++) {
6107 // A[m] = 0;
6108 // for (int j = 0; j < 100; j++) {
6109 // A[m] = A[m] + m * j;
6110 // }
6111 // }
6112 // for (int n = 0; n < 20; n++) {
6113 // B[n] = A[n+1];
6114 // for (int k = 0; k < 50; k++) {
6115 // B[n] = B[n] + n * k;
6116 // }
6117 // }
6118 BufHandle a_buf("A", {25, 100}, kInt);
6119 BufHandle b_buf("B", {20, 50}, kInt);
6120 VarHandle m("m", kInt);
6121 VarHandle n("n", kInt);
6122 VarHandle j("j", kInt);
6123 VarHandle k("k", kInt);
6124 auto initA = Store::make(a_buf, {m}, 0);
6125 auto forJ = For::make(
6126 j,
6127 0,
6128 100,
6129 Store::make(
6130 a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j))));
6131 auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1}));
6132 auto forK = For::make(
6133 k,
6134 0,
6135 50,
6136 Store::make(
6137 b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k))));
6138 auto forM = For::make(m, 0, 20, Block::make({initA, forJ}));
6139 auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
6140 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6141 auto par = Block::make({forM, forN});
6142 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6143 ForPtr fused_loop;
6144 ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
6145}
6146
6147TEST(LoopNest, fuseLoopsThatViolateDependencies4) {
6148 // Input IR:
6149 // for (int i = 0; i < 20; i++) {
6150 // for (int j = 0; j < 100; j++) {
6151 // A[i,j] = i * j * 500;
6152 // }
6153 // }
6154 // for (int m = 0; m < 20; m++) {
6155 // for (int n = 0; n < 50; n++) {
6156 // A[m+1,n] = m + n * 100;
6157 // }
6158 // }
6159 BufHandle a_buf("A", {30, 100}, kInt);
6160 VarHandle i("i", kInt);
6161 VarHandle j("j", kInt);
6162 VarHandle m("m", kInt);
6163 VarHandle n("n", kInt);
6164 auto forI = For::make(
6165 i,
6166 0,
6167 20,
6168 For::make(
6169 j,
6170 0,
6171 100,
6172 Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))));
6173 auto forM = For::make(
6174 m,
6175 0,
6176 20,
6177 For::make(
6178 n,
6179 0,
6180 50,
6181 Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100)))));
6182 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6183 auto par = Block::make({forI, forM});
6184 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6185 ForPtr fused_loop;
6186 ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
6187}
6188
6189TEST(LoopNest, fuseLoopsThatViolateDependencies5) {
6190 // Input IR:
6191 // for (int i = 0; i < 20; i++) {
6192 // for (int j = 0; j < 100; j++) {
6193 // A[i,j] = i * j * 500;
6194 // }
6195 // for (int n = 0; n < 100; n++) {
6196 // A[i,n+1] = m + n * 100;
6197 // }
6198 // }
6199 BufHandle a_buf("A", {20, 200}, kInt);
6200 VarHandle i("i", kInt);
6201 VarHandle j("j", kInt);
6202 VarHandle n("n", kInt);
6203 auto forJ = For::make(
6204 j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)));
6205 auto forN = For::make(
6206 n,
6207 0,
6208 100,
6209 Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100))));
6210 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers)
6211 auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
6212 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6213 ForPtr fused_loop;
6214 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
6215}
6216
6217TEST(LoopNest, fuseLoopsThatViolateDependencies6) {
6218 // Input IR:
6219 // for (int j = 0; j < 100; j++) {
6220 // A[j] = 10 * j;
6221 // }
6222 // for (int k = 0; k < 100; k++) {
6223 // B[k] = 20 * A[99-k];
6224 // }
6225 BufHandle a_buf("A", {100}, kInt);
6226 BufHandle b_buf("B", {100}, kInt);
6227 VarHandle j("j", kInt);
6228 VarHandle k("k", kInt);
6229 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6230 auto forK = For::make(
6231 k,
6232 0,
6233 100,
6234 Store::make(
6235 b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
6236 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6237 auto par = Block::make({forJ, forK});
6238 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6239 ForPtr fused_loop;
6240 ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
6241}
6242
6243TEST(LoopNest, fuseLoopsThatViolateDependencies7) {
6244 // Input IR:
6245 // for (int k = 0; k < 100; k++) {
6246 // B[k] = 20 * A[99-k];
6247 // }
6248 // for (int j = 0; j < 100; j++) {
6249 // A[j] = 10 * j;
6250 // }
6251 BufHandle a_buf("A", {100}, kInt);
6252 BufHandle b_buf("B", {100}, kInt);
6253 VarHandle j("j", kInt);
6254 VarHandle k("k", kInt);
6255 auto forK = For::make(
6256 k,
6257 0,
6258 100,
6259 Store::make(
6260 b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
6261 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
6262 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6263 auto par = Block::make({forK, forJ});
6264 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
6265 ForPtr fused_loop;
6266 ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop));
6267}
6268
6269TEST(LoopNest, areLoopsPerfectlyNested) {
6270 // Input IR:
6271 // for (int i = 0; i < 20; i++) {
6272 // for (int j = 0; j < 30; j++) {
6273 // for (int k = 0; k < 40; k++) {
6274 // A[i,j,k] = i * j * k;
6275 // }
6276 // }
6277 // }
6278 BufHandle a_buf("A", {20, 30, 40}, kInt);
6279 VarHandle i("i", kInt);
6280 VarHandle j("j", kInt);
6281 VarHandle k("k", kInt);
6282 auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6283 auto forK = For::make(k, 0, 40, store);
6284 auto forJ = For::make(j, 0, 30, forK);
6285 auto forI = For::make(i, 0, 20, forJ);
6286 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6287 auto par = Block::make({forI});
6288 ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6289
6290 // Specifying the loops in any other order fails.
6291 ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK}));
6292 ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ}));
6293 ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI}));
6294
6295 // Adding a statment to forK body should be OK.
6296 auto init = Store::make(a_buf, {i, j}, 0);
6297 forK->body()->insert_stmt_before(init, store);
6298 ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6299
6300 // Adding a statement in forJ body should fail this test.
6301 forK->body()->remove_stmt(init);
6302 forJ->body()->insert_stmt_before(init, forK);
6303 ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6304
6305 // Similarly, adding a statement in forI body should fail this test.
6306 forJ->body()->remove_stmt(init);
6307 forI->body()->insert_stmt_before(init, forJ);
6308 ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6309}
6310
6311TEST(LoopNest, reorderNestedLoops2D) {
6312 // Input IR:
6313 // for (int i = 0; i < 20; i++) {
6314 // for (int j = 0; j < 30; j++) {
6315 // A[i,j] = i * j;
6316 // }
6317 // }
6318 BufHandle a_buf("A", {20, 30, 40}, kInt);
6319 VarHandle i("i", kInt);
6320 VarHandle j("j", kInt);
6321 auto store = Store::make(a_buf, {i, j}, Mul::make(i, j));
6322 auto forJ = For::make(j, 0, 30, store);
6323 auto forI = For::make(i, 0, 20, forJ);
6324 auto par = Block::make({forI});
6325
6326 auto reordered = LoopNest::reorder({forI, forJ}, {1, 0});
6327
6328 ASSERT_EQ(reordered[0], forJ);
6329 ASSERT_EQ(reordered[1], forI);
6330 ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI}));
6331 ASSERT_EQ(forJ->get_parent(), par);
6332 ASSERT_EQ(store->get_parent(), forI->body());
6333}
6334
6335TEST(LoopNest, reorderNestedLoops3D) {
6336 // Input IR:
6337 // for (int i = 0; i < 20; i++) {
6338 // for (int j = 0; j < 30; j++) {
6339 // for (int k = 0; k < 40; k++) {
6340 // A[i,j,k] = i * j * k;
6341 // }
6342 // }
6343 // }
6344 BufHandle a_buf("A", {20, 30, 40}, kInt);
6345 VarHandle i("i", kInt);
6346 VarHandle j("j", kInt);
6347 VarHandle k("k", kInt);
6348 auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6349 auto forK = For::make(k, 0, 40, store);
6350 auto forJ = For::make(j, 0, 30, forK);
6351 auto forI = For::make(i, 0, 20, forJ);
6352 auto par = Block::make({forI});
6353
6354 auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1});
6355
6356 ASSERT_EQ(reordered[0], forK);
6357 ASSERT_EQ(reordered[1], forI);
6358 ASSERT_EQ(reordered[2], forJ);
6359 ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ}));
6360 ASSERT_EQ(forK->get_parent(), par);
6361 ASSERT_EQ(store->get_parent(), forJ->body());
6362}
6363
6364TEST(LoopNest, reorderNestedLoops4D) {
6365 // Input IR:
6366 // for (int i = 0; i < 20; i++) {
6367 // for (int j = 0; j < 30; j++) {
6368 // for (int k = 0; k < 40; k++) {
6369 // for (int l = 0; l < 50; l++) {
6370 // A[i,j,k,l] = i * j * k * l * 500;
6371 // }
6372 // }
6373 // }
6374 // }
6375 BufHandle a_buf("A", {20, 30, 40, 50}, kInt);
6376 VarHandle i("i", kInt);
6377 VarHandle j("j", kInt);
6378 VarHandle k("k", kInt);
6379 VarHandle l("l", kInt);
6380 auto store = Store::make(
6381 a_buf,
6382 {i, j, k, l},
6383 Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500));
6384 auto forL = For::make(l, 0, 50, store);
6385 auto forK = For::make(k, 0, 40, forL);
6386 auto forJ = For::make(j, 0, 30, forK);
6387 auto forI = For::make(i, 0, 20, forJ);
6388 auto par = Block::make({forI});
6389
6390 auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1});
6391
6392 ASSERT_EQ(reordered[0], forK);
6393 ASSERT_EQ(reordered[1], forI);
6394 ASSERT_EQ(reordered[2], forL);
6395 ASSERT_EQ(reordered[3], forJ);
6396 ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ}));
6397 ASSERT_EQ(forK->get_parent(), par);
6398 ASSERT_EQ(store->get_parent(), forJ->body());
6399}
6400
6401TEST(LoopNest, reorderTrivialPermutation) {
6402 // Input IR:
6403 // for (int i = 0; i < 20; i++) {
6404 // for (int j = 0; j < 30; j++) {
6405 // for (int k = 0; k < 40; k++) {
6406 // A[i,j,k] = i * j * k;
6407 // }
6408 // }
6409 // }
6410 BufHandle a_buf("A", {20, 30, 40}, kInt);
6411 VarHandle i("i", kInt);
6412 VarHandle j("j", kInt);
6413 VarHandle k("k", kInt);
6414 auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6415 auto forK = For::make(k, 0, 40, store);
6416 auto forJ = For::make(j, 0, 30, forK);
6417 auto forI = For::make(i, 0, 20, forJ);
6418 auto par = Block::make({forI});
6419
6420 auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2});
6421
6422 ASSERT_EQ(reordered[0], forI);
6423 ASSERT_EQ(reordered[1], forJ);
6424 ASSERT_EQ(reordered[2], forK);
6425 ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK}));
6426 ASSERT_EQ(forI->get_parent(), par);
6427 ASSERT_EQ(store->get_parent(), forK->body());
6428}
6429
6430TEST(LoopNest, reorderInvalidPermutations) {
6431 // Input IR:
6432 // for (int i = 0; i < 20; i++) {
6433 // for (int j = 0; j < 30; j++) {
6434 // for (int k = 0; k < 40; k++) {
6435 // A[i,j,k] = i * j * k;
6436 // }
6437 // }
6438 // }
6439 BufHandle a_buf("A", {20, 30, 40}, kInt);
6440 VarHandle i("i", kInt);
6441 VarHandle j("j", kInt);
6442 VarHandle k("k", kInt);
6443 auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6444 auto forK = For::make(k, 0, 40, store);
6445 auto forJ = For::make(j, 0, 30, forK);
6446 auto forI = For::make(i, 0, 20, forJ);
6447 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6448 auto par = Block::make({forI});
6449
6450 ASSERT_THROWS_WITH(
6451 LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}),
6452 "invalid permutation size");
6453 ASSERT_THROWS_WITH(
6454 LoopNest::reorder({forI, forJ, forK}, {1, 2}),
6455 "invalid permutation size");
6456 ASSERT_THROWS_WITH(
6457 LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}),
6458 "invalid permutation for reorder");
6459 ASSERT_THROWS_WITH(
6460 LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}),
6461 "invalid permutation for reorder");
6462 ASSERT_THROWS_WITH(
6463 LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}),
6464 "invalid permutation for reorder");
6465}
6466
6467TEST(LoopNest, reorderInvalidLoopNest) {
6468 // Input IR:
6469 // for (int i = 0; i < 20; i++) {
6470 // for (int j = 0; j < 30; j++) {
6471 // A[i,j] = 0
6472 // for (int k = 0; k < 40; k++) {
6473 // A[i,j,k] = i * j * k;
6474 // }
6475 // }
6476 // }
6477 BufHandle a_buf("A", {20, 30, 40}, kInt);
6478 VarHandle i("i", kInt);
6479 VarHandle j("j", kInt);
6480 VarHandle k("k", kInt);
6481 auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k));
6482 auto forK = For::make(k, 0, 40, store);
6483 auto forJ = For::make(j, 0, 30, forK);
6484 auto forI = For::make(i, 0, 20, forJ);
6485 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
6486 auto par = Block::make({forI});
6487
6488 // Specifying the loops in incorrect order fails.
6489 ASSERT_THROWS_WITH(
6490 LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}),
6491 "reorder is only allowed on perfectly nested loops");
6492
6493 // Adding a statement to forJ loop fails.
6494 auto init = Store::make(a_buf, {i}, 0);
6495 forJ->body()->insert_stmt_before(init, forK);
6496 ASSERT_THROWS_WITH(
6497 LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}),
6498 "reorder is only allowed on perfectly nested loops");
6499
6500 // Moving that statement to forI loop also fails.
6501 forJ->body()->remove_stmt(init);
6502 forI->body()->insert_stmt_before(init, forJ);
6503 ASSERT_THROWS_WITH(
6504 LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}),
6505 "reorder is only allowed on perfectly nested loops");
6506}
6507
6508TEST(LoopNest, compressBufferSimple) {
6509 // Input IR:
6510 // for (int i = 0; i < 100; ++i) {
6511 // for (int j = 0; j < 200; ++j) {
6512 // A[i,j] = sin(i*j)
6513 // }
6514 // for (int j = 0; j < 199; ++j) {
6515 // B[i,j] = A[i,j] + A[i, j+1]
6516 // }
6517 // }
6518 BufHandle aBuf("A", {100, 200}, kInt);
6519 BufHandle bBuf("B", {100, 200}, kInt);
6520 VarHandle i("i", kInt);
6521 VarHandle j("j", kInt);
6522 auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
6523 auto forJ2 = For::make(
6524 j,
6525 0,
6526 199,
6527 Store::make(
6528 bBuf,
6529 {i, j},
6530 Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
6531 auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
6532 auto par = Block::make({forI});
6533 LoopNest::compressBuffer(aBuf.node(), par);
6534
6535 std::ostringstream oss;
6536 oss << *par;
6537 const std::string& verification_pattern =
6538 R"IR(
6539# CHECK: for (int i
6540# CHECK-NEXT: for (int j
6541# CHECK-NEXT: A[0, j] =
6542# CHECK: for (int j
6543# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1])
6544 )IR";
6545 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6546
6547 ASSERT_EQ(aBuf.node()->ndim(), 2);
6548 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6549 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6550}
6551
6552TEST(LoopNest, compressBufferMultipleDims) {
6553 // Input IR:
6554 // for (int i = 0; i < 100; ++i) {
6555 // for (int j = 0; j < 200; ++j) {
6556 // A[i,j] = sin(i*j)
6557 // B[i,j] = A[i,j] + A[i,j]
6558 // }
6559 // }
6560 BufHandle aBuf("A", {100, 200}, kInt);
6561 BufHandle bBuf("B", {100, 200}, kInt);
6562 VarHandle i("i", kInt);
6563 VarHandle j("j", kInt);
6564 auto store1 = Store::make(aBuf, {i, j}, sin(i * j));
6565 auto store2 = Store::make(
6566 bBuf,
6567 {i, j},
6568 Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j})));
6569 auto forJ = For::make(j, 0, 200, Block::make({store1, store2}));
6570 auto forI = For::make(i, 0, 100, forJ);
6571 auto par = Block::make({forI});
6572 LoopNest::compressBuffer(aBuf.node(), par);
6573
6574 std::ostringstream oss;
6575 oss << *par;
6576 const std::string& verification_pattern =
6577 R"IR(
6578# CHECK: for (int i
6579# CHECK-NEXT: for (int j
6580# CHECK-NEXT: A[0, 0] =
6581# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0])
6582 )IR";
6583 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6584
6585 ASSERT_EQ(aBuf.node()->ndim(), 2);
6586 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6587 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
6588}
6589
6590TEST(LoopNest, compressBufferMultipleDims2) {
6591 // Input IR:
6592 // for (int i = 0; i < 100; ++i) {
6593 // for (int j = 0; j < 200; ++j) {
6594 // for (int k = 0; k < 300; ++k) {
6595 // A[i,j,k] = sin(i*j*k)
6596 // }
6597 // for (int k = 0; k < 299; ++j) {
6598 // B[i,j,k] = A[i,j,k] + A[i,j,k+1]
6599 // }
6600 // }
6601 // }
6602 BufHandle aBuf("A", {100, 200, 300}, kInt);
6603 BufHandle bBuf("B", {100, 200, 300}, kInt);
6604 VarHandle i("i", kInt);
6605 VarHandle j("j", kInt);
6606 VarHandle k("k", kInt);
6607 auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k));
6608 auto forK1 = For::make(k, 0, 300, store1);
6609 auto store2 = Store::make(
6610 bBuf,
6611 {i, j, k},
6612 Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1})));
6613 auto forK2 = For::make(k, 0, 299, store2);
6614 auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2}));
6615 auto forI = For::make(i, 0, 100, forJ);
6616 auto par = Block::make({forI});
6617 LoopNest::compressBuffer(aBuf.node(), par);
6618
6619 std::ostringstream oss;
6620 oss << *par;
6621 const std::string& verification_pattern =
6622 R"IR(
6623# CHECK: for (int i
6624# CHECK-NEXT: for (int j
6625# CHECK-NEXT: for (int k
6626# CHECK-NEXT: A[0, 0, k] =
6627# CHECK: for (int k
6628# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1])
6629 )IR";
6630 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6631
6632 ASSERT_EQ(aBuf.node()->ndim(), 3);
6633 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6634 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
6635 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300);
6636}
6637
6638TEST(LoopNest, compressBufferDifferentOrderIndices) {
6639 // Input IR:
6640 // for (int i = 0; i < 100; ++i) {
6641 // for (int j = 0; j < 200; ++j) {
6642 // A[j, i] = sin(i*j)
6643 // }
6644 // for (int j = 0; j < 99; ++j) {
6645 // B[i, j] = A[j, i] + A[j+1, 0]
6646 // }
6647 // }
6648 BufHandle aBuf("A", {100, 200}, kInt);
6649 BufHandle bBuf("B", {100, 200}, kInt);
6650 VarHandle i("i", kInt);
6651 VarHandle j("j", kInt);
6652 auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j)));
6653 auto forJ2 = For::make(
6654 j,
6655 0,
6656 99,
6657 Store::make(
6658 bBuf,
6659 {i, j},
6660 Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i}))));
6661 auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
6662 auto par = Block::make({forI});
6663 LoopNest::compressBuffer(aBuf.node(), par);
6664
6665 std::ostringstream oss;
6666 oss << *par;
6667 const std::string& verification_pattern =
6668 R"IR(
6669# CHECK: for (int i
6670# CHECK-NEXT: for (int j
6671# CHECK-NEXT: A[j, 0] =
6672# CHECK: for (int j
6673# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0])
6674 )IR";
6675 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6676
6677 ASSERT_EQ(aBuf.node()->ndim(), 2);
6678 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100);
6679 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1);
6680}
6681
6682TEST(LoopNest, compressBufferVariableBounds) {
6683 // Input IR:
6684 // for (int i = 0; i < M; ++i) {
6685 // for (int j = 0; j < N; ++j) {
6686 // A[i,j] = sin(i*j)
6687 // }
6688 // for (int j = 0; j < N-1; ++j) {
6689 // B[i,j] = A[i,j] + A[i, j+1]
6690 // }
6691 // }
6692 BufHandle aBuf("A", {100, 200}, kInt);
6693 BufHandle bBuf("B", {100, 200}, kInt);
6694 VarHandle i("i", kInt);
6695 VarHandle j("j", kInt);
6696 VarHandle M("M", kInt);
6697 VarHandle N("N", kInt);
6698 auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j)));
6699 auto forJ2 = For::make(
6700 j,
6701 0,
6702 N - 1,
6703 Store::make(
6704 bBuf,
6705 {i, j},
6706 Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
6707 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
6708 auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2}));
6709 auto par = Block::make({forI});
6710 LoopNest::compressBuffer(aBuf.node(), par);
6711
6712 std::ostringstream oss;
6713 oss << *par;
6714 const std::string& verification_pattern =
6715 R"IR(
6716# CHECK: for (int i
6717# CHECK-NEXT: for (int j
6718# CHECK-NEXT: A[0, j] =
6719# CHECK: for (int j
6720# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1])
6721 )IR";
6722 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6723
6724 ASSERT_EQ(aBuf.node()->ndim(), 2);
6725 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6726 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6727}
6728
6729TEST(LoopNest, compressBufferNoCommonParentLoops) {
6730 // Input IR:
6731 // for (int i = 0; i < 100; ++i) {
6732 // for (int j = 0; j < 200; ++j) {
6733 // A[i,j] = sin(i*j)
6734 // }
6735 // }
6736 // for (int i = 0; i < 100; ++i) {
6737 // for (int j = 0; j < 199; ++j) {
6738 // B[i,j] = A[i,j] + A[i, j+1]
6739 // }
6740 // }
6741 BufHandle aBuf("A", {100, 200}, kInt);
6742 BufHandle bBuf("B", {100, 200}, kInt);
6743 VarHandle i("i", kInt);
6744 VarHandle j("j", kInt);
6745 auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
6746 auto forJ2 = For::make(
6747 j,
6748 0,
6749 199,
6750 Store::make(
6751 bBuf,
6752 {i, j},
6753 Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1}))));
6754 auto forI1 = For::make(i, 0, 100, forJ1);
6755 auto forI2 = For::make(i, 0, 100, forJ2);
6756 auto par = Block::make({forI1, forI2});
6757 LoopNest::compressBuffer(aBuf.node(), par);
6758
6759 // There should be no change in the buffer or code.
6760 std::ostringstream oss;
6761 oss << *par;
6762 const std::string& verification_pattern =
6763 R"IR(
6764# CHECK: for (int i
6765# CHECK-NEXT: for (int j
6766# CHECK-NEXT: A[i, j] =
6767# CHECK: for (int i
6768# CHECK-NEXT: for (int j
6769# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1])
6770 )IR";
6771 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6772
6773 ASSERT_EQ(aBuf.node()->ndim(), 2);
6774 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100);
6775 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6776}
6777
6778TEST(LoopNest, compressBufferIndicesMixed) {
6779 // Input IR:
6780 // for (int i = 0; i < 100; ++i) {
6781 // for (int j = 0; j < 200; ++j) {
6782 // A[i + j, j] = sin(i*j)
6783 // }
6784 // for (int j = 0; j < 199; ++j) {
6785 // B[i,j] = A[i + j, j] + A[i + j, j+1]
6786 // }
6787 // }
6788 BufHandle aBuf("A", {300, 200}, kInt);
6789 BufHandle bBuf("B", {100, 200}, kInt);
6790 VarHandle i("i", kInt);
6791 VarHandle j("j", kInt);
6792 auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j)));
6793 auto forJ2 = For::make(
6794 j,
6795 0,
6796 199,
6797 Store::make(
6798 bBuf,
6799 {i, j},
6800 Add::make(
6801 Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1}))));
6802 auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2}));
6803 auto par = Block::make({forI});
6804 LoopNest::compressBuffer(aBuf.node(), par);
6805
6806 // There should be no change in the buffer or code.
6807 std::ostringstream oss;
6808 oss << *par;
6809 const std::string& verification_pattern =
6810 R"IR(
6811# CHECK: for (int i
6812# CHECK-NEXT: for (int j
6813# CHECK-NEXT: A[i + j, j] =
6814# CHECK: for (int j
6815# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1])
6816 )IR";
6817 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6818
6819 ASSERT_EQ(aBuf.node()->ndim(), 2);
6820 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300);
6821 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6822}
6823
6824TEST(LoopNest, compressMultipleBuffers) {
6825 // Input IR:
6826 // for (int i = 0; i < 100; ++i) {
6827 // for (int j = 0; j < 200; ++j) {
6828 // A[i,j] = sin(i*j)
6829 // }
6830 // for (int k = 0; k < 199; ++k) {
6831 // B[i,k] = A[i,k] + A[i, k+1]
6832 // }
6833 // for (int m = 0; m < 50; ++m) {
6834 // C[i,m] = B[i,m]
6835 // }
6836 // }
6837 BufHandle aBuf("A", {100, 200}, kInt);
6838 BufHandle bBuf("B", {100, 200}, kInt);
6839 BufHandle cBuf("C", {100, 200}, kInt);
6840 VarHandle i("i", kInt);
6841 VarHandle j("j", kInt);
6842 VarHandle k("k", kInt);
6843 VarHandle m("m", kInt);
6844 auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j)));
6845 auto forK = For::make(
6846 k,
6847 0,
6848 199,
6849 Store::make(
6850 bBuf,
6851 {i, k},
6852 Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1}))));
6853 auto forM =
6854 For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m})));
6855 auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM}));
6856 auto par = Block::make({forI});
6857
6858 // This should compress all buffers A, B, and C as follows:
6859 // A[100, 200] -> A[1, 200]
6860 // B[100, 200] -> B[1, 200]
6861 // C[100, 200] -> C[1, 1]
6862 LoopNest::compressAllBuffers(par);
6863
6864 std::ostringstream oss;
6865 oss << *par;
6866 const std::string& verification_pattern =
6867 R"IR(
6868# CHECK: for (int i
6869# CHECK-NEXT: for (int j
6870# CHECK-NEXT: A[0, j] =
6871# CHECK: for (int k
6872# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1])
6873# CHECK: for (int m
6874# CHECK-NEXT: C[0, 0] = B[0, m]
6875 )IR";
6876 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6877
6878 ASSERT_EQ(aBuf.node()->ndim(), 2);
6879 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1);
6880 IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200);
6881 ASSERT_EQ(bBuf.node()->ndim(), 2);
6882 IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1);
6883 IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200);
6884 ASSERT_EQ(cBuf.node()->ndim(), 2);
6885 IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1);
6886 IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1);
6887}
6888
6889TEST(LoopNest, sanitizeNames) {
6890 std::vector<ExprHandle> dim_args;
6891 // Let's pick names that would overlap with default index names if not
6892 // sanitized properly:
6893 dim_args.emplace_back(ExprHandle(alloc<Var>("i", kInt)));
6894 dim_args.emplace_back(ExprHandle(alloc<Var>("N:2", kInt)));
6895 // Now let's create a many dimensions so that we had to use the same letter
6896 // for different loops
6897 for (int i = 0; i < 10; i++) {
6898 dim_args.emplace_back(ExprHandle(alloc<Var>("N", kInt)));
6899 }
6900
6901 // Now create two Computes with conflicting after sanitization names:
6902 Tensor X = Compute("$X:!", dim_args, [&](const std::vector<VarHandle>& v) {
6903 return v[0] + v[1] + v[9] + 1;
6904 });
6905 Tensor Y = Reduce(
6906 "%X\"+",
6907 {},
6908 Sum(),
6909 [&](const std::vector<VarHandle>& v) { return X.load(v); },
6910 dim_args);
6911
6912 // Finally, let's verify what we got after sanitization:
6913 LoopNest l({X, Y});
6914 StmtPtr s = l.root_stmt();
6915 LoopNest::sanitizeNames(s);
6916
6917 std::ostringstream oss;
6918 oss << *s;
6919 const std::string& verification_pattern =
6920 R"IR(
6921# CHECK: for (int i = 0; i < i_1; i++) {
6922# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) {
6923# CHECK-NEXT: for (int k = 0; k < N_9; k++) {
6924# CHECK-NEXT: for (int l = 0; l < N_8; l++) {
6925# CHECK-NEXT: for (int m = 0; m < N_7; m++) {
6926# CHECK-NEXT: for (int n = 0; n < N_6; n++) {
6927# CHECK-NEXT: for (int o = 0; o < N_5; o++) {
6928# CHECK-NEXT: for (int p = 0; p < N_4; p++) {
6929# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) {
6930# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) {
6931# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) {
6932# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) {
6933# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1;
6934# CHECK: v_X___1 = int(0);
6935# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) {
6936# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) {
6937# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) {
6938# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) {
6939# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) {
6940# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) {
6941# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) {
6942# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) {
6943# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) {
6944# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) {
6945# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) {
6946# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) {
6947# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1});
6948 )IR";
6949 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
6950}
6951
6952} // namespace jit
6953} // namespace torch
6954