1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/tensorexpr/test_base.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/irparser.h> |
6 | #include <torch/csrc/jit/passes/lower_tuples.h> |
7 | #include <torch/csrc/jit/tensorexpr/graph_opt.h> |
8 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
9 | #include <torch/csrc/jit/testing/file_check.h> |
10 | #include <torch/torch.h> |
11 | |
12 | #include <limits> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | using namespace torch::jit::tensorexpr; |
18 | |
19 | class GraphOpt : public ::testing::Test { |
20 | public: |
21 | void SetUp() override { |
22 | old_cat_wo_conditionals_ = getCatWoConditionals(); |
23 | getCatWoConditionals() = true; |
24 | } |
25 | |
26 | void TearDown() override { |
27 | getCatWoConditionals() = old_cat_wo_conditionals_; |
28 | } |
29 | |
30 | private: |
31 | bool old_cat_wo_conditionals_; |
32 | }; |
33 | |
34 | TEST_F(GraphOpt, OptimizeCat) { |
35 | #ifdef TORCH_ENABLE_LLVM |
36 | const auto graph_string = R"IR( |
37 | graph(%x : Float(10, strides=[1], device=cpu), |
38 | %y : Float(20, strides=[1], device=cpu), |
39 | %z : Float(30, strides=[1], device=cpu)): |
40 | %dim : int = prim::Constant[value=0]() |
41 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
42 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
43 | %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) |
44 | return (%5))IR" ; |
45 | auto g = std::make_shared<Graph>(); |
46 | torch::jit::parseIR(graph_string, g.get()); |
47 | g->lint(); |
48 | |
49 | TensorExprKernel kernel(g); |
50 | |
51 | // The `aten::log` op must be moved to the inputs of `aten::cat`. |
52 | testing::FileCheck() |
53 | .check("aten::log" ) |
54 | ->check("aten::log" ) |
55 | ->check("aten::log" ) |
56 | ->check("aten::cat" ) |
57 | ->check_not("aten::log" ) |
58 | ->run(*kernel.graph()); |
59 | |
60 | auto x = at::rand({10}, at::kFloat); |
61 | auto y = at::rand({20}, at::kFloat); |
62 | auto z = at::rand({30}, at::kFloat); |
63 | auto ref = at::log(at::cat({x, y, z}, 0)); |
64 | |
65 | std::vector<at::Tensor> inputs = {x, y, z}; |
66 | std::vector<IValue> stack = fmap<IValue>(inputs); |
67 | kernel.run(stack); |
68 | auto out = stack[0].toTensor(); |
69 | ASSERT_EQ(out.sizes(), ref.sizes()); |
70 | ASSERT_EQ(out.dtype(), ref.dtype()); |
71 | ASSERT_TRUE(at::allclose(out, ref)); |
72 | #endif |
73 | } |
74 | |
75 | TEST_F(GraphOpt, OptimizeCat2) { |
76 | #ifdef TORCH_ENABLE_LLVM |
77 | const auto graph_string = R"IR( |
78 | graph(%x : Float(10, strides=[1], device=cpu), |
79 | %y : Float(20, strides=[1], device=cpu), |
80 | %z : Float(30, strides=[1], device=cpu)): |
81 | %dim : int = prim::Constant[value=0]() |
82 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
83 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
84 | %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) |
85 | %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) |
86 | return (%6))IR" ; |
87 | auto g = std::make_shared<Graph>(); |
88 | torch::jit::parseIR(graph_string, g.get()); |
89 | g->lint(); |
90 | |
91 | TensorExprKernel kernel(g); |
92 | |
93 | // The `aten::log` and `aten::tanh` ops must be moved to the inputs of |
94 | // `aten::cat`. |
95 | testing::FileCheck() |
96 | .check("aten::log" ) |
97 | ->check("aten::log" ) |
98 | ->check("aten::log" ) |
99 | ->check("aten::tanh" ) |
100 | ->check("aten::tanh" ) |
101 | ->check("aten::tanh" ) |
102 | ->check("aten::cat" ) |
103 | ->check_not("aten::log" ) |
104 | ->check_not("aten::tanh" ) |
105 | ->run(*kernel.graph()); |
106 | |
107 | auto x = at::rand({10}, at::kFloat); |
108 | auto y = at::rand({20}, at::kFloat); |
109 | auto z = at::rand({30}, at::kFloat); |
110 | auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); |
111 | |
112 | std::vector<at::Tensor> inputs = {x, y, z}; |
113 | std::vector<IValue> stack = fmap<IValue>(inputs); |
114 | kernel.run(stack); |
115 | auto out = stack[0].toTensor(); |
116 | ASSERT_EQ(out.sizes(), ref.sizes()); |
117 | ASSERT_EQ(out.dtype(), ref.dtype()); |
118 | ASSERT_TRUE(at::allclose(out, ref)); |
119 | #endif |
120 | } |
121 | |
122 | TEST_F(GraphOpt, OptimizeCat3) { |
123 | #ifdef TORCH_ENABLE_LLVM |
124 | const auto graph_string = R"IR( |
125 | graph(%a : Float(60, strides=[1], device=cpu), |
126 | %x : Float(10, strides=[1], device=cpu), |
127 | %y : Float(20, strides=[1], device=cpu), |
128 | %z : Float(30, strides=[1], device=cpu)): |
129 | %dim : int = prim::Constant[value=0]() |
130 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
131 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
132 | %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) |
133 | %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) |
134 | return (%6))IR" ; |
135 | auto g = std::make_shared<Graph>(); |
136 | torch::jit::parseIR(graph_string, g.get()); |
137 | g->lint(); |
138 | |
139 | TensorExprKernel kernel(g); |
140 | |
141 | // The `aten::tanh` op must be moved to the inputs of `aten::cat`. |
142 | // But the `aten::mul` op must not be moved since it is not a single-tensor |
143 | // op (it has 2 tensor inputs). |
144 | testing::FileCheck() |
145 | .check("aten::tanh" ) |
146 | ->check("aten::tanh" ) |
147 | ->check("aten::tanh" ) |
148 | ->check("aten::cat" ) |
149 | ->check("aten::mul" ) |
150 | ->check_not("aten::tanh" ) |
151 | ->run(*kernel.graph()); |
152 | |
153 | auto a = at::rand({60}, at::kFloat); |
154 | auto x = at::rand({10}, at::kFloat); |
155 | auto y = at::rand({20}, at::kFloat); |
156 | auto z = at::rand({30}, at::kFloat); |
157 | auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; |
158 | |
159 | std::vector<at::Tensor> inputs = {a, x, y, z}; |
160 | std::vector<IValue> stack = fmap<IValue>(inputs); |
161 | kernel.run(stack); |
162 | auto out = stack[0].toTensor(); |
163 | ASSERT_EQ(out.sizes(), ref.sizes()); |
164 | ASSERT_EQ(out.dtype(), ref.dtype()); |
165 | ASSERT_TRUE(at::allclose(out, ref)); |
166 | #endif |
167 | } |
168 | |
169 | TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { |
170 | #ifdef TORCH_ENABLE_LLVM |
171 | const auto graph_string = R"IR( |
172 | graph(%x : Int(10, strides=[1], device=cpu), |
173 | %y : Int(20, strides=[1], device=cpu), |
174 | %z : Int(30, strides=[1], device=cpu)): |
175 | %dim : int = prim::Constant[value=0]() |
176 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
177 | %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
178 | %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) |
179 | return (%5))IR" ; |
180 | auto g = std::make_shared<Graph>(); |
181 | torch::jit::parseIR(graph_string, g.get()); |
182 | g->lint(); |
183 | |
184 | TensorExprKernel kernel(g); |
185 | |
186 | // The `aten::tanh` op must be moved to the inputs of `aten::cat`. |
187 | // The scalar type of the inputs to `cat` should now be `Float` since they |
188 | // are the result of `tanh` which does the type promotion. |
189 | testing::FileCheck() |
190 | .check("aten::tanh" ) |
191 | ->check("aten::tanh" ) |
192 | ->check("aten::tanh" ) |
193 | ->check("aten::cat" ) |
194 | ->check_not("aten::tanh" ) |
195 | ->run(*kernel.graph()); |
196 | |
197 | auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt); |
198 | auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt); |
199 | auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt); |
200 | auto ref = at::tanh(at::cat({x, y, z}, 0)); |
201 | |
202 | std::vector<at::Tensor> inputs = {x, y, z}; |
203 | std::vector<IValue> stack = fmap<IValue>(inputs); |
204 | kernel.run(stack); |
205 | auto out = stack[0].toTensor(); |
206 | ASSERT_EQ(out.sizes(), ref.sizes()); |
207 | ASSERT_EQ(out.dtype(), ref.dtype()); |
208 | ASSERT_TRUE(at::allclose(out, ref)); |
209 | #endif |
210 | } |
211 | |
212 | TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { |
213 | #ifdef TORCH_ENABLE_LLVM |
214 | const auto graph_string = R"IR( |
215 | graph(%x : Float(10, strides=[1], device=cpu), |
216 | %y : Float(20, strides=[1], device=cpu), |
217 | %z : Double(30, strides=[1], device=cpu)): |
218 | %dim : int = prim::Constant[value=0]() |
219 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
220 | %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
221 | %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) |
222 | return (%5))IR" ; |
223 | auto g = std::make_shared<Graph>(); |
224 | torch::jit::parseIR(graph_string, g.get()); |
225 | g->lint(); |
226 | |
227 | TensorExprKernel kernel(g); |
228 | |
229 | // No transformation should have happened because the `aten::cat` op performs |
230 | // type promotion. This case is currently not handled. |
231 | testing::FileCheck() |
232 | .check("aten::cat" ) |
233 | ->check("aten::log" ) |
234 | ->check_not("aten::cat" ) |
235 | ->check_not("aten::log" ) |
236 | ->run(*kernel.graph()); |
237 | #endif |
238 | } |
239 | |
240 | TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { |
241 | #ifdef TORCH_ENABLE_LLVM |
242 | const auto graph_string = R"IR( |
243 | graph(%0 : Float(60, strides=[1], device=cpu), |
244 | %x : Float(10, strides=[1], device=cpu), |
245 | %y : Float(20, strides=[1], device=cpu), |
246 | %z : Float(30, strides=[1], device=cpu)): |
247 | %dim : int = prim::Constant[value=0]() |
248 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
249 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
250 | %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) |
251 | return (%5))IR" ; |
252 | auto g = std::make_shared<Graph>(); |
253 | torch::jit::parseIR(graph_string, g.get()); |
254 | g->lint(); |
255 | |
256 | TensorExprKernel kernel(g); |
257 | |
258 | // No transformation is expected since the consumers of cat are not |
259 | // single-tensor element-wise ops. |
260 | testing::FileCheck() |
261 | .check("aten::cat" ) |
262 | ->check("aten::mul" ) |
263 | ->check_not("aten::cat" ) |
264 | ->check_not("aten::mul" ) |
265 | ->run(*kernel.graph()); |
266 | #endif |
267 | } |
268 | |
269 | TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { |
270 | #ifdef TORCH_ENABLE_LLVM |
271 | const auto graph_string = R"IR( |
272 | graph(%0 : Float(60, strides=[1], device=cpu), |
273 | %1 : Float(60, strides=[1], device=cpu), |
274 | %x : Float(10, strides=[1], device=cpu), |
275 | %y : Float(20, strides=[1], device=cpu), |
276 | %z : Float(30, strides=[1], device=cpu)): |
277 | %one : int = prim::Constant[value=1]() |
278 | %dim : int = prim::Constant[value=0]() |
279 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
280 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
281 | %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) |
282 | %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) |
283 | return (%6))IR" ; |
284 | auto g = std::make_shared<Graph>(); |
285 | torch::jit::parseIR(graph_string, g.get()); |
286 | g->lint(); |
287 | |
288 | TensorExprKernel kernel(g); |
289 | |
290 | // No transformation is expected since the consumers of cat are not |
291 | // single-tensor element-wise ops. |
292 | testing::FileCheck() |
293 | .check("aten::cat" ) |
294 | ->check("aten::mul" ) |
295 | ->check("aten::add" ) |
296 | ->check_not("aten::cat" ) |
297 | ->check_not("aten::mul" ) |
298 | ->check_not("aten::add" ) |
299 | ->run(*kernel.graph()); |
300 | #endif |
301 | } |
302 | |
303 | TEST_F(GraphOpt, AOTGraphPrepPasses) { |
304 | const auto graph_string = R"IR( |
305 | graph(%x, %y, %z, %i : int): |
306 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
307 | return (%xyz_list, %i))IR" ; |
308 | auto g = std::make_shared<Graph>(); |
309 | torch::jit::parseIR(graph_string, g.get()); |
310 | |
311 | removeGraphOutput(g, 1); |
312 | replaceListOutputWithTuple(g); |
313 | LowerAllTuples(g); |
314 | |
315 | testing::FileCheck().check("return (%x, %y, %z)" )->run(*g); |
316 | } |
317 | |
318 | } // namespace jit |
319 | } // namespace torch |
320 | |