1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/ir/irparser.h> |
7 | #include <torch/csrc/jit/passes/peephole.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | TEST(PeepholeOptimizeTest, IsAndIsNot) |
13 | // test is / is not none optimization |
14 | { |
15 | auto graph = std::make_shared<Graph>(); |
16 | parseIR( |
17 | R"IR( |
18 | graph(%0 : int): |
19 | %1 : None = prim::Constant() |
20 | %2 : bool = aten::__is__(%0, %1) |
21 | %3 : bool = aten::__isnot__(%0, %1) |
22 | return (%2, %3) |
23 | )IR" , |
24 | graph.get()); |
25 | PeepholeOptimize(graph); |
26 | testing::FileCheck() |
27 | .check_not("aten::__is__" ) |
28 | ->check_not("aten::__isnot__" ) |
29 | ->run(*graph); |
30 | } |
31 | |
32 | TEST(PeepholeOptimizeTest, IsAndIsNot2) { |
33 | auto graph = std::make_shared<Graph>(); |
34 | parseIR( |
35 | R"IR( |
36 | graph(%0: int?): |
37 | %1 : None = prim::Constant() |
38 | %2 : bool = aten::__is__(%0, %1) |
39 | %3 : bool = aten::__isnot__(%0, %1) |
40 | return (%2, %3) |
41 | )IR" , |
42 | graph.get()); |
43 | PeepholeOptimize(graph); |
44 | testing::FileCheck() |
45 | .check("aten::__is__" ) |
46 | ->check("aten::__isnot__" ) |
47 | ->run(*graph); |
48 | } |
49 | |
50 | TEST(PeepholeOptimizeTest, IsAndIsNot3) { |
51 | auto graph = std::make_shared<Graph>(); |
52 | parseIR( |
53 | R"IR( |
54 | graph(%0: int?): |
55 | %1 : Tensor = prim::AutogradZero() |
56 | %2 : None = prim::Constant() |
57 | %4 : bool = aten::__is__(%0, %1) |
58 | %5 : bool = aten::__isnot__(%1, %2) |
59 | return (%4, %5) |
60 | )IR" , |
61 | graph.get()); |
62 | PeepholeOptimize(graph); |
63 | testing::FileCheck() |
64 | .check("aten::__is__" ) |
65 | ->check_not("aten::__isnot__" ) |
66 | ->run(*graph); |
67 | } |
68 | |
69 | TEST(PeepholeOptimizeTest, UnwrapOptional) |
70 | // test unwrap optional |
71 | { |
72 | auto graph = std::make_shared<Graph>(); |
73 | parseIR( |
74 | R"IR( |
75 | graph(): |
76 | %1 : Float(*, *, *) = prim::Constant() |
77 | %2 : bool = aten::_unwrap_optional(%1) |
78 | %3 : bool = prim::unchecked_unwrap_optional(%1) |
79 | return (%2, %3) |
80 | )IR" , |
81 | graph.get()); |
82 | PeepholeOptimize(graph); |
83 | testing::FileCheck().check_not("unwrap" )->run(*graph); |
84 | } |
85 | |
86 | TEST(PeepholeOptimizeTest, UnwrapOptional2) { |
87 | auto graph = std::make_shared<Graph>(); |
88 | parseIR( |
89 | R"IR( |
90 | graph(%1 : Float(*, *, *)?): |
91 | %2 : bool = aten::_unwrap_optional(%1) |
92 | %3 : bool = prim::unchecked_unwrap_optional(%1) |
93 | return (%2, %3) |
94 | )IR" , |
95 | graph.get()); |
96 | PeepholeOptimize(graph); |
97 | testing::FileCheck().check_count("unwrap" , 2)->run(*graph); |
98 | } |
99 | |
100 | TEST(PeepholeOptimizeTest, AddMMFusion) { |
101 | auto graph = std::make_shared<Graph>(); |
102 | parseIR( |
103 | R"IR( |
104 | graph( |
105 | %0 : Float(2, 3, 4), |
106 | %1 : Float(2, 3, 4), |
107 | %2 : Float(1, 1, 1)): |
108 | %3 : int = prim::Constant[value=1]() |
109 | %4 : Tensor = aten::mm(%0, %1) |
110 | %5 : Tensor = aten::add(%4, %2, %3) |
111 | %6 : Tensor = aten::add(%5, %2, %3) |
112 | return (%6) |
113 | )IR" , |
114 | graph.get()); |
115 | FuseAddMM(graph); |
116 | testing::FileCheck().check("addmm" )->run(*graph); |
117 | } |
118 | } // namespace jit |
119 | } // namespace torch |
120 | |