1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/code_template.h> |
4 | #include <c10/core/DeviceType.h> |
5 | #include <test/cpp/tensorexpr/test_base.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/ir/irparser.h> |
8 | #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h> |
9 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
10 | #include <torch/csrc/jit/testing/file_check.h> |
11 | #include <torch/torch.h> |
12 | #include <cmath> |
13 | #include <sstream> |
14 | #include <stdexcept> |
15 | #include <thread> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | using namespace torch::indexing; |
21 | using namespace torch::jit::tensorexpr; |
22 | |
23 | TEST(DynamicShapes, SimpleGraph) { |
24 | #ifdef TORCH_ENABLE_LLVM |
25 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
26 | const auto graph_string = R"IR( |
27 | graph(%x : Tensor, |
28 | %SS_2 : int, |
29 | %SS_3 : int): |
30 | %3 : Tensor = aten::tanh(%x) |
31 | %4 : Tensor = aten::erf(%3) |
32 | return (%4))IR" ; |
33 | torch::jit::parseIR(graph_string, graph.get()); |
34 | |
35 | auto x_inp = graph->inputs()[0]; |
36 | auto x_type = TensorType::create(at::rand({10, 5})); |
37 | std::vector<ShapeSymbol> x_sym_dims( |
38 | {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); |
39 | auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); |
40 | graph->inputs().at(0)->setType(x_sym_type); |
41 | for (const auto n : graph->nodes()) { |
42 | n->output()->setType(x_sym_type); |
43 | } |
44 | |
45 | // Graph with symbolic shapes: |
46 | // |
47 | // graph(%x : Float(SS(-2), SS(-3)), |
48 | // %SS_2 : int, |
49 | // %SS_3 : int): |
50 | // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x) |
51 | // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3) |
52 | // return (%4) |
53 | |
54 | std::vector<torch::jit::StrideInput> input_desc = { |
55 | torch::jit::StrideInput::TENSOR_CONT}; |
56 | std::unordered_map< |
57 | const torch::jit::Value*, |
58 | std::vector<torch::jit::StrideInput>> |
59 | symbolic_strides; |
60 | symbolic_strides[x_inp] = input_desc; |
61 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
62 | std::vector<int64_t> symbolic_shape_inputs = c10::fmap( |
63 | x_sym_dims, |
64 | [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); |
65 | |
66 | TensorExprKernel kernel( |
67 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
68 | // Run with the same static dims as the one we initialized the graph with. |
69 | { |
70 | auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
71 | auto ref = at::erf(at::tanh(a)); |
72 | |
73 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a})); |
74 | stack.push_back(10); |
75 | stack.push_back(5); |
76 | kernel.run(stack); |
77 | |
78 | auto o = stack[0].toTensor(); |
79 | ASSERT_TRUE(at::allclose(o, ref)); |
80 | } |
81 | |
82 | // Run with inputs having different dims. |
83 | { |
84 | auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
85 | auto ref = at::erf(at::tanh(a)); |
86 | |
87 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a})); |
88 | stack.push_back(50); |
89 | stack.push_back(100); |
90 | kernel.run(stack); |
91 | |
92 | auto o = stack[0].toTensor(); |
93 | ASSERT_TRUE(at::allclose(o, ref)); |
94 | } |
95 | #endif |
96 | } |
97 | |
98 | TEST(DynamicShapes, GraphWith2InputsSameDims) { |
99 | #ifdef TORCH_ENABLE_LLVM |
100 | // The two inputs in this graph must have the same dims. |
101 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
102 | const auto graph_string = R"IR( |
103 | graph(%x : Tensor, |
104 | %y : Tensor, |
105 | %SS_2 : int, |
106 | %SS_3 : int): |
107 | %3 : Tensor = aten::tanh(%x) |
108 | %4 : Tensor = aten::erf(%3) |
109 | %5 : Tensor = aten::mul(%4, %y) |
110 | return (%5))IR" ; |
111 | torch::jit::parseIR(graph_string, graph.get()); |
112 | |
113 | auto x_inp = graph->inputs()[0]; |
114 | auto y_inp = graph->inputs()[1]; |
115 | auto x_type = TensorType::create(at::rand({10, 5})); |
116 | std::vector<ShapeSymbol> x_sym_dims( |
117 | {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); |
118 | auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); |
119 | graph->inputs().at(0)->setType(x_sym_type); |
120 | graph->inputs().at(1)->setType(x_sym_type); |
121 | for (const auto n : graph->nodes()) { |
122 | n->output()->setType(x_sym_type); |
123 | } |
124 | |
125 | // Graph with symbolic shapes: |
126 | // |
127 | // graph(%x : Float(SS(-4), SS(-5)), |
128 | // %y : Float(SS(-4), SS(-5)), |
129 | // %SS_2 : int, |
130 | // %SS_3 : int): |
131 | // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x) |
132 | // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4) |
133 | // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y) |
134 | // return (%6) |
135 | |
136 | std::vector<int64_t> symbolic_shape_inputs = c10::fmap( |
137 | x_sym_dims, |
138 | [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); |
139 | |
140 | std::vector<torch::jit::StrideInput> input_desc = { |
141 | torch::jit::StrideInput::TENSOR_CONT}; |
142 | std::unordered_map< |
143 | const torch::jit::Value*, |
144 | std::vector<torch::jit::StrideInput>> |
145 | symbolic_strides; |
146 | symbolic_strides[x_inp] = input_desc; |
147 | symbolic_strides[y_inp] = input_desc; |
148 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
149 | |
150 | TensorExprKernel kernel( |
151 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
152 | |
153 | // Run with the same static dims as the one we initialized the graph with. |
154 | { |
155 | auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
156 | auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
157 | auto ref = at::mul(at::erf(at::tanh(a)), b); |
158 | |
159 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
160 | stack.push_back(10); |
161 | stack.push_back(5); |
162 | kernel.run(stack); |
163 | |
164 | auto o = stack[0].toTensor(); |
165 | ASSERT_TRUE(at::allclose(o, ref)); |
166 | } |
167 | |
168 | // Run with inputs having different dims. |
169 | { |
170 | auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
171 | auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
172 | auto ref = at::mul(at::erf(at::tanh(a)), b); |
173 | |
174 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
175 | stack.push_back(50); |
176 | stack.push_back(100); |
177 | kernel.run(stack); |
178 | |
179 | auto o = stack[0].toTensor(); |
180 | ASSERT_TRUE(at::allclose(o, ref)); |
181 | } |
182 | #endif |
183 | } |
184 | |
185 | TEST(DynamicShapes, GraphWith2InputsAndBroadcast) { |
186 | #ifdef TORCH_ENABLE_LLVM |
187 | // The second input to the graph has a dim of size 1 which should be |
188 | // broadcasted in the at::mul op. |
189 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
190 | const auto graph_string = R"IR( |
191 | graph(%x : Float(10, 5, requires_grad=0, device=cpu), |
192 | %y : Float(1, 5, requires_grad=0, device=cpu), |
193 | %SS_2 : int, |
194 | %SS_3 : int): |
195 | %3 : Tensor = aten::tanh(%x) |
196 | %4 : Tensor = aten::erf(%3) |
197 | %5 : Tensor = aten::mul(%4, %y) |
198 | return (%5))IR" ; |
199 | torch::jit::parseIR(graph_string, graph.get()); |
200 | |
201 | auto x_inp = graph->inputs()[0]; |
202 | auto y_inp = graph->inputs()[1]; |
203 | auto x_type = TensorType::create(at::rand({10, 5})); |
204 | auto y_type = TensorType::create(at::rand({1, 5})); |
205 | auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); |
206 | auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); |
207 | auto x_sym_type = x_type->withSymbolicShapes( |
208 | std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym})); |
209 | auto y_sym_type = y_type->withSymbolicShapes(std::vector<ShapeSymbol>( |
210 | {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); |
211 | graph->inputs().at(0)->setType(x_sym_type); |
212 | graph->inputs().at(1)->setType(y_sym_type); |
213 | for (const auto n : graph->nodes()) { |
214 | n->output()->setType(x_sym_type); |
215 | } |
216 | |
217 | // Graph with symbolic shapes: |
218 | // |
219 | // graph(%x : Float(SS(-6), SS(-7)), |
220 | // %y : Float(1, SS(-7)), |
221 | // %SS_2 : int, |
222 | // %SS_3 : int): |
223 | // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x) |
224 | // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4) |
225 | // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y) |
226 | // return (%6) |
227 | |
228 | std::vector<int64_t> symbolic_shape_inputs( |
229 | {x_dim0_sym.value(), x_dim1_sym.value()}); |
230 | |
231 | std::vector<torch::jit::StrideInput> input_desc = { |
232 | torch::jit::StrideInput::TENSOR_CONT}; |
233 | std::unordered_map< |
234 | const torch::jit::Value*, |
235 | std::vector<torch::jit::StrideInput>> |
236 | symbolic_strides; |
237 | symbolic_strides[x_inp] = input_desc; |
238 | symbolic_strides[y_inp] = input_desc; |
239 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
240 | |
241 | TensorExprKernel kernel( |
242 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
243 | |
244 | // Run with the same static dims as the one we initialized the graph with. |
245 | { |
246 | auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
247 | auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
248 | auto ref = at::mul(at::erf(at::tanh(a)), b); |
249 | |
250 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
251 | stack.push_back(10); |
252 | stack.push_back(5); |
253 | kernel.run(stack); |
254 | |
255 | auto o = stack[0].toTensor(); |
256 | ASSERT_TRUE(at::allclose(o, ref)); |
257 | } |
258 | |
259 | // Run with inputs having different dims. |
260 | { |
261 | auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
262 | auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
263 | auto ref = at::mul(at::erf(at::tanh(a)), b); |
264 | |
265 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
266 | stack.push_back(50); |
267 | stack.push_back(100); |
268 | kernel.run(stack); |
269 | |
270 | auto o = stack[0].toTensor(); |
271 | ASSERT_TRUE(at::allclose(o, ref)); |
272 | } |
273 | #endif |
274 | } |
275 | |
276 | TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) { |
277 | #ifdef TORCH_ENABLE_LLVM |
278 | // The second input to the graph has a dim of size 1 which should be |
279 | // broadcasted in the at::mul op. |
280 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
281 | const auto graph_string = R"IR( |
282 | graph(%x : Float(1, 5, requires_grad=0, device=cpu), |
283 | %y : Float(1, 5, requires_grad=0, device=cpu), |
284 | %SS_2 : int): |
285 | %4 : Tensor = aten::tanh(%x) |
286 | %5 : Tensor = aten::mul(%4, %y) |
287 | return (%5))IR" ; |
288 | torch::jit::parseIR(graph_string, graph.get()); |
289 | |
290 | auto x_inp = graph->inputs()[0]; |
291 | auto y_inp = graph->inputs()[1]; |
292 | auto x_type = TensorType::create(at::rand({1, 5})); |
293 | auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); |
294 | auto x_sym_type = x_type->withSymbolicShapes(std::vector<ShapeSymbol>( |
295 | {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); |
296 | graph->inputs().at(0)->setType(x_sym_type); |
297 | graph->inputs().at(1)->setType(x_sym_type); |
298 | for (const auto n : graph->nodes()) { |
299 | n->output()->setType(x_sym_type); |
300 | } |
301 | |
302 | // Graph with symbolic shapes: |
303 | // |
304 | // graph(%x : Float(1, SS(-2)), |
305 | // %y : Float(1, SS(-2)), |
306 | // %SS_2 : int): |
307 | // %3 : Float(1, SS(-2)) = aten::tanh(%x) |
308 | // %4 : Float(1, SS(-2)) = aten::mul(%3, %y) |
309 | // return (%4) |
310 | |
311 | std::vector<int64_t> symbolic_shape_inputs({x_dim1_sym.value()}); |
312 | |
313 | std::vector<torch::jit::StrideInput> input_desc = { |
314 | torch::jit::StrideInput::TENSOR_CONT}; |
315 | std::unordered_map< |
316 | const torch::jit::Value*, |
317 | std::vector<torch::jit::StrideInput>> |
318 | symbolic_strides; |
319 | symbolic_strides[x_inp] = input_desc; |
320 | symbolic_strides[y_inp] = input_desc; |
321 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
322 | |
323 | TensorExprKernel kernel( |
324 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
325 | |
326 | // Run with the same static dims as the one we initialized the graph with. |
327 | { |
328 | auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
329 | auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
330 | auto ref = at::mul(at::tanh(a), b); |
331 | |
332 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
333 | stack.push_back(5); |
334 | kernel.run(stack); |
335 | |
336 | auto o = stack[0].toTensor(); |
337 | ASSERT_TRUE(at::allclose(o, ref)); |
338 | } |
339 | |
340 | // Run with inputs having different dims. |
341 | { |
342 | auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
343 | auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
344 | auto ref = at::mul(at::tanh(a), b); |
345 | |
346 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
347 | stack.push_back(100); |
348 | kernel.run(stack); |
349 | |
350 | auto o = stack[0].toTensor(); |
351 | ASSERT_TRUE(at::allclose(o, ref)); |
352 | } |
353 | #endif |
354 | } |
355 | |
356 | TEST(DynamicShapes, GraphWithSymbolicStrides) { |
357 | #ifdef TORCH_ENABLE_LLVM |
358 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
359 | const auto graph_string = R"IR( |
360 | graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), |
361 | %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), |
362 | %SS_3 : int, |
363 | %SS_2 : int): |
364 | %15 : int = prim::Constant[value=1]() |
365 | %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15) |
366 | %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0) |
367 | return (%22))IR" ; |
368 | parseIR(graph_string, &*graph); |
369 | |
370 | std::vector<torch::jit::StrideInput> input_desc = { |
371 | torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE}; |
372 | std::vector<torch::jit::StrideInput> output_desc = { |
373 | torch::jit::StrideInput::TENSOR_CONT}; |
374 | std::unordered_map< |
375 | const torch::jit::Value*, |
376 | std::vector<torch::jit::StrideInput>> |
377 | symbolic_strides; |
378 | symbolic_strides[graph->inputs().at(0)] = input_desc; |
379 | symbolic_strides[graph->inputs().at(1)] = input_desc; |
380 | symbolic_strides[graph->outputs().at(0)] = output_desc; |
381 | std::vector<int64_t> symbolic_shape_inputs = {-3, -2}; |
382 | TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
383 | |
384 | { |
385 | auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
386 | auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
387 | auto ref = at::mul(at::add(x0, x1, 1), x0); |
388 | |
389 | std::vector<at::Tensor> inputs = {x0, x1}; |
390 | std::vector<IValue> stack = at::fmap<at::IValue>(inputs); |
391 | stack.push_back(32); |
392 | stack.push_back(10); |
393 | k.run(stack); |
394 | |
395 | auto o = stack[0].toTensor(); |
396 | ASSERT_TRUE(at::allclose(o, ref)); |
397 | } |
398 | |
399 | { |
400 | auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
401 | auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
402 | auto out = |
403 | at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
404 | auto ref = at::mul(at::add(x0, x1, 1), x0); |
405 | |
406 | std::vector<at::Tensor> inputs = {out, x0, x1}; |
407 | std::vector<IValue> stack = at::fmap<at::IValue>(inputs); |
408 | stack.push_back(32); |
409 | stack.push_back(10); |
410 | k.runWithAllocatedOutputs(stack); |
411 | |
412 | ASSERT_TRUE(at::allclose(out, ref)); |
413 | } |
414 | #endif |
415 | } |
416 | |
417 | TEST(DynamicShapes, GraphWithCatAndBroadcast) { |
418 | #ifdef TORCH_ENABLE_LLVM |
419 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
420 | const auto graph_string = R"IR( |
421 | graph(%x : Float(10, 5, requires_grad=0, device=cpu), |
422 | %y : Float(4, 5, requires_grad=0, device=cpu), |
423 | %z : Float(1, 1, requires_grad=0, device=cpu), |
424 | %SS_2 : int, |
425 | %SS_3 : int, |
426 | %SS_4 : int, |
427 | %SS_5 : int): |
428 | %11 : int = prim::Constant[value=0]() |
429 | %3 : Tensor = aten::tanh(%x) |
430 | %out1 : Tensor = aten::erf(%3) |
431 | %out2 : Tensor = aten::relu(%y) |
432 | %10 : Tensor[] = prim::ListConstruct(%out1, %out2) |
433 | %25 : Tensor = aten::cat(%10, %11) |
434 | %28 : Tensor = aten::hardswish(%25) |
435 | %29 : Tensor = aten::mul(%28, %z) |
436 | return (%29))IR" ; |
437 | torch::jit::parseIR(graph_string, graph.get()); |
438 | |
439 | auto x_inp = graph->inputs()[0]; |
440 | auto y_inp = graph->inputs()[1]; |
441 | auto z_inp = graph->inputs()[2]; |
442 | auto x_type = TensorType::create(at::rand({10, 5})); |
443 | auto y_type = TensorType::create(at::rand({4, 5})); |
444 | auto z_type = TensorType::create(at::rand({1, 1})); |
445 | auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); |
446 | auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); |
447 | auto x_sym_type = x_type->withSymbolicShapes( |
448 | std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym})); |
449 | auto y_dim0_sym = c10::ShapeSymbol::newSymbol(); |
450 | auto y_sym_type = y_type->withSymbolicShapes( |
451 | std::vector<ShapeSymbol>({y_dim0_sym, x_dim1_sym})); |
452 | graph->inputs().at(0)->setType(x_sym_type); |
453 | graph->inputs().at(1)->setType(y_sym_type); |
454 | auto cat_dim0_sym = c10::ShapeSymbol::newSymbol(); |
455 | auto cat_out_type = x_type->withSymbolicShapes( |
456 | std::vector<ShapeSymbol>({cat_dim0_sym, x_dim1_sym})); |
457 | auto nodeIt = graph->nodes().begin(); |
458 | ++nodeIt; |
459 | nodeIt->output()->setType(x_sym_type); // aten::tanh |
460 | ++nodeIt; |
461 | nodeIt->output()->setType(x_sym_type); // aten::erf |
462 | ++nodeIt; |
463 | nodeIt->output()->setType(y_sym_type); // aten::relu |
464 | ++nodeIt; |
465 | ++nodeIt; |
466 | nodeIt->output()->setType(cat_out_type); // aten::cat |
467 | ++nodeIt; |
468 | nodeIt->output()->setType(cat_out_type); // aten::hardswish |
469 | ++nodeIt; |
470 | nodeIt->output()->setType(cat_out_type); // aten::mul |
471 | |
472 | // Graph with symbolic shapes: |
473 | // |
474 | // graph(%x : Float(SS(-2), SS(-3)), |
475 | // %y : Float(SS(-4), SS(-3)), |
476 | // %z : Float(1, 1), |
477 | // %SS_2 : int, |
478 | // %SS_3 : int, |
479 | // %SS_4 : int, |
480 | // %SS_5 : int): |
481 | // %7 : int = prim::Constant[value=0]() |
482 | // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x) |
483 | // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8) |
484 | // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y) |
485 | // %11 : Tensor[] = prim::ListConstruct(%9, %10) |
486 | // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7) |
487 | // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12) |
488 | // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z) |
489 | // return (%14) |
490 | |
491 | std::vector<int64_t> symbolic_shape_inputs( |
492 | {x_dim0_sym.value(), |
493 | x_dim1_sym.value(), |
494 | y_dim0_sym.value(), |
495 | cat_dim0_sym.value()}); |
496 | |
497 | std::vector<torch::jit::StrideInput> input_desc = { |
498 | torch::jit::StrideInput::TENSOR_CONT}; |
499 | std::unordered_map< |
500 | const torch::jit::Value*, |
501 | std::vector<torch::jit::StrideInput>> |
502 | symbolic_strides; |
503 | symbolic_strides[x_inp] = input_desc; |
504 | symbolic_strides[y_inp] = input_desc; |
505 | symbolic_strides[z_inp] = input_desc; |
506 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
507 | |
508 | TensorExprKernel kernel( |
509 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
510 | |
511 | auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
512 | auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
513 | auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
514 | auto ref = at::mul( |
515 | at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c); |
516 | |
517 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b, c})); |
518 | stack.push_back(10); |
519 | stack.push_back(5); |
520 | stack.push_back(4); |
521 | stack.push_back(14); |
522 | kernel.run(stack); |
523 | |
524 | auto o = stack[0].toTensor(); |
525 | ASSERT_TRUE(at::allclose(o, ref)); |
526 | #endif |
527 | } |
528 | |
529 | TEST(DynamicShapes, GraphFromModel) { |
530 | #ifdef TORCH_ENABLE_LLVM |
531 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
532 | const auto graph_string = R"IR( |
533 | graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), |
534 | %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu), |
535 | %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu), |
536 | %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu), |
537 | %4 : Float(SS(-7), requires_grad=0, device=cpu), |
538 | %5 : Float(SS(-7), requires_grad=0, device=cpu), |
539 | %SS_10 : int, |
540 | %SS_9 : int, |
541 | %SS_8 : int, |
542 | %SS_7 : int, |
543 | %SS_6 : int, |
544 | %SS_5 : int, |
545 | %SS_4 : int, |
546 | %SS_3 : int, |
547 | %SS_2 : int): |
548 | %15 : int = prim::Constant[value=1]() |
549 | %16 : bool = prim::Constant[value=0]() |
550 | %17 : int = prim::Constant[value=6]() |
551 | %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16) |
552 | %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2) |
553 | %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15) |
554 | %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15) |
555 | %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4) |
556 | return (%22))IR" ; |
557 | parseIR(graph_string, &*graph); |
558 | |
559 | std::vector<torch::jit::StrideInput> input_desc = { |
560 | torch::jit::StrideInput::TENSOR_CONT}; |
561 | std::unordered_map< |
562 | const torch::jit::Value*, |
563 | std::vector<torch::jit::StrideInput>> |
564 | symbolic_strides; |
565 | symbolic_strides[graph->inputs().at(0)] = input_desc; |
566 | symbolic_strides[graph->inputs().at(1)] = input_desc; |
567 | symbolic_strides[graph->inputs().at(2)] = input_desc; |
568 | symbolic_strides[graph->inputs().at(3)] = input_desc; |
569 | symbolic_strides[graph->inputs().at(4)] = input_desc; |
570 | symbolic_strides[graph->inputs().at(5)] = input_desc; |
571 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
572 | std::vector<int64_t> symbolic_shape_inputs = { |
573 | -10, -9, -8, -7, -6, -5, -4, -3, -2}; |
574 | TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
575 | |
576 | int64_t i2 = 10; |
577 | int64_t i3 = 32; |
578 | int64_t i4 = 19; |
579 | int64_t i5 = 71; |
580 | int64_t i6 = 139; |
581 | int64_t i7 = 261; |
582 | int64_t i8 = 261; |
583 | int64_t i9 = 261; |
584 | int64_t i10 = 261; |
585 | auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
586 | auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
587 | auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
588 | auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong)); |
589 | auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
590 | auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
591 | auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4); |
592 | |
593 | { |
594 | std::vector<at::Tensor> inputs = {x0, x1, x2, x3, x4, x5}; |
595 | std::vector<IValue> stack = at::fmap<at::IValue>(inputs); |
596 | stack.emplace_back(i10); |
597 | stack.emplace_back(i9); |
598 | stack.emplace_back(i8); |
599 | stack.emplace_back(i7); |
600 | stack.emplace_back(i6); |
601 | stack.emplace_back(i5); |
602 | stack.emplace_back(i4); |
603 | stack.emplace_back(i3); |
604 | stack.emplace_back(i2); |
605 | k.run(stack); |
606 | |
607 | auto o = stack[0].toTensor(); |
608 | ASSERT_TRUE(at::allclose(o, ref)); |
609 | } |
610 | |
611 | { |
612 | auto out = |
613 | at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); |
614 | std::vector<at::Tensor> inputs = {out, x0, x1, x2, x3, x4, x5}; |
615 | std::vector<IValue> stack = at::fmap<at::IValue>(inputs); |
616 | stack.emplace_back(i10); |
617 | stack.emplace_back(i9); |
618 | stack.emplace_back(i8); |
619 | stack.emplace_back(i7); |
620 | stack.emplace_back(i6); |
621 | stack.emplace_back(i5); |
622 | stack.emplace_back(i4); |
623 | stack.emplace_back(i3); |
624 | stack.emplace_back(i2); |
625 | k.runWithAllocatedOutputs(stack); |
626 | |
627 | ASSERT_TRUE(at::allclose(out, ref)); |
628 | } |
629 | #endif |
630 | } |
631 | |
632 | TEST(DynamicShapes, MultiThreadedExecution) { |
633 | #ifdef TORCH_ENABLE_LLVM |
634 | const auto graph_template = R"IR( |
635 | graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), |
636 | %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), |
637 | %SS_2 : int, |
638 | %SS_3 : int): |
639 | %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) |
640 | %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) |
641 | %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) |
642 | return (%5))IR" ; |
643 | for (bool use_cuda : {false, true}) { |
644 | if (!torch::cuda::is_available() && use_cuda) { |
645 | continue; |
646 | } |
647 | auto device = use_cuda ? at::kCUDA : at::kCPU; |
648 | at::jit::TemplateEnv env; |
649 | env.s("device" , use_cuda ? "cuda:0" : "cpu" ); |
650 | const auto graph_string = format(graph_template, env); |
651 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
652 | torch::jit::parseIR(graph_string, graph.get()); |
653 | |
654 | std::vector<int64_t> symbolic_shape_inputs = {-2, -3}; |
655 | |
656 | std::vector<torch::jit::StrideInput> input_desc = { |
657 | torch::jit::StrideInput::TENSOR_CONT}; |
658 | std::unordered_map< |
659 | const torch::jit::Value*, |
660 | std::vector<torch::jit::StrideInput>> |
661 | symbolic_strides; |
662 | symbolic_strides[graph->inputs().at(0)] = input_desc; |
663 | symbolic_strides[graph->inputs().at(1)] = input_desc; |
664 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
665 | |
666 | TensorExprKernel kernel( |
667 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
668 | |
669 | auto run_kernel = [&](int dim1, int dim2) { |
670 | auto a = |
671 | at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); |
672 | auto b = |
673 | at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); |
674 | |
675 | auto ref = at::mul(at::erf(at::tanh(a)), b); |
676 | |
677 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
678 | stack.emplace_back(dim1); |
679 | stack.emplace_back(dim2); |
680 | kernel.run(stack); |
681 | |
682 | auto o = stack[0].toTensor(); |
683 | ASSERT_TRUE(at::allclose(o, ref)); |
684 | }; |
685 | |
686 | // Run the kernel in parallel to ensure that the run() method calls in |
687 | // TensorExprKernel are not changing any state. |
688 | constexpr size_t kNumThreads = 4; |
689 | std::vector<std::thread> threads; |
690 | for (size_t id = 0; id < kNumThreads; ++id) { |
691 | threads.emplace_back(run_kernel, id + 5, id + 20); |
692 | } |
693 | for (auto& t : threads) { |
694 | t.join(); |
695 | } |
696 | } |
697 | #endif |
698 | } |
699 | |
700 | } // namespace jit |
701 | } // namespace torch |
702 | |