1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <torch/csrc/jit/passes/constant_pooling.h>
6#include <torch/csrc/jit/passes/constant_propagation.h>
7#include <torch/csrc/jit/testing/file_check.h>
8
9#include <sstream>
10#include <string>
11
12namespace torch {
13namespace jit {
14
15TEST(ConstantPoolingTest, Int) {
16 auto graph = std::make_shared<Graph>();
17 parseIR(
18 R"IR(
19graph():
20 %8 : int = prim::Constant[value=1]()
21 %10 : int = prim::Constant[value=1]()
22 return (%8, %10)
23 )IR",
24 &*graph);
25 ConstantPooling(graph);
26 testing::FileCheck()
27 .check_count("prim::Constant", 1, /*exactly*/ true)
28 ->run(*graph);
29}
30
31TEST(ConstantPoolingTest, PoolingAcrossBlocks) {
32 auto graph = std::make_shared<Graph>();
33 parseIR(
34 R"IR(
35graph(%cond : Tensor):
36 %a : str = prim::Constant[value="bcd"]()
37 %3 : bool = aten::Bool(%cond)
38 %b : str = prim::If(%3)
39 block0():
40 %b.1 : str = prim::Constant[value="abc"]()
41 -> (%b.1)
42 block1():
43 %b.2 : str = prim::Constant[value="abc"]()
44 -> (%b.2)
45 %7 : (str, str) = prim::TupleConstruct(%a, %b)
46 return (%7)
47 )IR",
48 &*graph);
49 ConstantPooling(graph);
50 testing::FileCheck()
51 .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
52 ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
53 ->run(*graph);
54}
55
56TEST(ConstantPoolingTest, PoolingDifferentDevices) {
57 auto graph = std::make_shared<Graph>();
58 parseIR(
59 R"IR(
60graph():
61 %2 : int = prim::Constant[value=2]()
62 %1 : int = prim::Constant[value=1]()
63 %5 : int? = prim::Constant()
64 %7 : Device? = prim::Constant()
65 %15: bool = prim::Constant[value=0]()
66 %10 : int = prim::Constant[value=6]()
67 %3 : int[] = prim::ListConstruct(%1, %2)
68 %x : Tensor = aten::tensor(%3, %5, %7, %15)
69 %y : Tensor = aten::tensor(%3, %10, %7, %15)
70 %9 : int[] = prim::ListConstruct(%1, %2)
71 %z : Tensor = aten::tensor(%9, %10, %7, %15)
72 prim::Print(%x, %y, %z)
73 return (%1)
74 )IR",
75 &*graph);
76 // three tensors created - two different devices among the three
77 // don't have good support for parsing tensor constants
78 ConstantPropagation(graph);
79 ConstantPooling(graph);
80 testing::FileCheck()
81 .check_count(
82 "Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant",
83 1,
84 /*exactly*/ true)
85 ->check_count(
86 "Long(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant",
87 1,
88 /*exactly*/ true)
89 ->run(*graph);
90}
91
92TEST(ConstantPoolingTest, DictConstantPooling) {
93 auto graph = std::make_shared<Graph>();
94 parseIR(
95 R"IR(
96graph():
97 %0 : int = prim::Constant[value=1]() # test/elias.py:6:9
98 %1 : int = prim::Constant[value=2]() # test/elias.py:6:12
99 %a.1 : Dict(int, int) = prim::DictConstruct(%0, %1)
100 %b.1 : Dict(int, int) = prim::DictConstruct(%1, %1)
101 return (%a.1, %b.1)
102 )IR",
103 &*graph);
104 ConstantPropagation(graph);
105 ConstantPooling(graph);
106 testing::FileCheck()
107 .check_count(
108 "Dict(int, int) = prim::Constant",
109 2,
110 /*exactly*/ true)
111 ->run(*graph);
112}
113} // namespace jit
114} // namespace torch
115