1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/jit/ir/irparser.h> |
4 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
5 | #include <torch/csrc/jit/testing/file_check.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | TEST(EliminateDeadCodeTest, Basic) { |
10 | auto graph = std::make_shared<Graph>(); |
11 | |
12 | // Consider the following loop: |
13 | // for i in range(3): |
14 | // tot += a[0][0] |
15 | // b = a[0] |
16 | // b[0] += 1 |
17 | // print(tot) |
18 | // We want to check that b[0] and b are properly marked as live and thus not |
19 | // DCE'd. |
20 | const std::string input = |
21 | R"IR( |
22 | graph(): |
23 | %48 : None = prim::Constant() |
24 | %50 : bool = prim::Constant[value=1]() |
25 | %0 : int = prim::Constant[value=2]() |
26 | %12 : int = prim::Constant[value=1]() |
27 | %24 : int = prim::Constant[value=3]() |
28 | %31 : int = prim::Constant[value=0]() |
29 | %2 : int[] = prim::ListConstruct(%0, %0) |
30 | %a.1 : Tensor = prim::MakeTestTensor() |
31 | %14 : int[] = prim::ListConstruct(%12) |
32 | %tot.1 : Tensor = prim::MakeTestTensor() |
33 | %tot : Tensor = prim::Loop(%24, %50, %tot.1) |
34 | block0(%i : int, %tot.6 : Tensor): |
35 | %33 : Tensor = aten::select(%a.1, %31, %31) |
36 | %35 : Tensor = aten::select(%33, %31, %31) |
37 | # CHECK: add_ |
38 | %tot.3 : Tensor = aten::add_(%tot.6, %35, %12) |
39 | %b.1 : Tensor = aten::select(%a.1, %31, %31) |
40 | %44 : Tensor = aten::select(%b.1, %31, %31) |
41 | # CHECK: add_ |
42 | %46 : Tensor = aten::add_(%44, %12, %12) |
43 | -> (%50, %tot.3) |
44 | return (%tot) |
45 | )IR" ; |
46 | parseIR(input, graph.get()); |
47 | EliminateDeadCode(graph); |
48 | // Check that dead code elimin |
49 | testing::FileCheck().run(input, *graph); |
50 | } |
51 | } // namespace jit |
52 | } // namespace torch |
53 | |