1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/passes/peephole.h>
8
9namespace torch {
10namespace jit {
11
12TEST(PeepholeOptimizeTest, IsAndIsNot)
13// test is / is not none optimization
14{
15 auto graph = std::make_shared<Graph>();
16 parseIR(
17 R"IR(
18graph(%0 : int):
19 %1 : None = prim::Constant()
20 %2 : bool = aten::__is__(%0, %1)
21 %3 : bool = aten::__isnot__(%0, %1)
22 return (%2, %3)
23 )IR",
24 graph.get());
25 PeepholeOptimize(graph);
26 testing::FileCheck()
27 .check_not("aten::__is__")
28 ->check_not("aten::__isnot__")
29 ->run(*graph);
30}
31
32TEST(PeepholeOptimizeTest, IsAndIsNot2) {
33 auto graph = std::make_shared<Graph>();
34 parseIR(
35 R"IR(
36graph(%0: int?):
37 %1 : None = prim::Constant()
38 %2 : bool = aten::__is__(%0, %1)
39 %3 : bool = aten::__isnot__(%0, %1)
40 return (%2, %3)
41 )IR",
42 graph.get());
43 PeepholeOptimize(graph);
44 testing::FileCheck()
45 .check("aten::__is__")
46 ->check("aten::__isnot__")
47 ->run(*graph);
48}
49
50TEST(PeepholeOptimizeTest, IsAndIsNot3) {
51 auto graph = std::make_shared<Graph>();
52 parseIR(
53 R"IR(
54graph(%0: int?):
55 %1 : Tensor = prim::AutogradZero()
56 %2 : None = prim::Constant()
57 %4 : bool = aten::__is__(%0, %1)
58 %5 : bool = aten::__isnot__(%1, %2)
59 return (%4, %5)
60 )IR",
61 graph.get());
62 PeepholeOptimize(graph);
63 testing::FileCheck()
64 .check("aten::__is__")
65 ->check_not("aten::__isnot__")
66 ->run(*graph);
67}
68
69TEST(PeepholeOptimizeTest, UnwrapOptional)
70// test unwrap optional
71{
72 auto graph = std::make_shared<Graph>();
73 parseIR(
74 R"IR(
75graph():
76 %1 : Float(*, *, *) = prim::Constant()
77 %2 : bool = aten::_unwrap_optional(%1)
78 %3 : bool = prim::unchecked_unwrap_optional(%1)
79 return (%2, %3)
80 )IR",
81 graph.get());
82 PeepholeOptimize(graph);
83 testing::FileCheck().check_not("unwrap")->run(*graph);
84}
85
86TEST(PeepholeOptimizeTest, UnwrapOptional2) {
87 auto graph = std::make_shared<Graph>();
88 parseIR(
89 R"IR(
90graph(%1 : Float(*, *, *)?):
91 %2 : bool = aten::_unwrap_optional(%1)
92 %3 : bool = prim::unchecked_unwrap_optional(%1)
93 return (%2, %3)
94 )IR",
95 graph.get());
96 PeepholeOptimize(graph);
97 testing::FileCheck().check_count("unwrap", 2)->run(*graph);
98}
99
100TEST(PeepholeOptimizeTest, AddMMFusion) {
101 auto graph = std::make_shared<Graph>();
102 parseIR(
103 R"IR(
104 graph(
105 %0 : Float(2, 3, 4),
106 %1 : Float(2, 3, 4),
107 %2 : Float(1, 1, 1)):
108 %3 : int = prim::Constant[value=1]()
109 %4 : Tensor = aten::mm(%0, %1)
110 %5 : Tensor = aten::add(%4, %2, %3)
111 %6 : Tensor = aten::add(%5, %2, %3)
112 return (%6)
113 )IR",
114 graph.get());
115 FuseAddMM(graph);
116 testing::FileCheck().check("addmm")->run(*graph);
117}
118} // namespace jit
119} // namespace torch
120