1#include <gtest/gtest.h>
2#include <test/cpp/tensorexpr/test_base.h>
3
4#include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
5#include <torch/csrc/jit/tensorexpr/ir.h>
6#include <torch/csrc/jit/tensorexpr/ir_printer.h>
7#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8#include <torch/csrc/jit/tensorexpr/loopnest.h>
9#include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
10#include <torch/csrc/jit/tensorexpr/tensor.h>
11
12namespace torch {
13namespace jit {
14
15using namespace torch::jit::tensorexpr;
16
17// Test helper function used to determine if two regions of a buffer have an
18// overlap. No Overlap & partial overlap is obvious. Contains means A is
19// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal
20// ranges are ContainedOrEqual.
21TEST(MemDependency, BoundOverlap) {
22 using namespace analysis;
23
24 auto CB = [](int s, int e) {
25 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
26 };
27
28 // Sanity check 3 overlap cases.
29 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0)));
30 ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5)));
31 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1)));
32
33 // Partial overlap works in either order.
34 ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14)));
35 ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10)));
36
37 // Total Overlap works when one bound encloses the other, and returns which.
38 ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9)));
39 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16)));
40
41 // Total overlap works when the bounds are an identical range, returns
42 // ContainedOrEqual.
43 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15)));
44
45 // Total overlap when only one end of the bound matches.
46 ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10)));
47 ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15)));
48 ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9)));
49 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15)));
50 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15)));
51
52 // No overlap when a < b.
53 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10)));
54 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3)));
55 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130)));
56
57 // No overlap when a > b.
58 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2)));
59 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2)));
60 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120)));
61
62 // No overlap when adjacent.
63 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120)));
64 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1)));
65
66 // Partial overlap when middle bounds match.
67 ASSERT_EQ(
68 OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120)));
69 ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4)));
70 ASSERT_EQ(
71 OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100)));
72 ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2)));
73
74 // Total overlap when one bound is single length over one end of the other.
75 ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15)));
76 ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2)));
77 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15)));
78 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15)));
79}
80
81TEST(MemDependency, BoundComparison) {
82 using namespace analysis;
83
84 auto CB = [](int s, int e) {
85 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
86 };
87
88 ASSERT_EQ(
89 CmpEvalResult::NotDetermined,
90 compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ));
91 ASSERT_EQ(
92 CmpEvalResult::True,
93 compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ));
94 ASSERT_EQ(
95 CmpEvalResult::False,
96 compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ));
97 ASSERT_EQ(
98 CmpEvalResult::False,
99 compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ));
100 ASSERT_EQ(
101 CmpEvalResult::NotDetermined,
102 compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ));
103 ASSERT_EQ(
104 CmpEvalResult::NotDetermined,
105 compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
106 ASSERT_EQ(
107 CmpEvalResult::NotDetermined,
108 compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ));
109
110 ASSERT_EQ(
111 CmpEvalResult::NotDetermined,
112 compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE));
113 ASSERT_EQ(
114 CmpEvalResult::False,
115 compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE));
116 ASSERT_EQ(
117 CmpEvalResult::True,
118 compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE));
119 ASSERT_EQ(
120 CmpEvalResult::True,
121 compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE));
122 ASSERT_EQ(
123 CmpEvalResult::NotDetermined,
124 compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE));
125 ASSERT_EQ(
126 CmpEvalResult::NotDetermined,
127 compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
128 ASSERT_EQ(
129 CmpEvalResult::NotDetermined,
130 compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE));
131
132 ASSERT_EQ(
133 CmpEvalResult::True,
134 compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT));
135 ASSERT_EQ(
136 CmpEvalResult::False,
137 compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT));
138 ASSERT_EQ(
139 CmpEvalResult::False,
140 compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT));
141 ASSERT_EQ(
142 CmpEvalResult::NotDetermined,
143 compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT));
144 ASSERT_EQ(
145 CmpEvalResult::NotDetermined,
146 compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT));
147 ASSERT_EQ(
148 CmpEvalResult::NotDetermined,
149 compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT));
150
151 ASSERT_EQ(
152 CmpEvalResult::False,
153 compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE));
154 ASSERT_EQ(
155 CmpEvalResult::True,
156 compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE));
157 ASSERT_EQ(
158 CmpEvalResult::True,
159 compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE));
160 ASSERT_EQ(
161 CmpEvalResult::NotDetermined,
162 compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE));
163 ASSERT_EQ(
164 CmpEvalResult::NotDetermined,
165 compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE));
166 ASSERT_EQ(
167 CmpEvalResult::NotDetermined,
168 compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE));
169
170 ASSERT_EQ(
171 CmpEvalResult::False,
172 compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT));
173 ASSERT_EQ(
174 CmpEvalResult::False,
175 compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT));
176 ASSERT_EQ(
177 CmpEvalResult::NotDetermined,
178 compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT));
179 ASSERT_EQ(
180 CmpEvalResult::True,
181 compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT));
182 ASSERT_EQ(
183 CmpEvalResult::NotDetermined,
184 compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT));
185 ASSERT_EQ(
186 CmpEvalResult::NotDetermined,
187 compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT));
188
189 ASSERT_EQ(
190 CmpEvalResult::True,
191 compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE));
192 ASSERT_EQ(
193 CmpEvalResult::True,
194 compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE));
195 ASSERT_EQ(
196 CmpEvalResult::NotDetermined,
197 compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE));
198 ASSERT_EQ(
199 CmpEvalResult::False,
200 compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE));
201 ASSERT_EQ(
202 CmpEvalResult::NotDetermined,
203 compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE));
204 ASSERT_EQ(
205 CmpEvalResult::NotDetermined,
206 compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE));
207}
208
209TEST(MemDependency, BoundOverlapSymbolic) {
210 VarHandle x("x", kInt);
211 VarHandle y("y", kInt);
212 VarHandle z("z", kInt);
213 VarHandle w("w", kInt);
214
215 using namespace analysis;
216
217 auto CB = [](ExprHandle s, ExprHandle e) {
218 return Bound(s.node(), e.node());
219 };
220
221 // Sanity check cases where the start and end is symbolic but the diff is
222 // constant.
223 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
224 ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x)));
225 ASSERT_EQ(
226 OverlapKind::PartialOverlap,
227 boundOverlap(CB(x, x + 3), CB(x + 2, x + 5)));
228 ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1)));
229
230 // We can't infer the sign of y, so cannot tell whether adding y is larger or
231 // smaller than y/2.
232 ASSERT_EQ(
233 OverlapKind::PartialOverlap,
234 boundOverlap(CB(x, x + y), CB(x, x + y / 2)));
235
236 // No information about this bound, have to take the most conservative option:
237 // there may be an overlap.
238 ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w)));
239
240 // Math on opaque terms works.
241 ASSERT_EQ(
242 OverlapKind::ContainedOrEqual,
243 boundOverlap(CB(x + w, y - z), CB(x + w, y - z)));
244 // Even requiring simplification.
245 ASSERT_EQ(
246 OverlapKind::ContainedOrEqual,
247 boundOverlap(CB(x - w - w, y), CB(x - w * 2, y)));
248}
249
250// Tests the helper function for overlap of multi dimensional indices bounds.
251// This uses boundOverlap on each dimension and return the "lowest" kind of
252// overlap.
253TEST(MemDependency, BoundOverlapMultiDim) {
254 using namespace analysis;
255
256 auto CB = [](int s, int e) {
257 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
258 };
259
260 // Sanity check one dimensional cases.
261 ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)}));
262 ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)}));
263 ASSERT_EQ(
264 OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)}));
265
266 // Total overlap in 3 dims.
267 ASSERT_EQ(
268 OverlapKind::ContainedOrEqual,
269 overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)}));
270 ASSERT_EQ(
271 OverlapKind::ContainedOrEqual,
272 overlaps(
273 {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)}));
274
275 // Total overlap in 2 dims, no overlap in another.
276 ASSERT_EQ(
277 OverlapKind::NoOverlap,
278 overlaps(
279 {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
280
281 // Total overlap in 2 dims, partial overlap in another.
282 ASSERT_EQ(
283 OverlapKind::PartialOverlap,
284 overlaps(
285 {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
286 // This case is most important, so verify the overlap in any dim. (dim 2)
287 ASSERT_EQ(
288 OverlapKind::PartialOverlap,
289 overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)}));
290 // Dim 1.
291 ASSERT_EQ(
292 OverlapKind::PartialOverlap,
293 overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)}));
294 // Total overlap in 1 dim, partial in 2.
295 ASSERT_EQ(
296 OverlapKind::PartialOverlap,
297 overlaps(
298 {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)}));
299 // Total overlap, partial overlap, no overlap.
300 ASSERT_EQ(
301 OverlapKind::NoOverlap,
302 overlaps(
303 {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)}));
304
305 // Total overlap (B) in 2 dims, total overlap (A) in another.
306 ASSERT_EQ(
307 OverlapKind::Contains,
308 overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)}));
309
310 // Total overlap (A) in 2 dims, total overlap (B) in another.
311 ASSERT_EQ(
312 OverlapKind::Contains,
313 overlaps(
314 {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)}));
315
316 // Total (B), No Overlap, Total (A).
317 ASSERT_EQ(
318 OverlapKind::NoOverlap,
319 overlaps(
320 {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)}));
321}
322
323// Test the helper we use to subtract bounds: returns the regions(s) of A which
324// remain after removing the region of B.
325TEST(MemDependency, BoundSubtract) {
326 using namespace analysis;
327
328 auto CB = [](int s, int e) {
329 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
330 };
331 auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
332 return indexBoundsEquals(x, y);
333 };
334
335 // One element subtract.
336 ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0);
337 ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0);
338
339 // No Overlap.
340 ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)}));
341 ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)}));
342
343 // one side overlap.
344 ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)}));
345 ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)}));
346 ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)}));
347 ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)}));
348
349 // both sides overlap.
350 ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {}));
351 ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {}));
352
353 // internal overlap.
354 ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)}));
355 ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)}));
356}
357
358TEST(MemDependency, BoundSubtractSymbolic) {
359 VarHandle x("x", kInt);
360 VarHandle y("y", kInt);
361 VarHandle z("z", kInt);
362 VarHandle w("w", kInt);
363
364 using namespace analysis;
365
366 auto CB = [](ExprHandle s, ExprHandle e) {
367 return Bound(s.node(), e.node());
368 };
369 auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
370 return indexBoundsEquals(x, y);
371 };
372
373 // One element subtract.
374 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
375 ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {}));
376 ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {}));
377 ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {}));
378
379 // Subtract constant range low.
380 ASSERT_TRUE(
381 EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)}));
382 // Subtract constant range high.
383 ASSERT_TRUE(
384 EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)}));
385 // Subtract constant range total overlap.
386 ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {}));
387 ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {}));
388 // Subtract constant range internal.
389 ASSERT_TRUE(
390 EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)),
391 {CB(x, x + 2), CB(x + 8, x + 10)}));
392
393 // Size is inferable but not constant, only works with a single var.
394 ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {}));
395 ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)}));
396
397 // Size is not inferable.
398 ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)}));
399 ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)}));
400 ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)}));
401 ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)}));
402}
403
404// Tests the helper function that does subtraction, but for multi dimensional
405// indices bounds.
406TEST(MemDependency, BoundSubtractMultiDim) {
407 using namespace analysis;
408
409 auto CB = [](int s, int e) {
410 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
411 };
412 auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
413 if (x.size() != y.size()) {
414 return false;
415 }
416 for (auto i = 0U; i < x.size(); ++i) {
417 if (!indexBoundsEquals(x[i], y[i])) {
418 return false;
419 }
420 }
421 return true;
422 };
423
424 // sanity check one dimension.
425 ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {}));
426 ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {}));
427 ASSERT_TRUE(
428 EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}}));
429 ASSERT_TRUE(
430 EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}}));
431 ASSERT_TRUE(EQ(
432 subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}}));
433
434 // Multi dim total overlap.
435 ASSERT_TRUE(EQ(
436 subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {}));
437 ASSERT_TRUE(EQ(
438 subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {}));
439
440 // Mutli dim one way partial in dim 1.
441 ASSERT_TRUE(
442 EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}),
443 {{CB(4, 9), CB(0, 2)}}));
444
445 // Mutli dim one way partial in dim 2.
446 ASSERT_TRUE(
447 EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}),
448 {{CB(0, 9), CB(11, 20)}}));
449
450 // Partial overlap in 2 dims.
451 ASSERT_TRUE(
452 EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}),
453 {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}}));
454
455 // Partial overlap in 3 dims.
456 ASSERT_TRUE(
457 EQ(subtractIndicesBounds(
458 {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}),
459 {{CB(0, 1), CB(0, 5), CB(0, 5)},
460 {CB(2, 5), CB(0, 1), CB(0, 5)},
461 {CB(2, 5), CB(2, 5), CB(0, 1)}}));
462}
463
464// Tests the multi dimensional subtraction code for bounds that cannot be fully
465// materialized.
466TEST(MemDependency, BoundSubtractMultiDimSymbolic) {
467 VarHandle x("x", kInt);
468 VarHandle y("y", kInt);
469
470 using namespace analysis;
471
472 auto CB = [](ExprHandle s, ExprHandle e) {
473 return Bound(s.node(), e.node());
474 };
475
476 auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
477 if (x.size() != y.size()) {
478 return false;
479 }
480 for (auto i = 0U; i < x.size(); ++i) {
481 if (!indexBoundsEquals(x[i], y[i])) {
482 return false;
483 }
484 }
485 return true;
486 };
487
488 // Cannot determine overlaps.
489 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
490 ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}}));
491
492 // Various total Overlaps.
493 ASSERT_TRUE(EQ(
494 subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {}));
495 ASSERT_TRUE(EQ(
496 subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {}));
497 ASSERT_TRUE(EQ(
498 subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {}));
499 ASSERT_TRUE(EQ(
500 subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {}));
501
502 // one-way overlap in first dim.
503 ASSERT_TRUE(
504 EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}),
505 {{CB(x - 4, x), CB(0, y)}}));
506 // second dim.
507 ASSERT_TRUE(
508 EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}),
509 {{CB(0, x), CB(0, 4)}}));
510
511 // Internal overlap in first dim.
512 ASSERT_TRUE(
513 EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}),
514 {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}}));
515 // second dim.
516 ASSERT_TRUE(EQ(
517 subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}),
518 {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}}));
519
520 // Overlap in both dimensions.
521 ASSERT_TRUE(
522 EQ(subtractIndicesBounds(
523 {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}),
524 {
525 {CB(0, 4), CB(0, y)},
526 {CB(x - 4, x), CB(0, y)},
527 {CB(0, x), CB(0, 9)},
528 {CB(0, x), CB(y - 9, y)},
529 }));
530}
531
532// Simple check that the analyzer does anything at all...
533TEST(MemDependency, MemDependencyCheckerSimple) {
534 BufHandle a("A", {1}, kInt);
535 BufHandle b("B", {1}, kInt);
536
537 analysis::MemDependencyChecker analyzer;
538
539 /*
540 * A[0] = 3;
541 * B[0] = A[0] + 1;
542 */
543
544 StorePtr aStore = Store::make(a, {0}, 3);
545 StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
546
547 StmtPtr stmt = Block::make({aStore, bStore});
548
549 stmt->accept(&analyzer);
550
551 ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
552 ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
553 // sanity check, but anything that depends directly must depend indirectly.
554 ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore));
555}
556
557// Check that there is a difference between direct and indirect dependence.
558TEST(MemDependency, MemDependencyCheckerMultiStmt) {
559 BufHandle a("A", {1}, kInt);
560 BufHandle b("B", {1}, kInt);
561 BufHandle c("C", {1}, kInt);
562
563 analysis::MemDependencyChecker analyzer;
564
565 /*
566 * A[0] = 3;
567 * B[0] = A[0];
568 * C[0] = B[0] + 1;
569 */
570
571 StorePtr aStore = Store::make(a, {0}, 3);
572 StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
573 StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
574
575 StmtPtr stmt = Block::make({aStore, bStore, cStore});
576
577 stmt->accept(&analyzer);
578
579 // C depends on A indirectly.
580 ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore));
581 ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore));
582
583 // C depends on B directly, which depends on A directly.
584 ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore));
585 ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
586
587 // Dependency goes top to bottom only.
588 ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore));
589 ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
590 ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore));
591}
592
593// Verify that we do filter writes that are totally overlapped by later writes.
594TEST(MemDependency, MemDependencyCheckerOverlap) {
595 BufHandle a("A", {1}, kInt);
596 BufHandle b("B", {1}, kInt);
597
598 analysis::MemDependencyChecker analyzer;
599
600 /*
601 * A[0] = 3;
602 * A[0] = 6;
603 * B[0] = A[0] + 1;
604 */
605
606 StorePtr aStore = Store::make(a, {0}, 3);
607 StorePtr a2Store = Store::make(a, {0}, 6);
608 StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
609
610 StmtPtr stmt = Block::make({aStore, a2Store, bStore});
611
612 stmt->accept(&analyzer);
613
614 // B store depends on second A store but not first since it is completely
615 // overlapped.
616 ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store));
617 ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore));
618
619 // No dependency between either A store.
620 ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store));
621 ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore));
622}
623
624// Verify that bounds match loop iterations, and that dependencies progress
625// across loop scopes.
626TEST(MemDependency, MemDependencyCheckerLoop) {
627 BufHandle a("A", {1}, kInt);
628 BufHandle b("B", {1}, kInt);
629 VarHandle x("x", kInt);
630
631 using namespace analysis;
632
633 MemDependencyChecker analyzer;
634
635 /*
636 * for (int x = 0; x < 10; ++x) {
637 * A[x] = x;
638 * }
639 * B[0] = A[0] + 1;
640 */
641
642 StorePtr aStore = Store::make(a, {x}, x);
643 StmtPtr loop = For::make(x, 0, 10, aStore);
644 StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
645
646 StmtPtr stmt = Block::make({loop, bStore});
647
648 stmt->accept(&analyzer);
649
650 // Same A->B dependency.
651 ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
652
653 // B depends on the loop.
654 ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
655 // A is in the loop but does not depend on any loop iteration.
656 ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop));
657
658 auto aStoreAccess = analyzer.accessFor(aStore);
659 ASSERT_NE(aStoreAccess, nullptr);
660
661 // It should have bounds covering the range of x: 0 <= x < 10.
662 ASSERT_TRUE(indexBoundsEquals(
663 aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
664}
665
666// Reductions should promote dependencies as well.
667TEST(MemDependency, MemDependencyCheckerLoopReduce) {
668 BufHandle a("A", {10}, kInt);
669 BufHandle b("B", {10}, kInt);
670 VarHandle x("x", kInt);
671
672 using namespace analysis;
673
674 MemDependencyChecker analyzer;
675
676 /*
677 * A[0] = 0;
678 * for (int x = 0; x < 10; ++x) {
679 * A[0] = A[x] + 1;
680 * }
681 * B[0] = A[0];
682 */
683
684 StorePtr aInit = Store::make(a, {0}, 0);
685 ExprHandle reduce = Sum()(a, 1, {x}, {x});
686 StorePtr aReduce = Store::make(a, {0}, reduce);
687 StmtPtr loop = For::make(x, 0, 10, aReduce);
688 StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
689
690 StmtPtr stmt = Block::make({aInit, loop, bStore});
691
692 stmt->accept(&analyzer);
693
694 // B -> A.
695 ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
696
697 // B depends indirectly on the intializer of A, since the reduction depends
698 // on it.
699 ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
700 ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
701
702 ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
703
704 // B depends on the loop.
705 ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
706 // A is in the loop and depends on other iterations.
707 ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
708
709 // The loop contents depend on the initializer too.
710 ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
711
712 // Find loads within the reduction:
713 auto reduceLoads = NodeFinder<Load>::find(reduce.node());
714 // Pull out the access for the load inside the loop.
715 for (auto load : reduceLoads) {
716 auto loopLoad = analyzer.accessFor(load);
717 // It should have 10 element long bounds.
718 ASSERT_TRUE(indexBoundsEquals(
719 loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
720 }
721}
722
723// Lowering a reduction doesn't affect dependency analysis.
724TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) {
725 BufHandle a("A", {10}, kInt);
726 BufHandle b("B", {10}, kInt);
727 VarHandle x("x", kInt);
728
729 using namespace analysis;
730
731 MemDependencyChecker analyzer;
732
733 /*
734 * A[0] = 0;
735 * for (int x = 0; x < 10; ++x) {
736 * A[0] = A[x] + 1;
737 * }
738 * B[0] = A[0];
739 */
740
741 StorePtr aInit = Store::make(a, {0}, 0);
742 ExprHandle aLoad = Load::make(a, {x});
743 StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
744 StmtPtr loop = For::make(x, 0, 10, aReduce);
745 StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
746
747 StmtPtr stmt = Block::make({aInit, loop, bStore});
748
749 stmt->accept(&analyzer);
750
751 // B -> A.
752 ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
753
754 // B depends indirectly on the intializer of A, since the reduction depends
755 // on it.
756 ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
757 ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
758
759 ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
760
761 // B depends on the loop.
762 ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
763 // A is in the loop and depends on other iterations.
764 ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
765
766 // The loop contents depend on the initializer too.
767 ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
768
769 // Pull out the access for the store inside the loop.
770 auto loopLoad = analyzer.accessFor(aLoad.node());
771 // It should have 10 element long bounds.
772 ASSERT_TRUE(indexBoundsEquals(
773 loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
774}
775
776// Can determine dependencies of outputs, through to inputs.
777TEST(MemDependency, MemDependencyCheckerInputsOutputs) {
778 BufHandle a("A", {10}, kInt);
779 BufHandle b("B", {10}, kInt);
780 VarHandle x("x", kInt);
781
782 // initialize analyzer with inputs and outputs.
783 analysis::MemDependencyChecker analyzer({a}, {b});
784
785 // Here's a Relu.
786 /*
787 * for (int x = 0; x < 10; ++x) {
788 * B[x] = Max(A[x], 0);
789 * }
790 */
791
792 ExprHandle aLoad = Load::make(a, {x});
793 StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
794 StmtPtr loop = For::make(x, 0, 10, bStore);
795
796 StmtPtr stmt = Block::make({loop});
797
798 stmt->accept(&analyzer);
799
800 // Output depends indirectly on input.
801 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
802 // aLoad depends directly on the input A.
803 ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node()));
804 // bStore therefore depends directly on the input A.
805 ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node()));
806 // The output depends directly on the store.
807 ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
808
809 // Check AccessInfo based overloads.
810 auto input = analyzer.input(a.node());
811 auto output = analyzer.output(b.node());
812
813 // Output depends indirectly on input.
814 ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
815 // Not directly.
816 ASSERT_FALSE(analyzer.dependsDirectly(output, input));
817 // Not in reverse order.
818 ASSERT_FALSE(analyzer.dependsIndirectly(input, output));
819
820 // output -> bStore -> bLoad -> input.
821 auto storeAccess = analyzer.accessFor(bStore);
822 auto loadAccess = analyzer.accessFor(aLoad.node());
823
824 ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess));
825 ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input));
826}
827
828// Can tell if an output does not depend on an input.
829TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) {
830 BufHandle a("A", {10}, kInt);
831 BufHandle b("B", {10}, kInt);
832 VarHandle x("x", kInt);
833
834 // initialize analyzer with inputs and outputs.
835 analysis::MemDependencyChecker analyzer({a}, {b});
836
837 // Here's a dumb Relu.
838 /*
839 * for (int x = 0; x < 10; ++x) {
840 * B[x] = Max(x, 0);
841 * }
842 */
843
844 StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true));
845 StmtPtr loop = For::make(x, 0, 10, bStore);
846
847 StmtPtr stmt = Block::make({loop});
848
849 stmt->accept(&analyzer);
850
851 // Output does not depend indirectly on input.
852 ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node()));
853
854 // The output still depends directly on the store.
855 ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
856
857 // Check AccessInfo based overloads.
858 auto input = analyzer.input(a.node());
859 auto output = analyzer.output(b.node());
860
861 // Output does not depend indirectly on input.
862 ASSERT_FALSE(analyzer.dependsIndirectly(output, input));
863}
864
865// Verify different loop extents produce accesses with different bounds, and
866// that later accesses find dependencies that overlap their entire bound range.
867TEST(MemDependency, MemDependencyCheckerLoopBounds) {
868 BufHandle a("A", {10}, kInt);
869 BufHandle b("B", {10}, kInt);
870 BufHandle c("C", {10}, kInt);
871 VarHandle x("x", kInt);
872 using namespace analysis;
873
874 MemDependencyChecker analyzer({a}, {c});
875
876 // This enables using the execution order of the loops to determine if some
877 // loops are self dependent or not.
878 analyzer.allowLoopExecutionOrderAnalysis();
879
880 /*
881 * for (int x = 1; x < 10; ++x) {
882 * B[x] = A[x];
883 * }
884 * for (int x = 1; x < 9; ++x) {
885 * B[x] = B[x] * 2;
886 * }
887 * for (int x = 3; x < 4; ++x) {
888 * C[x] = A[x];
889 * }
890 * for (int x = 0; x < 10; ++x) {
891 * C[x] = B[x];
892 * }
893 */
894
895 std::vector<StmtPtr> stmts(
896 {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))),
897 For::make(
898 x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))),
899 For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))),
900 For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))});
901
902 StmtPtr stmt = Block::make(stmts);
903
904 stmt->accept(&analyzer);
905
906 auto input = analyzer.input(a.node());
907 auto output = analyzer.output(c.node());
908
909 // sanity check Output -> Input.
910 ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
911
912 // Check the For loop dependencies:
913
914 // Last write to C depends on both writes to B since they contain the last
915 // write to at least one element.
916 ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1]));
917 ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0]));
918
919 // The last write to C does not depend on the other write to C.
920 ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2]));
921
922 auto CB = [](int s, int e) {
923 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
924 };
925 auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
926 return indexBoundsEquals(x, y);
927 };
928
929 /* 0. Input: A[(0, 9)] - dependents: 1 5
930 * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2
931 * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7
932 * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4
933 * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7
934 * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6
935 * 6. Store: C[(3, 3)] - depends on: 5
936 * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8
937 * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9
938 * 9. Output: C[(0, 9)] - depends on: 8
939 */
940
941 // Now let's look at the bounds of each access.
942 // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this
943 // much.
944 auto history = analyzer.getHistory();
945 ASSERT_EQ(history.size(), 10);
946 VarPtr aVar = a.node()->base_handle();
947 VarPtr bVar = b.node()->base_handle();
948 VarPtr cVar = c.node()->base_handle();
949
950 // The first access is the input A.
951 ASSERT_EQ(history[0]->type(), AccessType::Input);
952 ASSERT_EQ(history[0]->var(), aVar);
953 // It has the bounds of the producing Input.
954 ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
955 // sanity check the input we retrieved earlier matches.
956 ASSERT_EQ(history[0], input);
957
958 // The second access is the load of A in the first loop.
959 ASSERT_EQ(history[1]->type(), AccessType::Load);
960 ASSERT_EQ(history[1]->var(), aVar);
961 // It has the bounds of the loop, i.e. start == 1.
962 ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)}));
963 // It reads from A, so it should have a dependency on the last write to this
964 // range - with is the input.
965 ASSERT_EQ(history[1]->dependencies().size(), 1);
966 ASSERT_TRUE(history[1]->hasDependency(history[0]));
967
968 // The third access is the store into B in the first loop.
969 ASSERT_EQ(history[2]->type(), AccessType::Store);
970 ASSERT_EQ(history[2]->var(), bVar);
971 // It also has the bounds of the loop, i.e. start == 1.
972 ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
973 // The previous load is in its RHS, so it depends on it.
974 ASSERT_EQ(history[2]->dependencies().size(), 1);
975 ASSERT_TRUE(history[2]->hasDependency(history[1]));
976
977 // The third access is the load from B in the second loop.
978 ASSERT_EQ(history[3]->type(), AccessType::Load);
979 ASSERT_EQ(history[3]->var(), bVar);
980 // It has the bounds of the second loop, i.e. >= 1 < 9.
981 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)}));
982 // It reads from B in a smaller range, so should depend on the previous
983 // store.
984 ASSERT_EQ(history[3]->dependencies().size(), 1);
985 ASSERT_TRUE(history[3]->hasDependency(history[2]));
986
987 // The fourth: the store to B in the second loop.
988 ASSERT_EQ(history[4]->type(), AccessType::Store);
989 ASSERT_EQ(history[4]->var(), bVar);
990 // It also has the bounds of the second loop.
991 ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)}));
992 // The previous load is in its RHS, so it depends on it as before.
993 ASSERT_EQ(history[4]->dependencies().size(), 1);
994 ASSERT_TRUE(history[4]->hasDependency(history[3]));
995
996 // The fifth access is the load is from the 3rd loop, and skips previous B
997 // accesses.
998 ASSERT_EQ(history[5]->type(), AccessType::Load);
999 ASSERT_EQ(history[5]->var(), aVar);
1000 // It has the bounds of the third loop: >= 3 < 4.
1001 ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)}));
1002 // It depends on the last thing to write to A, which is the A input.
1003 ASSERT_EQ(history[5]->dependencies().size(), 1);
1004 ASSERT_TRUE(history[5]->hasDependency(history[0]));
1005
1006 // Sixth: the store into the output C.
1007 ASSERT_EQ(history[6]->type(), AccessType::Store);
1008 ASSERT_EQ(history[6]->var(), cVar);
1009 // It also has the bounds of the third loop.
1010 ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)}));
1011 // The previous load is in its RHS, so it depends on it as always.
1012 ASSERT_EQ(history[6]->dependencies().size(), 1);
1013 ASSERT_TRUE(history[6]->hasDependency(history[5]));
1014
1015 // The seventh access is the load of B in the fourth loop.
1016 ASSERT_EQ(history[7]->type(), AccessType::Load);
1017 ASSERT_EQ(history[7]->var(), bVar);
1018 // It has the bounds of the final loop, >= 0 < 10
1019 ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1020 // The bounds of this read are larger than the bounds of the previous write,
1021 // so it depends on both previous Stores to B.
1022 ASSERT_EQ(history[7]->dependencies().size(), 2);
1023 ASSERT_TRUE(history[7]->hasDependency(history[2]));
1024 ASSERT_TRUE(history[7]->hasDependency(history[4]));
1025
1026 // Eight: the final store into the output C.
1027 ASSERT_EQ(history[8]->type(), AccessType::Store);
1028 ASSERT_EQ(history[8]->var(), cVar);
1029 // It also has the bounds of the final loop.
1030 ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1031 // The previous load is in its RHS, so it depends on it as always.
1032 ASSERT_EQ(history[8]->dependencies().size(), 1);
1033 ASSERT_TRUE(history[8]->hasDependency(history[7]));
1034
1035 // The last access represents the output Buf.
1036 ASSERT_EQ(history[9]->type(), AccessType::Output);
1037 ASSERT_EQ(history[9]->var(), cVar);
1038 // It has the bounds of the output Buf.
1039 ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)}));
1040 // sanity check the input we retrieved earlier matches.
1041 ASSERT_EQ(history[9], output);
1042 // It depends on the last write to C only.
1043 ASSERT_EQ(history[9]->dependencies().size(), 1);
1044 ASSERT_TRUE(history[9]->hasDependency(history[8]));
1045}
1046
1047// Verify that we can still infer bounds when the loop var is offset.
1048TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
1049 BufHandle a("A", {10}, kInt);
1050 BufHandle b("B", {10}, kInt);
1051 VarHandle x("x", kInt);
1052
1053 using namespace analysis;
1054
1055 MemDependencyChecker analyzer({a}, {b});
1056
1057 // This enables using the execution order of the loops to determine if some
1058 // loops are self dependent or not.
1059 analyzer.allowLoopExecutionOrderAnalysis();
1060
1061 /*
1062 * for (int x = 1; x < 10; x++) {
1063 * A[x] = A[x - 1];
1064 * }
1065 * for (int x = 0; x < 9; x++) {
1066 * A[x] = A[x + 1];
1067 * }
1068 * for (int x = 0; x < 9; x++) {
1069 * A[9 - x] = A[8 - x];
1070 * }
1071 * for (int x = 0; x < 10; x++) {
1072 * A[x] = A[9 - x];
1073 * }
1074 * for (int x = 0; x < 10; x++) {
1075 * B[x] = A[x];
1076 * }
1077 */
1078
1079 StmtPtr stmt = Block::make(
1080 {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
1081 For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))),
1082 For::make(
1083 x,
1084 0,
1085 9,
1086 Store::make(
1087 a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))),
1088 For::make(
1089 x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))),
1090 For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))});
1091
1092 stmt->accept(&analyzer);
1093
1094 // Sanity check output depends on Input.
1095 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1096
1097 auto CB = [](int s, int e) {
1098 return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
1099 };
1100 auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
1101 return indexBoundsEquals(x, y);
1102 };
1103
1104 /* 0. Input: A[(0, 9)] - dependents: 1
1105 * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2
1106 * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3
1107 * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4
1108 * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7
1109 * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6
1110 * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7
1111 * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8
1112 * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9
1113 * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10
1114 * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11
1115 * 11. Output: B[(0, 9)] - depends on: 10
1116 */
1117
1118 // Now let's look at the bounds of each access.
1119 auto history = analyzer.getHistory();
1120 ASSERT_EQ(history.size(), 12);
1121 VarPtr aVar = a.node()->base_handle();
1122 VarPtr bVar = b.node()->base_handle();
1123
1124 // The first access is the input A.
1125 ASSERT_EQ(history[0]->type(), AccessType::Input);
1126 ASSERT_EQ(history[0]->var(), aVar);
1127 // It has the bounds of the producing Input.
1128 ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
1129
1130 // The second access is the load A[x-1].
1131 ASSERT_EQ(history[1]->type(), AccessType::Load);
1132 ASSERT_EQ(history[1]->var(), aVar);
1133 // It has the bounds of the loop modified by the offset of each index, in
1134 // this case -1.
1135 ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)}));
1136 // It depends on the input, but also the store in the same loop, since
1137 // different interations of the loop depend on each other.
1138 ASSERT_EQ(history[1]->dependencies().size(), 2);
1139 ASSERT_TRUE(history[1]->hasDependency(history[0]));
1140 ASSERT_TRUE(history[1]->hasDependency(history[2]));
1141
1142 // The third access is the Store to A[x] in the first loop.
1143 ASSERT_EQ(history[2]->type(), AccessType::Store);
1144 ASSERT_EQ(history[2]->var(), aVar);
1145 // It has no offset on x, so should have the same bounds as the loop.
1146 ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
1147
1148 // The fourth access is the load A[x+1] in the second loop.
1149 ASSERT_EQ(history[3]->type(), AccessType::Load);
1150 ASSERT_EQ(history[3]->var(), aVar);
1151 // It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1152 // index, in this case 1.
1153 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)}));
1154 // This load totally overlaps the previous write to A, so it depends only on
1155 // it and not the input.
1156 ASSERT_EQ(history[3]->dependencies().size(), 1);
1157 ASSERT_TRUE(history[3]->hasDependency(history[2]));
1158
1159 // The fifth access is the store to A[x] in the second loop.
1160 ASSERT_EQ(history[4]->type(), AccessType::Store);
1161 ASSERT_EQ(history[4]->var(), aVar);
1162 // It has no offset on x, so should have the same bounds as the loop.
1163 ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)}));
1164
1165 // The sixth access is the load to A[8 - x] in the third loop.
1166 ASSERT_EQ(history[5]->type(), AccessType::Load);
1167 ASSERT_EQ(history[5]->var(), aVar);
1168 // It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1169 // index, in this case 8 - x.
1170 // This access has a negative stride, which will be normalized.
1171 ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)}));
1172 // This load totally overlaps the most recent write to A, so it depends only
1173 // on it and not the input or the first write to A.
1174 ASSERT_EQ(history[5]->dependencies().size(), 1);
1175 ASSERT_TRUE(history[5]->hasDependency(history[4]));
1176
1177 // The seventh access is the store to A[9 - x] in the third loop.
1178 ASSERT_EQ(history[6]->type(), AccessType::Store);
1179 ASSERT_EQ(history[6]->var(), aVar);
1180 // This store has a negative stride on it's indices, but is notmalized
1181 // internally.
1182 ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)}));
1183
1184 // The eighth access is the load A[9-x] in the second loop.
1185 ASSERT_EQ(history[7]->type(), AccessType::Load);
1186 ASSERT_EQ(history[7]->var(), aVar);
1187 // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x,
1188 // which esstentially traverses the loop backwards.
1189 ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1190 // This Load has three write dependencies:
1191 ASSERT_EQ(history[7]->dependencies().size(), 3);
1192 // * The previous store (#6) for elements 1-9
1193 ASSERT_TRUE(history[7]->hasDependency(history[6]));
1194 // * An earlier store (#4) covering element 0
1195 ASSERT_TRUE(history[7]->hasDependency(history[4]));
1196 // * A future store inside this loop, since this loop modifies the buffer
1197 // in a non distinct way (due to the load and store having different access
1198 // strides).
1199 ASSERT_TRUE(history[7]->hasDependency(history[8]));
1200
1201 // The ninth access is the store to A[x] in the fourth loop.
1202 ASSERT_EQ(history[8]->type(), AccessType::Store);
1203 ASSERT_EQ(history[8]->var(), aVar);
1204 // This store has a negative stride on it's indices, but is notmalized
1205 // internally.
1206 ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1207
1208 // The tenth and 11th acceses are the copy from A[x] to B[x].
1209 ASSERT_EQ(history[9]->type(), AccessType::Load);
1210 ASSERT_EQ(history[9]->var(), aVar);
1211 ASSERT_EQ(history[10]->type(), AccessType::Store);
1212 ASSERT_EQ(history[10]->var(), bVar);
1213
1214 // The last access represents the output Buf.
1215 ASSERT_EQ(history[11]->type(), AccessType::Output);
1216 ASSERT_EQ(history[11]->var(), bVar);
1217 // It has the bounds of the output Buf.
1218 ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)}));
1219 // It depends on the last write to B only.
1220 ASSERT_EQ(history[11]->dependencies().size(), 1);
1221 ASSERT_TRUE(history[11]->hasDependency(history[10]));
1222
1223 // ok that's enough of that.
1224}
1225
1226// Check many different cases of loop self dependency - when a load within a
1227// loop is dependent on a Store later in the same loop but in different
1228// iteration. This is affected by whether or not we can trust the execution
1229// order of the loop.
1230TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) {
1231 BufHandle a("A", {5}, kInt);
1232 BufHandle b("B", {5}, kInt);
1233 VarHandle x("x", kInt);
1234 VarHandle y("y", kInt);
1235 VarHandle z("z", kInt);
1236
1237 using namespace analysis;
1238
1239 // This check assumes that the Stmt has a single Store with a single Load on
1240 // the RHS.
1241 auto isSelfDependent =
1242 [](const std::vector<std::shared_ptr<AccessInfo>>& history) -> bool {
1243 return history.front()->hasDependency(history.back());
1244 };
1245
1246 {
1247 /* for (int y = 0; y < 10; y++) {
1248 * A[y] = (A[y]) + 1;
1249 * } */
1250
1251 // Not self dependent since all loop iterations use a different y.
1252
1253 MemDependencyChecker analyzer;
1254 StmtPtr stmt = For::make(
1255 y,
1256 0,
1257 10,
1258 Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))}));
1259
1260 stmt->accept(&analyzer);
1261
1262 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1263 }
1264
1265 {
1266 /* for (int y = 0; y < 10; y++) {
1267 * A[y + 1] = (A[y + 1]) + 1;
1268 * }
1269 */
1270
1271 // Not self dependent due to different y (with offset).
1272
1273 MemDependencyChecker analyzer;
1274 StmtPtr stmt = For::make(
1275 y,
1276 0,
1277 10,
1278 Block::make(
1279 {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))}));
1280
1281 stmt->accept(&analyzer);
1282
1283 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1284 }
1285
1286 {
1287 /* for (int x = 0; x < 10; x++) {
1288 * A[0] = (A[0]) + x;
1289 * }
1290 */
1291
1292 // Is self dependent since all loops use a common constant element of A.
1293
1294 MemDependencyChecker analyzer;
1295 StmtPtr stmt = For::make(
1296 x,
1297 0,
1298 10,
1299 Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}));
1300 stmt->accept(&analyzer);
1301
1302 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1303 }
1304
1305 {
1306 /* for (int x = 0; x < 10; x++) {
1307 * A[0] = (B[0]) + x;
1308 * }
1309 */
1310
1311 // Is not self dependent beacause there is no store to the buffer that is
1312 // read.
1313
1314 MemDependencyChecker analyzer;
1315 StmtPtr stmt = For::make(
1316 x,
1317 0,
1318 10,
1319 Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))}));
1320 stmt->accept(&analyzer);
1321
1322 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1323 }
1324
1325 {
1326 /* for (int x = 0; x < 10; x++) {
1327 * A[y] = (A[y]) + x;
1328 * }
1329 */
1330
1331 // Is self dependent since all loops use a common symbolic element of A.
1332
1333 MemDependencyChecker analyzer;
1334 StmtPtr stmt = For::make(
1335 x,
1336 0,
1337 10,
1338 Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))}));
1339 stmt->accept(&analyzer);
1340
1341 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1342 }
1343
1344 {
1345 /* for (int x = 0; x < 10; x++) {
1346 * A[x] = A[x + 1];
1347 * }
1348 */
1349
1350 // In this case it depends if we are considering execution order.
1351
1352 MemDependencyChecker analyzer;
1353
1354 StmtPtr stmt =
1355 For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1356 stmt->accept(&analyzer);
1357
1358 // With analysis of order disabled, this is self dependent since the read
1359 // from X+1 and the write to X+1 could be in reverse order.
1360 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1361 }
1362
1363 {
1364 /* for (int x = 0; x < 10; x++) {
1365 * A[x] = A[x + 1];
1366 * }
1367 */
1368
1369 MemDependencyChecker analyzer;
1370 analyzer.allowLoopExecutionOrderAnalysis();
1371
1372 StmtPtr stmt =
1373 For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1374 stmt->accept(&analyzer);
1375
1376 // If order analysis is enabled, this is not dependent since the read for
1377 // each element occurs before the write to that element.
1378 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1379 }
1380
1381 {
1382 /* for (int x = 1; x < 10; x++) {
1383 * A[x] = A[x - 1];
1384 * }
1385 */
1386
1387 MemDependencyChecker analyzer;
1388
1389 StmtPtr stmt =
1390 For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1391 stmt->accept(&analyzer);
1392
1393 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1394 }
1395
1396 {
1397 /* for (int x = 1; x < 10; x++) {
1398 * A[x] = A[x - 1];
1399 * }
1400 */
1401
1402 MemDependencyChecker analyzer;
1403 analyzer.allowLoopExecutionOrderAnalysis();
1404
1405 StmtPtr stmt =
1406 For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1407 stmt->accept(&analyzer);
1408
1409 // In this case, even with order analysis the Load is dependent on the
1410 // Store, since the write to X occurs before the read from X.
1411 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1412 }
1413
1414 {
1415 /* for (int x = 0; x < 9; x++) {
1416 * A[9 - x] = A[8 - x];
1417 * }
1418 */
1419
1420 // Still works if the execution order is reversed, so long as the read
1421 // comes before the write.
1422
1423 MemDependencyChecker analyzer;
1424 analyzer.allowLoopExecutionOrderAnalysis();
1425
1426 StmtPtr stmt = For::make(
1427 x,
1428 3,
1429 10,
1430 Store::make(
1431 a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1432 stmt->accept(&analyzer);
1433
1434 // However here was can determine the A store is earlier in the order than
1435 // the load.
1436 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1437 }
1438
1439 {
1440 /* for (int x = 0; x < 9; x++) {
1441 * A[8 - x] = A[9 - x];
1442 * }
1443 */
1444
1445 // But not if it doesn't.
1446
1447 MemDependencyChecker analyzer;
1448 analyzer.allowLoopExecutionOrderAnalysis();
1449
1450 StmtPtr stmt = For::make(
1451 x,
1452 3,
1453 10,
1454 Store::make(
1455 a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x})));
1456 stmt->accept(&analyzer);
1457
1458 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1459 }
1460
1461 {
1462 /* for (int x = 0; x < 9; x++) {
1463 * A[9 - x] = A[8 - x];
1464 * }
1465 */
1466
1467 // And not if we're not relying on execution order.
1468
1469 MemDependencyChecker analyzer;
1470
1471 StmtPtr stmt = For::make(
1472 x,
1473 3,
1474 10,
1475 Store::make(
1476 a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1477 stmt->accept(&analyzer);
1478
1479 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1480 }
1481
1482 {
1483 /* for (int x = 3; x < 10; x++) {
1484 * A[x - 2] = A[x - 1];
1485 * }
1486 */
1487
1488 // Forward order but negative indices.
1489
1490 MemDependencyChecker analyzer;
1491 analyzer.allowLoopExecutionOrderAnalysis();
1492
1493 StmtPtr stmt =
1494 For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1})));
1495 stmt->accept(&analyzer);
1496
1497 // However here was can determine the A store is earlier in the order than
1498 // the load.
1499 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1500 }
1501
1502 {
1503 /* for (int x = 0; x < 10; x++) {
1504 * A[x * 2] = A[x * 2];
1505 * }
1506 */
1507
1508 // With an access stride.
1509
1510 MemDependencyChecker analyzer;
1511 // Execution order doesn't matter since the read and the write are totally
1512 // distinct.
1513
1514 StmtPtr stmt =
1515 For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2})));
1516 stmt->accept(&analyzer);
1517
1518 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1519 }
1520
1521 {
1522 /* for (int x = 0; x < 10; x++) {
1523 * A[x * 2] = A[x * 2 + 1];
1524 * }
1525 */
1526
1527 // Here we can use the common stride of the accesses to determine they are
1528 // distinct.
1529 // Note, this is the only place (loop self depedency) we use this stride
1530 // to avoid unnecessary depedence.
1531
1532 MemDependencyChecker analyzer;
1533 // Execution order doesn't matter since the read and the write are totally
1534 // distinct.
1535
1536 StmtPtr stmt = For::make(
1537 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1})));
1538 stmt->accept(&analyzer);
1539
1540 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1541 }
1542
1543 {
1544 /* for (int x = 0; x < 10; x++) {
1545 * A[x * 2] = A[x * 2 - 1];
1546 * }
1547 */
1548
1549 // same if the read is behind the write so long as they are distinct.
1550
1551 MemDependencyChecker analyzer;
1552 StmtPtr stmt = For::make(
1553 x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1})));
1554 stmt->accept(&analyzer);
1555
1556 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1557 }
1558
1559 {
1560 /* for (int x = 0; x < 10; x++) {
1561 * A[x * 2] = A[x * 2 + 2];
1562 * }
1563 */
1564
1565 // But not if the offset is in the stride.
1566
1567 MemDependencyChecker analyzer;
1568 StmtPtr stmt = For::make(
1569 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2})));
1570 stmt->accept(&analyzer);
1571
1572 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1573 }
1574
1575 {
1576 /* for (int x = 0; x < 10; x++) {
1577 * A[x * 2] = A[x * 2 - 2];
1578 * }
1579 */
1580
1581 // Works with negative offsets too.
1582
1583 MemDependencyChecker analyzer;
1584 StmtPtr stmt = For::make(
1585 x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2})));
1586 stmt->accept(&analyzer);
1587
1588 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1589 }
1590
1591 {
1592 /* for (int x = 0; x < 10; x++) {
1593 * A[x * 2] = A[x * 2 + 7];
1594 * }
1595 */
1596
1597 // Detects accesses are distinct when offset is large but not a multiple
1598 // of stride.
1599 MemDependencyChecker analyzer;
1600 StmtPtr stmt = For::make(
1601 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7})));
1602 stmt->accept(&analyzer);
1603
1604 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1605 }
1606
1607 {
1608 /* for (int x = 0; x < 10; x++) {
1609 * A[x * 2] = A[x * 2 + 4];
1610 * }
1611 */
1612
1613 // Works with offsets which are multiples of the stride.
1614 MemDependencyChecker analyzer;
1615 StmtPtr stmt = For::make(
1616 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4})));
1617 stmt->accept(&analyzer);
1618
1619 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1620 }
1621
1622 {
1623 /* for (int x = 0; x < 10; x++) {
1624 * A[x * 6] = A[x * 6 + 5];
1625 * }
1626 */
1627
1628 // detects accesses are distinct with large strides when the offset is
1629 // within.
1630
1631 MemDependencyChecker analyzer;
1632 StmtPtr stmt = For::make(
1633 x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5})));
1634 stmt->accept(&analyzer);
1635
1636 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1637 }
1638
1639 {
1640 /* for (int x = 0; x < 10; x++) {
1641 * A[x * 2] = A[x * 6];
1642 * }
1643 */
1644
1645 // detects accesses are overlapping when stride is different but a
1646 // multiple.
1647
1648 MemDependencyChecker analyzer;
1649 StmtPtr stmt =
1650 For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6})));
1651 stmt->accept(&analyzer);
1652
1653 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1654 }
1655
1656 {
1657 /* for (int x = 0; x < 10; x++) {
1658 * A[x * 4] = A[x * 2];
1659 * }
1660 */
1661
1662 // still works when the read axis is the smaller stride.
1663
1664 MemDependencyChecker analyzer;
1665 StmtPtr stmt =
1666 For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2})));
1667 stmt->accept(&analyzer);
1668
1669 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1670 }
1671
1672 {
1673 /* for (int x = 0; x < 10; x++) {
1674 * A[x * 2] = A[x * 6 + 1];
1675 * }
1676 */
1677
1678 // detects accesses are distinct when stride is different but a multiple
1679 // and there is an offset.
1680
1681 MemDependencyChecker analyzer;
1682 StmtPtr stmt = For::make(
1683 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1})));
1684 stmt->accept(&analyzer);
1685
1686 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1687 }
1688
1689 {
1690 /* for (int x = 0; x < 10; x++) {
1691 * A[x * 2] = A[x * 6 + 4];
1692 * }
1693 */
1694
1695 // The smaller stride determines whether there is overlap.
1696
1697 MemDependencyChecker analyzer;
1698 StmtPtr stmt = For::make(
1699 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4})));
1700 stmt->accept(&analyzer);
1701
1702 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1703 }
1704
1705 {
1706 /* for (int x = 0; x < 10; x++) {
1707 * A[x * 2 + 3] = A[x * 6];
1708 * }
1709 */
1710
1711 // The smaller stride determines whether there is overlap, not the larger.
1712
1713 MemDependencyChecker analyzer;
1714 StmtPtr stmt = For::make(
1715 x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6})));
1716 stmt->accept(&analyzer);
1717
1718 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1719 }
1720
1721 {
1722 /* for (int x = 0; x < 10; x++) {
1723 * A[x * 2] = A[x * 3 + 1];
1724 * }
1725 */
1726
1727 // If they have strides with no common muliple > 1, they overlap.
1728 MemDependencyChecker analyzer;
1729 StmtPtr stmt = For::make(
1730 x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1})));
1731 stmt->accept(&analyzer);
1732
1733 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1734 }
1735
1736 {
1737 /* for (int x = 0; x < 10; x++) {
1738 * A[x] = A[x + 10];
1739 * }
1740 */
1741
1742 // If the offset is greater than the size of the loop, they can't overlap.
1743
1744 MemDependencyChecker analyzer;
1745 StmtPtr stmt =
1746 For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10})));
1747 stmt->accept(&analyzer);
1748
1749 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1750 }
1751
1752 {
1753 /* for (int x = 0; x < 10; x++) {
1754 * A[x] = A[9 - x];
1755 * }
1756 */
1757
1758 // If they have different execution orders they may overlap.
1759 MemDependencyChecker analyzer;
1760 StmtPtr stmt = For::make(
1761 x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x})));
1762 stmt->accept(&analyzer);
1763
1764 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1765 }
1766
1767 {
1768 /* for (int x = 0; x < 10; x++) {
1769 * A[x * 2] = A[19 - x * 2];
1770 * }
1771 */
1772
1773 // Or they may not, depending on their start offset and strides.
1774 MemDependencyChecker analyzer;
1775 StmtPtr stmt = For::make(
1776 x,
1777 0,
1778 10,
1779 Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2})));
1780 stmt->accept(&analyzer);
1781
1782 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1783 }
1784
1785 {
1786 /* for (int x = 0; x < 10; x++) {
1787 * A[x / 2] = A[x / 2];
1788 * }
1789 */
1790
1791 // If the stride is not monotonic, they overlap.
1792
1793 MemDependencyChecker analyzer;
1794 StmtPtr stmt =
1795 For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2})));
1796 stmt->accept(&analyzer);
1797
1798 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1799 }
1800
1801 {
1802 /* for (int x = 0; x < 10; x++) {
1803 * A[x / 2] = A[x / 2] + 1;
1804 * }
1805 */
1806
1807 // If the stride is not monotonic, they overlap - even with an offset.
1808 MemDependencyChecker analyzer;
1809 StmtPtr stmt = For::make(
1810 x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1})));
1811 stmt->accept(&analyzer);
1812
1813 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1814 }
1815
1816 {
1817 /* for (int x = 0; x < 10; x++) {
1818 * A[x % 2] = A[x % 2];
1819 * }
1820 */
1821
1822 // Mod too...
1823
1824 analysis::MemDependencyChecker analyzer;
1825 StmtPtr stmt = For::make(
1826 x,
1827 0,
1828 10,
1829 Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)})));
1830 stmt->accept(&analyzer);
1831
1832 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1833 }
1834
1835 {
1836 /* for (int x = y; x < z; x++) {
1837 * A[x] = A[x + 1];
1838 * }
1839 */
1840
1841 // Still works with symbolic loop extents.
1842
1843 {
1844 MemDependencyChecker analyzer;
1845 StmtPtr stmt =
1846 For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1847 stmt->accept(&analyzer);
1848
1849 ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1850 }
1851
1852 {
1853 MemDependencyChecker analyzer;
1854 analyzer.allowLoopExecutionOrderAnalysis();
1855 StmtPtr stmt =
1856 For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1857 stmt->accept(&analyzer);
1858
1859 ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1860 }
1861 }
1862}
1863
1864// Verify that a strided access still works.
1865// TODO: actually this only works because of the size of the ranges, revist this
1866// test after strided overlap is implemented.
1867TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1868 BufHandle a("A", {20}, kInt);
1869 BufHandle b("B", {20}, kInt);
1870 VarHandle x("x", kInt);
1871 VarHandle y("y", kInt);
1872
1873 using namespace analysis;
1874 MemDependencyChecker analyzer({a.node()}, {b.node()});
1875 StmtPtr stmt = Block::make(
1876 {For::make(
1877 x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1878 For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2})))
1879
1880 });
1881 stmt->accept(&analyzer);
1882
1883 // Sanity check output depends on input.
1884 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1885
1886 // Output has 2 dependencies... the store in each loop.
1887 auto outputAccess = analyzer.output(b.node());
1888 ASSERT_EQ(outputAccess->dependencies().size(), 2);
1889}
1890
1891/* TODO(nickg) - this test will fail due to the lack of stride math in Bound
1892TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1893 BufHandle a("A", {20}, kInt);
1894 BufHandle b("B", {20}, kInt);
1895 BufHandle c("C", {10}, kInt);
1896 VarHandle x("x", kInt);
1897 VarHandle y("y", kInt);
1898
1899 {
1900 analysis::MemDependencyChecker analyzer({a.node()}, {c.node()});
1901 StmtPtr stmt = Block::make(
1902 {For::make(
1903 x,
1904 0,
1905 10,
1906 Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1907 For::make(
1908 x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))),
1909 For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))
1910
1911 });
1912 stmt->accept(&analyzer);
1913
1914 std::cout << *stmt << "\n";
1915 for (auto& wi : analyzer.getHistory()) {
1916 wi->print();
1917 }
1918 }
1919}*/
1920
1921// analysis on Stmts using Cond.
1922TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) {
1923 BufHandle a("A", {10}, kInt);
1924 BufHandle b("B", {10}, kInt);
1925 BufHandle c("C", {10}, kInt);
1926 VarHandle x("x", kInt);
1927 VarHandle y("y", kInt);
1928
1929 using namespace analysis;
1930
1931 {
1932 /* for (int x = 0; x < 10; x++) {
1933 * C[x] = A[x];
1934 * }
1935 * if (y<5 ? 1 : 0) {
1936 * C[0] = (B[0]) + 1;
1937 * } else {
1938 * C[0] = (B[1]) + 1;
1939 * }
1940 */
1941
1942 // Future usages may depend on accesses in both branches of a condition.
1943
1944 MemDependencyChecker analyzer({a, b}, {c});
1945 StmtPtr stmt = Block::make(
1946 {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1947 Cond::make(
1948 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1949 Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)),
1950 Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))});
1951
1952 stmt->accept(&analyzer);
1953
1954 // Output C should have 3 dependencies, each of the three stores.
1955 auto outputAccess = analyzer.output(c.node());
1956 ASSERT_NE(outputAccess, nullptr);
1957 ASSERT_EQ(outputAccess->dependencies().size(), 3);
1958
1959 // C depends indirectly on A and B.
1960 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
1961 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
1962 }
1963
1964 {
1965 /* for (int x = 0; x < 10; x++) {
1966 * C[x] = A[x];
1967 * }
1968 * if (y<5 ? 1 : 0) {
1969 * for (int x = 0; x < 10; x++) {
1970 * C[x] = B[x];
1971 * }
1972 * } else {
1973 * for (int x = 0; x < 10; x++) {
1974 * C[x] = (B[x]) + 1;
1975 * }
1976 * }
1977 */
1978
1979 // Future usages may depend on accesses in both branches of a condition.
1980
1981 MemDependencyChecker analyzer({a, b}, {c});
1982 StmtPtr stmt = Block::make(
1983 {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1984 Cond::make(
1985 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1986 For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))),
1987 For::make(
1988 x,
1989 0,
1990 10,
1991 Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
1992
1993 stmt->accept(&analyzer);
1994
1995 // Output C should have 3 dependencies, each of the three stores.
1996 auto outputAccess = analyzer.output(c.node());
1997 ASSERT_NE(outputAccess, nullptr);
1998 ASSERT_EQ(outputAccess->dependencies().size(), 3);
1999
2000 // TODO(nickg): actually since the true and false branch cover the total
2001 // range of the first store this should have 2 dependencies, but we don't
2002 // do that yet.
2003
2004 // C depends indirectly on A and B.
2005 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2006 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2007 }
2008
2009 {
2010 /* for (int x = 0; x < 10; x++) {
2011 * C[x] = A[x];
2012 * }
2013 * if (y<5 ? 1 : 0) {
2014 * for (int x = 0; x < 10; x++) {
2015 * C[x] = (B[x]) + 1;
2016 * }
2017 * }
2018 */
2019
2020 // Only has true branch.
2021
2022 MemDependencyChecker analyzer({a, b}, {c});
2023 StmtPtr stmt = Block::make(
2024 {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2025 Cond::make(
2026 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2027 For::make(
2028 x,
2029 0,
2030 10,
2031 Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))),
2032 nullptr)});
2033
2034 stmt->accept(&analyzer);
2035
2036 // Output C should have 3 dependencies, each of the three stores.
2037 auto outputAccess = analyzer.output(c.node());
2038 ASSERT_NE(outputAccess, nullptr);
2039 ASSERT_EQ(outputAccess->dependencies().size(), 2);
2040
2041 // C depends indirectly on A and B.
2042 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2043 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2044 }
2045
2046 {
2047 /* for (int x = 0; x < 10; x++) {
2048 * C[x] = A[x];
2049 * }
2050 * if (y<5 ? 1 : 0) {
2051 * } else {
2052 * for (int x = 0; x < 10; x++) {
2053 * C[x] = (B[x]) + 1;
2054 * }
2055 * }
2056 */
2057
2058 // Only has false branch.
2059
2060 MemDependencyChecker analyzer({a, b}, {c});
2061 StmtPtr stmt = Block::make(
2062 {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2063 Cond::make(
2064 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2065 nullptr,
2066 For::make(
2067 x,
2068 0,
2069 10,
2070 Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
2071
2072 stmt->accept(&analyzer);
2073
2074 // Output C should have 3 dependencies, each of the three stores.
2075 auto outputAccess = analyzer.output(c.node());
2076 ASSERT_NE(outputAccess, nullptr);
2077 ASSERT_EQ(outputAccess->dependencies().size(), 2);
2078
2079 // C depends indirectly on A and B.
2080 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2081 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2082 }
2083
2084 {
2085 /* for (int x = 0; x < 10; x++) {
2086 * C[x] = A[x];
2087 * }
2088 * if (C[0]<5 ? 1 : 0) {
2089 * C[0] = 5;
2090 * }
2091 */
2092
2093 // Cond's Condition depends on a previous access.
2094
2095 MemDependencyChecker analyzer({a}, {c});
2096 StorePtr initStore = Store::make(c, {x}, Load::make(a, {x}));
2097 ExprHandle conditionalLoad = Load::make(c, {0});
2098 StmtPtr stmt = Block::make(
2099 {For::make(x, 0, 10, initStore),
2100 Cond::make(
2101 CompareSelect::make(
2102 conditionalLoad, 5, CompareSelectOperation::kLT),
2103 Store::make(c, {0}, 5),
2104 nullptr)});
2105
2106 stmt->accept(&analyzer);
2107
2108 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2109
2110 ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore));
2111 ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node()));
2112 ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node()));
2113 }
2114}
2115
2116// Stmts using IfThenElse.
2117TEST(MemDependency, MemDependencyCheckerIfThenElse) {
2118 BufHandle a("A", {10}, kInt);
2119 BufHandle b("B", {10}, kInt);
2120 BufHandle c("C", {10}, kInt);
2121 VarHandle x("x", kInt);
2122 VarHandle y("y", kInt);
2123
2124 using namespace analysis;
2125
2126 {
2127 /* for (int x = 0; x < 10; x++) {
2128 * C[x] = A[x];
2129 * }
2130 * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1;
2131 */
2132
2133 // Future usages may depend on accesses in both branches of a condition.
2134
2135 MemDependencyChecker analyzer({a, b}, {c});
2136 StorePtr ifStore = Store::make(
2137 c,
2138 {0},
2139 IfThenElse::make(
2140 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2141 Add::make(Load::make(b, {0}), 1),
2142 Add::make(Load::make(b, {1}), 1)));
2143 StmtPtr stmt = Block::make(
2144 {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2145 ifStore});
2146
2147 stmt->accept(&analyzer);
2148
2149 // Output C should have 2 dependencies, each of the two stores.
2150 auto outputAccess = analyzer.output(c.node());
2151 ASSERT_NE(outputAccess, nullptr);
2152 ASSERT_EQ(outputAccess->dependencies().size(), 2);
2153
2154 // Now we need to check the Store containing the IfThenElse.
2155 auto ifStoreAccess = analyzer.accessFor(ifStore);
2156
2157 // It should have 2 dependencies.
2158 ASSERT_EQ(ifStoreAccess->dependencies().size(), 2);
2159
2160 // C depends indirectly on A and B.
2161 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2162 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2163 }
2164
2165 {
2166 /* for (int x = 0; x < 10; x++) {
2167 * C[x] = A[x];
2168 * }
2169 * C[0] = (y < 5 ? (B[0]) + 1 : 42;
2170 */
2171
2172 // If the load appears in only one side of an IfThenElse the output may be
2173 // dependent on it.
2174
2175 MemDependencyChecker analyzer({a, b}, {c});
2176 StorePtr ifStore = Store::make(
2177 c,
2178 {0},
2179 IfThenElse::make(
2180 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2181 Add::make(Load::make(b, {0}), 1),
2182 42));
2183 StmtPtr stmt = Block::make(
2184 {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2185 ifStore});
2186
2187 stmt->accept(&analyzer);
2188
2189 // C depends indirectly on A and B.
2190 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2191 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2192 }
2193
2194 {
2195 /* for (int x = 0; x < 10; x++) {
2196 * C[x] = (x < 5 ? B[x] : A[x];
2197 * }
2198 */
2199
2200 // In this case C is dependent on both A and B.
2201
2202 // TODO: in cases like this it would be possible to split the range of B
2203 // into two bounds, one dependent on A and one depenent on B. We'd need to
2204 // examine conditions relative to previously encountered loop variables. I'm
2205 // uncertain if this would be helpful.
2206
2207 MemDependencyChecker analyzer({a, b}, {c});
2208 StorePtr ifStore = Store::make(
2209 c,
2210 {0},
2211 IfThenElse::make(
2212 CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2213 Load::make(b, {x}),
2214 Load::make(a, {x})));
2215 StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)});
2216
2217 stmt->accept(&analyzer);
2218
2219 // C depends indirectly on A and B.
2220 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2221 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2222 }
2223}
2224
2225// Cutting a loop with single elem writes
2226TEST(MemDependency, MemDependencyCheckerCutLoop) {
2227 BufHandle a("A", {10}, kInt);
2228 BufHandle b("B", {10}, kInt);
2229 VarHandle x("x", kInt);
2230
2231 using namespace analysis;
2232
2233 {
2234 /* for (int x = 0; x < 10; x++) {
2235 * B[x] = A[x];
2236 * }
2237 * B[5] = 100;
2238 */
2239
2240 // Cutting a loop with single element writes.
2241
2242 MemDependencyChecker analyzer({a}, {b});
2243 StmtPtr stmt = Block::make(
2244 {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))),
2245 Store::make(b, {5}, 100)});
2246
2247 stmt->accept(&analyzer);
2248
2249 // Output depends on input.
2250 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2251
2252 // Output has 2 depdenencies.
2253 auto outputAccess = analyzer.output(b.node());
2254 ASSERT_NE(outputAccess, nullptr);
2255 ASSERT_EQ(outputAccess->dependencies().size(), 2);
2256 }
2257
2258 {
2259 /* for (int x = 0; x < 10; x++) {
2260 * B[x] = A[x];
2261 * }
2262 * for (int x = 4; x < 7; x++) {
2263 * B[x] = B[x] + 3;
2264 * }
2265 * B[5] = 100;
2266 * B[6] = 101;
2267 * B[7] = 102;
2268 */
2269
2270 // Cutting a loop with a smaller loop but then totally overlap that second
2271 // loop with one element writes.
2272
2273 MemDependencyChecker analyzer({a}, {b});
2274 ForPtr firstLoop =
2275 For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})));
2276 StorePtr secondStore =
2277 Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
2278 ForPtr secondLoop = For::make(x, 4, 7, secondStore);
2279
2280 StmtPtr stmt = Block::make(
2281 {firstLoop,
2282 secondLoop,
2283 Store::make(b, {4}, 100),
2284 Store::make(b, {5}, 101),
2285 Store::make(b, {6}, 102)});
2286
2287 stmt->accept(&analyzer);
2288
2289 // Output depends on input.
2290 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2291
2292 // Output has 4 depdenencies.
2293 auto outputAccess = analyzer.output(b.node());
2294 ASSERT_NE(outputAccess, nullptr);
2295 ASSERT_EQ(outputAccess->dependencies().size(), 4);
2296
2297 // Second loop depends on first loop.
2298 ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop));
2299
2300 // Output does not depend on second loop or store.
2301 ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop));
2302 ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore));
2303 }
2304}
2305
2306// Dynamic shapes (load in indices).
2307TEST(MemDependency, MemDependencyCheckerDynamicShapes) {
2308 BufHandle a("A", {100}, kInt);
2309 BufHandle b("B", {100}, kInt);
2310 BufHandle c("C", {100}, kInt);
2311 VarHandle x("x", kInt);
2312
2313 using namespace analysis;
2314
2315 auto CB = [](ExprHandle s, ExprHandle e) {
2316 return Bound(s.node(), e.node());
2317 };
2318
2319 auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2320 return indexBoundsEquals(x, y);
2321 };
2322
2323 {
2324 /* for (int x = 0; x < B[0]; x++) {
2325 * C[x] = A[x];
2326 * }
2327 */
2328 MemDependencyChecker analyzer({a, b}, {c});
2329 StmtPtr stmt = Block::make({For::make(
2330 x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))});
2331
2332 stmt->accept(&analyzer);
2333
2334 /* 0. Input: B[(0, 99)] - dependents: 2
2335 * 1. Input: A[(0, 99)] - dependents: 3
2336 * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4
2337 * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4
2338 * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5
2339 * 5. Output: C[(0, 99)] - depends on: 4
2340 */
2341
2342 // Output dependent on A input.
2343 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2344 // Also dependent on B input to determine the size of the region written.
2345 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2346
2347 auto history = analyzer.getHistory();
2348 ASSERT_EQ(history.size(), 6);
2349
2350 // The accesses in the loop depend on the load in the stop condition.
2351 ASSERT_TRUE(history[4]->hasDependency(history[2]));
2352 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2353
2354 // Make a load from B to compare against.
2355 ExprHandle loadFromB = Load::make(b, {0});
2356
2357 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)}));
2358 ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)}));
2359 }
2360
2361 {
2362 /* for (int x = B[0]; x < B[1]; x++) {
2363 * C[x] = A[x];
2364 * }
2365 */
2366 MemDependencyChecker analyzer({a, b}, {c});
2367 StmtPtr stmt = Block::make({For::make(
2368 x,
2369 Load::make(b, {0}),
2370 Load::make(b, {1}),
2371 Store::make(c, {x}, Load::make(a, {x})))});
2372
2373 stmt->accept(&analyzer);
2374
2375 /* 0. Input: B[(0, 99)] - dependents: 2 3
2376 * 1. Input: A[(0, 99)] - dependents: 4
2377 * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5
2378 * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5
2379 * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5
2380 * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6
2381 * 6. Output: C[(0, 99)] - depends on: 5
2382 */
2383
2384 // Sanity check output depends on input.
2385 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2386 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2387
2388 auto history = analyzer.getHistory();
2389 ASSERT_EQ(history.size(), 7);
2390
2391 // The accesses in the loop depend on the load in the start condition.
2392 ASSERT_TRUE(history[5]->hasDependency(history[2]));
2393 ASSERT_TRUE(history[4]->hasDependency(history[2]));
2394
2395 // also the stop condition.
2396 ASSERT_TRUE(history[5]->hasDependency(history[3]));
2397 ASSERT_TRUE(history[4]->hasDependency(history[3]));
2398
2399 // Make loads from B to compare against.
2400 ExprHandle loadFromB0 = Load::make(b, {0});
2401 ExprHandle loadFromB1 = Load::make(b, {1});
2402 ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2403 ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2404 }
2405
2406 {
2407 /* for (int x = 0; x < 10; x++) {
2408 * C[x] = A[B[x]];
2409 * }
2410 */
2411 MemDependencyChecker analyzer({a, b}, {c});
2412 StmtPtr stmt = Block::make({For::make(
2413 x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))});
2414
2415 stmt->accept(&analyzer);
2416
2417 /* 0. Input: B[(0, 99)] - dependents: 2
2418 * 1. Input: A[(0, 99)] - dependents: 3
2419 * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4
2420 * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4
2421 * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5
2422 * 5. Output: C[(0, 99)] - depends on: 4
2423 */
2424
2425 // Sanity check output depends on input.
2426 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2427 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2428
2429 auto history = analyzer.getHistory();
2430 ASSERT_EQ(history.size(), 6);
2431
2432 // The store depends on both loads, the load of A depends on the load of B.
2433 ASSERT_TRUE(history[4]->hasDependency(history[2]));
2434 ASSERT_TRUE(history[4]->hasDependency(history[3]));
2435
2436 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2437
2438 // The loads in the indices depend on the relevant input buffer.
2439 ASSERT_TRUE(history[3]->hasDependency(history[1]));
2440 ASSERT_TRUE(history[2]->hasDependency(history[0]));
2441
2442 // The load from B has the loop bounds.
2443 ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2444
2445 // The load from A has bounds B[0] to B[9].
2446 ExprHandle loadFromB0 = Load::make(b, {0});
2447 ExprHandle loadFromB9 = Load::make(b, {9});
2448 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)}));
2449 }
2450
2451 {
2452 /* for (int x = 0; x < 10; x++) {
2453 * C[B[x]] = A[x];
2454 * }
2455 */
2456 MemDependencyChecker analyzer({a, b}, {c});
2457 StmtPtr stmt = Block::make({For::make(
2458 x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))});
2459
2460 stmt->accept(&analyzer);
2461
2462 /* 0. Input: B[(0, 99)] - dependents: 3
2463 * 1. Input: A[(0, 99)] - dependents: 2
2464 * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4
2465 * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4
2466 * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5
2467 * 5. Output: C[(0, 99)] - depends on: 4
2468 */
2469 // Sanity check output depends on input.
2470 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2471 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2472
2473 auto history = analyzer.getHistory();
2474 ASSERT_EQ(history.size(), 6);
2475
2476 // The store depends on both loads, neither load is dependent.
2477 ASSERT_TRUE(history[4]->hasDependency(history[2]));
2478 ASSERT_TRUE(history[4]->hasDependency(history[3]));
2479
2480 ASSERT_FALSE(history[3]->hasDependency(history[2]));
2481 ASSERT_FALSE(history[2]->hasDependency(history[3]));
2482
2483 // The loads each depend on their relevant input. (but accesses are in a
2484 // different order than the last case).
2485 ASSERT_TRUE(history[3]->hasDependency(history[0]));
2486 ASSERT_TRUE(history[2]->hasDependency(history[1]));
2487
2488 // The load from B has the loop bounds.
2489 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)}));
2490
2491 // And so does the load from A.
2492 ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2493 }
2494
2495 {
2496 /* for (int x = 0; x < 10; x++) {
2497 * C[B[A[x]]] = x;
2498 * }
2499 */
2500 MemDependencyChecker analyzer({a, b}, {c});
2501 StmtPtr stmt = Block::make({For::make(
2502 x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))});
2503
2504 stmt->accept(&analyzer);
2505
2506 /* 0. Input: B[(0, 99)] - dependents: 3
2507 * 1. Input: A[(0, 99)] - dependents: 2
2508 * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4
2509 * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4
2510 * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5
2511 * 5. Output: C[(0, 99)] - depends on: 4
2512 */
2513
2514 // Sanity check output depends on input.
2515 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2516 ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2517
2518 auto history = analyzer.getHistory();
2519 ASSERT_EQ(history.size(), 6);
2520
2521 // The store depends on both loads.
2522 ASSERT_TRUE(history[4]->hasDependency(history[2]));
2523 ASSERT_TRUE(history[4]->hasDependency(history[3]));
2524
2525 // The outer load depends on the inner.
2526 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2527
2528 // The loads each depend on their relevant input. (but accesses are in a
2529 // different order than the last case).
2530 ASSERT_TRUE(history[3]->hasDependency(history[0]));
2531 ASSERT_TRUE(history[2]->hasDependency(history[1]));
2532
2533 // The load from A has the loop bounds.
2534 ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2535 // The load from B as bounds A[0] to A[9].
2536 ExprHandle loadFromA0 = Load::make(a, {0});
2537 ExprHandle loadFromA9 = Load::make(a, {9});
2538 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)}));
2539
2540 // The store has bounds of B[A[0]] to B[A[9]].
2541 ExprHandle loadFromBA0 = Load::make(b, {loadFromA0});
2542 ExprHandle loadFromBA9 = Load::make(b, {loadFromA9});
2543 ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)}));
2544 }
2545}
2546
2547// Verify multi dimensional bounds work.
2548TEST(MemDependency, MemDependencyCheckerMultiDim) {
2549 int M = 10, N = 9, K = 12;
2550 BufHandle a("A", {M, N, K}, kInt);
2551 BufHandle b("B", {M, N, K}, kInt);
2552 BufHandle c("C", {M, K}, kInt);
2553 VarHandle x("x", kInt);
2554 VarHandle y("y", kInt);
2555 VarHandle z("z", kInt);
2556
2557 using namespace analysis;
2558
2559 auto CB = [](ExprHandle s, ExprHandle e) {
2560 return Bound(s.node(), e.node());
2561 };
2562
2563 auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2564 return indexBoundsEquals(x, y);
2565 };
2566
2567 {
2568 /* for (int x = 0; x < 10; x++) {
2569 * for (int y = 0; y < 9; y++) {
2570 * for (int z = 0; z < 12; z++) {
2571 * B[x, y, z] = A[x, y, z];
2572 * }
2573 * }
2574 * }
2575 */
2576 // Full range.
2577
2578 MemDependencyChecker analyzer({a}, {b});
2579 StmtPtr stmt = Block::make({For::make(
2580 x,
2581 0,
2582 M,
2583 For::make(
2584 y,
2585 0,
2586 N,
2587 For::make(
2588 z,
2589 0,
2590 K,
2591 Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2592
2593 stmt->accept(&analyzer);
2594
2595 // Sanity test: Output depends on input.
2596 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2597
2598 // 4 accesses: input, load, store, output.
2599 auto history = analyzer.getHistory();
2600 ASSERT_EQ(history.size(), 4);
2601
2602 // Simple chain from input to output.
2603 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2604 ASSERT_TRUE(history[2]->hasDependency(history[1]));
2605 ASSERT_TRUE(history[1]->hasDependency(history[0]));
2606
2607 ASSERT_TRUE(
2608 EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2609 ASSERT_TRUE(
2610 EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2611 }
2612
2613 {
2614 /* for (int x = 0; x < 5; x++) {
2615 * for (int y = 0; y < 5; y++) {
2616 * for (int z = 0; z < 5; z++) {
2617 * B[x, y, z] = A[x, y, z];
2618 * }
2619 * }
2620 * }
2621 */
2622 // Partial range.
2623
2624 MemDependencyChecker analyzer({a}, {b});
2625 StmtPtr stmt = Block::make({For::make(
2626 x,
2627 0,
2628 5,
2629 For::make(
2630 y,
2631 0,
2632 5,
2633 For::make(
2634 z,
2635 0,
2636 5,
2637 Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2638
2639 stmt->accept(&analyzer);
2640
2641 // Sanity test: Output depends on input.
2642 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2643
2644 // 4 accesses: input, load, store, output.
2645 auto history = analyzer.getHistory();
2646 ASSERT_EQ(history.size(), 4);
2647
2648 // Simple chain from input to output.
2649 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2650 ASSERT_TRUE(history[2]->hasDependency(history[1]));
2651 ASSERT_TRUE(history[1]->hasDependency(history[0]));
2652
2653 ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2654 ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2655 }
2656
2657 {
2658 /* for (int x = 0; x < 10; x++) {
2659 * for (int y = 0; y < 12; y++) {
2660 * B[x, 0, y] = A[x, 0, y];
2661 * }
2662 * }
2663 */
2664
2665 // Partial loops.
2666
2667 MemDependencyChecker analyzer({a}, {b});
2668 StmtPtr stmt = Block::make({For::make(
2669 x,
2670 0,
2671 N,
2672 For::make(
2673 y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))});
2674
2675 stmt->accept(&analyzer);
2676
2677 // Sanity test: Output depends on input.
2678 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2679
2680 // 4 accesses: input, load, store, output.
2681 auto history = analyzer.getHistory();
2682 ASSERT_EQ(history.size(), 4);
2683
2684 // Simple chain from input to output.
2685 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2686 ASSERT_TRUE(history[2]->hasDependency(history[1]));
2687 ASSERT_TRUE(history[1]->hasDependency(history[0]));
2688
2689 ASSERT_TRUE(
2690 EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2691 ASSERT_TRUE(
2692 EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2693 }
2694
2695 {
2696 /* for (int x = 0; x < 10; x++) {
2697 * for (int y = 0; y < 100; y++) {
2698 * for (int z = 0; z < 12; z++) {
2699 * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]);
2700 * }
2701 * }
2702 * }
2703 */
2704
2705 // Loops that don't correspond to an index, bufs with different
2706 // dimensionality.
2707
2708 MemDependencyChecker analyzer({a, c}, {b});
2709 StmtPtr stmt = Block::make({For::make(
2710 x,
2711 0,
2712 M,
2713 For::make(
2714 y,
2715 0,
2716 100,
2717 For::make(
2718 z,
2719 0,
2720 K,
2721 Store::make(
2722 b,
2723 {x, 0, z},
2724 Add::make(
2725 Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))});
2726
2727 stmt->accept(&analyzer);
2728
2729 // Sanity test: Output depends on both inputs.
2730 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2731 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node()));
2732
2733 // 6 accesses: 2 inputs, 2 loads, store, output.
2734 auto history = analyzer.getHistory();
2735 ASSERT_EQ(history.size(), 6);
2736
2737 // Simple chain from input to output over the A buf.
2738 // history[0] is the C input, history[3] is the load from C.
2739 ASSERT_TRUE(history[5]->hasDependency(history[4]));
2740 ASSERT_TRUE(history[4]->hasDependency(history[2]));
2741 ASSERT_TRUE(history[2]->hasDependency(history[1]));
2742 // The store also depends on the load from the C input.
2743 ASSERT_TRUE(history[4]->hasDependency(history[3]));
2744 ASSERT_TRUE(history[3]->hasDependency(history[0]));
2745
2746 // A Buf accesses.
2747 ASSERT_TRUE(
2748 EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2749 ASSERT_TRUE(
2750 EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2751
2752 // C buf access.
2753 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)}));
2754 }
2755
2756 {
2757 /* for (int x = 0; x < 9; x++) {
2758 * for (int y = 0; y < 10; y++) {
2759 * for (int z = 0; z < 12; z++) {
2760 * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]);
2761 * }
2762 * }
2763 * }
2764 */
2765 // Multi-dim reductions.
2766
2767 MemDependencyChecker analyzer({a}, {b});
2768 StmtPtr stmt = Block::make({For::make(
2769 x,
2770 0,
2771 M,
2772 For::make(
2773 y,
2774 0,
2775 N,
2776 For::make(
2777 z,
2778 0,
2779 K,
2780 Store::make(
2781 b,
2782 {x, 0, 0},
2783 Add::make(
2784 Load::make(b, {x, y, z}),
2785 Load::make(a, {x, y, z}))))))});
2786
2787 stmt->accept(&analyzer);
2788
2789 // Sanity test: Output depends on input.
2790 ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2791
2792 // 4 accesses: input, 2 loads, store, output.
2793 auto history = analyzer.getHistory();
2794 ASSERT_EQ(history.size(), 5);
2795
2796 // Simple chain from input to output.
2797 ASSERT_TRUE(history[4]->hasDependency(history[3]));
2798 ASSERT_TRUE(history[3]->hasDependency(history[2]));
2799 ASSERT_TRUE(history[3]->hasDependency(history[1]));
2800 ASSERT_TRUE(history[2]->hasDependency(history[0]));
2801
2802 // The load from B depends on the store to B.
2803 ASSERT_TRUE(history[1]->hasDependency(history[3]));
2804
2805 ASSERT_TRUE(
2806 EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2807 ASSERT_TRUE(
2808 EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2809 ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)}));
2810 }
2811}
2812
2813// Various tests using the external Compute/Reduce API.
2814TEST(MemDependency, MemDependencyCheckerComputeAPI) {
2815 using namespace analysis;
2816
2817 /* for (int m = 0; m < 4; m++) {
2818 * for (int n = 0; n < 5; n++) {
2819 * for (int k = 0; k < 6; k++) {
2820 * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]);
2821 * }
2822 * }
2823 * }
2824 * for (int m_1 = 0; m_1 < 4; m_1++) {
2825 * for (int n_1 = 0; n_1 < 5; n_1++) {
2826 * for (int k_1 = 0; k_1 < 6; k_1++) {
2827 * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1);
2828 * }
2829 * }
2830 * }
2831 */
2832
2833 // Can determine if 2 loops created by Compute are dependent.
2834 BufHandle a_buf("a", {4, 5}, kFloat);
2835 BufHandle b_buf("b", {5, 6}, kFloat);
2836 Tensor c = Compute(
2837 "broadcast_add",
2838 {4, 5, 6},
2839 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2840 return a_buf.load(m, n) + b_buf.load(n, k);
2841 });
2842 Tensor d = Compute(
2843 "d",
2844 {4, 5, 6},
2845 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2846 return c.load(m, n, k) + 1;
2847 });
2848
2849 LoopNest l({d}, {c, d});
2850
2851 MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2852
2853 l.root_stmt()->accept(&analyzer);
2854
2855 // Sanity test: Output depends on input.
2856 ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2857 ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2858
2859 // Second loop depends on first loop.
2860 auto c_loop = l.getLoopStmtsFor(c)[0];
2861 auto d_loop = l.getLoopStmtsFor(d)[0];
2862 ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
2863}
2864
2865TEST(MemDependency, MemDependencyCheckerComputeInline) {
2866 using namespace analysis;
2867
2868 /* for (int m = 0; m < 4; m++) {
2869 * for (int n = 0; n < 5; n++) {
2870 * for (int k = 0; k < 6; k++) {
2871 * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1);
2872 * }
2873 * }
2874 * }
2875 */
2876
2877 // Check inlining affects the number of accesses returned.
2878
2879 BufHandle a_buf("a", {4, 5}, kFloat);
2880 BufHandle b_buf("b", {5, 6}, kFloat);
2881 Tensor c = Compute(
2882 "broadcast_add",
2883 {4, 5, 6},
2884 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2885 return a_buf.load(m, n) + b_buf.load(n, k);
2886 });
2887 Tensor d = Compute(
2888 "d",
2889 {4, 5, 6},
2890 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2891 return c.load(m, n, k) + 1;
2892 });
2893
2894 LoopNest l({d}, {c, d});
2895 l.computeInline(c.buf());
2896
2897 MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2898 l.root_stmt()->accept(&analyzer);
2899
2900 // Sanity test: Output depends on input.
2901 ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2902 ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2903
2904 // broadcast_add tensor should not appear in trace at all.
2905 for (auto& wi : analyzer.getHistory()) {
2906 ASSERT_NE(wi->var(), c.buf()->base_handle());
2907 }
2908}
2909
2910TEST(MemDependency, MemDependencyCheckerComputeSplit) {
2911 using namespace analysis;
2912 // Split an axis, so the number of loops != the number of dimensions.
2913
2914 BufHandle a_buf("a", {4, 5}, kFloat);
2915 BufHandle b_buf("b", {5, 6}, kFloat);
2916 Tensor c = Compute(
2917 "broadcast_add",
2918 {4, 5, 6},
2919 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2920 return a_buf.load(m, n) + b_buf.load(n, k);
2921 });
2922
2923 LoopNest l({c});
2924
2925 MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2926 l.root_stmt()->accept(&analyzer_before);
2927
2928 l.splitWithTail(l.getLoopStmtsFor(c)[0], 2);
2929
2930 MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2931 StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2932 stmt->accept(&analyzer_after);
2933
2934 // Splitting should not change accesses at all.
2935 auto history_before = analyzer_before.getHistory();
2936 auto history_after = analyzer_after.getHistory();
2937
2938 ASSERT_EQ(history_before.size(), history_after.size());
2939
2940 for (size_t i = 0; i < history_before.size(); ++i) {
2941 ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2942 ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2943 ASSERT_EQ(
2944 history_before[i]->bounds().size(), history_after[i]->bounds().size());
2945 ASSERT_TRUE(indexBoundsEquals(
2946 history_before[i]->bounds(), history_after[i]->bounds()));
2947 ASSERT_EQ(
2948 history_before[i]->dependencies().size(),
2949 history_after[i]->dependencies().size());
2950 ASSERT_EQ(
2951 history_before[i]->dependents().size(),
2952 history_after[i]->dependents().size());
2953 }
2954}
2955
2956TEST(MemDependency, MemDependencyCheckerComputeReorder) {
2957 using namespace analysis;
2958 // Reorder an axis, so the loop order doesn't match the indexing order.
2959
2960 BufHandle a_buf("a", {4, 5}, kFloat);
2961 BufHandle b_buf("b", {5, 6}, kFloat);
2962 Tensor c = Compute(
2963 "broadcast_add",
2964 {4, 5, 6},
2965 [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2966 return a_buf.load(m, n) + b_buf.load(n, k);
2967 });
2968
2969 LoopNest l({c});
2970
2971 MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2972 l.root_stmt()->accept(&analyzer_before);
2973
2974 auto loops = l.getLoopStmtsFor(c);
2975 l.reorderAxis(loops[0], loops[1]);
2976
2977 MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2978 StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2979 stmt->accept(&analyzer_after);
2980
2981 // Reordering should not change accesses at all.
2982 auto history_before = analyzer_before.getHistory();
2983 auto history_after = analyzer_after.getHistory();
2984
2985 ASSERT_EQ(history_before.size(), history_after.size());
2986
2987 for (size_t i = 0; i < history_before.size(); ++i) {
2988 ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2989 ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2990 ASSERT_EQ(
2991 history_before[i]->bounds().size(), history_after[i]->bounds().size());
2992 ASSERT_TRUE(indexBoundsEquals(
2993 history_before[i]->bounds(), history_after[i]->bounds()));
2994 ASSERT_EQ(
2995 history_before[i]->dependencies().size(),
2996 history_after[i]->dependencies().size());
2997 ASSERT_EQ(
2998 history_before[i]->dependents().size(),
2999 history_after[i]->dependents().size());
3000 }
3001}
3002
3003TEST(MemDependency, MemDependencyCheckerComputeReduce) {
3004 using namespace analysis;
3005 /* for (int l2 = 0; l2 < 2; l2++) {
3006 * for (int n1 = 0; n1 < 3; n1++) {
3007 * for (int m1 = 0; m1 < 6; m1++) {
3008 * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]);
3009 * }
3010 * }
3011 * }
3012 * for (int l1 = 0; l1 < 2; l1++) {
3013 * sum[l1] = float(0);
3014 * for (int n1_1 = 0; n1_1 < 3; n1_1++) {
3015 * for (int m1_1 = 0; m1_1 < 6; m1_1++) {
3016 * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)),
3017 * out_args={l1}, reduce_args={n1, m1});
3018 * }
3019 * }
3020 * }
3021 */
3022
3023 // Can determine dependencies of a Reduction.
3024
3025 BufHandle a("a", {2, 3, 6}, kFloat);
3026 BufHandle b("b", {2, 3, 6}, kFloat);
3027
3028 Tensor c = Compute(
3029 "scale",
3030 {2, 3, 6},
3031 [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
3032 return b.load(l, n, m) * a.load(l, n, m);
3033 });
3034 Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
3035 LoopNest l({d}, {c, d});
3036
3037 MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
3038
3039 l.root_stmt()->accept(&analyzer);
3040
3041 // Sanity test: Output depends on input.
3042 ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node()));
3043 ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node()));
3044
3045 // Second loop depends on first loop.
3046 auto c_loop = l.getLoopStmtsFor(c)[0];
3047 auto d_loop = l.getLoopStmtsFor(d)[0];
3048 ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
3049
3050 // Reduction depends on both inputs.
3051 auto reduces = NodeFinder<ReduceOp>::find(l.root_stmt());
3052 ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node()));
3053 ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node()));
3054}
3055
3056TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
3057 int M = 1024;
3058 int N = 1024;
3059 int K = 2048;
3060 using namespace analysis;
3061
3062 BufHandle AP("A", {M, K}, kFloat);
3063 BufHandle BP("B", {K, N}, kFloat);
3064 Tensor CT = Reduce(
3065 "gemm",
3066 {M, N},
3067 Sum(),
3068 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
3069 return AP.load(m, k) * BP.load(k, n);
3070 },
3071 {K});
3072 LoopNest loop({CT});
3073
3074 {
3075 auto const& loops = loop.getLoopStmtsFor(CT);
3076 ForPtr m = loops[0];
3077 loop.splitWithMask(m, 4);
3078 }
3079 {
3080 auto const& loops = loop.getLoopStmtsFor(CT);
3081 ForPtr n = loops[2];
3082 loop.splitWithMask(n, 16);
3083 }
3084 // mo, mi, no, ni, k ->
3085 // mo, no, mi, ni, k
3086 {
3087 auto const& loops = loop.getLoopStmtsFor(CT);
3088 ForPtr mi = loops[1];
3089 ForPtr no = loops[2];
3090 loop.reorderAxis(mi, no);
3091 }
3092 // mo, no, mi, ni, k ->
3093 // mo, no, mi, k, ni
3094 {
3095 auto const& loops = loop.getLoopStmtsFor(CT);
3096 ForPtr ni = loops[3];
3097 ForPtr k = loops[4];
3098 loop.reorderAxis(ni, k);
3099 }
3100 // mo, no, mi, k, ni ->
3101 // mo, no, k, mi, ni
3102 {
3103 auto const& loops = loop.getLoopStmtsFor(CT);
3104 ForPtr mi = loops[2];
3105 ForPtr k = loops[3];
3106 loop.reorderAxis(mi, k);
3107 }
3108 {
3109 auto const& loops = loop.getLoopStmtsFor(CT);
3110 loop.cacheAccesses(CT.buf(), "C_regs", loops[2]);
3111 }
3112
3113 MemDependencyChecker analyzer_unlowered(
3114 loop.getInputBufs(), loop.getOutputBufs());
3115
3116 MemDependencyChecker analyzer_lowered(
3117 loop.getInputBufs(), loop.getOutputBufs());
3118
3119 // Test both unlowered and lowered form.
3120 {
3121 StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
3122 stmt->accept(&analyzer_unlowered);
3123
3124 // Outputs depend on inputs.
3125 ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node()));
3126 ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node()));
3127
3128 // The last write to gemm should cover the total bound of the output.
3129 std::shared_ptr<AccessInfo> outputAccess =
3130 analyzer_unlowered.output(CT.buf());
3131 // A single dependency.
3132 ASSERT_EQ(outputAccess->dependencies().size(), 1);
3133
3134 // dependencies is a set with 1 element, so can just deref begin().
3135 std::shared_ptr<AccessInfo> gemmStore =
3136 outputAccess->dependencies().begin()->second;
3137 // Check its a store.
3138 ASSERT_EQ(gemmStore->type(), AccessType::Store);
3139
3140 ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds()));
3141
3142 // Likewise the first read from each input cover the entire range of the
3143 // input.
3144 auto aInput = analyzer_unlowered.input(AP.node());
3145 auto bInput = analyzer_unlowered.input(BP.node());
3146
3147 // A single dependent each.
3148 ASSERT_EQ(aInput->dependents().size(), 1);
3149 ASSERT_EQ(bInput->dependents().size(), 1);
3150
3151 // They're both loads.
3152 std::shared_ptr<AccessInfo> aLoad = aInput->dependents().begin()->second;
3153 std::shared_ptr<AccessInfo> bLoad = bInput->dependents().begin()->second;
3154 ASSERT_EQ(aLoad->type(), AccessType::Load);
3155 ASSERT_EQ(bLoad->type(), AccessType::Load);
3156
3157 ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds()));
3158 ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds()));
3159 }
3160
3161 loop.prepareForCodegen();
3162 SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT});
3163
3164 // now check lowered dependency graph.
3165 {
3166 StmtPtr stmt = IRSimplifier::simplify(cg.stmt());
3167 stmt->accept(&analyzer_lowered);
3168
3169 // Lowering will change the dimensionality of all bounds due to index
3170 // flattening and will insert Allocates and Frees.
3171
3172 auto history_before = analyzer_unlowered.getHistory();
3173 auto history_after = analyzer_lowered.getHistory();
3174
3175 ASSERT_EQ(history_before.size() + 2, history_after.size());
3176
3177 // Filter out the alloc/free;
3178 auto isAllocFree = [](const auto& info) {
3179 return info->type() == AccessType::Alloc ||
3180 info->type() == AccessType::Free;
3181 };
3182 history_after.erase(
3183 std::remove_if(history_after.begin(), history_after.end(), isAllocFree),
3184 history_after.end());
3185
3186 ASSERT_EQ(history_before.size(), history_after.size());
3187
3188 for (size_t i = 0; i < history_before.size(); ++i) {
3189 ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
3190 ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
3191
3192 if (history_before[i]->dependencies().size() !=
3193 history_after[i]->dependencies().size()) {
3194 // Must depend on an Alloc.
3195 ASSERT_TRUE(std::any_of(
3196 history_after[i]->dependencies().begin(),
3197 history_after[i]->dependencies().end(),
3198 [](const auto& pair) {
3199 return pair.second->type() == AccessType::Alloc;
3200 }));
3201
3202 ASSERT_EQ(
3203 history_before[i]->dependencies().size() + 1,
3204 history_after[i]->dependencies().size());
3205 }
3206
3207 if (history_before[i]->dependents().size() !=
3208 history_after[i]->dependents().size()) {
3209 // Must depend on an Free.
3210 ASSERT_TRUE(std::any_of(
3211 history_after[i]->dependents().begin(),
3212 history_after[i]->dependents().end(),
3213 [](const auto& pair) {
3214 return pair.second->type() == AccessType::Free;
3215 }));
3216
3217 ASSERT_EQ(
3218 history_before[i]->dependents().size() + 1,
3219 history_after[i]->dependents().size());
3220 }
3221
3222 // Inputs and outputs are not flattened, only accesses.
3223 if (history_before[i]->type() == AccessType::Input ||
3224 history_before[i]->type() == AccessType::Output) {
3225 ASSERT_EQ(
3226 history_before[i]->bounds().size(),
3227 history_after[i]->bounds().size());
3228 ASSERT_TRUE(indexBoundsEquals(
3229 history_before[i]->bounds(), history_after[i]->bounds()));
3230 } else {
3231 ASSERT_EQ(history_after[i]->bounds().size(), 1);
3232 ExprPtr flat_bounds = alloc<IntImm>(1);
3233
3234 for (auto& b : history_before[i]->bounds()) {
3235 flat_bounds =
3236 alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1)));
3237
3238 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3239 ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start));
3240 }
3241
3242 flat_bounds = IRSimplifier::simplify(flat_bounds);
3243 ExprPtr after_bounds = IRSimplifier::simplify(
3244 alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1)));
3245 ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
3246 }
3247 }
3248 }
3249}
3250
3251} // namespace jit
3252} // namespace torch
3253