1 | #include <gtest/gtest.h> |
2 | |
3 | #include <limits> |
4 | #include <memory> |
5 | #include <sstream> |
6 | #include <stdexcept> |
7 | #include <unordered_map> |
8 | |
9 | #include <test/cpp/tensorexpr/test_base.h> |
10 | |
11 | #include <c10/util/irange.h> |
12 | #include <test/cpp/tensorexpr/padded_buffer.h> |
13 | #include <torch/csrc/jit/tensorexpr/analysis.h> |
14 | #include <torch/csrc/jit/tensorexpr/eval.h> |
15 | #include <torch/csrc/jit/tensorexpr/ir.h> |
16 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
17 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
18 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
19 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
20 | #include <torch/csrc/jit/testing/file_check.h> |
21 | |
22 | namespace torch { |
23 | namespace jit { |
24 | |
25 | using namespace torch::jit::tensorexpr; |
26 | |
27 | TEST(Reductions, ReduceSum0D_1) { |
28 | const int M = 10; |
29 | |
30 | BufHandle b("b" , {M}, kFloat); |
31 | std::vector<float> in(M); |
32 | for (const auto j : c10::irange(M)) { |
33 | in[j] = j; |
34 | } |
35 | |
36 | std::vector<float> out(M, -1.f); |
37 | |
38 | Tensor c = Reduce("sum" , {M}, Sum(), b, {}); |
39 | LoopNest loop({c}); |
40 | loop.prepareForCodegen(); |
41 | StmtPtr s = loop.root_stmt(); |
42 | s = IRSimplifier::simplify(s); |
43 | |
44 | SimpleIREvaluator cg(s, {b, c}); |
45 | |
46 | cg.call({in, out}); |
47 | for (const auto i : c10::irange(M)) { |
48 | ASSERT_EQ(out[i], in[i]); |
49 | } |
50 | } |
51 | |
52 | TEST(Reductions, ReduceSum0D_2) { |
53 | BufHandle b("b" , {}, kFloat); |
54 | std::vector<float> in(1); |
55 | in[0] = 77.7; |
56 | |
57 | std::vector<float> out(1, -1.f); |
58 | |
59 | Tensor c = Reduce("sum" , {}, Sum(), b, {}); |
60 | LoopNest loop({c}); |
61 | loop.prepareForCodegen(); |
62 | StmtPtr s = loop.root_stmt(); |
63 | s = IRSimplifier::simplify(s); |
64 | |
65 | SimpleIREvaluator cg(s, {b, c}); |
66 | |
67 | cg.call({in, out}); |
68 | ASSERT_EQ(out[0], in[0]); |
69 | } |
70 | |
71 | // Sum an array to a single value. |
72 | TEST(Reductions, ReduceSum1D) { |
73 | BufHandle b("b" , {10}, kFloat); |
74 | std::vector<float> in(10); |
75 | for (const auto j : c10::irange(10)) { |
76 | in[j] = j; |
77 | } |
78 | |
79 | std::vector<float> out(1, -1.f); |
80 | |
81 | Tensor c = Reduce("sum" , {}, Sum(), b, {10}); |
82 | LoopNest loop({c}); |
83 | loop.prepareForCodegen(); |
84 | StmtPtr s = loop.root_stmt(); |
85 | s = IRSimplifier::simplify(s); |
86 | |
87 | SimpleIREvaluator cg(s, {b, c}); |
88 | |
89 | cg.call({in, out}); |
90 | ASSERT_EQ(out[0], 45); |
91 | } |
92 | // Sum a 2D tensor to a 1D tensor with dynamic shapes. |
93 | TEST(Reductions, ReduceSum2D) { |
94 | const int M = 3; |
95 | const int N = 7; |
96 | |
97 | VarHandle m("m" , kInt); |
98 | VarHandle n("n" , kInt); |
99 | |
100 | BufHandle b("b" , {m, n}, kFloat); |
101 | std::vector<float> in(M * N); |
102 | for (const auto i : c10::irange(M)) { |
103 | for (const auto j : c10::irange(N)) { |
104 | in[i * N + j] = j; |
105 | } |
106 | } |
107 | |
108 | std::vector<float> out(M, -1.f); |
109 | |
110 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N}); |
111 | LoopNest loop({c}); |
112 | loop.prepareForCodegen(); |
113 | StmtPtr s = loop.root_stmt(); |
114 | s = IRSimplifier::simplify(s); |
115 | |
116 | SimpleIREvaluator cg(s, {b, c, n, m}); |
117 | |
118 | cg.call({in, out, 5, 7}); |
119 | |
120 | float expected = 0; |
121 | for (const auto i : c10::irange(N)) { |
122 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
123 | expected += i; |
124 | } |
125 | |
126 | for (const auto i : c10::irange(M)) { |
127 | ASSERT_EQ(out[i], expected); |
128 | } |
129 | } |
130 | |
131 | // Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to |
132 | // check our work. |
133 | TEST(Reductions, ReduceSum3D) { |
134 | const int M = 10; |
135 | VarHandle m("m" , kInt); |
136 | |
137 | BufHandle b("b" , {2, 3, m}, kFloat); |
138 | |
139 | Tensor c = Reduce("sum" , {2, 3}, Sum(), b, {m}); |
140 | LoopNest loop({c}); |
141 | loop.prepareForCodegen(); |
142 | StmtPtr s = loop.root_stmt(); |
143 | s = IRSimplifier::simplify(s); |
144 | |
145 | SimpleIREvaluator cg(s, {b, c, m}); |
146 | |
147 | std::vector<float> bData(2 * 3 * M, 0); |
148 | std::vector<float> cData(2 * 3, 6.0f); |
149 | std::vector<float> dData(2, 1.0f); |
150 | std::vector<float> eData(2, 1.0f); |
151 | |
152 | for (int i = 0; i < 2 * 3; ++i) { |
153 | for (const auto j : c10::irange(M)) { |
154 | bData[i * M + j] = j; |
155 | } |
156 | } |
157 | |
158 | cg.call({bData, cData, M}); |
159 | float expected = 0; |
160 | for (const auto i : c10::irange(M)) { |
161 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
162 | expected += i; |
163 | } |
164 | |
165 | for (int i = 0; i < 2 * 3; ++i) { |
166 | ASSERT_EQ(cData[i], expected); |
167 | } |
168 | |
169 | Tensor d = Reduce("sum2" , {2}, Sum(), b, {3, m}); |
170 | LoopNest loop2({d}); |
171 | loop2.prepareForCodegen(); |
172 | StmtPtr s2 = loop2.root_stmt(); |
173 | s2 = IRSimplifier::simplify(s2); |
174 | |
175 | SimpleIREvaluator cg2(s2, {b, d, m}); |
176 | cg2.call({bData, dData, M}); |
177 | |
178 | // We're combining an additional dimension of 3, so the sum is 3x. |
179 | expected = expected * 3; |
180 | |
181 | for (const auto i : c10::irange(2)) { |
182 | ASSERT_EQ(dData[i], expected); |
183 | } |
184 | |
185 | // This is the same as just reducing the original result across that axis. |
186 | BufHandle c_buf(c.buf()); |
187 | Tensor e = Reduce("sum3" , {2}, Sum(), c_buf, {3}); |
188 | LoopNest loop3({e}); |
189 | loop3.prepareForCodegen(); |
190 | StmtPtr s3 = loop3.root_stmt(); |
191 | s3 = IRSimplifier::simplify(s3); |
192 | |
193 | SimpleIREvaluator cg3(s3, {c, e}); |
194 | cg3.call({cData, eData}); |
195 | |
196 | for (const auto i : c10::irange(2)) { |
197 | ASSERT_EQ(eData[i], expected); |
198 | } |
199 | } |
200 | |
201 | // Sum a large (10 D) Tensor 5 dimensions in. |
202 | TEST(Reductions, ReduceSum10D) { |
203 | BufHandle in_("in_" , {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat); |
204 | const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; |
205 | BufHandle out_("out_" , {2, 3, 2, 3, 2}, kFloat); |
206 | const int OutputSize = 2 * 3 * 2 * 3 * 2; |
207 | |
208 | std::vector<float> in(InputSize, 1.f); |
209 | std::vector<float> out(OutputSize, -1.f); |
210 | |
211 | Tensor c = Reduce("sum" , {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3}); |
212 | LoopNest loop({c}); |
213 | loop.prepareForCodegen(); |
214 | StmtPtr s = loop.root_stmt(); |
215 | s = IRSimplifier::simplify(s); |
216 | |
217 | SimpleIREvaluator cg(s, {in_, c}); |
218 | |
219 | cg.call({in, out}); |
220 | |
221 | // NOLINTNEXTLINE(bugprone-integer-division) |
222 | float expected = InputSize / OutputSize; |
223 | for (const auto i : c10::irange(OutputSize)) { |
224 | ASSERT_EQ(out[i], expected); |
225 | } |
226 | } |
227 | |
228 | // Reduce via Mul rather than Add using a custom Reducer. |
229 | TEST(Reductions, ReduceProduct) { |
230 | const int M = 4; |
231 | const int N = 4; |
232 | |
233 | BufHandle b("b" , {M, N}, kFloat); |
234 | std::vector<float> in(M * N); |
235 | for (const auto i : c10::irange(M)) { |
236 | for (const auto j : c10::irange(N)) { |
237 | in[i * N + j] = 2 + j; |
238 | } |
239 | } |
240 | |
241 | std::vector<float> out(M, -1.f); |
242 | |
243 | Reducer product( |
244 | ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); |
245 | |
246 | Tensor c = Reduce("product" , {M}, product, b, {N}); |
247 | LoopNest loop({c}); |
248 | loop.prepareForCodegen(); |
249 | StmtPtr s = loop.root_stmt(); |
250 | s = IRSimplifier::simplify(s); |
251 | |
252 | SimpleIREvaluator cg(s, {b, c}); |
253 | |
254 | cg.call({in, out}); |
255 | |
256 | float expected = 1; |
257 | for (const auto i : c10::irange(N)) { |
258 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
259 | expected *= 2 + i; |
260 | } |
261 | |
262 | for (const auto i : c10::irange(M)) { |
263 | ASSERT_EQ(out[i], expected); |
264 | } |
265 | } |
266 | |
267 | // Maximum reductions. |
268 | TEST(Reductions, ReduceMax) { |
269 | BufHandle in_("b" , {10}, kFloat); |
270 | |
271 | std::vector<float> in(10); |
272 | std::vector<float> out(1, -1.f); |
273 | for (const auto j : c10::irange(10)) { |
274 | in[j] = j; |
275 | } |
276 | |
277 | Tensor dm1 = Reduce("max" , {}, Maximum(kFloat), in_, {10}); |
278 | |
279 | LoopNest loop({dm1}); |
280 | loop.prepareForCodegen(); |
281 | StmtPtr s = loop.root_stmt(); |
282 | s = IRSimplifier::simplify(s); |
283 | SimpleIREvaluator cg(s, {in_, dm1}); |
284 | |
285 | cg.call({in, out}); |
286 | |
287 | ASSERT_EQ(out[0], 9); |
288 | |
289 | BufHandle in2_("b" , {2, 5}, kFloat); |
290 | std::vector<float> out2(2, -1.f); |
291 | |
292 | Tensor m2d = Reduce("max" , {2}, Maximum(kFloat), in2_, {5}); |
293 | |
294 | LoopNest loop2({m2d}); |
295 | loop2.prepareForCodegen(); |
296 | s = loop2.root_stmt(); |
297 | s = IRSimplifier::simplify(s); |
298 | |
299 | SimpleIREvaluator cg2(s, {in2_, m2d}); |
300 | cg2.call({in, out2}); |
301 | |
302 | ASSERT_EQ(out2[0], 4); |
303 | ASSERT_EQ(out2[1], 9); |
304 | } |
305 | |
306 | // Minimum reduction, with custom initialization. |
307 | TEST(Reductions, ReduceMinCustomInitializer) { |
308 | VarHandle minInit("minInit" , kFloat); |
309 | BufHandle in_("b" , {10}, kFloat); |
310 | |
311 | std::vector<float> in(10); |
312 | std::vector<float> out(1, -1.f); |
313 | for (const auto j : c10::irange(10)) { |
314 | in[j] = 10 + j; |
315 | } |
316 | |
317 | Tensor min = Reduce( |
318 | "min" , |
319 | {}, |
320 | Minimum(ExprHandle(minInit)), |
321 | [&](ParameterList& v) { return in_.load(v); }, |
322 | {10}); |
323 | |
324 | LoopNest loop({min}); |
325 | loop.prepareForCodegen(); |
326 | StmtPtr s = loop.root_stmt(); |
327 | s = IRSimplifier::simplify(s); |
328 | |
329 | SimpleIREvaluator cg(s, {in_, min, minInit}); |
330 | |
331 | // Works normally (note that out data starts lower than the correct |
332 | // minimum). |
333 | cg.call({in, out, std::numeric_limits<float>::max()}); |
334 | ASSERT_EQ(out[0], 10); |
335 | |
336 | // With an initalizer lower than the min, that's the min. |
337 | cg.call({in, out, 5.f}); |
338 | ASSERT_EQ(out[0], 5); |
339 | } |
340 | |
341 | // Example implementation of Any/All. |
342 | // TODO: this is very awkward without logical And/Or operators. |
343 | TEST(Reductions, ReduceAnyAll) { |
344 | VarHandle searchValue("searchValue" , kInt); |
345 | BufHandle b("b" , {4, 10}, kInt); |
346 | |
347 | Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { |
348 | return CompareSelect::make(a, 1, 1, b, kEQ); |
349 | }); |
350 | |
351 | Tensor any = Reduce( |
352 | "anyEqual" , |
353 | {4}, |
354 | anyEqSV, |
355 | [&](const auto& i, const auto& j) { |
356 | return CompareSelect::make(b.load(i, j), searchValue, kEQ); |
357 | }, |
358 | {10}); |
359 | |
360 | LoopNest loop({any}); |
361 | loop.prepareForCodegen(); |
362 | StmtPtr s = loop.root_stmt(); |
363 | s = IRSimplifier::simplify(s); |
364 | |
365 | SimpleIREvaluator cg(s, {b, any, searchValue}); |
366 | |
367 | std::vector<int> in(40, 0); |
368 | std::vector<int> out(4, 0); |
369 | |
370 | // input has 0-39 in 4 rows. |
371 | for (const auto i : c10::irange(40)) { |
372 | in[i] = i; |
373 | } |
374 | cg.call({in, out, 1}); |
375 | |
376 | // only the first row has 1 |
377 | ASSERT_EQ(out[0], 1); |
378 | ASSERT_EQ(out[1], 0); |
379 | ASSERT_EQ(out[2], 0); |
380 | ASSERT_EQ(out[3], 0); |
381 | |
382 | cg.call({in, out, 15}); |
383 | |
384 | // 15 in the 3rd row |
385 | ASSERT_EQ(out[0], 0); |
386 | ASSERT_EQ(out[1], 1); |
387 | ASSERT_EQ(out[2], 0); |
388 | ASSERT_EQ(out[3], 0); |
389 | |
390 | Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) { |
391 | return CompareSelect::make(a, 0, 0, b, kEQ); |
392 | }); |
393 | |
394 | Tensor allGreaterThan = Reduce( |
395 | "allGreaterThan" , |
396 | {4}, |
397 | allGTSV, |
398 | [&](const auto& i, const auto& j) { |
399 | return CompareSelect::make(b.load(i, j), searchValue, kGT); |
400 | }, |
401 | {10}); |
402 | |
403 | LoopNest loop2({allGreaterThan}); |
404 | loop2.prepareForCodegen(); |
405 | s = loop2.root_stmt(); |
406 | s = IRSimplifier::simplify(s); |
407 | |
408 | SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); |
409 | |
410 | cg2.call({in, out, 11}); |
411 | |
412 | // 11 is in row 2. |
413 | ASSERT_EQ(out[0], 0); |
414 | ASSERT_EQ(out[1], 0); |
415 | ASSERT_EQ(out[2], 1); |
416 | ASSERT_EQ(out[3], 1); |
417 | |
418 | cg2.call({in, out, -3}); |
419 | |
420 | // All are positive. |
421 | ASSERT_EQ(out[0], 1); |
422 | ASSERT_EQ(out[1], 1); |
423 | ASSERT_EQ(out[2], 1); |
424 | ASSERT_EQ(out[3], 1); |
425 | } |
426 | |
427 | TEST(Reductions, ReduceMatmul2D) { |
428 | BufHandle tA("tA" , {3, 2}, kFloat); |
429 | BufHandle tB("tB" , {2, 3}, kFloat); |
430 | |
431 | std::vector<float> tA_(6); |
432 | std::vector<float> tB_(6); |
433 | |
434 | std::vector<float> out(9, -1.f); |
435 | for (const auto i : c10::irange(3)) { |
436 | for (const auto j : c10::irange(2)) { |
437 | tA_[i * 2 + j] = i * 2 + j; |
438 | tB_[j * 3 + i] = i * 2 + j; |
439 | } |
440 | } |
441 | |
442 | Tensor mm = Reduce( |
443 | "mm" , |
444 | {3, 3}, |
445 | Sum(), |
446 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
447 | return tA.load(m, k) * tB.load(k, n); |
448 | }, |
449 | {2}); |
450 | |
451 | LoopNest loop({mm}); |
452 | loop.prepareForCodegen(); |
453 | StmtPtr s = loop.root_stmt(); |
454 | s = IRSimplifier::simplify(s); |
455 | |
456 | SimpleIREvaluator cg(s, {tA, tB, mm}); |
457 | cg.call({tA_, tB_, out}); |
458 | |
459 | std::vector<float> expected( |
460 | {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f}); |
461 | |
462 | for (const auto i : c10::irange(9)) { |
463 | ASSERT_EQ(out[i], expected[i]); |
464 | } |
465 | } |
466 | |
467 | TEST(Reductions, ReduceRfactorLike) { |
468 | BufHandle in("in" , {10, 10}, kFloat); |
469 | std::vector<float> in_(100); |
470 | for (const auto i : c10::irange(100)) { |
471 | in_[i] = i; |
472 | } |
473 | std::vector<float> in_rf_(10, -2.f); |
474 | std::vector<float> out(1, -1.f); |
475 | |
476 | Tensor l1 = Reduce("l1" , {10}, Sum(), in, {10}); |
477 | BufHandle in_rf(l1.buf()); |
478 | |
479 | Tensor l2 = Reduce("l2" , {}, Sum(), in_rf, {10}); |
480 | |
481 | LoopNest loop({l1, l2}); |
482 | loop.prepareForCodegen(); |
483 | StmtPtr s = loop.root_stmt(); |
484 | s = IRSimplifier::simplify(s); |
485 | |
486 | SimpleIREvaluator cg(s, {in, l1, l2}); |
487 | cg.call({in_, in_rf_, out}); |
488 | |
489 | ASSERT_EQ(out[0], 99 * 50); |
490 | } |
491 | |
492 | TEST(Reductions, ReduceAsProducer) { |
493 | const int M = 10; |
494 | VarHandle m("m" , kInt); |
495 | |
496 | BufHandle a("a" , {2, 3}, kFloat); |
497 | BufHandle b("b" , {2, 3, m}, kFloat); |
498 | |
499 | Tensor c = Reduce("sum" , {2, 3}, Sum(), b, {m}); |
500 | Tensor d = |
501 | Compute("scale" , {2, 3}, [&](const VarHandle& l, const VarHandle& n) { |
502 | return c.load(l, n) * a.load(l, n); |
503 | }); |
504 | LoopNest loop({d}, {c, d}); |
505 | loop.prepareForCodegen(); |
506 | StmtPtr s = loop.root_stmt(); |
507 | s = IRSimplifier::simplify(s); |
508 | |
509 | SimpleIREvaluator cg(s, {a, b, d, m}); |
510 | |
511 | std::vector<float> aData(2 * 3, 0); |
512 | std::vector<float> bData(2 * 3 * M, 0); |
513 | std::vector<float> dData(2 * 3, 6.0f); |
514 | |
515 | for (int i = 0; i < 2 * 3; ++i) { |
516 | aData[i] = 6 - i; |
517 | for (const auto j : c10::irange(M)) { |
518 | bData[i * M + j] = j; |
519 | } |
520 | } |
521 | |
522 | cg.call({aData, bData, dData, M}); |
523 | float expected = 0; |
524 | for (const auto i : c10::irange(M)) { |
525 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
526 | expected += i; |
527 | } |
528 | for (int i = 0; i < 2 * 3; ++i) { |
529 | ASSERT_EQ(dData[i], expected * (6 - i)); |
530 | } |
531 | } |
532 | |
533 | TEST(Reductions, ReduceAsConsumer) { |
534 | const int M = 10; |
535 | VarHandle m("m" , kInt); |
536 | |
537 | BufHandle a("a" , {2, 3, m}, kFloat); |
538 | BufHandle b("b" , {2, 3, m}, kFloat); |
539 | |
540 | Tensor c = Compute( |
541 | "scale" , |
542 | {2, 3, m}, |
543 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
544 | return b.load(l, n, m) * a.load(l, n, m); |
545 | }); |
546 | Tensor d = Reduce("sum" , {2}, Sum(), c, {3, m}); |
547 | LoopNest loop({d}, {c, d}); |
548 | loop.prepareForCodegen(); |
549 | StmtPtr s = loop.root_stmt(); |
550 | s = IRSimplifier::simplify(s); |
551 | |
552 | SimpleIREvaluator cg(s, {a, b, d, m}); |
553 | |
554 | std::vector<float> aData(2 * 3 * M, 0); |
555 | std::vector<float> bData(2 * 3 * M, 0); |
556 | std::vector<float> dData(2, 6.0f); |
557 | |
558 | for (int i = 0; i < 2 * 3; ++i) { |
559 | for (const auto j : c10::irange(M)) { |
560 | bData[i * M + j] = j + 1; |
561 | aData[i * M + j] = 6 - i; |
562 | } |
563 | } |
564 | |
565 | cg.call({aData, bData, dData, M}); |
566 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
567 | float expected[2] = {0, 0}; |
568 | for (const auto i : c10::irange(2)) { |
569 | for (const auto j : c10::irange(3)) { |
570 | for (const auto k : c10::irange(M)) { |
571 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
572 | expected[i] += (k + 1) * (6 - (i * 3 + j)); |
573 | } |
574 | } |
575 | } |
576 | |
577 | for (const auto i : c10::irange(2)) { |
578 | ASSERT_EQ(dData[i], expected[i]); |
579 | } |
580 | } |
581 | |
582 | TEST(Reductions, SplitReduceAxis) { |
583 | BufHandle in("in" , {16, 8}, kFloat); |
584 | |
585 | std::vector<float> in_(16 * 8); |
586 | for (const auto i : c10::irange(16)) { |
587 | for (const auto j : c10::irange(8)) { |
588 | in_[i * 8 + j] = i; |
589 | } |
590 | } |
591 | std::vector<float> out(16, -1.f); |
592 | |
593 | Tensor tensor = Reduce("sum" , {16}, Sum(), in, {8}); |
594 | LoopNest l({tensor}); |
595 | std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor); |
596 | LoopNest::splitWithTail(loops[1], 2); |
597 | |
598 | l.prepareForCodegen(); |
599 | |
600 | StmtPtr s = l.root_stmt(); |
601 | s = IRSimplifier::simplify(s); |
602 | |
603 | SimpleIREvaluator cg(s, {in, tensor}); |
604 | cg.call({in_, out}); |
605 | |
606 | for (const auto i : c10::irange(16)) { |
607 | ASSERT_EQ(out[i], i * 8); |
608 | } |
609 | } |
610 | |
611 | TEST(Reductions, SplitNonReduceAxis) { |
612 | BufHandle in("in" , {16, 8}, kFloat); |
613 | |
614 | std::vector<float> in_(16 * 8); |
615 | for (const auto i : c10::irange(16)) { |
616 | for (const auto j : c10::irange(8)) { |
617 | in_[i * 8 + j] = i; |
618 | } |
619 | } |
620 | std::vector<float> out(16, -1.f); |
621 | Tensor tensor = Reduce("sum" , {16}, Sum(), in, {8}); |
622 | LoopNest l({tensor}); |
623 | std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor); |
624 | LoopNest::splitWithTail(loops[0], 2); |
625 | LoopNest::splitWithTail(loops[0], 2); |
626 | |
627 | l.prepareForCodegen(); |
628 | |
629 | StmtPtr s = l.root_stmt(); |
630 | s = IRSimplifier::simplify(s); |
631 | |
632 | SimpleIREvaluator cg(s, {in, tensor}); |
633 | cg.call({in_, out}); |
634 | |
635 | for (const auto i : c10::irange(16)) { |
636 | ASSERT_EQ(out[i], i * 8); |
637 | } |
638 | } |
639 | |
640 | TEST(Reductions, ReorderedReductionInitializer) { |
641 | /* From the quip: |
642 | for k in 0..1: // blockIdx |
643 | for m in 0..128: |
644 | for n in 0..64: // threadIdx |
645 | SumOp(c(k, n), 0, a(k, m, n), {m}) |
646 | */ |
647 | |
648 | BufHandle in("in" , {1, 12, 6}, kFloat); |
649 | std::vector<float> in_(12 * 6, 1.f); |
650 | |
651 | Tensor tensor_ = Reduce("sum" , {1, 12}, Sum(), in, {6}); |
652 | LoopNest l_({tensor_}); |
653 | |
654 | l_.prepareForCodegen(); |
655 | StmtPtr s_ = Stmt::clone(l_.root_stmt()); |
656 | s_ = IRSimplifier::simplify(s_); |
657 | |
658 | Tensor tensor = Reduce("sum" , {1, 12}, Sum(), in, {6}); |
659 | LoopNest l({tensor}); |
660 | |
661 | auto loops = l.getLoopStmtsFor(tensor); |
662 | loops[0]->set_gpu_block_index(0); |
663 | loops[1]->set_gpu_thread_index(0); |
664 | |
665 | LoopNest::reorderAxis(loops[1], loops[2]); |
666 | |
667 | StmtPtr s = l.root_stmt(); |
668 | // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
669 | s = IRSimplifier::simplify(s); |
670 | |
671 | l.prepareForCodegen(); |
672 | |
673 | s = l.root_stmt(); |
674 | s = IRSimplifier::simplify(s); |
675 | |
676 | std::vector<float> out1(16, -1.f); |
677 | SimpleIREvaluator cg(s_, {in, tensor_}); |
678 | cg.call({in_, out1}); |
679 | |
680 | std::vector<float> out2(16, -1.f); |
681 | SimpleIREvaluator cg2(s, {in, tensor}); |
682 | cg2.call({in_, out2}); |
683 | |
684 | for (const auto i : c10::irange(16)) { |
685 | ASSERT_EQ(out1[i], out2[i]); |
686 | } |
687 | } |
688 | |
689 | TEST(Reductions, ReduceRfactor) { |
690 | const int M = 10; |
691 | const int N = 10; |
692 | VarHandle m("m" , kInt); |
693 | VarHandle n("n" , kInt); |
694 | |
695 | BufHandle b("b" , {m, n}, kFloat); |
696 | std::vector<float> in(M * N); |
697 | for (int j = 0; j < M * N; ++j) { |
698 | in[j] = j; |
699 | } |
700 | |
701 | std::vector<float> out(1, -1.f); |
702 | |
703 | Tensor c = Reduce("sum" , {}, Sum(), b, {m, n}); |
704 | LoopNest loop({c}); |
705 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
706 | auto c_body = loop.getAllWritesToBuf(c.buf())[1]; |
707 | ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); |
708 | auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt()); |
709 | ASSERT_EQ(rc.size(), 2); |
710 | loop.prepareForCodegen(); |
711 | StmtPtr s = loop.root_stmt(); |
712 | s = IRSimplifier::simplify(s); |
713 | |
714 | SimpleIREvaluator cg(s, {b, c, m, n}); |
715 | |
716 | cg.call({in, out, M, N}); |
717 | ASSERT_EQ(out[0], 4950); |
718 | } |
719 | |
720 | TEST(Reductions, Reduce3DRfactorInner) { |
721 | const int M = 10; |
722 | const int N = 10; |
723 | const int K = 10; |
724 | VarHandle m("m" , kInt); |
725 | VarHandle n("n" , kInt); |
726 | VarHandle k("k" , kInt); |
727 | |
728 | BufHandle b("b" , {m, n, k}, kFloat); |
729 | std::vector<float> in(M * N * K); |
730 | for (int j = 0; j < M * N * K; ++j) { |
731 | in[j] = j; |
732 | } |
733 | |
734 | std::vector<float> out(1, -1.f); |
735 | |
736 | Tensor c = Reduce("sum" , {}, Sum(), b, {m, n, k}); |
737 | LoopNest loop({c}); |
738 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
739 | auto c_body = loop.getAllWritesToBuf(c.buf())[1]; |
740 | ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); |
741 | auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt()); |
742 | ASSERT_EQ(rc.size(), 1); |
743 | loop.prepareForCodegen(); |
744 | StmtPtr s = loop.root_stmt(); |
745 | s = IRSimplifier::simplify(s); |
746 | |
747 | SimpleIREvaluator cg(s, {b, c, m, n, k}); |
748 | |
749 | cg.call({in, out, M, N, K}); |
750 | ASSERT_EQ(out[0], 499500); |
751 | } |
752 | |
753 | TEST(Reductions, Reduce3DRfactorOuter) { |
754 | const int M = 10; |
755 | const int N = 10; |
756 | const int K = 10; |
757 | VarHandle m("m" , kInt); |
758 | VarHandle n("n" , kInt); |
759 | VarHandle k("k" , kInt); |
760 | |
761 | BufHandle b("b" , {m, n, k}, kFloat); |
762 | std::vector<float> in(M * N * K); |
763 | for (int j = 0; j < M * N * K; ++j) { |
764 | in[j] = j; |
765 | } |
766 | |
767 | std::vector<float> out(1, -1.f); |
768 | |
769 | Tensor c = Reduce("sum" , {}, Sum(), b, {m, n, k}); |
770 | LoopNest loop({c}); |
771 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
772 | auto c_body = loop.getAllWritesToBuf(c.buf())[1]; |
773 | ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); |
774 | auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt()); |
775 | ASSERT_EQ(rc.size(), 2); |
776 | loop.prepareForCodegen(); |
777 | StmtPtr s = loop.root_stmt(); |
778 | s = IRSimplifier::simplify(s); |
779 | |
780 | SimpleIREvaluator cg(s, {b, c, m, n, k}); |
781 | cg.call({in, out, M, N, K}); |
782 | ASSERT_EQ(out[0], 499500); |
783 | } |
784 | |
785 | TEST(Reductions, ReduceRepeatedInternalRfactor) { |
786 | BufHandle in_("in_" , {2, 3, 4, 5, 6}, kFloat); |
787 | const int InputSize = 2 * 3 * 4 * 5 * 6; |
788 | |
789 | std::vector<float> in(InputSize, 1.f); |
790 | std::vector<float> out(1, -1.f); |
791 | std::vector<float> ref(1, -1.f); |
792 | |
793 | Tensor c = Reduce("sum" , {}, Sum(), in_, {2, 3, 4, 5, 6}); |
794 | LoopNest orig_loop({c}); |
795 | |
796 | // Try rfactoring N outer loops |
797 | for (const auto rfac_number : c10::irange(1, 5)) { |
798 | LoopNest refloop(orig_loop); |
799 | LoopNest loop(orig_loop); |
800 | refloop.prepareForCodegen(); |
801 | SimpleIREvaluator ref_cg( |
802 | IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); |
803 | ref_cg.call({in, ref}); |
804 | |
805 | BufPtr tmp_buf = c.buf(); |
806 | |
807 | for (const auto idx : c10::irange(rfac_number)) { |
808 | auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; |
809 | ASSERT_TRUE(loop.rfactor( |
810 | reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); |
811 | } |
812 | |
813 | loop.prepareForCodegen(); |
814 | StmtPtr s = loop.root_stmt(); |
815 | s = IRSimplifier::simplify(s); |
816 | |
817 | SimpleIREvaluator cg(s, {in_, c}); |
818 | cg.call({in, out}); |
819 | |
820 | ASSERT_EQ(ref[0], out[0]); |
821 | } |
822 | } |
823 | |
824 | // Split a reduction axis with a tail loop. |
825 | TEST(Reductions, ReduceSplitTail) { |
826 | const int M = 10; |
827 | const int N = 10; |
828 | const int K = 10; |
829 | |
830 | BufHandle b("b" , {M, N, K}, kFloat); |
831 | std::vector<float> in(M * N * K); |
832 | for (int j = 0; j < M * N * K; ++j) { |
833 | in[j] = j; |
834 | } |
835 | |
836 | for (const auto i : c10::irange(3)) { |
837 | std::vector<float> out(M, -1.f); |
838 | |
839 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
840 | LoopNest loop({c}); |
841 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
842 | LoopNest::splitWithTail(loops[i], 8); |
843 | |
844 | loop.prepareForCodegen(); |
845 | StmtPtr s = loop.root_stmt(); |
846 | s = IRSimplifier::simplify(s); |
847 | |
848 | SimpleIREvaluator cg(s, {b, c}); |
849 | |
850 | cg.call({in, out}); |
851 | ASSERT_EQ(out[0], 4950); |
852 | } |
853 | } |
854 | |
855 | // Split a reduction axis cleanly so there is no tail loop. |
856 | TEST(Reductions, ReduceSplitNoTail) { |
857 | const int M = 10; |
858 | const int N = 10; |
859 | const int K = 10; |
860 | BufHandle b("b" , {M, N, K}, kFloat); |
861 | std::vector<float> in(M * N * K); |
862 | for (int j = 0; j < M * N * K; ++j) { |
863 | in[j] = j; |
864 | } |
865 | |
866 | for (const auto i : c10::irange(3)) { |
867 | std::vector<float> out(M, -1.f); |
868 | |
869 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
870 | LoopNest loop({c}); |
871 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
872 | LoopNest::splitWithTail(loops[i], 5); |
873 | |
874 | loop.prepareForCodegen(); |
875 | StmtPtr s = loop.root_stmt(); |
876 | s = IRSimplifier::simplify(s); |
877 | |
878 | SimpleIREvaluator cg(s, {b, c}); |
879 | |
880 | cg.call({in, out}); |
881 | ASSERT_EQ(out[0], 4950); |
882 | } |
883 | } |
884 | |
885 | // Split a reduction axis with only a tail loop (the split loop will be size 0 |
886 | // and eliminated out). |
887 | TEST(Reductions, ReduceOverSplitTail) { |
888 | const int M = 10; |
889 | const int N = 10; |
890 | const int K = 10; |
891 | |
892 | BufHandle b("b" , {M, N, K}, kFloat); |
893 | std::vector<float> in(M * N * K); |
894 | for (int j = 0; j < M * N * K; ++j) { |
895 | in[j] = j; |
896 | } |
897 | |
898 | for (const auto i : c10::irange(3)) { |
899 | std::vector<float> out(M, -1.f); |
900 | |
901 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
902 | LoopNest loop({c}); |
903 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
904 | LoopNest::splitWithTail(loops[i], 16); |
905 | |
906 | loop.prepareForCodegen(); |
907 | StmtPtr s = loop.root_stmt(); |
908 | s = IRSimplifier::simplify(s); |
909 | |
910 | SimpleIREvaluator cg(s, {b, c}); |
911 | |
912 | cg.call({in, out}); |
913 | ASSERT_EQ(out[0], 4950); |
914 | } |
915 | } |
916 | |
917 | // Split a reduction axis with a mask. |
918 | TEST(Reductions, ReduceSplitMask) { |
919 | const int M = 10; |
920 | const int N = 10; |
921 | const int K = 10; |
922 | |
923 | BufHandle b("b" , {M, N, K}, kFloat); |
924 | std::vector<float> in(M * N * K); |
925 | for (int j = 0; j < M * N * K; ++j) { |
926 | in[j] = j; |
927 | } |
928 | |
929 | for (const auto i : c10::irange(3)) { |
930 | std::vector<float> out(M, -1.f); |
931 | |
932 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
933 | LoopNest loop({c}); |
934 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
935 | LoopNest::splitWithMask(loops[i], 8); |
936 | |
937 | loop.prepareForCodegen(); |
938 | StmtPtr s = loop.root_stmt(); |
939 | s = IRSimplifier::simplify(s); |
940 | |
941 | SimpleIREvaluator cg(s, {b, c}); |
942 | |
943 | cg.call({in, out}); |
944 | ASSERT_EQ(out[0], 4950); |
945 | } |
946 | } |
947 | |
948 | // Split a reduction axis cleanly not requiring a mask. |
949 | TEST(Reductions, ReduceSplitNoMask) { |
950 | const int M = 10; |
951 | const int N = 10; |
952 | const int K = 10; |
953 | BufHandle b("b" , {M, N, K}, kFloat); |
954 | std::vector<float> in(M * N * K); |
955 | for (int j = 0; j < M * N * K; ++j) { |
956 | in[j] = j; |
957 | } |
958 | |
959 | for (const auto i : c10::irange(3)) { |
960 | std::vector<float> out(M, -1.f); |
961 | |
962 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
963 | LoopNest loop({c}); |
964 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
965 | LoopNest::splitWithMask(loops[i], 5); |
966 | |
967 | loop.prepareForCodegen(); |
968 | StmtPtr s = loop.root_stmt(); |
969 | s = IRSimplifier::simplify(s); |
970 | |
971 | SimpleIREvaluator cg(s, {b, c}); |
972 | |
973 | cg.call({in, out}); |
974 | ASSERT_EQ(out[0], 4950); |
975 | } |
976 | } |
977 | |
978 | // Split a reduction axis with all logic in the mask. |
979 | TEST(Reductions, ReduceOverSplitMask) { |
980 | const int M = 10; |
981 | const int N = 10; |
982 | const int K = 10; |
983 | |
984 | BufHandle b("b" , {M, N, K}, kFloat); |
985 | std::vector<float> in(M * N * K); |
986 | for (int j = 0; j < M * N * K; ++j) { |
987 | in[j] = j; |
988 | } |
989 | |
990 | for (const auto i : c10::irange(3)) { |
991 | std::vector<float> out(M, -1.f); |
992 | |
993 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
994 | LoopNest loop({c}); |
995 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
996 | LoopNest::splitWithMask(loops[i], 16); |
997 | |
998 | loop.prepareForCodegen(); |
999 | StmtPtr s = loop.root_stmt(); |
1000 | s = IRSimplifier::simplify(s); |
1001 | |
1002 | SimpleIREvaluator cg(s, {b, c}); |
1003 | |
1004 | cg.call({in, out}); |
1005 | ASSERT_EQ(out[0], 4950); |
1006 | } |
1007 | } |
1008 | |
1009 | // Test an rfactor when there are two ReduceOps in the graph due to a |
1010 | // splitWithTail. |
1011 | TEST(Reductions, ReduceSplitRfactor) { |
1012 | const int M = 2; |
1013 | const int N = 10; |
1014 | const int K = 10; |
1015 | const int SPLIT_FACTOR = 4; |
1016 | |
1017 | BufHandle b("b" , {M, N, K}, kFloat); |
1018 | std::vector<float> in(M * N * K); |
1019 | for (const auto m : c10::irange(M)) { |
1020 | for (int j = 0; j < N * K; ++j) { |
1021 | in[m * N * K + j] = j; |
1022 | } |
1023 | } |
1024 | |
1025 | std::vector<float> out(M, -1.f); |
1026 | |
1027 | Tensor c = Reduce("sum" , {M}, Sum(), b, {N, K}); |
1028 | LoopNest loop({c}); |
1029 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
1030 | LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); |
1031 | |
1032 | auto c_body = loop.getAllWritesToBuf(c.buf())[2]; |
1033 | auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); |
1034 | ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); |
1035 | LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); |
1036 | all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); |
1037 | ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); |
1038 | ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); |
1039 | loop.prepareForCodegen(); |
1040 | loop.simplify(); |
1041 | StmtPtr s = loop.root_stmt(); |
1042 | |
1043 | SimpleIREvaluator cg(s, {b, c}); |
1044 | |
1045 | cg.call({in, out}); |
1046 | for (const auto i : c10::irange(M)) { |
1047 | (void)i; // Suppress unused variable warning |
1048 | ASSERT_EQ(out[0], 4950); |
1049 | } |
1050 | } |
1051 | |
1052 | // Test an rfactor which ends up being eliminated since the total loop size is |
1053 | // smaller than the split factor. |
1054 | TEST(Reductions, ReduceOverSplitRfactor) { |
1055 | const int N = 10; |
1056 | const int K = 10; |
1057 | const int SPLIT_FACTOR = 16; |
1058 | |
1059 | BufHandle b("b" , {N, K}, kFloat); |
1060 | std::vector<float> in(N * K); |
1061 | for (int j = 0; j < N * K; ++j) { |
1062 | in[j] = j; |
1063 | } |
1064 | |
1065 | std::vector<float> out(1, -1.f); |
1066 | |
1067 | Tensor c = Reduce("sum" , {}, Sum(), b, {N, K}); |
1068 | LoopNest loop({c}); |
1069 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
1070 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1071 | ForPtr i, t; |
1072 | LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); |
1073 | LoopNest::reorderAxis(loops[0], i); |
1074 | |
1075 | auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); |
1076 | ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); |
1077 | auto c_body = loop.getAllWritesToBuf(c.buf())[1]; |
1078 | ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); |
1079 | LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); |
1080 | |
1081 | loop.prepareForCodegen(); |
1082 | loop.simplify(); |
1083 | StmtPtr s = loop.root_stmt(); |
1084 | |
1085 | SimpleIREvaluator cg(s, {b, c}); |
1086 | |
1087 | cg.call({in, out}); |
1088 | ASSERT_EQ(out[0], 4950); |
1089 | |
1090 | std::ostringstream oss; |
1091 | oss << *cg.stmt(); |
1092 | |
1093 | // Check the IR to verify the rfactored reduce is eliminated. |
1094 | // TODO: The alloc free should be eliminated here since it is size 0. |
1095 | const std::string& verification_pattern = |
1096 | R"IR( |
1097 | # CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] |
1098 | # CHECK: sum[0] = 0.f; |
1099 | # CHECK: for (int n = 0; n < 10; n++) { |
1100 | # CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) { |
1101 | # CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]); |
1102 | # CHECK: } |
1103 | # CHECK: } |
1104 | # CHECK: Free(tmp_buf);)IR" ; |
1105 | // TODO: rfactor output is not consistent yet, will fix (@nickg). |
1106 | // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1107 | } |
1108 | |
1109 | TEST(Reductions, ReduceInlineReduction) { |
1110 | const int M = 4; |
1111 | const int N = 5; |
1112 | const int K = 6; |
1113 | |
1114 | BufHandle a_buf("a" , {M}, kFloat); |
1115 | BufHandle b_buf("b" , {M, N, K}, kFloat); |
1116 | |
1117 | Tensor x = Reduce("x" , {M}, Sum(), b_buf, {N, K}); |
1118 | Tensor y = Compute( |
1119 | "y" , {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); }); |
1120 | |
1121 | PaddedBuffer<float> a_v(M); |
1122 | PaddedBuffer<float> b_v(M, N, K); |
1123 | |
1124 | for (const auto i : c10::irange(M)) { |
1125 | a_v(i) = i * i; |
1126 | } |
1127 | for (const auto i : c10::irange(M)) { |
1128 | for (const auto j : c10::irange(N)) { |
1129 | for (const auto k : c10::irange(K)) { |
1130 | b_v(i, j, k) = j * j * k; |
1131 | } |
1132 | } |
1133 | } |
1134 | |
1135 | LoopNest l1({y}, {x, y}); |
1136 | // Cannot inline a reduction computation |
1137 | ASSERT_FALSE(l1.computeInline(x.buf())); |
1138 | } |
1139 | |
1140 | TEST(Reductions, ReduceInlineConsumer) { |
1141 | const int M = 4; |
1142 | const int N = 5; |
1143 | const int K = 6; |
1144 | |
1145 | BufHandle a_buf("a" , {M, N, K}, kFloat); |
1146 | BufHandle b_buf("b" , {M, N, K}, kFloat); |
1147 | |
1148 | Tensor x = Compute( |
1149 | "x" , |
1150 | {M, N, K}, |
1151 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1152 | return a_buf.load(m, n, k) + b_buf.load(m, n, k); |
1153 | }); |
1154 | Tensor y = Reduce("y" , {M}, Sum(), x, {N, K}); |
1155 | |
1156 | PaddedBuffer<float> a_v(M, N, K); |
1157 | PaddedBuffer<float> b_v(M, N, K); |
1158 | |
1159 | for (const auto i : c10::irange(M)) { |
1160 | for (const auto j : c10::irange(N)) { |
1161 | for (const auto k : c10::irange(K)) { |
1162 | a_v(i, j, k) = i * i + k; |
1163 | b_v(i, j, k) = j * j + k; |
1164 | } |
1165 | } |
1166 | } |
1167 | |
1168 | LoopNest l1({y}, {x, y}); |
1169 | LoopNest l2(l1); |
1170 | l2.computeInline(x.buf()); |
1171 | |
1172 | l1.prepareForCodegen(); |
1173 | l2.prepareForCodegen(); |
1174 | |
1175 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1176 | StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); |
1177 | |
1178 | SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); |
1179 | SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); |
1180 | |
1181 | PaddedBuffer<float> y_1(M); |
1182 | PaddedBuffer<float> y_2(M); |
1183 | |
1184 | eval1(a_v, b_v, y_1); |
1185 | eval2(a_v, b_v, y_2); |
1186 | ExpectAllNear(y_1, y_2, 1e-5); |
1187 | std::ostringstream oss1, oss2; |
1188 | oss1 << *stmt1; |
1189 | oss2 << *stmt2; |
1190 | ASSERT_GT(oss1.str().size(), oss2.str().size()); |
1191 | } |
1192 | |
1193 | TEST(Reductions, ReduceInlineReducerInternal) { |
1194 | const int M = 4; |
1195 | const int N = 5; |
1196 | const int K = 6; |
1197 | |
1198 | BufHandle a_buf("a" , {M, N, K}, kFloat); |
1199 | BufHandle b_buf("b" , {M, N, K}, kFloat); |
1200 | |
1201 | Tensor x = Compute( |
1202 | "x" , |
1203 | {M, N, K}, |
1204 | [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
1205 | return a_buf.load(m, n, k) + b_buf.load(m, n, k); |
1206 | }); |
1207 | |
1208 | Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { |
1209 | return Add::make(ExprHandle(1.f), Min::make(a, b, false)); |
1210 | }); |
1211 | Tensor y = Reduce("y" , {M}, minimum, x, {N, K}); |
1212 | |
1213 | PaddedBuffer<float> a_v(M, N, K); |
1214 | PaddedBuffer<float> b_v(M, N, K); |
1215 | |
1216 | for (const auto i : c10::irange(M)) { |
1217 | for (const auto j : c10::irange(N)) { |
1218 | for (const auto k : c10::irange(K)) { |
1219 | a_v(i, j, k) = i * i + k; |
1220 | b_v(i, j, k) = j * j + k; |
1221 | } |
1222 | } |
1223 | } |
1224 | |
1225 | LoopNest l1({y}, {x, y}); |
1226 | LoopNest l2(l1); |
1227 | l2.computeInline(x.buf()); |
1228 | |
1229 | l1.prepareForCodegen(); |
1230 | l2.prepareForCodegen(); |
1231 | |
1232 | StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
1233 | StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); |
1234 | |
1235 | SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); |
1236 | SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); |
1237 | |
1238 | PaddedBuffer<float> y_1(M); |
1239 | PaddedBuffer<float> y_2(M); |
1240 | |
1241 | eval1(a_v, b_v, y_1); |
1242 | eval2(a_v, b_v, y_2); |
1243 | ExpectAllNear(y_1, y_2, 1e-5); |
1244 | std::ostringstream oss1, oss2; |
1245 | oss1 << *stmt1; |
1246 | oss2 << *stmt2; |
1247 | ASSERT_GT(oss1.str().size(), oss2.str().size()); |
1248 | } |
1249 | |
1250 | TEST(Reductions, ReductionCacheAccessesOperatorAxis) { |
1251 | int L = 4; |
1252 | int N = 3; |
1253 | int M = 2; |
1254 | |
1255 | BufHandle a("a" , {L, N, M}, kFloat); |
1256 | BufHandle b("b" , {L, N, M}, kFloat); |
1257 | |
1258 | Tensor c = Compute( |
1259 | "scale" , |
1260 | {L, N, M}, |
1261 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1262 | return b.load(l, n, m) * a.load(l, n, m); |
1263 | }); |
1264 | Tensor d = Reduce("sum" , {L}, Sum(), c, {N, M}); |
1265 | |
1266 | Tensor e = Compute("scale" , {L}, [&](const VarHandle& l) { |
1267 | return b.load(0, 0, l) * d.load(l); |
1268 | }); |
1269 | |
1270 | LoopNest l({e}, {c, d, e}); |
1271 | LoopNest l_before(l); |
1272 | l_before.prepareForCodegen(); |
1273 | SimpleIREvaluator cg_before( |
1274 | LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e}); |
1275 | |
1276 | StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; |
1277 | l.cacheAccesses(d.buf(), "d_local" , d_loop); |
1278 | l.prepareForCodegen(); |
1279 | |
1280 | StmtPtr result = |
1281 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1282 | SimpleIREvaluator cg_after(result, {a, b, e}); |
1283 | |
1284 | std::ostringstream oss; |
1285 | oss << *cg_after.stmt(); |
1286 | const std::string& expected_ir = |
1287 | R"IR( |
1288 | #CHECK: Allocate(d_local); // dtype=float, dims=[4] |
1289 | #CHECK: for (int i_2 |
1290 | #CHECK: d_local[i_2] = 0.f |
1291 | #CHECK: for (int |
1292 | #CHECK: for (int |
1293 | #CHECK: d_local[i_2] = (d_local[i_2]) + (scale[ |
1294 | #CHECK: } |
1295 | #CHECK: } |
1296 | #CHECK: } |
1297 | #CHECK: for (int i_3 |
1298 | #CHECK: sum[i_3] = d_local[i_3] |
1299 | #CHECK: Free(d_local); |
1300 | #CHECK-NOT: d_local |
1301 | )IR" ; |
1302 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1303 | |
1304 | PaddedBuffer<float> a_v(L, M, N, "a" ); |
1305 | PaddedBuffer<float> b_v(L, M, N, "b" ); |
1306 | PaddedBuffer<float> c_v(L, M, N, "c" ); |
1307 | PaddedBuffer<float> d_v(L, "d" ); |
1308 | PaddedBuffer<float> e_before(L, "e_before" ); |
1309 | PaddedBuffer<float> e_after(L, "e_after" ); |
1310 | |
1311 | for (const auto l : c10::irange(L)) { |
1312 | for (const auto m : c10::irange(M)) { |
1313 | for (const auto n : c10::irange(N)) { |
1314 | a_v(l, m, n) = at::randn({1}).item().to<float>(); |
1315 | b_v(l, m, n) = at::randn({1}).item().to<float>(); |
1316 | } |
1317 | } |
1318 | } |
1319 | |
1320 | cg_before.call({a_v, b_v, e_before}); |
1321 | cg_after.call({a_v, b_v, e_after}); |
1322 | |
1323 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1324 | ExpectAllNear(e_before, e_after, 1e-5); |
1325 | } |
1326 | |
1327 | TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { |
1328 | int L = 4; |
1329 | int N = 3; |
1330 | int M = 2; |
1331 | |
1332 | BufHandle a("a" , {L, N, M}, kFloat); |
1333 | BufHandle b("b" , {L, N, M}, kFloat); |
1334 | |
1335 | Tensor c = Compute( |
1336 | "scale" , |
1337 | {L, N, M}, |
1338 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1339 | return b.load(l, n, m) * a.load(l, n, m); |
1340 | }); |
1341 | Tensor d = Reduce("sum" , {L}, Sum(), c, {N, M}); |
1342 | |
1343 | Tensor e = Compute("scale" , {L}, [&](const VarHandle& l) { |
1344 | return b.load(0, 0, l) * d.load(l); |
1345 | }); |
1346 | |
1347 | LoopNest l({e}, {c, d, e}); |
1348 | LoopNest l_before(l); |
1349 | l_before.prepareForCodegen(); |
1350 | SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); |
1351 | |
1352 | StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; |
1353 | l.cacheAccesses(d.buf(), "d_local" , d_loop); |
1354 | l.prepareForCodegen(); |
1355 | |
1356 | StmtPtr result = |
1357 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1358 | SimpleIREvaluator cg_after(result, {a, b, e}); |
1359 | |
1360 | std::ostringstream oss; |
1361 | oss << *cg_after.stmt(); |
1362 | const std::string& expected_ir = |
1363 | R"IR( |
1364 | #CHECK: Allocate(d_local); // dtype=float, dims=[1] |
1365 | #CHECK: sum[i_1] = 0 |
1366 | #CHECK: d_local[0] = sum[i_1] |
1367 | #CHECK: for (int j_1 |
1368 | #CHECK: for (int k_1 |
1369 | #CHECK: d_local[0] = (d_local[0]) + (scale[ |
1370 | #CHECK: } |
1371 | #CHECK: } |
1372 | #CHECK: sum[i_1] = d_local[0] |
1373 | #CHECK: Free(d_local); |
1374 | #CHECK-NOT: d_local |
1375 | )IR" ; |
1376 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1377 | |
1378 | PaddedBuffer<float> a_v(L, M, N, "a" ); |
1379 | PaddedBuffer<float> b_v(L, M, N, "b" ); |
1380 | PaddedBuffer<float> c_v(L, M, N, "c" ); |
1381 | PaddedBuffer<float> d_v(L, "d" ); |
1382 | PaddedBuffer<float> e_before(L, "e_before" ); |
1383 | PaddedBuffer<float> e_after(L, "e_after" ); |
1384 | |
1385 | for (const auto l : c10::irange(L)) { |
1386 | for (const auto m : c10::irange(M)) { |
1387 | for (const auto n : c10::irange(N)) { |
1388 | a_v(l, m, n) = at::randn({1}).item().to<float>(); |
1389 | b_v(l, m, n) = at::randn({1}).item().to<float>(); |
1390 | } |
1391 | } |
1392 | } |
1393 | |
1394 | cg_before.call({a_v, b_v, e_before}); |
1395 | cg_after.call({a_v, b_v, e_after}); |
1396 | |
1397 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1398 | ExpectAllNear(e_before, e_after, 1e-5); |
1399 | } |
1400 | |
1401 | TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { |
1402 | int L = 4; |
1403 | int N = 3; |
1404 | int M = 2; |
1405 | |
1406 | BufHandle a("a" , {L, N, M}, kFloat); |
1407 | BufHandle b("b" , {L, N, M}, kFloat); |
1408 | |
1409 | Tensor c = Compute( |
1410 | "scale" , |
1411 | {L, N, M}, |
1412 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1413 | return b.load(l, n, m) * a.load(l, n, m); |
1414 | }); |
1415 | Tensor d = Reduce("sum" , {L}, Sum(), c, {N, M}); |
1416 | |
1417 | Tensor e = Compute("scale" , {L}, [&](const VarHandle& l) { |
1418 | return b.load(0, 0, l) * d.load(l); |
1419 | }); |
1420 | |
1421 | LoopNest l({e}, {c, d, e}); |
1422 | LoopNest l_before(l); |
1423 | l_before.prepareForCodegen(); |
1424 | SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); |
1425 | |
1426 | StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; |
1427 | l.cacheAccesses(d.buf(), "d_local" , d_loop); |
1428 | l.prepareForCodegen(); |
1429 | |
1430 | StmtPtr result = |
1431 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1432 | SimpleIREvaluator cg_after(result, {a, b, e}); |
1433 | |
1434 | std::ostringstream oss; |
1435 | oss << *cg_after.stmt(); |
1436 | const std::string& expected_ir = |
1437 | R"IR( |
1438 | #CHECK: Allocate(d_local); // dtype=float, dims=[1] |
1439 | #CHECK: sum[i_1] = 0 |
1440 | #CHECK: for (int |
1441 | #CHECK: d_local[0] = 0 |
1442 | #CHECK: for (int |
1443 | #CHECK: d_local[0] = (d_local[0]) + (scale[ |
1444 | #CHECK: } |
1445 | #CHECK: sum[i_1] = (sum[i_1]) + (d_local[0]) |
1446 | #CHECK: } |
1447 | #CHECK: Free(d_local); |
1448 | #CHECK-NOT: d_local |
1449 | )IR" ; |
1450 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1451 | |
1452 | PaddedBuffer<float> a_v(L, M, N, "a" ); |
1453 | PaddedBuffer<float> b_v(L, M, N, "b" ); |
1454 | PaddedBuffer<float> c_v(L, M, N, "c" ); |
1455 | PaddedBuffer<float> d_v(L, "d" ); |
1456 | PaddedBuffer<float> e_before(L, "e_before" ); |
1457 | PaddedBuffer<float> e_after(L, "e_after" ); |
1458 | |
1459 | for (const auto l : c10::irange(L)) { |
1460 | for (const auto m : c10::irange(M)) { |
1461 | for (const auto n : c10::irange(N)) { |
1462 | a_v(l, m, n) = at::randn({1}).item().to<float>(); |
1463 | b_v(l, m, n) = at::randn({1}).item().to<float>(); |
1464 | } |
1465 | } |
1466 | } |
1467 | |
1468 | cg_before.call({a_v, b_v, e_before}); |
1469 | cg_after.call({a_v, b_v, e_after}); |
1470 | |
1471 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1472 | ExpectAllNear(e_before, e_after, 1e-5); |
1473 | } |
1474 | |
1475 | TEST(Reductions, ReductionCacheBodyAccess) { |
1476 | BufHandle a("a" , {24, 32, 12}, kFloat); |
1477 | BufHandle b("b" , {24, 32, 12}, kFloat); |
1478 | |
1479 | Tensor c = Compute( |
1480 | "scale" , |
1481 | {24, 32, 12}, |
1482 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1483 | return b.load(l, n, m) * a.load(l, n, m); |
1484 | }); |
1485 | Tensor d = Reduce("sum" , {24}, Sum(), c, {32, 12}); |
1486 | |
1487 | Tensor e = Compute("scale" , {24}, [&](const VarHandle& l) { |
1488 | return b.load(0, 0, l) * d.load(l); |
1489 | }); |
1490 | |
1491 | LoopNest l({e}, {c, d, e}); |
1492 | |
1493 | StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; |
1494 | l.cacheAccesses(c.buf(), "scale_local" , d_loop); |
1495 | |
1496 | l.prepareForCodegen(); |
1497 | StmtPtr result = |
1498 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1499 | SimpleIREvaluator cg(result, {a, b, e}); |
1500 | |
1501 | std::ostringstream oss; |
1502 | oss << *cg.stmt(); |
1503 | const std::string& expected_ir = |
1504 | R"IR( |
1505 | #CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] |
1506 | #CHECK: for (int j_1 = 0; j_1 < 32; j_1++) { |
1507 | #CHECK: for (int k_1 = 0; k_1 < 12; k_1++) { |
1508 | #CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1]; |
1509 | #CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]); |
1510 | #CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]); |
1511 | #CHECK: Free(scale_local); |
1512 | )IR" ; |
1513 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1514 | } |
1515 | |
1516 | TEST(Reductions, ReductionCacheConsumerAccess) { |
1517 | BufHandle a("a" , {24, 32, 12}, kFloat); |
1518 | BufHandle b("b" , {24, 32, 12}, kFloat); |
1519 | |
1520 | Tensor c = Compute( |
1521 | "scale" , |
1522 | {24, 32, 12}, |
1523 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1524 | return b.load(l, n, m) * a.load(l, n, m); |
1525 | }); |
1526 | Tensor d = Reduce("sum" , {24}, Sum(), c, {32, 12}); |
1527 | |
1528 | Tensor e = Compute("scale" , {24}, [&](const VarHandle& l) { |
1529 | return b.load(0, 0, l) * d.load(l); |
1530 | }); |
1531 | |
1532 | LoopNest l({e}, {c, d, e}); |
1533 | |
1534 | LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); |
1535 | |
1536 | StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; |
1537 | l.cacheAccesses(d.buf(), "sum_local" , e_loop); |
1538 | l.prepareForCodegen(); |
1539 | |
1540 | StmtPtr result = |
1541 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1542 | SimpleIREvaluator cg(result, {a, b, e}); |
1543 | |
1544 | std::ostringstream oss; |
1545 | oss << *cg.stmt(); |
1546 | const std::string& expected_ir = |
1547 | R"IR( |
1548 | #CHECK: Alias(sum_local,scale); |
1549 | #CHECK: sum[i_1] = (sum[i_1]) + (scale[ |
1550 | #CHECK: for (int j_2 = 0; j_2 < 4 |
1551 | #CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; |
1552 | #CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); |
1553 | )IR" ; |
1554 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1555 | } |
1556 | |
1557 | TEST(Reductions, ReductionSplitCacheConsumerAccess) { |
1558 | BufHandle a("a" , {24, 32, 12}, kFloat); |
1559 | BufHandle b("b" , {24, 32, 12}, kFloat); |
1560 | |
1561 | Tensor c = Compute( |
1562 | "scale" , |
1563 | {24, 32, 12}, |
1564 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1565 | return b.load(l, n, m) * a.load(l, n, m); |
1566 | }); |
1567 | Tensor d = Reduce("sum" , {24}, Sum(), c, {32, 12}); |
1568 | |
1569 | Tensor e = Compute("scale" , {24}, [&](const VarHandle& l) { |
1570 | return b.load(0, 0, l) * d.load(l); |
1571 | }); |
1572 | |
1573 | LoopNest l({e}, {c, d, e}); |
1574 | |
1575 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1576 | ForPtr inner; |
1577 | |
1578 | // Split outer reduction axis. |
1579 | LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); |
1580 | |
1581 | // Split reduction consumer. |
1582 | LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); |
1583 | |
1584 | l.cacheAccesses(d.buf(), "sum_local" , inner); |
1585 | l.prepareForCodegen(); |
1586 | |
1587 | StmtPtr result = |
1588 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1589 | SimpleIREvaluator cg(result, {a, b, e}); |
1590 | |
1591 | // reduction changes but cache does not. |
1592 | std::ostringstream oss; |
1593 | oss << *cg.stmt(); |
1594 | const std::string& expected_ir = |
1595 | R"IR( |
1596 | #CHECK: Alias(sum_local,scale); |
1597 | #CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]); |
1598 | #CHECK: for (int i_2 = 0; i_2 < 6 |
1599 | #CHECK: for (int j_2 = 0; j_2 < 4 |
1600 | #CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; |
1601 | #CHECK: for (int j_3 = 0; j_3 < 4 |
1602 | #CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); |
1603 | )IR" ; |
1604 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1605 | } |
1606 | |
1607 | TEST(Reductions, ReductionReorderCacheConsumerAccess) { |
1608 | BufHandle a("a" , {24, 32, 12}, kFloat); |
1609 | BufHandle b("b" , {24, 32, 12}, kFloat); |
1610 | |
1611 | Tensor c = Compute( |
1612 | "scale" , |
1613 | {24, 32, 12}, |
1614 | [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
1615 | return b.load(l, n, m) * a.load(l, n, m); |
1616 | }); |
1617 | Tensor d = Reduce("sum" , {24}, Sum(), c, {32, 12}); |
1618 | |
1619 | Tensor e = Compute("scale" , {24}, [&](const VarHandle& l) { |
1620 | return b.load(0, 0, l) * d.load(l); |
1621 | }); |
1622 | |
1623 | LoopNest l({e}, {c, d, e}); |
1624 | |
1625 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1626 | ForPtr inner; |
1627 | |
1628 | // reorder outer reduction axes. |
1629 | auto loops = l.getLoopStmtsFor(d); |
1630 | LoopNest::reorderAxis(loops[0], loops[1]); |
1631 | |
1632 | // Split reduction consumer. |
1633 | LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); |
1634 | |
1635 | l.cacheAccesses(d.buf(), "sum_local" , inner); |
1636 | l.prepareForCodegen(); |
1637 | |
1638 | StmtPtr result = |
1639 | LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
1640 | SimpleIREvaluator cg(result, {a, b, e}); |
1641 | |
1642 | // neither reduction body not cache changes. |
1643 | std::ostringstream oss; |
1644 | oss << *cg.stmt(); |
1645 | const std::string& expected_ir = |
1646 | R"IR( |
1647 | #CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]); |
1648 | #CHECK: for (int i_3 = 0; i_3 < 6; |
1649 | #CHECK: for (int j_2 = 0; j_2 < 4; |
1650 | #CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3]; |
1651 | #CHECK: for (int j_3 = 0; j_3 < 4; |
1652 | #CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]); |
1653 | )IR" ; |
1654 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1655 | } |
1656 | |
1657 | TEST(Reductions, ReductionRfactorCacheTempOuter) { |
1658 | const int M = 10; |
1659 | const int N = 10; |
1660 | const int K = 10; |
1661 | VarHandle m("m" , kInt); |
1662 | VarHandle n("n" , kInt); |
1663 | VarHandle k("k" , kInt); |
1664 | |
1665 | BufHandle b("B" , {m, n, k}, kFloat); |
1666 | std::vector<float> in(M * N * K); |
1667 | for (int j = 0; j < M * N * K; ++j) { |
1668 | in[j] = j; |
1669 | } |
1670 | |
1671 | std::vector<float> out(1, -1.f); |
1672 | |
1673 | Tensor c = Reduce("sum" , {}, Sum(), b, {m, n, k}); |
1674 | LoopNest loop({c}); |
1675 | |
1676 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
1677 | LoopNest::reorderAxis(loops.at(0), loops.at(1)); |
1678 | loops = loop.getLoopStmtsFor(c); |
1679 | auto c_body = loop.getAllWritesToBuf(c.buf())[1]; |
1680 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1681 | BufPtr rfac_buf; |
1682 | ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); |
1683 | loop.distributeLoop(loops.at(0)); |
1684 | |
1685 | auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); |
1686 | ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); |
1687 | LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); |
1688 | |
1689 | all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); |
1690 | LoopNest::cacheAccesses(rfac_buf, "tmp" , all_loops[1][1]); |
1691 | loop.simplify(); |
1692 | loop.prepareForCodegen(); |
1693 | StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); |
1694 | SimpleIREvaluator cg(s, {b, c, m, n, k}); |
1695 | |
1696 | std::ostringstream oss; |
1697 | oss << *cg.stmt(); |
1698 | const std::string& expected_ir = |
1699 | R"IR( |
1700 | #CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] |
1701 | #CHECK: Allocate(tmp); // dtype=float, dims=[n] |
1702 | #CHECK: for (int i_1 = 0; i_1 < m |
1703 | #CHECK: for (int j = 0; j < n |
1704 | #CHECK: tmp[j] = 0 |
1705 | #CHECK: } |
1706 | #CHECK: for (int j_1 = 0; j_1 < n |
1707 | #CHECK: for (int k |
1708 | #CHECK: tmp[j_1] = (tmp[j_1]) + (B[ |
1709 | #CHECK: } |
1710 | #CHECK: } |
1711 | #CHECK: for (int j_2 = 0; j_2 < n |
1712 | #CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]); |
1713 | #CHECK: } |
1714 | #CHECK: Free(tmp); |
1715 | #CHECK-NOT: tmp |
1716 | )IR" ; |
1717 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1718 | |
1719 | cg.call({in, out, M, N, K}); |
1720 | ASSERT_EQ(out[0], 499500); |
1721 | } |
1722 | |
1723 | TEST(Reductions, ReductionRfactorCacheTempInner) { |
1724 | const int M = 10; |
1725 | const int N = 10; |
1726 | const int K = 10; |
1727 | VarHandle m("m" , kInt); |
1728 | VarHandle n("n" , kInt); |
1729 | VarHandle k("k" , kInt); |
1730 | |
1731 | BufHandle b("B" , {m, n, k}, kFloat); |
1732 | std::vector<float> in(M * N * K); |
1733 | for (int j = 0; j < M * N * K; ++j) { |
1734 | in[j] = j; |
1735 | } |
1736 | |
1737 | std::vector<float> out(1, -1.f); |
1738 | |
1739 | Tensor c = Reduce("sum" , {}, Sum(), b, {m, n, k}); |
1740 | LoopNest loop({c}); |
1741 | std::vector<ForPtr> loops = loop.getLoopStmtsFor(c); |
1742 | auto c_body = loop.getAllWritesToBuf(c.buf())[1]; |
1743 | |
1744 | LoopNest::reorderAxis(loops.at(0), loops.at(1)); |
1745 | loops = loop.getLoopStmtsFor(c); |
1746 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1747 | BufPtr rfac_buf; |
1748 | ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); |
1749 | loop.distributeLoop(loops.at(0)); |
1750 | auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); |
1751 | ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); |
1752 | LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); |
1753 | |
1754 | all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); |
1755 | ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); |
1756 | LoopNest::cacheAccesses(rfac_buf, "tmp" , all_loops[1][2]); |
1757 | loop.prepareForCodegen(); |
1758 | loop.simplify(); |
1759 | StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); |
1760 | SimpleIREvaluator cg(s, {b, c, m, n, k}); |
1761 | |
1762 | std::ostringstream oss; |
1763 | oss << *cg.stmt(); |
1764 | const std::string& expected_ir = |
1765 | R"IR( |
1766 | #CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] |
1767 | #CHECK: Allocate(tmp); // dtype=float, dims=[1] |
1768 | #CHECK: for (int i_1 = 0; i_1 < m |
1769 | #CHECK: for (int j = 0; j < n |
1770 | #CHECK: tmp[0] = 0 |
1771 | #CHECK: for (int k |
1772 | #CHECK: tmp[0] = (tmp[0]) + (B[ |
1773 | #CHECK: } |
1774 | #CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]); |
1775 | #CHECK: Free(tmp); |
1776 | #CHECK-NOT: tmp |
1777 | )IR" ; |
1778 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1779 | |
1780 | cg.call({in, out, M, N, K}); |
1781 | ASSERT_EQ(out[0], 499500); |
1782 | } |
1783 | |
1784 | TEST(Reductions, ReductionVectorize) { |
1785 | std::vector<float> in_(8 * 8); |
1786 | for (const auto i : c10::irange(8)) { |
1787 | for (const auto j : c10::irange(8)) { |
1788 | in_[i * 8 + j] = i; |
1789 | } |
1790 | } |
1791 | std::vector<float> out_before(8, -1.f); |
1792 | std::vector<float> out_after(8, -1.f); |
1793 | |
1794 | BufHandle in("in" , {8, 8}, kFloat); |
1795 | |
1796 | Tensor tensor = Reduce("sum" , {8}, Sum(), in, {8}); |
1797 | LoopNest l_before({tensor}); |
1798 | LoopNest l(l_before); |
1799 | l_before.prepareForCodegen(); |
1800 | SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); |
1801 | cg_before.call({in_, out_before}); |
1802 | |
1803 | ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); |
1804 | |
1805 | StmtPtr s = l.root_stmt(); |
1806 | s = LoopNest::sanitizeNames(IRSimplifier::simplify(s)); |
1807 | |
1808 | std::ostringstream oss; |
1809 | oss << *s; |
1810 | const std::string& expected_ir = |
1811 | R"IR( |
1812 | #CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); |
1813 | #CHECK: for (int i = 0; i < 8; i++) { |
1814 | #CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i}); |
1815 | #CHECK: } |
1816 | )IR" ; |
1817 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1818 | |
1819 | // Vectorizing should not change result. |
1820 | l.prepareForCodegen(); |
1821 | s = IRSimplifier::simplify(l.root_stmt()); |
1822 | SimpleIREvaluator cg_after(s, {in, tensor}); |
1823 | cg_after.call({in_, out_after}); |
1824 | for (const auto i : c10::irange(8)) { |
1825 | ASSERT_EQ(out_before[i], out_after[i]); |
1826 | } |
1827 | } |
1828 | |
1829 | TEST(Reductions, ReductionVectorizeInner) { |
1830 | BufHandle in("in" , {8, 8}, kFloat); |
1831 | |
1832 | Tensor tensor = Reduce("sum" , {8}, Sum(), in, {8}); |
1833 | LoopNest l({tensor}); |
1834 | |
1835 | ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); |
1836 | } |
1837 | |
1838 | TEST(Reductions, ReductionVectorizeRfactor) { |
1839 | std::vector<float> in_(8 * 8); |
1840 | for (const auto i : c10::irange(8)) { |
1841 | for (const auto j : c10::irange(8)) { |
1842 | in_[i * 8 + j] = i; |
1843 | } |
1844 | } |
1845 | std::vector<float> out_before(1, -1.f); |
1846 | std::vector<float> out_after(1, -1.f); |
1847 | |
1848 | BufHandle in("in" , {8, 8}, kFloat); |
1849 | |
1850 | Tensor tensor = Reduce("sum" , {}, Sum(), in, {8, 8}); |
1851 | |
1852 | LoopNest l_before({tensor}); |
1853 | LoopNest l(l_before); |
1854 | l_before.prepareForCodegen(); |
1855 | SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); |
1856 | cg_before.call({in_, out_before}); |
1857 | |
1858 | ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); |
1859 | |
1860 | // But if we rfactor this so it's not a reduce axis we can vectorize that |
1861 | // loop. |
1862 | std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor); |
1863 | LoopNest::reorderAxis(loops[0], loops[1]); |
1864 | loops = l.getLoopStmtsFor(tensor); |
1865 | auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; |
1866 | BufPtr rfac_buf = nullptr; |
1867 | ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); |
1868 | |
1869 | LoopNest::distributeLoop(loops.at(0)); |
1870 | auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf); |
1871 | |
1872 | ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); |
1873 | l.simplify(); |
1874 | |
1875 | StmtPtr s = LoopNest::sanitizeNames(l.root_stmt()); |
1876 | |
1877 | std::ostringstream oss; |
1878 | oss << *s; |
1879 | const std::string& expected_ir = |
1880 | R"IR( |
1881 | #CHECK: sum = 0.f; |
1882 | #CHECK: for (int i = 0; i < 8; i++) { |
1883 | #CHECK: sum_rfac[i] = 0.f; |
1884 | #CHECK: } |
1885 | #CHECK: for (int i_1 = 0; i_1 < 8; i_1++) { |
1886 | #CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1}); |
1887 | #CHECK: } |
1888 | #CHECK: for (int i_2 = 0; i_2 < 8; i_2++) { |
1889 | #CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2}); |
1890 | #CHECK: } |
1891 | )IR" ; |
1892 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1893 | |
1894 | // Vectorizing should not change result. |
1895 | l.prepareForCodegen(); |
1896 | s = IRSimplifier::simplify(l.root_stmt()); |
1897 | SimpleIREvaluator cg_after(s, {in, tensor}); |
1898 | cg_after.call({in_, out_after}); |
1899 | |
1900 | ASSERT_EQ(out_before[0], out_after[0]); |
1901 | } |
1902 | |
1903 | TEST(Reductions, InitFunction) { |
1904 | constexpr int M = 32; |
1905 | constexpr int N = 16; |
1906 | BufHandle A("A" , {M, N}, kFloat); |
1907 | BufHandle B("B" , {N}, kFloat); |
1908 | Tensor C = Reduce( |
1909 | "C" , |
1910 | {N}, |
1911 | Sum(), |
1912 | [&](const std::vector<VarHandle>& v) { return B.load(v[0]); }, |
1913 | [&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); }, |
1914 | {M}); |
1915 | LoopNest nest({C}); |
1916 | nest.prepareForCodegen(); |
1917 | StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt())); |
1918 | std::ostringstream oss; |
1919 | oss << *s << "\n" ; |
1920 | const std::string& expected_ir = |
1921 | R"IR( |
1922 | #CHECK: for (int i = 0; i < 16; i++) { |
1923 | #CHECK: C[i] = B[i]; |
1924 | #CHECK: for (int j = 0; j < 32; j++) { |
1925 | #CHECK: C[i] = (C[i]) + (A[i + 16 * j]); |
1926 | #CHECK: } |
1927 | #CHECK: } |
1928 | )IR" ; |
1929 | torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
1930 | } |
1931 | } // namespace jit |
1932 | } // namespace torch |
1933 | |