1 | #include <memory> |
2 | #include <sstream> |
3 | #include <stdexcept> |
4 | #include <unordered_map> |
5 | |
6 | #include <gtest/gtest.h> |
7 | |
8 | #include <c10/util/irange.h> |
9 | #include <test/cpp/tensorexpr/padded_buffer.h> |
10 | #include <torch/csrc/jit/tensorexpr/analysis.h> |
11 | #include <torch/csrc/jit/tensorexpr/bounds_inference.h> |
12 | #include <torch/csrc/jit/tensorexpr/eval.h> |
13 | #include <torch/csrc/jit/tensorexpr/ir.h> |
14 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
15 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
16 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
17 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | using namespace torch::jit::tensorexpr; |
23 | |
24 | static void verifyConstBounds( |
25 | const TensorAccessBoundsInfo& access_info, |
26 | const std::vector<std::pair<int, int>>& ref) { |
27 | size_t ndim = ref.size(); |
28 | ASSERT_EQ(access_info.start.size(), ndim); |
29 | ASSERT_EQ(access_info.stop.size(), ndim); |
30 | for (const auto i : c10::irange(ndim)) { |
31 | if (ref[i].first >= 0) { // Negative values are used to skip the check |
32 | ASSERT_TRUE(access_info.start[i]->isConstant()); |
33 | int start_i = immediateAs<int>(access_info.start[i]); |
34 | ASSERT_EQ(start_i, ref[i].first); |
35 | } |
36 | if (ref[i].second >= 0) { |
37 | ASSERT_TRUE(access_info.stop[i]->isConstant()); |
38 | int stop_i = immediateAs<int>(access_info.stop[i]); |
39 | ASSERT_EQ(stop_i, ref[i].second); |
40 | } |
41 | } |
42 | } |
43 | |
44 | TEST(BoundsInference, _1) { |
45 | // Verify that bounds inference works for the following example: |
46 | // for i in 0..100: |
47 | // b[i] = a[i] |
48 | // For this loop bounds inference should yield the following: |
49 | // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} |
50 | ExprHandle n(100); |
51 | BufHandle a("a" , {n}, kFloat); |
52 | Tensor b = Compute("b" , {n}, [&](const VarHandle& i) { return a.load(i); }); |
53 | LoopNest l({b}); |
54 | auto bounds_info = inferBounds(l.root_stmt()); |
55 | |
56 | // We should have two entries: one for 'b' and one for 'a'. |
57 | ASSERT_EQ(bounds_info.size(), 2); |
58 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
59 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
60 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}}); |
61 | |
62 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
63 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
64 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); |
65 | } |
66 | |
67 | TEST(BoundsInference, _2) { |
68 | // Verify that bounds inference works for the following example: |
69 | // for i in 0..n: |
70 | // b[i] = a[i] |
71 | // For this loop bounds inference should yield the following: |
72 | // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} |
73 | VarHandle n("n" , kInt); |
74 | BufHandle a("a" , {n}, kFloat); |
75 | Tensor b = Compute("b" , {n}, [&](const VarHandle& i) { return a.load(i); }); |
76 | LoopNest l({b}); |
77 | auto bounds_info = inferBounds(l.root_stmt()); |
78 | |
79 | // We should have two entries: one for 'b' and one for 'a'. |
80 | ASSERT_EQ(bounds_info.size(), 2); |
81 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
82 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
83 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}}); |
84 | |
85 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
86 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
87 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); |
88 | } |
89 | |
90 | TEST(BoundsInference, _3) { |
91 | // Verify that bounds inference works for the following example: |
92 | // for i in 0..100: |
93 | // b[i] = a[i] * a[i+10] |
94 | // For this loop bounds inference should yield the following: |
95 | // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} |
96 | ExprHandle n(100); |
97 | BufHandle a("a" , {n + 10}, kFloat); |
98 | Tensor b = Compute( |
99 | "b" , {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); |
100 | LoopNest l({b}); |
101 | auto bounds_info = inferBounds(l.root_stmt()); |
102 | |
103 | // We should have two entries: one for 'b' and one for 'a'. |
104 | ASSERT_EQ(bounds_info.size(), 2); |
105 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
106 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
107 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}}); |
108 | |
109 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
110 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
111 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); |
112 | } |
113 | |
114 | TEST(BoundsInference, _4) { |
115 | // Verify that bounds inference works for the following example: |
116 | // |
117 | // for y in 0..200: |
118 | // for x in 0..320: |
119 | // b[y,x] = x*y |
120 | // for y in 0..200: |
121 | // for x in 0..320: |
122 | // c[y,x] = a[y,x] * b[y,x] |
123 | ExprHandle W(320); |
124 | ExprHandle H(200); |
125 | BufHandle a("a" , {H, W}, kFloat); |
126 | Tensor b = Compute("b" , {H, W}, [&](const VarHandle& y, const VarHandle& x) { |
127 | return x * y; |
128 | }); |
129 | Tensor c = Compute("c" , {H, W}, [&](const VarHandle& y, const VarHandle& x) { |
130 | return a.load(y, x) * b.load(y, x); |
131 | }); |
132 | LoopNest l({c}); |
133 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
134 | StmtPtr body = l.getLoopBodyFor(c); |
135 | { |
136 | // Infer bounds on the top-level loop scope |
137 | auto bounds_info = inferBounds(loops[0]); |
138 | ASSERT_EQ(bounds_info.size(), 3); |
139 | |
140 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
141 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
142 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}}); |
143 | |
144 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
145 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); |
146 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); |
147 | |
148 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
149 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
150 | verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); |
151 | } |
152 | { |
153 | // Infer bounds on the inner loop scope |
154 | auto bounds_info = inferBounds(loops[1]); |
155 | ASSERT_EQ(bounds_info.size(), 3); |
156 | |
157 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
158 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
159 | verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}}); |
160 | |
161 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
162 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); |
163 | verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); |
164 | |
165 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
166 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
167 | verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); |
168 | } |
169 | { |
170 | // Infer bounds on the inner loop body's scope |
171 | auto bounds_info = inferBounds(body); |
172 | ASSERT_EQ(bounds_info.size(), 3); |
173 | |
174 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
175 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
176 | verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); |
177 | |
178 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
179 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); |
180 | verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); |
181 | |
182 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
183 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
184 | verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); |
185 | } |
186 | } |
187 | |
188 | TEST(BoundsInference, _5) { |
189 | // Verify that bounds inference works for the following example: |
190 | // for i in 0..100: |
191 | // b[i] = a[i] |
192 | // |
193 | // ==> split ==> |
194 | // |
195 | // for i_outer in 0..100/16: |
196 | // for i_inner in 0..16: |
197 | // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] |
198 | // for i_tail in 0..100%16: |
199 | // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; |
200 | ExprHandle n(100); |
201 | BufHandle a("a" , {n}, kFloat); |
202 | Tensor b = Compute("b" , {n}, [&](const VarHandle& i) { return a.load(i); }); |
203 | LoopNest l({b}); |
204 | |
205 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
206 | ForPtr inner; |
207 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
208 | ForPtr tail; |
209 | std::vector<ForPtr> loops = l.getLoopStmtsFor(b); |
210 | LoopNest::splitWithTail(loops[0], 16, &inner, &tail); |
211 | ForPtr outer = loops[0]; |
212 | |
213 | { |
214 | // Verify inferred bounds for the outer loop |
215 | auto bounds_info = inferBounds(outer); |
216 | ASSERT_EQ(bounds_info.size(), 2); |
217 | |
218 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
219 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
220 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}}); |
221 | |
222 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
223 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
224 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); |
225 | } |
226 | { |
227 | // Verify inferred bounds for the tail loop |
228 | auto bounds_info = inferBounds(tail); |
229 | ASSERT_EQ(bounds_info.size(), 2); |
230 | |
231 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
232 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
233 | verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}}); |
234 | |
235 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
236 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
237 | verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); |
238 | } |
239 | } |
240 | |
241 | TEST(BoundsInference, _6) { |
242 | // Verify that bounds inference works for the following example: |
243 | // |
244 | // for y in 0..200: |
245 | // for x in 0..320: |
246 | // b[y,x] = x*y |
247 | // for y in 0..20: |
248 | // for x in 0..32: |
249 | // c[y,x] = a[y+100,x+100] * b[y*2,x*5] |
250 | ExprHandle W(320); |
251 | ExprHandle H(200); |
252 | ExprHandle CW(32); |
253 | ExprHandle CH(20); |
254 | BufHandle a("a" , {H, W}, kFloat); |
255 | Tensor b = Compute("b" , {H, W}, [&](const VarHandle& y, const VarHandle& x) { |
256 | return x * y; |
257 | }); |
258 | Tensor c = |
259 | Compute("c" , {CH, CW}, [&](const VarHandle& y, const VarHandle& x) { |
260 | return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); |
261 | }); |
262 | LoopNest l({c}); |
263 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
264 | StmtPtr body = l.getLoopBodyFor(c); |
265 | { |
266 | // Infer bounds on the top-level loop scope |
267 | auto bounds_info = inferBounds(loops[0]); |
268 | ASSERT_EQ(bounds_info.size(), 3); |
269 | |
270 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
271 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
272 | verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}}); |
273 | |
274 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
275 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); |
276 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); |
277 | |
278 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
279 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
280 | verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); |
281 | } |
282 | { |
283 | // Infer bounds on the inner loop scope |
284 | auto bounds_info = inferBounds(loops[1]); |
285 | ASSERT_EQ(bounds_info.size(), 3); |
286 | |
287 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
288 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
289 | verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}}); |
290 | |
291 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
292 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); |
293 | verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); |
294 | |
295 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
296 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
297 | verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); |
298 | } |
299 | { |
300 | // Infer bounds on the inner loop body's scope |
301 | auto bounds_info = inferBounds(body); |
302 | ASSERT_EQ(bounds_info.size(), 3); |
303 | |
304 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
305 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
306 | verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); |
307 | |
308 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
309 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); |
310 | verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); |
311 | |
312 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
313 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
314 | verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); |
315 | } |
316 | } |
317 | |
318 | TEST(BoundsInference, Adjacent) { |
319 | ExprHandle H(6); |
320 | BufHandle a("a" , {20}, kFloat); |
321 | Tensor b = Compute("b" , {H}, [&](const VarHandle& x) { return a.load(x); }); |
322 | Tensor c = |
323 | Compute("c" , {H}, [&](const VarHandle& x) { return a.load(x + H); }); |
324 | LoopNest l({b, c}); |
325 | std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt()); |
326 | |
327 | { |
328 | // Infer bounds on the top-level loop scope |
329 | auto bounds_info = inferBounds(loops[0]); |
330 | ASSERT_EQ(bounds_info.size(), 2); |
331 | |
332 | // reads from a[0:5], writes to b[0:5] |
333 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
334 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
335 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}}); |
336 | |
337 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
338 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
339 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); |
340 | } |
341 | { |
342 | // Infer bounds on the inner loop scope |
343 | auto bounds_info = inferBounds(loops[1]); |
344 | ASSERT_EQ(bounds_info.size(), 2); |
345 | |
346 | // reads from a[0+6:5+6], writes to c[0:5] |
347 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
348 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
349 | verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}}); |
350 | |
351 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
352 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
353 | verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); |
354 | } |
355 | { |
356 | // Infer bounds on the high level program. |
357 | auto bounds_info = inferBounds(l.root_stmt()); |
358 | ASSERT_EQ(bounds_info.size(), 3); |
359 | |
360 | // Should be union of above 2 bounds, but this time the bounds of A can be |
361 | // merged. |
362 | ASSERT_EQ(bounds_info.at(a.node()).size(), 1); |
363 | ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); |
364 | verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}}); |
365 | |
366 | ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); |
367 | ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); |
368 | verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); |
369 | |
370 | ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); |
371 | ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); |
372 | verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); |
373 | } |
374 | } |
375 | |
376 | TEST(BoundsInference, MultipleTopLoopLoad) { |
377 | BufHandle a("a" , {100}, kFloat); |
378 | Tensor b = Compute("b" , {64}, [&](const VarHandle& x) { return a.load(x); }); |
379 | Tensor c = |
380 | Compute("c" , {32}, [&](const VarHandle& x) { return a.load(x + 10); }); |
381 | Tensor d = |
382 | Compute("d" , {96}, [&](const VarHandle& x) { return a.load(x + 2); }); |
383 | LoopNest l({b, c, d}); |
384 | |
385 | auto bounds_info = inferBounds(l.root_stmt()); |
386 | |
387 | ASSERT_EQ(bounds_info.size(), 4); |
388 | |
389 | // a only read. |
390 | { |
391 | auto bounds = bounds_info[a.node()]; |
392 | ASSERT_EQ(bounds.size(), 1); |
393 | // One dimension. |
394 | auto bound = bounds[0]; |
395 | ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); |
396 | // Bounds: |
397 | // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). |
398 | // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = |
399 | // 96 + 2 - 1 (d). |
400 | verifyConstBounds(bound, {{0, 97}}); |
401 | } |
402 | |
403 | // b, c, d only written. |
404 | { |
405 | auto bounds = bounds_info[b.buf()]; |
406 | ASSERT_EQ(bounds.size(), 1); |
407 | auto bound = bounds[0]; |
408 | ASSERT_EQ(bound.kind, TensorAccessKind::kStore); |
409 | // Just the loop extents for b. |
410 | verifyConstBounds(bound, {{0, 63}}); |
411 | } |
412 | { |
413 | auto bounds = bounds_info[c.buf()]; |
414 | ASSERT_EQ(bounds.size(), 1); |
415 | auto bound = bounds[0]; |
416 | ASSERT_EQ(bound.kind, TensorAccessKind::kStore); |
417 | // Just the loop extents for c. |
418 | verifyConstBounds(bound, {{0, 31}}); |
419 | } |
420 | { |
421 | auto bounds = bounds_info[d.buf()]; |
422 | ASSERT_EQ(bounds.size(), 1); |
423 | auto bound = bounds[0]; |
424 | ASSERT_EQ(bound.kind, TensorAccessKind::kStore); |
425 | // Just the loop extents for d. |
426 | verifyConstBounds(bound, {{0, 95}}); |
427 | } |
428 | } |
429 | |
430 | TEST(BoundsInference, MultipleTopLoopStore) { |
431 | BufHandle a("a" , {100}, kFloat); |
432 | BufHandle b("b" , {100}, kFloat); |
433 | BufHandle c("c" , {100}, kFloat); |
434 | BufHandle d("d" , {100}, kFloat); |
435 | VarHandle x("x" , kInt); |
436 | |
437 | // Same as above but the offsets are on the Store now. |
438 | // Can't do this through ComputeAPI without transforms we don't have yet. |
439 | StmtPtr stmt = Block::make( |
440 | {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), |
441 | For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), |
442 | For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); |
443 | |
444 | auto bounds_info = inferBounds(stmt); |
445 | |
446 | ASSERT_EQ(bounds_info.size(), 4); |
447 | |
448 | // a only read. |
449 | { |
450 | auto bounds = bounds_info[a.node()]; |
451 | ASSERT_EQ(bounds.size(), 1); |
452 | // One dimension. |
453 | auto bound = bounds[0]; |
454 | ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); |
455 | // Bounds: there are no offsets, so this is just the max loop bounds. |
456 | verifyConstBounds(bound, {{0, 95}}); |
457 | } |
458 | |
459 | // b, c, d only written. |
460 | { |
461 | auto bounds = bounds_info[b.node()]; |
462 | ASSERT_EQ(bounds.size(), 1); |
463 | auto bound = bounds[0]; |
464 | ASSERT_EQ(bound.kind, TensorAccessKind::kStore); |
465 | // This should be equivalent to {offset, extent + offset} for the b loop. |
466 | // b loop has no offset, so just the loop extents. |
467 | verifyConstBounds(bound, {{0, 63}}); |
468 | } |
469 | { |
470 | auto bounds = bounds_info[c.node()]; |
471 | ASSERT_EQ(bounds.size(), 1); |
472 | auto bound = bounds[0]; |
473 | ASSERT_EQ(bound.kind, TensorAccessKind::kStore); |
474 | // This should be equivalent to {offset, extent + offset} for the c loop. |
475 | // Offset is 10, extent is 32-1. |
476 | verifyConstBounds(bound, {{10, 41}}); |
477 | } |
478 | { |
479 | auto bounds = bounds_info[d.node()]; |
480 | ASSERT_EQ(bounds.size(), 1); |
481 | auto bound = bounds[0]; |
482 | ASSERT_EQ(bound.kind, TensorAccessKind::kStore); |
483 | // This should be equivalent to {offset, extent + offset} for the d loop. |
484 | // Offset is 2, extent is 96-1. |
485 | verifyConstBounds(bound, {{2, 97}}); |
486 | } |
487 | } |
488 | |
489 | TEST(BoundsInference, CacheReads) { |
490 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
491 | return i * j; |
492 | }); |
493 | Tensor B = |
494 | Compute("B" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
495 | return A.load(i + 30, j + 3); |
496 | }); |
497 | Tensor C = |
498 | Compute("C" , {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
499 | return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
500 | }); |
501 | |
502 | LoopNest l({B, C}); |
503 | auto bounds_info_before = inferBounds(l.root_stmt()); |
504 | |
505 | StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; |
506 | LoopNest::cacheAccesses(A.buf(), "A_local" , j_loop); |
507 | |
508 | auto bounds_info_after = inferBounds(l.root_stmt()); |
509 | |
510 | // CacheAccesses should not change existing bounds, but add a new one for the |
511 | // cache. |
512 | for (auto& pair : bounds_info_after) { |
513 | auto beforeIt = bounds_info_before.find(pair.first); |
514 | if (beforeIt != bounds_info_before.end()) { |
515 | // Same number of TensorAccessBoundInfos. |
516 | ASSERT_EQ(pair.second.size(), beforeIt->second.size()); |
517 | |
518 | for (const auto i : c10::irange(pair.second.size())) { |
519 | TensorAccessBoundsInfo& after = pair.second[i]; |
520 | TensorAccessBoundsInfo& before = beforeIt->second[i]; |
521 | // Same number of dimensions. |
522 | ASSERT_EQ(before.start.size(), after.start.size()); |
523 | |
524 | // Bounds are equal. |
525 | for (const auto j : c10::irange(before.start.size())) { |
526 | ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); |
527 | ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); |
528 | } |
529 | } |
530 | } else { |
531 | // This should be the cache. |
532 | ASSERT_EQ(pair.first->name_hint(), "A_local" ); |
533 | // Should have both a load and a store. |
534 | ASSERT_EQ(pair.second.size(), 2); |
535 | TensorAccessBoundsInfo& first = pair.second[0]; |
536 | TensorAccessBoundsInfo& second = pair.second[1]; |
537 | |
538 | ASSERT_NE(first.kind, second.kind); |
539 | // 2 dimensions. |
540 | ASSERT_EQ(first.start.size(), second.start.size()); |
541 | ASSERT_EQ(first.start.size(), 2); |
542 | |
543 | // bounds for load and store are equal. |
544 | for (const auto j : c10::irange(first.start.size())) { |
545 | ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); |
546 | ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); |
547 | } |
548 | } |
549 | } |
550 | } |
551 | |
552 | TEST(BoundsInference, Flattened) { |
553 | Tensor b = Compute( |
554 | "b" , |
555 | {3, 4, 5}, |
556 | [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { |
557 | return x * y + z; |
558 | }); |
559 | |
560 | LoopNest l({b}); |
561 | // Flatten indices. |
562 | l.prepareForCodegen(); |
563 | auto bounds_info = inferBounds(l.root_stmt()); |
564 | |
565 | // There's only one buffer. |
566 | ASSERT_EQ(bounds_info.size(), 1); |
567 | auto& TABI = bounds_info[b.buf()][0]; |
568 | ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); |
569 | // Flattened bounds should have a single dimension. |
570 | ASSERT_EQ(TABI.start.size(), 1); |
571 | ASSERT_EQ(TABI.stop.size(), 1); |
572 | |
573 | // Bounds should be 0 -> (3*4*5)-1 |
574 | ASSERT_TRUE(exprEquals(TABI.start[0], alloc<IntImm>(0))); |
575 | ASSERT_TRUE(exprEquals(TABI.stop[0], alloc<IntImm>(3 * 4 * 5 - 1))); |
576 | } |
577 | |
578 | TEST(BoundsInference, GetPotentialHazards) { |
579 | BufHandle a("A" , {5}, kInt); |
580 | BufHandle b("B" , {5}, kInt); |
581 | BufHandle c("C" , {5}, kInt); |
582 | VarHandle x("x" , kInt); |
583 | VarHandle y("y" , kInt); |
584 | |
585 | using namespace analysis; |
586 | |
587 | { |
588 | /* |
589 | * A[0] = B[0]; |
590 | * B[0] = 3; WAR on B |
591 | * A[0] = B[0]; WAW on A, RAW on B |
592 | * C[0] = 5; |
593 | */ |
594 | |
595 | StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); |
596 | StorePtr store2 = Store::make(b, {0}, 3); |
597 | StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); |
598 | StorePtr store4 = Store::make(c, {0}, 5); |
599 | StmtPtr stmt = Block::make({store1, store2, store3, store4}); |
600 | |
601 | MemDependencyChecker analyzer; |
602 | stmt->accept(&analyzer); |
603 | |
604 | ASSERT_EQ( |
605 | HazardKind::WriteAfterRead, |
606 | getPotentialHazards(analyzer, store1, store2)); |
607 | |
608 | ASSERT_EQ( |
609 | HazardKind::ReadAfterWrite, |
610 | getPotentialHazards(analyzer, store2, store3)); |
611 | |
612 | ASSERT_EQ( |
613 | HazardKind::WriteAfterWrite, |
614 | getPotentialHazards(analyzer, store1, store3)); |
615 | |
616 | // Fourth store has no dependencies |
617 | ASSERT_EQ( |
618 | HazardKind::NoDependency, |
619 | getPotentialHazards(analyzer, store1, store4)); |
620 | ASSERT_EQ( |
621 | HazardKind::NoDependency, |
622 | getPotentialHazards(analyzer, store2, store4)); |
623 | ASSERT_EQ( |
624 | HazardKind::NoDependency, |
625 | getPotentialHazards(analyzer, store3, store4)); |
626 | } |
627 | } |
628 | |
629 | TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { |
630 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
631 | return i * j; |
632 | }); |
633 | Tensor B = Compute("B" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
634 | return (i + 1) * (j + 1); |
635 | }); |
636 | |
637 | LoopNest l({A, B}); |
638 | |
639 | using namespace analysis; |
640 | |
641 | MemDependencyChecker analyzer; |
642 | l.root_stmt()->accept(&analyzer); |
643 | |
644 | ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; |
645 | ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; |
646 | |
647 | // No dependencies between loops. |
648 | ASSERT_EQ( |
649 | HazardKind::NoDependency, |
650 | getPotentialHazards(analyzer, loopRootA, loopRootB)); |
651 | } |
652 | |
653 | TEST(BoundsInference, GetPotentialHazardsLoopCall) { |
654 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
655 | return i * j; |
656 | }); |
657 | Tensor B = |
658 | Compute("B" , {64, 64}, [&](const VarHandle& i, const VarHandle& j) { |
659 | return A.load(i, j) + 5; |
660 | }); |
661 | |
662 | LoopNest l({A, B}); |
663 | |
664 | using namespace analysis; |
665 | |
666 | MemDependencyChecker analyzer; |
667 | l.root_stmt()->accept(&analyzer); |
668 | |
669 | ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; |
670 | ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; |
671 | |
672 | ASSERT_EQ( |
673 | HazardKind::ReadAfterWrite, |
674 | getPotentialHazards(analyzer, loopRootA, loopRootB)); |
675 | } |
676 | |
677 | TEST(BoundsInference, GetPotentialHazardsLoopSplit) { |
678 | Tensor A = Compute("A" , {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
679 | return i * j; |
680 | }); |
681 | |
682 | LoopNest l({A}); |
683 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
684 | ForPtr inner, tail; |
685 | |
686 | // Splitting with tail by something offset creates a tail which also writes to |
687 | // A. |
688 | ForPtr outer = l.getLoopStmtsFor(A)[0]; |
689 | // `outer` loop get transformed to the outer loop after splitting. |
690 | LoopNest::splitWithTail(outer, 5, &inner, &tail); |
691 | |
692 | using namespace analysis; |
693 | |
694 | MemDependencyChecker analyzer; |
695 | l.root_stmt()->accept(&analyzer); |
696 | |
697 | ASSERT_EQ( |
698 | HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); |
699 | } |
700 | |
701 | TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { |
702 | // Input IR: |
703 | // for (const auto j : c10::irange(10, 100)) { |
704 | // A[j] = 10 * j; |
705 | // } |
706 | // for (const auto k : c10::irange(10, 100)) { |
707 | // A[k-1] = 20 * k; |
708 | // } |
709 | BufHandle a_buf("A" , {200}, kInt); |
710 | VarHandle j("j" , kInt); |
711 | VarHandle k("k" , kInt); |
712 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
713 | auto forK = |
714 | For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); |
715 | auto par = Block::make({forJ, forK}); |
716 | |
717 | tensorexpr::analysis::MemDependencyChecker analyzer; |
718 | par->accept(&analyzer); |
719 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); |
720 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); |
721 | } |
722 | |
723 | TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { |
724 | // Input IR: |
725 | // for (const auto j : c10::irange(10, 100)) { |
726 | // A[j] = 10 * j; |
727 | // } |
728 | // for (const auto k : c10::irange(10, 100)) { |
729 | // A[k] = 20 * k; |
730 | // } |
731 | BufHandle a_buf("A" , {200}, kInt); |
732 | VarHandle j("j" , kInt); |
733 | VarHandle k("k" , kInt); |
734 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
735 | auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k))); |
736 | auto par = Block::make({forJ, forK}); |
737 | |
738 | tensorexpr::analysis::MemDependencyChecker analyzer; |
739 | par->accept(&analyzer); |
740 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); |
741 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); |
742 | } |
743 | |
744 | TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { |
745 | // Input IR: |
746 | // for (const auto j : c10::irange(10, 100)) { |
747 | // A[j] = 10 * j; |
748 | // } |
749 | // for (const auto k : c10::irange(10, 100)) { |
750 | // B[k] = A[k]; |
751 | // } |
752 | BufHandle a_buf("A" , {200}, kInt); |
753 | BufHandle b_buf("B" , {200}, kInt); |
754 | VarHandle j("j" , kInt); |
755 | VarHandle k("k" , kInt); |
756 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
757 | auto forK = |
758 | For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k}))); |
759 | auto par = Block::make({forJ, forK}); |
760 | |
761 | tensorexpr::analysis::MemDependencyChecker analyzer; |
762 | par->accept(&analyzer); |
763 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); |
764 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); |
765 | } |
766 | |
767 | TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { |
768 | // Input IR: |
769 | // for (const auto j : c10::irange(10, 100)) { |
770 | // A[j] = 10 * j; |
771 | // } |
772 | // for (const auto k : c10::irange(10, 100)) { |
773 | // A[k+100] = 20 * k; |
774 | // } |
775 | BufHandle a_buf("A" , {200}, kInt); |
776 | VarHandle j("j" , kInt); |
777 | VarHandle k("k" , kInt); |
778 | auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
779 | auto forK = |
780 | For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k))); |
781 | auto par = Block::make({forJ, forK}); |
782 | |
783 | tensorexpr::analysis::MemDependencyChecker analyzer; |
784 | par->accept(&analyzer); |
785 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); |
786 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); |
787 | } |
788 | |
789 | TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { |
790 | // Input IR: |
791 | // for (const auto i : c10::irange(20)) { |
792 | // for (const auto j : c10::irange(100)) { |
793 | // A[i,j] = i * j * 500; |
794 | // } |
795 | // } |
796 | // for (const auto m : c10::irange(20)) { |
797 | // for (const auto n : c10::irange(50)) { |
798 | // A[m+1,n] = m + n * 100; |
799 | // } |
800 | // } |
801 | BufHandle a_buf("A" , {20, 100}, kInt); |
802 | BufHandle b_buf("B" , {20, 50}, kInt); |
803 | VarHandle i("i" , kInt); |
804 | VarHandle j("j" , kInt); |
805 | VarHandle m("m" , kInt); |
806 | VarHandle n("n" , kInt); |
807 | auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); |
808 | auto forJ = For::make(j, 0, 100, storeA1); |
809 | auto forI = For::make(i, 0, 20, forJ); |
810 | auto storeA2 = |
811 | Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))); |
812 | auto forN = For::make(n, 0, 50, storeA2); |
813 | auto forM = For::make(m, 0, 20, forN); |
814 | auto par = Block::make({forI, forM}); |
815 | |
816 | tensorexpr::analysis::MemDependencyChecker analyzer; |
817 | par->accept(&analyzer); |
818 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM)); |
819 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI)); |
820 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN)); |
821 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ)); |
822 | ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2)); |
823 | ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1)); |
824 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2)); |
825 | ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM)); |
826 | } |
827 | |
828 | TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { |
829 | // Input IR: |
830 | // for (const auto i : c10::irange(20)) { |
831 | // for (const auto j : c10::irange(100)) { |
832 | // A[i,j] = i * j * 500; |
833 | // } |
834 | // } |
835 | // for (const auto m : c10::irange(20)) { |
836 | // for (const auto n : c10::irange(50)) { |
837 | // A[m+20,n+100] = m + n * 100; |
838 | // } |
839 | // } |
840 | BufHandle a_buf("A" , {20, 100}, kInt); |
841 | BufHandle b_buf("B" , {20, 50}, kInt); |
842 | VarHandle i("i" , kInt); |
843 | VarHandle j("j" , kInt); |
844 | VarHandle m("m" , kInt); |
845 | VarHandle n("n" , kInt); |
846 | auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); |
847 | auto forJ = For::make(j, 0, 100, storeA1); |
848 | auto forI = For::make(i, 0, 20, forJ); |
849 | auto storeA2 = |
850 | Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); |
851 | auto forN = For::make(n, 0, 50, storeA2); |
852 | auto forM = For::make(m, 0, 20, forN); |
853 | auto par = Block::make({forI, forM}); |
854 | |
855 | tensorexpr::analysis::MemDependencyChecker analyzer; |
856 | par->accept(&analyzer); |
857 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); |
858 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); |
859 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); |
860 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); |
861 | ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); |
862 | ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); |
863 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); |
864 | ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); |
865 | } |
866 | |
867 | TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { |
868 | // Input IR: |
869 | // for (const auto i : c10::irange(20)) { |
870 | // for (const auto j : c10::irange(100)) { |
871 | // A[i,j] = i * j * 500; |
872 | // } |
873 | // } |
874 | // for (const auto m : c10::irange(20)) { |
875 | // for (const auto n : c10::irange(50)) { |
876 | // B[m,n] = m + n * 100; |
877 | // } |
878 | // } |
879 | BufHandle a_buf("A" , {20, 100}, kInt); |
880 | BufHandle b_buf("B" , {20, 50}, kInt); |
881 | VarHandle i("i" , kInt); |
882 | VarHandle j("j" , kInt); |
883 | VarHandle m("m" , kInt); |
884 | VarHandle n("n" , kInt); |
885 | auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); |
886 | auto forJ = For::make(j, 0, 100, storeA1); |
887 | auto forI = For::make(i, 0, 20, forJ); |
888 | auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))); |
889 | auto forN = For::make(n, 0, 50, storeA2); |
890 | auto forM = For::make(m, 0, 20, forN); |
891 | auto par = Block::make({forI, forM}); |
892 | |
893 | tensorexpr::analysis::MemDependencyChecker analyzer; |
894 | par->accept(&analyzer); |
895 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); |
896 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); |
897 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); |
898 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); |
899 | ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); |
900 | ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); |
901 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); |
902 | ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); |
903 | } |
904 | |
905 | TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { |
906 | // Input IR: |
907 | // for (const auto j : c10::irange(100)) { |
908 | // A[j] = 10 * j; |
909 | // } |
910 | // for (const auto k : c10::irange(100)) { |
911 | // B[k] = 20 * A[99-k]; |
912 | // } |
913 | BufHandle a_buf("A" , {100}, kInt); |
914 | BufHandle b_buf("B" , {100}, kInt); |
915 | VarHandle j("j" , kInt); |
916 | VarHandle k("k" , kInt); |
917 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
918 | auto forK = For::make( |
919 | k, |
920 | 0, |
921 | 100, |
922 | Store::make( |
923 | b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
924 | auto par = Block::make({forJ, forK}); |
925 | |
926 | tensorexpr::analysis::MemDependencyChecker analyzer; |
927 | par->accept(&analyzer); |
928 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); |
929 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); |
930 | } |
931 | |
932 | TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { |
933 | // Input IR: |
934 | // for (const auto k : c10::irange(100)) { |
935 | // B[k] = 20 * A[99-k]; |
936 | // } |
937 | // for (const auto j : c10::irange(100)) { |
938 | // A[j] = 10 * j; |
939 | // } |
940 | BufHandle a_buf("A" , {100}, kInt); |
941 | BufHandle b_buf("B" , {100}, kInt); |
942 | VarHandle j("j" , kInt); |
943 | VarHandle k("k" , kInt); |
944 | auto forK = For::make( |
945 | k, |
946 | 0, |
947 | 100, |
948 | Store::make( |
949 | b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
950 | auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
951 | auto par = Block::make({forK, forJ}); |
952 | |
953 | tensorexpr::analysis::MemDependencyChecker analyzer; |
954 | par->accept(&analyzer); |
955 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); |
956 | ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); |
957 | } |
958 | |
959 | TEST(BoundsInference, HasConflictingOverlapWithLoads) { |
960 | // Input IR: |
961 | // for (const auto k : c10::irange(10, 100)) { |
962 | // B[k] = 20 * A[99-k]; |
963 | // } |
964 | // for (const auto j : c10::irange(10, 100)) { |
965 | // C[j] = 10 * A[j]; |
966 | // } |
967 | BufHandle a_buf("A" , {100}, kInt); |
968 | BufHandle b_buf("B" , {100}, kInt); |
969 | BufHandle c_buf("C" , {100}, kInt); |
970 | VarHandle j("j" , kInt); |
971 | VarHandle k("k" , kInt); |
972 | auto forK = For::make( |
973 | k, |
974 | 10, |
975 | 100, |
976 | Store::make( |
977 | b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
978 | auto forJ = For::make( |
979 | j, |
980 | 10, |
981 | 100, |
982 | Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j})))); |
983 | auto par = Block::make({forK, forJ}); |
984 | |
985 | tensorexpr::analysis::MemDependencyChecker analyzer; |
986 | par->accept(&analyzer); |
987 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); |
988 | ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); |
989 | } |
990 | |
991 | TEST(BoundsInference, IsOverlapping) { |
992 | // Input IR: |
993 | // for (const auto i : c10::irange(100)) { |
994 | // A[i] = i * 10; // storeA1 |
995 | // B[i] = A[99-i] * 20; // loadA1 |
996 | // C[i] = A[i + 100] * 10; // loadA2 |
997 | // A[i + 50] = i * 50; // storeA2 |
998 | // A[i + 150] = i * 150; // storeA3 |
999 | // } |
1000 | BufHandle a_buf("A" , {300}, kInt); |
1001 | BufHandle b_buf("B" , {100}, kInt); |
1002 | BufHandle c_buf("C" , {100}, kInt); |
1003 | VarHandle i("i" , kInt); |
1004 | auto storeA1 = Store::make(a_buf, {i}, i * 10); |
1005 | auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i}); |
1006 | auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20)); |
1007 | auto loadA2 = Load::make(a_buf, {i + 100}); |
1008 | auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10)); |
1009 | auto storeA2 = Store::make(a_buf, {i + 50}, i * 50); |
1010 | auto storeA3 = Store::make(a_buf, {i + 150}, i * 150); |
1011 | auto forI = For::make( |
1012 | i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); |
1013 | tensorexpr::analysis::MemDependencyChecker analyzer; |
1014 | forI->accept(&analyzer); |
1015 | ASSERT_TRUE(isOverlapping(analyzer, storeA1, to<Load>(loadA1.node()))); |
1016 | ASSERT_FALSE(isOverlapping(analyzer, storeA1, to<Load>(loadA2.node()))); |
1017 | ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); |
1018 | ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); |
1019 | } |
1020 | |
1021 | } // namespace jit |
1022 | } // namespace torch |
1023 | |