1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <torch/csrc/jit/passes/add_if_then_else.h>
6
7namespace torch {
8namespace jit {
9
10TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) {
11 const auto src = R"IR(
12 graph(%cond: bool, %a: Tensor, %b: Tensor):
13 %result: Tensor = prim::If(%cond)
14 block0():
15 -> (%a)
16 block1():
17 -> (%b)
18 return (%result)
19 )IR";
20
21 auto graph = std::make_shared<Graph>();
22 parseIR(src, graph.get());
23 EXPECT_TRUE(AddIfThenElseOp(graph));
24
25 testing::FileCheck()
26 .check_count("= prim::IfThenElse", 1, /*exactly*/ true)
27 ->check_count("= prim::If", 0, /*exactly*/ true)
28 ->run(*graph);
29}
30
31TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) {
32 const auto src = R"IR(
33 graph(%cond: bool, %a: Tensor, %b: Tensor):
34 %result1: Tensor, %result2: Tensor = prim::If(%cond)
35 block0():
36 -> (%a, %b)
37 block1():
38 -> (%b, %a)
39 return (%result1, %result2)
40 )IR";
41
42 auto graph = std::make_shared<Graph>();
43 parseIR(src, graph.get());
44 EXPECT_FALSE(AddIfThenElseOp(graph));
45
46 testing::FileCheck()
47 .check_count("= prim::IfThenElse", 0, /*exactly*/ true)
48 ->check_count("= prim::If", 1, /*exactly*/ true)
49 ->run(*graph);
50}
51
52} // namespace jit
53} // namespace torch
54