1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/ir/irparser.h> |
5 | #include <torch/csrc/jit/passes/add_if_then_else.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) { |
11 | const auto src = R"IR( |
12 | graph(%cond: bool, %a: Tensor, %b: Tensor): |
13 | %result: Tensor = prim::If(%cond) |
14 | block0(): |
15 | -> (%a) |
16 | block1(): |
17 | -> (%b) |
18 | return (%result) |
19 | )IR" ; |
20 | |
21 | auto graph = std::make_shared<Graph>(); |
22 | parseIR(src, graph.get()); |
23 | EXPECT_TRUE(AddIfThenElseOp(graph)); |
24 | |
25 | testing::FileCheck() |
26 | .check_count("= prim::IfThenElse" , 1, /*exactly*/ true) |
27 | ->check_count("= prim::If" , 0, /*exactly*/ true) |
28 | ->run(*graph); |
29 | } |
30 | |
31 | TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) { |
32 | const auto src = R"IR( |
33 | graph(%cond: bool, %a: Tensor, %b: Tensor): |
34 | %result1: Tensor, %result2: Tensor = prim::If(%cond) |
35 | block0(): |
36 | -> (%a, %b) |
37 | block1(): |
38 | -> (%b, %a) |
39 | return (%result1, %result2) |
40 | )IR" ; |
41 | |
42 | auto graph = std::make_shared<Graph>(); |
43 | parseIR(src, graph.get()); |
44 | EXPECT_FALSE(AddIfThenElseOp(graph)); |
45 | |
46 | testing::FileCheck() |
47 | .check_count("= prim::IfThenElse" , 0, /*exactly*/ true) |
48 | ->check_count("= prim::If" , 1, /*exactly*/ true) |
49 | ->run(*graph); |
50 | } |
51 | |
52 | } // namespace jit |
53 | } // namespace torch |
54 | |