1#include <gtest/gtest.h>
2
3#include <ATen/code_template.h>
4#include <c10/core/DeviceType.h>
5#include <test/cpp/tensorexpr/test_base.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/irparser.h>
8#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
9#include <torch/csrc/jit/tensorexpr/kernel.h>
10#include <torch/csrc/jit/testing/file_check.h>
11#include <torch/torch.h>
12#include <cmath>
13#include <sstream>
14#include <stdexcept>
15#include <thread>
16
17namespace torch {
18namespace jit {
19
20using namespace torch::indexing;
21using namespace torch::jit::tensorexpr;
22
23TEST(DynamicShapes, SimpleGraph) {
24#ifdef TORCH_ENABLE_LLVM
25 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
26 const auto graph_string = R"IR(
27 graph(%x : Tensor,
28 %SS_2 : int,
29 %SS_3 : int):
30 %3 : Tensor = aten::tanh(%x)
31 %4 : Tensor = aten::erf(%3)
32 return (%4))IR";
33 torch::jit::parseIR(graph_string, graph.get());
34
35 auto x_inp = graph->inputs()[0];
36 auto x_type = TensorType::create(at::rand({10, 5}));
37 std::vector<ShapeSymbol> x_sym_dims(
38 {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()});
39 auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims);
40 graph->inputs().at(0)->setType(x_sym_type);
41 for (const auto n : graph->nodes()) {
42 n->output()->setType(x_sym_type);
43 }
44
45 // Graph with symbolic shapes:
46 //
47 // graph(%x : Float(SS(-2), SS(-3)),
48 // %SS_2 : int,
49 // %SS_3 : int):
50 // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x)
51 // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3)
52 // return (%4)
53
54 std::vector<torch::jit::StrideInput> input_desc = {
55 torch::jit::StrideInput::TENSOR_CONT};
56 std::unordered_map<
57 const torch::jit::Value*,
58 std::vector<torch::jit::StrideInput>>
59 symbolic_strides;
60 symbolic_strides[x_inp] = input_desc;
61 symbolic_strides[graph->outputs().at(0)] = input_desc;
62 std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
63 x_sym_dims,
64 [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
65
66 TensorExprKernel kernel(
67 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
68 // Run with the same static dims as the one we initialized the graph with.
69 {
70 auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
71 auto ref = at::erf(at::tanh(a));
72
73 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a}));
74 stack.push_back(10);
75 stack.push_back(5);
76 kernel.run(stack);
77
78 auto o = stack[0].toTensor();
79 ASSERT_TRUE(at::allclose(o, ref));
80 }
81
82 // Run with inputs having different dims.
83 {
84 auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
85 auto ref = at::erf(at::tanh(a));
86
87 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a}));
88 stack.push_back(50);
89 stack.push_back(100);
90 kernel.run(stack);
91
92 auto o = stack[0].toTensor();
93 ASSERT_TRUE(at::allclose(o, ref));
94 }
95#endif
96}
97
98TEST(DynamicShapes, GraphWith2InputsSameDims) {
99#ifdef TORCH_ENABLE_LLVM
100 // The two inputs in this graph must have the same dims.
101 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
102 const auto graph_string = R"IR(
103 graph(%x : Tensor,
104 %y : Tensor,
105 %SS_2 : int,
106 %SS_3 : int):
107 %3 : Tensor = aten::tanh(%x)
108 %4 : Tensor = aten::erf(%3)
109 %5 : Tensor = aten::mul(%4, %y)
110 return (%5))IR";
111 torch::jit::parseIR(graph_string, graph.get());
112
113 auto x_inp = graph->inputs()[0];
114 auto y_inp = graph->inputs()[1];
115 auto x_type = TensorType::create(at::rand({10, 5}));
116 std::vector<ShapeSymbol> x_sym_dims(
117 {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()});
118 auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims);
119 graph->inputs().at(0)->setType(x_sym_type);
120 graph->inputs().at(1)->setType(x_sym_type);
121 for (const auto n : graph->nodes()) {
122 n->output()->setType(x_sym_type);
123 }
124
125 // Graph with symbolic shapes:
126 //
127 // graph(%x : Float(SS(-4), SS(-5)),
128 // %y : Float(SS(-4), SS(-5)),
129 // %SS_2 : int,
130 // %SS_3 : int):
131 // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x)
132 // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4)
133 // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y)
134 // return (%6)
135
136 std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
137 x_sym_dims,
138 [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
139
140 std::vector<torch::jit::StrideInput> input_desc = {
141 torch::jit::StrideInput::TENSOR_CONT};
142 std::unordered_map<
143 const torch::jit::Value*,
144 std::vector<torch::jit::StrideInput>>
145 symbolic_strides;
146 symbolic_strides[x_inp] = input_desc;
147 symbolic_strides[y_inp] = input_desc;
148 symbolic_strides[graph->outputs().at(0)] = input_desc;
149
150 TensorExprKernel kernel(
151 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
152
153 // Run with the same static dims as the one we initialized the graph with.
154 {
155 auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
156 auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
157 auto ref = at::mul(at::erf(at::tanh(a)), b);
158
159 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
160 stack.push_back(10);
161 stack.push_back(5);
162 kernel.run(stack);
163
164 auto o = stack[0].toTensor();
165 ASSERT_TRUE(at::allclose(o, ref));
166 }
167
168 // Run with inputs having different dims.
169 {
170 auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
171 auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
172 auto ref = at::mul(at::erf(at::tanh(a)), b);
173
174 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
175 stack.push_back(50);
176 stack.push_back(100);
177 kernel.run(stack);
178
179 auto o = stack[0].toTensor();
180 ASSERT_TRUE(at::allclose(o, ref));
181 }
182#endif
183}
184
185TEST(DynamicShapes, GraphWith2InputsAndBroadcast) {
186#ifdef TORCH_ENABLE_LLVM
187 // The second input to the graph has a dim of size 1 which should be
188 // broadcasted in the at::mul op.
189 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
190 const auto graph_string = R"IR(
191 graph(%x : Float(10, 5, requires_grad=0, device=cpu),
192 %y : Float(1, 5, requires_grad=0, device=cpu),
193 %SS_2 : int,
194 %SS_3 : int):
195 %3 : Tensor = aten::tanh(%x)
196 %4 : Tensor = aten::erf(%3)
197 %5 : Tensor = aten::mul(%4, %y)
198 return (%5))IR";
199 torch::jit::parseIR(graph_string, graph.get());
200
201 auto x_inp = graph->inputs()[0];
202 auto y_inp = graph->inputs()[1];
203 auto x_type = TensorType::create(at::rand({10, 5}));
204 auto y_type = TensorType::create(at::rand({1, 5}));
205 auto x_dim0_sym = c10::ShapeSymbol::newSymbol();
206 auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
207 auto x_sym_type = x_type->withSymbolicShapes(
208 std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym}));
209 auto y_sym_type = y_type->withSymbolicShapes(std::vector<ShapeSymbol>(
210 {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym}));
211 graph->inputs().at(0)->setType(x_sym_type);
212 graph->inputs().at(1)->setType(y_sym_type);
213 for (const auto n : graph->nodes()) {
214 n->output()->setType(x_sym_type);
215 }
216
217 // Graph with symbolic shapes:
218 //
219 // graph(%x : Float(SS(-6), SS(-7)),
220 // %y : Float(1, SS(-7)),
221 // %SS_2 : int,
222 // %SS_3 : int):
223 // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x)
224 // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4)
225 // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y)
226 // return (%6)
227
228 std::vector<int64_t> symbolic_shape_inputs(
229 {x_dim0_sym.value(), x_dim1_sym.value()});
230
231 std::vector<torch::jit::StrideInput> input_desc = {
232 torch::jit::StrideInput::TENSOR_CONT};
233 std::unordered_map<
234 const torch::jit::Value*,
235 std::vector<torch::jit::StrideInput>>
236 symbolic_strides;
237 symbolic_strides[x_inp] = input_desc;
238 symbolic_strides[y_inp] = input_desc;
239 symbolic_strides[graph->outputs().at(0)] = input_desc;
240
241 TensorExprKernel kernel(
242 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
243
244 // Run with the same static dims as the one we initialized the graph with.
245 {
246 auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
247 auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
248 auto ref = at::mul(at::erf(at::tanh(a)), b);
249
250 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
251 stack.push_back(10);
252 stack.push_back(5);
253 kernel.run(stack);
254
255 auto o = stack[0].toTensor();
256 ASSERT_TRUE(at::allclose(o, ref));
257 }
258
259 // Run with inputs having different dims.
260 {
261 auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
262 auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
263 auto ref = at::mul(at::erf(at::tanh(a)), b);
264
265 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
266 stack.push_back(50);
267 stack.push_back(100);
268 kernel.run(stack);
269
270 auto o = stack[0].toTensor();
271 ASSERT_TRUE(at::allclose(o, ref));
272 }
273#endif
274}
275
276TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) {
277#ifdef TORCH_ENABLE_LLVM
278 // The second input to the graph has a dim of size 1 which should be
279 // broadcasted in the at::mul op.
280 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
281 const auto graph_string = R"IR(
282 graph(%x : Float(1, 5, requires_grad=0, device=cpu),
283 %y : Float(1, 5, requires_grad=0, device=cpu),
284 %SS_2 : int):
285 %4 : Tensor = aten::tanh(%x)
286 %5 : Tensor = aten::mul(%4, %y)
287 return (%5))IR";
288 torch::jit::parseIR(graph_string, graph.get());
289
290 auto x_inp = graph->inputs()[0];
291 auto y_inp = graph->inputs()[1];
292 auto x_type = TensorType::create(at::rand({1, 5}));
293 auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
294 auto x_sym_type = x_type->withSymbolicShapes(std::vector<ShapeSymbol>(
295 {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym}));
296 graph->inputs().at(0)->setType(x_sym_type);
297 graph->inputs().at(1)->setType(x_sym_type);
298 for (const auto n : graph->nodes()) {
299 n->output()->setType(x_sym_type);
300 }
301
302 // Graph with symbolic shapes:
303 //
304 // graph(%x : Float(1, SS(-2)),
305 // %y : Float(1, SS(-2)),
306 // %SS_2 : int):
307 // %3 : Float(1, SS(-2)) = aten::tanh(%x)
308 // %4 : Float(1, SS(-2)) = aten::mul(%3, %y)
309 // return (%4)
310
311 std::vector<int64_t> symbolic_shape_inputs({x_dim1_sym.value()});
312
313 std::vector<torch::jit::StrideInput> input_desc = {
314 torch::jit::StrideInput::TENSOR_CONT};
315 std::unordered_map<
316 const torch::jit::Value*,
317 std::vector<torch::jit::StrideInput>>
318 symbolic_strides;
319 symbolic_strides[x_inp] = input_desc;
320 symbolic_strides[y_inp] = input_desc;
321 symbolic_strides[graph->outputs().at(0)] = input_desc;
322
323 TensorExprKernel kernel(
324 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
325
326 // Run with the same static dims as the one we initialized the graph with.
327 {
328 auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
329 auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
330 auto ref = at::mul(at::tanh(a), b);
331
332 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
333 stack.push_back(5);
334 kernel.run(stack);
335
336 auto o = stack[0].toTensor();
337 ASSERT_TRUE(at::allclose(o, ref));
338 }
339
340 // Run with inputs having different dims.
341 {
342 auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
343 auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
344 auto ref = at::mul(at::tanh(a), b);
345
346 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
347 stack.push_back(100);
348 kernel.run(stack);
349
350 auto o = stack[0].toTensor();
351 ASSERT_TRUE(at::allclose(o, ref));
352 }
353#endif
354}
355
356TEST(DynamicShapes, GraphWithSymbolicStrides) {
357#ifdef TORCH_ENABLE_LLVM
358 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
359 const auto graph_string = R"IR(
360 graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
361 %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
362 %SS_3 : int,
363 %SS_2 : int):
364 %15 : int = prim::Constant[value=1]()
365 %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15)
366 %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0)
367 return (%22))IR";
368 parseIR(graph_string, &*graph);
369
370 std::vector<torch::jit::StrideInput> input_desc = {
371 torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE};
372 std::vector<torch::jit::StrideInput> output_desc = {
373 torch::jit::StrideInput::TENSOR_CONT};
374 std::unordered_map<
375 const torch::jit::Value*,
376 std::vector<torch::jit::StrideInput>>
377 symbolic_strides;
378 symbolic_strides[graph->inputs().at(0)] = input_desc;
379 symbolic_strides[graph->inputs().at(1)] = input_desc;
380 symbolic_strides[graph->outputs().at(0)] = output_desc;
381 std::vector<int64_t> symbolic_shape_inputs = {-3, -2};
382 TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides);
383
384 {
385 auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
386 auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
387 auto ref = at::mul(at::add(x0, x1, 1), x0);
388
389 std::vector<at::Tensor> inputs = {x0, x1};
390 std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
391 stack.push_back(32);
392 stack.push_back(10);
393 k.run(stack);
394
395 auto o = stack[0].toTensor();
396 ASSERT_TRUE(at::allclose(o, ref));
397 }
398
399 {
400 auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
401 auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
402 auto out =
403 at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
404 auto ref = at::mul(at::add(x0, x1, 1), x0);
405
406 std::vector<at::Tensor> inputs = {out, x0, x1};
407 std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
408 stack.push_back(32);
409 stack.push_back(10);
410 k.runWithAllocatedOutputs(stack);
411
412 ASSERT_TRUE(at::allclose(out, ref));
413 }
414#endif
415}
416
417TEST(DynamicShapes, GraphWithCatAndBroadcast) {
418#ifdef TORCH_ENABLE_LLVM
419 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
420 const auto graph_string = R"IR(
421 graph(%x : Float(10, 5, requires_grad=0, device=cpu),
422 %y : Float(4, 5, requires_grad=0, device=cpu),
423 %z : Float(1, 1, requires_grad=0, device=cpu),
424 %SS_2 : int,
425 %SS_3 : int,
426 %SS_4 : int,
427 %SS_5 : int):
428 %11 : int = prim::Constant[value=0]()
429 %3 : Tensor = aten::tanh(%x)
430 %out1 : Tensor = aten::erf(%3)
431 %out2 : Tensor = aten::relu(%y)
432 %10 : Tensor[] = prim::ListConstruct(%out1, %out2)
433 %25 : Tensor = aten::cat(%10, %11)
434 %28 : Tensor = aten::hardswish(%25)
435 %29 : Tensor = aten::mul(%28, %z)
436 return (%29))IR";
437 torch::jit::parseIR(graph_string, graph.get());
438
439 auto x_inp = graph->inputs()[0];
440 auto y_inp = graph->inputs()[1];
441 auto z_inp = graph->inputs()[2];
442 auto x_type = TensorType::create(at::rand({10, 5}));
443 auto y_type = TensorType::create(at::rand({4, 5}));
444 auto z_type = TensorType::create(at::rand({1, 1}));
445 auto x_dim0_sym = c10::ShapeSymbol::newSymbol();
446 auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
447 auto x_sym_type = x_type->withSymbolicShapes(
448 std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym}));
449 auto y_dim0_sym = c10::ShapeSymbol::newSymbol();
450 auto y_sym_type = y_type->withSymbolicShapes(
451 std::vector<ShapeSymbol>({y_dim0_sym, x_dim1_sym}));
452 graph->inputs().at(0)->setType(x_sym_type);
453 graph->inputs().at(1)->setType(y_sym_type);
454 auto cat_dim0_sym = c10::ShapeSymbol::newSymbol();
455 auto cat_out_type = x_type->withSymbolicShapes(
456 std::vector<ShapeSymbol>({cat_dim0_sym, x_dim1_sym}));
457 auto nodeIt = graph->nodes().begin();
458 ++nodeIt;
459 nodeIt->output()->setType(x_sym_type); // aten::tanh
460 ++nodeIt;
461 nodeIt->output()->setType(x_sym_type); // aten::erf
462 ++nodeIt;
463 nodeIt->output()->setType(y_sym_type); // aten::relu
464 ++nodeIt;
465 ++nodeIt;
466 nodeIt->output()->setType(cat_out_type); // aten::cat
467 ++nodeIt;
468 nodeIt->output()->setType(cat_out_type); // aten::hardswish
469 ++nodeIt;
470 nodeIt->output()->setType(cat_out_type); // aten::mul
471
472 // Graph with symbolic shapes:
473 //
474 // graph(%x : Float(SS(-2), SS(-3)),
475 // %y : Float(SS(-4), SS(-3)),
476 // %z : Float(1, 1),
477 // %SS_2 : int,
478 // %SS_3 : int,
479 // %SS_4 : int,
480 // %SS_5 : int):
481 // %7 : int = prim::Constant[value=0]()
482 // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x)
483 // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8)
484 // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y)
485 // %11 : Tensor[] = prim::ListConstruct(%9, %10)
486 // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7)
487 // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12)
488 // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z)
489 // return (%14)
490
491 std::vector<int64_t> symbolic_shape_inputs(
492 {x_dim0_sym.value(),
493 x_dim1_sym.value(),
494 y_dim0_sym.value(),
495 cat_dim0_sym.value()});
496
497 std::vector<torch::jit::StrideInput> input_desc = {
498 torch::jit::StrideInput::TENSOR_CONT};
499 std::unordered_map<
500 const torch::jit::Value*,
501 std::vector<torch::jit::StrideInput>>
502 symbolic_strides;
503 symbolic_strides[x_inp] = input_desc;
504 symbolic_strides[y_inp] = input_desc;
505 symbolic_strides[z_inp] = input_desc;
506 symbolic_strides[graph->outputs().at(0)] = input_desc;
507
508 TensorExprKernel kernel(
509 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
510
511 auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
512 auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
513 auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
514 auto ref = at::mul(
515 at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c);
516
517 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
518 stack.push_back(10);
519 stack.push_back(5);
520 stack.push_back(4);
521 stack.push_back(14);
522 kernel.run(stack);
523
524 auto o = stack[0].toTensor();
525 ASSERT_TRUE(at::allclose(o, ref));
526#endif
527}
528
529TEST(DynamicShapes, GraphFromModel) {
530#ifdef TORCH_ENABLE_LLVM
531 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
532 const auto graph_string = R"IR(
533 graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
534 %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu),
535 %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu),
536 %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu),
537 %4 : Float(SS(-7), requires_grad=0, device=cpu),
538 %5 : Float(SS(-7), requires_grad=0, device=cpu),
539 %SS_10 : int,
540 %SS_9 : int,
541 %SS_8 : int,
542 %SS_7 : int,
543 %SS_6 : int,
544 %SS_5 : int,
545 %SS_4 : int,
546 %SS_3 : int,
547 %SS_2 : int):
548 %15 : int = prim::Constant[value=1]()
549 %16 : bool = prim::Constant[value=0]()
550 %17 : int = prim::Constant[value=6]()
551 %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16)
552 %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2)
553 %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15)
554 %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15)
555 %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4)
556 return (%22))IR";
557 parseIR(graph_string, &*graph);
558
559 std::vector<torch::jit::StrideInput> input_desc = {
560 torch::jit::StrideInput::TENSOR_CONT};
561 std::unordered_map<
562 const torch::jit::Value*,
563 std::vector<torch::jit::StrideInput>>
564 symbolic_strides;
565 symbolic_strides[graph->inputs().at(0)] = input_desc;
566 symbolic_strides[graph->inputs().at(1)] = input_desc;
567 symbolic_strides[graph->inputs().at(2)] = input_desc;
568 symbolic_strides[graph->inputs().at(3)] = input_desc;
569 symbolic_strides[graph->inputs().at(4)] = input_desc;
570 symbolic_strides[graph->inputs().at(5)] = input_desc;
571 symbolic_strides[graph->outputs().at(0)] = input_desc;
572 std::vector<int64_t> symbolic_shape_inputs = {
573 -10, -9, -8, -7, -6, -5, -4, -3, -2};
574 TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides);
575
576 int64_t i2 = 10;
577 int64_t i3 = 32;
578 int64_t i4 = 19;
579 int64_t i5 = 71;
580 int64_t i6 = 139;
581 int64_t i7 = 261;
582 int64_t i8 = 261;
583 int64_t i9 = 261;
584 int64_t i10 = 261;
585 auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
586 auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
587 auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
588 auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong));
589 auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
590 auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
591 auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4);
592
593 {
594 std::vector<at::Tensor> inputs = {x0, x1, x2, x3, x4, x5};
595 std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
596 stack.emplace_back(i10);
597 stack.emplace_back(i9);
598 stack.emplace_back(i8);
599 stack.emplace_back(i7);
600 stack.emplace_back(i6);
601 stack.emplace_back(i5);
602 stack.emplace_back(i4);
603 stack.emplace_back(i3);
604 stack.emplace_back(i2);
605 k.run(stack);
606
607 auto o = stack[0].toTensor();
608 ASSERT_TRUE(at::allclose(o, ref));
609 }
610
611 {
612 auto out =
613 at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
614 std::vector<at::Tensor> inputs = {out, x0, x1, x2, x3, x4, x5};
615 std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
616 stack.emplace_back(i10);
617 stack.emplace_back(i9);
618 stack.emplace_back(i8);
619 stack.emplace_back(i7);
620 stack.emplace_back(i6);
621 stack.emplace_back(i5);
622 stack.emplace_back(i4);
623 stack.emplace_back(i3);
624 stack.emplace_back(i2);
625 k.runWithAllocatedOutputs(stack);
626
627 ASSERT_TRUE(at::allclose(out, ref));
628 }
629#endif
630}
631
632TEST(DynamicShapes, MultiThreadedExecution) {
633#ifdef TORCH_ENABLE_LLVM
634 const auto graph_template = R"IR(
635 graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
636 %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
637 %SS_2 : int,
638 %SS_3 : int):
639 %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x)
640 %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3)
641 %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y)
642 return (%5))IR";
643 for (bool use_cuda : {false, true}) {
644 if (!torch::cuda::is_available() && use_cuda) {
645 continue;
646 }
647 auto device = use_cuda ? at::kCUDA : at::kCPU;
648 at::jit::TemplateEnv env;
649 env.s("device", use_cuda ? "cuda:0" : "cpu");
650 const auto graph_string = format(graph_template, env);
651 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
652 torch::jit::parseIR(graph_string, graph.get());
653
654 std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
655
656 std::vector<torch::jit::StrideInput> input_desc = {
657 torch::jit::StrideInput::TENSOR_CONT};
658 std::unordered_map<
659 const torch::jit::Value*,
660 std::vector<torch::jit::StrideInput>>
661 symbolic_strides;
662 symbolic_strides[graph->inputs().at(0)] = input_desc;
663 symbolic_strides[graph->inputs().at(1)] = input_desc;
664 symbolic_strides[graph->outputs().at(0)] = input_desc;
665
666 TensorExprKernel kernel(
667 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
668
669 auto run_kernel = [&](int dim1, int dim2) {
670 auto a =
671 at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
672 auto b =
673 at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
674
675 auto ref = at::mul(at::erf(at::tanh(a)), b);
676
677 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
678 stack.emplace_back(dim1);
679 stack.emplace_back(dim2);
680 kernel.run(stack);
681
682 auto o = stack[0].toTensor();
683 ASSERT_TRUE(at::allclose(o, ref));
684 };
685
686 // Run the kernel in parallel to ensure that the run() method calls in
687 // TensorExprKernel are not changing any state.
688 constexpr size_t kNumThreads = 4;
689 std::vector<std::thread> threads;
690 for (size_t id = 0; id < kNumThreads; ++id) {
691 threads.emplace_back(run_kernel, id + 5, id + 20);
692 }
693 for (auto& t : threads) {
694 t.join();
695 }
696 }
697#endif
698}
699
700} // namespace jit
701} // namespace torch
702