1 | #include <gtest/gtest.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/irparser.h> |
6 | #include <torch/csrc/jit/testing/file_check.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | TEST(CleanupPassTest, Basic) { |
12 | // Tests stability of clean up passes when dealing with constant pooling |
13 | // and constant propagation. |
14 | auto graph = std::make_shared<Graph>(); |
15 | parseIR( |
16 | R"IR( |
17 | graph(%cond.1 : Tensor, |
18 | %suffix.1 : str): |
19 | %3 : bool = aten::Bool(%cond.1) # o.py:6:7 |
20 | %25 : str = prim::If(%3) # o.py:6:4 |
21 | block0(): |
22 | %a.1 : str = prim::Constant[value="same string"]() |
23 | %b.1 : str = prim::Constant[value=" with a twist"]() |
24 | %7 : str = aten::add(%a.1, %b.1) |
25 | %11 : str = aten::add(%7, %suffix.1) # o.py:10:15 |
26 | -> (%11) |
27 | block1(): |
28 | %c.1 : str = prim::Constant[value="same string"]() |
29 | %d.1 : str = prim::Constant[value=" with a twist"]() |
30 | %12 : str = aten::add(%c.1, %d.1) |
31 | -> (%12) |
32 | return (%25) |
33 | )IR", |
34 | &*graph); |
35 | runCleanupPasses(graph); |
36 | testing::FileCheck() |
37 | .check_count( |
38 | "prim::Constant[value=\"same string with a twist\"]", |
39 | 1, |
40 | /*exactly=*/true) |
41 | ->run(*graph); |
42 | |
43 | auto graph_after_pass_once = graph->toString(); |
44 | runCleanupPasses(graph); |
45 | auto graph_after_pass_twice = graph->toString(); |
46 | ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice); |
47 | } |
48 | } // namespace jit |
49 | } // namespace torch |
50 |