1#include <gtest/gtest.h>
2
3#include <test/cpp/tensorexpr/test_base.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/ir/irparser.h>
6#include <torch/csrc/jit/passes/lower_tuples.h>
7#include <torch/csrc/jit/tensorexpr/graph_opt.h>
8#include <torch/csrc/jit/tensorexpr/kernel.h>
9#include <torch/csrc/jit/testing/file_check.h>
10#include <torch/torch.h>
11
12#include <limits>
13
14namespace torch {
15namespace jit {
16
17using namespace torch::jit::tensorexpr;
18
19class GraphOpt : public ::testing::Test {
20 public:
21 void SetUp() override {
22 old_cat_wo_conditionals_ = getCatWoConditionals();
23 getCatWoConditionals() = true;
24 }
25
26 void TearDown() override {
27 getCatWoConditionals() = old_cat_wo_conditionals_;
28 }
29
30 private:
31 bool old_cat_wo_conditionals_;
32};
33
34TEST_F(GraphOpt, OptimizeCat) {
35#ifdef TORCH_ENABLE_LLVM
36 const auto graph_string = R"IR(
37 graph(%x : Float(10, strides=[1], device=cpu),
38 %y : Float(20, strides=[1], device=cpu),
39 %z : Float(30, strides=[1], device=cpu)):
40 %dim : int = prim::Constant[value=0]()
41 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
42 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
43 %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
44 return (%5))IR";
45 auto g = std::make_shared<Graph>();
46 torch::jit::parseIR(graph_string, g.get());
47 g->lint();
48
49 TensorExprKernel kernel(g);
50
51 // The `aten::log` op must be moved to the inputs of `aten::cat`.
52 testing::FileCheck()
53 .check("aten::log")
54 ->check("aten::log")
55 ->check("aten::log")
56 ->check("aten::cat")
57 ->check_not("aten::log")
58 ->run(*kernel.graph());
59
60 auto x = at::rand({10}, at::kFloat);
61 auto y = at::rand({20}, at::kFloat);
62 auto z = at::rand({30}, at::kFloat);
63 auto ref = at::log(at::cat({x, y, z}, 0));
64
65 std::vector<at::Tensor> inputs = {x, y, z};
66 std::vector<IValue> stack = fmap<IValue>(inputs);
67 kernel.run(stack);
68 auto out = stack[0].toTensor();
69 ASSERT_EQ(out.sizes(), ref.sizes());
70 ASSERT_EQ(out.dtype(), ref.dtype());
71 ASSERT_TRUE(at::allclose(out, ref));
72#endif
73}
74
75TEST_F(GraphOpt, OptimizeCat2) {
76#ifdef TORCH_ENABLE_LLVM
77 const auto graph_string = R"IR(
78 graph(%x : Float(10, strides=[1], device=cpu),
79 %y : Float(20, strides=[1], device=cpu),
80 %z : Float(30, strides=[1], device=cpu)):
81 %dim : int = prim::Constant[value=0]()
82 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
83 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
84 %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
85 %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
86 return (%6))IR";
87 auto g = std::make_shared<Graph>();
88 torch::jit::parseIR(graph_string, g.get());
89 g->lint();
90
91 TensorExprKernel kernel(g);
92
93 // The `aten::log` and `aten::tanh` ops must be moved to the inputs of
94 // `aten::cat`.
95 testing::FileCheck()
96 .check("aten::log")
97 ->check("aten::log")
98 ->check("aten::log")
99 ->check("aten::tanh")
100 ->check("aten::tanh")
101 ->check("aten::tanh")
102 ->check("aten::cat")
103 ->check_not("aten::log")
104 ->check_not("aten::tanh")
105 ->run(*kernel.graph());
106
107 auto x = at::rand({10}, at::kFloat);
108 auto y = at::rand({20}, at::kFloat);
109 auto z = at::rand({30}, at::kFloat);
110 auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
111
112 std::vector<at::Tensor> inputs = {x, y, z};
113 std::vector<IValue> stack = fmap<IValue>(inputs);
114 kernel.run(stack);
115 auto out = stack[0].toTensor();
116 ASSERT_EQ(out.sizes(), ref.sizes());
117 ASSERT_EQ(out.dtype(), ref.dtype());
118 ASSERT_TRUE(at::allclose(out, ref));
119#endif
120}
121
122TEST_F(GraphOpt, OptimizeCat3) {
123#ifdef TORCH_ENABLE_LLVM
124 const auto graph_string = R"IR(
125 graph(%a : Float(60, strides=[1], device=cpu),
126 %x : Float(10, strides=[1], device=cpu),
127 %y : Float(20, strides=[1], device=cpu),
128 %z : Float(30, strides=[1], device=cpu)):
129 %dim : int = prim::Constant[value=0]()
130 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
131 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
132 %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
133 %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
134 return (%6))IR";
135 auto g = std::make_shared<Graph>();
136 torch::jit::parseIR(graph_string, g.get());
137 g->lint();
138
139 TensorExprKernel kernel(g);
140
141 // The `aten::tanh` op must be moved to the inputs of `aten::cat`.
142 // But the `aten::mul` op must not be moved since it is not a single-tensor
143 // op (it has 2 tensor inputs).
144 testing::FileCheck()
145 .check("aten::tanh")
146 ->check("aten::tanh")
147 ->check("aten::tanh")
148 ->check("aten::cat")
149 ->check("aten::mul")
150 ->check_not("aten::tanh")
151 ->run(*kernel.graph());
152
153 auto a = at::rand({60}, at::kFloat);
154 auto x = at::rand({10}, at::kFloat);
155 auto y = at::rand({20}, at::kFloat);
156 auto z = at::rand({30}, at::kFloat);
157 auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
158
159 std::vector<at::Tensor> inputs = {a, x, y, z};
160 std::vector<IValue> stack = fmap<IValue>(inputs);
161 kernel.run(stack);
162 auto out = stack[0].toTensor();
163 ASSERT_EQ(out.sizes(), ref.sizes());
164 ASSERT_EQ(out.dtype(), ref.dtype());
165 ASSERT_TRUE(at::allclose(out, ref));
166#endif
167}
168
169TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
170#ifdef TORCH_ENABLE_LLVM
171 const auto graph_string = R"IR(
172 graph(%x : Int(10, strides=[1], device=cpu),
173 %y : Int(20, strides=[1], device=cpu),
174 %z : Int(30, strides=[1], device=cpu)):
175 %dim : int = prim::Constant[value=0]()
176 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
177 %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
178 %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
179 return (%5))IR";
180 auto g = std::make_shared<Graph>();
181 torch::jit::parseIR(graph_string, g.get());
182 g->lint();
183
184 TensorExprKernel kernel(g);
185
186 // The `aten::tanh` op must be moved to the inputs of `aten::cat`.
187 // The scalar type of the inputs to `cat` should now be `Float` since they
188 // are the result of `tanh` which does the type promotion.
189 testing::FileCheck()
190 .check("aten::tanh")
191 ->check("aten::tanh")
192 ->check("aten::tanh")
193 ->check("aten::cat")
194 ->check_not("aten::tanh")
195 ->run(*kernel.graph());
196
197 auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
198 auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
199 auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
200 auto ref = at::tanh(at::cat({x, y, z}, 0));
201
202 std::vector<at::Tensor> inputs = {x, y, z};
203 std::vector<IValue> stack = fmap<IValue>(inputs);
204 kernel.run(stack);
205 auto out = stack[0].toTensor();
206 ASSERT_EQ(out.sizes(), ref.sizes());
207 ASSERT_EQ(out.dtype(), ref.dtype());
208 ASSERT_TRUE(at::allclose(out, ref));
209#endif
210}
211
212TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
213#ifdef TORCH_ENABLE_LLVM
214 const auto graph_string = R"IR(
215 graph(%x : Float(10, strides=[1], device=cpu),
216 %y : Float(20, strides=[1], device=cpu),
217 %z : Double(30, strides=[1], device=cpu)):
218 %dim : int = prim::Constant[value=0]()
219 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
220 %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
221 %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
222 return (%5))IR";
223 auto g = std::make_shared<Graph>();
224 torch::jit::parseIR(graph_string, g.get());
225 g->lint();
226
227 TensorExprKernel kernel(g);
228
229 // No transformation should have happened because the `aten::cat` op performs
230 // type promotion. This case is currently not handled.
231 testing::FileCheck()
232 .check("aten::cat")
233 ->check("aten::log")
234 ->check_not("aten::cat")
235 ->check_not("aten::log")
236 ->run(*kernel.graph());
237#endif
238}
239
240TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
241#ifdef TORCH_ENABLE_LLVM
242 const auto graph_string = R"IR(
243 graph(%0 : Float(60, strides=[1], device=cpu),
244 %x : Float(10, strides=[1], device=cpu),
245 %y : Float(20, strides=[1], device=cpu),
246 %z : Float(30, strides=[1], device=cpu)):
247 %dim : int = prim::Constant[value=0]()
248 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
249 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
250 %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
251 return (%5))IR";
252 auto g = std::make_shared<Graph>();
253 torch::jit::parseIR(graph_string, g.get());
254 g->lint();
255
256 TensorExprKernel kernel(g);
257
258 // No transformation is expected since the consumers of cat are not
259 // single-tensor element-wise ops.
260 testing::FileCheck()
261 .check("aten::cat")
262 ->check("aten::mul")
263 ->check_not("aten::cat")
264 ->check_not("aten::mul")
265 ->run(*kernel.graph());
266#endif
267}
268
269TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
270#ifdef TORCH_ENABLE_LLVM
271 const auto graph_string = R"IR(
272 graph(%0 : Float(60, strides=[1], device=cpu),
273 %1 : Float(60, strides=[1], device=cpu),
274 %x : Float(10, strides=[1], device=cpu),
275 %y : Float(20, strides=[1], device=cpu),
276 %z : Float(30, strides=[1], device=cpu)):
277 %one : int = prim::Constant[value=1]()
278 %dim : int = prim::Constant[value=0]()
279 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
280 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
281 %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
282 %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
283 return (%6))IR";
284 auto g = std::make_shared<Graph>();
285 torch::jit::parseIR(graph_string, g.get());
286 g->lint();
287
288 TensorExprKernel kernel(g);
289
290 // No transformation is expected since the consumers of cat are not
291 // single-tensor element-wise ops.
292 testing::FileCheck()
293 .check("aten::cat")
294 ->check("aten::mul")
295 ->check("aten::add")
296 ->check_not("aten::cat")
297 ->check_not("aten::mul")
298 ->check_not("aten::add")
299 ->run(*kernel.graph());
300#endif
301}
302
303TEST_F(GraphOpt, AOTGraphPrepPasses) {
304 const auto graph_string = R"IR(
305 graph(%x, %y, %z, %i : int):
306 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
307 return (%xyz_list, %i))IR";
308 auto g = std::make_shared<Graph>();
309 torch::jit::parseIR(graph_string, g.get());
310
311 removeGraphOutput(g, 1);
312 replaceListOutputWithTuple(g);
313 LowerAllTuples(g);
314
315 testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
316}
317
318} // namespace jit
319} // namespace torch
320