1#include <memory>
2#include <sstream>
3#include <stdexcept>
4#include <unordered_map>
5
6#include <gtest/gtest.h>
7
8#include <c10/util/irange.h>
9#include <test/cpp/tensorexpr/padded_buffer.h>
10#include <torch/csrc/jit/tensorexpr/analysis.h>
11#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
12#include <torch/csrc/jit/tensorexpr/eval.h>
13#include <torch/csrc/jit/tensorexpr/ir.h>
14#include <torch/csrc/jit/tensorexpr/ir_printer.h>
15#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
16#include <torch/csrc/jit/tensorexpr/loopnest.h>
17#include <torch/csrc/jit/tensorexpr/tensor.h>
18
19namespace torch {
20namespace jit {
21
22using namespace torch::jit::tensorexpr;
23
24static void verifyConstBounds(
25 const TensorAccessBoundsInfo& access_info,
26 const std::vector<std::pair<int, int>>& ref) {
27 size_t ndim = ref.size();
28 ASSERT_EQ(access_info.start.size(), ndim);
29 ASSERT_EQ(access_info.stop.size(), ndim);
30 for (const auto i : c10::irange(ndim)) {
31 if (ref[i].first >= 0) { // Negative values are used to skip the check
32 ASSERT_TRUE(access_info.start[i]->isConstant());
33 int start_i = immediateAs<int>(access_info.start[i]);
34 ASSERT_EQ(start_i, ref[i].first);
35 }
36 if (ref[i].second >= 0) {
37 ASSERT_TRUE(access_info.stop[i]->isConstant());
38 int stop_i = immediateAs<int>(access_info.stop[i]);
39 ASSERT_EQ(stop_i, ref[i].second);
40 }
41 }
42}
43
44TEST(BoundsInference, _1) {
45 // Verify that bounds inference works for the following example:
46 // for i in 0..100:
47 // b[i] = a[i]
48 // For this loop bounds inference should yield the following:
49 // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}}
50 ExprHandle n(100);
51 BufHandle a("a", {n}, kFloat);
52 Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
53 LoopNest l({b});
54 auto bounds_info = inferBounds(l.root_stmt());
55
56 // We should have two entries: one for 'b' and one for 'a'.
57 ASSERT_EQ(bounds_info.size(), 2);
58 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
59 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
60 verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}});
61
62 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
63 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
64 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}});
65}
66
67TEST(BoundsInference, _2) {
68 // Verify that bounds inference works for the following example:
69 // for i in 0..n:
70 // b[i] = a[i]
71 // For this loop bounds inference should yield the following:
72 // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}}
73 VarHandle n("n", kInt);
74 BufHandle a("a", {n}, kFloat);
75 Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
76 LoopNest l({b});
77 auto bounds_info = inferBounds(l.root_stmt());
78
79 // We should have two entries: one for 'b' and one for 'a'.
80 ASSERT_EQ(bounds_info.size(), 2);
81 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
82 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
83 verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}});
84
85 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
86 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
87 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}});
88}
89
90TEST(BoundsInference, _3) {
91 // Verify that bounds inference works for the following example:
92 // for i in 0..100:
93 // b[i] = a[i] * a[i+10]
94 // For this loop bounds inference should yield the following:
95 // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}}
96 ExprHandle n(100);
97 BufHandle a("a", {n + 10}, kFloat);
98 Tensor b = Compute(
99 "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); });
100 LoopNest l({b});
101 auto bounds_info = inferBounds(l.root_stmt());
102
103 // We should have two entries: one for 'b' and one for 'a'.
104 ASSERT_EQ(bounds_info.size(), 2);
105 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
106 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
107 verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}});
108
109 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
110 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
111 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}});
112}
113
114TEST(BoundsInference, _4) {
115 // Verify that bounds inference works for the following example:
116 //
117 // for y in 0..200:
118 // for x in 0..320:
119 // b[y,x] = x*y
120 // for y in 0..200:
121 // for x in 0..320:
122 // c[y,x] = a[y,x] * b[y,x]
123 ExprHandle W(320);
124 ExprHandle H(200);
125 BufHandle a("a", {H, W}, kFloat);
126 Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
127 return x * y;
128 });
129 Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
130 return a.load(y, x) * b.load(y, x);
131 });
132 LoopNest l({c});
133 std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
134 StmtPtr body = l.getLoopBodyFor(c);
135 {
136 // Infer bounds on the top-level loop scope
137 auto bounds_info = inferBounds(loops[0]);
138 ASSERT_EQ(bounds_info.size(), 3);
139
140 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
141 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
142 verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}});
143
144 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
145 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
146 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}});
147
148 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
149 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
150 verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}});
151 }
152 {
153 // Infer bounds on the inner loop scope
154 auto bounds_info = inferBounds(loops[1]);
155 ASSERT_EQ(bounds_info.size(), 3);
156
157 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
158 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
159 verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}});
160
161 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
162 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
163 verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}});
164
165 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
166 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
167 verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}});
168 }
169 {
170 // Infer bounds on the inner loop body's scope
171 auto bounds_info = inferBounds(body);
172 ASSERT_EQ(bounds_info.size(), 3);
173
174 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
175 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
176 verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}});
177
178 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
179 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
180 verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}});
181
182 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
183 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
184 verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}});
185 }
186}
187
188TEST(BoundsInference, _5) {
189 // Verify that bounds inference works for the following example:
190 // for i in 0..100:
191 // b[i] = a[i]
192 //
193 // ==> split ==>
194 //
195 // for i_outer in 0..100/16:
196 // for i_inner in 0..16:
197 // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner]
198 // for i_tail in 0..100%16:
199 // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16];
200 ExprHandle n(100);
201 BufHandle a("a", {n}, kFloat);
202 Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
203 LoopNest l({b});
204
205 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
206 ForPtr inner;
207 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
208 ForPtr tail;
209 std::vector<ForPtr> loops = l.getLoopStmtsFor(b);
210 LoopNest::splitWithTail(loops[0], 16, &inner, &tail);
211 ForPtr outer = loops[0];
212
213 {
214 // Verify inferred bounds for the outer loop
215 auto bounds_info = inferBounds(outer);
216 ASSERT_EQ(bounds_info.size(), 2);
217
218 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
219 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
220 verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}});
221
222 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
223 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
224 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}});
225 }
226 {
227 // Verify inferred bounds for the tail loop
228 auto bounds_info = inferBounds(tail);
229 ASSERT_EQ(bounds_info.size(), 2);
230
231 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
232 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
233 verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}});
234
235 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
236 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
237 verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}});
238 }
239}
240
241TEST(BoundsInference, _6) {
242 // Verify that bounds inference works for the following example:
243 //
244 // for y in 0..200:
245 // for x in 0..320:
246 // b[y,x] = x*y
247 // for y in 0..20:
248 // for x in 0..32:
249 // c[y,x] = a[y+100,x+100] * b[y*2,x*5]
250 ExprHandle W(320);
251 ExprHandle H(200);
252 ExprHandle CW(32);
253 ExprHandle CH(20);
254 BufHandle a("a", {H, W}, kFloat);
255 Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
256 return x * y;
257 });
258 Tensor c =
259 Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) {
260 return a.load(y + 100, x + 100) * b.load(y * 2, x * 5);
261 });
262 LoopNest l({c});
263 std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
264 StmtPtr body = l.getLoopBodyFor(c);
265 {
266 // Infer bounds on the top-level loop scope
267 auto bounds_info = inferBounds(loops[0]);
268 ASSERT_EQ(bounds_info.size(), 3);
269
270 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
271 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
272 verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}});
273
274 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
275 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
276 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}});
277
278 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
279 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
280 verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}});
281 }
282 {
283 // Infer bounds on the inner loop scope
284 auto bounds_info = inferBounds(loops[1]);
285 ASSERT_EQ(bounds_info.size(), 3);
286
287 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
288 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
289 verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}});
290
291 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
292 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
293 verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}});
294
295 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
296 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
297 verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}});
298 }
299 {
300 // Infer bounds on the inner loop body's scope
301 auto bounds_info = inferBounds(body);
302 ASSERT_EQ(bounds_info.size(), 3);
303
304 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
305 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
306 verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}});
307
308 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
309 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
310 verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}});
311
312 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
313 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
314 verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}});
315 }
316}
317
318TEST(BoundsInference, Adjacent) {
319 ExprHandle H(6);
320 BufHandle a("a", {20}, kFloat);
321 Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); });
322 Tensor c =
323 Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); });
324 LoopNest l({b, c});
325 std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
326
327 {
328 // Infer bounds on the top-level loop scope
329 auto bounds_info = inferBounds(loops[0]);
330 ASSERT_EQ(bounds_info.size(), 2);
331
332 // reads from a[0:5], writes to b[0:5]
333 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
334 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
335 verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}});
336
337 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
338 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
339 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}});
340 }
341 {
342 // Infer bounds on the inner loop scope
343 auto bounds_info = inferBounds(loops[1]);
344 ASSERT_EQ(bounds_info.size(), 2);
345
346 // reads from a[0+6:5+6], writes to c[0:5]
347 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
348 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
349 verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}});
350
351 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
352 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
353 verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}});
354 }
355 {
356 // Infer bounds on the high level program.
357 auto bounds_info = inferBounds(l.root_stmt());
358 ASSERT_EQ(bounds_info.size(), 3);
359
360 // Should be union of above 2 bounds, but this time the bounds of A can be
361 // merged.
362 ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
363 ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
364 verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}});
365
366 ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
367 ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
368 verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}});
369
370 ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
371 ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
372 verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}});
373 }
374}
375
376TEST(BoundsInference, MultipleTopLoopLoad) {
377 BufHandle a("a", {100}, kFloat);
378 Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); });
379 Tensor c =
380 Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); });
381 Tensor d =
382 Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); });
383 LoopNest l({b, c, d});
384
385 auto bounds_info = inferBounds(l.root_stmt());
386
387 ASSERT_EQ(bounds_info.size(), 4);
388
389 // a only read.
390 {
391 auto bounds = bounds_info[a.node()];
392 ASSERT_EQ(bounds.size(), 1);
393 // One dimension.
394 auto bound = bounds[0];
395 ASSERT_EQ(bound.kind, TensorAccessKind::kLoad);
396 // Bounds:
397 // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b).
398 // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 =
399 // 96 + 2 - 1 (d).
400 verifyConstBounds(bound, {{0, 97}});
401 }
402
403 // b, c, d only written.
404 {
405 auto bounds = bounds_info[b.buf()];
406 ASSERT_EQ(bounds.size(), 1);
407 auto bound = bounds[0];
408 ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
409 // Just the loop extents for b.
410 verifyConstBounds(bound, {{0, 63}});
411 }
412 {
413 auto bounds = bounds_info[c.buf()];
414 ASSERT_EQ(bounds.size(), 1);
415 auto bound = bounds[0];
416 ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
417 // Just the loop extents for c.
418 verifyConstBounds(bound, {{0, 31}});
419 }
420 {
421 auto bounds = bounds_info[d.buf()];
422 ASSERT_EQ(bounds.size(), 1);
423 auto bound = bounds[0];
424 ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
425 // Just the loop extents for d.
426 verifyConstBounds(bound, {{0, 95}});
427 }
428}
429
430TEST(BoundsInference, MultipleTopLoopStore) {
431 BufHandle a("a", {100}, kFloat);
432 BufHandle b("b", {100}, kFloat);
433 BufHandle c("c", {100}, kFloat);
434 BufHandle d("d", {100}, kFloat);
435 VarHandle x("x", kInt);
436
437 // Same as above but the offsets are on the Store now.
438 // Can't do this through ComputeAPI without transforms we don't have yet.
439 StmtPtr stmt = Block::make(
440 {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))),
441 For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))),
442 For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))});
443
444 auto bounds_info = inferBounds(stmt);
445
446 ASSERT_EQ(bounds_info.size(), 4);
447
448 // a only read.
449 {
450 auto bounds = bounds_info[a.node()];
451 ASSERT_EQ(bounds.size(), 1);
452 // One dimension.
453 auto bound = bounds[0];
454 ASSERT_EQ(bound.kind, TensorAccessKind::kLoad);
455 // Bounds: there are no offsets, so this is just the max loop bounds.
456 verifyConstBounds(bound, {{0, 95}});
457 }
458
459 // b, c, d only written.
460 {
461 auto bounds = bounds_info[b.node()];
462 ASSERT_EQ(bounds.size(), 1);
463 auto bound = bounds[0];
464 ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
465 // This should be equivalent to {offset, extent + offset} for the b loop.
466 // b loop has no offset, so just the loop extents.
467 verifyConstBounds(bound, {{0, 63}});
468 }
469 {
470 auto bounds = bounds_info[c.node()];
471 ASSERT_EQ(bounds.size(), 1);
472 auto bound = bounds[0];
473 ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
474 // This should be equivalent to {offset, extent + offset} for the c loop.
475 // Offset is 10, extent is 32-1.
476 verifyConstBounds(bound, {{10, 41}});
477 }
478 {
479 auto bounds = bounds_info[d.node()];
480 ASSERT_EQ(bounds.size(), 1);
481 auto bound = bounds[0];
482 ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
483 // This should be equivalent to {offset, extent + offset} for the d loop.
484 // Offset is 2, extent is 96-1.
485 verifyConstBounds(bound, {{2, 97}});
486 }
487}
488
489TEST(BoundsInference, CacheReads) {
490 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
491 return i * j;
492 });
493 Tensor B =
494 Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
495 return A.load(i + 30, j + 3);
496 });
497 Tensor C =
498 Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
499 return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
500 });
501
502 LoopNest l({B, C});
503 auto bounds_info_before = inferBounds(l.root_stmt());
504
505 StmtPtr j_loop = l.getLoopStmtsFor(B)[1];
506 LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
507
508 auto bounds_info_after = inferBounds(l.root_stmt());
509
510 // CacheAccesses should not change existing bounds, but add a new one for the
511 // cache.
512 for (auto& pair : bounds_info_after) {
513 auto beforeIt = bounds_info_before.find(pair.first);
514 if (beforeIt != bounds_info_before.end()) {
515 // Same number of TensorAccessBoundInfos.
516 ASSERT_EQ(pair.second.size(), beforeIt->second.size());
517
518 for (const auto i : c10::irange(pair.second.size())) {
519 TensorAccessBoundsInfo& after = pair.second[i];
520 TensorAccessBoundsInfo& before = beforeIt->second[i];
521 // Same number of dimensions.
522 ASSERT_EQ(before.start.size(), after.start.size());
523
524 // Bounds are equal.
525 for (const auto j : c10::irange(before.start.size())) {
526 ASSERT_TRUE(exprEquals(before.start[j], after.start[j]));
527 ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j]));
528 }
529 }
530 } else {
531 // This should be the cache.
532 ASSERT_EQ(pair.first->name_hint(), "A_local");
533 // Should have both a load and a store.
534 ASSERT_EQ(pair.second.size(), 2);
535 TensorAccessBoundsInfo& first = pair.second[0];
536 TensorAccessBoundsInfo& second = pair.second[1];
537
538 ASSERT_NE(first.kind, second.kind);
539 // 2 dimensions.
540 ASSERT_EQ(first.start.size(), second.start.size());
541 ASSERT_EQ(first.start.size(), 2);
542
543 // bounds for load and store are equal.
544 for (const auto j : c10::irange(first.start.size())) {
545 ASSERT_TRUE(exprEquals(first.start[j], second.start[j]));
546 ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j]));
547 }
548 }
549 }
550}
551
552TEST(BoundsInference, Flattened) {
553 Tensor b = Compute(
554 "b",
555 {3, 4, 5},
556 [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) {
557 return x * y + z;
558 });
559
560 LoopNest l({b});
561 // Flatten indices.
562 l.prepareForCodegen();
563 auto bounds_info = inferBounds(l.root_stmt());
564
565 // There's only one buffer.
566 ASSERT_EQ(bounds_info.size(), 1);
567 auto& TABI = bounds_info[b.buf()][0];
568 ASSERT_EQ(TABI.kind, TensorAccessKind::kStore);
569 // Flattened bounds should have a single dimension.
570 ASSERT_EQ(TABI.start.size(), 1);
571 ASSERT_EQ(TABI.stop.size(), 1);
572
573 // Bounds should be 0 -> (3*4*5)-1
574 ASSERT_TRUE(exprEquals(TABI.start[0], alloc<IntImm>(0)));
575 ASSERT_TRUE(exprEquals(TABI.stop[0], alloc<IntImm>(3 * 4 * 5 - 1)));
576}
577
578TEST(BoundsInference, GetPotentialHazards) {
579 BufHandle a("A", {5}, kInt);
580 BufHandle b("B", {5}, kInt);
581 BufHandle c("C", {5}, kInt);
582 VarHandle x("x", kInt);
583 VarHandle y("y", kInt);
584
585 using namespace analysis;
586
587 {
588 /*
589 * A[0] = B[0];
590 * B[0] = 3; WAR on B
591 * A[0] = B[0]; WAW on A, RAW on B
592 * C[0] = 5;
593 */
594
595 StorePtr store1 = Store::make(a, {0}, Load::make(b, {0}));
596 StorePtr store2 = Store::make(b, {0}, 3);
597 StorePtr store3 = Store::make(a, {0}, Load::make(b, {0}));
598 StorePtr store4 = Store::make(c, {0}, 5);
599 StmtPtr stmt = Block::make({store1, store2, store3, store4});
600
601 MemDependencyChecker analyzer;
602 stmt->accept(&analyzer);
603
604 ASSERT_EQ(
605 HazardKind::WriteAfterRead,
606 getPotentialHazards(analyzer, store1, store2));
607
608 ASSERT_EQ(
609 HazardKind::ReadAfterWrite,
610 getPotentialHazards(analyzer, store2, store3));
611
612 ASSERT_EQ(
613 HazardKind::WriteAfterWrite,
614 getPotentialHazards(analyzer, store1, store3));
615
616 // Fourth store has no dependencies
617 ASSERT_EQ(
618 HazardKind::NoDependency,
619 getPotentialHazards(analyzer, store1, store4));
620 ASSERT_EQ(
621 HazardKind::NoDependency,
622 getPotentialHazards(analyzer, store2, store4));
623 ASSERT_EQ(
624 HazardKind::NoDependency,
625 getPotentialHazards(analyzer, store3, store4));
626 }
627}
628
629TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) {
630 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
631 return i * j;
632 });
633 Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
634 return (i + 1) * (j + 1);
635 });
636
637 LoopNest l({A, B});
638
639 using namespace analysis;
640
641 MemDependencyChecker analyzer;
642 l.root_stmt()->accept(&analyzer);
643
644 ForPtr loopRootA = l.getLoopStmtsFor(A)[0];
645 ForPtr loopRootB = l.getLoopStmtsFor(B)[0];
646
647 // No dependencies between loops.
648 ASSERT_EQ(
649 HazardKind::NoDependency,
650 getPotentialHazards(analyzer, loopRootA, loopRootB));
651}
652
653TEST(BoundsInference, GetPotentialHazardsLoopCall) {
654 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
655 return i * j;
656 });
657 Tensor B =
658 Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) {
659 return A.load(i, j) + 5;
660 });
661
662 LoopNest l({A, B});
663
664 using namespace analysis;
665
666 MemDependencyChecker analyzer;
667 l.root_stmt()->accept(&analyzer);
668
669 ForPtr loopRootA = l.getLoopStmtsFor(A)[0];
670 ForPtr loopRootB = l.getLoopStmtsFor(B)[0];
671
672 ASSERT_EQ(
673 HazardKind::ReadAfterWrite,
674 getPotentialHazards(analyzer, loopRootA, loopRootB));
675}
676
677TEST(BoundsInference, GetPotentialHazardsLoopSplit) {
678 Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
679 return i * j;
680 });
681
682 LoopNest l({A});
683 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
684 ForPtr inner, tail;
685
686 // Splitting with tail by something offset creates a tail which also writes to
687 // A.
688 ForPtr outer = l.getLoopStmtsFor(A)[0];
689 // `outer` loop get transformed to the outer loop after splitting.
690 LoopNest::splitWithTail(outer, 5, &inner, &tail);
691
692 using namespace analysis;
693
694 MemDependencyChecker analyzer;
695 l.root_stmt()->accept(&analyzer);
696
697 ASSERT_EQ(
698 HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail));
699}
700
701TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) {
702 // Input IR:
703 // for (const auto j : c10::irange(10, 100)) {
704 // A[j] = 10 * j;
705 // }
706 // for (const auto k : c10::irange(10, 100)) {
707 // A[k-1] = 20 * k;
708 // }
709 BufHandle a_buf("A", {200}, kInt);
710 VarHandle j("j", kInt);
711 VarHandle k("k", kInt);
712 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
713 auto forK =
714 For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k)));
715 auto par = Block::make({forJ, forK});
716
717 tensorexpr::analysis::MemDependencyChecker analyzer;
718 par->accept(&analyzer);
719 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
720 ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
721}
722
723TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) {
724 // Input IR:
725 // for (const auto j : c10::irange(10, 100)) {
726 // A[j] = 10 * j;
727 // }
728 // for (const auto k : c10::irange(10, 100)) {
729 // A[k] = 20 * k;
730 // }
731 BufHandle a_buf("A", {200}, kInt);
732 VarHandle j("j", kInt);
733 VarHandle k("k", kInt);
734 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
735 auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k)));
736 auto par = Block::make({forJ, forK});
737
738 tensorexpr::analysis::MemDependencyChecker analyzer;
739 par->accept(&analyzer);
740 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
741 ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
742}
743
744TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) {
745 // Input IR:
746 // for (const auto j : c10::irange(10, 100)) {
747 // A[j] = 10 * j;
748 // }
749 // for (const auto k : c10::irange(10, 100)) {
750 // B[k] = A[k];
751 // }
752 BufHandle a_buf("A", {200}, kInt);
753 BufHandle b_buf("B", {200}, kInt);
754 VarHandle j("j", kInt);
755 VarHandle k("k", kInt);
756 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
757 auto forK =
758 For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k})));
759 auto par = Block::make({forJ, forK});
760
761 tensorexpr::analysis::MemDependencyChecker analyzer;
762 par->accept(&analyzer);
763 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
764 ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
765}
766
767TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) {
768 // Input IR:
769 // for (const auto j : c10::irange(10, 100)) {
770 // A[j] = 10 * j;
771 // }
772 // for (const auto k : c10::irange(10, 100)) {
773 // A[k+100] = 20 * k;
774 // }
775 BufHandle a_buf("A", {200}, kInt);
776 VarHandle j("j", kInt);
777 VarHandle k("k", kInt);
778 auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
779 auto forK =
780 For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k)));
781 auto par = Block::make({forJ, forK});
782
783 tensorexpr::analysis::MemDependencyChecker analyzer;
784 par->accept(&analyzer);
785 ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK));
786 ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ));
787}
788
789TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) {
790 // Input IR:
791 // for (const auto i : c10::irange(20)) {
792 // for (const auto j : c10::irange(100)) {
793 // A[i,j] = i * j * 500;
794 // }
795 // }
796 // for (const auto m : c10::irange(20)) {
797 // for (const auto n : c10::irange(50)) {
798 // A[m+1,n] = m + n * 100;
799 // }
800 // }
801 BufHandle a_buf("A", {20, 100}, kInt);
802 BufHandle b_buf("B", {20, 50}, kInt);
803 VarHandle i("i", kInt);
804 VarHandle j("j", kInt);
805 VarHandle m("m", kInt);
806 VarHandle n("n", kInt);
807 auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
808 auto forJ = For::make(j, 0, 100, storeA1);
809 auto forI = For::make(i, 0, 20, forJ);
810 auto storeA2 =
811 Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100)));
812 auto forN = For::make(n, 0, 50, storeA2);
813 auto forM = For::make(m, 0, 20, forN);
814 auto par = Block::make({forI, forM});
815
816 tensorexpr::analysis::MemDependencyChecker analyzer;
817 par->accept(&analyzer);
818 ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM));
819 ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI));
820 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN));
821 ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ));
822 ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2));
823 ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1));
824 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2));
825 ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM));
826}
827
828TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) {
829 // Input IR:
830 // for (const auto i : c10::irange(20)) {
831 // for (const auto j : c10::irange(100)) {
832 // A[i,j] = i * j * 500;
833 // }
834 // }
835 // for (const auto m : c10::irange(20)) {
836 // for (const auto n : c10::irange(50)) {
837 // A[m+20,n+100] = m + n * 100;
838 // }
839 // }
840 BufHandle a_buf("A", {20, 100}, kInt);
841 BufHandle b_buf("B", {20, 50}, kInt);
842 VarHandle i("i", kInt);
843 VarHandle j("j", kInt);
844 VarHandle m("m", kInt);
845 VarHandle n("n", kInt);
846 auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
847 auto forJ = For::make(j, 0, 100, storeA1);
848 auto forI = For::make(i, 0, 20, forJ);
849 auto storeA2 =
850 Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100)));
851 auto forN = For::make(n, 0, 50, storeA2);
852 auto forM = For::make(m, 0, 20, forN);
853 auto par = Block::make({forI, forM});
854
855 tensorexpr::analysis::MemDependencyChecker analyzer;
856 par->accept(&analyzer);
857 ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM));
858 ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI));
859 ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN));
860 ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ));
861 ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2));
862 ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1));
863 ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2));
864 ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM));
865}
866
867TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) {
868 // Input IR:
869 // for (const auto i : c10::irange(20)) {
870 // for (const auto j : c10::irange(100)) {
871 // A[i,j] = i * j * 500;
872 // }
873 // }
874 // for (const auto m : c10::irange(20)) {
875 // for (const auto n : c10::irange(50)) {
876 // B[m,n] = m + n * 100;
877 // }
878 // }
879 BufHandle a_buf("A", {20, 100}, kInt);
880 BufHandle b_buf("B", {20, 50}, kInt);
881 VarHandle i("i", kInt);
882 VarHandle j("j", kInt);
883 VarHandle m("m", kInt);
884 VarHandle n("n", kInt);
885 auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
886 auto forJ = For::make(j, 0, 100, storeA1);
887 auto forI = For::make(i, 0, 20, forJ);
888 auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100)));
889 auto forN = For::make(n, 0, 50, storeA2);
890 auto forM = For::make(m, 0, 20, forN);
891 auto par = Block::make({forI, forM});
892
893 tensorexpr::analysis::MemDependencyChecker analyzer;
894 par->accept(&analyzer);
895 ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM));
896 ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI));
897 ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN));
898 ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ));
899 ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2));
900 ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1));
901 ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2));
902 ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM));
903}
904
905TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) {
906 // Input IR:
907 // for (const auto j : c10::irange(100)) {
908 // A[j] = 10 * j;
909 // }
910 // for (const auto k : c10::irange(100)) {
911 // B[k] = 20 * A[99-k];
912 // }
913 BufHandle a_buf("A", {100}, kInt);
914 BufHandle b_buf("B", {100}, kInt);
915 VarHandle j("j", kInt);
916 VarHandle k("k", kInt);
917 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
918 auto forK = For::make(
919 k,
920 0,
921 100,
922 Store::make(
923 b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
924 auto par = Block::make({forJ, forK});
925
926 tensorexpr::analysis::MemDependencyChecker analyzer;
927 par->accept(&analyzer);
928 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
929 ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
930}
931
932TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) {
933 // Input IR:
934 // for (const auto k : c10::irange(100)) {
935 // B[k] = 20 * A[99-k];
936 // }
937 // for (const auto j : c10::irange(100)) {
938 // A[j] = 10 * j;
939 // }
940 BufHandle a_buf("A", {100}, kInt);
941 BufHandle b_buf("B", {100}, kInt);
942 VarHandle j("j", kInt);
943 VarHandle k("k", kInt);
944 auto forK = For::make(
945 k,
946 0,
947 100,
948 Store::make(
949 b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
950 auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
951 auto par = Block::make({forK, forJ});
952
953 tensorexpr::analysis::MemDependencyChecker analyzer;
954 par->accept(&analyzer);
955 ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
956 ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
957}
958
959TEST(BoundsInference, HasConflictingOverlapWithLoads) {
960 // Input IR:
961 // for (const auto k : c10::irange(10, 100)) {
962 // B[k] = 20 * A[99-k];
963 // }
964 // for (const auto j : c10::irange(10, 100)) {
965 // C[j] = 10 * A[j];
966 // }
967 BufHandle a_buf("A", {100}, kInt);
968 BufHandle b_buf("B", {100}, kInt);
969 BufHandle c_buf("C", {100}, kInt);
970 VarHandle j("j", kInt);
971 VarHandle k("k", kInt);
972 auto forK = For::make(
973 k,
974 10,
975 100,
976 Store::make(
977 b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
978 auto forJ = For::make(
979 j,
980 10,
981 100,
982 Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j}))));
983 auto par = Block::make({forK, forJ});
984
985 tensorexpr::analysis::MemDependencyChecker analyzer;
986 par->accept(&analyzer);
987 ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK));
988 ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ));
989}
990
991TEST(BoundsInference, IsOverlapping) {
992 // Input IR:
993 // for (const auto i : c10::irange(100)) {
994 // A[i] = i * 10; // storeA1
995 // B[i] = A[99-i] * 20; // loadA1
996 // C[i] = A[i + 100] * 10; // loadA2
997 // A[i + 50] = i * 50; // storeA2
998 // A[i + 150] = i * 150; // storeA3
999 // }
1000 BufHandle a_buf("A", {300}, kInt);
1001 BufHandle b_buf("B", {100}, kInt);
1002 BufHandle c_buf("C", {100}, kInt);
1003 VarHandle i("i", kInt);
1004 auto storeA1 = Store::make(a_buf, {i}, i * 10);
1005 auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i});
1006 auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20));
1007 auto loadA2 = Load::make(a_buf, {i + 100});
1008 auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10));
1009 auto storeA2 = Store::make(a_buf, {i + 50}, i * 50);
1010 auto storeA3 = Store::make(a_buf, {i + 150}, i * 150);
1011 auto forI = For::make(
1012 i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3}));
1013 tensorexpr::analysis::MemDependencyChecker analyzer;
1014 forI->accept(&analyzer);
1015 ASSERT_TRUE(isOverlapping(analyzer, storeA1, to<Load>(loadA1.node())));
1016 ASSERT_FALSE(isOverlapping(analyzer, storeA1, to<Load>(loadA2.node())));
1017 ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2));
1018 ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3));
1019}
1020
1021} // namespace jit
1022} // namespace torch
1023