1 | #include <gtest/gtest.h> |
2 | #include <test/cpp/tensorexpr/test_base.h> |
3 | |
4 | #include <c10/util/irange.h> |
5 | #include <test/cpp/tensorexpr/padded_buffer.h> |
6 | #include <torch/csrc/jit/tensorexpr/ir.h> |
7 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
8 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
9 | #include <torch/csrc/jit/tensorexpr/llvm_codegen.h> |
10 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
11 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | using namespace torch::jit::tensorexpr; |
17 | |
18 | extern void checkIR(StmtPtr s, const std::string& pattern); |
19 | |
20 | TEST(BufLiveRange, SingleRangeLine) { |
21 | VarHandle i("i" , kInt), j("j" , kInt); |
22 | BufHandle a("a" , {32}, kFloat); |
23 | BufHandle b("b" , {32, 32}, kFloat); |
24 | |
25 | // Construct Stmt: |
26 | // { |
27 | // for (int i = 0; i < 32; i++) { |
28 | // a[i] = 0; |
29 | // for (int j = 0; j < 32; j++) { |
30 | // a[i] = (a[i]) + (b[i, j]); |
31 | // } |
32 | // } |
33 | // } |
34 | |
35 | StorePtr aInit = Store::make(a, {i}, 0); |
36 | ExprHandle reduce = a.load({i}) + b.load({i, j}); |
37 | StorePtr aReduce = Store::make(a, {i}, reduce); |
38 | StmtPtr loop = |
39 | For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); |
40 | |
41 | StmtPtr stmt = Block::make({loop}); |
42 | |
43 | auto range = BufLiveRange::liveRange(stmt, a.node()); |
44 | ASSERT_TRUE(std::get<0>(range) == 0); |
45 | ASSERT_TRUE(std::get<1>(range) == 0); |
46 | } |
47 | |
48 | TEST(BufLiveRange, MulRangeLine) { |
49 | VarHandle i("i" , kInt); |
50 | BufHandle a("a" , {32}, kFloat); |
51 | BufHandle b("b" , {32}, kFloat); |
52 | |
53 | // Construct Stmt: |
54 | // { |
55 | // for (int i = 0; i < 32; i++) { |
56 | // if (i<10 ? 1 : 0) { |
57 | // a[i] = i + i; |
58 | // b[i] = i * i; |
59 | // } |
60 | // } |
61 | // for (int i = 0; i < 32; i++) { |
62 | // if (i>10 ? 1 : 0) { |
63 | // a[i] = i * i; |
64 | // b[i] = i + i; |
65 | // } |
66 | // } |
67 | // } |
68 | |
69 | StorePtr aStore_1 = Store::make(a, {i}, i + i); |
70 | StorePtr bStore_1 = Store::make(b, {i}, i * i); |
71 | StmtPtr loop_1 = For::make( |
72 | i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); |
73 | |
74 | StorePtr aStore_2 = Store::make(a, {i}, i * i); |
75 | StorePtr bStore_2 = Store::make(b, {i}, i + i); |
76 | StmtPtr loop_2 = For::make( |
77 | i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); |
78 | |
79 | StmtPtr stmt = Block::make({loop_1, loop_2}); |
80 | |
81 | auto range_a = BufLiveRange::liveRange(stmt, a.node()); |
82 | ASSERT_TRUE(std::get<0>(range_a) == 0); |
83 | ASSERT_TRUE(std::get<1>(range_a) == 1); |
84 | |
85 | auto range_b = BufLiveRange::liveRange(stmt, b.node()); |
86 | ASSERT_TRUE(std::get<0>(range_b) == 0); |
87 | ASSERT_TRUE(std::get<1>(range_b) == 1); |
88 | } |
89 | |
90 | TEST(MemPlanning, MemReuseWithTypeCast) { |
91 | int M = 4; |
92 | int N = 4; |
93 | int K = 4; |
94 | |
95 | BufHandle AP("A" , {M, K}, kFloat); |
96 | BufHandle BP("B" , {K, N}, kFloat); |
97 | |
98 | Tensor CT = Reduce( |
99 | "gemm" , |
100 | {M, N}, |
101 | Sum(), |
102 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
103 | return AP.load(m, k) * BP.load(k, n); |
104 | }, |
105 | {K}); |
106 | Tensor DT = |
107 | Compute("relu" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
108 | return CompareSelect::make( |
109 | CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); |
110 | }); |
111 | Tensor ET = |
112 | Compute("E" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
113 | return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); |
114 | }); |
115 | Tensor FT = |
116 | Compute("F" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
117 | return ET.load(m, n); |
118 | }); |
119 | StmtPtr stmt = |
120 | tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); |
121 | |
122 | // Constructed stmt: |
123 | // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
124 | // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are |
125 | // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' |
126 | // with typecasting. |
127 | //{ |
128 | // for (int i = 0; i < 4; i++) { |
129 | // for (int i_1 = 0; i_1 < 4; i_1++) { |
130 | // gemm[i, i_1] = float(0); |
131 | // for (int i_2 = 0; i_2 < 4; i_2++) { |
132 | // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, |
133 | // i_1]), reduce_args={i_2}); |
134 | // } |
135 | // } |
136 | // } |
137 | // for (int i_3 = 0; i_3 < 4; i_3++) { |
138 | // for (int i_4 = 0; i_4 < 4; i_4++) { |
139 | // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); |
140 | // } |
141 | // } |
142 | // for (int i_5 = 0; i_5 < 4; i_5++) { |
143 | // for (int i_6 = 0; i_6 < 4; i_6++) { |
144 | // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); |
145 | // } |
146 | // } |
147 | // for (int i_7 = 0; i_7 < 4; i_7++) { |
148 | // for (int i_8 = 0; i_8 < 4; i_8++) { |
149 | // F[i_7, i_8] = E[i_7, i_8]; |
150 | // } |
151 | // } |
152 | //} |
153 | |
154 | LoopNest l(stmt, {FT.buf()}); |
155 | l.prepareForCodegen(); |
156 | SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); |
157 | |
158 | checkIR(cg.stmt(), R"IR( |
159 | # CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] |
160 | # CHECK: Allocate(relu); // dtype=float, dims=[4, 4] |
161 | # CHECK: Alias(E,gemm); |
162 | # CHECK: Free(relu); |
163 | # CHECK: Free(gemm))IR" ); |
164 | |
165 | PaddedBuffer<float> a_v(M, K, "a" ); |
166 | PaddedBuffer<float> b_v(K, N, "b" ); |
167 | PaddedBuffer<uint8_t> o1(M, N, "e_before" ); |
168 | PaddedBuffer<uint8_t> o2(M, N, "e_after" ); |
169 | |
170 | for (const auto m : c10::irange(M)) { |
171 | for (const auto k : c10::irange(K)) { |
172 | a_v(m, k) = at::randn({1}).item().to<float>(); |
173 | } |
174 | } |
175 | |
176 | for (const auto k : c10::irange(K)) { |
177 | for (const auto n : c10::irange(N)) { |
178 | b_v(k, n) = at::randn({1}).item().to<float>(); |
179 | } |
180 | } |
181 | |
182 | cg.call({a_v, b_v, o1}); |
183 | |
184 | #ifdef TORCH_ENABLE_LLVM |
185 | LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); |
186 | |
187 | checkIR(cg_llvm.stmt(), R"IR( |
188 | # CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] |
189 | # CHECK: Allocate(relu); // dtype=float, dims=[4, 4] |
190 | # CHECK: Alias(E,gemm); |
191 | # CHECK: Free(relu); |
192 | # CHECK: Free(gemm))IR" ); |
193 | |
194 | cg_llvm.call({a_v, b_v, o2}); |
195 | |
196 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
197 | ExpectAllNear(o1, o2, 1e-5); |
198 | #endif |
199 | } |
200 | |
201 | TEST(MemPlanning, NoMemReuseForLargerType) { |
202 | int M = 4; |
203 | int N = 4; |
204 | int K = 4; |
205 | |
206 | BufHandle AP("A" , {M, K}, kShort); |
207 | BufHandle BP("B" , {K, N}, kShort); |
208 | |
209 | Tensor CT = Reduce( |
210 | "gemm" , |
211 | {M, N}, |
212 | Sum(), |
213 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
214 | return AP.load(m, k) * BP.load(k, n); |
215 | }, |
216 | {K}); |
217 | auto zero = Cast::make(CT.buf()->dtype(), 0); |
218 | Tensor DT = |
219 | Compute("relu" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
220 | return CompareSelect::make( |
221 | CT.load(m, n), zero, zero, CT.load(m, n), kLT); |
222 | }); |
223 | Tensor ET = |
224 | Compute("E" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
225 | return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); |
226 | }); |
227 | Tensor FT = |
228 | Compute("F" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
229 | return ET.load(m, n); |
230 | }); |
231 | StmtPtr stmt = |
232 | tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); |
233 | |
234 | // Constructed stmt: |
235 | // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
236 | // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are |
237 | // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for |
238 | // 'E'. |
239 | //{ |
240 | // for (int i = 0; i < 4; i++) { |
241 | // for (int i_1 = 0; i_1 < 4; i_1++) { |
242 | // gemm[i, i_1] = int16_t(0); |
243 | // for (int i_2 = 0; i_2 < 4; i_2++) { |
244 | // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, |
245 | // i_1]), reduce_args={i_2}); |
246 | // } |
247 | // } |
248 | // } |
249 | // for (int i_3 = 0; i_3 < 4; i_3++) { |
250 | // for (int i_4 = 0; i_4 < 4; i_4++) { |
251 | // relu[i_3, i_4] = (gemm[i_3, i_4])<int16_t(0) ? int16_t(0) : (gemm[i_3, |
252 | // i_4]); |
253 | // } |
254 | // } |
255 | // for (int i_5 = 0; i_5 < 4; i_5++) { |
256 | // for (int i_6 = 0; i_6 < 4; i_6++) { |
257 | // E[i_5, i_6] = float((relu[i_5, i_6]) + (relu[i_5, i_6])); |
258 | // } |
259 | // } |
260 | // for (int i_7 = 0; i_7 < 4; i_7++) { |
261 | // for (int i_8 = 0; i_8 < 4; i_8++) { |
262 | // F[i_7, i_8] = E[i_7, i_8]; |
263 | // } |
264 | // } |
265 | //} |
266 | |
267 | LoopNest l(stmt, {FT.buf()}); |
268 | l.prepareForCodegen(); |
269 | SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT.buf()}); |
270 | |
271 | checkIR(cg.stmt(), R"IR( |
272 | # CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] |
273 | # CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] |
274 | # CHECK: Allocate(E); // dtype=float, dims=[4, 4] |
275 | # CHECK: Free(E); |
276 | # CHECK: Free(relu); |
277 | # CHECK: Free(gemm))IR" ); |
278 | |
279 | PaddedBuffer<short> a_v(M, K, "a" ); |
280 | PaddedBuffer<short> b_v(K, N, "b" ); |
281 | PaddedBuffer<float> o1(M, N, "e_before" ); |
282 | PaddedBuffer<float> o2(M, N, "e_after" ); |
283 | |
284 | for (const auto m : c10::irange(M)) { |
285 | for (const auto k : c10::irange(K)) { |
286 | a_v(m, k) = at::randn({1}).item().to<float>(); |
287 | } |
288 | } |
289 | |
290 | for (const auto k : c10::irange(K)) { |
291 | for (const auto n : c10::irange(N)) { |
292 | b_v(k, n) = at::randn({1}).item().to<float>(); |
293 | } |
294 | } |
295 | |
296 | cg.call({a_v, b_v, o1}); |
297 | |
298 | #ifdef TORCH_ENABLE_LLVM |
299 | LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); |
300 | |
301 | checkIR(cg_llvm.stmt(), R"IR( |
302 | # CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] |
303 | # CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] |
304 | # CHECK: Allocate(E); // dtype=float, dims=[4, 4] |
305 | # CHECK: Free(E); |
306 | # CHECK: Free(relu); |
307 | # CHECK: Free(gemm))IR" ); |
308 | |
309 | cg_llvm.call({a_v, b_v, o2}); |
310 | |
311 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
312 | ExpectAllNear(o1, o2, 1e-5); |
313 | #endif |
314 | } |
315 | |
316 | TEST(MemPlanning, SameBufSizeMemReuse) { |
317 | int M = 1024; |
318 | int N = 1024; |
319 | int K = 2048; |
320 | |
321 | BufHandle AP("A" , {M, K}, kFloat); |
322 | BufHandle BP("B" , {K, N}, kFloat); |
323 | |
324 | Tensor CT = Reduce( |
325 | "gemm" , |
326 | {M, N}, |
327 | Sum(), |
328 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
329 | return AP.load(m, k) * BP.load(k, n); |
330 | }, |
331 | {K}); |
332 | Tensor DT = |
333 | Compute("relu" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
334 | auto zero = Cast::make(CT.buf()->dtype(), 0); |
335 | return CompareSelect::make( |
336 | CT.load(m, n), zero, zero, CT.load(m, n), kLT); |
337 | }); |
338 | Tensor ET = |
339 | Compute("add" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
340 | return DT.load(m, n) + DT.load(m, n); |
341 | }); |
342 | Tensor FT = |
343 | Compute("mul" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
344 | return ET.load(m, n) * ET.load(m, n); |
345 | }); |
346 | auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); |
347 | |
348 | // Constructed stmt: |
349 | // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
350 | // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' |
351 | // for 'add'. |
352 | //{ |
353 | // for (int M = 0; M < 1024; M++) { |
354 | // for (int N = 0; N < 1024; N++) { |
355 | // gemm[M, N] = float(0); |
356 | // for (int K = 0; K < 2048; K++) { |
357 | // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), |
358 | // reduce_args={K}); |
359 | // } |
360 | // } |
361 | // } |
362 | // for (int M_1 = 0; M_1 < 1024; M_1++) { |
363 | // for (int N_1 = 0; N_1 < 1024; N_1++) { |
364 | // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1, |
365 | // N_1]); |
366 | // } |
367 | // } |
368 | // for (int M_2 = 0; M_2 < 1024; M_2++) { |
369 | // for (int N_2 = 0; N_2 < 1024; N_2++) { |
370 | // add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]); |
371 | // } |
372 | // } |
373 | // for (int M_3 = 0; M_3 < 1024; M_3++) { |
374 | // for (int N_3 = 0; N_3 < 1024; N_3++) { |
375 | // mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]); |
376 | // } |
377 | // } |
378 | //} |
379 | |
380 | SimpleIREvaluator cg(stmt, {AP, BP, FT}); |
381 | |
382 | checkIR(cg.stmt(), R"IR( |
383 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
384 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
385 | # CHECK: Alias(add,gemm); |
386 | # CHECK: Free(relu); |
387 | # CHECK: Free(gemm))IR" ); |
388 | |
389 | #ifdef TORCH_ENABLE_LLVM |
390 | LoopNest loop(Stmt::clone(stmt), {FT.buf()}); |
391 | loop.prepareForCodegen(); |
392 | LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT}); |
393 | |
394 | checkIR(cg_llvm.stmt(), R"IR( |
395 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
396 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
397 | # CHECK: Alias(add,gemm); |
398 | # CHECK: Free(relu); |
399 | # CHECK: Free(gemm))IR" ); |
400 | #endif |
401 | } |
402 | |
403 | TEST(MemPlanning, SameBufSizeMultiMemReuses) { |
404 | int M = 1024; |
405 | int N = 1024; |
406 | int K = 2048; |
407 | |
408 | BufHandle AP("A" , {M, K}, kFloat); |
409 | BufHandle BP("B" , {K, N}, kFloat); |
410 | |
411 | Tensor CT = Reduce( |
412 | "gemm" , |
413 | {M, N}, |
414 | Sum(), |
415 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
416 | return AP.load(m, k) * BP.load(k, n); |
417 | }, |
418 | {K}); |
419 | Tensor DT = |
420 | Compute("relu" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
421 | auto zero = Cast::make(CT.buf()->dtype(), 0); |
422 | return CompareSelect::make( |
423 | CT.load(m, n), zero, zero, CT.load(m, n), kLT); |
424 | }); |
425 | Tensor ET = |
426 | Compute("add" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
427 | return DT.load(m, n) + DT.load(m, n); |
428 | }); |
429 | Tensor FT = |
430 | Compute("mul" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
431 | return ET.load(m, n) * ET.load(m, n); |
432 | }); |
433 | Tensor GT = |
434 | Compute("sub" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
435 | return FT.load(m, n) - ET.load(m, n); |
436 | }); |
437 | |
438 | auto stmt = |
439 | Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); |
440 | |
441 | // Constructed stmt: |
442 | // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
443 | // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same |
444 | // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' |
445 | //{ |
446 | // for (int M = 0; M < 1024; M++) { |
447 | // for (int N = 0; N < 1024; N++) { |
448 | // gemm[M, N] = float(0); |
449 | // for (int K = 0; K < 2048; K++) { |
450 | // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), |
451 | // reduce_args={K}); |
452 | // } |
453 | // } |
454 | // } |
455 | // for (int M_1 = 0; M_1 < 1024; M_1++) { |
456 | // for (int N_1 = 0; N_1 < 1024; N_1++) { |
457 | // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1, |
458 | // N_1]); |
459 | // } |
460 | // } |
461 | // for (int M_2 = 0; M_2 < 1024; M_2++) { |
462 | // for (int N_2 = 0; N_2 < 1024; N_2++) { |
463 | // add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]); |
464 | // } |
465 | // } |
466 | // for (int M_3 = 0; M_3 < 1024; M_3++) { |
467 | // for (int N_3 = 0; N_3 < 1024; N_3++) { |
468 | // mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]); |
469 | // } |
470 | // } |
471 | // for (int M_4 = 0; M_4 < 1024; M_4++) { |
472 | // for (int N_4 = 0; N_4 < 1024; N_4++) { |
473 | // sub[M_4, N_4] = (mul[M_4, N_4]) - (add[M_4, N_4]); |
474 | // } |
475 | // } |
476 | //} |
477 | |
478 | SimpleIREvaluator cg(stmt, {AP, BP, GT}); |
479 | |
480 | checkIR(cg.stmt(), R"IR( |
481 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
482 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
483 | # CHECK: Alias(add,gemm); |
484 | # CHECK: Alias(mul,relu); |
485 | # CHECK: Free(relu); |
486 | # CHECK: Free(gemm))IR" ); |
487 | |
488 | #ifdef TORCH_ENABLE_LLVM |
489 | LoopNest loop(Stmt::clone(stmt), {FT.buf()}); |
490 | loop.prepareForCodegen(); |
491 | LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT}); |
492 | |
493 | checkIR(cg_llvm.stmt(), R"IR( |
494 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
495 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
496 | # CHECK: Alias(add,gemm); |
497 | # CHECK: Alias(mul,relu); |
498 | # CHECK: Free(relu); |
499 | # CHECK: Free(gemm))IR" ); |
500 | #endif |
501 | } |
502 | |
503 | TEST(MemPlanning, SameBufSizeMultiMemReusesOfOneBuf) { |
504 | int M = 1024; |
505 | int N = 1024; |
506 | int K = 2048; |
507 | |
508 | BufHandle AP("A" , {M, K}, kFloat); |
509 | BufHandle BP("B" , {K, N}, kFloat); |
510 | |
511 | Tensor CT = Reduce( |
512 | "gemm" , |
513 | {M, N}, |
514 | Sum(), |
515 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
516 | return AP.load(m, k) * BP.load(k, n); |
517 | }, |
518 | {K}); |
519 | Tensor DT = |
520 | Compute("relu" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
521 | auto zero = Cast::make(CT.buf()->dtype(), 0); |
522 | return CompareSelect::make( |
523 | CT.load(m, n), zero, zero, CT.load(m, n), kLT); |
524 | }); |
525 | Tensor ET = |
526 | Compute("add" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
527 | return DT.load(m, n) + DT.load(m, n); |
528 | }); |
529 | Tensor FT = |
530 | Compute("mul" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
531 | return ET.load(m, n) * ET.load(m, n); |
532 | }); |
533 | Tensor GT = |
534 | Compute("sub" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
535 | return FT.load(m, n) - 1; |
536 | }); |
537 | Tensor HT = |
538 | Compute("div" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
539 | return GT.load(m, n) / 2; |
540 | }); |
541 | |
542 | auto stmt = Block::make( |
543 | {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); |
544 | |
545 | // Constructed stmt: |
546 | // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
547 | // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and |
548 | // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for |
549 | // 'mul', and reuse 'gemm' for 'sub'. |
550 | //{ |
551 | // for (int M = 0; M < 1024; M++) { |
552 | // for (int N = 0; N < 1024; N++) { |
553 | // gemm[M, N] = float(0); |
554 | // for (int K = 0; K < 2048; K++) { |
555 | // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), |
556 | // reduce_args={K}); |
557 | // } |
558 | // } |
559 | // } |
560 | // for (int M_1 = 0; M_1 < 1024; M_1++) { |
561 | // for (int N_1 = 0; N_1 < 1024; N_1++) { |
562 | // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1, |
563 | // N_1]); |
564 | // } |
565 | // } |
566 | // for (int M_2 = 0; M_2 < 1024; M_2++) { |
567 | // for (int N_2 = 0; N_2 < 1024; N_2++) { |
568 | // add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]); |
569 | // } |
570 | // } |
571 | // for (int M_3 = 0; M_3 < 1024; M_3++) { |
572 | // for (int N_3 = 0; N_3 < 1024; N_3++) { |
573 | // mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]); |
574 | // } |
575 | // } |
576 | // for (int M_4 = 0; M_4 < 1024; M_4++) { |
577 | // for (int N_4 = 0; N_4 < 1024; N_4++) { |
578 | // sub[M_4, N_4] = (mul[M_4, N_4]) - float(1); |
579 | // } |
580 | // } |
581 | // for (int M_5 = 0; M_5 < 1024; M_5++) { |
582 | // for (int N_5 = 0; N_5 < 1024; N_5++) { |
583 | // div[M_5, N_5] = (sub[M_5, N_5]) / float(2); |
584 | // } |
585 | // } |
586 | //} |
587 | |
588 | SimpleIREvaluator cg(stmt, {AP, BP, HT}); |
589 | |
590 | checkIR(cg.stmt(), R"IR( |
591 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
592 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
593 | # CHECK: Alias(add,gemm); |
594 | # CHECK: Alias(mul,relu); |
595 | # CHECK: Alias(sub,gemm); |
596 | # CHECK: Free(relu); |
597 | # CHECK: Free(gemm))IR" ); |
598 | |
599 | #ifdef TORCH_ENABLE_LLVM |
600 | LoopNest loop(Stmt::clone(stmt), {FT.buf()}); |
601 | loop.prepareForCodegen(); |
602 | LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT}); |
603 | |
604 | checkIR(cg_llvm.stmt(), R"IR( |
605 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
606 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
607 | # CHECK: Alias(add,gemm); |
608 | # CHECK: Alias(mul,relu); |
609 | # CHECK: Alias(sub,gemm); |
610 | # CHECK: Free(relu); |
611 | # CHECK: Free(gemm))IR" ); |
612 | #endif |
613 | } |
614 | |
615 | TEST(MemPlanning, SmallerBufSizeNonMemReuse) { |
616 | int M = 1024; |
617 | int N = 1024; |
618 | int K = 2048; |
619 | |
620 | BufHandle AP("A" , {M, K}, kFloat); |
621 | BufHandle BP("B" , {K, N}, kFloat); |
622 | |
623 | Tensor CT = Reduce( |
624 | "gemm" , |
625 | {M, N}, |
626 | Sum(), |
627 | [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
628 | return AP.load(m, k) * BP.load(k, n); |
629 | }, |
630 | {K}); |
631 | Tensor DT = |
632 | Compute("relu" , {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
633 | auto zero = Cast::make(CT.buf()->dtype(), 0); |
634 | return CompareSelect::make( |
635 | CT.load(m, n), zero, zero, CT.load(m, n), kLT); |
636 | }); |
637 | Tensor ET = Compute( |
638 | "add" , {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { |
639 | return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); |
640 | }); |
641 | Tensor FT = Compute( |
642 | "mul" , {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { |
643 | return ET.load(fm, fn) * ET.load(fm, fn); |
644 | }); |
645 | auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); |
646 | |
647 | // Constructed stmt: |
648 | // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
649 | // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of |
650 | // buffer 'gemm' is smaller. |
651 | //{ |
652 | // for (int M = 0; M < 1024; M++) { |
653 | // for (int N = 0; N < 1024; N++) { |
654 | // gemm[M, N] = float(0); |
655 | // for (int K = 0; K < 2048; K++) { |
656 | // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), |
657 | // reduce_args={K}); |
658 | // } |
659 | // } |
660 | // } |
661 | // for (int M_1 = 0; M_1 < 1024; M_1++) { |
662 | // for (int N_1 = 0; N_1 < 1024; N_1++) { |
663 | // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1, |
664 | // N_1]); |
665 | // } |
666 | // } |
667 | // for (int EM = 0; EM < 2048; EM++) { |
668 | // for (int EN = 0; EN < 2048; EN++) { |
669 | // add[EM, EN] = (relu[EM / 2, EN / 2]) + (relu[EM / 2, EN / 2]); |
670 | // } |
671 | // } |
672 | // for (int FM = 0; FM < 2048; FM++) { |
673 | // for (int FN = 0; FN < 2048; FN++) { |
674 | // mul[FM, FN] = (add[FM, FN]) * (add[FM, FN]); |
675 | // } |
676 | // } |
677 | //} |
678 | // |
679 | |
680 | SimpleIREvaluator cg(stmt, {AP, BP, FT}); |
681 | |
682 | checkIR(cg.stmt(), R"IR( |
683 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
684 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
685 | # CHECK-NOT: Alias(add,gemm); |
686 | # CHECK: Allocate(add); // dtype=float, dims=[2048, 2048] |
687 | # CHECK: Free(add); |
688 | # CHECK: Free(relu); |
689 | # CHECK: Free(gemm))IR" ); |
690 | |
691 | #ifdef TORCH_ENABLE_LLVM |
692 | LoopNest loop(Stmt::clone(stmt), {FT.buf()}); |
693 | loop.prepareForCodegen(); |
694 | LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT}); |
695 | |
696 | checkIR(cg_llvm.stmt(), R"IR( |
697 | # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024] |
698 | # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024] |
699 | # CHECK-NOT: Alias(add,gemm); |
700 | # CHECK: Allocate(add); // dtype=float, dims=[2048, 2048] |
701 | # CHECK: Free(add); |
702 | # CHECK: Free(relu); |
703 | # CHECK: Free(gemm))IR" ); |
704 | #endif |
705 | } |
706 | |
707 | } // namespace jit |
708 | } // namespace torch |
709 | |