1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/ir/irparser.h>
4#include <torch/csrc/jit/passes/dead_code_elimination.h>
5#include <torch/csrc/jit/testing/file_check.h>
6
7namespace torch {
8namespace jit {
9TEST(EliminateDeadCodeTest, Basic) {
10 auto graph = std::make_shared<Graph>();
11
12 // Consider the following loop:
13 // for i in range(3):
14 // tot += a[0][0]
15 // b = a[0]
16 // b[0] += 1
17 // print(tot)
18 // We want to check that b[0] and b are properly marked as live and thus not
19 // DCE'd.
20 const std::string input =
21 R"IR(
22graph():
23 %48 : None = prim::Constant()
24 %50 : bool = prim::Constant[value=1]()
25 %0 : int = prim::Constant[value=2]()
26 %12 : int = prim::Constant[value=1]()
27 %24 : int = prim::Constant[value=3]()
28 %31 : int = prim::Constant[value=0]()
29 %2 : int[] = prim::ListConstruct(%0, %0)
30 %a.1 : Tensor = prim::MakeTestTensor()
31 %14 : int[] = prim::ListConstruct(%12)
32 %tot.1 : Tensor = prim::MakeTestTensor()
33 %tot : Tensor = prim::Loop(%24, %50, %tot.1)
34 block0(%i : int, %tot.6 : Tensor):
35 %33 : Tensor = aten::select(%a.1, %31, %31)
36 %35 : Tensor = aten::select(%33, %31, %31)
37 # CHECK: add_
38 %tot.3 : Tensor = aten::add_(%tot.6, %35, %12)
39 %b.1 : Tensor = aten::select(%a.1, %31, %31)
40 %44 : Tensor = aten::select(%b.1, %31, %31)
41 # CHECK: add_
42 %46 : Tensor = aten::add_(%44, %12, %12)
43 -> (%50, %tot.3)
44 return (%tot)
45)IR";
46 parseIR(input, graph.get());
47 EliminateDeadCode(graph);
48 // Check that dead code elimin
49 testing::FileCheck().run(input, *graph);
50}
51} // namespace jit
52} // namespace torch
53