1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/ir/irparser.h> |
5 | #include <torch/csrc/jit/passes/constant_pooling.h> |
6 | #include <torch/csrc/jit/passes/constant_propagation.h> |
7 | #include <torch/csrc/jit/testing/file_check.h> |
8 | |
9 | #include <sstream> |
10 | #include <string> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | TEST(ConstantPoolingTest, Int) { |
16 | auto graph = std::make_shared<Graph>(); |
17 | parseIR( |
18 | R"IR( |
19 | graph(): |
20 | %8 : int = prim::Constant[value=1]() |
21 | %10 : int = prim::Constant[value=1]() |
22 | return (%8, %10) |
23 | )IR" , |
24 | &*graph); |
25 | ConstantPooling(graph); |
26 | testing::FileCheck() |
27 | .check_count("prim::Constant" , 1, /*exactly*/ true) |
28 | ->run(*graph); |
29 | } |
30 | |
31 | TEST(ConstantPoolingTest, PoolingAcrossBlocks) { |
32 | auto graph = std::make_shared<Graph>(); |
33 | parseIR( |
34 | R"IR( |
35 | graph(%cond : Tensor): |
36 | %a : str = prim::Constant[value="bcd"]() |
37 | %3 : bool = aten::Bool(%cond) |
38 | %b : str = prim::If(%3) |
39 | block0(): |
40 | %b.1 : str = prim::Constant[value="abc"]() |
41 | -> (%b.1) |
42 | block1(): |
43 | %b.2 : str = prim::Constant[value="abc"]() |
44 | -> (%b.2) |
45 | %7 : (str, str) = prim::TupleConstruct(%a, %b) |
46 | return (%7) |
47 | )IR" , |
48 | &*graph); |
49 | ConstantPooling(graph); |
50 | testing::FileCheck() |
51 | .check_count("prim::Constant[value=\"abc\"]" , 1, /*exactly*/ true) |
52 | ->check_count("prim::Constant[value=\"bcd\"]" , 1, /*exactly*/ true) |
53 | ->run(*graph); |
54 | } |
55 | |
56 | TEST(ConstantPoolingTest, PoolingDifferentDevices) { |
57 | auto graph = std::make_shared<Graph>(); |
58 | parseIR( |
59 | R"IR( |
60 | graph(): |
61 | %2 : int = prim::Constant[value=2]() |
62 | %1 : int = prim::Constant[value=1]() |
63 | %5 : int? = prim::Constant() |
64 | %7 : Device? = prim::Constant() |
65 | %15: bool = prim::Constant[value=0]() |
66 | %10 : int = prim::Constant[value=6]() |
67 | %3 : int[] = prim::ListConstruct(%1, %2) |
68 | %x : Tensor = aten::tensor(%3, %5, %7, %15) |
69 | %y : Tensor = aten::tensor(%3, %10, %7, %15) |
70 | %9 : int[] = prim::ListConstruct(%1, %2) |
71 | %z : Tensor = aten::tensor(%9, %10, %7, %15) |
72 | prim::Print(%x, %y, %z) |
73 | return (%1) |
74 | )IR" , |
75 | &*graph); |
76 | // three tensors created - two different devices among the three |
77 | // don't have good support for parsing tensor constants |
78 | ConstantPropagation(graph); |
79 | ConstantPooling(graph); |
80 | testing::FileCheck() |
81 | .check_count( |
82 | "Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant" , |
83 | 1, |
84 | /*exactly*/ true) |
85 | ->check_count( |
86 | "Long(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant" , |
87 | 1, |
88 | /*exactly*/ true) |
89 | ->run(*graph); |
90 | } |
91 | |
92 | TEST(ConstantPoolingTest, DictConstantPooling) { |
93 | auto graph = std::make_shared<Graph>(); |
94 | parseIR( |
95 | R"IR( |
96 | graph(): |
97 | %0 : int = prim::Constant[value=1]() # test/elias.py:6:9 |
98 | %1 : int = prim::Constant[value=2]() # test/elias.py:6:12 |
99 | %a.1 : Dict(int, int) = prim::DictConstruct(%0, %1) |
100 | %b.1 : Dict(int, int) = prim::DictConstruct(%1, %1) |
101 | return (%a.1, %b.1) |
102 | )IR" , |
103 | &*graph); |
104 | ConstantPropagation(graph); |
105 | ConstantPooling(graph); |
106 | testing::FileCheck() |
107 | .check_count( |
108 | "Dict(int, int) = prim::Constant" , |
109 | 2, |
110 | /*exactly*/ true) |
111 | ->run(*graph); |
112 | } |
113 | } // namespace jit |
114 | } // namespace torch |
115 | |