1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/tensorexpr/test_base.h> |
4 | #include <memory> |
5 | #include <sstream> |
6 | #include <stdexcept> |
7 | #include <unordered_map> |
8 | |
9 | #include <test/cpp/tensorexpr/padded_buffer.h> |
10 | #include <test/cpp/tensorexpr/test_utils.h> |
11 | #include <torch/csrc/jit/tensorexpr/analysis.h> |
12 | #include <torch/csrc/jit/tensorexpr/bounds_inference.h> |
13 | #include <torch/csrc/jit/tensorexpr/eval.h> |
14 | #include <torch/csrc/jit/tensorexpr/ir.h> |
15 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
16 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
17 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
18 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
19 | #include <torch/csrc/jit/testing/file_check.h> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | |
24 | using namespace torch::jit::tensorexpr; |
25 | |
26 | void checkIR(StmtPtr s, const std::string& pattern) { |
27 | std::ostringstream oss; |
28 | oss << *s; |
29 | torch::jit::testing::FileCheck().run(pattern, oss.str()); |
30 | } |
31 | |
32 | void checkExprIR(ExprPtr e, const std::string& pattern) { |
33 | std::string prefixed_pattern = "# CHECK: " + pattern + "\n" ; |
34 | std::ostringstream oss; |
35 | oss << *e << "\n" ; |
36 | torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); |
37 | } |
38 | |
39 | void checkExprIR(const ExprHandle& e, const std::string& pattern) { |
40 | checkExprIR(e.node(), pattern); |
41 | } |
42 | |
43 | TEST(LoopNest, ExprSimple01) { |
44 | Tensor tensor = |
45 | Compute("f" , {16, 5}, [](const VarHandle& x, const VarHandle& y) { |
46 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
47 | }); |
48 | LoopNest l({tensor}); |
49 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
50 | |
51 | LoopNest::splitWithTail(loops[0], 2); |
52 | LoopNest::splitWithTail(loops[0], 2); |
53 | } |
54 | |
55 | TEST(LoopNest, ExprLower01) { |
56 | Tensor tensor = |
57 | Compute("f" , {16, 5}, [](const VarHandle& x, const VarHandle& y) { |
58 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
59 | }); |
60 | LoopNest l({tensor}); |
61 | StmtPtr stmt = l.root_stmt(); |
62 | std::ostringstream oss; |
63 | oss << *stmt; |
64 | ASSERT_GT(oss.str().size(), 20); |
65 | ASSERT_LT(oss.str().size(), 200); |
66 | } |
67 | |
68 | TEST(LoopNest, ExprSimple02) { |
69 | auto func = [](const ExprHandle& x, const ExprHandle& y) { |
70 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
71 | }; |
72 | Tensor tensor = Compute("f" , {26, 5}, func); |
73 | LoopNest l({tensor}); |
74 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
75 | |
76 | LoopNest::splitWithTail(loops[0], 4); |
77 | |
78 | StmtPtr stmt = l.root_stmt(); |
79 | std::ostringstream oss; |
80 | oss << *stmt; |
81 | ASSERT_GT(oss.str().size(), 200); |
82 | ASSERT_LT(oss.str().size(), 600); |
83 | |
84 | { |
85 | // Compare to a reference loop structure structure. |
86 | VarHandle x_outer("i_outer" , kInt); |
87 | VarHandle x_inner("i_inner" , kInt); |
88 | VarHandle y("i" , kInt); |
89 | VarHandle x_tail("i_tail" , kInt); |
90 | BufHandle f("f" , {26, 5}, kFloat); |
91 | ExprHandle x_1 = x_outer * 4 + x_inner; |
92 | ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; |
93 | ForPtr stmt1 = For::make( |
94 | x_outer, |
95 | 0, |
96 | x_outer_end, |
97 | For::make( |
98 | x_inner, |
99 | 0, |
100 | 4, |
101 | For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); |
102 | ExprHandle x_2 = x_tail + x_outer_end * 4; |
103 | ForPtr stmt2 = For::make( |
104 | x_tail, |
105 | 0, |
106 | (ExprHandle(26) - 0) % 4, |
107 | For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); |
108 | StmtPtr stmt = Block::make({stmt1, stmt2}); |
109 | |
110 | std::ostringstream oss_ref; |
111 | oss_ref << *stmt; |
112 | ASSERT_EQ(oss.str(), oss_ref.str()); |
113 | } |
114 | |
115 | { |
116 | PaddedBuffer<float> f_v(26, 5, "f_v" ); |
117 | PaddedBuffer<float> f_ref(26, 5, "f_res" ); |
118 | |
119 | stmt = FlattenIndexes(stmt); |
120 | SimpleIREvaluator ir_eval(stmt, {tensor}); |
121 | ir_eval(f_v); |
122 | |
123 | for (int x = 0; x < 26; x++) { |
124 | for (int y = 0; y < 5; y++) { |
125 | f_ref(x, y) = 1 + x * x + y * y; |
126 | } |
127 | } |
128 | |
129 | ExpectAllNear(f_v, f_ref, 1e-5); |
130 | } |
131 | } |
132 | |
133 | BlockPtr getSimplifiedBody(const LoopNest& l) { |
134 | StmtPtr stmt = l.root_stmt(); |
135 | StmtPtr simplified = IRSimplifier::simplify(stmt); |
136 | return to<Block>(simplified); |
137 | } |
138 | |
139 | void assertForRange(ForPtr f, int expected_start, int expected_stop) { |
140 | ASSERT_NE(f, nullptr); |
141 | IntImmPtr start = to<IntImm>(f->start()); |
142 | ASSERT_NE(start, nullptr); |
143 | ASSERT_EQ(start->value(), expected_start); |
144 | IntImmPtr stop = to<IntImm>(f->stop()); |
145 | ASSERT_NE(stop, nullptr); |
146 | ASSERT_EQ(stop->value(), expected_stop); |
147 | } |
148 | |
149 | void assertForRanges( |
150 | BlockPtr body, |
151 | const std::vector<std::pair<int, int>>& start_stops) { |
152 | ASSERT_EQ(body->nstmts(), start_stops.size()); |
153 | |
154 | auto it = body->begin(); |
155 | for (size_t i = 0; i < start_stops.size(); i++, it++) { |
156 | ForPtr loop = to<For>(*it); |
157 | assertForRange(loop, start_stops[i].first, start_stops[i].second); |
158 | } |
159 | } |
160 | |
161 | TEST(LoopNest, ExprSliceHeadWithLoopOptions) { |
162 | auto func = [](const ExprHandle& x) { |
163 | return ExprHandle(1.0f) + cast<float>(x); |
164 | }; |
165 | Tensor tensor = Compute("f" , {10}, func); |
166 | LoopNest l({tensor}); |
167 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
168 | ForPtr head; |
169 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
170 | ForPtr tail; |
171 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
172 | loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); |
173 | LoopNest::sliceHead(loops[0], 2, &head, &tail); |
174 | |
175 | BlockPtr body = getSimplifiedBody(l); |
176 | assertForRanges(body, {{0, 2}, {0, 8}}); |
177 | |
178 | ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); |
179 | ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
180 | |
181 | ASSERT_TRUE(head->loop_options().isDefault()); |
182 | } |
183 | |
184 | TEST(LoopNest, ExprSliceTailWithLoopOptions) { |
185 | auto func = [](const ExprHandle& x) { |
186 | return ExprHandle(1.0f) + cast<float>(x); |
187 | }; |
188 | Tensor tensor = Compute("f" , {10}, func); |
189 | LoopNest l({tensor}); |
190 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
191 | ForPtr head; |
192 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
193 | ForPtr tail; |
194 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
195 | LoopNest::sliceTail(loops[0], 4, &head, &tail); |
196 | |
197 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
198 | ForPtr tail_head; |
199 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
200 | ForPtr tail_tail; |
201 | tail->set_gpu_block_index(LoopOptions::IDX_Y); |
202 | LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); |
203 | |
204 | BlockPtr body = getSimplifiedBody(l); |
205 | assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); |
206 | |
207 | ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); |
208 | ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
209 | |
210 | ASSERT_TRUE(head->loop_options().isDefault()); |
211 | ASSERT_TRUE(tail_tail->loop_options().isDefault()); |
212 | } |
213 | |
214 | TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { |
215 | // When factor equals the For loop's original size, keep using the original |
216 | // For loop. |
217 | auto func = [](const ExprHandle& x) { |
218 | return ExprHandle(1.0f) + cast<float>(x); |
219 | }; |
220 | Tensor tensor = Compute("f" , {10}, func); |
221 | LoopNest l({tensor}); |
222 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
223 | ForPtr head; |
224 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
225 | ForPtr tail; |
226 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
227 | LoopNest::sliceHead(loops[0], 10, &head, &tail); |
228 | |
229 | ASSERT_EQ(head, loops[0]); |
230 | ASSERT_EQ(tail, nullptr); |
231 | |
232 | BlockPtr body = getSimplifiedBody(l); |
233 | assertForRanges(body, {{0, 10}}); |
234 | } |
235 | |
236 | TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { |
237 | auto func = [](const ExprHandle& x) { |
238 | return ExprHandle(1.0f) + cast<float>(x); |
239 | }; |
240 | Tensor tensor = Compute("f" , {10}, func); |
241 | LoopNest l({tensor}); |
242 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
243 | ForPtr head; |
244 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
245 | ForPtr tail; |
246 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
247 | LoopNest::sliceHead(loops[0], 100, &head, &tail); |
248 | |
249 | ASSERT_EQ(head, loops[0]); |
250 | ASSERT_EQ(tail, nullptr); |
251 | |
252 | BlockPtr body = getSimplifiedBody(l); |
253 | assertForRanges(body, {{0, 10}}); |
254 | } |
255 | |
256 | TEST(LoopNest, ExprSliceHead) { |
257 | auto func = [](const ExprHandle& x) { |
258 | return ExprHandle(1.0f) + cast<float>(x); |
259 | }; |
260 | Tensor tensor = Compute("f" , {10}, func); |
261 | LoopNest l({tensor}); |
262 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
263 | ForPtr head; |
264 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
265 | ForPtr tail; |
266 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
267 | LoopNest::sliceHead(loops[0], 4, &head, &tail); |
268 | |
269 | ASSERT_NE(head, nullptr); |
270 | ASSERT_NE(head, loops[0]); |
271 | ASSERT_NE(tail, nullptr); |
272 | ASSERT_EQ(tail, loops[0]); |
273 | |
274 | BlockPtr body = getSimplifiedBody(l); |
275 | assertForRanges(body, {{0, 4}, {4, 10}}); |
276 | } |
277 | |
278 | TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { |
279 | auto func = [](const ExprHandle& x) { |
280 | return ExprHandle(1.0f) + cast<float>(x); |
281 | }; |
282 | Tensor tensor = Compute("f" , {10}, func); |
283 | LoopNest l({tensor}); |
284 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
285 | |
286 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
287 | ForPtr head; |
288 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
289 | ForPtr tail; |
290 | LoopNest::sliceTail(loops[0], 4, &head, &tail); |
291 | // head: [0, 6) |
292 | // tail: [6, 10) |
293 | |
294 | LoopNest::sliceHead(tail, 2); |
295 | // tail_head: [6, 8) |
296 | // tail_tail: [8, 10) |
297 | |
298 | BlockPtr body = getSimplifiedBody(l); |
299 | assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); |
300 | } |
301 | |
302 | TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { |
303 | // When factor equals the For loop's original size, keep using the original |
304 | // For loop. |
305 | auto func = [](const ExprHandle& x) { |
306 | return ExprHandle(1.0f) + cast<float>(x); |
307 | }; |
308 | Tensor tensor = Compute("f" , {10}, func); |
309 | LoopNest l({tensor}); |
310 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
311 | ForPtr head; |
312 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
313 | ForPtr tail; |
314 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
315 | LoopNest::sliceTail(loops[0], 10, &head, &tail); |
316 | |
317 | ASSERT_EQ(head, nullptr); |
318 | ASSERT_EQ(tail, loops[0]); |
319 | |
320 | BlockPtr body = getSimplifiedBody(l); |
321 | assertForRanges(body, {{0, 10}}); |
322 | } |
323 | |
324 | TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { |
325 | // When factor equals the For loop's original size, keep using the original |
326 | // For loop. |
327 | auto func = [](const ExprHandle& x) { |
328 | return ExprHandle(1.0f) + cast<float>(x); |
329 | }; |
330 | Tensor tensor = Compute("f" , {10}, func); |
331 | LoopNest l({tensor}); |
332 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
333 | ForPtr head; |
334 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
335 | ForPtr tail; |
336 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
337 | LoopNest::sliceTail(loops[0], 100, &head, &tail); |
338 | |
339 | ASSERT_EQ(head, nullptr); |
340 | ASSERT_EQ(tail, loops[0]); |
341 | |
342 | BlockPtr body = getSimplifiedBody(l); |
343 | assertForRanges(body, {{0, 10}}); |
344 | } |
345 | |
346 | TEST(LoopNest, ExprSliceTail) { |
347 | auto func = [](const ExprHandle& x) { |
348 | return ExprHandle(1.0f) + cast<float>(x); |
349 | }; |
350 | Tensor tensor = Compute("f" , {10}, func); |
351 | LoopNest l({tensor}); |
352 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
353 | ForPtr head; |
354 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
355 | ForPtr tail; |
356 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
357 | LoopNest::sliceTail(loops[0], 4, &head, &tail); |
358 | |
359 | ASSERT_NE(head, nullptr); |
360 | ASSERT_EQ(head, loops[0]); |
361 | ASSERT_NE(tail, nullptr); |
362 | ASSERT_NE(tail, loops[0]); |
363 | |
364 | BlockPtr body = getSimplifiedBody(l); |
365 | assertForRanges(body, {{0, 6}, {6, 10}}); |
366 | } |
367 | |
368 | TEST(LoopNest, ExprSplitAndSlice) { |
369 | // 0: splitWithTail |
370 | // 1: sliceTail on inner loop |
371 | // 2: sliceHead on outer loop |
372 | auto func = [](const ExprHandle& x) { |
373 | return ExprHandle(1.0f) + cast<float>(x); |
374 | }; |
375 | Tensor tensor = Compute("f" , {100}, func); |
376 | LoopNest l({tensor}); |
377 | |
378 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
379 | ForPtr inner; |
380 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
381 | ForPtr tail; |
382 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
383 | // outer: [0, 4) |
384 | // inner: [0, 21) |
385 | // tail: [84, 100) |
386 | LoopNest::splitWithTail(loops[0], 21, &inner, &tail); |
387 | LoopNest::sliceTail(inner, 2); |
388 | LoopNest::sliceHead(loops[0], 2); |
389 | |
390 | // for (int x_outer = 0; x_outer < 2; x_outer++) { |
391 | // for (int x_inner = 0; x_inner < 19; x_inner++) { |
392 | // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
393 | // } |
394 | // for (int x_inner = 19; x_inner < 21; x_inner++) { |
395 | // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
396 | // } |
397 | // } |
398 | // for (int x_outer = 2; x_outer < 4; x_outer++) { |
399 | // for (int x_inner = 0; x_inner < 19; x_inner++) { |
400 | // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
401 | // } |
402 | // for (int x_inner = 19; x_inner < 21; x_inner++) { |
403 | // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
404 | // } |
405 | // } |
406 | // for (int x_tail = 0; x_tail < 16; x_tail++) { |
407 | // f[x_tail + 84] = 1.f + float(x_tail + 84); |
408 | // } |
409 | BlockPtr body = getSimplifiedBody(l); |
410 | assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); |
411 | |
412 | auto biter = body->begin(); |
413 | |
414 | ForPtr loop = to<For>(*biter++); |
415 | assertForRanges(loop->body(), {{0, 19}, {19, 21}}); |
416 | |
417 | loop = to<For>(*biter); |
418 | assertForRanges(loop->body(), {{0, 19}, {19, 21}}); |
419 | } |
420 | |
421 | TEST(LoopNest, ExprSliceAndNormalize) { |
422 | // 0: sliceHead |
423 | // 1: normalize tail |
424 | auto func = [](const ExprHandle& x) { |
425 | return ExprHandle(1.0f) + cast<float>(x); |
426 | }; |
427 | Tensor tensor = Compute("f" , {10}, func); |
428 | LoopNest l({tensor}); |
429 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
430 | |
431 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
432 | ForPtr head; |
433 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
434 | ForPtr tail; |
435 | LoopNest::sliceHead(loops[0], 2, &head, &tail); |
436 | // head: [0, 2) |
437 | // tail: [2, 10) |
438 | |
439 | LoopNest::normalize(tail); |
440 | // normalized_tail: [0, 8) |
441 | |
442 | BlockPtr body = getSimplifiedBody(l); |
443 | assertForRanges(body, {{0, 2}, {0, 8}}); |
444 | } |
445 | |
446 | template <typename T> |
447 | T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { |
448 | ExprEval<SimpleIREvaluator> eval(expr, {var}); |
449 | return eval.value<T>(value); |
450 | } |
451 | |
452 | TEST(LoopNest, ExprSliceWithVariableDimension) { |
453 | auto testWithDimension = |
454 | [](int dimension, |
455 | const std::vector<std::pair<int, int>>& expected_for_ranges) { |
456 | VarHandle dim("dim" , kInt); |
457 | Tensor tensor = |
458 | Compute("f" , {dim}, [](const ExprHandle& x) { return x; }); |
459 | LoopNest l({tensor}); |
460 | std::vector<ForPtr> loops = |
461 | l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
462 | |
463 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
464 | ForPtr head; |
465 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
466 | ForPtr tail; |
467 | LoopNest::sliceHead(loops[0], 2, &head, &tail); |
468 | |
469 | LoopNest::sliceTail(tail, 2); |
470 | |
471 | BlockPtr body = getSimplifiedBody(l); |
472 | ASSERT_EQ(expected_for_ranges.size(), 3); |
473 | auto it = body->begin(); |
474 | for (auto& start_stop : expected_for_ranges) { |
475 | ForPtr loop = to<For>(*it++); |
476 | int start = evalExpr<int>(ExprHandle(loop->start()), dim, dimension); |
477 | int stop = evalExpr<int>(ExprHandle(loop->stop()), dim, dimension); |
478 | ASSERT_EQ(start, start_stop.first); |
479 | ASSERT_EQ(stop, start_stop.second); |
480 | } |
481 | }; |
482 | |
483 | testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); |
484 | testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); |
485 | testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); |
486 | testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); |
487 | testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); |
488 | testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); |
489 | } |
490 | |
491 | TEST(LoopNest, ExprSplitWithTail) { |
492 | auto func = [](const ExprHandle& x) { |
493 | return ExprHandle(1.0f) + cast<float>(x); |
494 | }; |
495 | Tensor tensor = Compute("f" , {199}, func); |
496 | LoopNest l({tensor}); |
497 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
498 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
499 | LoopNest::splitWithTail(loops[0], 17); |
500 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
501 | LoopNest::splitWithTail(loops[0], 7); |
502 | |
503 | StmtPtr stmt = l.root_stmt(); |
504 | StmtPtr simplified = IRSimplifier::simplify(stmt); |
505 | BlockPtr body = to<Block>(simplified); |
506 | ASSERT_EQ(body->nstmts(), 3); |
507 | auto biter = body->begin(); |
508 | |
509 | // Verify that the split loops are ordered correctly. |
510 | ForPtr loop = to<For>(*biter++); |
511 | assertForRange(loop, 0, 7); |
512 | |
513 | loop = to<For>(*biter++); |
514 | assertForRange(loop, 0, 4); |
515 | |
516 | loop = to<For>(*biter); |
517 | assertForRange(loop, 0, 12); |
518 | } |
519 | |
520 | TEST(LoopNest, ExprSplitWithTailNone) { |
521 | auto func = [](const ExprHandle& x, const ExprHandle& y) { |
522 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
523 | }; |
524 | Tensor tensor = Compute("f" , {24, 5}, func); |
525 | LoopNest l({tensor}); |
526 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
527 | LoopNest::splitWithTail(loops[0], 4); |
528 | |
529 | StmtPtr stmt = l.root_stmt(); |
530 | std::ostringstream oss; |
531 | oss << *stmt; |
532 | ASSERT_GT(oss.str().size(), 200); |
533 | ASSERT_LT(oss.str().size(), 600); |
534 | |
535 | { |
536 | // Compare to a reference loop structure structure. |
537 | VarHandle x_outer("i_outer" , kInt); |
538 | VarHandle x_inner("i_inner" , kInt); |
539 | VarHandle y("i" , kInt); |
540 | VarHandle x_tail("i_tail" , kInt); |
541 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
542 | BufHandle f("f" , {24, 5}, kFloat); |
543 | ExprHandle x_1 = x_outer * 4 + x_inner; |
544 | ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; |
545 | StmtPtr stmt = alloc<Block>(std::vector<StmtPtr>({For::make( |
546 | x_outer, |
547 | 0, |
548 | x_outer_end, |
549 | For::make( |
550 | x_inner, |
551 | 0, |
552 | 4, |
553 | For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); |
554 | |
555 | std::ostringstream oss_ref; |
556 | oss_ref << *stmt; |
557 | ASSERT_EQ(oss.str(), oss_ref.str()); |
558 | } |
559 | |
560 | { |
561 | PaddedBuffer<float> f_v(24, 5, "f_v" ); |
562 | PaddedBuffer<float> f_ref(24, 5, "f_res" ); |
563 | |
564 | SimpleIREvaluator ir_eval(stmt, {tensor}); |
565 | ir_eval(f_v); |
566 | |
567 | for (int x = 0; x < 24; x++) { |
568 | for (int y = 0; y < 5; y++) { |
569 | f_ref(x, y) = 1 + x * x + y * y; |
570 | } |
571 | } |
572 | |
573 | ExpectAllNear(f_v, f_ref, 1e-5); |
574 | } |
575 | } |
576 | |
577 | TEST(LoopNest, ExprSplitWithMask01) { |
578 | const int M = 26; |
579 | const int N = 5; |
580 | BufHandle a_buf("a" , {M, N}, kFloat); |
581 | BufHandle b_buf("b" , {M, N}, kFloat); |
582 | Tensor tensor = |
583 | Compute("f" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
584 | return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; |
585 | }); |
586 | |
587 | LoopNest l({tensor}); |
588 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
589 | LoopNest::splitWithMask(loops[1], 4); |
590 | |
591 | StmtPtr stmt = l.root_stmt(); |
592 | |
593 | PaddedBuffer<float> a_v(M, N, "a" ); |
594 | PaddedBuffer<float> b_v(M, N, "b" ); |
595 | PaddedBuffer<float> c_v(M, N, "c" ); |
596 | PaddedBuffer<float> c_ref(M, N, "c_ref" ); |
597 | for (int m = 0; m < M; m++) { |
598 | for (int n = 0; n < N; n++) { |
599 | a_v(m, n) = 2 * m; |
600 | b_v(m, n) = 3 * n; |
601 | c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; |
602 | } |
603 | } |
604 | |
605 | SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
606 | |
607 | ExpectAllNear(c_v, c_ref, 1e-5); |
608 | } |
609 | |
610 | // Tests the case where we split a loop cleanly multiple times, we should not |
611 | // insert any masks. |
612 | TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { |
613 | const int M = 64; |
614 | BufHandle a_buf("a" , {M}, kFloat); |
615 | BufHandle b_buf("b" , {M}, kFloat); |
616 | Tensor tensor = Compute("f" , {M}, [&](const ExprHandle& m) { |
617 | return a_buf.load(m) + b_buf.load(m) + 1.0f; |
618 | }); |
619 | |
620 | LoopNest l({tensor}); |
621 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
622 | LoopNest::splitWithMask(loops[0], 4); |
623 | LoopNest::splitWithMask(loops[0], 4); |
624 | |
625 | StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); |
626 | |
627 | // Two splits mean 3 loops, but should need no masks in this case. |
628 | checkIR(stmt1, R"IR( |
629 | # CHECK: for ( |
630 | # CHECK-NOT: if ( |
631 | # CHECK: for ( |
632 | # CHECK-NOT: if ( |
633 | # CHECK: for ( |
634 | # CHECK-NOT: if ( |
635 | # CHECK: f[)IR" ); |
636 | } |
637 | |
638 | TEST(LoopNest, getLoopAt) { |
639 | // Input IR: |
640 | // for (int i = 0; i < 100; i++) { |
641 | // for (int j = 0; j < 100; j++) { |
642 | // A[i, j] = sin(i * j); |
643 | // for (int k1 = 0; k1 < 200; k1++) { |
644 | // B[i, j, k1] = (A[i, j]) / (k1 + 1); |
645 | // } |
646 | // for (int k2 = 0; k2 < 300; k2++) { |
647 | // C[i, j, k2] = (A[i, j]) * (k2 + 1); |
648 | // } |
649 | // } |
650 | // } |
651 | BufPtr A = alloc<Buf>( |
652 | "A" , |
653 | std::vector<ExprPtr>({alloc<IntImm>(100), alloc<IntImm>(100)}), |
654 | kInt); |
655 | BufPtr B = alloc<Buf>( |
656 | "B" , |
657 | std::vector<ExprPtr>( |
658 | {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(200)}), |
659 | kInt); |
660 | BufPtr C = alloc<Buf>( |
661 | "C" , |
662 | std::vector<ExprPtr>( |
663 | {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(300)}), |
664 | kInt); |
665 | BufHandle a_buf(A); |
666 | BufHandle b_buf(B); |
667 | BufHandle c_buf(C); |
668 | VarHandle i("i" , kInt); |
669 | VarHandle j("j" , kInt); |
670 | VarHandle k1("k1" , kInt); |
671 | VarHandle k2("k2" , kInt); |
672 | auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); |
673 | auto store2 = Store::make( |
674 | b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); |
675 | auto store3 = Store::make( |
676 | c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); |
677 | auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); |
678 | auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); |
679 | auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); |
680 | auto for_i = For::make(i, 0, 100, for_j); |
681 | LoopNest l(Block::make({for_i}), {B, C}); |
682 | auto ret_k2 = l.getLoopAt(for_i, {0, 2}); |
683 | TORCH_CHECK(ret_k2 == for_k2); |
684 | |
685 | std::ostringstream oss; |
686 | oss << *ret_k2; |
687 | const std::string& verification_pattern = |
688 | R"IR( |
689 | # CHECK: for (int k2 |
690 | # CHECK-NEXT: C[i, j, k2] = |
691 | )IR" ; |
692 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
693 | } |
694 | |
695 | TEST(LoopNest, TileSimple) { |
696 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
697 | const int M = 64, N = 64; |
698 | BufHandle a_buf("a" , {M, N}, kFloat); |
699 | BufHandle b_buf("b" , {M, N}, kFloat); |
700 | Tensor tensor = |
701 | Compute("f" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
702 | return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; |
703 | }); |
704 | |
705 | LoopNest l({tensor}); |
706 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
707 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
708 | l.tile(loops[0], loops[1], 4, 8); |
709 | |
710 | // IR check |
711 | StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
712 | checkIR(stmt, R"IR( |
713 | # CHECK: for (int i_outer |
714 | # CHECK: for (int i_outer_1 |
715 | # CHECK: for (int i_inner |
716 | # CHECK: for (int i_inner_1 |
717 | # CHECK: f[ |
718 | # CHECK-NOT: for (int i_tail |
719 | # CHECK-NOT: for (int i_tail)IR" ); |
720 | |
721 | // Correctness check |
722 | PaddedBuffer<float> a_v(M, N, "a" ); |
723 | PaddedBuffer<float> b_v(M, N, "b" ); |
724 | PaddedBuffer<float> c_v(M, N, "c" ); |
725 | PaddedBuffer<float> c_ref(M, N, "c_ref" ); |
726 | for (int m = 0; m < M; m++) { |
727 | for (int n = 0; n < N; n++) { |
728 | a_v(m, n) = 2 * m; |
729 | b_v(m, n) = 3 * n; |
730 | c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; |
731 | } |
732 | } |
733 | |
734 | SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
735 | |
736 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
737 | ExpectAllNear(c_v, c_ref, 1e-5); |
738 | } |
739 | |
740 | TEST(LoopNest, TileWithTails) { |
741 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
742 | const int M = 64, N = 64; |
743 | BufHandle a_buf("a" , {M, N}, kFloat); |
744 | BufHandle b_buf("b" , {M, N}, kFloat); |
745 | Tensor tensor = |
746 | Compute("f" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
747 | return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; |
748 | }); |
749 | |
750 | LoopNest l({tensor}); |
751 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
752 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
753 | l.tile(loops[0], loops[1], 5, 9); |
754 | |
755 | // IR check |
756 | StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
757 | checkIR(stmt, R"IR( |
758 | # CHECK: for (int i_outer |
759 | # CHECK: for (int i_outer_1 |
760 | # CHECK: for (int i_inner |
761 | # CHECK: for (int i_inner_1 |
762 | # CHECK: f[ |
763 | # CHECK: for (int i_inner |
764 | # CHECK: f[ |
765 | # CHECK: for (int i_tail)IR" ); |
766 | |
767 | // Correctness check |
768 | PaddedBuffer<float> a_v(M, N, "a" ); |
769 | PaddedBuffer<float> b_v(M, N, "b" ); |
770 | PaddedBuffer<float> c_v(M, N, "c" ); |
771 | PaddedBuffer<float> c_ref(M, N, "c_ref" ); |
772 | for (int m = 0; m < M; m++) { |
773 | for (int n = 0; n < N; n++) { |
774 | a_v(m, n) = 2 * m; |
775 | b_v(m, n) = 3 * n; |
776 | c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; |
777 | } |
778 | } |
779 | |
780 | SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
781 | |
782 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
783 | ExpectAllNear(c_v, c_ref, 1e-5); |
784 | } |
785 | |
786 | TEST(LoopNest, TileInMiddle) { |
787 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
788 | const int M = 8, N = 8, L = 8, K = 8; |
789 | BufHandle a_buf("a" , {M, N, L, K}, kFloat); |
790 | BufHandle b_buf("b" , {M, N, L, K}, kFloat); |
791 | Tensor tensor = Compute( |
792 | "f" , |
793 | {M, N, L, K}, |
794 | [&](const ExprHandle& m, |
795 | const ExprHandle& n, |
796 | const ExprHandle& l, |
797 | const ExprHandle& k) { |
798 | return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; |
799 | }); |
800 | |
801 | LoopNest nest({tensor}); |
802 | std::vector<ForPtr> loops = |
803 | nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
804 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
805 | nest.tile(loops[1], loops[2], 3, 3); |
806 | |
807 | // IR check |
808 | StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); |
809 | checkIR(stmt, R"IR( |
810 | # CHECK: for (int i |
811 | # CHECK: for (int i_outer |
812 | # CHECK: for (int i_outer_1 |
813 | # CHECK: for (int i_inner |
814 | # CHECK: for (int i_inner_1 |
815 | # CHECK: for (int i_1 |
816 | # CHECK: f[ |
817 | # CHECK: for (int i_tail_1 |
818 | # CHECK: for (int i_inner_1 |
819 | # CHECK: for (int i_1 |
820 | # CHECK: f[ |
821 | # CHECK: for (int i_tail)IR" ); |
822 | |
823 | // Correctness check |
824 | PaddedBuffer<float> a_v(M, N, L, K, "a" ); |
825 | PaddedBuffer<float> b_v(M, N, L, K, "b" ); |
826 | PaddedBuffer<float> c_v(M, N, L, K, "c" ); |
827 | PaddedBuffer<float> c_ref(M, N, L, K, "c_ref" ); |
828 | for (int m = 0; m < M; m++) { |
829 | for (int n = 0; n < N; n++) { |
830 | for (int l = 0; l < L; l++) { |
831 | for (int k = 0; k < K; k++) { |
832 | a_v(m, n, l, k) = 2 * (m + l); |
833 | b_v(m, n, l, k) = 3 * (n + k); |
834 | c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; |
835 | } |
836 | } |
837 | } |
838 | } |
839 | |
840 | SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
841 | |
842 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
843 | ExpectAllNear(c_v, c_ref, 1e-5); |
844 | } |
845 | |
846 | TEST(LoopNest, SplitWithTailWithLoopOptions) { |
847 | const int M = 21; |
848 | BufHandle a_buf("a" , {M}, kFloat); |
849 | BufHandle b_buf("b" , {M}, kFloat); |
850 | Tensor tensor = Compute("f" , {M}, [&](const ExprHandle& m) { |
851 | return a_buf.load(m) + b_buf.load(m) + 1.0f; |
852 | }); |
853 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
854 | ForPtr inner, tail; |
855 | |
856 | LoopNest l({tensor}); |
857 | auto loops = NodeFinder<For>::find(l.root_stmt()); |
858 | ASSERT_GT(loops.size(), 0); |
859 | loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); |
860 | LoopNest::splitWithTail(loops[0], 4, &inner, &tail); |
861 | ASSERT_NE(inner, nullptr); |
862 | ASSERT_NE(tail, nullptr); |
863 | ForPtr outer = loops[0]; |
864 | |
865 | // Outer loop carries loop axis bindings. |
866 | ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); |
867 | ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
868 | |
869 | // Inner loop has none. |
870 | ASSERT_TRUE(inner->loop_options().isDefault()); |
871 | |
872 | // Tail loop has none. |
873 | ASSERT_TRUE(tail->loop_options().isDefault()); |
874 | } |
875 | |
876 | TEST(LoopNest, SplitWithMaskWithLoopOptions) { |
877 | const int M = 21; |
878 | BufHandle a_buf("a" , {M}, kFloat); |
879 | BufHandle b_buf("b" , {M}, kFloat); |
880 | Tensor tensor = Compute("f" , {M}, [&](const ExprHandle& m) { |
881 | return a_buf.load(m) + b_buf.load(m) + 1.0f; |
882 | }); |
883 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
884 | ForPtr inner; |
885 | |
886 | LoopNest l({tensor}); |
887 | auto loops = NodeFinder<For>::find(l.root_stmt()); |
888 | loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); |
889 | LoopNest::splitWithMask(loops[0], 4, &inner); |
890 | ForPtr outer = loops[0]; |
891 | |
892 | // Outer loop carries loop axis bindings. |
893 | ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); |
894 | ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
895 | |
896 | // Inner loop has none. |
897 | ASSERT_TRUE(inner->loop_options().isDefault()); |
898 | } |
899 | |
900 | TEST(LoopNest, ScheduleBroadcastAddBuffer) { |
901 | const int M = 4; |
902 | const int N = 5; |
903 | const int K = 6; |
904 | BufHandle a_buf("a" , {M, N}, kFloat); |
905 | BufHandle b_buf("b" , {N, K}, kFloat); |
906 | Tensor c = Compute( |
907 | "broadcast_add" , |
908 | {M, N, K}, |
909 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
910 | return a_buf.load(m, n) + b_buf.load(n, k); |
911 | }); |
912 | LoopNest l({c}); |
913 | StmtPtr stmt = l.root_stmt(); |
914 | |
915 | PaddedBuffer<float> a_v(M, N, "a_v" ); |
916 | for (int m = 0; m < M; m++) { |
917 | for (int n = 0; n < N; n++) { |
918 | a_v(m, n) = 7 * m * n; |
919 | } |
920 | } |
921 | a_v.Backup(); |
922 | |
923 | PaddedBuffer<float> b_v(N, K, "b_v" ); |
924 | for (int n = 0; n < N; n++) { |
925 | for (int k = 0; k < K; k++) { |
926 | b_v(n, k) = 11 * n * k; |
927 | } |
928 | } |
929 | b_v.Backup(); |
930 | |
931 | PaddedBuffer<float> c_v(M, N, K, "c_buf" ); |
932 | SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); |
933 | ir_eval(a_v, b_v, c_v); |
934 | |
935 | a_v.CheckBackup(); |
936 | b_v.CheckBackup(); |
937 | PaddedBuffer<float> c_ref(M, N, K, "c_ref" ); |
938 | for (int m = 0; m < M; m++) { |
939 | for (int n = 0; n < N; n++) { |
940 | for (int k = 0; k < K; k++) { |
941 | c_ref(m, n, k) = 7 * m * n + 11 * n * k; |
942 | } |
943 | } |
944 | } |
945 | ExpectAllNear(c_v, c_ref, 1e-5); |
946 | } |
947 | |
948 | TEST(LoopNest, ScheduleFunctionCall01) { |
949 | const int M = 4; |
950 | const int N = 5; |
951 | const int K = 6; |
952 | BufHandle a_buf("a" , {M, N}, kFloat); |
953 | BufHandle b_buf("b" , {N, K}, kFloat); |
954 | Tensor c = Compute( |
955 | "broadcast_add" , |
956 | {M, N, K}, |
957 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
958 | return a_buf.load(m, n) + b_buf.load(n, k); |
959 | }); |
960 | Tensor d = Compute( |
961 | "d" , |
962 | {M, N, K}, |
963 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
964 | return c.load(m, n, k) + 1; |
965 | }); |
966 | |
967 | LoopNest l({d}, {c, d}); |
968 | l.prepareForCodegen(); |
969 | StmtPtr stmt = l.root_stmt(); |
970 | std::ostringstream oss; |
971 | oss << *stmt; |
972 | ASSERT_GT(oss.str().size(), 100); |
973 | |
974 | PaddedBuffer<float> a_v(M, N); |
975 | PaddedBuffer<float> b_v(N, K); |
976 | PaddedBuffer<float> c_v(M, N, K); |
977 | PaddedBuffer<float> d_v(M, N, K); |
978 | PaddedBuffer<float> d_ref(M, N, K); |
979 | |
980 | for (int i = 0; i < M; i++) { |
981 | for (int j = 0; j < N; j++) { |
982 | a_v(i, j) = i * i; |
983 | } |
984 | } |
985 | for (int i = 0; i < N; i++) { |
986 | for (int j = 0; j < K; j++) { |
987 | b_v(i, j) = j * j; |
988 | } |
989 | } |
990 | for (int i = 0; i < M; i++) { |
991 | for (int j = 0; j < N; j++) { |
992 | for (int k = 0; k < K; k++) { |
993 | d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; |
994 | } |
995 | } |
996 | } |
997 | |
998 | SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); |
999 | eval(a_v, b_v, d_v); |
1000 | |
1001 | ExpectAllNear(d_v, d_ref, 1e-5); |
1002 | } |
1003 | |
1004 | TEST(LoopNest, ScheduleInlineSimple) { |
1005 | const int M = 4; |
1006 | const int N = 5; |
1007 | const int K = 6; |
1008 | BufHandle a_buf("a" , {M, N}, kFloat); |
1009 | BufHandle b_buf("b" , {N, K}, kFloat); |
1010 | BufHandle c_buf("c" , {M, N}, kFloat); |
1011 | BufHandle d_buf("d" , {M, K}, kFloat); |
1012 | |
1013 | Tensor x = Compute( |
1014 | "x" , |
1015 | {M, N, K}, |
1016 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1017 | return a_buf.load(m, n) * b_buf.load(n, k); |
1018 | }); |
1019 | Tensor y = Compute( |
1020 | "y" , |
1021 | {M, N, K}, |
1022 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1023 | return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); |
1024 | }); |
1025 | |
1026 | LoopNest l1({y}, {x, y}); |
1027 | LoopNest l2(l1); |
1028 | l2.computeInline(x.buf()); |
1029 | |
1030 | l1.prepareForCodegen(); |
1031 | l2.prepareForCodegen(); |
1032 | |
1033 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1034 | StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); |
1035 | |
1036 | SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); |
1037 | SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); |
1038 | |
1039 | PaddedBuffer<float> a_v(M, N); |
1040 | PaddedBuffer<float> b_v(N, K); |
1041 | PaddedBuffer<float> c_v(M, N); |
1042 | PaddedBuffer<float> d_v(M, K); |
1043 | |
1044 | for (int i = 0; i < M; i++) { |
1045 | for (int j = 0; j < N; j++) { |
1046 | a_v(i, j) = i * i; |
1047 | } |
1048 | } |
1049 | for (int i = 0; i < N; i++) { |
1050 | for (int j = 0; j < K; j++) { |
1051 | b_v(i, j) = j * j; |
1052 | } |
1053 | } |
1054 | for (int i = 0; i < M; i++) { |
1055 | for (int j = 0; j < N; j++) { |
1056 | c_v(i, j) = i + j; |
1057 | } |
1058 | } |
1059 | for (int i = 0; i < M; i++) { |
1060 | for (int j = 0; j < K; j++) { |
1061 | d_v(i, j) = i * j; |
1062 | } |
1063 | } |
1064 | |
1065 | PaddedBuffer<float> y_1(M, N, K); |
1066 | PaddedBuffer<float> y_2(M, N, K); |
1067 | |
1068 | eval1(a_v, b_v, c_v, d_v, y_1); |
1069 | eval2(a_v, b_v, c_v, d_v, y_2); |
1070 | ExpectAllNear(y_1, y_2, 1e-5); |
1071 | std::ostringstream oss1, oss2; |
1072 | oss1 << *stmt1; |
1073 | oss2 << *stmt2; |
1074 | ASSERT_GT(oss1.str().size(), oss2.str().size()); |
1075 | } |
1076 | |
1077 | static std::string remove_space(const std::string& str) { |
1078 | std::string str_new = str; |
1079 | str_new.erase( |
1080 | remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); |
1081 | return str_new; |
1082 | } |
1083 | |
1084 | void InlineFunc01Helper(const std::vector<std::string>& inline_order) { |
1085 | const int M = 4; |
1086 | const int N = 5; |
1087 | const int K = 6; |
1088 | BufHandle a_buf("a" , {M, N}, kFloat); |
1089 | BufHandle b_buf("b" , {N, K}, kFloat); |
1090 | BufHandle c_buf("c" , {M, N}, kFloat); |
1091 | BufHandle d_buf("d" , {M, K}, kFloat); |
1092 | |
1093 | Tensor x = Compute( |
1094 | "x" , |
1095 | {M, N, K}, |
1096 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1097 | return a_buf.load(m, n) * b_buf.load(n, k); |
1098 | }); |
1099 | Tensor y = Compute( |
1100 | "y" , |
1101 | {M, N, K}, |
1102 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1103 | return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); |
1104 | }); |
1105 | Tensor z = Compute( |
1106 | "z" , |
1107 | {M, N, K}, |
1108 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1109 | return x.load(m, n, k) + y.load(m, n, k); |
1110 | }); |
1111 | |
1112 | LoopNest l({z}, {x, y, z}); |
1113 | for (const std::string& order : inline_order) { |
1114 | if (order == "x" ) { |
1115 | l.computeInline(x.buf()); |
1116 | } else if (order == "y" ) { |
1117 | l.computeInline(y.buf()); |
1118 | } else { |
1119 | throw std::runtime_error("Invalid order: " + order); |
1120 | } |
1121 | } |
1122 | l.prepareForCodegen(); |
1123 | StmtPtr stmt = l.root_stmt(); |
1124 | |
1125 | std::ostringstream oss; |
1126 | oss << *stmt; |
1127 | std::string str1 = remove_space(oss.str()); |
1128 | |
1129 | { |
1130 | PaddedBuffer<float> a_v(M, N); |
1131 | PaddedBuffer<float> b_v(N, K); |
1132 | PaddedBuffer<float> c_v(M, N); |
1133 | PaddedBuffer<float> d_v(M, K); |
1134 | |
1135 | for (int i = 0; i < M; i++) { |
1136 | for (int j = 0; j < N; j++) { |
1137 | a_v(i, j) = i * i; |
1138 | } |
1139 | } |
1140 | for (int i = 0; i < N; i++) { |
1141 | for (int j = 0; j < K; j++) { |
1142 | b_v(i, j) = j * j; |
1143 | } |
1144 | } |
1145 | for (int i = 0; i < M; i++) { |
1146 | for (int j = 0; j < N; j++) { |
1147 | c_v(i, j) = i + j; |
1148 | } |
1149 | } |
1150 | for (int i = 0; i < M; i++) { |
1151 | for (int j = 0; j < K; j++) { |
1152 | d_v(i, j) = i * j; |
1153 | } |
1154 | } |
1155 | |
1156 | PaddedBuffer<float> z_v(M, N, K); |
1157 | PaddedBuffer<float> z_ref(M, N, K); |
1158 | for (int m = 0; m < M; m++) { |
1159 | for (int n = 0; n < N; n++) { |
1160 | for (int k = 0; k < K; k++) { |
1161 | z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); |
1162 | } |
1163 | } |
1164 | } |
1165 | |
1166 | SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); |
1167 | eval(a_v, b_v, c_v, d_v, z_v); |
1168 | ExpectAllNear(z_v, z_ref, 1e-5); |
1169 | } |
1170 | |
1171 | if (inline_order.size() == 2) { |
1172 | Tensor z2 = Compute( |
1173 | "z" , |
1174 | {M, N, K}, |
1175 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1176 | return a_buf.load(m, n) * b_buf.load(n, k) + |
1177 | (c_buf.load(m, n) * d_buf.load(m, k) + |
1178 | a_buf.load(m, n) * b_buf.load(n, k)); |
1179 | }); |
1180 | LoopNest l2({z2}); |
1181 | l2.prepareForCodegen(); |
1182 | StmtPtr stmt2 = l2.root_stmt(); |
1183 | |
1184 | std::ostringstream oss2; |
1185 | oss2 << *stmt2; |
1186 | std::string str2 = remove_space(oss2.str()); |
1187 | |
1188 | ASSERT_EQ(str1, str2); |
1189 | ASSERT_GT(str1.size(), 100); |
1190 | } |
1191 | } |
1192 | |
1193 | TEST(LoopNest, ScheduleInlineFunc01) { |
1194 | InlineFunc01Helper({"x" , "y" }); |
1195 | InlineFunc01Helper({"y" , "x" }); |
1196 | InlineFunc01Helper({"x" }); |
1197 | InlineFunc01Helper({"y" }); |
1198 | InlineFunc01Helper({}); |
1199 | } |
1200 | |
1201 | // Make sure we cache random vars if we should. |
1202 | TEST(LoopNest, ScheduleInlineRandom) { |
1203 | const int M = 4; |
1204 | const int N = 5; |
1205 | const int K = 6; |
1206 | |
1207 | Tensor x = Compute( |
1208 | "x" , |
1209 | {M, N, K}, |
1210 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1211 | return Mod::make(Intrinsics::make(kRand, kInt), 5); |
1212 | }); |
1213 | Tensor y = Compute( |
1214 | "y" , |
1215 | {M, N, K}, |
1216 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1217 | return x.load(m, n, k) + x.load(m, n, k); |
1218 | }); |
1219 | |
1220 | LoopNest l1({y}, {x, y}); |
1221 | l1.computeInline(x.buf()); |
1222 | |
1223 | // would normally compare results but Rand isn't implemented in the |
1224 | // SimpleIREvaluator, even if we could seed it. |
1225 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1226 | |
1227 | // Check the IR we produced |
1228 | checkIR(stmt1, R"IR( |
1229 | # CHECK: for (int i = 0; i < 4; i++) |
1230 | # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
1231 | # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
1232 | # CHECK: int x = rand(); |
1233 | # CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR" ); |
1234 | } |
1235 | |
1236 | // Make sure we don't cache random vars that are not being inlined. |
1237 | TEST(LoopNest, ScheduleInlineRandomUnrelated) { |
1238 | const int M = 4; |
1239 | const int N = 5; |
1240 | const int K = 6; |
1241 | |
1242 | Tensor x = Compute( |
1243 | "x" , |
1244 | {M, N, K}, |
1245 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1246 | return m * n * k; |
1247 | }); |
1248 | Tensor y = Compute( |
1249 | "y" , |
1250 | {M, N, K}, |
1251 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1252 | return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + |
1253 | Intrinsics::make(kRand, kInt); |
1254 | }); |
1255 | |
1256 | LoopNest l1({y}, {x, y}); |
1257 | l1.computeInline(x.buf()); |
1258 | |
1259 | // would normally compare results but Rand isn't implemented in the |
1260 | // SimpleIREvaluator, even if we could seed it. |
1261 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1262 | |
1263 | // Check the IR we produced |
1264 | checkIR(stmt1, R"IR( |
1265 | # CHECK: for (int i = 0; i < 4; i++) |
1266 | # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
1267 | # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
1268 | # CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR" ); |
1269 | } |
1270 | |
1271 | // Make sure we generate the right number of random values == the dimensionality |
1272 | // of the production tensor. |
1273 | TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { |
1274 | const int M = 4; |
1275 | const int N = 5; |
1276 | const int K = 6; |
1277 | |
1278 | Tensor x = Compute("x" , {M}, [&](const VarHandle& m) { |
1279 | return Mod::make(Intrinsics::make(kRand, kInt), 5); |
1280 | }); |
1281 | Tensor y = Compute( |
1282 | "y" , |
1283 | {M, N, K}, |
1284 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1285 | return x.load(m) + x.load(m); |
1286 | }); |
1287 | |
1288 | LoopNest l1({y}, {x, y}); |
1289 | l1.computeInline(x.buf()); |
1290 | |
1291 | // would normally compare results but Rand isn't implemented in the |
1292 | // SimpleIREvaluator, even if we could seed it. |
1293 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1294 | |
1295 | // Check the IR we produced |
1296 | checkIR(stmt1, R"IR( |
1297 | # CHECK: for (int i = 0; i < 4; i++) |
1298 | # CHECK: int x = rand(); |
1299 | # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
1300 | # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
1301 | # CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR" ); |
1302 | } |
1303 | |
1304 | // Make sure we don't screw up intrinsics thinking they're rand. |
1305 | TEST(LoopNest, ScheduleInlineIntrinsics) { |
1306 | const int M = 4; |
1307 | const int N = 5; |
1308 | const int K = 6; |
1309 | BufHandle a_buf("a" , {M, N}, kFloat); |
1310 | BufHandle b_buf("b" , {N, K}, kFloat); |
1311 | |
1312 | Tensor x = Compute( |
1313 | "x" , |
1314 | {M, N, K}, |
1315 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1316 | return a_buf.load(m, n) * b_buf.load(n, k); |
1317 | }); |
1318 | Tensor y = Compute( |
1319 | "y" , |
1320 | {M, N, K}, |
1321 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1322 | return Intrinsics::make(kSqrt, x.load(m, n, k)); |
1323 | }); |
1324 | |
1325 | PaddedBuffer<float> a_v(M, N); |
1326 | PaddedBuffer<float> b_v(N, K); |
1327 | |
1328 | for (int i = 0; i < M; i++) { |
1329 | for (int j = 0; j < N; j++) { |
1330 | a_v(i, j) = i * i; |
1331 | } |
1332 | } |
1333 | for (int i = 0; i < N; i++) { |
1334 | for (int j = 0; j < K; j++) { |
1335 | b_v(i, j) = j * j; |
1336 | } |
1337 | } |
1338 | |
1339 | LoopNest l1({y}, {x, y}); |
1340 | LoopNest l2(l1); |
1341 | l2.computeInline(x.buf()); |
1342 | |
1343 | l1.prepareForCodegen(); |
1344 | l2.prepareForCodegen(); |
1345 | |
1346 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1347 | StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); |
1348 | |
1349 | SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); |
1350 | SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); |
1351 | |
1352 | PaddedBuffer<float> y_1(M, N, K); |
1353 | PaddedBuffer<float> y_2(M, N, K); |
1354 | |
1355 | eval1(a_v, b_v, y_1); |
1356 | eval2(a_v, b_v, y_2); |
1357 | ExpectAllNear(y_1, y_2, 1e-5); |
1358 | std::ostringstream oss1, oss2; |
1359 | oss1 << *stmt1; |
1360 | oss2 << *stmt2; |
1361 | ASSERT_GT(oss1.str().size(), oss2.str().size()); |
1362 | } |
1363 | |
1364 | // Make sure we can handle rand and non-rand intrinsics. |
1365 | TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { |
1366 | const int M = 4; |
1367 | const int N = 5; |
1368 | const int K = 6; |
1369 | |
1370 | Tensor x = Compute( |
1371 | "x" , |
1372 | {M, N, K}, |
1373 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1374 | return Intrinsics::make(kRand, kFloat); |
1375 | }); |
1376 | Tensor y = Compute( |
1377 | "y" , |
1378 | {M, N, K}, |
1379 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1380 | return Intrinsics::make(kSqrt, x.load(m, n, k)); |
1381 | }); |
1382 | |
1383 | LoopNest l1({y}, {x, y}); |
1384 | l1.computeInline(x.buf()); |
1385 | |
1386 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1387 | |
1388 | // Check the IR we produced |
1389 | checkIR(stmt1, R"IR( |
1390 | # CHECK: for (int i = 0; i < 4; i++) |
1391 | # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
1392 | # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
1393 | # CHECK: float x = rand(); |
1394 | # CHECK: y[i, i_1, i_2] = sqrt(x);)IR" ); |
1395 | } |
1396 | |
1397 | // Split a Compute then inline it into another compute. |
1398 | TEST(LoopNest, ScheduleSplitAThenInline) { |
1399 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1400 | Tensor b = Compute( |
1401 | "b" , {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1402 | |
1403 | LoopNest l({b}, {a, b}); |
1404 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1405 | LoopNest::splitWithMask(loops[0], 4); |
1406 | ASSERT_FALSE(l.computeInline(a.buf())); |
1407 | } |
1408 | |
1409 | // Split a Compute then inline another Compute into it. |
1410 | TEST(LoopNest, ScheduleSplitBThenInline) { |
1411 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1412 | Tensor b = Compute( |
1413 | "b" , {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1414 | |
1415 | LoopNest l({b}, {a, b}); |
1416 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); |
1417 | LoopNest::splitWithMask(loops[0], 3); |
1418 | l.computeInline(a.buf()); |
1419 | l.prepareForCodegen(); |
1420 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
1421 | |
1422 | std::vector<int> output(6, 0); |
1423 | SimpleIREvaluator eval(s, {b}); |
1424 | eval(output); |
1425 | |
1426 | for (int i = 0; i < 6; ++i) { |
1427 | ASSERT_EQ(output[i], (i + 8) * (i + 8)); |
1428 | } |
1429 | } |
1430 | |
1431 | // Split a Compute twice then inline it. |
1432 | TEST(LoopNest, ScheduleSplitTwiceThenInline) { |
1433 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1434 | Tensor b = Compute( |
1435 | "b" , {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1436 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1437 | ForPtr i_inner; |
1438 | |
1439 | LoopNest l({b}, {a, b}); |
1440 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1441 | LoopNest::splitWithMask(loops[0], 4, &i_inner); |
1442 | LoopNest::splitWithMask(i_inner, 2); |
1443 | ASSERT_FALSE(l.computeInline(a.buf())); |
1444 | } |
1445 | |
1446 | // Inline a Compute, then split. |
1447 | TEST(LoopNest, ScheduleInlineThenSplit) { |
1448 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1449 | Tensor b = Compute( |
1450 | "b" , {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1451 | |
1452 | LoopNest l({b}, {a, b}); |
1453 | l.computeInline(a.buf()); |
1454 | |
1455 | std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt()); |
1456 | LoopNest::splitWithMask(loops.back(), 3); |
1457 | l.prepareForCodegen(); |
1458 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
1459 | std::vector<int> output(6, 0); |
1460 | SimpleIREvaluator eval(s, {b}); |
1461 | eval(output); |
1462 | |
1463 | for (int i = 0; i < 6; ++i) { |
1464 | ASSERT_EQ(output[i], (i + 8) * (i + 8)); |
1465 | } |
1466 | } |
1467 | |
1468 | // Split a Compute, inline it, then split the result. |
1469 | TEST(LoopNest, ScheduleSplitInlineThenSplit) { |
1470 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1471 | Tensor b = Compute( |
1472 | "b" , {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1473 | |
1474 | LoopNest l({b}, {a, b}); |
1475 | auto loops = NodeFinder<For>::find(l.root_stmt()); |
1476 | LoopNest::splitWithMask(loops.back(), 2); |
1477 | l.computeInline(a.buf()); |
1478 | |
1479 | loops = NodeFinder<For>::find(l.root_stmt()); |
1480 | LoopNest::splitWithMask(loops.front(), 2); |
1481 | l.prepareForCodegen(); |
1482 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
1483 | std::vector<int> output(16, 0); |
1484 | SimpleIREvaluator eval(s, {b}); |
1485 | eval(output); |
1486 | |
1487 | for (int i = 0; i < 16; ++i) { |
1488 | ASSERT_EQ(output[i], (i + 8) * (i + 8)); |
1489 | } |
1490 | } |
1491 | |
1492 | // Oversplit a loop that is simplified out after inlining. |
1493 | TEST(LoopNest, ScheduleSplitInlineSimplify) { |
1494 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { |
1495 | return ExprHandle(4) * i - ExprHandle(2) * i; |
1496 | }); |
1497 | Tensor b = Compute( |
1498 | "b" , {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); |
1499 | |
1500 | LoopNest l({b}, {a, b}); |
1501 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1502 | LoopNest::splitWithMask(loops[0], 4); |
1503 | ASSERT_FALSE(l.computeInline(a.buf())); |
1504 | } |
1505 | |
1506 | // Inline a Compute with two consumers. |
1507 | TEST(LoopNest, ScheduleInlineThreeMixedOnce) { |
1508 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1509 | Tensor b = Compute( |
1510 | "b" , {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1511 | Tensor c = Compute("c" , {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
1512 | return a.load(k) * b.load(l); |
1513 | }); |
1514 | |
1515 | LoopNest l({c}, {a, b, c}); |
1516 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1517 | l.computeInline(a.buf()); |
1518 | l.prepareForCodegen(); |
1519 | |
1520 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
1521 | std::vector<int> output(4 * 3, 0); |
1522 | SimpleIREvaluator eval(s, {c}); |
1523 | eval(output); |
1524 | |
1525 | for (int k = 0; k < 4; ++k) { |
1526 | for (int l = 0; l < 3; ++l) { |
1527 | ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); |
1528 | } |
1529 | } |
1530 | } |
1531 | |
1532 | // Inline Compute A into B, then inline B into C. |
1533 | TEST(LoopNest, ScheduleInlineThreeMixedTwice) { |
1534 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1535 | Tensor b = Compute( |
1536 | "b" , {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1537 | Tensor c = Compute("c" , {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
1538 | return a.load(k) * b.load(l); |
1539 | }); |
1540 | |
1541 | LoopNest l({c}, {a, b, c}); |
1542 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1543 | l.computeInline(a.buf()); |
1544 | l.computeInline(b.buf()); |
1545 | l.prepareForCodegen(); |
1546 | |
1547 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
1548 | std::vector<int> output(4 * 3, 0); |
1549 | SimpleIREvaluator eval(s, {c}); |
1550 | eval(output); |
1551 | |
1552 | for (int k = 0; k < 4; ++k) { |
1553 | for (int l = 0; l < 3; ++l) { |
1554 | ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); |
1555 | } |
1556 | } |
1557 | } |
1558 | |
1559 | // Inline a Compute that is both a producer and consumer. |
1560 | TEST(LoopNest, ScheduleInlineThreeMixedInner) { |
1561 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1562 | Tensor b = Compute( |
1563 | "b" , {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1564 | Tensor c = Compute("c" , {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
1565 | return a.load(k) * b.load(l); |
1566 | }); |
1567 | |
1568 | LoopNest l({c}, {a, b, c}); |
1569 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1570 | l.computeInline(b.buf()); |
1571 | l.prepareForCodegen(); |
1572 | |
1573 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
1574 | std::vector<int> output(4 * 3, 0); |
1575 | SimpleIREvaluator eval(s, {c}); |
1576 | eval(output); |
1577 | |
1578 | for (int k = 0; k < 4; ++k) { |
1579 | for (int l = 0; l < 3; ++l) { |
1580 | ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); |
1581 | } |
1582 | } |
1583 | } |
1584 | |
1585 | // Split 3 Computes, then inline the first two into the last. |
1586 | TEST(LoopNest, ScheduleInlineThreeMixedSplit) { |
1587 | Tensor a = Compute("a" , {18}, [&](const VarHandle& i) { return i * i; }); |
1588 | Tensor b = Compute( |
1589 | "b" , {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
1590 | Tensor c = Compute("c" , {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
1591 | return a.load(k) * b.load(l); |
1592 | }); |
1593 | |
1594 | LoopNest l({c}, {a, b, c}); |
1595 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
1596 | LoopNest::splitWithMask(loops[0], 4); |
1597 | loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); |
1598 | LoopNest::splitWithMask(loops[0], 3); |
1599 | loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
1600 | LoopNest::splitWithMask(loops[0], 2); |
1601 | |
1602 | ASSERT_FALSE(l.computeInline(a.buf())); |
1603 | } |
1604 | |
1605 | // Check that inlining works for output tensors too |
1606 | TEST(LoopNest, ScheduleInlineOutputTensors) { |
1607 | const int M = 4; |
1608 | const int N = 5; |
1609 | const int K = 6; |
1610 | |
1611 | Tensor x = Compute( |
1612 | "x" , |
1613 | {M, N, K}, |
1614 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1615 | return m * n * k; |
1616 | }); |
1617 | Tensor y = Compute( |
1618 | "y" , |
1619 | {M, N, K}, |
1620 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1621 | return x.load(m, n, k) + m; |
1622 | }); |
1623 | |
1624 | LoopNest l1({x, y}); |
1625 | l1.computeInline(x.buf()); |
1626 | |
1627 | // would normally compare results but Rand isn't implemented in the |
1628 | // SimpleIREvaluator, even if we could seed it. |
1629 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1630 | |
1631 | // Check the IR we produced |
1632 | checkIR(stmt1, R"IR( |
1633 | # CHECK: for (int i = 0; i < 4; i++) |
1634 | # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
1635 | # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
1636 | # CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; |
1637 | # CHECK: for (int i_3 = 0; i_3 < 4; i_3++) |
1638 | # CHECK: for (int i_4 = 0; i_4 < 5; i_4++) |
1639 | # CHECK: for (int i_5 = 0; i_5 < 6; i_5++) |
1640 | # CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR" ); |
1641 | } |
1642 | |
1643 | TEST(LoopNest, ScheduleInlineWithCompoundIndices) { |
1644 | // Input IR: |
1645 | // for (int64_t i = 0; i < 100; i++) { |
1646 | // A[i*2,i] = i * 500ll; |
1647 | // } |
1648 | // for (int64_t j = 0; j < 100; j++) { |
1649 | // B[0ll,j] = A[0, j] + j * 100ll; |
1650 | // } |
1651 | BufHandle a_buf("A" , {20, 100}, kLong); |
1652 | BufHandle b_buf("B" , {20, 100}, kLong); |
1653 | VarHandle i("i" , kLong); |
1654 | VarHandle j("j" , kLong); |
1655 | auto forI = For::make( |
1656 | i, |
1657 | 0, |
1658 | 100, |
1659 | Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast<int64_t>(500)))); |
1660 | auto forJ = For::make( |
1661 | j, |
1662 | 0, |
1663 | 100, |
1664 | Store::make( |
1665 | b_buf, |
1666 | {static_cast<int64_t>(0), j}, |
1667 | Add::make( |
1668 | Load::make(a_buf, {static_cast<int64_t>(0), j}), |
1669 | Mul::make(j, static_cast<int64_t>(100))))); |
1670 | auto par = Block::make({forI, forJ}); |
1671 | |
1672 | LoopNest l(par, {b_buf.node()}); |
1673 | // Inlining should fail since the producer has compound expr as index. |
1674 | ASSERT_FALSE(l.computeInline(a_buf.node())); |
1675 | |
1676 | // The input statement must remain as is. |
1677 | checkIR(l.root_stmt(), R"IR( |
1678 | # CHECK: for (int64_t i = 0; |
1679 | # CHECK-NEXT: A[ |
1680 | # CHECK: for (int64_t j = 0; |
1681 | # CHECK-NEXT: B[)IR" ); |
1682 | } |
1683 | |
1684 | TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { |
1685 | // Input IR: |
1686 | // for (int64_t i = 0; i < 100; i++) { |
1687 | // A[0ll,i] = i * 500ll; |
1688 | // } |
1689 | // for (int64_t j = 0; j < 100; j++) { |
1690 | // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; |
1691 | // } |
1692 | BufHandle a_buf("A" , {20, 100}, kLong); |
1693 | BufHandle b_buf("B" , {20, 100}, kLong); |
1694 | VarHandle i("i" , kLong); |
1695 | VarHandle j("j" , kLong); |
1696 | auto forI = For::make( |
1697 | i, |
1698 | 0, |
1699 | 100, |
1700 | Store::make( |
1701 | a_buf, |
1702 | {static_cast<int64_t>(0), i}, |
1703 | Mul::make(i, static_cast<int64_t>(500)))); |
1704 | auto forJ = For::make( |
1705 | j, |
1706 | 0, |
1707 | 100, |
1708 | Store::make( |
1709 | b_buf, |
1710 | {static_cast<int64_t>(0), j}, |
1711 | Add::make( |
1712 | Load::make(a_buf, {0, j}), |
1713 | Mul::make(j, static_cast<int64_t>(100))))); |
1714 | auto par = Block::make({forI, forJ}); |
1715 | |
1716 | LoopNest l(par, {b_buf.node()}); |
1717 | ASSERT_TRUE(l.computeInline(a_buf.node())); |
1718 | |
1719 | checkIR(l.root_stmt(), R"IR( |
1720 | # CHECK: for (int64_t j = 0; j < 100; j++) { |
1721 | # CHECK: B[0ll, j] = j * 500ll + j * 100ll; |
1722 | # CHECK: })IR" ); |
1723 | } |
1724 | |
1725 | TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { |
1726 | // Input IR: |
1727 | // for (int64_t i = 0; i < 100; i++) { |
1728 | // A[(int64_t)0,i] = i * 500ll; |
1729 | // } |
1730 | // for (int64_t j = 0; j < 100; j++) { |
1731 | // B[0ll,j] = A[0ll, j] + j * 100ll; |
1732 | // } |
1733 | BufHandle a_buf("A" , {20, 100}, kLong); |
1734 | BufHandle b_buf("B" , {20, 100}, kLong); |
1735 | VarHandle i("i" , kLong); |
1736 | VarHandle j("j" , kLong); |
1737 | auto forI = For::make( |
1738 | i, |
1739 | 0, |
1740 | 100, |
1741 | Store::make(a_buf, {0, i}, Mul::make(i, static_cast<int64_t>(500)))); |
1742 | auto forJ = For::make( |
1743 | j, |
1744 | 0, |
1745 | 100, |
1746 | Store::make( |
1747 | b_buf, |
1748 | {static_cast<int64_t>(0), j}, |
1749 | Add::make( |
1750 | Load::make(a_buf, {static_cast<int64_t>(0), j}), |
1751 | Mul::make(j, static_cast<int64_t>(100))))); |
1752 | auto par = Block::make({forI, forJ}); |
1753 | |
1754 | LoopNest l(par, {b_buf.node()}); |
1755 | ASSERT_TRUE(l.computeInline(a_buf.node())); |
1756 | |
1757 | checkIR(l.root_stmt(), R"IR( |
1758 | # CHECK: for (int64_t j = 0; j < 100; j++) { |
1759 | # CHECK: B[0ll, j] = j * 500ll + j * 100ll; |
1760 | # CHECK: })IR" ); |
1761 | } |
1762 | |
1763 | TEST(LoopNest, ScheduleFuserStyle) { |
1764 | const int kVectorSize = 8; |
1765 | const int kVectorCount = 128; |
1766 | const int kTotalSize = kVectorSize * kVectorCount; |
1767 | |
1768 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kFloat); |
1769 | |
1770 | Tensor b = |
1771 | Compute("f" , {kTotalSize}, [&](const std::vector<VarHandle>& axes) { |
1772 | return a_buf.load(axes[0]) + 11.0f; |
1773 | }); |
1774 | |
1775 | Tensor c = |
1776 | Compute("g" , {kTotalSize}, [&](const std::vector<VarHandle>& axes) { |
1777 | return b.load(axes[0]) + 1.0f; |
1778 | }); |
1779 | |
1780 | LoopNest l({b, c}); |
1781 | l.prepareForCodegen(); |
1782 | StmtPtr s = l.root_stmt(); |
1783 | |
1784 | std::vector<float> a_data(kTotalSize, 7.0f); |
1785 | std::vector<float> b_data(kTotalSize, 0.0f); |
1786 | std::vector<float> c_data(kTotalSize, 0.0f); |
1787 | SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); |
1788 | |
1789 | for (int i = 0; i < kTotalSize; i++) { |
1790 | ASSERT_EQ(b_data[i], 18.0f); |
1791 | ASSERT_EQ(c_data[i], 19.0f); |
1792 | } |
1793 | } |
1794 | |
1795 | TEST(LoopNest, ScheduleFuserThreeArg) { |
1796 | const int kVectorSize = 8; |
1797 | const int kVectorCount = 128; |
1798 | const int kTotalSize = kVectorSize * kVectorCount; |
1799 | |
1800 | BufHandle a("A" , {ExprHandle(kTotalSize)}, kFloat); |
1801 | BufHandle b("B" , {ExprHandle(kTotalSize)}, kFloat); |
1802 | BufHandle c("C" , {ExprHandle(kTotalSize)}, kFloat); |
1803 | BufHandle d("D" , {ExprHandle(kTotalSize)}, kFloat); |
1804 | |
1805 | Tensor e = Compute("e" , {kTotalSize}, [&](const VarHandle& i) { |
1806 | return a.load(i) + b.load(i); |
1807 | }); |
1808 | Tensor f = Compute("f" , {kTotalSize}, [&](const VarHandle& i) { |
1809 | return e.load(i) + c.load(i); |
1810 | }); |
1811 | Tensor g = Compute("g" , {kTotalSize}, [&](const VarHandle& i) { |
1812 | return f.load(i) + d.load(i); |
1813 | }); |
1814 | |
1815 | LoopNest l({g}, {e, f, g}); |
1816 | l.computeInline(l.getLoopBodyFor(e)); |
1817 | l.computeInline(l.getLoopBodyFor(f)); |
1818 | l.prepareForCodegen(); |
1819 | StmtPtr s = l.root_stmt(); |
1820 | |
1821 | std::vector<float> a_data(kTotalSize, 1.0f); |
1822 | std::vector<float> b_data(kTotalSize, 2.0f); |
1823 | std::vector<float> c_data(kTotalSize, 3.0f); |
1824 | std::vector<float> d_data(kTotalSize, 4.0f); |
1825 | std::vector<float> g_data(kTotalSize, 0.0f); |
1826 | SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); |
1827 | |
1828 | for (int i = 0; i < kTotalSize; i++) { |
1829 | ASSERT_EQ(g_data[i], 10.0f); |
1830 | } |
1831 | } |
1832 | |
1833 | TEST(LoopNest, ScheduleDynamicShape2D) { |
1834 | auto testWithSize = [](int32_t M, int32_t N) { |
1835 | VarHandle m("m" , kInt); |
1836 | VarHandle n("n" , kInt); |
1837 | BufHandle a("a" , {m, n}, kFloat); |
1838 | BufHandle b("b" , {m, n}, kFloat); |
1839 | Tensor c = |
1840 | Compute("c" , {m, n}, [&](const VarHandle& i, const VarHandle& j) { |
1841 | return a.load(i, j) + b.load(i, j); |
1842 | }); |
1843 | LoopNest l({c}); |
1844 | StmtPtr s = l.root_stmt(); |
1845 | SimpleIREvaluator cg(s, {a, b, c, m, n}); |
1846 | std::vector<float> aData(M * N, 1.0f); |
1847 | std::vector<float> bData(M * N, 2.0f); |
1848 | std::vector<float> cData(M * N, 0.0f); |
1849 | cg.call({aData, bData, cData, M, N}); |
1850 | ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7); |
1851 | }; |
1852 | testWithSize(1, 8); |
1853 | testWithSize(16, 32); |
1854 | testWithSize(37, 11); |
1855 | } |
1856 | |
1857 | TEST(LoopNest, LoopNestComputeAt_1) { |
1858 | // Verify that compute_at works on the following example: |
1859 | // |
1860 | // for (int i_a = 0; i_a < N; i_a++) { |
1861 | // A[i_a] = i_a * i_a |
1862 | // } |
1863 | // for (int i_b = 0; i_b < N; i_b++) { |
1864 | // B[i_b] = A[i_b] |
1865 | // } |
1866 | // |
1867 | // After the transformation the i_b loop should have an allocation for a temp |
1868 | // buffer and that buffer should be used in computation of B. No use of A |
1869 | // should be in that loop after the transformation. Also, computation of A |
1870 | // should not be inlined into B. Instead, it should be computed into the temp, |
1871 | // and the temp should be used in B. |
1872 | VarHandle N("N" , kInt); |
1873 | Tensor A = Compute("A" , {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); |
1874 | Tensor B = |
1875 | Compute("B" , {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); |
1876 | LoopNest l({B}, {A, B}); |
1877 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); |
1878 | LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); |
1879 | l.prepareForCodegen(); |
1880 | SimpleIREvaluator cg(l.root_stmt(), {B, N}); |
1881 | StmtPtr s = cg.stmt(); |
1882 | |
1883 | checkIR(s, R"IR( |
1884 | # CHECK: Allocate(temp); // dtype=int, dims=[1] |
1885 | # CHECK: for (int i = 0; i < N; i++) |
1886 | # CHECK: temp[ |
1887 | # CHECK-NOT: A[ |
1888 | # CHECK: B[i_1] = temp[0] |
1889 | # CHECK: Free(temp))IR" ); |
1890 | |
1891 | // Now check that the loop still produces the correct result. |
1892 | std::vector<int> b_data(100, 0); |
1893 | cg.call({b_data, 100}); |
1894 | |
1895 | std::vector<int> b_ref(100, 0); |
1896 | for (int i = 0; i < 100; i++) { |
1897 | b_ref[i] = i * i; |
1898 | } |
1899 | assertAllEqual(b_data, b_ref); |
1900 | } |
1901 | |
1902 | TEST(LoopNest, LoopNestComputeAt_2) { |
1903 | // Verify that compute_at works on the following example: |
1904 | // |
1905 | // for (int py = 0; py < H+1; py++) { |
1906 | // for (int px = 0; px < W+1; px++) { |
1907 | // p[py, px] = py*px |
1908 | // } |
1909 | // } |
1910 | // for (int cy = 0; cy < H; cy++) { |
1911 | // for (int cx = 0; cx < W; cx++) { |
1912 | // c[py, px] = p[cy,cx] + p[cy+1,cx] + |
1913 | // p[cy,cx+1] + p[cy+1,cx+1] |
1914 | // } |
1915 | // } |
1916 | |
1917 | const int kW = 16, kH = 16; |
1918 | VarHandle W("W" , kInt); |
1919 | VarHandle H("H" , kInt); |
1920 | Tensor p = Compute( |
1921 | "prod" , {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { |
1922 | return px * py; |
1923 | }); |
1924 | Tensor c = |
1925 | Compute("cons" , {H, W}, [&](const VarHandle& y, const VarHandle& x) { |
1926 | return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + |
1927 | p.load(y + 1, x + 1); |
1928 | }); |
1929 | |
1930 | std::vector<int> c_ref(kW * kH, 0); |
1931 | for (int y = 0; y < kH; y++) { |
1932 | for (int x = 0; x < kW; x++) { |
1933 | c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); |
1934 | } |
1935 | } |
1936 | LoopNest orig_loopnest({c}, {p, c}); |
1937 | |
1938 | { |
1939 | // First let's try to compute P at axis cy (the outer loop) |
1940 | LoopNest l(orig_loopnest); |
1941 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
1942 | LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); |
1943 | l.prepareForCodegen(); |
1944 | SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
1945 | StmtPtr s = cg.stmt(); |
1946 | |
1947 | // Check the IR we produced |
1948 | checkIR(s, R"IR( |
1949 | # CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] |
1950 | # CHECK: for (int i_2 = 0; i_2 < H; i_2++) |
1951 | # CHECK: for |
1952 | # CHECK: for |
1953 | # CHECK: for (int i_3 = 0; i_3 < W; i_3++) |
1954 | # CHECK-NOT: prod[ |
1955 | # CHECK: cons[ |
1956 | # CHECK: Free(temp))IR" ); |
1957 | |
1958 | // Now check that the loop still produces the correct result. |
1959 | std::vector<int> c_data(kW * kH, 0); |
1960 | cg.call({c_data, kW, kH}); |
1961 | |
1962 | assertAllEqual(c_data, c_ref); |
1963 | } |
1964 | { |
1965 | // Now let's try to compute P at axis cx (the inner loop) |
1966 | LoopNest l(orig_loopnest); |
1967 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
1968 | LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); |
1969 | l.prepareForCodegen(); |
1970 | SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
1971 | StmtPtr s = cg.stmt(); |
1972 | |
1973 | // Check the IR we produced |
1974 | checkIR(s, R"IR( |
1975 | # CHECK: Allocate(temp); // dtype=int, dims=[2, 2] |
1976 | # CHECK: for (int i_2 = 0; i_2 < H; i_2++) |
1977 | # CHECK: for (int i_3 = 0; i_3 < W; i_3++) |
1978 | # CHECK: for |
1979 | # CHECK: for |
1980 | # CHECK-NOT: prod[ |
1981 | # CHECK: cons[ |
1982 | # CHECK: Free(temp))IR" ); |
1983 | |
1984 | // Now check that the loop still produces the correct result. |
1985 | std::vector<int> c_data(kW * kH, 0); |
1986 | cg.call({c_data, kW, kH}); |
1987 | |
1988 | assertAllEqual(c_data, c_ref); |
1989 | } |
1990 | } |
1991 | |
1992 | TEST(LoopNest, LoopNestComputeAt_3) { |
1993 | // Verify that compute_at works on the following example: |
1994 | // |
1995 | // A(x,y) = x*y |
1996 | // B(x,y) = A(x, y) |
1997 | // C(x,y) = B(x+1, y) |
1998 | // D(x,y) = A(x, y+1) + C(x, y) |
1999 | // |
2000 | // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. |
2001 | |
2002 | const int kW = 16, kH = 16; |
2003 | VarHandle W("W" , kInt); |
2004 | VarHandle H("H" , kInt); |
2005 | Tensor A = Compute( |
2006 | "A" , {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { |
2007 | return ax * ay; |
2008 | }); |
2009 | Tensor B = Compute( |
2010 | "B" , {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { |
2011 | return A.load(by, bx); |
2012 | }); |
2013 | Tensor C = |
2014 | Compute("C" , {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { |
2015 | return B.load(cy, cx + 1); |
2016 | }); |
2017 | Tensor D = |
2018 | Compute("D" , {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { |
2019 | return A.load(dy + 1, dx) + C.load(dy, dx); |
2020 | }); |
2021 | |
2022 | std::vector<int> c_ref(kW * kH, 0); |
2023 | for (int y = 0; y < kH; y++) { |
2024 | for (int x = 0; x < kW; x++) { |
2025 | c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); |
2026 | } |
2027 | } |
2028 | |
2029 | LoopNest orig_loopnest({D}, {A, B, C, D}); |
2030 | { |
2031 | // First let's try to compute A at axis dy (the outer loop) |
2032 | LoopNest l(orig_loopnest); |
2033 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); |
2034 | LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); |
2035 | l.prepareForCodegen(); |
2036 | SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); |
2037 | StmtPtr s = cg.stmt(); |
2038 | |
2039 | // Check the IR we produced |
2040 | checkIR(s, R"IR( |
2041 | # CHECK: Allocate(temp); // dtype=int, dims=[1, W] |
2042 | # CHECK: for (int i = 0; i < H + 1; i++) |
2043 | # CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) |
2044 | # CHECK: A[ |
2045 | # CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) |
2046 | # CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) |
2047 | # CHECK: B[ |
2048 | # CHECK: for (int i_4 = 0; i_4 < H; i_4++) |
2049 | # CHECK: for (int i_5 = 0; i_5 < W; i_5++) |
2050 | # CHECK: C[ |
2051 | # CHECK: for (int i_6 = 0; i_6 < H; i_6++) |
2052 | # CHECK: for (int i_7 = 0; i_7 < W; i_7++) |
2053 | # CHECK-NOT: A[)IR" ); |
2054 | |
2055 | // Now check that the loop still produces the correct result. |
2056 | std::vector<int> c_data(kW * kH, 0); |
2057 | cg.call({c_data, kW, kH}); |
2058 | |
2059 | assertAllEqual(c_data, c_ref); |
2060 | } |
2061 | { |
2062 | // Now let's try to compute A at axis dx (the inner loop) |
2063 | LoopNest l(orig_loopnest); |
2064 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); |
2065 | LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); |
2066 | l.prepareForCodegen(); |
2067 | SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); |
2068 | StmtPtr s = cg.stmt(); |
2069 | |
2070 | // Check the IR we produced |
2071 | checkIR(s, R"IR( |
2072 | # CHECK: Allocate(temp); // dtype=int, dims=[1, 1] |
2073 | # CHECK: for (int i = 0; i < H + 1; i++) |
2074 | # CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) |
2075 | # CHECK: A[ |
2076 | # CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) |
2077 | # CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) |
2078 | # CHECK: B[ |
2079 | # CHECK: for (int i_4 = 0; i_4 < H; i_4++) |
2080 | # CHECK: for (int i_5 = 0; i_5 < W; i_5++) |
2081 | # CHECK: C[ |
2082 | # CHECK: for (int i_6 = 0; i_6 < H; i_6++) |
2083 | # CHECK: for (int i_7 = 0; i_7 < W; i_7++) |
2084 | # CHECK-NOT: A[)IR" ); |
2085 | |
2086 | // Now check that the loop still produces the correct result. |
2087 | std::vector<int> c_data(kW * kH, 0); |
2088 | cg.call({c_data, kW, kH}); |
2089 | |
2090 | assertAllEqual(c_data, c_ref); |
2091 | } |
2092 | } |
2093 | |
2094 | using Axis = const VarHandle&; |
2095 | |
2096 | TEST(LoopNest, Reduce2dComputeAt) { |
2097 | const int kW = 16, kH = 16; |
2098 | VarHandle W("W" , kInt); |
2099 | VarHandle H("H" , kInt); |
2100 | |
2101 | Tensor p = Compute( |
2102 | "prod" , {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); |
2103 | Tensor c = Reduce( |
2104 | "cons" , |
2105 | {H, W}, |
2106 | Sum(), |
2107 | [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, |
2108 | {2, 2}); |
2109 | |
2110 | std::vector<int> c_ref(kW * kH, 0); |
2111 | for (int y = 0; y < kH; y++) { |
2112 | for (int x = 0; x < kW; x++) { |
2113 | c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); |
2114 | } |
2115 | } |
2116 | LoopNest orig_loopnest({c}, {p, c}); |
2117 | checkIR(orig_loopnest.root_stmt(), R"IR( |
2118 | # CHECK: for (int i = 0; i < H + 1; i++) { |
2119 | # CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { |
2120 | # CHECK: prod[i, i_1] = i_1 * i; |
2121 | # CHECK: } |
2122 | # CHECK: } |
2123 | # CHECK: for (int i_2 = 0; i_2 < H; i_2++) { |
2124 | # CHECK: for (int i_3 = 0; i_3 < W; i_3++) { |
2125 | # CHECK: cons[i_2, i_3] = int(0); |
2126 | # CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { |
2127 | # CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { |
2128 | # CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); |
2129 | # CHECK: } |
2130 | # CHECK: } |
2131 | # CHECK: } |
2132 | # CHECK: } |
2133 | )IR" ); |
2134 | |
2135 | { |
2136 | // First let's try to compute P at axis cy (the outer loop) |
2137 | LoopNest l(orig_loopnest); |
2138 | auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
2139 | LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); |
2140 | // FIXME: Calling simplify here breaks the IR: |
2141 | // MALFORMED INPUT: could not find base node in Load - temp[...] |
2142 | // l.simplify(); |
2143 | l.eliminateDeadStores(); |
2144 | l.prepareForCodegen(); |
2145 | SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
2146 | checkIR(cg.stmt(), R"IR( |
2147 | # CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] |
2148 | # CHECK: for (int i = 0; i < H; i++) { |
2149 | # CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { |
2150 | # CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { |
2151 | # CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); |
2152 | # CHECK: } |
2153 | # CHECK: } |
2154 | # CHECK: for (int i_1 = 0; i_1 < W; i_1++) { |
2155 | # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); |
2156 | # CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { |
2157 | # CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { |
2158 | # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); |
2159 | # CHECK: } |
2160 | # CHECK: } |
2161 | # CHECK: } |
2162 | # CHECK: } |
2163 | # CHECK: Free(temp); |
2164 | )IR" ); |
2165 | |
2166 | // Now check that the loop still produces the correct result. |
2167 | std::vector<int> c_data(kW * kH, 0); |
2168 | cg.call({c_data, kW, kH}); |
2169 | assertAllEqual(c_data, c_ref); |
2170 | } |
2171 | { |
2172 | // Now let's try to compute P at axis cx (the inner loop) |
2173 | LoopNest l(orig_loopnest); |
2174 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
2175 | LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); |
2176 | l.simplify(); |
2177 | l.eliminateDeadStores(); |
2178 | l.prepareForCodegen(); |
2179 | SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
2180 | checkIR(cg.stmt(), R"IR( |
2181 | # CHECK: Allocate(temp); // dtype=int, dims=[2, 2] |
2182 | # CHECK: for (int i = 0; i < H; i++) { |
2183 | # CHECK: for (int i_1 = 0; i_1 < W; i_1++) { |
2184 | # CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { |
2185 | # CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { |
2186 | # CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); |
2187 | # CHECK: } |
2188 | # CHECK: } |
2189 | # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; |
2190 | # CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { |
2191 | # CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { |
2192 | # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); |
2193 | # CHECK: } |
2194 | # CHECK: } |
2195 | # CHECK: } |
2196 | # CHECK: } |
2197 | # CHECK: Free(temp); |
2198 | )IR" ); |
2199 | |
2200 | // Now check that the loop still produces the correct result. |
2201 | std::vector<int> c_data(kW * kH, 0); |
2202 | cg.call({c_data, kW, kH}); |
2203 | assertAllEqual(c_data, c_ref); |
2204 | } |
2205 | } |
2206 | |
2207 | TEST(LoopNest, DISABLED_Conv1d_NH) { |
2208 | // Lots of stuff is broken here. The computeAt swaps the axes for some odd |
2209 | // reason. Even without that, the index flattener fails due to "dimensions |
2210 | // mismatch in flatten index". |
2211 | |
2212 | int N = 4; |
2213 | int H = 256; |
2214 | int R = 3; |
2215 | int Pad = 1; |
2216 | BufHandle IP("input" , {H}, kFloat); |
2217 | |
2218 | Tensor A = Compute("A" , {N, H + 2 * Pad}, [&](Axis n, Axis h) { |
2219 | auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); |
2220 | cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); |
2221 | return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); |
2222 | }); |
2223 | Tensor B = Reduce( |
2224 | "B" , |
2225 | {N, H}, |
2226 | Sum(), |
2227 | [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, |
2228 | {R}); |
2229 | LoopNest l({B}); |
2230 | checkIR(l.root_stmt(), R"IR( |
2231 | # CHECK: for (int np = 0; np < 4; np++) { |
2232 | # CHECK: for (int hp = 0; hp < 258; hp++) { |
2233 | # CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); |
2234 | # CHECK: } |
2235 | # CHECK: } |
2236 | # CHECK: for (int n = 0; n < 4; n++) { |
2237 | # CHECK: for (int h = 0; h < 256; h++) { |
2238 | # CHECK: B[n, h] = float(0); |
2239 | # CHECK: for (int r = 0; r < 3; r++) { |
2240 | # CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); |
2241 | # CHECK: } |
2242 | # CHECK: } |
2243 | # CHECK: } |
2244 | )IR" ); |
2245 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); |
2246 | LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); |
2247 | // FIXME: The current IR is totally broken. The body of the inlined loop is: |
2248 | |
2249 | // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), |
2250 | // 0.f, input[idx1 + 0, (idx0 + n) - 1]); |
2251 | |
2252 | // Which seems to mix up the axes. The CHECK below is my best guess at what |
2253 | // the input "should" look like |
2254 | |
2255 | checkIR(l.root_stmt(), R"IR( |
2256 | # CHECK: for (int n = 0; n < 4; n++) { |
2257 | # CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { |
2258 | # CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { |
2259 | temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); |
2260 | # CHECK: } |
2261 | # CHECK: } |
2262 | # CHECK: for (int h = 0; h < 256; h++) { |
2263 | # CHECK: B[n, h] = float(0); |
2264 | # CHECK: for (int r = 0; r < 3; r++) { |
2265 | # CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); |
2266 | # CHECK: } |
2267 | # CHECK: } |
2268 | # CHECK: } |
2269 | )IR" ); |
2270 | |
2271 | l.simplify(); |
2272 | l.prepareForCodegen(); |
2273 | StmtPtr s = l.root_stmt(); |
2274 | |
2275 | SimpleIREvaluator cg(s, {IP, B}); |
2276 | // auto At = at::ones({N, H}, at::kFloat); |
2277 | auto At = at::arange(N * H, at::kFloat).reshape({N, H}); |
2278 | auto Rt = at::conv1d( |
2279 | At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); |
2280 | auto Bt = at::empty_like(Rt); |
2281 | cg.call({At.data_ptr<float>(), Bt.data_ptr<float>()}); |
2282 | ASSERT_TRUE(at::allclose(Rt, Bt)); |
2283 | } |
2284 | |
2285 | class LoopOrderHelper : public IRVisitor { |
2286 | std::stringstream ordering; |
2287 | |
2288 | public: |
2289 | std::string getOrder(StmtPtr s) { |
2290 | ordering.str("" ); |
2291 | s->accept(this); |
2292 | return ordering.str(); |
2293 | } |
2294 | |
2295 | void visit(ForPtr v) final { |
2296 | ordering << v->var()->name_hint() << "," ; |
2297 | IRVisitor::visit(v); |
2298 | } |
2299 | }; |
2300 | |
2301 | TEST(LoopNest, LoopNestReorderAxis1) { |
2302 | Tensor tensor = |
2303 | Compute("f" , {2, 3}, [](const VarHandle& x, const VarHandle& y) { |
2304 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
2305 | }); |
2306 | LoopNest l({tensor}); |
2307 | StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2308 | |
2309 | std::vector<int> stmt1_output(6, 0); |
2310 | SimpleIREvaluator cg(stmt1, {tensor}); |
2311 | cg.call({stmt1_output}); |
2312 | |
2313 | auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2314 | LoopNest::reorderAxis(loops[0], loops[1]); |
2315 | StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2316 | |
2317 | ASSERT_NE(stmt1, stmt2); |
2318 | LoopOrderHelper loopOrderHelper; |
2319 | std::string order1 = loopOrderHelper.getOrder(stmt1); |
2320 | std::string order2 = loopOrderHelper.getOrder(stmt2); |
2321 | |
2322 | ASSERT_EQ(order1, "j,i," ); |
2323 | ASSERT_EQ(order2, "i,j," ); |
2324 | |
2325 | std::vector<int> stmt2_output(6, 0); |
2326 | SimpleIREvaluator cg2(stmt2, {tensor}); |
2327 | cg.call({stmt2_output}); |
2328 | |
2329 | for (int i = 0; i < 6; ++i) { |
2330 | ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
2331 | } |
2332 | |
2333 | // Reorder them back. |
2334 | loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2335 | LoopNest::reorderAxis(loops[0], loops[1]); |
2336 | StmtPtr stmt3 = l.root_stmt(); |
2337 | |
2338 | std::string order3 = loopOrderHelper.getOrder(stmt3); |
2339 | ASSERT_EQ(order3, order1); |
2340 | |
2341 | std::ostringstream oss1, oss2; |
2342 | oss1 << *stmt1; |
2343 | oss2 << *stmt3; |
2344 | |
2345 | // Should be identical to the unreordered statement. |
2346 | ASSERT_EQ(oss1.str(), oss2.str()); |
2347 | } |
2348 | |
2349 | TEST(LoopNest, LoopNestReorderPartialAxes) { |
2350 | Tensor tensor = Compute( |
2351 | "f" , |
2352 | {2, 3, 4}, |
2353 | [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { |
2354 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y + |
2355 | cast<float>(z) * z; |
2356 | }); |
2357 | LoopNest l({tensor}); |
2358 | |
2359 | LoopOrderHelper loopOrderHelper; |
2360 | StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2361 | ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k," ); |
2362 | |
2363 | std::vector<int> stmt1_output(24, 0); |
2364 | SimpleIREvaluator cg(stmt1, {tensor}); |
2365 | cg.call({stmt1_output}); |
2366 | |
2367 | auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2368 | LoopNest::reorderAxis(loops[0], loops[1]); |
2369 | ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k," ); |
2370 | |
2371 | StmtPtr stmt2 = Stmt::clone(l.root_stmt()); |
2372 | |
2373 | std::vector<int> stmt2_output(24, 0); |
2374 | SimpleIREvaluator cg2(stmt2, {tensor}); |
2375 | cg2.call({stmt2_output}); |
2376 | |
2377 | for (int i = 0; i < 24; ++i) { |
2378 | ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
2379 | } |
2380 | |
2381 | loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2382 | LoopNest::reorderAxis(loops[1], loops[2]); |
2383 | ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i," ); |
2384 | |
2385 | StmtPtr stmt3 = Stmt::clone(l.root_stmt()); |
2386 | |
2387 | std::vector<int> stmt3_output(24, 0); |
2388 | SimpleIREvaluator cg3(stmt3, {tensor}); |
2389 | cg3.call({stmt3_output}); |
2390 | |
2391 | for (int i = 0; i < 24; ++i) { |
2392 | ASSERT_EQ(stmt1_output[i], stmt3_output[i]); |
2393 | } |
2394 | } |
2395 | |
2396 | TEST(LoopNest, LoopNestReorderInternalAxis) { |
2397 | Tensor tensor = Compute( |
2398 | "f" , |
2399 | {1, 2, 3, 4}, |
2400 | [](const VarHandle& w, |
2401 | const VarHandle& x, |
2402 | const VarHandle& y, |
2403 | const VarHandle& z) { |
2404 | return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y + |
2405 | cast<float>(z) * z; |
2406 | }); |
2407 | LoopNest l({tensor}); |
2408 | |
2409 | LoopOrderHelper loopOrderHelper; |
2410 | StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2411 | ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l," ); |
2412 | |
2413 | std::vector<int> stmt1_output(24, 0); |
2414 | SimpleIREvaluator cg(stmt1, {tensor}); |
2415 | cg.call({stmt1_output}); |
2416 | |
2417 | auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2418 | LoopNest::reorderAxis(loops[2], loops[1]); |
2419 | ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l," ); |
2420 | |
2421 | StmtPtr stmt2 = l.root_stmt(); |
2422 | |
2423 | std::vector<int> stmt2_output(24, 0); |
2424 | SimpleIREvaluator cg2(stmt2, {tensor}); |
2425 | cg2.call({stmt2_output}); |
2426 | |
2427 | for (int i = 0; i < 24; ++i) { |
2428 | ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
2429 | } |
2430 | } |
2431 | |
2432 | TEST(LoopNest, LoopNestReorderEnclosingAxis) { |
2433 | Tensor tensor = Compute( |
2434 | "f" , |
2435 | {1, 2, 3, 4}, |
2436 | [](const VarHandle& w, |
2437 | const VarHandle& x, |
2438 | const VarHandle& y, |
2439 | const VarHandle& z) { |
2440 | return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y + |
2441 | cast<float>(z) * z; |
2442 | }); |
2443 | LoopNest l({tensor}); |
2444 | |
2445 | LoopOrderHelper loopOrderHelper; |
2446 | StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2447 | |
2448 | std::vector<int> stmt1_output(24, 0); |
2449 | SimpleIREvaluator cg(stmt1, {tensor}); |
2450 | cg.call({stmt1_output}); |
2451 | |
2452 | auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2453 | LoopNest::reorderAxis(loops[0], loops[3]); |
2454 | ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i," ); |
2455 | |
2456 | StmtPtr stmt2 = l.root_stmt(); |
2457 | |
2458 | std::vector<int> stmt2_output(24, 0); |
2459 | SimpleIREvaluator cg2(stmt2, {tensor}); |
2460 | cg2.call({stmt2_output}); |
2461 | |
2462 | for (int i = 0; i < 24; ++i) { |
2463 | ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
2464 | } |
2465 | } |
2466 | |
2467 | TEST(LoopNest, LoopNestReorderSameAxis) { |
2468 | Tensor tensor = |
2469 | Compute("f" , {2, 3}, [](const VarHandle& x, const VarHandle& y) { |
2470 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
2471 | }); |
2472 | LoopNest l({tensor}); |
2473 | StmtPtr stmt1 = Stmt::clone(l.root_stmt()); |
2474 | |
2475 | auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2476 | LoopNest::reorderAxis(loops[1], loops[1]); |
2477 | StmtPtr stmt2 = Stmt::clone(l.root_stmt()); |
2478 | |
2479 | std::ostringstream oss, oss2; |
2480 | oss << *stmt1; |
2481 | oss2 << *stmt2; |
2482 | ASSERT_EQ(oss.str(), oss2.str()); |
2483 | } |
2484 | |
2485 | TEST(LoopNest, LoopNestReorderExtraStatements) { |
2486 | /* We're going for a structure like this: |
2487 | * for i in ... |
2488 | * Stmt 1 |
2489 | * for j in ... |
2490 | * Stmt 2 |
2491 | * for k in ... |
2492 | * Stmt 3 |
2493 | * Stmt 4 |
2494 | */ |
2495 | |
2496 | Tensor tensor = Compute( |
2497 | "f" , |
2498 | {2, 3, 4}, |
2499 | [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { |
2500 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y + |
2501 | cast<float>(z) * z; |
2502 | }); |
2503 | LoopNest l({tensor}); |
2504 | |
2505 | BufHandle ("res" , {6, 3}, kFloat); |
2506 | |
2507 | auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2508 | |
2509 | VarHandle i = VarHandle(loops[0]->var()); |
2510 | |
2511 | StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); |
2512 | StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); |
2513 | // stmt 3 is the Function body. |
2514 | StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); |
2515 | |
2516 | loops[0]->body()->prepend_stmt(store_1); |
2517 | loops[1]->body()->prepend_stmt(store_2); |
2518 | loops[1]->body()->append_stmt(store_3); |
2519 | StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2520 | |
2521 | std::vector<int> (6, 0); |
2522 | std::vector<int> res1(24, 0); |
2523 | SimpleIREvaluator cg(stmt1, {tensor, extra}); |
2524 | cg.call({res1, extra1}); |
2525 | |
2526 | /* Then we reorder loop y and z, we want it to look like: |
2527 | * |
2528 | * for i in ... |
2529 | * Stmt 1 |
2530 | * for j in ... |
2531 | * Stmt 2 |
2532 | * for j_1 in ... |
2533 | * for k in ... |
2534 | * Stmt 3 |
2535 | * for j_2 in ... |
2536 | * Stmt 4 |
2537 | * |
2538 | * We need extra loops because we don't have dependency info about stmt 3 |
2539 | * and 4. |
2540 | * |
2541 | */ |
2542 | |
2543 | LoopNest::reorderAxis(loops[1], loops[2]); |
2544 | StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2545 | |
2546 | // Check the IR we produced |
2547 | checkIR(stmt2, R"IR( |
2548 | # CHECK: for |
2549 | # CHECK: res[i, 0] = 1 |
2550 | # CHECK: for |
2551 | # CHECK: res[i, 1] = 2 |
2552 | # CHECK: for |
2553 | # CHECK: for |
2554 | # CHECK: f[ |
2555 | # CHECK: for |
2556 | # CHECK: res[i, 2] = 4 |
2557 | )IR" ); |
2558 | |
2559 | std::vector<int> (6, 0); |
2560 | std::vector<int> res2(24, 0); |
2561 | SimpleIREvaluator cg2(stmt2, {tensor, extra}); |
2562 | cg2.call({res2, extra2}); |
2563 | |
2564 | for (int i = 0; i < 24; ++i) { |
2565 | ASSERT_EQ(res1[i], res2[i]); |
2566 | } |
2567 | for (int i = 0; i < 6; ++i) { |
2568 | ASSERT_EQ(extra1[i], extra2[i]); |
2569 | } |
2570 | |
2571 | /* Now reorder x and the y above stmt 3: |
2572 | * |
2573 | * |
2574 | * for x in ... |
2575 | * Stmt 1 |
2576 | * for y in ... |
2577 | * Stmt 2 |
2578 | * |
2579 | * for y in ... |
2580 | * for z in ... |
2581 | * for x in ... |
2582 | * Stmt 3 |
2583 | * |
2584 | * for x in ... |
2585 | * for y in ... |
2586 | * Stmt 4 |
2587 | * |
2588 | * |
2589 | */ |
2590 | loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
2591 | LoopNest::reorderAxis(loops[0], loops[2]); |
2592 | StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
2593 | |
2594 | // Check the IR we produced |
2595 | checkIR(stmt3, R"IR( |
2596 | # CHECK: for |
2597 | # CHECK: res[i, 0] = 1 |
2598 | # CHECK: for |
2599 | # CHECK: res[i, 1] = 2 |
2600 | # CHECK: for |
2601 | # CHECK: for |
2602 | # CHECK: for |
2603 | # CHECK: f[ |
2604 | # CHECK: for |
2605 | # CHECK: for |
2606 | # CHECK: res[i_2, 2] = 4 |
2607 | )IR" ); |
2608 | |
2609 | std::vector<int> (6, 0); |
2610 | std::vector<int> res3(24, 0); |
2611 | SimpleIREvaluator cg3(stmt3, {tensor, extra}); |
2612 | cg3.call({res3, extra3}); |
2613 | |
2614 | for (int i = 0; i < 24; ++i) { |
2615 | ASSERT_EQ(res1[i], res3[i]); |
2616 | } |
2617 | for (int i = 0; i < 6; ++i) { |
2618 | ASSERT_EQ(extra1[i], extra3[i]); |
2619 | } |
2620 | } |
2621 | |
2622 | void LoopNestReorderTestHelper( |
2623 | bool prepend, |
2624 | bool append, |
2625 | int index1, |
2626 | int index2) { |
2627 | Tensor c = Compute( |
2628 | "5d" , {2, 3, 2, 3, 2}, [](const std::vector<VarHandle>&) { return -1; }); |
2629 | LoopNest l({c}); |
2630 | |
2631 | BufHandle ("extra" , {5}, kInt); |
2632 | |
2633 | auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
2634 | int j = 0; |
2635 | for (auto l : loops) { |
2636 | // Add an increment at each layer of the loop which counts the number of |
2637 | // times the loop executes. |
2638 | LoadPtr load = |
2639 | alloc<Load>(extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)})); |
2640 | AddPtr add = alloc<Add>(load, alloc<IntImm>(1)); |
2641 | StmtPtr store = alloc<Store>( |
2642 | extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}), add); |
2643 | if (prepend) { |
2644 | l->body()->prepend_stmt(store); |
2645 | } |
2646 | if (append) { |
2647 | l->body()->append_stmt(Stmt::clone(store)); |
2648 | } |
2649 | |
2650 | j++; |
2651 | } |
2652 | |
2653 | StmtPtr stmt1 = Stmt::clone(l.root_stmt()); |
2654 | |
2655 | std::vector<int> (5, 0); |
2656 | std::vector<int> res1(2 * 3 * 2 * 3 * 2, 0); |
2657 | SimpleIREvaluator cg(stmt1, {c, extra}); |
2658 | cg.call({res1, extra1}); |
2659 | |
2660 | std::vector<int> loopExtents = {2, 3, 2, 3, 2}; |
2661 | |
2662 | int expected_loops = 0; |
2663 | if (prepend) { |
2664 | expected_loops++; |
2665 | } |
2666 | if (append) { |
2667 | expected_loops++; |
2668 | } |
2669 | for (int i = 0; i < 5; ++i) { |
2670 | expected_loops *= loopExtents[i]; |
2671 | ASSERT_EQ(extra1[i], expected_loops); |
2672 | } |
2673 | |
2674 | loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
2675 | LoopNest::reorderAxis(loops[index1], loops[index2]); |
2676 | StmtPtr stmt2 = Stmt::clone(l.root_stmt()); |
2677 | |
2678 | std::ostringstream oss, oss2; |
2679 | oss << *stmt1; |
2680 | oss2 << *stmt2; |
2681 | ASSERT_NE(oss.str(), oss2.str()); |
2682 | |
2683 | std::vector<int> (5, 0); |
2684 | std::vector<int> res2(2 * 3 * 2 * 3 * 2, 0); |
2685 | SimpleIREvaluator cg2(stmt2, {c, extra}); |
2686 | cg2.call({res2, extra2}); |
2687 | |
2688 | expected_loops = 0; |
2689 | if (prepend) { |
2690 | expected_loops++; |
2691 | } |
2692 | if (append) { |
2693 | expected_loops++; |
2694 | } |
2695 | |
2696 | for (int i = 0; i < 5; ++i) { |
2697 | expected_loops *= loopExtents[i]; |
2698 | ASSERT_EQ(extra2[i], expected_loops); |
2699 | } |
2700 | |
2701 | for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { |
2702 | ASSERT_EQ(res2[i], res1[i]); |
2703 | } |
2704 | } |
2705 | |
2706 | TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { |
2707 | for (int i = 0; i < 5; ++i) { |
2708 | for (int j = 0; j < 5; ++j) { |
2709 | // skip noops, since we check the loop isn't the same after reordering. |
2710 | if (i != j) { |
2711 | LoopNestReorderTestHelper(true, false, i, j); |
2712 | } |
2713 | } |
2714 | } |
2715 | } |
2716 | |
2717 | TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { |
2718 | for (int i = 0; i < 5; ++i) { |
2719 | for (int j = 0; j < 5; ++j) { |
2720 | // skip noops, since we check the loop isn't the same after reordering. |
2721 | if (i != j) { |
2722 | LoopNestReorderTestHelper(false, true, i, j); |
2723 | } |
2724 | } |
2725 | } |
2726 | } |
2727 | |
2728 | TEST(LoopNest, LoopNestReorderLongStringFull) { |
2729 | for (int i = 0; i < 5; ++i) { |
2730 | for (int j = 0; j < 5; ++j) { |
2731 | // skip noops, since we check the loop isn't the same after reordering. |
2732 | if (i != j) { |
2733 | LoopNestReorderTestHelper(true, true, i, j); |
2734 | } |
2735 | } |
2736 | } |
2737 | } |
2738 | |
2739 | TEST(LoopNest, LoopNestReorderInternalLoopNest) { |
2740 | const int M = 4; |
2741 | const int N = 5; |
2742 | const int K = 6; |
2743 | BufHandle a_buf("a" , {M, N}, kFloat); |
2744 | BufHandle b_buf("b" , {N, K}, kFloat); |
2745 | BufHandle c_buf("c" , {M, N}, kFloat); |
2746 | BufHandle d_buf("d" , {M, K}, kFloat); |
2747 | |
2748 | Tensor x = Compute( |
2749 | "x" , |
2750 | {M, N, K}, |
2751 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
2752 | return a_buf.load(m, n) * b_buf.load(n, k); |
2753 | }); |
2754 | Tensor y = Compute( |
2755 | "y" , |
2756 | {M, N, K}, |
2757 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
2758 | return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); |
2759 | }); |
2760 | Tensor z = Compute( |
2761 | "z" , |
2762 | {M, N, K}, |
2763 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
2764 | return x.load(m, n, k) + y.load(m, n, k); |
2765 | }); |
2766 | |
2767 | LoopNest l({z}, {x, y, z}); |
2768 | ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; |
2769 | ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; |
2770 | LoopNest::reorderAxis(a, b); |
2771 | |
2772 | l.prepareForCodegen(); |
2773 | StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
2774 | |
2775 | // Check the IR we produced has the 3 nests in the right order, but k and m |
2776 | // swapped in the middle. |
2777 | checkIR(stmt, R"IR( |
2778 | # CHECK: < 4 |
2779 | # CHECK: < 5 |
2780 | # CHECK: < 6 |
2781 | # CHECK: < 6 |
2782 | # CHECK: < 5 |
2783 | # CHECK: < 4 |
2784 | # CHECK: < 4 |
2785 | # CHECK: < 5 |
2786 | # CHECK: < 6)IR" ); |
2787 | |
2788 | { |
2789 | PaddedBuffer<float> a_v(M, N); |
2790 | PaddedBuffer<float> b_v(N, K); |
2791 | PaddedBuffer<float> c_v(M, N); |
2792 | PaddedBuffer<float> d_v(M, K); |
2793 | |
2794 | for (int i = 0; i < M; i++) { |
2795 | for (int j = 0; j < N; j++) { |
2796 | a_v(i, j) = i * i; |
2797 | } |
2798 | } |
2799 | for (int i = 0; i < N; i++) { |
2800 | for (int j = 0; j < K; j++) { |
2801 | b_v(i, j) = j * j; |
2802 | } |
2803 | } |
2804 | for (int i = 0; i < M; i++) { |
2805 | for (int j = 0; j < N; j++) { |
2806 | c_v(i, j) = i + j; |
2807 | } |
2808 | } |
2809 | for (int i = 0; i < M; i++) { |
2810 | for (int j = 0; j < K; j++) { |
2811 | d_v(i, j) = i * j; |
2812 | } |
2813 | } |
2814 | |
2815 | PaddedBuffer<float> z_v(M, N, K); |
2816 | PaddedBuffer<float> z_ref(M, N, K); |
2817 | for (int m = 0; m < M; m++) { |
2818 | for (int n = 0; n < N; n++) { |
2819 | for (int k = 0; k < K; k++) { |
2820 | z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); |
2821 | } |
2822 | } |
2823 | } |
2824 | |
2825 | SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); |
2826 | eval(a_v, b_v, c_v, d_v, z_v); |
2827 | ExpectAllNear(z_v, z_ref, 1e-5); |
2828 | } |
2829 | } |
2830 | |
2831 | TEST(LoopNest, OuterLoopVectorization) { |
2832 | Tensor tensor = |
2833 | Compute("f" , {8, 8}, [](const VarHandle& x, const VarHandle& y) { |
2834 | return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
2835 | }); |
2836 | LoopNest l({tensor}); |
2837 | |
2838 | ASSERT_TRUE( |
2839 | LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); |
2840 | |
2841 | StmtPtr root_stmt = l.root_stmt(); |
2842 | BlockPtr outer_block = to<Block>(root_stmt); |
2843 | ASSERT_NE(outer_block, nullptr); |
2844 | while (BlockPtr inner_block = to<Block>(outer_block->front())) { |
2845 | outer_block = inner_block; |
2846 | } |
2847 | |
2848 | // Verify that we have only a single loop level remaining after |
2849 | // vectorization. |
2850 | ASSERT_EQ(outer_block->nstmts(), 1); |
2851 | ForPtr for_loop = to<For>(outer_block->front()); |
2852 | ASSERT_NE(for_loop, nullptr); |
2853 | BlockPtr for_body = for_loop->body(); |
2854 | ASSERT_EQ(for_body->nstmts(), 1); |
2855 | ASSERT_EQ(to<For>(for_body->front()), nullptr); |
2856 | } |
2857 | |
2858 | TEST(LoopNest, VectorizeLoopNotNormalized) { |
2859 | // Input IR: |
2860 | // for (int i = 0; i < 10; i++) { |
2861 | // for (int j = 1; j < 5; j++) { |
2862 | // A[i,j] = i * j; |
2863 | // } |
2864 | // } |
2865 | BufHandle a_buf("A" , {10, 5}, kInt); |
2866 | VarHandle i("i" , kInt); |
2867 | VarHandle j("j" , kInt); |
2868 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
2869 | auto inner_for = For::make(j, 1, 5, for_body); |
2870 | auto outer_for = For::make(i, 0, 10, inner_for); |
2871 | auto block = Block::make({outer_for}); |
2872 | LoopNest l(block, {a_buf.node()}); |
2873 | |
2874 | ASSERT_TRUE(LoopNest::vectorize(inner_for)); |
2875 | ASSERT_EQ(outer_for->body()->nstmts(), 1); |
2876 | ASSERT_EQ(to<For>(outer_for->body()->front()), nullptr); |
2877 | } |
2878 | |
2879 | namespace { |
2880 | |
2881 | std::string constantUpperBoundLoopIR(int upper_bound_val) { |
2882 | ExprHandle upper_bound(upper_bound_val); |
2883 | Tensor A = |
2884 | Compute("A" , {upper_bound}, [&](const VarHandle& x) { return x * 2; }); |
2885 | LoopNest l({A}); |
2886 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
2887 | StmtPtr unrolled = nullptr; |
2888 | LoopNest::fullUnroll(loops[0], &unrolled); |
2889 | std::ostringstream oss; |
2890 | oss << *unrolled; |
2891 | return oss.str(); |
2892 | } |
2893 | |
2894 | } // namespace |
2895 | |
2896 | TEST(LoopNest, Unroll) { |
2897 | const std::string actual = constantUpperBoundLoopIR(3); |
2898 | const std::string& verification_pattern = |
2899 | R"IR( |
2900 | # CHECK: A[0] = 0; |
2901 | # CHECK: A[1] = 2; |
2902 | # CHECK: A[2] = 4)IR" ; |
2903 | |
2904 | torch::jit::testing::FileCheck().run(verification_pattern, actual); |
2905 | } |
2906 | |
2907 | TEST(LoopNest, UnrollOuter) { |
2908 | ExprHandle outer_bound(3); |
2909 | ExprHandle inner_bound(4); |
2910 | Tensor A = Compute( |
2911 | "A" , |
2912 | {outer_bound, inner_bound}, |
2913 | [&](const VarHandle& x, const VarHandle& y) { return x + y; }); |
2914 | LoopNest l({A}); |
2915 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
2916 | StmtPtr unrolled = nullptr; |
2917 | LoopNest::fullUnroll(loops[0], &unrolled); |
2918 | checkIR(unrolled, R"IR( |
2919 | # CHECK: for (int i = 0; i < 4; i++) { |
2920 | # CHECK: A[0, i] = i; |
2921 | # CHECK: } |
2922 | # CHECK: for (int i = 0; i < 4; i++) { |
2923 | # CHECK: A[1, i] = i + 1; |
2924 | # CHECK: } |
2925 | # CHECK: for (int i = 0; i < 4; i++) { |
2926 | # CHECK: A[2, i] = i + 2; |
2927 | # CHECK: })IR" ); |
2928 | } |
2929 | |
2930 | TEST(LoopNest, UnrollInner) { |
2931 | ExprHandle outer_bound(3); |
2932 | ExprHandle inner_bound(4); |
2933 | Tensor A = Compute( |
2934 | "A" , |
2935 | {outer_bound, inner_bound}, |
2936 | [&](const VarHandle& x, const VarHandle& y) { return x + y; }); |
2937 | LoopNest l({A}); |
2938 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
2939 | StmtPtr unrolled = nullptr; |
2940 | LoopNest::fullUnroll( |
2941 | static_to<For>(loops[0]->body()->stmts().front()), &unrolled); |
2942 | checkIR(loops[0], R"IR( |
2943 | # CHECK: for (int i = 0; i < 3; i++) { |
2944 | # CHECK: A[i, 0] = i; |
2945 | # CHECK: A[i, 1] = i + 1; |
2946 | # CHECK: A[i, 2] = i + 2; |
2947 | # CHECK: A[i, 3] = i + 3; |
2948 | # CHECK: })IR" ); |
2949 | } |
2950 | |
2951 | TEST(LoopNest, UnrollMultipleStatements) { |
2952 | const int kTotalSize = 3; |
2953 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
2954 | BufHandle b_buf("B" , {ExprHandle(kTotalSize)}, kInt); |
2955 | |
2956 | VarHandle x("x" , kInt); |
2957 | auto f = For::make( |
2958 | x, |
2959 | 0, |
2960 | kTotalSize, |
2961 | Block::make( |
2962 | {Store::make(a_buf, {x}, x * 2), |
2963 | Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); |
2964 | auto parent_block = Block::make({f}); |
2965 | StmtPtr unrolled = nullptr; |
2966 | LoopNest::fullUnroll(f, &unrolled); |
2967 | checkIR(unrolled, R"IR( |
2968 | # CHECK: A[0] = 0; |
2969 | # CHECK: B[0] = A[0]; |
2970 | # CHECK: A[1] = 2; |
2971 | # CHECK: B[1] = A[1]; |
2972 | # CHECK: A[2] = 4 |
2973 | # CHECK: B[2] = A[2];)IR" ); |
2974 | } |
2975 | |
2976 | TEST(LoopNest, UnrollNonLiteralConstantBounds) { |
2977 | // Input IR: |
2978 | // for (int i = 2 - 1; i < 12 / 3; i++) { |
2979 | // for (int j = 0; j < 4; j++) { |
2980 | // A[i,j] = i * j; |
2981 | // } |
2982 | // } |
2983 | BufHandle a_buf("A" , {3, 4}, kInt); |
2984 | VarHandle i("i" , kInt); |
2985 | VarHandle j("j" , kInt); |
2986 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
2987 | auto inner_for = For::make(j, 0, 4, for_body); |
2988 | auto outer_for = For::make( |
2989 | i, |
2990 | IntImm::make(2) - IntImm::make(1), |
2991 | IntImm::make(12) / IntImm::make(3), |
2992 | inner_for); |
2993 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
2994 | auto b = Block::make({outer_for}); |
2995 | |
2996 | std::vector<ForPtr> loops = {outer_for, inner_for}; |
2997 | StmtPtr unrolled = nullptr; |
2998 | LoopNest::fullUnroll(loops[0], &unrolled); |
2999 | checkIR(unrolled, R"IR( |
3000 | # CHECK: for (int j = 0; j < 4; j++) { |
3001 | # CHECK: A[1, j] = j; |
3002 | # CHECK: } |
3003 | # CHECK: for (int j = 0; j < 4; j++) { |
3004 | # CHECK: A[2, j] = 2 * j; |
3005 | # CHECK: } |
3006 | # CHECK: for (int j = 0; j < 4; j++) { |
3007 | # CHECK: A[3, j] = 3 * j; |
3008 | # CHECK: })IR" ); |
3009 | } |
3010 | |
3011 | TEST(LoopNest, UnrollNonConstantBounds) { |
3012 | // Input IR: |
3013 | // for (int i = 0; i < M; i++) { |
3014 | // for (int j = 0; j < N; j++) { |
3015 | // A[i, j] = i * j; |
3016 | // } |
3017 | // } |
3018 | VarHandle M("M" , kInt); |
3019 | VarHandle N("N" , kInt); |
3020 | BufHandle a_buf("A" , {M, N}, kInt); |
3021 | VarHandle i("i" , kInt); |
3022 | VarHandle j("j" , kInt); |
3023 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
3024 | auto inner_for = For::make(j, 0, N, for_body); |
3025 | auto outer_for = For::make(i, 0, M, inner_for); |
3026 | auto block = Block::make({outer_for}); |
3027 | LoopNest l(block, {a_buf.node()}); |
3028 | |
3029 | LoopNest::unroll(inner_for, 8); |
3030 | l.simplify(); |
3031 | checkIR(l.root_stmt(), R"IR( |
3032 | # CHECK: for (int i = 0; i < M; i++) { |
3033 | # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { |
3034 | # CHECK: A[i, 8 * j_outer] = |
3035 | # CHECK: A[i, 8 * j_outer + 1] = |
3036 | # CHECK: A[i, 2 * (4 * j_outer + 1)] = |
3037 | # CHECK: A[i, 8 * j_outer + 3] = |
3038 | # CHECK: A[i, 4 * (2 * j_outer + 1)] = |
3039 | # CHECK: A[i, 8 * j_outer + 5] = |
3040 | # CHECK: A[i, 8 * j_outer + 6] = |
3041 | # CHECK: A[i, 8 * j_outer + 7] = |
3042 | # CHECK: } |
3043 | # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { |
3044 | # CHECK: A[i, 8 * (N / 8) + j_tail] = |
3045 | # CHECK: } |
3046 | # CHECK: } |
3047 | )IR" ); |
3048 | } |
3049 | |
3050 | TEST(LoopNest, UnrollByFactorsLessThan2) { |
3051 | // Input IR: |
3052 | // for (int i = 0; i < M; i++) { |
3053 | // for (int j = 0; j < N; j++) { |
3054 | // A[i, j] = i * j; |
3055 | // } |
3056 | // } |
3057 | VarHandle M("M" , kInt); |
3058 | VarHandle N("N" , kInt); |
3059 | BufHandle a_buf("A" , {M, N}, kInt); |
3060 | VarHandle i("i" , kInt); |
3061 | VarHandle j("j" , kInt); |
3062 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
3063 | auto inner_for = For::make(j, 0, N, for_body); |
3064 | auto outer_for = For::make(i, 0, M, inner_for); |
3065 | auto block = Block::make({outer_for}); |
3066 | LoopNest l(block, {a_buf.node()}); |
3067 | |
3068 | // Unrolling by factor = 1 should do nothing. |
3069 | LoopNest::unroll(inner_for, 1); |
3070 | checkIR(l.root_stmt(), R"IR( |
3071 | # CHECK: for (int i = 0; i < M; i++) { |
3072 | # CHECK: for (int j = 0; j < N; j++) { |
3073 | # CHECK: A[i, j] = |
3074 | # CHECK: } |
3075 | # CHECK: } |
3076 | )IR" ); |
3077 | |
3078 | // Unrolling by factor = 0 should do nothing. |
3079 | LoopNest::unroll(inner_for, 0); |
3080 | checkIR(l.root_stmt(), R"IR( |
3081 | # CHECK: for (int i = 0; i < M; i++) { |
3082 | # CHECK: for (int j = 0; j < N; j++) { |
3083 | # CHECK: A[i, j] = |
3084 | # CHECK: } |
3085 | # CHECK: } |
3086 | )IR" ); |
3087 | |
3088 | // Unrolling by negative factor should do nothing. |
3089 | LoopNest::unroll(inner_for, -2); |
3090 | checkIR(l.root_stmt(), R"IR( |
3091 | # CHECK: for (int i = 0; i < M; i++) { |
3092 | # CHECK: for (int j = 0; j < N; j++) { |
3093 | # CHECK: A[i, j] = |
3094 | # CHECK: } |
3095 | # CHECK: } |
3096 | )IR" ); |
3097 | } |
3098 | |
3099 | TEST(LoopNest, UnrollByFactorEqualToIters) { |
3100 | // Input IR: |
3101 | // for (int i = 0; i < 5; i++) { |
3102 | // A[i] = i * i; |
3103 | // } |
3104 | BufHandle a_buf("A" , {5}, kInt); |
3105 | VarHandle i("i" , kInt); |
3106 | auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); |
3107 | auto for_loop = For::make(i, 0, 5, for_body); |
3108 | auto block = Block::make({for_loop}); |
3109 | LoopNest l(block, {a_buf.node()}); |
3110 | |
3111 | LoopNest::unroll(for_loop, 5); |
3112 | checkIR(l.root_stmt(), R"IR( |
3113 | # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) |
3114 | # CHECK: A[5 * i_outer] |
3115 | # CHECK: A[5 * i_outer + 1] |
3116 | # CHECK: A[5 * i_outer + 2] |
3117 | # CHECK: A[5 * i_outer + 3] |
3118 | # CHECK: A[5 * i_outer + 4] |
3119 | )IR" ); |
3120 | } |
3121 | |
3122 | TEST(LoopNest, UnrollEmpty) { |
3123 | const std::string actual = constantUpperBoundLoopIR(0); |
3124 | const std::string& verification_pattern = R"IR( |
3125 | # CHECK-NOT: A[ |
3126 | )IR" ; |
3127 | |
3128 | torch::jit::testing::FileCheck().run(verification_pattern, actual); |
3129 | } |
3130 | |
3131 | TEST(LoopNest, NoUnroll) { |
3132 | VarHandle upper_bound("N" , kInt); |
3133 | Tensor A = |
3134 | Compute("A" , {upper_bound}, [&](const VarHandle& x) { return x * 2; }); |
3135 | LoopNest l({A}); |
3136 | std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
3137 | StmtPtr unrolled = nullptr; |
3138 | ASSERT_THROWS_WITH( |
3139 | LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop" ); |
3140 | } |
3141 | |
3142 | TEST(LoopNest, UnrollWithLet) { |
3143 | const int kTotalSize = 3; |
3144 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
3145 | BufHandle b_buf("B" , {ExprHandle(kTotalSize)}, kInt); |
3146 | |
3147 | VarHandle e("e" , kInt); |
3148 | VarHandle x("x" , kInt); |
3149 | auto f = For::make( |
3150 | x, |
3151 | 0, |
3152 | kTotalSize, |
3153 | Block::make( |
3154 | {Let::make(e, 7), |
3155 | Store::make(a_buf, {x}, e), |
3156 | Store::make(b_buf, {x}, e + 1)})); |
3157 | auto parent_block = Block::make({f}); |
3158 | StmtPtr unrolled = nullptr; |
3159 | LoopNest::fullUnroll(f, &unrolled); |
3160 | std::ostringstream oss; |
3161 | oss << *unrolled; |
3162 | const std::string& verification_pattern = |
3163 | R"IR( |
3164 | # CHECK: int e = 7; |
3165 | # CHECK: A[0] = e; |
3166 | # CHECK: B[0] = e + 1; |
3167 | # CHECK: A[1] = e; |
3168 | # CHECK: B[1] = e + 1; |
3169 | # CHECK: A[2] = e; |
3170 | # CHECK: B[2] = e + 1;)IR" ; |
3171 | |
3172 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
3173 | |
3174 | std::vector<int> a_v(kTotalSize, 0); |
3175 | std::vector<int> b_v(kTotalSize, 0); |
3176 | SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); |
3177 | eval(a_v, b_v); |
3178 | for (int i = 0; i < kTotalSize; ++i) { |
3179 | ASSERT_EQ(a_v[i], 7); |
3180 | ASSERT_EQ(b_v[i], 8); |
3181 | } |
3182 | } |
3183 | |
3184 | TEST(LoopNest, IsNormalized) { |
3185 | // Input IR: |
3186 | // for (int i = 50; i < 100; i++) { |
3187 | // A[i] = B[i]; |
3188 | // } |
3189 | BufHandle a_buf("A" , {ExprHandle(100)}, kInt); |
3190 | BufHandle b_buf("B" , {ExprHandle(100)}, kInt); |
3191 | VarHandle i("i" , kInt); |
3192 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
3193 | auto for_stmt = |
3194 | For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); |
3195 | Block::make({for_stmt}); |
3196 | ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); |
3197 | |
3198 | for_stmt->set_start(alloc<IntImm>(0)); |
3199 | ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); |
3200 | |
3201 | VarHandle N("N" , kInt); |
3202 | for_stmt->set_start(N.node()); |
3203 | ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); |
3204 | } |
3205 | |
3206 | TEST(LoopNest, NormalizeStartPositive) { |
3207 | // Input IR: |
3208 | // for (int x = 50; x < 100; x++) { |
3209 | // A[x] = B[x]; |
3210 | // B[x] = x * 2; |
3211 | // } |
3212 | const int kTotalSize = 50; |
3213 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
3214 | BufHandle b_buf("B" , {ExprHandle(kTotalSize)}, kInt); |
3215 | VarHandle x("x" , kInt); |
3216 | auto for_body = Block::make( |
3217 | {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), |
3218 | Store::make(b_buf, {x}, x * 2)}); |
3219 | auto for_stmt = For::make(x, 50, 100, for_body); |
3220 | Block::make({for_stmt}); |
3221 | |
3222 | LoopNest::normalize(for_stmt); |
3223 | |
3224 | auto result = IRSimplifier::simplify(for_stmt); |
3225 | std::ostringstream oss; |
3226 | oss << *result; |
3227 | const std::string& expected_ir = |
3228 | R"IR( |
3229 | # CHECK: for (int x = 0; x < 50; x++) { |
3230 | # CHECK: A[x + 50] = B[x + 50]; |
3231 | # CHECK: B[x + 50] = 2 * (x + 50); |
3232 | )IR" ; |
3233 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3234 | } |
3235 | |
3236 | TEST(LoopNest, NormalizeStartNegative) { |
3237 | // Input IR: |
3238 | // for (int x = -50; x < 100; x++) { |
3239 | // A[x + 50] = B[x + 50]; |
3240 | // B[x + 50] = x * 2; |
3241 | // } |
3242 | const int kTotalSize = 150; |
3243 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
3244 | BufHandle b_buf("B" , {ExprHandle(kTotalSize)}, kInt); |
3245 | VarHandle x("x" , kInt); |
3246 | auto for_body = Block::make( |
3247 | {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), |
3248 | Store::make(b_buf, {x + 50}, x * 2)}); |
3249 | auto for_stmt = For::make(x, -50, 100, for_body); |
3250 | Block::make({for_stmt}); |
3251 | |
3252 | LoopNest::normalize(for_stmt); |
3253 | |
3254 | auto result = IRSimplifier::simplify(for_stmt); |
3255 | std::ostringstream oss; |
3256 | oss << *result; |
3257 | const std::string& expected_ir = |
3258 | R"IR( |
3259 | # CHECK: for (int x = 0; x < 150; x++) { |
3260 | # CHECK: A[x] = B[x]; |
3261 | # CHECK: B[x] = 2 * (x - 50); |
3262 | )IR" ; |
3263 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3264 | } |
3265 | |
3266 | TEST(LoopNest, NormalizeStartZero) { |
3267 | // Input IR: |
3268 | // for (int x = 0; x < 100; x++) { |
3269 | // A[x] = B[x]; |
3270 | // B[x] = x * 2; |
3271 | // } |
3272 | // Should not be modified. |
3273 | |
3274 | const int kTotalSize = 100; |
3275 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
3276 | BufHandle b_buf("B" , {ExprHandle(kTotalSize)}, kInt); |
3277 | VarHandle x("x" , kInt); |
3278 | auto for_body = Block::make( |
3279 | {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), |
3280 | Store::make(b_buf, {x}, x * 2)}); |
3281 | auto for_stmt = For::make(x, 0, 100, for_body); |
3282 | Block::make({for_stmt}); |
3283 | |
3284 | LoopNest::normalize(for_stmt); |
3285 | |
3286 | auto result = IRSimplifier::simplify(for_stmt); |
3287 | std::ostringstream oss; |
3288 | oss << *result; |
3289 | const std::string& expected_ir = |
3290 | R"IR( |
3291 | # CHECK: for (int x = 0; x < 100; x++) { |
3292 | # CHECK: A[x] = B[x]; |
3293 | # CHECK: B[x] = 2 * x; |
3294 | )IR" ; |
3295 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3296 | } |
3297 | |
3298 | TEST(LoopNest, NormalizeStartVariable) { |
3299 | // Input IR: |
3300 | // for (int x = y; x < 100; x++) { |
3301 | // A[x] = B[x]; |
3302 | // B[x] = x * 2; |
3303 | // } |
3304 | |
3305 | const int kTotalSize = 100; |
3306 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
3307 | BufHandle b_buf("B" , {ExprHandle(kTotalSize)}, kInt); |
3308 | VarHandle x("x" , kInt); |
3309 | VarHandle y("y" , kInt); |
3310 | auto for_body = Block::make( |
3311 | {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), |
3312 | Store::make(b_buf, {x}, x * 2)}); |
3313 | auto for_stmt = For::make(x, y, 100, for_body); |
3314 | auto parent_block = Block::make({for_stmt}); |
3315 | |
3316 | LoopNest::normalize(for_stmt); |
3317 | |
3318 | auto result = IRSimplifier::simplify(for_stmt); |
3319 | std::ostringstream oss; |
3320 | oss << *result; |
3321 | const std::string& expected_ir = |
3322 | R"IR( |
3323 | # CHECK: for (int x = 0; x < 100 - y; x++) { |
3324 | # CHECK: A[x + y] = B[x + y]; |
3325 | # CHECK: B[x + y] = 2 * (x + y); |
3326 | )IR" ; |
3327 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3328 | } |
3329 | |
3330 | TEST(LoopNest, NormalizeOnNestedOuterLoop) { |
3331 | // Input IR: |
3332 | // for (int x = 50; x < 100; x++) { |
3333 | // for (int y = 10; y < 100; y++) { |
3334 | // A[x] = A[x] + B[y] + y * 2; |
3335 | // } |
3336 | // } |
3337 | |
3338 | BufHandle a_buf("A" , {ExprHandle(50)}, kInt); |
3339 | BufHandle b_buf("B" , {ExprHandle(100)}, kInt); |
3340 | VarHandle x("x" , kInt); |
3341 | VarHandle y("y" , kInt); |
3342 | auto inner_for_body = Store::make( |
3343 | a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); |
3344 | auto inner_for = For::make(y, 10, 100, inner_for_body); |
3345 | auto for_stmt = For::make(x, 50, 100, inner_for); |
3346 | Block::make({for_stmt}); |
3347 | |
3348 | LoopNest::normalize(for_stmt); |
3349 | |
3350 | auto result = IRSimplifier::simplify(for_stmt); |
3351 | std::ostringstream oss; |
3352 | oss << *result; |
3353 | const std::string& expected_ir = |
3354 | R"IR( |
3355 | # CHECK: for (int x = 0; x < 50; x++) { |
3356 | # CHECK: for (int y = 10; y < 100; y++) { |
3357 | # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; |
3358 | )IR" ; |
3359 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3360 | } |
3361 | |
3362 | TEST(LoopNest, NormalizeOnNestedInnerLoop) { |
3363 | // Input IR: |
3364 | // for (int x = 50; x < 100; x++) { |
3365 | // for (int y = 10; y < 100; y++) { |
3366 | // A[x] = A[x] + B[y] + y * 2; |
3367 | // } |
3368 | // } |
3369 | |
3370 | BufHandle a_buf("A" , {ExprHandle(50)}, kInt); |
3371 | BufHandle b_buf("B" , {ExprHandle(100)}, kInt); |
3372 | VarHandle x("x" , kInt); |
3373 | VarHandle y("y" , kInt); |
3374 | auto inner_for_body = Store::make( |
3375 | a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); |
3376 | auto inner_for = For::make(y, 10, 100, inner_for_body); |
3377 | auto for_stmt = For::make(x, 50, 100, inner_for); |
3378 | Block::make({for_stmt}); |
3379 | |
3380 | LoopNest::normalize(inner_for); |
3381 | |
3382 | auto result = IRSimplifier::simplify(for_stmt); |
3383 | std::ostringstream oss; |
3384 | oss << *result; |
3385 | const std::string& expected_ir = |
3386 | R"IR( |
3387 | # CHECK: for (int x = 50; x < 100; x++) { |
3388 | # CHECK: for (int y = 0; y < 90; y++) { |
3389 | # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; |
3390 | )IR" ; |
3391 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3392 | } |
3393 | |
3394 | TEST(LoopNest, NormalizeAndSplitWithTail) { |
3395 | // Create a dummy tensor to construct LoopNest. |
3396 | ExprHandle n(100); |
3397 | BufHandle a("a" , {n}, kFloat); |
3398 | Tensor b = Compute("b" , {n}, [&](const VarHandle& i) { return a.load(i); }); |
3399 | LoopNest l({b}); |
3400 | |
3401 | // Input IR: |
3402 | // for (int x = 5; x < 10; x++) { |
3403 | // A[x] = x * 2; |
3404 | // } |
3405 | const int kTotalSize = 5; |
3406 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kInt); |
3407 | VarHandle x("x" , kInt); |
3408 | auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); |
3409 | auto parent_block = Block::make({for_stmt}); |
3410 | |
3411 | LoopNest::normalize(for_stmt); |
3412 | |
3413 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3414 | ForPtr x_inner; |
3415 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3416 | ForPtr x_tail; |
3417 | LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); |
3418 | |
3419 | auto x_outer_result = IRSimplifier::simplify(for_stmt); |
3420 | std::ostringstream oss_outer; |
3421 | oss_outer << *x_outer_result; |
3422 | const std::string& expected_outer_ir = |
3423 | R"IR( |
3424 | # CHECK: { |
3425 | # CHECK: } |
3426 | )IR" ; |
3427 | torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); |
3428 | |
3429 | auto x_tail_result = IRSimplifier::simplify(x_tail); |
3430 | std::ostringstream oss_tail; |
3431 | oss_tail << *x_tail_result; |
3432 | const std::string& expected_tail_ir = |
3433 | R"IR( |
3434 | # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { |
3435 | # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); |
3436 | )IR" ; |
3437 | torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); |
3438 | } |
3439 | |
3440 | TEST(LoopNest, NotNormalizeAndSplitWithTail) { |
3441 | // Create a dummy tensor to construct LoopNest. |
3442 | ExprHandle n(100); |
3443 | BufHandle a("a" , {n}, kFloat); |
3444 | Tensor b = Compute("b" , {n}, [&](const VarHandle& i) { return a.load(i); }); |
3445 | LoopNest l({b}); |
3446 | |
3447 | // Input IR: |
3448 | // for (int x = 5; x < 15; x++) { |
3449 | // A[x] = x * 2; |
3450 | // } |
3451 | const int kTotalSize = 10; |
3452 | BufHandle a_buf("A" , {kTotalSize}, kInt); |
3453 | VarHandle x("x" , kInt); |
3454 | auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); |
3455 | auto parent_block = Block::make({for_stmt}); |
3456 | |
3457 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3458 | ForPtr x_inner; |
3459 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3460 | ForPtr x_tail; |
3461 | LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); |
3462 | |
3463 | auto x_outer_result = IRSimplifier::simplify(for_stmt); |
3464 | std::ostringstream oss_outer; |
3465 | oss_outer << *x_outer_result; |
3466 | const std::string& expected_outer_ir = |
3467 | R"IR( |
3468 | # CHECK: { |
3469 | # CHECK: } |
3470 | )IR" ; |
3471 | torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); |
3472 | |
3473 | auto x_tail_result = IRSimplifier::simplify(x_tail); |
3474 | std::ostringstream oss_tail; |
3475 | oss_tail << *x_tail_result; |
3476 | const std::string& expected_tail_ir = |
3477 | R"IR( |
3478 | # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { |
3479 | # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); |
3480 | )IR" ; |
3481 | torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); |
3482 | } |
3483 | |
3484 | TEST(LoopNest, FlattenSimpleLoopNest2D) { |
3485 | // Input IR: |
3486 | // for (int i = 0; i < 10; i++) { |
3487 | // for (int j = 0; j < 5; j++) { |
3488 | // A[i,j] = i * j; |
3489 | // } |
3490 | // } |
3491 | BufHandle a_buf("A" , {10, 5}, kInt); |
3492 | VarHandle i("i" , kInt); |
3493 | VarHandle j("j" , kInt); |
3494 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
3495 | auto inner_for = For::make(j, 0, 5, for_body); |
3496 | auto outer_for = For::make(i, 0, 10, inner_for); |
3497 | auto parent_block = Block::make({outer_for}); |
3498 | |
3499 | std::vector<ForPtr> loops = {outer_for, inner_for}; |
3500 | ForPtr flattened = nullptr; |
3501 | ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
3502 | ASSERT_EQ(flattened, loops.front()); |
3503 | |
3504 | auto result = IRSimplifier::simplify(flattened); |
3505 | std::ostringstream oss; |
3506 | oss << *result; |
3507 | const std::string& expected_ir = |
3508 | R"IR( |
3509 | # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { |
3510 | # CHECK: A[i_flat / 5, i_flat % 5] = |
3511 | )IR" ; |
3512 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3513 | |
3514 | { |
3515 | SimpleIREvaluator eval1(loops[0], {a_buf}); |
3516 | PaddedBuffer<int> inp1(10, 5); |
3517 | eval1(inp1); |
3518 | SimpleIREvaluator eval2(flattened, {a_buf}); |
3519 | PaddedBuffer<int> inp2(10, 5); |
3520 | eval2(inp2); |
3521 | ExpectAllNear(inp1, inp2, 1e-5); |
3522 | } |
3523 | } |
3524 | |
3525 | TEST(LoopNest, FlattenSimpleLoopNest3D) { |
3526 | // Input IR: |
3527 | // for (int i = 0; i < 10; i++) { |
3528 | // for (int j = 0; j < 5; j++) { |
3529 | // for (int k = 0; k < 7; k++) { |
3530 | // A[i,j,k] = i + j * k; |
3531 | // } |
3532 | // } |
3533 | // } |
3534 | BufHandle a_buf("A" , {10, 5, 7}, kInt); |
3535 | VarHandle i("i" , kInt); |
3536 | VarHandle j("j" , kInt); |
3537 | VarHandle k("k" , kInt); |
3538 | auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); |
3539 | auto for1 = For::make(k, 0, 7, for_body); |
3540 | auto for2 = For::make(j, 0, 5, for1); |
3541 | auto for3 = For::make(i, 0, 10, for2); |
3542 | auto parent_block = Block::make({for3}); |
3543 | |
3544 | std::vector<ForPtr> loops = {for3, for2, for1}; |
3545 | ForPtr flattened = nullptr; |
3546 | ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
3547 | ASSERT_EQ(flattened, loops.front()); |
3548 | |
3549 | auto result = IRSimplifier::simplify(flattened); |
3550 | std::ostringstream oss; |
3551 | oss << *result; |
3552 | const std::string& expected_ir = |
3553 | R"IR( |
3554 | # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { |
3555 | # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = |
3556 | )IR" ; |
3557 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3558 | |
3559 | { |
3560 | SimpleIREvaluator eval1(loops[0], {a_buf}); |
3561 | PaddedBuffer<int> inp1(10, 5, 7); |
3562 | eval1(inp1); |
3563 | SimpleIREvaluator eval2(flattened, {a_buf}); |
3564 | PaddedBuffer<int> inp2(10, 5, 7); |
3565 | eval2(inp2); |
3566 | ExpectAllNear(inp1, inp2, 1e-5); |
3567 | } |
3568 | } |
3569 | |
3570 | TEST(LoopNest, FlattenLoopNestAfterNormalize) { |
3571 | // Input IR: |
3572 | // for (int i = 2; i < 10; i++) { |
3573 | // for (int j = 3; j < 15; j++) { |
3574 | // A[i - 2,j - 3] = i * j; |
3575 | // } |
3576 | // } |
3577 | BufHandle a_buf("A" , {8, 12}, kInt); |
3578 | VarHandle i("i" , kInt); |
3579 | VarHandle j("j" , kInt); |
3580 | auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); |
3581 | auto inner_for = For::make(j, 3, 15, for_body); |
3582 | auto outer_for = For::make(i, 2, 10, inner_for); |
3583 | auto parent_block = Block::make({outer_for}); |
3584 | |
3585 | std::vector<ForPtr> loops = {outer_for, inner_for}; |
3586 | ForPtr flattened = nullptr; |
3587 | ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
3588 | ASSERT_EQ(flattened, loops.front()); |
3589 | |
3590 | auto result = IRSimplifier::simplify(flattened); |
3591 | std::ostringstream oss; |
3592 | oss << *result; |
3593 | const std::string& expected_ir = |
3594 | R"IR( |
3595 | # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { |
3596 | # CHECK: A[i_flat / 12, i_flat % 12] = |
3597 | )IR" ; |
3598 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
3599 | |
3600 | { |
3601 | SimpleIREvaluator eval1(loops[0], {a_buf}); |
3602 | PaddedBuffer<int> inp1(8, 12); |
3603 | eval1(inp1); |
3604 | SimpleIREvaluator eval2(flattened, {a_buf}); |
3605 | PaddedBuffer<int> inp2(8, 12); |
3606 | eval2(inp2); |
3607 | ExpectAllNear(inp1, inp2, 1e-5); |
3608 | } |
3609 | } |
3610 | |
3611 | TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { |
3612 | // Input IR: |
3613 | // for (int i = 0; i < 15-5; i++) { |
3614 | // for (int j = 0; j < 20/4; j++) { |
3615 | // A[i,j] = i * j; |
3616 | // } |
3617 | // } |
3618 | BufHandle a_buf("A" , {10, 5}, kInt); |
3619 | VarHandle i("i" , kInt); |
3620 | VarHandle j("j" , kInt); |
3621 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
3622 | auto inner_for = |
3623 | For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); |
3624 | auto outer_for = |
3625 | For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); |
3626 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
3627 | auto b = Block::make({outer_for}); |
3628 | |
3629 | std::vector<ForPtr> loops = {outer_for, inner_for}; |
3630 | ForPtr flattened = nullptr; |
3631 | ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
3632 | ASSERT_EQ(flattened, loops.front()); |
3633 | |
3634 | auto result = IRSimplifier::simplify(flattened); |
3635 | checkIR(result, R"IR( |
3636 | # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { |
3637 | # CHECK: A[i_flat / 5, i_flat % 5] = |
3638 | )IR" ); |
3639 | |
3640 | { |
3641 | SimpleIREvaluator eval1(loops[0], {a_buf}); |
3642 | PaddedBuffer<int> inp1(10, 5); |
3643 | eval1(inp1); |
3644 | SimpleIREvaluator eval2(flattened, {a_buf}); |
3645 | PaddedBuffer<int> inp2(10, 5); |
3646 | eval2(inp2); |
3647 | ExpectAllNear(inp1, inp2, 1e-5); |
3648 | } |
3649 | } |
3650 | |
3651 | TEST(LoopNest, FlattenImperfectLoopNest) { |
3652 | // Input IR: |
3653 | // for (int i = 0; i < 10; i++) { |
3654 | // A[i, i] = 0; |
3655 | // for (int j = 0; j < 15; j++) { |
3656 | // A[i,j] = i * j; |
3657 | // } |
3658 | // } |
3659 | // Do not flatten. |
3660 | |
3661 | BufHandle a_buf("A" , {10, 15}, kInt); |
3662 | VarHandle i("i" , kInt); |
3663 | VarHandle j("j" , kInt); |
3664 | auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
3665 | auto inner_for = For::make(j, 0, 15, for_body); |
3666 | auto outer_for = For::make( |
3667 | i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); |
3668 | auto par = Block::make({outer_for}); |
3669 | HashProvider hasher; |
3670 | auto hash_before = hasher.hash(par); |
3671 | |
3672 | std::vector<ForPtr> loops = {outer_for, inner_for}; |
3673 | ForPtr flattened = nullptr; |
3674 | ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
3675 | ASSERT_EQ(flattened, nullptr); |
3676 | auto hash_after = hasher.hash(par); |
3677 | ASSERT_EQ(hash_before, hash_after); |
3678 | } |
3679 | |
3680 | TEST(LoopNest, FlattenReductionLoopNest) { |
3681 | // Input IR: |
3682 | // for (int i = 0; i < 10; i++) { |
3683 | // S[i] = 0; |
3684 | // for (int j = 0; j < 15; j++) { |
3685 | // S[i] = S[i] + A[i,j]; |
3686 | // } |
3687 | // } |
3688 | // Do not flatten. |
3689 | |
3690 | BufHandle a_buf("A" , {10, 15}, kInt); |
3691 | BufHandle s_buf("S" , {10}, kInt); |
3692 | VarHandle i("i" , kInt); |
3693 | VarHandle j("j" , kInt); |
3694 | auto for_body = Block::make({Store::make( |
3695 | s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); |
3696 | auto inner_for = For::make(j, 0, 15, for_body); |
3697 | auto outer_for = |
3698 | For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); |
3699 | auto par = Block::make({outer_for}); |
3700 | HashProvider hasher; |
3701 | auto hash_before = hasher.hash(par); |
3702 | |
3703 | std::vector<ForPtr> loops = {outer_for, inner_for}; |
3704 | ForPtr flattened = nullptr; |
3705 | ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
3706 | ASSERT_EQ(flattened, nullptr); |
3707 | auto hash_after = hasher.hash(par); |
3708 | ASSERT_EQ(hash_before, hash_after); |
3709 | } |
3710 | |
3711 | TEST(LoopNest, FlattenReductionLoopNestFromTensor) { |
3712 | const int M = 3; |
3713 | const int N = 7; |
3714 | VarHandle m("m" , kInt); |
3715 | VarHandle n("n" , kInt); |
3716 | BufHandle b("b" , {m, n}, kFloat); |
3717 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N}); |
3718 | LoopNest loop({c}); |
3719 | HashProvider hasher; |
3720 | auto hash_before = hasher.hash(loop.root_stmt()); |
3721 | |
3722 | auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; |
3723 | ForPtr flattened = nullptr; |
3724 | ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
3725 | ASSERT_EQ(flattened, nullptr); |
3726 | auto hash_after = hasher.hash(loop.root_stmt()); |
3727 | ASSERT_EQ(hash_before, hash_after); |
3728 | } |
3729 | |
3730 | TEST(LoopNest, FlattenIncorrectLoopsAsInput) { |
3731 | // Input IR: |
3732 | // for (int i = 0; i < 10; i++) { |
3733 | // for (int j = 0; j < 5; j++) { |
3734 | // A[i,j] = i * j; |
3735 | // } |
3736 | // } |
3737 | // for (int x = 0; x < 10; x++) { |
3738 | // for (int y = 0; y < 5; y++) { |
3739 | // A[x,y] = A[x,y] + x + y; |
3740 | // } |
3741 | // } |
3742 | // Flatten({For_i, For_y}) => should not succeed |
3743 | |
3744 | BufHandle a_buf("A" , {10, 5}, kInt); |
3745 | VarHandle i("i" , kInt); |
3746 | VarHandle j("j" , kInt); |
3747 | VarHandle x("x" , kInt); |
3748 | VarHandle y("y" , kInt); |
3749 | auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
3750 | auto inner_for1 = For::make(j, 0, 5, for_body1); |
3751 | auto outer_for1 = For::make(i, 0, 10, inner_for1); |
3752 | auto for_body2 = Block::make( |
3753 | {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); |
3754 | auto inner_for2 = For::make(y, 0, 5, for_body2); |
3755 | auto outer_for2 = For::make(x, 0, 10, inner_for2); |
3756 | auto par = Block::make({outer_for1, outer_for2}); |
3757 | HashProvider hasher; |
3758 | auto hash_before = hasher.hash(par); |
3759 | |
3760 | std::vector<ForPtr> loops = {outer_for1, inner_for2}; |
3761 | ForPtr flattened = nullptr; |
3762 | ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
3763 | ASSERT_EQ(flattened, nullptr); |
3764 | auto hash_after = hasher.hash(par); |
3765 | ASSERT_EQ(hash_before, hash_after); |
3766 | } |
3767 | |
3768 | TEST(LoopNest, DetectInlineRankMismatch) { |
3769 | const int kTotalSize = 8; |
3770 | |
3771 | BufHandle a_buf("A" , {ExprHandle(kTotalSize)}, kFloat); |
3772 | Tensor a = Compute( |
3773 | "a" , {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); |
3774 | Tensor reshape = Compute( |
3775 | "reshape" , |
3776 | {kTotalSize / 2, 2}, |
3777 | [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); |
3778 | LoopNest l({reshape}, {a, reshape}); |
3779 | ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); |
3780 | } |
3781 | |
3782 | TEST(LoopNest, CacheReadsSimple) { |
3783 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
3784 | return i * j; |
3785 | }); |
3786 | Tensor B = |
3787 | Compute("B" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3788 | return A.load(i + 30, j + 3); |
3789 | }); |
3790 | Tensor C = |
3791 | Compute("C" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3792 | return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
3793 | }); |
3794 | |
3795 | LoopNest l({B, C}, {A, B, C}); |
3796 | StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; |
3797 | LoopNest::cacheAccesses(A.buf(), "A_local" , j_loop); |
3798 | |
3799 | l.prepareForCodegen(); |
3800 | StmtPtr result = |
3801 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
3802 | SimpleIREvaluator cg(result, {B, C}); |
3803 | result = cg.stmt(); |
3804 | |
3805 | // just this once: verify the whole thing. |
3806 | checkIR(result, R"IR( |
3807 | #CHECK: Allocate(A); // dtype=int, dims=[64, 64] |
3808 | #CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] |
3809 | #CHECK: for (int i |
3810 | #CHECK: for (int j |
3811 | #CHECK: A[ |
3812 | #CHECK: } |
3813 | #CHECK: } |
3814 | #CHECK: for (int i_1 |
3815 | #CHECK: for (int j_1 |
3816 | #CHECK: A_local[j_1] = A[ |
3817 | #CHECK: } |
3818 | #CHECK: for (int j_2 |
3819 | #CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; |
3820 | #CHECK: } |
3821 | #CHECK: } |
3822 | #CHECK: for (int i_2 |
3823 | #CHECK: for (int j_3 |
3824 | #CHECK: C[ |
3825 | #CHECK: } |
3826 | #CHECK: } |
3827 | #CHECK: Free(A_local); |
3828 | #CHECK: Free(A); |
3829 | )IR" ); |
3830 | |
3831 | std::vector<int> b_data(200, 0); |
3832 | std::vector<int> c_data(200, 0); |
3833 | cg.call({b_data, c_data}); |
3834 | |
3835 | std::vector<int> b_ref(200, 0); |
3836 | std::vector<int> c_ref(200, 0); |
3837 | |
3838 | for (int i = 0; i < 20; ++i) { |
3839 | for (int j = 0; j < 10; ++j) { |
3840 | b_ref[i * 10 + j] = (i + 30) * (j + 3); |
3841 | c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
3842 | } |
3843 | } |
3844 | |
3845 | assertAllEqual(b_data, b_ref); |
3846 | assertAllEqual(c_data, c_ref); |
3847 | } |
3848 | |
3849 | TEST(LoopNest, CacheReadsOuter) { |
3850 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
3851 | return i * j; |
3852 | }); |
3853 | Tensor B = |
3854 | Compute("B" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3855 | return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); |
3856 | }); |
3857 | Tensor C = |
3858 | Compute("C" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3859 | return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
3860 | }); |
3861 | |
3862 | LoopNest l({B, C}, {A, B, C}); |
3863 | StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; |
3864 | LoopNest::cacheAccesses(A.buf(), "A_local" , i_loop); |
3865 | |
3866 | l.prepareForCodegen(); |
3867 | StmtPtr result = |
3868 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
3869 | SimpleIREvaluator cg(result, {B, C}); |
3870 | result = cg.stmt(); |
3871 | |
3872 | checkIR(result, R"IR( |
3873 | #CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] |
3874 | #CHECK: A_local[j_1 + 11 * i_1] = |
3875 | #CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); |
3876 | )IR" ); |
3877 | |
3878 | std::vector<int> b_data(200, 0); |
3879 | std::vector<int> c_data(200, 0); |
3880 | cg.call({b_data, c_data}); |
3881 | |
3882 | std::vector<int> b_ref(200, 0); |
3883 | std::vector<int> c_ref(200, 0); |
3884 | |
3885 | for (int i = 0; i < 20; ++i) { |
3886 | for (int j = 0; j < 10; ++j) { |
3887 | b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); |
3888 | c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
3889 | } |
3890 | } |
3891 | |
3892 | assertAllEqual(b_data, b_ref); |
3893 | assertAllEqual(c_data, c_ref); |
3894 | } |
3895 | |
3896 | TEST(LoopNest, CacheReadsInternal) { |
3897 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
3898 | return i * j; |
3899 | }); |
3900 | Tensor B = |
3901 | Compute("B" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3902 | return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); |
3903 | }); |
3904 | Tensor C = |
3905 | Compute("C" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3906 | return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
3907 | }); |
3908 | |
3909 | LoopNest l({B, C}, {A, B, C}); |
3910 | StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; |
3911 | LoopNest::cacheAccesses(A.buf(), "A_local" , j_loop); |
3912 | l.prepareForCodegen(); |
3913 | StmtPtr result = |
3914 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
3915 | SimpleIREvaluator cg(result, {B, C}); |
3916 | result = cg.stmt(); |
3917 | |
3918 | checkIR(result, R"IR( |
3919 | #CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] |
3920 | #CHECK: A_local[k + 11 * j_1] = |
3921 | #CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); |
3922 | )IR" ); |
3923 | |
3924 | std::vector<int> b_data(200, 0); |
3925 | std::vector<int> c_data(200, 0); |
3926 | cg.call({b_data, c_data}); |
3927 | |
3928 | std::vector<int> b_ref(200, 0); |
3929 | std::vector<int> c_ref(200, 0); |
3930 | |
3931 | for (int i = 0; i < 20; ++i) { |
3932 | for (int j = 0; j < 10; ++j) { |
3933 | b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); |
3934 | c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
3935 | } |
3936 | } |
3937 | |
3938 | assertAllEqual(b_data, b_ref); |
3939 | assertAllEqual(c_data, c_ref); |
3940 | } |
3941 | |
3942 | TEST(LoopNest, CacheReadsInner) { |
3943 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
3944 | return i * j; |
3945 | }); |
3946 | // note im changing the offset of the first arg of the first call to A. |
3947 | Tensor B = |
3948 | Compute("B" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3949 | return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); |
3950 | }); |
3951 | Tensor C = |
3952 | Compute("C" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3953 | return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
3954 | }); |
3955 | |
3956 | LoopNest l({B, C}, {A, B, C}); |
3957 | StmtPtr body = l.getLoopBodyFor(B); |
3958 | LoopNest::cacheAccesses(A.buf(), "A_local" , body); |
3959 | l.prepareForCodegen(); |
3960 | StmtPtr result = |
3961 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
3962 | SimpleIREvaluator cg(result, {B, C}); |
3963 | result = cg.stmt(); |
3964 | |
3965 | checkIR(result, R"IR( |
3966 | #CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] |
3967 | #CHECK: A_local[l + 2 * k] = |
3968 | #CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); |
3969 | )IR" ); |
3970 | |
3971 | std::vector<int> b_data(200, 0); |
3972 | std::vector<int> c_data(200, 0); |
3973 | cg.call({b_data, c_data}); |
3974 | |
3975 | std::vector<int> b_ref(200, 0); |
3976 | std::vector<int> c_ref(200, 0); |
3977 | |
3978 | for (int i = 0; i < 20; ++i) { |
3979 | for (int j = 0; j < 10; ++j) { |
3980 | b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); |
3981 | c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
3982 | } |
3983 | } |
3984 | |
3985 | assertAllEqual(b_data, b_ref); |
3986 | assertAllEqual(c_data, c_ref); |
3987 | } |
3988 | |
3989 | TEST(LoopNest, CacheWritesSimple) { |
3990 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
3991 | return i * j; |
3992 | }); |
3993 | Tensor B = |
3994 | Compute("B" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3995 | return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); |
3996 | }); |
3997 | Tensor C = |
3998 | Compute("C" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
3999 | return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
4000 | }); |
4001 | |
4002 | LoopNest l({B, C}, {A, B, C}); |
4003 | StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; |
4004 | LoopNest::cacheAccesses(A.buf(), "A_local" , a_loop); |
4005 | |
4006 | l.prepareForCodegen(); |
4007 | StmtPtr result = |
4008 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
4009 | SimpleIREvaluator cg(result, {B, C}); |
4010 | result = cg.stmt(); |
4011 | |
4012 | checkIR(result, R"IR( |
4013 | #CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] |
4014 | #CHECK: for (int j = 0; j < 64 |
4015 | #CHECK: A_local[j] = i * j; |
4016 | #CHECK: for (int j_1 = 0; j_1 < 64 |
4017 | #CHECK: A[j_1 + 64 * i] = A_local[ |
4018 | #CHECK: Free(A_local); |
4019 | #CHECK-NOT: A_local |
4020 | )IR" ); |
4021 | |
4022 | std::vector<int> b_data(200, 0); |
4023 | std::vector<int> c_data(200, 0); |
4024 | cg.call({b_data, c_data}); |
4025 | |
4026 | std::vector<int> b_ref(200, 0); |
4027 | std::vector<int> c_ref(200, 0); |
4028 | |
4029 | for (int i = 0; i < 20; ++i) { |
4030 | for (int j = 0; j < 10; ++j) { |
4031 | b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); |
4032 | c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
4033 | } |
4034 | } |
4035 | |
4036 | assertAllEqual(b_data, b_ref); |
4037 | assertAllEqual(c_data, c_ref); |
4038 | } |
4039 | |
4040 | TEST(LoopNest, DeadStoreElimination) { |
4041 | VarHandle y("y" , kInt); |
4042 | VarHandle x("x_tail" , kInt); |
4043 | BufHandle f("f" , {26, 5}, kInt); |
4044 | BufHandle g("g" , {26, 5}, kInt); |
4045 | ExprHandle x_outer_end = 5; |
4046 | ExprHandle x_2 = x + x_outer_end * 4; |
4047 | ForPtr stmt1 = For::make( |
4048 | x, |
4049 | 0, |
4050 | 5, |
4051 | For::make( |
4052 | y, |
4053 | 0, |
4054 | 5, |
4055 | Block::make({ |
4056 | Store::make(f, {x_2, y}, (x_2 + y)), |
4057 | Store::make(g, {x_2, y}, (x_2 * y)), |
4058 | }))); |
4059 | StmtPtr stmt = Block::make({stmt1}); |
4060 | |
4061 | // Will eliminate if not used by an output. |
4062 | LoopNest loop(Stmt::clone(stmt), {f.node()}); |
4063 | loop.eliminateDeadStores(); |
4064 | |
4065 | checkIR(loop.root_stmt(), R"IR( |
4066 | #CHECK: f[x_tail + 5 * 4, y] |
4067 | #CHECK-NOT: g[x_tail + 5 * 4, y] |
4068 | )IR" ); |
4069 | |
4070 | // But won't eliminate if used by different outputs. |
4071 | LoopNest loop2(stmt, {f.node(), g.node()}); |
4072 | loop2.eliminateDeadStores(); |
4073 | |
4074 | checkIR(loop2.root_stmt(), R"IR( |
4075 | #CHECK: f[x_tail + 5 * 4, y] |
4076 | #CHECK: g[x_tail + 5 * 4, y] |
4077 | )IR" ); |
4078 | } |
4079 | |
4080 | TEST(LoopNest, DeadStoreEliminationWithIntermediates) { |
4081 | VarHandle x("x" , kInt); |
4082 | VarHandle y("y" , kInt); |
4083 | VarHandle z("z" , kInt); |
4084 | BufHandle f("f" , {26 * 5}, kInt); |
4085 | BufHandle g("g" , {26 * 5}, kInt); |
4086 | BufHandle h("h" , {26, 5}, kInt); |
4087 | ExprHandle x_outer_end = 5; |
4088 | ExprHandle x_2 = x + x_outer_end * 4; |
4089 | ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); |
4090 | ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); |
4091 | ForPtr stmt3 = For::make( |
4092 | x, |
4093 | 0, |
4094 | 5, |
4095 | For::make( |
4096 | y, |
4097 | 0, |
4098 | 5, |
4099 | Block::make({ |
4100 | Store::make(h, {x, y}, Load::make(f, {x * y})), |
4101 | }))); |
4102 | StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); |
4103 | |
4104 | // Will eliminate the write to g, but not f since it used by the producer of |
4105 | // h. |
4106 | LoopNest loop(Stmt::clone(stmt), {h.node()}); |
4107 | loop.eliminateDeadStores(); |
4108 | |
4109 | checkIR(loop.root_stmt(), R"IR( |
4110 | #CHECK: f[x] = x; |
4111 | #CHECK-NOT: g[z] = |
4112 | #CHECK: h[x, y] = f[x * y]; |
4113 | )IR" ); |
4114 | |
4115 | // Sanity check won't eliminate if g is an output. |
4116 | LoopNest loop2(stmt, {h.node(), g.node()}); |
4117 | loop2.eliminateDeadStores(); |
4118 | |
4119 | checkIR(loop2.root_stmt(), R"IR( |
4120 | #CHECK: f[x] = x; |
4121 | #CHECK: g[z] = z + 1; |
4122 | #CHECK: h[x, y] = f[x * y]; |
4123 | )IR" ); |
4124 | } |
4125 | |
4126 | TEST(LoopNest, CompoundTensorSimple) { |
4127 | BufHandle a_buf("A" , {10, 5}, kInt); |
4128 | VarHandle i("i" , kInt); |
4129 | VarHandle j("j" , kInt); |
4130 | VarHandle x("x" , kInt); |
4131 | VarHandle y("y" , kInt); |
4132 | auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
4133 | auto inner_for1 = For::make(j, 0, 5, for_body1); |
4134 | auto outer_for1 = For::make(i, 0, 10, inner_for1); |
4135 | auto for_body2 = Block::make( |
4136 | {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); |
4137 | auto inner_for2 = For::make(y, 0, 5, for_body2); |
4138 | auto outer_for2 = For::make(x, 0, 10, inner_for2); |
4139 | BlockPtr body = Block::make({outer_for1, outer_for2}); |
4140 | |
4141 | Tensor A = Tensor(a_buf.node(), body); |
4142 | |
4143 | LoopNest l({A}); |
4144 | l.prepareForCodegen(); |
4145 | |
4146 | std::vector<int> a_data(50, 0); |
4147 | |
4148 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
4149 | SimpleIREvaluator cg(s, {A}); |
4150 | |
4151 | std::vector<int> a_ref(50, 0); |
4152 | |
4153 | for (int i = 0; i < 10; ++i) { |
4154 | for (int j = 0; j < 5; ++j) { |
4155 | a_ref[i * 5 + j] = (i * j) + i + j; |
4156 | } |
4157 | } |
4158 | cg.call({a_data}); |
4159 | |
4160 | assertAllEqual(a_data, a_ref); |
4161 | } |
4162 | |
4163 | TEST(LoopNest, InlineConstantIndex) { |
4164 | const int N = 10; |
4165 | BufHandle x_buf("a" , {1, N, 1}, kFloat); |
4166 | Tensor y = Compute( |
4167 | "f" , |
4168 | {1, N, 1}, |
4169 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { |
4170 | return x_buf.load(m, n, o); |
4171 | }); |
4172 | Tensor z = Compute( |
4173 | "f" , |
4174 | {1, N, 1}, |
4175 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { |
4176 | return y.load(m, n, o); |
4177 | }); |
4178 | |
4179 | LoopNest l({z}, {y, z}); |
4180 | l.simplify(); |
4181 | ASSERT_TRUE(l.computeInline(y.buf())); |
4182 | } |
4183 | |
4184 | TEST(LoopNest, CompoundTensorUsed) { |
4185 | BufHandle a_buf("A" , {10, 5}, kInt); |
4186 | VarHandle i("i" , kInt); |
4187 | VarHandle j("j" , kInt); |
4188 | VarHandle x("x" , kInt); |
4189 | VarHandle y("y" , kInt); |
4190 | auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
4191 | auto inner_for1 = For::make(j, 0, 5, for_body1); |
4192 | auto outer_for1 = For::make(i, 0, 10, inner_for1); |
4193 | auto for_body2 = Block::make( |
4194 | {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); |
4195 | auto inner_for2 = For::make(y, 0, 5, for_body2); |
4196 | auto outer_for2 = For::make(x, 0, 10, inner_for2); |
4197 | BlockPtr body = Block::make({outer_for1, outer_for2}); |
4198 | |
4199 | Tensor A = Tensor(a_buf.node(), body); |
4200 | Tensor B = Compute("B" , {10, 3}, [&](const VarHandle& i, const VarHandle& j) { |
4201 | return A.load(i, j + 1) + A.load(i, j + 2); |
4202 | }); |
4203 | |
4204 | LoopNest l({B}, {A, B}); |
4205 | ASSERT_FALSE(l.computeInline(A.buf())); |
4206 | l.prepareForCodegen(); |
4207 | |
4208 | std::vector<int> a_data(50, 0); |
4209 | std::vector<int> b_data(50, 0); |
4210 | |
4211 | StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
4212 | SimpleIREvaluator cg(s, {B}); |
4213 | |
4214 | std::vector<int> b_ref(50, 0); |
4215 | |
4216 | auto AT = [](int i, int j) { return i * j + i + j; }; |
4217 | for (int i = 0; i < 10; ++i) { |
4218 | for (int j = 0; j < 3; ++j) { |
4219 | b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); |
4220 | } |
4221 | } |
4222 | cg.call({b_data}); |
4223 | |
4224 | assertAllEqual(b_data, b_ref); |
4225 | } |
4226 | |
4227 | TEST(LoopNest, InlineFromLoad) { |
4228 | constexpr int N = 1024; |
4229 | BufHandle a("A" , {N}, kInt); |
4230 | BufHandle b("B" , {N}, kInt); |
4231 | VarHandle i("i" , kInt); |
4232 | VarHandle j("j" , kInt); |
4233 | auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); |
4234 | auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); |
4235 | LoopNest l(Block::make({store_a, store_b}), {b.node()}); |
4236 | |
4237 | l.computeInline(a.node()); |
4238 | |
4239 | // Check that A[j] is replaced with j after inlining |
4240 | std::ostringstream oss; |
4241 | oss << *l.root_stmt(); |
4242 | torch::jit::testing::FileCheck().run( |
4243 | R"IR( |
4244 | # CHECK: for (int j |
4245 | # CHECK-NOT: B[j] = A[j] |
4246 | # CHECK-NEXT: B[j] = j |
4247 | )IR" , |
4248 | oss.str()); |
4249 | } |
4250 | |
4251 | TEST(LoopNest, OptimizeConditionalsSimple) { |
4252 | // Input IR: |
4253 | // for (int i = 0; i < 20; i++) { |
4254 | // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
4255 | // } |
4256 | |
4257 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4258 | BufHandle a_buf("A" , {20}, kInt); |
4259 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4260 | BufHandle b_buf("B" , {5}, kInt); |
4261 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4262 | BufHandle c_buf("C" , {15}, kInt); |
4263 | VarHandle i("i" , kInt); |
4264 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4265 | auto store = Store::make( |
4266 | a_buf, |
4267 | {i}, |
4268 | IfThenElse::make( |
4269 | CompareSelect::make(i, 5, kLT), |
4270 | Load::make(b_buf, {i}), |
4271 | Load::make(c_buf, {i - 5}))); |
4272 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4273 | auto forI = For::make(i, 0, 20, store); |
4274 | auto par = Block::make({forI}); |
4275 | |
4276 | LoopNest nest(par, {a_buf.node()}); |
4277 | nest.optimizeConditionals(); |
4278 | |
4279 | std::ostringstream oss; |
4280 | oss << *nest.root_stmt(); |
4281 | const std::string& verification_pattern = |
4282 | R"IR( |
4283 | # CHECK: for (int i = 0; i < 5 |
4284 | # CHECK-NEXT: A[i] = B[i] |
4285 | # CHECK: for (int i = 0; i < 15 |
4286 | # CHECK-NEXT: A[i + 5] = C[i] |
4287 | )IR" ; |
4288 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4289 | } |
4290 | |
4291 | TEST(LoopNest, OptimizeConditionalsNestedConditions) { |
4292 | // Input IR: |
4293 | // for (int i = 0; i < 20; i++) { |
4294 | // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
4295 | // } |
4296 | |
4297 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4298 | BufHandle a_buf("A" , {20}, kInt); |
4299 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4300 | BufHandle b_buf("B" , {5}, kInt); |
4301 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4302 | BufHandle c_buf("C" , {5}, kInt); |
4303 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4304 | BufHandle d_buf("D" , {10}, kInt); |
4305 | VarHandle i("i" , kInt); |
4306 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4307 | auto store = Store::make( |
4308 | a_buf, |
4309 | {i}, |
4310 | IfThenElse::make( |
4311 | CompareSelect::make(i, 10, kLT), |
4312 | IfThenElse::make( |
4313 | CompareSelect::make(i, 5, kLT), |
4314 | Load::make(b_buf, {i}), |
4315 | Load::make(c_buf, {i - 5})), |
4316 | Load::make(d_buf, {i - 10}))); |
4317 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4318 | auto forI = For::make(i, 0, 20, store); |
4319 | auto par = Block::make({forI}); |
4320 | |
4321 | LoopNest nest(par, {a_buf.node()}); |
4322 | nest.optimizeConditionals(); |
4323 | |
4324 | std::ostringstream oss; |
4325 | oss << *nest.root_stmt(); |
4326 | const std::string& verification_pattern = |
4327 | R"IR( |
4328 | # CHECK: for (int i = 0; i < 5 |
4329 | # CHECK-NEXT: A[i] = B[i] |
4330 | # CHECK: for (int i = 0; i < 5 |
4331 | # CHECK-NEXT: A[i + 5] = C[i] |
4332 | # CHECK: for (int i = 0; i < 10 |
4333 | # CHECK-NEXT: A[i + 10] = D[i] |
4334 | )IR" ; |
4335 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4336 | } |
4337 | |
4338 | TEST(LoopNest, OptimizeConditionalsMultipleStores) { |
4339 | // Input IR: |
4340 | // for (int i = 0; i < 20; i++) { |
4341 | // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
4342 | // } |
4343 | // for (int j = 0; j < 100; j++) { |
4344 | // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) |
4345 | // } |
4346 | |
4347 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4348 | BufHandle a_buf("A" , {20}, kInt); |
4349 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4350 | BufHandle b_buf("B" , {5}, kInt); |
4351 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4352 | BufHandle c_buf("C" , {100}, kInt); |
4353 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4354 | BufHandle d_buf("D" , {100}, kInt); |
4355 | VarHandle i("i" , kInt); |
4356 | VarHandle j("j" , kInt); |
4357 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4358 | auto storeA = Store::make( |
4359 | a_buf, |
4360 | {i}, |
4361 | IfThenElse::make( |
4362 | CompareSelect::make(i, 5, kLT), |
4363 | Load::make(b_buf, {i}), |
4364 | Load::make(c_buf, {i - 5}))); |
4365 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4366 | auto forI = For::make(i, 0, 20, storeA); |
4367 | auto storeB = Store::make( |
4368 | b_buf, |
4369 | {j}, |
4370 | IfThenElse::make( |
4371 | CompareSelect::make(j, 30, kLT), |
4372 | Load::make(c_buf, {j}), |
4373 | Load::make(d_buf, {j}))); |
4374 | auto forJ = For::make(j, 0, 100, storeB); |
4375 | auto par = Block::make({forI, forJ}); |
4376 | |
4377 | LoopNest nest(par, {a_buf.node()}); |
4378 | nest.optimizeConditionals(); |
4379 | |
4380 | std::ostringstream oss; |
4381 | oss << *nest.root_stmt(); |
4382 | const std::string& verification_pattern = |
4383 | R"IR( |
4384 | # CHECK: for (int i = 0; i < 5 |
4385 | # CHECK-NEXT: A[i] = B[i] |
4386 | # CHECK: for (int i = 0; i < 15 |
4387 | # CHECK-NEXT: A[i + 5] = C[i] |
4388 | # CHECK: for (int j = 0; j < 30 |
4389 | # CHECK-NEXT: B[j] = C[j] |
4390 | # CHECK: for (int j = 0; j < 70 |
4391 | # CHECK-NEXT: B[j + 30] = D[j + 30] |
4392 | )IR" ; |
4393 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4394 | } |
4395 | |
4396 | TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { |
4397 | // Input IR: |
4398 | // for (int i = 0; i < 50; i++) { |
4399 | // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
4400 | // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) |
4401 | // } |
4402 | // Only the first conditional, in the write to A, will be optimized. |
4403 | |
4404 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4405 | BufHandle a_buf("A" , {100}, kInt); |
4406 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4407 | BufHandle b_buf("B" , {100}, kInt); |
4408 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4409 | BufHandle c_buf("C" , {100}, kInt); |
4410 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4411 | BufHandle d_buf("D" , {100}, kInt); |
4412 | VarHandle i("i" , kInt); |
4413 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4414 | auto storeA = Store::make( |
4415 | a_buf, |
4416 | {i}, |
4417 | IfThenElse::make( |
4418 | CompareSelect::make(i, 5, kLT), |
4419 | Load::make(b_buf, {i}), |
4420 | Load::make(c_buf, {i - 5}))); |
4421 | auto storeB = Store::make( |
4422 | b_buf, |
4423 | {i}, |
4424 | IfThenElse::make( |
4425 | CompareSelect::make(i, 30, kLT), |
4426 | Load::make(c_buf, {i}), |
4427 | Load::make(d_buf, {i}))); |
4428 | auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); |
4429 | auto par = Block::make({forI}); |
4430 | |
4431 | LoopNest nest(par, {a_buf.node()}); |
4432 | nest.optimizeConditionals(); |
4433 | |
4434 | std::ostringstream oss; |
4435 | oss << *nest.root_stmt(); |
4436 | const std::string& verification_pattern = |
4437 | R"IR( |
4438 | # CHECK: for (int i = 0; i < 5 |
4439 | # CHECK-NEXT: A[i] = B[i] |
4440 | # CHECK-NEXT: B[i] = C[i] |
4441 | # CHECK: for (int i = 0; i < 45 |
4442 | # CHECK-NEXT: A[i + 5] = C[i] |
4443 | # CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) |
4444 | )IR" ; |
4445 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4446 | } |
4447 | |
4448 | TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { |
4449 | // Input IR: |
4450 | // for (int i = 0; i < 20; i++) { |
4451 | // for (int j = 0; j < 100; j++) { |
4452 | // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
4453 | // } |
4454 | // } |
4455 | // Currently, this case where the condition variable `i` is not the |
4456 | // inner-most loop variable, is not optimized. |
4457 | |
4458 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4459 | BufHandle a_buf("A" , {20}, kInt); |
4460 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4461 | BufHandle b_buf("B" , {5}, kInt); |
4462 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4463 | BufHandle c_buf("C" , {5}, kInt); |
4464 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4465 | BufHandle d_buf("D" , {10}, kInt); |
4466 | VarHandle i("i" , kInt); |
4467 | VarHandle j("j" , kInt); |
4468 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4469 | auto store = Store::make( |
4470 | a_buf, |
4471 | {i}, |
4472 | IfThenElse::make( |
4473 | CompareSelect::make(i, 10, kLT), |
4474 | IfThenElse::make( |
4475 | CompareSelect::make(i, 5, kLT), |
4476 | Load::make(b_buf, {i}), |
4477 | Load::make(c_buf, {i - 5})), |
4478 | Load::make(d_buf, {i - 10}))); |
4479 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4480 | auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); |
4481 | auto par = Block::make({forI}); |
4482 | LoopNest nest(par, {a_buf.node()}); |
4483 | |
4484 | HashProvider hasher; |
4485 | auto hash_before = hasher.hash(nest.root_stmt()); |
4486 | nest.optimizeConditionals(); |
4487 | auto hash_after = hasher.hash(nest.root_stmt()); |
4488 | ASSERT_EQ(hash_before, hash_after); |
4489 | } |
4490 | |
4491 | TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { |
4492 | // Input IR: |
4493 | // for (int i = 0; i < 20; i++) { |
4494 | // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) |
4495 | // } |
4496 | // No optimization should be done here because one of the conditions use '>'. |
4497 | |
4498 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4499 | BufHandle a_buf("A" , {20}, kInt); |
4500 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4501 | BufHandle b_buf("B" , {5}, kInt); |
4502 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4503 | BufHandle c_buf("C" , {5}, kInt); |
4504 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4505 | BufHandle d_buf("D" , {10}, kInt); |
4506 | VarHandle i("i" , kInt); |
4507 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4508 | auto store = Store::make( |
4509 | a_buf, |
4510 | {i}, |
4511 | IfThenElse::make( |
4512 | CompareSelect::make(i, 5, kLT), |
4513 | IfThenElse::make( |
4514 | CompareSelect::make(i, 10, kLT), |
4515 | Load::make(b_buf, {i}), |
4516 | Load::make(c_buf, {i - 5})), |
4517 | Load::make(d_buf, {i - 10}))); |
4518 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4519 | auto forI = For::make(i, 0, 20, store); |
4520 | auto par = Block::make({forI}); |
4521 | LoopNest nest(par, {a_buf.node()}); |
4522 | |
4523 | HashProvider hasher; |
4524 | auto hash_before = hasher.hash(nest.root_stmt()); |
4525 | nest.optimizeConditionals(); |
4526 | auto hash_after = hasher.hash(nest.root_stmt()); |
4527 | ASSERT_EQ(hash_before, hash_after); |
4528 | } |
4529 | |
4530 | TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { |
4531 | // Input IR: |
4532 | // for (int i = 0; i < 20; i++) { |
4533 | // A[i] = IfThenElse(i<N, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
4534 | // } |
4535 | // No optimization should be done here because one of the conditions use '>'. |
4536 | |
4537 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4538 | BufHandle a_buf("A" , {20}, kInt); |
4539 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4540 | BufHandle b_buf("B" , {5}, kInt); |
4541 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4542 | BufHandle c_buf("C" , {5}, kInt); |
4543 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4544 | BufHandle d_buf("D" , {10}, kInt); |
4545 | VarHandle i("i" , kInt); |
4546 | VarHandle N("N" , kInt); |
4547 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4548 | auto store = Store::make( |
4549 | a_buf, |
4550 | {i}, |
4551 | IfThenElse::make( |
4552 | CompareSelect::make(i, N, kLT), |
4553 | IfThenElse::make( |
4554 | CompareSelect::make(i, 5, kLT), |
4555 | Load::make(b_buf, {i}), |
4556 | Load::make(c_buf, {i - 5})), |
4557 | Load::make(d_buf, {i - 10}))); |
4558 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4559 | auto forI = For::make(i, 0, 20, store); |
4560 | auto par = Block::make({forI}); |
4561 | LoopNest nest(par, {a_buf.node()}); |
4562 | |
4563 | HashProvider hasher; |
4564 | auto hash_before = hasher.hash(nest.root_stmt()); |
4565 | nest.optimizeConditionals(); |
4566 | auto hash_after = hasher.hash(nest.root_stmt()); |
4567 | ASSERT_EQ(hash_before, hash_after); |
4568 | } |
4569 | |
4570 | TEST(LoopNest, OptimizeConditionalsInvalidCondition) { |
4571 | // Input IR: |
4572 | // for (int i = 0; i < 20; i++) { |
4573 | // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) |
4574 | // } |
4575 | // No optimization should be done here because one of the conditions use '>'. |
4576 | |
4577 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4578 | BufHandle a_buf("A" , {20}, kInt); |
4579 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4580 | BufHandle b_buf("B" , {5}, kInt); |
4581 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4582 | BufHandle c_buf("C" , {5}, kInt); |
4583 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4584 | BufHandle d_buf("D" , {10}, kInt); |
4585 | VarHandle i("i" , kInt); |
4586 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4587 | auto store = Store::make( |
4588 | a_buf, |
4589 | {i}, |
4590 | IfThenElse::make( |
4591 | CompareSelect::make(i, 10, kLT), |
4592 | IfThenElse::make( |
4593 | CompareSelect::make(i, 5, kGT), |
4594 | Load::make(b_buf, {i}), |
4595 | Load::make(c_buf, {i - 5})), |
4596 | Load::make(d_buf, {i - 10}))); |
4597 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4598 | auto forI = For::make(i, 0, 20, store); |
4599 | auto par = Block::make({forI}); |
4600 | LoopNest nest(par, {a_buf.node()}); |
4601 | |
4602 | HashProvider hasher; |
4603 | auto hash_before = hasher.hash(nest.root_stmt()); |
4604 | nest.optimizeConditionals(); |
4605 | auto hash_after = hasher.hash(nest.root_stmt()); |
4606 | ASSERT_EQ(hash_before, hash_after); |
4607 | } |
4608 | |
4609 | TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { |
4610 | // Input IR: |
4611 | // for (int i = 0; i < 20; i++) { |
4612 | // A[i] = IfThenElse(10<i, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
4613 | // } |
4614 | // No optimization should be done here because of the invalid condition: |
4615 | // "10 < i". |
4616 | |
4617 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4618 | BufHandle a_buf("A" , {20}, kInt); |
4619 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4620 | BufHandle b_buf("B" , {5}, kInt); |
4621 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4622 | BufHandle c_buf("C" , {5}, kInt); |
4623 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4624 | BufHandle d_buf("D" , {10}, kInt); |
4625 | VarHandle i("i" , kInt); |
4626 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4627 | auto store = Store::make( |
4628 | a_buf, |
4629 | {i}, |
4630 | IfThenElse::make( |
4631 | CompareSelect::make(10, i, kLT), |
4632 | IfThenElse::make( |
4633 | CompareSelect::make(i, 5, kLT), |
4634 | Load::make(b_buf, {i}), |
4635 | Load::make(c_buf, {i - 5})), |
4636 | Load::make(d_buf, {i - 10}))); |
4637 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4638 | auto forI = For::make(i, 0, 20, store); |
4639 | auto par = Block::make({forI}); |
4640 | LoopNest nest(par, {a_buf.node()}); |
4641 | |
4642 | HashProvider hasher; |
4643 | auto hash_before = hasher.hash(nest.root_stmt()); |
4644 | nest.optimizeConditionals(); |
4645 | auto hash_after = hasher.hash(nest.root_stmt()); |
4646 | ASSERT_EQ(hash_before, hash_after); |
4647 | } |
4648 | |
4649 | TEST(LoopNest, OptimizeConditionalsInvalidCondition3) { |
4650 | // Input IR: |
4651 | // for (int i = 0; i < 20; i++) { |
4652 | // A[i] = IfThenElse(i<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10]) |
4653 | // } |
4654 | // No optimization should be done here because the conditions use different |
4655 | // variables: "i < 10" and "k < 5" |
4656 | |
4657 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4658 | BufHandle a_buf("A" , {20}, kInt); |
4659 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4660 | BufHandle b_buf("B" , {5}, kInt); |
4661 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4662 | BufHandle c_buf("C" , {5}, kInt); |
4663 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4664 | BufHandle d_buf("D" , {10}, kInt); |
4665 | VarHandle i("i" , kInt); |
4666 | VarHandle k("k" , kInt); |
4667 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4668 | auto store = Store::make( |
4669 | a_buf, |
4670 | {i}, |
4671 | IfThenElse::make( |
4672 | CompareSelect::make(i, 10, kLT), |
4673 | IfThenElse::make( |
4674 | CompareSelect::make(k, 5, kLT), |
4675 | Load::make(b_buf, {i}), |
4676 | Load::make(c_buf, {i - 5})), |
4677 | Load::make(d_buf, {i - 10}))); |
4678 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4679 | auto forI = For::make(i, 0, 20, store); |
4680 | auto par = Block::make({forI}); |
4681 | LoopNest nest(par, {a_buf.node()}); |
4682 | |
4683 | HashProvider hasher; |
4684 | auto hash_before = hasher.hash(nest.root_stmt()); |
4685 | nest.optimizeConditionals(); |
4686 | auto hash_after = hasher.hash(nest.root_stmt()); |
4687 | ASSERT_EQ(hash_before, hash_after); |
4688 | } |
4689 | |
4690 | TEST(LoopNest, OptimizeConditionalsInvalidCondition4) { |
4691 | // Input IR: |
4692 | // for (int i = 0; i < 20; i++) { |
4693 | // A[i] = IfThenElse(k<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10]) |
4694 | // } |
4695 | // No optimization should be done here because the conditions use the |
4696 | // variable 'k' which is not a loop variable. |
4697 | |
4698 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4699 | BufHandle a_buf("A" , {20}, kInt); |
4700 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4701 | BufHandle b_buf("B" , {5}, kInt); |
4702 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4703 | BufHandle c_buf("C" , {5}, kInt); |
4704 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4705 | BufHandle d_buf("D" , {10}, kInt); |
4706 | VarHandle i("i" , kInt); |
4707 | VarHandle k("k" , kInt); |
4708 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4709 | auto store = Store::make( |
4710 | a_buf, |
4711 | {i}, |
4712 | IfThenElse::make( |
4713 | CompareSelect::make(k, 10, kLT), |
4714 | IfThenElse::make( |
4715 | CompareSelect::make(k, 5, kLT), |
4716 | Load::make(b_buf, {i}), |
4717 | Load::make(c_buf, {i - 5})), |
4718 | Load::make(d_buf, {i - 10}))); |
4719 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4720 | auto forI = For::make(i, 0, 20, store); |
4721 | auto par = Block::make({forI}); |
4722 | LoopNest nest(par, {a_buf.node()}); |
4723 | |
4724 | HashProvider hasher; |
4725 | auto hash_before = hasher.hash(nest.root_stmt()); |
4726 | nest.optimizeConditionals(); |
4727 | auto hash_after = hasher.hash(nest.root_stmt()); |
4728 | ASSERT_EQ(hash_before, hash_after); |
4729 | } |
4730 | |
4731 | TEST(LoopNest, OptimizeConditionalsNotNormalized) { |
4732 | // Input IR: |
4733 | // for (int i = 2; i < 20; i++) { |
4734 | // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
4735 | // } |
4736 | |
4737 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4738 | BufHandle a_buf("A" , {20}, kInt); |
4739 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4740 | BufHandle b_buf("B" , {5}, kInt); |
4741 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4742 | BufHandle c_buf("C" , {15}, kInt); |
4743 | VarHandle i("i" , kInt); |
4744 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4745 | auto store = Store::make( |
4746 | a_buf, |
4747 | {i}, |
4748 | IfThenElse::make( |
4749 | CompareSelect::make(i, 5, kLT), |
4750 | Load::make(b_buf, {i}), |
4751 | Load::make(c_buf, {i - 5}))); |
4752 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4753 | auto forI = For::make(i, 2, 20, store); |
4754 | auto par = Block::make({forI}); |
4755 | LoopNest nest(par, {a_buf.node()}); |
4756 | |
4757 | HashProvider hasher; |
4758 | auto hash_before = hasher.hash(nest.root_stmt()); |
4759 | nest.optimizeConditionals(); |
4760 | auto hash_after = hasher.hash(nest.root_stmt()); |
4761 | ASSERT_EQ(hash_before, hash_after); |
4762 | } |
4763 | |
4764 | static std::pair<BufHandle, Tensor> colReduce(int M, int N) { |
4765 | BufHandle a("a" , {M, N}, kFloat); |
4766 | Tensor t = Reduce( |
4767 | "b" , |
4768 | {N}, |
4769 | Sum(), |
4770 | [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, |
4771 | {M}); |
4772 | return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; |
4773 | } |
4774 | |
4775 | static StmtPtr splitTailReorder(Tensor b) { |
4776 | constexpr int kVectorWidth = 8; |
4777 | LoopNest nest({b}); |
4778 | auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; |
4779 | nest.splitWithTail(loops[0], kVectorWidth); |
4780 | // Now the loopnests will look like: |
4781 | // |
4782 | // for (int i_outer = 0; ... |
4783 | // for (int i_inner = 0; ... |
4784 | // b[i_outer * 8 + i_inner] = float(0); |
4785 | // for (int j = 0; ... |
4786 | // b[i_outer * 8 + i_inner] = ReduceOp(...); |
4787 | // |
4788 | // for (int i_tail = 0; ... |
4789 | // b[i_tail + ((100 - 0) / 8) * 8] = float(0); |
4790 | // for (int j = 0; ... |
4791 | // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); |
4792 | // |
4793 | // Since there are 4 writes to b, we will get 4 loopnests from the |
4794 | // call to `getAllLoopNestsWritingToBuf` below. |
4795 | // |
4796 | // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" |
4797 | // Loopnest #2: {i_outer, i_inner, j}; |
4798 | // We will have to reorder i_inner and j. |
4799 | auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); |
4800 | LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); |
4801 | nest.prepareForCodegen(); |
4802 | return nest.root_stmt(); |
4803 | } |
4804 | |
4805 | static StmtPtr splitMaskReorder(Tensor b) { |
4806 | constexpr int kVectorWidth = 8; |
4807 | LoopNest nest({b}); |
4808 | auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; |
4809 | nest.splitWithMask(loops[0], kVectorWidth); |
4810 | loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; |
4811 | LoopNest::reorderAxis(loops[1], loops[2]); |
4812 | nest.prepareForCodegen(); |
4813 | return nest.root_stmt(); |
4814 | } |
4815 | |
4816 | static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { |
4817 | int M = immediateAs<int>(p.dim(0)); |
4818 | int N = immediateAs<int>(p.dim(1)); |
4819 | PaddedBuffer<float> a(M, N); |
4820 | PaddedBuffer<float> b(N); |
4821 | PaddedBuffer<float> ref(N); |
4822 | for (int i = 0; i < M; i++) { |
4823 | for (int j = 0; j < N; j++) { |
4824 | a(i, j) = 1.0f; |
4825 | } |
4826 | } |
4827 | for (int i = 0; i < N; i++) { |
4828 | b(i) = 0.0f; |
4829 | } |
4830 | for (int i = 0; i < N; i++) { |
4831 | ref(i) = 76.0f; |
4832 | } |
4833 | SimpleIREvaluator(s, {p, t}).call({a, b}); |
4834 | ExpectAllNear(b, ref, 1e-5); |
4835 | } |
4836 | |
4837 | TEST(LoopNest, ColReduceSplitTailEvenReorder) { |
4838 | constexpr int M = 76, N = 128; |
4839 | auto p = colReduce(M, N); |
4840 | StmtPtr s = splitTailReorder(p.second); |
4841 | |
4842 | std::ostringstream oss; |
4843 | oss << *s; |
4844 | const std::string& verification_pattern = |
4845 | R"IR( |
4846 | # CHECK: for (int i_outer |
4847 | # CHECK-NEXT: for (int i_inner |
4848 | # CHECK-NEXT: b[ |
4849 | # CHECK: for (int j |
4850 | # CHECK-NEXT: for (int i_inner |
4851 | # CHECK-NEXT: b[ |
4852 | # CHECK-NOT: for ( |
4853 | )IR" ; |
4854 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4855 | |
4856 | checkColReduce(s, p.first, p.second); |
4857 | } |
4858 | |
4859 | TEST(LoopNest, ColReduceSplitTailUnevenReorder) { |
4860 | constexpr int M = 76, N = 100; |
4861 | auto p = colReduce(M, N); |
4862 | StmtPtr s = splitTailReorder(p.second); |
4863 | |
4864 | std::ostringstream oss; |
4865 | oss << *s; |
4866 | const std::string& verification_pattern = |
4867 | R"IR( |
4868 | # CHECK: for (int i_outer |
4869 | # CHECK-NEXT: for (int i_inner |
4870 | # CHECK-NEXT: b[ |
4871 | # CHECK: for (int j |
4872 | # CHECK-NEXT: for (int i_inner |
4873 | # CHECK-NEXT: b[ |
4874 | # CHECK: for (int i_tail |
4875 | # CHECK-NEXT: b[ |
4876 | # CHECK-NEXT: for (int j |
4877 | # CHECK-NEXT: b[ |
4878 | )IR" ; |
4879 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4880 | |
4881 | checkColReduce(s, p.first, p.second); |
4882 | } |
4883 | |
4884 | TEST(LoopNest, ColReduceSplitMaskEvenReorder) { |
4885 | constexpr int M = 76, N = 128; |
4886 | auto p = colReduce(M, N); |
4887 | StmtPtr s = splitMaskReorder(p.second); |
4888 | checkColReduce(s, p.first, p.second); |
4889 | } |
4890 | |
4891 | TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { |
4892 | constexpr int M = 76, N = 100; |
4893 | auto p = colReduce(M, N); |
4894 | StmtPtr s = splitMaskReorder(p.second); |
4895 | checkColReduce(s, p.first, p.second); |
4896 | } |
4897 | |
4898 | TEST(LoopNest, ReorderAxisWithMultipleConds) { |
4899 | // Input IR: |
4900 | // for (int i = 0; i < 20; i++) { |
4901 | // if i > 5 { |
4902 | // if i < 10 { |
4903 | // for (int j = 0; j < 100; j++) { |
4904 | // A[i] = i * j; |
4905 | // } |
4906 | // } |
4907 | // } |
4908 | // } |
4909 | BufHandle a_buf("A" , {20}, kInt); |
4910 | VarHandle i("i" , kInt); |
4911 | VarHandle j("j" , kInt); |
4912 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); |
4913 | auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); |
4914 | auto outer_cond = |
4915 | Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); |
4916 | auto forI = For::make(i, 0, 20, outer_cond); |
4917 | StmtPtr par = Block::make({forI}); |
4918 | LoopNest l(par, {a_buf.node()}); |
4919 | LoopNest::reorderAxis(forI, forJ); |
4920 | ASSERT_EQ(par, l.root_stmt()); |
4921 | par = IRSimplifier::simplify(par); |
4922 | |
4923 | const std::string& verification_pattern = |
4924 | R"IR( |
4925 | # CHECK: for (int j |
4926 | # CHECK-NEXT: for (int i |
4927 | # CHECK-NEXT: if (i>5 |
4928 | # CHECK-NEXT: if (i<10 |
4929 | # CHECK-NEXT: A[i] = i * j |
4930 | # CHECK-NOT: for ( |
4931 | )IR" ; |
4932 | std::ostringstream oss; |
4933 | oss << *par; |
4934 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
4935 | } |
4936 | |
4937 | TEST(LoopNest, VectorizeUse) { |
4938 | constexpr int N = 8; |
4939 | BufHandle a("a" , {N}, kFloat); |
4940 | Tensor b = |
4941 | Compute("b" , {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); |
4942 | Tensor c = |
4943 | Compute("c" , {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); |
4944 | LoopNest nest({c}, {b, c}); |
4945 | auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; |
4946 | ASSERT_TRUE(LoopNest::vectorize(loops[0])); |
4947 | loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; |
4948 | ASSERT_TRUE(LoopNest::vectorize(loops[0])); |
4949 | nest.prepareForCodegen(); |
4950 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
4951 | StmtPtr s = nest.root_stmt(); |
4952 | std::ostringstream oss; |
4953 | oss << *nest.root_stmt(); |
4954 | torch::jit::testing::FileCheck().run( |
4955 | R"IR( |
4956 | # CHECK: c[Ramp |
4957 | )IR" , |
4958 | oss.str()); |
4959 | } |
4960 | |
4961 | const char* int64Loop = R"IR( |
4962 | # CHECK: for (int64_t i = 0ll; i < 12ll; i++) { |
4963 | # CHECK: b[i] = (a[i]) + 1ll; |
4964 | # CHECK: } |
4965 | )IR" ; |
4966 | |
4967 | TEST(LoopNest, Int64Direct) { |
4968 | constexpr int64_t N = 12; |
4969 | BufHandle a("a" , {N}, kLong); |
4970 | BufHandle b("b" , {N}, kLong); |
4971 | VarHandle n("i" , kLong); |
4972 | StmtPtr s = For::make( |
4973 | n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); |
4974 | s = IRSimplifier::simplify(s); |
4975 | std::ostringstream oss; |
4976 | oss << *s; |
4977 | torch::jit::testing::FileCheck().run(int64Loop, oss.str()); |
4978 | } |
4979 | |
4980 | TEST(LoopNest, Int64Compute) { |
4981 | constexpr int64_t N = 12; |
4982 | BufHandle a("a" , {N}, kLong); |
4983 | Tensor b = Compute("b" , {N}, [&](const VarHandle& n) { |
4984 | return a.load(n) + LongImm::make(1l); |
4985 | }); |
4986 | LoopNest nest({b}); |
4987 | nest.prepareForCodegen(); |
4988 | nest.simplify(); |
4989 | std::ostringstream oss; |
4990 | oss << *nest.root_stmt(); |
4991 | torch::jit::testing::FileCheck().run(int64Loop, oss.str()); |
4992 | } |
4993 | |
4994 | TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { |
4995 | // Input IR: |
4996 | // for (int i = 0; i < 20; i++) { |
4997 | // A[i] = 0; |
4998 | // for (int j = 0; j < 100; j++) { |
4999 | // A[i] = A[i] + i * j; |
5000 | // } |
5001 | // B[i] = A[i]; |
5002 | // for (int k = 0; k < 50; k++) { |
5003 | // B[i] = B[i] + i * k; |
5004 | // } |
5005 | // } |
5006 | BufHandle a_buf("A" , {20}, kInt); |
5007 | BufHandle b_buf("B" , {20}, kInt); |
5008 | VarHandle i("i" , kInt); |
5009 | VarHandle j("j" , kInt); |
5010 | VarHandle k("k" , kInt); |
5011 | auto initA = Store::make(a_buf, {i}, 0); |
5012 | auto forJ = For::make( |
5013 | j, |
5014 | 0, |
5015 | 100, |
5016 | Store::make( |
5017 | a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
5018 | auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
5019 | auto forK = For::make( |
5020 | k, |
5021 | 0, |
5022 | 50, |
5023 | Store::make( |
5024 | b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
5025 | auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
5026 | auto par = Block::make({forI}); |
5027 | |
5028 | const std::string& verification_pattern = |
5029 | R"IR( |
5030 | # CHECK: for (int i |
5031 | # CHECK-NEXT: A[i] = 0 |
5032 | # CHECK: for (int i |
5033 | # CHECK-NEXT: for (int j |
5034 | # CHECK-NEXT: A[i] = |
5035 | # CHECK: for (int i |
5036 | # CHECK-NEXT: B[i] = A[i] |
5037 | # CHECK: for (int i |
5038 | # CHECK-NEXT: for (int k |
5039 | # CHECK-NEXT: B[i] = |
5040 | # CHECK-NOT: for ( |
5041 | )IR" ; |
5042 | |
5043 | LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
5044 | auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); |
5045 | |
5046 | std::ostringstream oss; |
5047 | oss << *par; |
5048 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5049 | |
5050 | // The first loop after distribution must be same as the original For. |
5051 | ASSERT_EQ(new_loops.front(), forI); |
5052 | } |
5053 | |
5054 | TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { |
5055 | // Input IR: |
5056 | // for (int i = 0; i < 20; i++) { |
5057 | // A[i] = 0; |
5058 | // for (int j = 0; j < 100; j++) { |
5059 | // A[i] = A[i] + i * j; |
5060 | // } |
5061 | // B[i] = A[i]; |
5062 | // for (int k = 0; k < 50; k++) { |
5063 | // B[i] = B[i] + i * k; |
5064 | // } |
5065 | // } |
5066 | BufHandle a_buf("A" , {20}, kInt); |
5067 | BufHandle b_buf("B" , {20}, kInt); |
5068 | VarHandle i("i" , kInt); |
5069 | VarHandle j("j" , kInt); |
5070 | VarHandle k("k" , kInt); |
5071 | auto initA = Store::make(a_buf, {i}, 0); |
5072 | auto forJ = For::make( |
5073 | j, |
5074 | 0, |
5075 | 100, |
5076 | Store::make( |
5077 | a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
5078 | auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
5079 | auto forK = For::make( |
5080 | k, |
5081 | 0, |
5082 | 50, |
5083 | Store::make( |
5084 | b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
5085 | auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
5086 | auto par = Block::make({forI}); |
5087 | |
5088 | LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
5089 | auto new_loops = LoopNest::distributeLoop(forI, {forJ}); |
5090 | |
5091 | std::ostringstream oss; |
5092 | oss << *par; |
5093 | const std::string& verification_pattern = |
5094 | R"IR( |
5095 | # CHECK: for (int i |
5096 | # CHECK-NEXT: A[i] = 0 |
5097 | # CHECK-NEXT: for (int j |
5098 | # CHECK-NEXT: A[i] = |
5099 | # CHECK: for (int i |
5100 | # CHECK-NEXT: B[i] = A[i] |
5101 | # CHECK-NEXT: for (int k |
5102 | # CHECK-NEXT: B[i] = |
5103 | # CHECK-NOT: for ( |
5104 | )IR" ; |
5105 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5106 | |
5107 | // The first loop after distribution must be same as the original For. |
5108 | ASSERT_EQ(new_loops.front(), forI); |
5109 | } |
5110 | |
5111 | TEST(LoopNest, DistributeLoopWithoutAnyPivot) { |
5112 | // Input IR: |
5113 | // for (int i = 0; i < 20; i++) { |
5114 | // A[i] = 0; |
5115 | // for (int j = 0; j < 100; j++) { |
5116 | // A[i] = A[i] + i * j; |
5117 | // } |
5118 | // B[i] = A[i]; |
5119 | // for (int k = 0; k < 50; k++) { |
5120 | // B[i] = B[i] + i * k; |
5121 | // } |
5122 | // } |
5123 | BufHandle a_buf("A" , {20}, kInt); |
5124 | BufHandle b_buf("B" , {20}, kInt); |
5125 | VarHandle i("i" , kInt); |
5126 | VarHandle j("j" , kInt); |
5127 | VarHandle k("k" , kInt); |
5128 | auto initA = Store::make(a_buf, {i}, 0); |
5129 | auto forJ = For::make( |
5130 | j, |
5131 | 0, |
5132 | 100, |
5133 | Store::make( |
5134 | a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
5135 | auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
5136 | auto forK = For::make( |
5137 | k, |
5138 | 0, |
5139 | 50, |
5140 | Store::make( |
5141 | b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
5142 | auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
5143 | auto par = Block::make({forI}); |
5144 | |
5145 | const std::string& verification_pattern = |
5146 | R"IR( |
5147 | # CHECK: for (int i |
5148 | # CHECK-NEXT: A[i] = 0 |
5149 | # CHECK: for (int i |
5150 | # CHECK-NEXT: for (int j |
5151 | # CHECK-NEXT: A[i] = |
5152 | # CHECK: for (int i |
5153 | # CHECK-NEXT: B[i] = A[i] |
5154 | # CHECK: for (int i |
5155 | # CHECK-NEXT: for (int k |
5156 | # CHECK-NEXT: B[i] = |
5157 | # CHECK-NOT: for ( |
5158 | )IR" ; |
5159 | |
5160 | LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
5161 | auto new_loops = LoopNest::distributeLoop(forI); |
5162 | |
5163 | std::ostringstream oss; |
5164 | oss << *par; |
5165 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5166 | |
5167 | // The first loop after distribution must be same as the original For. |
5168 | ASSERT_EQ(new_loops.front(), forI); |
5169 | } |
5170 | |
5171 | TEST(LoopNest, DistributeLoopOverInnerLoops) { |
5172 | // Input IR: |
5173 | // for (int i = 0; i < 20; i++) { |
5174 | // A[i] = 0; |
5175 | // for (int j = 0; j < 100; j++) { |
5176 | // A[i] = A[i] + i * j; |
5177 | // } |
5178 | // B[i] = A[i]; |
5179 | // for (int k = 0; k < 50; k++) { |
5180 | // B[i] = B[i] + i * k; |
5181 | // } |
5182 | // } |
5183 | BufHandle a_buf("A" , {20}, kInt); |
5184 | BufHandle b_buf("B" , {20}, kInt); |
5185 | VarHandle i("i" , kInt); |
5186 | VarHandle j("j" , kInt); |
5187 | VarHandle k("k" , kInt); |
5188 | auto initA = Store::make(a_buf, {i}, 0); |
5189 | auto forJ = For::make( |
5190 | j, |
5191 | 0, |
5192 | 100, |
5193 | Store::make( |
5194 | a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
5195 | auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
5196 | auto forK = For::make( |
5197 | k, |
5198 | 0, |
5199 | 50, |
5200 | Store::make( |
5201 | b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
5202 | auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
5203 | auto par = Block::make({forI}); |
5204 | |
5205 | LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
5206 | auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); |
5207 | |
5208 | std::ostringstream oss; |
5209 | oss << *par; |
5210 | const std::string& verification_pattern = |
5211 | R"IR( |
5212 | # CHECK: for (int i |
5213 | # CHECK-NEXT: A[i] = 0 |
5214 | # CHECK-NEXT: for (int j |
5215 | # CHECK-NEXT: A[i] = |
5216 | # CHECK: for (int i |
5217 | # CHECK-NEXT: B[i] = A[i] |
5218 | # CHECK-NEXT: for (int k |
5219 | # CHECK-NEXT: B[i] = |
5220 | # CHECK-NOT: for ( |
5221 | )IR" ; |
5222 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5223 | |
5224 | // The first loop after distribution must be same as the original For. |
5225 | ASSERT_EQ(new_loops.front(), forI); |
5226 | } |
5227 | |
5228 | TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { |
5229 | // Input IR: |
5230 | // for (int m = 0; m < 50; m++) { |
5231 | // for (int i = 0; i < 20; i++) { |
5232 | // A[m,i] = 0; |
5233 | // for (int j = 0; j < 100; j++) { |
5234 | // A[m,i] = A[m,i] + i * j; |
5235 | // } |
5236 | // B[m,i] = A[m,i]; |
5237 | // for (int k = 0; k < 50; k++) { |
5238 | // B[m,i] = B[m,i] + i * k; |
5239 | // } |
5240 | // } |
5241 | // } |
5242 | BufHandle a_buf("A" , {100, 100}, kInt); |
5243 | BufHandle b_buf("B" , {100, 100}, kInt); |
5244 | VarHandle m("m" , kInt); |
5245 | VarHandle i("i" , kInt); |
5246 | VarHandle j("j" , kInt); |
5247 | VarHandle k("k" , kInt); |
5248 | auto initA = Store::make(a_buf, {m, i}, 0); |
5249 | auto forJ = For::make( |
5250 | j, |
5251 | 0, |
5252 | 100, |
5253 | Store::make( |
5254 | a_buf, |
5255 | {m, i}, |
5256 | Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); |
5257 | auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); |
5258 | auto forK = For::make( |
5259 | k, |
5260 | 0, |
5261 | 50, |
5262 | Store::make( |
5263 | b_buf, |
5264 | {m, i}, |
5265 | Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); |
5266 | auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
5267 | |
5268 | { |
5269 | // Check the case of distributing loop and its parents over all the |
5270 | // statements in the loop. |
5271 | const std::string& verification_pattern = |
5272 | R"IR( |
5273 | # CHECK: for (int m |
5274 | # CHECK-NEXT: for (int i |
5275 | # CHECK-NEXT: A[m, i] = 0 |
5276 | # CHECK: for (int m |
5277 | # CHECK-NEXT: for (int i |
5278 | # CHECK-NEXT: for (int j |
5279 | # CHECK-NEXT: A[m, i] = |
5280 | # CHECK: for (int m |
5281 | # CHECK-NEXT: for (int i |
5282 | # CHECK-NEXT: B[m, i] = A[m, i] |
5283 | # CHECK: for (int m |
5284 | # CHECK-NEXT: for (int i |
5285 | # CHECK-NEXT: for (int k |
5286 | # CHECK-NEXT: B[m, i] = |
5287 | # CHECK-NOT: for ( |
5288 | )IR" ; |
5289 | |
5290 | auto newForI = to<For>(Stmt::clone(forI)); |
5291 | auto forM = For::make(m, 0, 50, newForI); |
5292 | auto par = Block::make({forM}); |
5293 | LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
5294 | auto newLoops = LoopNest::distributeLoopAndParents(newForI); |
5295 | |
5296 | std::ostringstream oss; |
5297 | oss << *par; |
5298 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5299 | |
5300 | // The first loop after distribution must be same as the original For. |
5301 | ASSERT_EQ(newLoops.front(), forM); |
5302 | } |
5303 | |
5304 | { |
5305 | // Check the case of distributing loop and its parents over all the inner |
5306 | // loops. |
5307 | const std::string& verification_pattern = |
5308 | R"IR( |
5309 | # CHECK: for (int m |
5310 | # CHECK-NEXT: for (int i |
5311 | # CHECK-NEXT: A[m, i] = 0 |
5312 | # CHECK-NEXT: for (int j |
5313 | # CHECK-NEXT: A[m, i] = |
5314 | # CHECK: for (int m |
5315 | # CHECK-NEXT: for (int i |
5316 | # CHECK-NEXT: B[m, i] = A[m, i] |
5317 | # CHECK-NEXT: for (int k |
5318 | # CHECK-NEXT: B[m, i] = |
5319 | # CHECK-NOT: for ( |
5320 | )IR" ; |
5321 | |
5322 | auto newForI = to<For>(Stmt::clone(forI)); |
5323 | auto forM = For::make(m, 0, 50, newForI); |
5324 | auto par = Block::make({forM}); |
5325 | LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
5326 | auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); |
5327 | |
5328 | std::ostringstream oss; |
5329 | oss << *par; |
5330 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5331 | |
5332 | // The first loop after distribution must be same as the original For. |
5333 | ASSERT_EQ(newLoops.front(), forM); |
5334 | } |
5335 | } |
5336 | |
5337 | TEST(LoopNest, fuseLoopsSimple) { |
5338 | // Input IR: |
5339 | // for (int j = 0; j < 100; j++) { |
5340 | // A[j] = 10 * j; |
5341 | // } |
5342 | // for (int k = 0; k < 100; k++) { |
5343 | // B[k] = 20 * k; |
5344 | // } |
5345 | BufHandle a_buf("A" , {100}, kInt); |
5346 | BufHandle b_buf("B" , {100}, kInt); |
5347 | VarHandle j("j" , kInt); |
5348 | VarHandle k("k" , kInt); |
5349 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
5350 | auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); |
5351 | auto par = Block::make({forJ, forK}); |
5352 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5353 | ForPtr fused_loop; |
5354 | ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5355 | |
5356 | std::ostringstream oss; |
5357 | oss << *par; |
5358 | const std::string& verification_pattern = |
5359 | R"IR( |
5360 | # CHECK: for (int j |
5361 | # CHECK-NEXT: A[j] = |
5362 | # CHECK-NEXT: B[j] = |
5363 | # CHECK-NOT: for ( |
5364 | )IR" ; |
5365 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5366 | |
5367 | // The fused loop must be the same as the first loop. |
5368 | ASSERT_EQ(fused_loop, forJ); |
5369 | } |
5370 | |
5371 | TEST(LoopNest, fuseLoopsMultiple) { |
5372 | // Input IR: |
5373 | // for (int i = 0; i < 100; i++) { |
5374 | // A[i+100] = 20 + i; |
5375 | // } |
5376 | // for (int j = 0; j < 100; j++) { |
5377 | // A[j] = 10 * j; |
5378 | // } |
5379 | // for (int k = 0; k < 100; k++) { |
5380 | // B[k] = 20 * k; |
5381 | // } |
5382 | BufHandle a_buf("A" , {200}, kInt); |
5383 | BufHandle b_buf("B" , {100}, kInt); |
5384 | VarHandle i("i" , kInt); |
5385 | VarHandle j("j" , kInt); |
5386 | VarHandle k("k" , kInt); |
5387 | auto forI = |
5388 | For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); |
5389 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
5390 | auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); |
5391 | auto par = Block::make({forI, forJ, forK}); |
5392 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5393 | ForPtr fused_loop; |
5394 | ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); |
5395 | |
5396 | std::ostringstream oss; |
5397 | oss << *par; |
5398 | const std::string& verification_pattern = |
5399 | R"IR( |
5400 | # CHECK: for (int i |
5401 | # CHECK-NEXT: A[i + 100] = |
5402 | # CHECK-NEXT: A[i] = |
5403 | # CHECK-NEXT: B[i] = |
5404 | # CHECK-NOT: for ( |
5405 | )IR" ; |
5406 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5407 | |
5408 | // The fused loop must be the same as the first loop. |
5409 | ASSERT_EQ(fused_loop, forI); |
5410 | } |
5411 | |
5412 | TEST(LoopNest, fuseLoopsNested) { |
5413 | // Input IR: |
5414 | // for (int m = 0; m < 20; m++) { |
5415 | // A[m] = 0; |
5416 | // for (int j = 0; j < 100; j++) { |
5417 | // A[m] = A[m] + m * j; |
5418 | // } |
5419 | // } |
5420 | // for (int n = 0; n < 20; n++) { |
5421 | // B[n] = A[n]; |
5422 | // for (int k = 0; k < 50; k++) { |
5423 | // B[n] = B[n] + n * k; |
5424 | // } |
5425 | // } |
5426 | BufHandle a_buf("A" , {20, 100}, kInt); |
5427 | BufHandle b_buf("B" , {20, 100}, kInt); |
5428 | VarHandle m("m" , kInt); |
5429 | VarHandle n("n" , kInt); |
5430 | VarHandle j("j" , kInt); |
5431 | VarHandle k("k" , kInt); |
5432 | auto initA = Store::make(a_buf, {m}, 0); |
5433 | auto forJ = For::make( |
5434 | j, |
5435 | 0, |
5436 | 100, |
5437 | Store::make( |
5438 | a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); |
5439 | auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); |
5440 | auto forK = For::make( |
5441 | k, |
5442 | 0, |
5443 | 50, |
5444 | Store::make( |
5445 | b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); |
5446 | auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); |
5447 | auto forN = For::make(n, 0, 20, Block::make({initB, forK})); |
5448 | auto par = Block::make({forM, forN}); |
5449 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5450 | ForPtr fused_loop; |
5451 | ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); |
5452 | |
5453 | std::ostringstream oss; |
5454 | oss << *par; |
5455 | const std::string& verification_pattern = |
5456 | R"IR( |
5457 | # CHECK: for (int m |
5458 | # CHECK-NEXT: A[m] = 0 |
5459 | # CHECK-NEXT: for (int j |
5460 | # CHECK-NEXT: A[m] = |
5461 | # CHECK: B[m] = A[m] |
5462 | # CHECK-NEXT: for (int k |
5463 | # CHECK-NEXT: B[m] = |
5464 | # CHECK-NOT: for ( |
5465 | )IR" ; |
5466 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5467 | |
5468 | // The fused loop must be the same as the first loop. |
5469 | ASSERT_EQ(fused_loop, forM); |
5470 | } |
5471 | |
5472 | TEST(LoopNest, fuseLoopsNested2D) { |
5473 | // Input IR: |
5474 | // for (int i = 0; i < 20; i++) { |
5475 | // for (int j = 0; j < 100; j++) { |
5476 | // A[i,j] = i * j * 500; |
5477 | // } |
5478 | // } |
5479 | // for (int m = 0; m < 20; m++) { |
5480 | // for (int n = 0; n < 50; n++) { |
5481 | // B[m,n] = m + n * 100; |
5482 | // } |
5483 | // } |
5484 | BufHandle a_buf("A" , {20, 100}, kInt); |
5485 | BufHandle b_buf("B" , {20, 100}, kInt); |
5486 | VarHandle i("i" , kInt); |
5487 | VarHandle j("j" , kInt); |
5488 | VarHandle m("m" , kInt); |
5489 | VarHandle n("n" , kInt); |
5490 | auto forI = For::make( |
5491 | i, |
5492 | 0, |
5493 | 20, |
5494 | For::make( |
5495 | j, |
5496 | 0, |
5497 | 100, |
5498 | Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); |
5499 | auto forM = For::make( |
5500 | m, |
5501 | 0, |
5502 | 20, |
5503 | For::make( |
5504 | n, |
5505 | 0, |
5506 | 50, |
5507 | Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); |
5508 | auto par = Block::make({forI, forM}); |
5509 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5510 | ForPtr fused_loop; |
5511 | ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
5512 | |
5513 | std::ostringstream oss; |
5514 | oss << *par; |
5515 | const std::string& verification_pattern = |
5516 | R"IR( |
5517 | # CHECK: for (int i |
5518 | # CHECK-NEXT: for (int j |
5519 | # CHECK-NEXT: A[i, j] = |
5520 | # CHECK: for (int n |
5521 | # CHECK-NEXT: B[i, n] = |
5522 | # CHECK-NOT: for ( |
5523 | )IR" ; |
5524 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5525 | |
5526 | // The fused loop must be the same as the first loop. |
5527 | ASSERT_EQ(fused_loop, forI); |
5528 | } |
5529 | |
5530 | TEST(LoopNest, fuseLoopsNested2DInner) { |
5531 | // Input IR: |
5532 | // for (int i = 0; i < 20; i++) { |
5533 | // for (int j = 0; j < 100; j++) { |
5534 | // A[i,j] = i * j * 500; |
5535 | // } |
5536 | // for (int n = 0; n < 100; n++) { |
5537 | // B[i,n] = m + n * 100; |
5538 | // } |
5539 | // } |
5540 | BufHandle a_buf("A" , {20, 100}, kInt); |
5541 | BufHandle b_buf("B" , {20, 100}, kInt); |
5542 | VarHandle i("i" , kInt); |
5543 | VarHandle j("j" , kInt); |
5544 | VarHandle n("n" , kInt); |
5545 | auto forJ = For::make( |
5546 | j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); |
5547 | auto forN = For::make( |
5548 | n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); |
5549 | auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); |
5550 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5551 | ForPtr fused_loop; |
5552 | ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); |
5553 | |
5554 | std::ostringstream oss; |
5555 | oss << *forI; |
5556 | const std::string& verification_pattern = |
5557 | R"IR( |
5558 | # CHECK: for (int i |
5559 | # CHECK-NEXT: for (int j |
5560 | # CHECK-NEXT: A[i, j] = |
5561 | # CHECK-NEXT: B[i, j] = |
5562 | # CHECK-NOT: for ( |
5563 | )IR" ; |
5564 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5565 | |
5566 | // The fused loop must be the same as the first loop. |
5567 | ASSERT_EQ(fused_loop, forJ); |
5568 | } |
5569 | |
5570 | TEST(LoopNest, fuseLoopsDifferentStopBounds) { |
5571 | // Input IR: |
5572 | // for (int j = 0; j < 100; j++) { |
5573 | // A[j] = 10 * j; |
5574 | // } |
5575 | // for (int k = 0; k < 50; k++) { |
5576 | // B[k] = 20 * k; |
5577 | // } |
5578 | BufHandle a_buf("A" , {100}, kInt); |
5579 | BufHandle b_buf("B" , {100}, kInt); |
5580 | VarHandle j("j" , kInt); |
5581 | VarHandle k("k" , kInt); |
5582 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
5583 | auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); |
5584 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
5585 | auto par = Block::make({forJ, forK}); |
5586 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5587 | ForPtr fused_loop; |
5588 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5589 | } |
5590 | |
5591 | TEST(LoopNest, fuseLoopsDifferentStartBounds) { |
5592 | // Input IR: |
5593 | // for (int j = 0; j < 100; j++) { |
5594 | // A[j] = 10 * j; |
5595 | // } |
5596 | // for (int k = 50; k < 100; k++) { |
5597 | // B[k] = 20 * k; |
5598 | // } |
5599 | BufHandle a_buf("A" , {100}, kInt); |
5600 | BufHandle b_buf("B" , {100}, kInt); |
5601 | VarHandle j("j" , kInt); |
5602 | VarHandle k("k" , kInt); |
5603 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
5604 | auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); |
5605 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
5606 | auto par = Block::make({forJ, forK}); |
5607 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5608 | ForPtr fused_loop; |
5609 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5610 | } |
5611 | |
5612 | TEST(LoopNest, fuseLoopsNotContiguous) { |
5613 | // Input IR: |
5614 | // for (int j = 0; j < 100; j++) { |
5615 | // A[j] = 10 * j; |
5616 | // } |
5617 | // B[0] = 0; |
5618 | // for (int k = 0; k < 100; k++) { |
5619 | // B[k] = 20 * k; |
5620 | // } |
5621 | BufHandle a_buf("A" , {100}, kInt); |
5622 | BufHandle b_buf("B" , {100}, kInt); |
5623 | VarHandle j("j" , kInt); |
5624 | VarHandle k("k" , kInt); |
5625 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
5626 | auto initB = Store::make(b_buf, {0}, 0); |
5627 | auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); |
5628 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
5629 | auto par = Block::make({forJ, initB, forK}); |
5630 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5631 | ForPtr fused_loop; |
5632 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5633 | } |
5634 | |
5635 | TEST(LoopNest, fuseLoopsWithDifferentParents) { |
5636 | // Input IR: |
5637 | // for (int i = 0; i < 50; i++) { |
5638 | // for (int j = 0; j < 100; j++) { |
5639 | // A[i,j] = i * j; |
5640 | // } |
5641 | // } |
5642 | // B[0] = 0; |
5643 | // for (int k = 50; k < 100; k++) { |
5644 | // B[k] = 20 * k; |
5645 | // } |
5646 | BufHandle a_buf("A" , {50, 100}, kInt); |
5647 | BufHandle b_buf("B" , {100}, kInt); |
5648 | VarHandle i("i" , kInt); |
5649 | VarHandle j("j" , kInt); |
5650 | VarHandle k("k" , kInt); |
5651 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); |
5652 | auto forI = For::make(i, 0, 50, forJ); |
5653 | auto initB = Store::make(b_buf, {0}, 0); |
5654 | auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); |
5655 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
5656 | auto par = Block::make({forI, initB, forK}); |
5657 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5658 | ForPtr fused_loop; |
5659 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5660 | } |
5661 | |
5662 | TEST(LoopNest, fuseLoopsWithVariableBounds) { |
5663 | // Input IR: |
5664 | // for (int j = 0; j < N; j++) { |
5665 | // A[j] = 10 * j; |
5666 | // } |
5667 | // for (int k = 0; k < N; k++) { |
5668 | // B[k] = 20 * k; |
5669 | // } |
5670 | BufHandle a_buf("A" , {20}, kInt); |
5671 | BufHandle b_buf("B" , {20}, kInt); |
5672 | VarHandle j("j" , kInt); |
5673 | VarHandle k("k" , kInt); |
5674 | VarHandle N("N" , kInt); |
5675 | auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); |
5676 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
5677 | auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); |
5678 | auto par = Block::make({forJ, forK}); |
5679 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5680 | ForPtr fused_loop; |
5681 | ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5682 | |
5683 | std::ostringstream oss; |
5684 | oss << *par; |
5685 | const std::string& verification_pattern = |
5686 | R"IR( |
5687 | # CHECK: for (int j |
5688 | # CHECK-NEXT: A[j] = |
5689 | # CHECK-NEXT: B[j] = |
5690 | # CHECK-NOT: for ( |
5691 | )IR" ; |
5692 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5693 | |
5694 | // The fused loop must be the same as the first loop. |
5695 | ASSERT_EQ(fused_loop, forJ); |
5696 | } |
5697 | |
5698 | TEST(LoopNest, fuseLoopsWithExprBounds) { |
5699 | // Input IR: |
5700 | // for (int j = 0; j < M + N; j++) { |
5701 | // A[j] = 10 * j; |
5702 | // } |
5703 | // for (int k = 0; k < M + N; k++) { |
5704 | // B[k] = 20 * k; |
5705 | // } |
5706 | BufHandle a_buf("A" , {20}, kInt); |
5707 | BufHandle b_buf("B" , {20}, kInt); |
5708 | VarHandle j("j" , kInt); |
5709 | VarHandle k("k" , kInt); |
5710 | VarHandle M("M" , kInt); |
5711 | VarHandle N("N" , kInt); |
5712 | auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); |
5713 | auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); |
5714 | auto par = Block::make({forJ, forK}); |
5715 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5716 | ForPtr fused_loop; |
5717 | ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5718 | |
5719 | std::ostringstream oss; |
5720 | oss << *par; |
5721 | const std::string& verification_pattern = |
5722 | R"IR( |
5723 | # CHECK: for (int j |
5724 | # CHECK-NEXT: A[j] = |
5725 | # CHECK-NEXT: B[j] = |
5726 | # CHECK-NOT: for ( |
5727 | )IR" ; |
5728 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5729 | |
5730 | // The fused loop must be the same as the first loop. |
5731 | ASSERT_EQ(fused_loop, forJ); |
5732 | } |
5733 | |
5734 | TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { |
5735 | // Input IR: |
5736 | // for (int j = M; j < N * 2; j++) { |
5737 | // A[j] = 10 * j; |
5738 | // } |
5739 | // for (int k = M; k < N + N; k++) { |
5740 | // B[k] = 20 * k; |
5741 | // } |
5742 | BufHandle a_buf("A" , {20}, kInt); |
5743 | BufHandle b_buf("B" , {20}, kInt); |
5744 | VarHandle j("j" , kInt); |
5745 | VarHandle k("k" , kInt); |
5746 | VarHandle M("M" , kInt); |
5747 | VarHandle N("N" , kInt); |
5748 | auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); |
5749 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
5750 | auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); |
5751 | auto par = Block::make({forJ, forK}); |
5752 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5753 | ForPtr fused_loop; |
5754 | ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5755 | |
5756 | std::ostringstream oss; |
5757 | oss << *par; |
5758 | const std::string& verification_pattern = |
5759 | R"IR( |
5760 | # CHECK: for (int j |
5761 | # CHECK-NEXT: A[j] = |
5762 | # CHECK-NEXT: B[j] = |
5763 | # CHECK-NOT: for ( |
5764 | )IR" ; |
5765 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5766 | |
5767 | // The fused loop must be the same as the first loop. |
5768 | ASSERT_EQ(fused_loop, forJ); |
5769 | } |
5770 | |
5771 | TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { |
5772 | // Input IR: |
5773 | // for (int j = 10; j < 100; j++) { |
5774 | // A[j] = 10 * j; |
5775 | // } |
5776 | // for (int k = 10; k < 100; k++) { |
5777 | // A[k+100] = 30 * k |
5778 | // } |
5779 | BufHandle a_buf("A" , {200}, kInt); |
5780 | VarHandle j("j" , kInt); |
5781 | VarHandle k("k" , kInt); |
5782 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
5783 | auto forK = |
5784 | For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); |
5785 | auto par = Block::make({forJ, forK}); |
5786 | |
5787 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5788 | ForPtr fused_loop; |
5789 | ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
5790 | |
5791 | std::ostringstream oss; |
5792 | oss << *par; |
5793 | const std::string& verification_pattern = |
5794 | R"IR( |
5795 | # CHECK: for (int j |
5796 | # CHECK-NEXT: A[j] = |
5797 | # CHECK-NEXT: A[j + 100] = |
5798 | # CHECK-NOT: for ( |
5799 | )IR" ; |
5800 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5801 | |
5802 | // The fused loop must be the same as the first loop. |
5803 | ASSERT_EQ(fused_loop, forJ); |
5804 | } |
5805 | |
5806 | TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { |
5807 | // Input IR: |
5808 | // for (int i = 0; i < 20; i++) { |
5809 | // for (int j = 0; j < 100; j++) { |
5810 | // A[i,j] = i * j * 500; |
5811 | // } |
5812 | // } |
5813 | // for (int m = 0; m < 20; m++) { |
5814 | // for (int n = 0; n < 50; n++) { |
5815 | // A[m+20,n+100] = m + n * 100; |
5816 | // } |
5817 | // } |
5818 | BufHandle a_buf("A" , {20, 100}, kInt); |
5819 | BufHandle b_buf("B" , {20, 50}, kInt); |
5820 | VarHandle i("i" , kInt); |
5821 | VarHandle j("j" , kInt); |
5822 | VarHandle m("m" , kInt); |
5823 | VarHandle n("n" , kInt); |
5824 | auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); |
5825 | auto forJ = For::make(j, 0, 100, storeA1); |
5826 | auto forI = For::make(i, 0, 20, forJ); |
5827 | auto storeA2 = |
5828 | Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); |
5829 | auto forN = For::make(n, 0, 50, storeA2); |
5830 | auto forM = For::make(m, 0, 20, forN); |
5831 | auto par = Block::make({forI, forM}); |
5832 | |
5833 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5834 | ForPtr fused_loop; |
5835 | ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
5836 | |
5837 | std::ostringstream oss; |
5838 | oss << *par; |
5839 | const std::string& verification_pattern = |
5840 | R"IR( |
5841 | # CHECK: for (int i |
5842 | # CHECK-NEXT: for (int j |
5843 | # CHECK-NEXT: A[i, j] = |
5844 | # CHECK: for (int n |
5845 | # CHECK-NEXT: A[i + 20, n + 100] = |
5846 | # CHECK-NOT: for ( |
5847 | )IR" ; |
5848 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5849 | |
5850 | // The fused loop must be the same as the first loop. |
5851 | ASSERT_EQ(fused_loop, forI); |
5852 | } |
5853 | |
5854 | TEST(LoopNest, fuseLoopsWithReductions) { |
5855 | // Input IR: |
5856 | // for (int i = 0; i < 20; i++) { |
5857 | // A[i] = 0 |
5858 | // for (int j = 0; j < 100; j++) { |
5859 | // A[i] = A[i] + B[i,j]; |
5860 | // } |
5861 | // } |
5862 | // for (int m = 0; m < 20; m++) { |
5863 | // C[m] = A[m]; |
5864 | // } |
5865 | BufHandle a_buf("A" , {20}, kInt); |
5866 | BufHandle b_buf("B" , {20, 100}, kInt); |
5867 | BufHandle c_buf("C" , {20}, kInt); |
5868 | VarHandle i("i" , kInt); |
5869 | VarHandle j("j" , kInt); |
5870 | VarHandle m("m" , kInt); |
5871 | auto initA = Store::make(a_buf, {i}, 0); |
5872 | auto sumA = Store::make( |
5873 | a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); |
5874 | auto forJ = For::make(j, 0, 100, sumA); |
5875 | auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); |
5876 | auto forM = |
5877 | For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); |
5878 | auto par = Block::make({forI, forM}); |
5879 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5880 | ForPtr fused_loop; |
5881 | ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
5882 | |
5883 | std::ostringstream oss; |
5884 | oss << *par; |
5885 | const std::string& verification_pattern = |
5886 | R"IR( |
5887 | # CHECK: for (int i |
5888 | # CHECK-NEXT: A[i] = |
5889 | # CHECK-NEXT: for (int j |
5890 | # CHECK-NEXT: A[i] = (A[i]) + |
5891 | # CHECK-NOT: for ( |
5892 | # CHECK: C[i] = A[i] |
5893 | )IR" ; |
5894 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5895 | |
5896 | // The fused loop must be the same as the first loop. |
5897 | ASSERT_EQ(fused_loop, forI); |
5898 | } |
5899 | |
5900 | TEST(LoopNest, fuseLoopsWith2DReductions) { |
5901 | // Input IR: |
5902 | // for (int i = 0; i < 20; i++) { |
5903 | // for (int j = 0; j < 50; j++) { |
5904 | // A[i,j] = 0 |
5905 | // for (int k = 0; k < 100; k++) { |
5906 | // A[i,j] = A[i,j] + B[i,j,k]; |
5907 | // } |
5908 | // } |
5909 | // } |
5910 | // for (int m = 0; m < 20; m++) { |
5911 | // for (int n = 0; n < 40; n++) { |
5912 | // C[m,n] = A[m,n]; |
5913 | // } |
5914 | // } |
5915 | BufHandle a_buf("A" , {20, 50}, kInt); |
5916 | BufHandle b_buf("B" , {20, 50, 100}, kInt); |
5917 | BufHandle c_buf("C" , {20, 40}, kInt); |
5918 | VarHandle i("i" , kInt); |
5919 | VarHandle j("j" , kInt); |
5920 | VarHandle k("k" , kInt); |
5921 | VarHandle m("m" , kInt); |
5922 | VarHandle n("n" , kInt); |
5923 | auto initA = Store::make(a_buf, {i, j}, 0); |
5924 | auto sumA = Store::make( |
5925 | a_buf, |
5926 | {i, j}, |
5927 | Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); |
5928 | auto forK = For::make(k, 0, 100, sumA); |
5929 | auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); |
5930 | auto forI = For::make(i, 0, 20, forJ); |
5931 | auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); |
5932 | auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); |
5933 | auto par = Block::make({forI, forM}); |
5934 | |
5935 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5936 | ForPtr fused_loop; |
5937 | ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
5938 | |
5939 | std::ostringstream oss; |
5940 | oss << *par; |
5941 | const std::string& verification_pattern = |
5942 | R"IR( |
5943 | # CHECK: for (int i |
5944 | # CHECK-NEXT: for (int j |
5945 | # CHECK-NEXT: A[i, j] = |
5946 | # CHECK-NEXT: for (int k |
5947 | # CHECK-NEXT: A[i, j] = (A[i, j]) + |
5948 | # CHECK: for (int n |
5949 | # CHECK-NEXT: C[i, n] = A[i, n] |
5950 | # CHECK-NOT: for ( |
5951 | )IR" ; |
5952 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5953 | |
5954 | // The fused loop must be the same as the first loop. |
5955 | ASSERT_EQ(fused_loop, forI); |
5956 | } |
5957 | |
5958 | TEST(LoopNest, fuseLoopsWithComplexIndices) { |
5959 | // Input IR: |
5960 | // for (int i = 0; i < 20; i++) { |
5961 | // for (int j = 0; j < 20; j++) { |
5962 | // A[i,j*20+j+2] = i + j; |
5963 | // } |
5964 | // } |
5965 | // for (int m = 0; m < 20; m++) { |
5966 | // for (int n = 0; n < 20; n++) { |
5967 | // B[m,n] = A[m,n*20+n+2]; |
5968 | // } |
5969 | // } |
5970 | BufHandle a_buf("A" , {20, 400}, kInt); |
5971 | BufHandle b_buf("B" , {20, 400}, kInt); |
5972 | VarHandle i("i" , kInt); |
5973 | VarHandle j("j" , kInt); |
5974 | VarHandle m("m" , kInt); |
5975 | VarHandle n("n" , kInt); |
5976 | auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); |
5977 | auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); |
5978 | auto storeB = |
5979 | Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); |
5980 | auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); |
5981 | auto par = Block::make({forI, forM}); |
5982 | |
5983 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
5984 | ForPtr fused_loop; |
5985 | ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
5986 | |
5987 | std::ostringstream oss; |
5988 | oss << *par; |
5989 | const std::string& verification_pattern = |
5990 | R"IR( |
5991 | # CHECK: for (int i |
5992 | # CHECK-NEXT: for (int j |
5993 | # CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j |
5994 | # CHECK: for (int n |
5995 | # CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] |
5996 | # CHECK-NOT: for ( |
5997 | )IR" ; |
5998 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
5999 | |
6000 | // The fused loop must be the same as the first loop. |
6001 | ASSERT_EQ(fused_loop, forI); |
6002 | } |
6003 | |
6004 | TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { |
6005 | // Input IR: |
6006 | // for (int i = 0; i < 20; i++) { |
6007 | // for (int j = 0; j < 20; j++) { |
6008 | // A[i,i*20+j] = i + j; |
6009 | // } |
6010 | // } |
6011 | // for (int m = 0; m < 20; m++) { |
6012 | // for (int n = 0; n < 20; n++) { |
6013 | // B[m,n] = A[m,m*20+n]; // Both indices of A use m |
6014 | // } |
6015 | // } |
6016 | BufHandle a_buf("A" , {20, 500}, kInt); |
6017 | BufHandle b_buf("B" , {20, 500}, kInt); |
6018 | VarHandle i("i" , kInt); |
6019 | VarHandle j("j" , kInt); |
6020 | VarHandle m("m" , kInt); |
6021 | VarHandle n("n" , kInt); |
6022 | auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); |
6023 | auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); |
6024 | auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); |
6025 | auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); |
6026 | auto par = Block::make({forI, forM}); |
6027 | |
6028 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6029 | ForPtr fused_loop; |
6030 | ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
6031 | } |
6032 | |
6033 | TEST(LoopNest, fuseLoopsWithTranspose) { |
6034 | // Input IR: |
6035 | // for (int i = 0; i < 20; i++) { |
6036 | // for (int j = 0; j < 20; j++) { |
6037 | // A[i,j] = i + j; |
6038 | // } |
6039 | // } |
6040 | // for (int m = 0; m < 20; m++) { |
6041 | // for (int n = 0; n < 20; n++) { |
6042 | // B[m,n] = A[n,m]; // Transpose |
6043 | // } |
6044 | // } |
6045 | BufHandle a_buf("A" , {20, 20}, kInt); |
6046 | BufHandle b_buf("B" , {20, 20}, kInt); |
6047 | VarHandle i("i" , kInt); |
6048 | VarHandle j("j" , kInt); |
6049 | VarHandle m("m" , kInt); |
6050 | VarHandle n("n" , kInt); |
6051 | auto writeA = Store::make(a_buf, {i, j}, i + j); |
6052 | auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); |
6053 | auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); |
6054 | auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); |
6055 | auto par = Block::make({forI, forM}); |
6056 | |
6057 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6058 | ForPtr fused_loop; |
6059 | ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
6060 | } |
6061 | |
6062 | TEST(LoopNest, fuseLoopsThatViolateDependencies1) { |
6063 | // Input IR: |
6064 | // for (int j = 10; j < 100; j++) { |
6065 | // A[j] = 10 * j; |
6066 | // } |
6067 | // for (int k = 10; k < 100; k++) { |
6068 | // A[k-1] = 20 * k; |
6069 | // } |
6070 | BufHandle a_buf("A" , {100}, kInt); |
6071 | VarHandle j("j" , kInt); |
6072 | VarHandle k("k" , kInt); |
6073 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
6074 | auto forK = |
6075 | For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); |
6076 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6077 | auto par = Block::make({forJ, forK}); |
6078 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6079 | ForPtr fused_loop; |
6080 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
6081 | } |
6082 | |
6083 | TEST(LoopNest, fuseLoopsThatViolateDependencies2) { |
6084 | // Input IR: |
6085 | // for (int j = 10; j < 100; j++) { |
6086 | // A[j] = 10 * j; |
6087 | // } |
6088 | // for (int k = 10; k < 100; k++) { |
6089 | // A[k+50] = 20 * k; |
6090 | // } |
6091 | BufHandle a_buf("A" , {150}, kInt); |
6092 | VarHandle j("j" , kInt); |
6093 | VarHandle k("k" , kInt); |
6094 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
6095 | auto forK = |
6096 | For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); |
6097 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6098 | auto par = Block::make({forJ, forK}); |
6099 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6100 | ForPtr fused_loop; |
6101 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
6102 | } |
6103 | |
6104 | TEST(LoopNest, fuseLoopsThatViolateDependencies3) { |
6105 | // Input IR: |
6106 | // for (int m = 0; m < 20; m++) { |
6107 | // A[m] = 0; |
6108 | // for (int j = 0; j < 100; j++) { |
6109 | // A[m] = A[m] + m * j; |
6110 | // } |
6111 | // } |
6112 | // for (int n = 0; n < 20; n++) { |
6113 | // B[n] = A[n+1]; |
6114 | // for (int k = 0; k < 50; k++) { |
6115 | // B[n] = B[n] + n * k; |
6116 | // } |
6117 | // } |
6118 | BufHandle a_buf("A" , {25, 100}, kInt); |
6119 | BufHandle b_buf("B" , {20, 50}, kInt); |
6120 | VarHandle m("m" , kInt); |
6121 | VarHandle n("n" , kInt); |
6122 | VarHandle j("j" , kInt); |
6123 | VarHandle k("k" , kInt); |
6124 | auto initA = Store::make(a_buf, {m}, 0); |
6125 | auto forJ = For::make( |
6126 | j, |
6127 | 0, |
6128 | 100, |
6129 | Store::make( |
6130 | a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); |
6131 | auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); |
6132 | auto forK = For::make( |
6133 | k, |
6134 | 0, |
6135 | 50, |
6136 | Store::make( |
6137 | b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); |
6138 | auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); |
6139 | auto forN = For::make(n, 0, 20, Block::make({initB, forK})); |
6140 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6141 | auto par = Block::make({forM, forN}); |
6142 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6143 | ForPtr fused_loop; |
6144 | ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); |
6145 | } |
6146 | |
6147 | TEST(LoopNest, fuseLoopsThatViolateDependencies4) { |
6148 | // Input IR: |
6149 | // for (int i = 0; i < 20; i++) { |
6150 | // for (int j = 0; j < 100; j++) { |
6151 | // A[i,j] = i * j * 500; |
6152 | // } |
6153 | // } |
6154 | // for (int m = 0; m < 20; m++) { |
6155 | // for (int n = 0; n < 50; n++) { |
6156 | // A[m+1,n] = m + n * 100; |
6157 | // } |
6158 | // } |
6159 | BufHandle a_buf("A" , {30, 100}, kInt); |
6160 | VarHandle i("i" , kInt); |
6161 | VarHandle j("j" , kInt); |
6162 | VarHandle m("m" , kInt); |
6163 | VarHandle n("n" , kInt); |
6164 | auto forI = For::make( |
6165 | i, |
6166 | 0, |
6167 | 20, |
6168 | For::make( |
6169 | j, |
6170 | 0, |
6171 | 100, |
6172 | Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); |
6173 | auto forM = For::make( |
6174 | m, |
6175 | 0, |
6176 | 20, |
6177 | For::make( |
6178 | n, |
6179 | 0, |
6180 | 50, |
6181 | Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); |
6182 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6183 | auto par = Block::make({forI, forM}); |
6184 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6185 | ForPtr fused_loop; |
6186 | ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
6187 | } |
6188 | |
6189 | TEST(LoopNest, fuseLoopsThatViolateDependencies5) { |
6190 | // Input IR: |
6191 | // for (int i = 0; i < 20; i++) { |
6192 | // for (int j = 0; j < 100; j++) { |
6193 | // A[i,j] = i * j * 500; |
6194 | // } |
6195 | // for (int n = 0; n < 100; n++) { |
6196 | // A[i,n+1] = m + n * 100; |
6197 | // } |
6198 | // } |
6199 | BufHandle a_buf("A" , {20, 200}, kInt); |
6200 | VarHandle i("i" , kInt); |
6201 | VarHandle j("j" , kInt); |
6202 | VarHandle n("n" , kInt); |
6203 | auto forJ = For::make( |
6204 | j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); |
6205 | auto forN = For::make( |
6206 | n, |
6207 | 0, |
6208 | 100, |
6209 | Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); |
6210 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) |
6211 | auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); |
6212 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6213 | ForPtr fused_loop; |
6214 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); |
6215 | } |
6216 | |
6217 | TEST(LoopNest, fuseLoopsThatViolateDependencies6) { |
6218 | // Input IR: |
6219 | // for (int j = 0; j < 100; j++) { |
6220 | // A[j] = 10 * j; |
6221 | // } |
6222 | // for (int k = 0; k < 100; k++) { |
6223 | // B[k] = 20 * A[99-k]; |
6224 | // } |
6225 | BufHandle a_buf("A" , {100}, kInt); |
6226 | BufHandle b_buf("B" , {100}, kInt); |
6227 | VarHandle j("j" , kInt); |
6228 | VarHandle k("k" , kInt); |
6229 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
6230 | auto forK = For::make( |
6231 | k, |
6232 | 0, |
6233 | 100, |
6234 | Store::make( |
6235 | b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
6236 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6237 | auto par = Block::make({forJ, forK}); |
6238 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6239 | ForPtr fused_loop; |
6240 | ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
6241 | } |
6242 | |
6243 | TEST(LoopNest, fuseLoopsThatViolateDependencies7) { |
6244 | // Input IR: |
6245 | // for (int k = 0; k < 100; k++) { |
6246 | // B[k] = 20 * A[99-k]; |
6247 | // } |
6248 | // for (int j = 0; j < 100; j++) { |
6249 | // A[j] = 10 * j; |
6250 | // } |
6251 | BufHandle a_buf("A" , {100}, kInt); |
6252 | BufHandle b_buf("B" , {100}, kInt); |
6253 | VarHandle j("j" , kInt); |
6254 | VarHandle k("k" , kInt); |
6255 | auto forK = For::make( |
6256 | k, |
6257 | 0, |
6258 | 100, |
6259 | Store::make( |
6260 | b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
6261 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
6262 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6263 | auto par = Block::make({forK, forJ}); |
6264 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
6265 | ForPtr fused_loop; |
6266 | ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); |
6267 | } |
6268 | |
6269 | TEST(LoopNest, areLoopsPerfectlyNested) { |
6270 | // Input IR: |
6271 | // for (int i = 0; i < 20; i++) { |
6272 | // for (int j = 0; j < 30; j++) { |
6273 | // for (int k = 0; k < 40; k++) { |
6274 | // A[i,j,k] = i * j * k; |
6275 | // } |
6276 | // } |
6277 | // } |
6278 | BufHandle a_buf("A" , {20, 30, 40}, kInt); |
6279 | VarHandle i("i" , kInt); |
6280 | VarHandle j("j" , kInt); |
6281 | VarHandle k("k" , kInt); |
6282 | auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
6283 | auto forK = For::make(k, 0, 40, store); |
6284 | auto forJ = For::make(j, 0, 30, forK); |
6285 | auto forI = For::make(i, 0, 20, forJ); |
6286 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6287 | auto par = Block::make({forI}); |
6288 | ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
6289 | |
6290 | // Specifying the loops in any other order fails. |
6291 | ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); |
6292 | ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); |
6293 | ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); |
6294 | |
6295 | // Adding a statment to forK body should be OK. |
6296 | auto init = Store::make(a_buf, {i, j}, 0); |
6297 | forK->body()->insert_stmt_before(init, store); |
6298 | ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
6299 | |
6300 | // Adding a statement in forJ body should fail this test. |
6301 | forK->body()->remove_stmt(init); |
6302 | forJ->body()->insert_stmt_before(init, forK); |
6303 | ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
6304 | |
6305 | // Similarly, adding a statement in forI body should fail this test. |
6306 | forJ->body()->remove_stmt(init); |
6307 | forI->body()->insert_stmt_before(init, forJ); |
6308 | ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
6309 | } |
6310 | |
6311 | TEST(LoopNest, reorderNestedLoops2D) { |
6312 | // Input IR: |
6313 | // for (int i = 0; i < 20; i++) { |
6314 | // for (int j = 0; j < 30; j++) { |
6315 | // A[i,j] = i * j; |
6316 | // } |
6317 | // } |
6318 | BufHandle a_buf("A" , {20, 30, 40}, kInt); |
6319 | VarHandle i("i" , kInt); |
6320 | VarHandle j("j" , kInt); |
6321 | auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); |
6322 | auto forJ = For::make(j, 0, 30, store); |
6323 | auto forI = For::make(i, 0, 20, forJ); |
6324 | auto par = Block::make({forI}); |
6325 | |
6326 | auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); |
6327 | |
6328 | ASSERT_EQ(reordered[0], forJ); |
6329 | ASSERT_EQ(reordered[1], forI); |
6330 | ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); |
6331 | ASSERT_EQ(forJ->get_parent(), par); |
6332 | ASSERT_EQ(store->get_parent(), forI->body()); |
6333 | } |
6334 | |
6335 | TEST(LoopNest, reorderNestedLoops3D) { |
6336 | // Input IR: |
6337 | // for (int i = 0; i < 20; i++) { |
6338 | // for (int j = 0; j < 30; j++) { |
6339 | // for (int k = 0; k < 40; k++) { |
6340 | // A[i,j,k] = i * j * k; |
6341 | // } |
6342 | // } |
6343 | // } |
6344 | BufHandle a_buf("A" , {20, 30, 40}, kInt); |
6345 | VarHandle i("i" , kInt); |
6346 | VarHandle j("j" , kInt); |
6347 | VarHandle k("k" , kInt); |
6348 | auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
6349 | auto forK = For::make(k, 0, 40, store); |
6350 | auto forJ = For::make(j, 0, 30, forK); |
6351 | auto forI = For::make(i, 0, 20, forJ); |
6352 | auto par = Block::make({forI}); |
6353 | |
6354 | auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); |
6355 | |
6356 | ASSERT_EQ(reordered[0], forK); |
6357 | ASSERT_EQ(reordered[1], forI); |
6358 | ASSERT_EQ(reordered[2], forJ); |
6359 | ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); |
6360 | ASSERT_EQ(forK->get_parent(), par); |
6361 | ASSERT_EQ(store->get_parent(), forJ->body()); |
6362 | } |
6363 | |
6364 | TEST(LoopNest, reorderNestedLoops4D) { |
6365 | // Input IR: |
6366 | // for (int i = 0; i < 20; i++) { |
6367 | // for (int j = 0; j < 30; j++) { |
6368 | // for (int k = 0; k < 40; k++) { |
6369 | // for (int l = 0; l < 50; l++) { |
6370 | // A[i,j,k,l] = i * j * k * l * 500; |
6371 | // } |
6372 | // } |
6373 | // } |
6374 | // } |
6375 | BufHandle a_buf("A" , {20, 30, 40, 50}, kInt); |
6376 | VarHandle i("i" , kInt); |
6377 | VarHandle j("j" , kInt); |
6378 | VarHandle k("k" , kInt); |
6379 | VarHandle l("l" , kInt); |
6380 | auto store = Store::make( |
6381 | a_buf, |
6382 | {i, j, k, l}, |
6383 | Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); |
6384 | auto forL = For::make(l, 0, 50, store); |
6385 | auto forK = For::make(k, 0, 40, forL); |
6386 | auto forJ = For::make(j, 0, 30, forK); |
6387 | auto forI = For::make(i, 0, 20, forJ); |
6388 | auto par = Block::make({forI}); |
6389 | |
6390 | auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); |
6391 | |
6392 | ASSERT_EQ(reordered[0], forK); |
6393 | ASSERT_EQ(reordered[1], forI); |
6394 | ASSERT_EQ(reordered[2], forL); |
6395 | ASSERT_EQ(reordered[3], forJ); |
6396 | ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); |
6397 | ASSERT_EQ(forK->get_parent(), par); |
6398 | ASSERT_EQ(store->get_parent(), forJ->body()); |
6399 | } |
6400 | |
6401 | TEST(LoopNest, reorderTrivialPermutation) { |
6402 | // Input IR: |
6403 | // for (int i = 0; i < 20; i++) { |
6404 | // for (int j = 0; j < 30; j++) { |
6405 | // for (int k = 0; k < 40; k++) { |
6406 | // A[i,j,k] = i * j * k; |
6407 | // } |
6408 | // } |
6409 | // } |
6410 | BufHandle a_buf("A" , {20, 30, 40}, kInt); |
6411 | VarHandle i("i" , kInt); |
6412 | VarHandle j("j" , kInt); |
6413 | VarHandle k("k" , kInt); |
6414 | auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
6415 | auto forK = For::make(k, 0, 40, store); |
6416 | auto forJ = For::make(j, 0, 30, forK); |
6417 | auto forI = For::make(i, 0, 20, forJ); |
6418 | auto par = Block::make({forI}); |
6419 | |
6420 | auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); |
6421 | |
6422 | ASSERT_EQ(reordered[0], forI); |
6423 | ASSERT_EQ(reordered[1], forJ); |
6424 | ASSERT_EQ(reordered[2], forK); |
6425 | ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
6426 | ASSERT_EQ(forI->get_parent(), par); |
6427 | ASSERT_EQ(store->get_parent(), forK->body()); |
6428 | } |
6429 | |
6430 | TEST(LoopNest, reorderInvalidPermutations) { |
6431 | // Input IR: |
6432 | // for (int i = 0; i < 20; i++) { |
6433 | // for (int j = 0; j < 30; j++) { |
6434 | // for (int k = 0; k < 40; k++) { |
6435 | // A[i,j,k] = i * j * k; |
6436 | // } |
6437 | // } |
6438 | // } |
6439 | BufHandle a_buf("A" , {20, 30, 40}, kInt); |
6440 | VarHandle i("i" , kInt); |
6441 | VarHandle j("j" , kInt); |
6442 | VarHandle k("k" , kInt); |
6443 | auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
6444 | auto forK = For::make(k, 0, 40, store); |
6445 | auto forJ = For::make(j, 0, 30, forK); |
6446 | auto forI = For::make(i, 0, 20, forJ); |
6447 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6448 | auto par = Block::make({forI}); |
6449 | |
6450 | ASSERT_THROWS_WITH( |
6451 | LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), |
6452 | "invalid permutation size" ); |
6453 | ASSERT_THROWS_WITH( |
6454 | LoopNest::reorder({forI, forJ, forK}, {1, 2}), |
6455 | "invalid permutation size" ); |
6456 | ASSERT_THROWS_WITH( |
6457 | LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), |
6458 | "invalid permutation for reorder" ); |
6459 | ASSERT_THROWS_WITH( |
6460 | LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), |
6461 | "invalid permutation for reorder" ); |
6462 | ASSERT_THROWS_WITH( |
6463 | LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), |
6464 | "invalid permutation for reorder" ); |
6465 | } |
6466 | |
6467 | TEST(LoopNest, reorderInvalidLoopNest) { |
6468 | // Input IR: |
6469 | // for (int i = 0; i < 20; i++) { |
6470 | // for (int j = 0; j < 30; j++) { |
6471 | // A[i,j] = 0 |
6472 | // for (int k = 0; k < 40; k++) { |
6473 | // A[i,j,k] = i * j * k; |
6474 | // } |
6475 | // } |
6476 | // } |
6477 | BufHandle a_buf("A" , {20, 30, 40}, kInt); |
6478 | VarHandle i("i" , kInt); |
6479 | VarHandle j("j" , kInt); |
6480 | VarHandle k("k" , kInt); |
6481 | auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
6482 | auto forK = For::make(k, 0, 40, store); |
6483 | auto forJ = For::make(j, 0, 30, forK); |
6484 | auto forI = For::make(i, 0, 20, forJ); |
6485 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
6486 | auto par = Block::make({forI}); |
6487 | |
6488 | // Specifying the loops in incorrect order fails. |
6489 | ASSERT_THROWS_WITH( |
6490 | LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), |
6491 | "reorder is only allowed on perfectly nested loops" ); |
6492 | |
6493 | // Adding a statement to forJ loop fails. |
6494 | auto init = Store::make(a_buf, {i}, 0); |
6495 | forJ->body()->insert_stmt_before(init, forK); |
6496 | ASSERT_THROWS_WITH( |
6497 | LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), |
6498 | "reorder is only allowed on perfectly nested loops" ); |
6499 | |
6500 | // Moving that statement to forI loop also fails. |
6501 | forJ->body()->remove_stmt(init); |
6502 | forI->body()->insert_stmt_before(init, forJ); |
6503 | ASSERT_THROWS_WITH( |
6504 | LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), |
6505 | "reorder is only allowed on perfectly nested loops" ); |
6506 | } |
6507 | |
6508 | TEST(LoopNest, compressBufferSimple) { |
6509 | // Input IR: |
6510 | // for (int i = 0; i < 100; ++i) { |
6511 | // for (int j = 0; j < 200; ++j) { |
6512 | // A[i,j] = sin(i*j) |
6513 | // } |
6514 | // for (int j = 0; j < 199; ++j) { |
6515 | // B[i,j] = A[i,j] + A[i, j+1] |
6516 | // } |
6517 | // } |
6518 | BufHandle aBuf("A" , {100, 200}, kInt); |
6519 | BufHandle bBuf("B" , {100, 200}, kInt); |
6520 | VarHandle i("i" , kInt); |
6521 | VarHandle j("j" , kInt); |
6522 | auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); |
6523 | auto forJ2 = For::make( |
6524 | j, |
6525 | 0, |
6526 | 199, |
6527 | Store::make( |
6528 | bBuf, |
6529 | {i, j}, |
6530 | Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); |
6531 | auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); |
6532 | auto par = Block::make({forI}); |
6533 | LoopNest::compressBuffer(aBuf.node(), par); |
6534 | |
6535 | std::ostringstream oss; |
6536 | oss << *par; |
6537 | const std::string& verification_pattern = |
6538 | R"IR( |
6539 | # CHECK: for (int i |
6540 | # CHECK-NEXT: for (int j |
6541 | # CHECK-NEXT: A[0, j] = |
6542 | # CHECK: for (int j |
6543 | # CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) |
6544 | )IR" ; |
6545 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6546 | |
6547 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6548 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
6549 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
6550 | } |
6551 | |
6552 | TEST(LoopNest, compressBufferMultipleDims) { |
6553 | // Input IR: |
6554 | // for (int i = 0; i < 100; ++i) { |
6555 | // for (int j = 0; j < 200; ++j) { |
6556 | // A[i,j] = sin(i*j) |
6557 | // B[i,j] = A[i,j] + A[i,j] |
6558 | // } |
6559 | // } |
6560 | BufHandle aBuf("A" , {100, 200}, kInt); |
6561 | BufHandle bBuf("B" , {100, 200}, kInt); |
6562 | VarHandle i("i" , kInt); |
6563 | VarHandle j("j" , kInt); |
6564 | auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); |
6565 | auto store2 = Store::make( |
6566 | bBuf, |
6567 | {i, j}, |
6568 | Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); |
6569 | auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); |
6570 | auto forI = For::make(i, 0, 100, forJ); |
6571 | auto par = Block::make({forI}); |
6572 | LoopNest::compressBuffer(aBuf.node(), par); |
6573 | |
6574 | std::ostringstream oss; |
6575 | oss << *par; |
6576 | const std::string& verification_pattern = |
6577 | R"IR( |
6578 | # CHECK: for (int i |
6579 | # CHECK-NEXT: for (int j |
6580 | # CHECK-NEXT: A[0, 0] = |
6581 | # CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) |
6582 | )IR" ; |
6583 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6584 | |
6585 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6586 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
6587 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); |
6588 | } |
6589 | |
6590 | TEST(LoopNest, compressBufferMultipleDims2) { |
6591 | // Input IR: |
6592 | // for (int i = 0; i < 100; ++i) { |
6593 | // for (int j = 0; j < 200; ++j) { |
6594 | // for (int k = 0; k < 300; ++k) { |
6595 | // A[i,j,k] = sin(i*j*k) |
6596 | // } |
6597 | // for (int k = 0; k < 299; ++j) { |
6598 | // B[i,j,k] = A[i,j,k] + A[i,j,k+1] |
6599 | // } |
6600 | // } |
6601 | // } |
6602 | BufHandle aBuf("A" , {100, 200, 300}, kInt); |
6603 | BufHandle bBuf("B" , {100, 200, 300}, kInt); |
6604 | VarHandle i("i" , kInt); |
6605 | VarHandle j("j" , kInt); |
6606 | VarHandle k("k" , kInt); |
6607 | auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); |
6608 | auto forK1 = For::make(k, 0, 300, store1); |
6609 | auto store2 = Store::make( |
6610 | bBuf, |
6611 | {i, j, k}, |
6612 | Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); |
6613 | auto forK2 = For::make(k, 0, 299, store2); |
6614 | auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); |
6615 | auto forI = For::make(i, 0, 100, forJ); |
6616 | auto par = Block::make({forI}); |
6617 | LoopNest::compressBuffer(aBuf.node(), par); |
6618 | |
6619 | std::ostringstream oss; |
6620 | oss << *par; |
6621 | const std::string& verification_pattern = |
6622 | R"IR( |
6623 | # CHECK: for (int i |
6624 | # CHECK-NEXT: for (int j |
6625 | # CHECK-NEXT: for (int k |
6626 | # CHECK-NEXT: A[0, 0, k] = |
6627 | # CHECK: for (int k |
6628 | # CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) |
6629 | )IR" ; |
6630 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6631 | |
6632 | ASSERT_EQ(aBuf.node()->ndim(), 3); |
6633 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
6634 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); |
6635 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); |
6636 | } |
6637 | |
6638 | TEST(LoopNest, compressBufferDifferentOrderIndices) { |
6639 | // Input IR: |
6640 | // for (int i = 0; i < 100; ++i) { |
6641 | // for (int j = 0; j < 200; ++j) { |
6642 | // A[j, i] = sin(i*j) |
6643 | // } |
6644 | // for (int j = 0; j < 99; ++j) { |
6645 | // B[i, j] = A[j, i] + A[j+1, 0] |
6646 | // } |
6647 | // } |
6648 | BufHandle aBuf("A" , {100, 200}, kInt); |
6649 | BufHandle bBuf("B" , {100, 200}, kInt); |
6650 | VarHandle i("i" , kInt); |
6651 | VarHandle j("j" , kInt); |
6652 | auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); |
6653 | auto forJ2 = For::make( |
6654 | j, |
6655 | 0, |
6656 | 99, |
6657 | Store::make( |
6658 | bBuf, |
6659 | {i, j}, |
6660 | Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); |
6661 | auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); |
6662 | auto par = Block::make({forI}); |
6663 | LoopNest::compressBuffer(aBuf.node(), par); |
6664 | |
6665 | std::ostringstream oss; |
6666 | oss << *par; |
6667 | const std::string& verification_pattern = |
6668 | R"IR( |
6669 | # CHECK: for (int i |
6670 | # CHECK-NEXT: for (int j |
6671 | # CHECK-NEXT: A[j, 0] = |
6672 | # CHECK: for (int j |
6673 | # CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) |
6674 | )IR" ; |
6675 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6676 | |
6677 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6678 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); |
6679 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); |
6680 | } |
6681 | |
6682 | TEST(LoopNest, compressBufferVariableBounds) { |
6683 | // Input IR: |
6684 | // for (int i = 0; i < M; ++i) { |
6685 | // for (int j = 0; j < N; ++j) { |
6686 | // A[i,j] = sin(i*j) |
6687 | // } |
6688 | // for (int j = 0; j < N-1; ++j) { |
6689 | // B[i,j] = A[i,j] + A[i, j+1] |
6690 | // } |
6691 | // } |
6692 | BufHandle aBuf("A" , {100, 200}, kInt); |
6693 | BufHandle bBuf("B" , {100, 200}, kInt); |
6694 | VarHandle i("i" , kInt); |
6695 | VarHandle j("j" , kInt); |
6696 | VarHandle M("M" , kInt); |
6697 | VarHandle N("N" , kInt); |
6698 | auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); |
6699 | auto forJ2 = For::make( |
6700 | j, |
6701 | 0, |
6702 | N - 1, |
6703 | Store::make( |
6704 | bBuf, |
6705 | {i, j}, |
6706 | Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); |
6707 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
6708 | auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); |
6709 | auto par = Block::make({forI}); |
6710 | LoopNest::compressBuffer(aBuf.node(), par); |
6711 | |
6712 | std::ostringstream oss; |
6713 | oss << *par; |
6714 | const std::string& verification_pattern = |
6715 | R"IR( |
6716 | # CHECK: for (int i |
6717 | # CHECK-NEXT: for (int j |
6718 | # CHECK-NEXT: A[0, j] = |
6719 | # CHECK: for (int j |
6720 | # CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) |
6721 | )IR" ; |
6722 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6723 | |
6724 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6725 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
6726 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
6727 | } |
6728 | |
6729 | TEST(LoopNest, compressBufferNoCommonParentLoops) { |
6730 | // Input IR: |
6731 | // for (int i = 0; i < 100; ++i) { |
6732 | // for (int j = 0; j < 200; ++j) { |
6733 | // A[i,j] = sin(i*j) |
6734 | // } |
6735 | // } |
6736 | // for (int i = 0; i < 100; ++i) { |
6737 | // for (int j = 0; j < 199; ++j) { |
6738 | // B[i,j] = A[i,j] + A[i, j+1] |
6739 | // } |
6740 | // } |
6741 | BufHandle aBuf("A" , {100, 200}, kInt); |
6742 | BufHandle bBuf("B" , {100, 200}, kInt); |
6743 | VarHandle i("i" , kInt); |
6744 | VarHandle j("j" , kInt); |
6745 | auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); |
6746 | auto forJ2 = For::make( |
6747 | j, |
6748 | 0, |
6749 | 199, |
6750 | Store::make( |
6751 | bBuf, |
6752 | {i, j}, |
6753 | Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); |
6754 | auto forI1 = For::make(i, 0, 100, forJ1); |
6755 | auto forI2 = For::make(i, 0, 100, forJ2); |
6756 | auto par = Block::make({forI1, forI2}); |
6757 | LoopNest::compressBuffer(aBuf.node(), par); |
6758 | |
6759 | // There should be no change in the buffer or code. |
6760 | std::ostringstream oss; |
6761 | oss << *par; |
6762 | const std::string& verification_pattern = |
6763 | R"IR( |
6764 | # CHECK: for (int i |
6765 | # CHECK-NEXT: for (int j |
6766 | # CHECK-NEXT: A[i, j] = |
6767 | # CHECK: for (int i |
6768 | # CHECK-NEXT: for (int j |
6769 | # CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) |
6770 | )IR" ; |
6771 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6772 | |
6773 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6774 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); |
6775 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
6776 | } |
6777 | |
6778 | TEST(LoopNest, compressBufferIndicesMixed) { |
6779 | // Input IR: |
6780 | // for (int i = 0; i < 100; ++i) { |
6781 | // for (int j = 0; j < 200; ++j) { |
6782 | // A[i + j, j] = sin(i*j) |
6783 | // } |
6784 | // for (int j = 0; j < 199; ++j) { |
6785 | // B[i,j] = A[i + j, j] + A[i + j, j+1] |
6786 | // } |
6787 | // } |
6788 | BufHandle aBuf("A" , {300, 200}, kInt); |
6789 | BufHandle bBuf("B" , {100, 200}, kInt); |
6790 | VarHandle i("i" , kInt); |
6791 | VarHandle j("j" , kInt); |
6792 | auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); |
6793 | auto forJ2 = For::make( |
6794 | j, |
6795 | 0, |
6796 | 199, |
6797 | Store::make( |
6798 | bBuf, |
6799 | {i, j}, |
6800 | Add::make( |
6801 | Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); |
6802 | auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); |
6803 | auto par = Block::make({forI}); |
6804 | LoopNest::compressBuffer(aBuf.node(), par); |
6805 | |
6806 | // There should be no change in the buffer or code. |
6807 | std::ostringstream oss; |
6808 | oss << *par; |
6809 | const std::string& verification_pattern = |
6810 | R"IR( |
6811 | # CHECK: for (int i |
6812 | # CHECK-NEXT: for (int j |
6813 | # CHECK-NEXT: A[i + j, j] = |
6814 | # CHECK: for (int j |
6815 | # CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) |
6816 | )IR" ; |
6817 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6818 | |
6819 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6820 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); |
6821 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
6822 | } |
6823 | |
6824 | TEST(LoopNest, compressMultipleBuffers) { |
6825 | // Input IR: |
6826 | // for (int i = 0; i < 100; ++i) { |
6827 | // for (int j = 0; j < 200; ++j) { |
6828 | // A[i,j] = sin(i*j) |
6829 | // } |
6830 | // for (int k = 0; k < 199; ++k) { |
6831 | // B[i,k] = A[i,k] + A[i, k+1] |
6832 | // } |
6833 | // for (int m = 0; m < 50; ++m) { |
6834 | // C[i,m] = B[i,m] |
6835 | // } |
6836 | // } |
6837 | BufHandle aBuf("A" , {100, 200}, kInt); |
6838 | BufHandle bBuf("B" , {100, 200}, kInt); |
6839 | BufHandle cBuf("C" , {100, 200}, kInt); |
6840 | VarHandle i("i" , kInt); |
6841 | VarHandle j("j" , kInt); |
6842 | VarHandle k("k" , kInt); |
6843 | VarHandle m("m" , kInt); |
6844 | auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); |
6845 | auto forK = For::make( |
6846 | k, |
6847 | 0, |
6848 | 199, |
6849 | Store::make( |
6850 | bBuf, |
6851 | {i, k}, |
6852 | Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); |
6853 | auto forM = |
6854 | For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); |
6855 | auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); |
6856 | auto par = Block::make({forI}); |
6857 | |
6858 | // This should compress all buffers A, B, and C as follows: |
6859 | // A[100, 200] -> A[1, 200] |
6860 | // B[100, 200] -> B[1, 200] |
6861 | // C[100, 200] -> C[1, 1] |
6862 | LoopNest::compressAllBuffers(par); |
6863 | |
6864 | std::ostringstream oss; |
6865 | oss << *par; |
6866 | const std::string& verification_pattern = |
6867 | R"IR( |
6868 | # CHECK: for (int i |
6869 | # CHECK-NEXT: for (int j |
6870 | # CHECK-NEXT: A[0, j] = |
6871 | # CHECK: for (int k |
6872 | # CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) |
6873 | # CHECK: for (int m |
6874 | # CHECK-NEXT: C[0, 0] = B[0, m] |
6875 | )IR" ; |
6876 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6877 | |
6878 | ASSERT_EQ(aBuf.node()->ndim(), 2); |
6879 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
6880 | IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
6881 | ASSERT_EQ(bBuf.node()->ndim(), 2); |
6882 | IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); |
6883 | IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); |
6884 | ASSERT_EQ(cBuf.node()->ndim(), 2); |
6885 | IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); |
6886 | IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); |
6887 | } |
6888 | |
6889 | TEST(LoopNest, sanitizeNames) { |
6890 | std::vector<ExprHandle> dim_args; |
6891 | // Let's pick names that would overlap with default index names if not |
6892 | // sanitized properly: |
6893 | dim_args.emplace_back(ExprHandle(alloc<Var>("i" , kInt))); |
6894 | dim_args.emplace_back(ExprHandle(alloc<Var>("N:2" , kInt))); |
6895 | // Now let's create a many dimensions so that we had to use the same letter |
6896 | // for different loops |
6897 | for (int i = 0; i < 10; i++) { |
6898 | dim_args.emplace_back(ExprHandle(alloc<Var>("N" , kInt))); |
6899 | } |
6900 | |
6901 | // Now create two Computes with conflicting after sanitization names: |
6902 | Tensor X = Compute("$X:!" , dim_args, [&](const std::vector<VarHandle>& v) { |
6903 | return v[0] + v[1] + v[9] + 1; |
6904 | }); |
6905 | Tensor Y = Reduce( |
6906 | "%X\"+" , |
6907 | {}, |
6908 | Sum(), |
6909 | [&](const std::vector<VarHandle>& v) { return X.load(v); }, |
6910 | dim_args); |
6911 | |
6912 | // Finally, let's verify what we got after sanitization: |
6913 | LoopNest l({X, Y}); |
6914 | StmtPtr s = l.root_stmt(); |
6915 | LoopNest::sanitizeNames(s); |
6916 | |
6917 | std::ostringstream oss; |
6918 | oss << *s; |
6919 | const std::string& verification_pattern = |
6920 | R"IR( |
6921 | # CHECK: for (int i = 0; i < i_1; i++) { |
6922 | # CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { |
6923 | # CHECK-NEXT: for (int k = 0; k < N_9; k++) { |
6924 | # CHECK-NEXT: for (int l = 0; l < N_8; l++) { |
6925 | # CHECK-NEXT: for (int m = 0; m < N_7; m++) { |
6926 | # CHECK-NEXT: for (int n = 0; n < N_6; n++) { |
6927 | # CHECK-NEXT: for (int o = 0; o < N_5; o++) { |
6928 | # CHECK-NEXT: for (int p = 0; p < N_4; p++) { |
6929 | # CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { |
6930 | # CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { |
6931 | # CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { |
6932 | # CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { |
6933 | # CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; |
6934 | # CHECK: v_X___1 = int(0); |
6935 | # CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { |
6936 | # CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { |
6937 | # CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { |
6938 | # CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { |
6939 | # CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { |
6940 | # CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { |
6941 | # CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { |
6942 | # CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { |
6943 | # CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { |
6944 | # CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { |
6945 | # CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { |
6946 | # CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { |
6947 | # CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); |
6948 | )IR" ; |
6949 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
6950 | } |
6951 | |
6952 | } // namespace jit |
6953 | } // namespace torch |
6954 | |