1#include <gtest/gtest.h>
2#include <test/cpp/tensorexpr/test_base.h>
3
4#include <c10/util/irange.h>
5#include <test/cpp/tensorexpr/padded_buffer.h>
6#include <torch/csrc/jit/tensorexpr/ir.h>
7#include <torch/csrc/jit/tensorexpr/ir_printer.h>
8#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
9#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
10#include <torch/csrc/jit/tensorexpr/loopnest.h>
11#include <torch/csrc/jit/tensorexpr/tensor.h>
12
13namespace torch {
14namespace jit {
15
16using namespace torch::jit::tensorexpr;
17
18extern void checkIR(StmtPtr s, const std::string& pattern);
19
20TEST(BufLiveRange, SingleRangeLine) {
21 VarHandle i("i", kInt), j("j", kInt);
22 BufHandle a("a", {32}, kFloat);
23 BufHandle b("b", {32, 32}, kFloat);
24
25 // Construct Stmt:
26 // {
27 // for (int i = 0; i < 32; i++) {
28 // a[i] = 0;
29 // for (int j = 0; j < 32; j++) {
30 // a[i] = (a[i]) + (b[i, j]);
31 // }
32 // }
33 // }
34
35 StorePtr aInit = Store::make(a, {i}, 0);
36 ExprHandle reduce = a.load({i}) + b.load({i, j});
37 StorePtr aReduce = Store::make(a, {i}, reduce);
38 StmtPtr loop =
39 For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)}));
40
41 StmtPtr stmt = Block::make({loop});
42
43 auto range = BufLiveRange::liveRange(stmt, a.node());
44 ASSERT_TRUE(std::get<0>(range) == 0);
45 ASSERT_TRUE(std::get<1>(range) == 0);
46}
47
48TEST(BufLiveRange, MulRangeLine) {
49 VarHandle i("i", kInt);
50 BufHandle a("a", {32}, kFloat);
51 BufHandle b("b", {32}, kFloat);
52
53 // Construct Stmt:
54 // {
55 // for (int i = 0; i < 32; i++) {
56 // if (i<10 ? 1 : 0) {
57 // a[i] = i + i;
58 // b[i] = i * i;
59 // }
60 // }
61 // for (int i = 0; i < 32; i++) {
62 // if (i>10 ? 1 : 0) {
63 // a[i] = i * i;
64 // b[i] = i + i;
65 // }
66 // }
67 // }
68
69 StorePtr aStore_1 = Store::make(a, {i}, i + i);
70 StorePtr bStore_1 = Store::make(b, {i}, i * i);
71 StmtPtr loop_1 = For::make(
72 i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL));
73
74 StorePtr aStore_2 = Store::make(a, {i}, i * i);
75 StorePtr bStore_2 = Store::make(b, {i}, i + i);
76 StmtPtr loop_2 = For::make(
77 i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL));
78
79 StmtPtr stmt = Block::make({loop_1, loop_2});
80
81 auto range_a = BufLiveRange::liveRange(stmt, a.node());
82 ASSERT_TRUE(std::get<0>(range_a) == 0);
83 ASSERT_TRUE(std::get<1>(range_a) == 1);
84
85 auto range_b = BufLiveRange::liveRange(stmt, b.node());
86 ASSERT_TRUE(std::get<0>(range_b) == 0);
87 ASSERT_TRUE(std::get<1>(range_b) == 1);
88}
89
90TEST(MemPlanning, MemReuseWithTypeCast) {
91 int M = 4;
92 int N = 4;
93 int K = 4;
94
95 BufHandle AP("A", {M, K}, kFloat);
96 BufHandle BP("B", {K, N}, kFloat);
97
98 Tensor CT = Reduce(
99 "gemm",
100 {M, N},
101 Sum(),
102 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
103 return AP.load(m, k) * BP.load(k, n);
104 },
105 {K});
106 Tensor DT =
107 Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
108 return CompareSelect::make(
109 CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT);
110 });
111 Tensor ET =
112 Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
113 return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n));
114 });
115 Tensor FT =
116 Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
117 return ET.load(m, n);
118 });
119 StmtPtr stmt =
120 tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
121
122 // Constructed stmt:
123 // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
124 // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
125 // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E'
126 // with typecasting.
127 //{
128 // for (int i = 0; i < 4; i++) {
129 // for (int i_1 = 0; i_1 < 4; i_1++) {
130 // gemm[i, i_1] = float(0);
131 // for (int i_2 = 0; i_2 < 4; i_2++) {
132 // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
133 // i_1]), reduce_args={i_2});
134 // }
135 // }
136 // }
137 // for (int i_3 = 0; i_3 < 4; i_3++) {
138 // for (int i_4 = 0; i_4 < 4; i_4++) {
139 // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]);
140 // }
141 // }
142 // for (int i_5 = 0; i_5 < 4; i_5++) {
143 // for (int i_6 = 0; i_6 < 4; i_6++) {
144 // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6]));
145 // }
146 // }
147 // for (int i_7 = 0; i_7 < 4; i_7++) {
148 // for (int i_8 = 0; i_8 < 4; i_8++) {
149 // F[i_7, i_8] = E[i_7, i_8];
150 // }
151 // }
152 //}
153
154 LoopNest l(stmt, {FT.buf()});
155 l.prepareForCodegen();
156 SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT});
157
158 checkIR(cg.stmt(), R"IR(
159# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
160# CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
161# CHECK: Alias(E,gemm);
162# CHECK: Free(relu);
163# CHECK: Free(gemm))IR");
164
165 PaddedBuffer<float> a_v(M, K, "a");
166 PaddedBuffer<float> b_v(K, N, "b");
167 PaddedBuffer<uint8_t> o1(M, N, "e_before");
168 PaddedBuffer<uint8_t> o2(M, N, "e_after");
169
170 for (const auto m : c10::irange(M)) {
171 for (const auto k : c10::irange(K)) {
172 a_v(m, k) = at::randn({1}).item().to<float>();
173 }
174 }
175
176 for (const auto k : c10::irange(K)) {
177 for (const auto n : c10::irange(N)) {
178 b_v(k, n) = at::randn({1}).item().to<float>();
179 }
180 }
181
182 cg.call({a_v, b_v, o1});
183
184#ifdef TORCH_ENABLE_LLVM
185 LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
186
187 checkIR(cg_llvm.stmt(), R"IR(
188# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
189# CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
190# CHECK: Alias(E,gemm);
191# CHECK: Free(relu);
192# CHECK: Free(gemm))IR");
193
194 cg_llvm.call({a_v, b_v, o2});
195
196 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
197 ExpectAllNear(o1, o2, 1e-5);
198#endif
199}
200
201TEST(MemPlanning, NoMemReuseForLargerType) {
202 int M = 4;
203 int N = 4;
204 int K = 4;
205
206 BufHandle AP("A", {M, K}, kShort);
207 BufHandle BP("B", {K, N}, kShort);
208
209 Tensor CT = Reduce(
210 "gemm",
211 {M, N},
212 Sum(),
213 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
214 return AP.load(m, k) * BP.load(k, n);
215 },
216 {K});
217 auto zero = Cast::make(CT.buf()->dtype(), 0);
218 Tensor DT =
219 Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
220 return CompareSelect::make(
221 CT.load(m, n), zero, zero, CT.load(m, n), kLT);
222 });
223 Tensor ET =
224 Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
225 return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n));
226 });
227 Tensor FT =
228 Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
229 return ET.load(m, n);
230 });
231 StmtPtr stmt =
232 tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
233
234 // Constructed stmt:
235 // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
236 // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
237 // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for
238 // 'E'.
239 //{
240 // for (int i = 0; i < 4; i++) {
241 // for (int i_1 = 0; i_1 < 4; i_1++) {
242 // gemm[i, i_1] = int16_t(0);
243 // for (int i_2 = 0; i_2 < 4; i_2++) {
244 // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
245 // i_1]), reduce_args={i_2});
246 // }
247 // }
248 // }
249 // for (int i_3 = 0; i_3 < 4; i_3++) {
250 // for (int i_4 = 0; i_4 < 4; i_4++) {
251 // relu[i_3, i_4] = (gemm[i_3, i_4])<int16_t(0) ? int16_t(0) : (gemm[i_3,
252 // i_4]);
253 // }
254 // }
255 // for (int i_5 = 0; i_5 < 4; i_5++) {
256 // for (int i_6 = 0; i_6 < 4; i_6++) {
257 // E[i_5, i_6] = float((relu[i_5, i_6]) + (relu[i_5, i_6]));
258 // }
259 // }
260 // for (int i_7 = 0; i_7 < 4; i_7++) {
261 // for (int i_8 = 0; i_8 < 4; i_8++) {
262 // F[i_7, i_8] = E[i_7, i_8];
263 // }
264 // }
265 //}
266
267 LoopNest l(stmt, {FT.buf()});
268 l.prepareForCodegen();
269 SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT.buf()});
270
271 checkIR(cg.stmt(), R"IR(
272# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
273# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
274# CHECK: Allocate(E); // dtype=float, dims=[4, 4]
275# CHECK: Free(E);
276# CHECK: Free(relu);
277# CHECK: Free(gemm))IR");
278
279 PaddedBuffer<short> a_v(M, K, "a");
280 PaddedBuffer<short> b_v(K, N, "b");
281 PaddedBuffer<float> o1(M, N, "e_before");
282 PaddedBuffer<float> o2(M, N, "e_after");
283
284 for (const auto m : c10::irange(M)) {
285 for (const auto k : c10::irange(K)) {
286 a_v(m, k) = at::randn({1}).item().to<float>();
287 }
288 }
289
290 for (const auto k : c10::irange(K)) {
291 for (const auto n : c10::irange(N)) {
292 b_v(k, n) = at::randn({1}).item().to<float>();
293 }
294 }
295
296 cg.call({a_v, b_v, o1});
297
298#ifdef TORCH_ENABLE_LLVM
299 LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
300
301 checkIR(cg_llvm.stmt(), R"IR(
302# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
303# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
304# CHECK: Allocate(E); // dtype=float, dims=[4, 4]
305# CHECK: Free(E);
306# CHECK: Free(relu);
307# CHECK: Free(gemm))IR");
308
309 cg_llvm.call({a_v, b_v, o2});
310
311 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
312 ExpectAllNear(o1, o2, 1e-5);
313#endif
314}
315
316TEST(MemPlanning, SameBufSizeMemReuse) {
317 int M = 1024;
318 int N = 1024;
319 int K = 2048;
320
321 BufHandle AP("A", {M, K}, kFloat);
322 BufHandle BP("B", {K, N}, kFloat);
323
324 Tensor CT = Reduce(
325 "gemm",
326 {M, N},
327 Sum(),
328 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
329 return AP.load(m, k) * BP.load(k, n);
330 },
331 {K});
332 Tensor DT =
333 Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
334 auto zero = Cast::make(CT.buf()->dtype(), 0);
335 return CompareSelect::make(
336 CT.load(m, n), zero, zero, CT.load(m, n), kLT);
337 });
338 Tensor ET =
339 Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
340 return DT.load(m, n) + DT.load(m, n);
341 });
342 Tensor FT =
343 Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
344 return ET.load(m, n) * ET.load(m, n);
345 });
346 auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
347
348 // Constructed stmt:
349 // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
350 // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm'
351 // for 'add'.
352 //{
353 // for (int M = 0; M < 1024; M++) {
354 // for (int N = 0; N < 1024; N++) {
355 // gemm[M, N] = float(0);
356 // for (int K = 0; K < 2048; K++) {
357 // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
358 // reduce_args={K});
359 // }
360 // }
361 // }
362 // for (int M_1 = 0; M_1 < 1024; M_1++) {
363 // for (int N_1 = 0; N_1 < 1024; N_1++) {
364 // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
365 // N_1]);
366 // }
367 // }
368 // for (int M_2 = 0; M_2 < 1024; M_2++) {
369 // for (int N_2 = 0; N_2 < 1024; N_2++) {
370 // add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
371 // }
372 // }
373 // for (int M_3 = 0; M_3 < 1024; M_3++) {
374 // for (int N_3 = 0; N_3 < 1024; N_3++) {
375 // mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
376 // }
377 // }
378 //}
379
380 SimpleIREvaluator cg(stmt, {AP, BP, FT});
381
382 checkIR(cg.stmt(), R"IR(
383# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
384# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
385# CHECK: Alias(add,gemm);
386# CHECK: Free(relu);
387# CHECK: Free(gemm))IR");
388
389#ifdef TORCH_ENABLE_LLVM
390 LoopNest loop(Stmt::clone(stmt), {FT.buf()});
391 loop.prepareForCodegen();
392 LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
393
394 checkIR(cg_llvm.stmt(), R"IR(
395# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
396# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
397# CHECK: Alias(add,gemm);
398# CHECK: Free(relu);
399# CHECK: Free(gemm))IR");
400#endif
401}
402
403TEST(MemPlanning, SameBufSizeMultiMemReuses) {
404 int M = 1024;
405 int N = 1024;
406 int K = 2048;
407
408 BufHandle AP("A", {M, K}, kFloat);
409 BufHandle BP("B", {K, N}, kFloat);
410
411 Tensor CT = Reduce(
412 "gemm",
413 {M, N},
414 Sum(),
415 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
416 return AP.load(m, k) * BP.load(k, n);
417 },
418 {K});
419 Tensor DT =
420 Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
421 auto zero = Cast::make(CT.buf()->dtype(), 0);
422 return CompareSelect::make(
423 CT.load(m, n), zero, zero, CT.load(m, n), kLT);
424 });
425 Tensor ET =
426 Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
427 return DT.load(m, n) + DT.load(m, n);
428 });
429 Tensor FT =
430 Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
431 return ET.load(m, n) * ET.load(m, n);
432 });
433 Tensor GT =
434 Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
435 return FT.load(m, n) - ET.load(m, n);
436 });
437
438 auto stmt =
439 Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()});
440
441 // Constructed stmt:
442 // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
443 // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same
444 // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul'
445 //{
446 // for (int M = 0; M < 1024; M++) {
447 // for (int N = 0; N < 1024; N++) {
448 // gemm[M, N] = float(0);
449 // for (int K = 0; K < 2048; K++) {
450 // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
451 // reduce_args={K});
452 // }
453 // }
454 // }
455 // for (int M_1 = 0; M_1 < 1024; M_1++) {
456 // for (int N_1 = 0; N_1 < 1024; N_1++) {
457 // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
458 // N_1]);
459 // }
460 // }
461 // for (int M_2 = 0; M_2 < 1024; M_2++) {
462 // for (int N_2 = 0; N_2 < 1024; N_2++) {
463 // add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
464 // }
465 // }
466 // for (int M_3 = 0; M_3 < 1024; M_3++) {
467 // for (int N_3 = 0; N_3 < 1024; N_3++) {
468 // mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
469 // }
470 // }
471 // for (int M_4 = 0; M_4 < 1024; M_4++) {
472 // for (int N_4 = 0; N_4 < 1024; N_4++) {
473 // sub[M_4, N_4] = (mul[M_4, N_4]) - (add[M_4, N_4]);
474 // }
475 // }
476 //}
477
478 SimpleIREvaluator cg(stmt, {AP, BP, GT});
479
480 checkIR(cg.stmt(), R"IR(
481# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
482# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
483# CHECK: Alias(add,gemm);
484# CHECK: Alias(mul,relu);
485# CHECK: Free(relu);
486# CHECK: Free(gemm))IR");
487
488#ifdef TORCH_ENABLE_LLVM
489 LoopNest loop(Stmt::clone(stmt), {FT.buf()});
490 loop.prepareForCodegen();
491 LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
492
493 checkIR(cg_llvm.stmt(), R"IR(
494# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
495# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
496# CHECK: Alias(add,gemm);
497# CHECK: Alias(mul,relu);
498# CHECK: Free(relu);
499# CHECK: Free(gemm))IR");
500#endif
501}
502
503TEST(MemPlanning, SameBufSizeMultiMemReusesOfOneBuf) {
504 int M = 1024;
505 int N = 1024;
506 int K = 2048;
507
508 BufHandle AP("A", {M, K}, kFloat);
509 BufHandle BP("B", {K, N}, kFloat);
510
511 Tensor CT = Reduce(
512 "gemm",
513 {M, N},
514 Sum(),
515 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
516 return AP.load(m, k) * BP.load(k, n);
517 },
518 {K});
519 Tensor DT =
520 Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
521 auto zero = Cast::make(CT.buf()->dtype(), 0);
522 return CompareSelect::make(
523 CT.load(m, n), zero, zero, CT.load(m, n), kLT);
524 });
525 Tensor ET =
526 Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
527 return DT.load(m, n) + DT.load(m, n);
528 });
529 Tensor FT =
530 Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
531 return ET.load(m, n) * ET.load(m, n);
532 });
533 Tensor GT =
534 Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
535 return FT.load(m, n) - 1;
536 });
537 Tensor HT =
538 Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
539 return GT.load(m, n) / 2;
540 });
541
542 auto stmt = Block::make(
543 {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()});
544
545 // Constructed stmt:
546 // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
547 // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and
548 // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for
549 // 'mul', and reuse 'gemm' for 'sub'.
550 //{
551 // for (int M = 0; M < 1024; M++) {
552 // for (int N = 0; N < 1024; N++) {
553 // gemm[M, N] = float(0);
554 // for (int K = 0; K < 2048; K++) {
555 // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
556 // reduce_args={K});
557 // }
558 // }
559 // }
560 // for (int M_1 = 0; M_1 < 1024; M_1++) {
561 // for (int N_1 = 0; N_1 < 1024; N_1++) {
562 // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
563 // N_1]);
564 // }
565 // }
566 // for (int M_2 = 0; M_2 < 1024; M_2++) {
567 // for (int N_2 = 0; N_2 < 1024; N_2++) {
568 // add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
569 // }
570 // }
571 // for (int M_3 = 0; M_3 < 1024; M_3++) {
572 // for (int N_3 = 0; N_3 < 1024; N_3++) {
573 // mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
574 // }
575 // }
576 // for (int M_4 = 0; M_4 < 1024; M_4++) {
577 // for (int N_4 = 0; N_4 < 1024; N_4++) {
578 // sub[M_4, N_4] = (mul[M_4, N_4]) - float(1);
579 // }
580 // }
581 // for (int M_5 = 0; M_5 < 1024; M_5++) {
582 // for (int N_5 = 0; N_5 < 1024; N_5++) {
583 // div[M_5, N_5] = (sub[M_5, N_5]) / float(2);
584 // }
585 // }
586 //}
587
588 SimpleIREvaluator cg(stmt, {AP, BP, HT});
589
590 checkIR(cg.stmt(), R"IR(
591# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
592# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
593# CHECK: Alias(add,gemm);
594# CHECK: Alias(mul,relu);
595# CHECK: Alias(sub,gemm);
596# CHECK: Free(relu);
597# CHECK: Free(gemm))IR");
598
599#ifdef TORCH_ENABLE_LLVM
600 LoopNest loop(Stmt::clone(stmt), {FT.buf()});
601 loop.prepareForCodegen();
602 LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
603
604 checkIR(cg_llvm.stmt(), R"IR(
605# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
606# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
607# CHECK: Alias(add,gemm);
608# CHECK: Alias(mul,relu);
609# CHECK: Alias(sub,gemm);
610# CHECK: Free(relu);
611# CHECK: Free(gemm))IR");
612#endif
613}
614
615TEST(MemPlanning, SmallerBufSizeNonMemReuse) {
616 int M = 1024;
617 int N = 1024;
618 int K = 2048;
619
620 BufHandle AP("A", {M, K}, kFloat);
621 BufHandle BP("B", {K, N}, kFloat);
622
623 Tensor CT = Reduce(
624 "gemm",
625 {M, N},
626 Sum(),
627 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
628 return AP.load(m, k) * BP.load(k, n);
629 },
630 {K});
631 Tensor DT =
632 Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
633 auto zero = Cast::make(CT.buf()->dtype(), 0);
634 return CompareSelect::make(
635 CT.load(m, n), zero, zero, CT.load(m, n), kLT);
636 });
637 Tensor ET = Compute(
638 "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) {
639 return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2);
640 });
641 Tensor FT = Compute(
642 "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) {
643 return ET.load(fm, fn) * ET.load(fm, fn);
644 });
645 auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
646
647 // Constructed stmt:
648 // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
649 // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of
650 // buffer 'gemm' is smaller.
651 //{
652 // for (int M = 0; M < 1024; M++) {
653 // for (int N = 0; N < 1024; N++) {
654 // gemm[M, N] = float(0);
655 // for (int K = 0; K < 2048; K++) {
656 // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
657 // reduce_args={K});
658 // }
659 // }
660 // }
661 // for (int M_1 = 0; M_1 < 1024; M_1++) {
662 // for (int N_1 = 0; N_1 < 1024; N_1++) {
663 // relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
664 // N_1]);
665 // }
666 // }
667 // for (int EM = 0; EM < 2048; EM++) {
668 // for (int EN = 0; EN < 2048; EN++) {
669 // add[EM, EN] = (relu[EM / 2, EN / 2]) + (relu[EM / 2, EN / 2]);
670 // }
671 // }
672 // for (int FM = 0; FM < 2048; FM++) {
673 // for (int FN = 0; FN < 2048; FN++) {
674 // mul[FM, FN] = (add[FM, FN]) * (add[FM, FN]);
675 // }
676 // }
677 //}
678 //
679
680 SimpleIREvaluator cg(stmt, {AP, BP, FT});
681
682 checkIR(cg.stmt(), R"IR(
683# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
684# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
685# CHECK-NOT: Alias(add,gemm);
686# CHECK: Allocate(add); // dtype=float, dims=[2048, 2048]
687# CHECK: Free(add);
688# CHECK: Free(relu);
689# CHECK: Free(gemm))IR");
690
691#ifdef TORCH_ENABLE_LLVM
692 LoopNest loop(Stmt::clone(stmt), {FT.buf()});
693 loop.prepareForCodegen();
694 LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
695
696 checkIR(cg_llvm.stmt(), R"IR(
697# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
698# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
699# CHECK-NOT: Alias(add,gemm);
700# CHECK: Allocate(add); // dtype=float, dims=[2048, 2048]
701# CHECK: Free(add);
702# CHECK: Free(relu);
703# CHECK: Free(gemm))IR");
704#endif
705}
706
707} // namespace jit
708} // namespace torch
709