1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/core/interned_strings.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <test/cpp/jit/test_utils.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/jit/ir/ir_views.h> |
9 | #include <torch/csrc/jit/ir/irparser.h> |
10 | #include <torch/csrc/jit/passes/constant_propagation.h> |
11 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
12 | #include <torch/csrc/jit/passes/symbolic_shape_cache.h> |
13 | #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h> |
14 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
15 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
16 | #include <torch/csrc/jit/runtime/interpreter.h> |
17 | #include <torch/csrc/jit/testing/file_check.h> |
18 | #include <torch/cuda.h> |
19 | #include <unordered_map> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | |
24 | namespace { |
25 | |
26 | Node* findNode(std::shared_ptr<Graph>& g, Symbol k) { |
27 | DepthFirstGraphNodeIterator graph_it(g); |
28 | for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) { |
29 | if (node->kind() == k) { |
30 | return node; |
31 | } |
32 | } |
33 | TORCH_INTERNAL_ASSERT(false, "Couldn't find node" ); |
34 | } |
35 | } // namespace |
36 | |
37 | TEST(ShapeAnalysisTest, DynamicShapesFusion) { |
38 | // Test Generalizing shapes to symbolic dimensions, guarding those symbolic |
39 | // dimensions and passing in runtime computed symbolic dimensions via inlined |
40 | // shape functions |
41 | std::shared_ptr<Graph> subgraph = std::make_shared<Graph>(); |
42 | const auto graph_string = R"IR( |
43 | graph(%x.1 : Tensor, %y.1 : Tensor, %z: Tensor): |
44 | %11 : int = prim::Constant[value=0]() |
45 | %3 : Tensor = aten::tanh(%x.1) |
46 | %out1.1 : Tensor = aten::erf(%3) |
47 | %out2.1 : Tensor = aten::relu(%y.1) |
48 | %10 : Tensor[] = prim::ListConstruct(%out1.1, %out2.1) |
49 | %25 : Tensor = aten::cat(%10, %11) |
50 | %28 : Tensor = aten::hardswish(%25) |
51 | %29 : Tensor = aten::mul(%28, %z) |
52 | return (%28))IR" ; |
53 | torch::jit::parseIR(graph_string, subgraph.get()); |
54 | |
55 | /* |
56 | set up fused TensorExprGroup |
57 | */ |
58 | |
59 | std::shared_ptr<Graph> g = std::make_shared<Graph>(); |
60 | auto x_inp = g->addInput("x_inp" ); |
61 | auto y_inp = g->addInput("y_inp" ); |
62 | auto z_inp = g->addInput("z_inp" ); |
63 | auto x_type = TensorType::create(at::rand({10, 5})); |
64 | auto y_type = TensorType::create(at::rand({4, 5})); |
65 | auto z_type = TensorType::create(at::rand({1, 1})); |
66 | x_inp->setType(x_type); |
67 | y_inp->setType(y_type); |
68 | z_inp->setType(z_type); |
69 | subgraph->inputs().at(0)->setType(x_type); |
70 | subgraph->inputs().at(1)->setType(y_type); |
71 | subgraph->inputs().at(2)->setType(z_type); |
72 | subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5}))); |
73 | auto output = g->insertNode(g->create(prim::TensorExprGroup))->output(); |
74 | subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5}))); |
75 | output->node()->addInput(x_inp); |
76 | output->node()->addInput(y_inp); |
77 | output->node()->addInput(z_inp); |
78 | output->node()->g_(attr::Subgraph, subgraph); |
79 | |
80 | auto success = GenerateGuard(output->node()); |
81 | TORCH_INTERNAL_ASSERT(success); |
82 | testing::FileCheck() |
83 | .check("TensorExprDynamicGuard" ) |
84 | ->check_next("prim::If" ) |
85 | ->check("aten::add" ) |
86 | ->check("TensorExprGroup" ) |
87 | ->check_same("symbolic_shape_inputs" ) |
88 | ->check("block1" ) |
89 | ->check("aten::cat" ) |
90 | ->run(*g); |
91 | |
92 | // clang-format off |
93 | /* Graph Should Look Something like: (note: strides not yet handled) |
94 | graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu), |
95 | %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu), |
96 | %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)): |
97 | %4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp) |
98 | %5 : Tensor = prim::If(%4) |
99 | block0(): |
100 | %15 : int[] = aten::size(%x_inp) |
101 | %16 : int[] = aten::size(%y_inp) |
102 | %17 : int = prim::Constant[value=1]() |
103 | %18 : int = prim::Constant[value=0]() |
104 | %elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10 |
105 | %elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10 |
106 | %elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10 |
107 | %cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29 |
108 | %3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3) |
109 | -> (%3) |
110 | block1(): |
111 | // FallbackGraph is inlined |
112 | %14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp) |
113 | -> (%14) |
114 | return () |
115 | with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), |
116 | %y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), |
117 | %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu), |
118 | %SS_5 : int, |
119 | %SS_4 : int, |
120 | %SS_3 : int, |
121 | %SS_2 : int): |
122 | %3 : int = prim::Constant[value=0]() |
123 | %4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1) |
124 | %5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4) |
125 | %6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1) |
126 | %7 : Tensor[] = prim::ListConstruct(%5, %6) |
127 | %8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3) |
128 | %9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8) |
129 | %10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z) |
130 | return (%9) |
131 | */ |
132 | // clang-format on |
133 | |
134 | DepthFirstGraphNodeIterator graph_it(g); |
135 | Node* te_group = findNode(g, prim::TensorExprGroup); |
136 | |
137 | /* |
138 | Test that input to the kernel - (10, 5), (4, 5), (1, 1) - are correctly |
139 | generalized to sym dimensions, and that the output - (10 + 4, 5) |
140 | correctly preserves non-catted dim as sym shape and catted dim as new sym |
141 | shape |
142 | */ |
143 | |
144 | auto tensorexpr_graph = te_group->g(attr::Subgraph); |
145 | auto inp1 = tensorexpr_graph->inputs().at(0)->type()->expect<TensorType>(); |
146 | auto inp2 = tensorexpr_graph->inputs().at(1)->type()->expect<TensorType>(); |
147 | auto inp3 = tensorexpr_graph->inputs().at(2)->type()->expect<TensorType>(); |
148 | auto out = tensorexpr_graph->outputs().at(0)->type()->expect<TensorType>(); |
149 | |
150 | // 1 dims are preserved |
151 | auto inp3_sizes = inp3->sizes().concrete_sizes(); |
152 | TORCH_INTERNAL_ASSERT(inp3_sizes); |
153 | TORCH_INTERNAL_ASSERT( |
154 | inp3_sizes->size() == 2 && inp3_sizes->at(0) == 1 && |
155 | inp3_sizes->at(1) == 1); |
156 | |
157 | // 5 made into sym shape |
158 | ASSERT_EQ( |
159 | inp1->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value()); |
160 | ASSERT_EQ( |
161 | out->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value()); |
162 | |
163 | // 4, 10, 14 are different sym shapes |
164 | ASSERT_NE( |
165 | inp1->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value()); |
166 | ASSERT_NE( |
167 | out->symbolic_sizes()[0].value(), inp1->symbolic_sizes()[0].value()); |
168 | ASSERT_NE( |
169 | out->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value()); |
170 | |
171 | /* |
172 | Test guard behaves correctly at runtime and symbolic shapes are computed |
173 | correctly. As we don't have TE Kernel support for dynamic shapes we're |
174 | going to return all of the computed runtime symbolic dimensions as outputs |
175 | of the graph on guard success, and return None on guard failure |
176 | */ |
177 | |
178 | // Setting up guard to return sym shapes on guard success and None on failure |
179 | Node* if_node = findNode(g, prim::If); |
180 | IfView if_v(if_node); |
181 | if_node->eraseOutput(0); |
182 | if_v.thenBlock()->eraseOutput(0); |
183 | if_v.elseBlock()->eraseOutput(0); |
184 | WithInsertPoint guard(if_node); |
185 | auto none_val = g->insertConstant(IValue()); |
186 | |
187 | auto sym_shapes = te_group->is(Symbol::attr("symbolic_shape_inputs" )); |
188 | auto offset = te_group->inputs().size() - sym_shapes.size(); |
189 | for (size_t i = 0; i < sym_shapes.size(); ++i) { |
190 | if_v.thenBlock()->insertOutput(i, te_group->inputs().at(offset + i)); |
191 | if_v.elseBlock()->insertOutput(i, none_val); |
192 | if_node->insertOutput(i)->setType(OptionalType::create(IntType::get())); |
193 | } |
194 | |
195 | auto new_outputs = g->createTuple(if_node->outputs())->insertAfter(if_node); |
196 | |
197 | g->registerOutput(new_outputs->output()); |
198 | te_group->destroy(); |
199 | findNode(g, prim::FallbackGraph)->destroy(); |
200 | |
201 | // Testing bad inputs |
202 | |
203 | auto first_inp = at::rand({2, 5}); |
204 | std::vector<std::vector<at::Tensor>> second_inps = { |
205 | {at::rand({3, 4}), at::rand({1, 1})}, // sym shape mismatch |
206 | {at::rand({5, 2}).transpose(0, 1), at::rand({1, 1})}, // discontiguous |
207 | {at::zeros({2, 5}).to(at::ScalarType::Int), |
208 | at::rand({1, 1})}, // wrong dtype |
209 | {at::rand({2, 5, 1}), at::rand({1, 1})}, // wrong # dims |
210 | {at::rand({2, 5}).requires_grad_(true), |
211 | at::rand({1, 1})}, // requires grad |
212 | {at::rand({2, 5}), at::rand({1, 12})}, // concrete dim mismatch (1) |
213 | }; |
214 | if (torch::cuda::is_available()) { |
215 | second_inps.push_back({at::rand({2, 5}).cuda(), at::rand({1, 1})}); |
216 | } |
217 | for (const auto& last_inps : second_inps) { |
218 | // todo - reusing interpreter across iters gave error |
219 | Code code(g, "" ); |
220 | InterpreterState interp(code); |
221 | auto stack = createStack({at::rand({2, 5}), last_inps[0], last_inps[1]}); |
222 | interp.run(stack); |
223 | TORCH_INTERNAL_ASSERT(pop(stack).toTuple()->elements().at(0).isNone()); |
224 | } |
225 | |
226 | // Test good inputs |
227 | Code code(g, "" ); |
228 | InterpreterState interp(code); |
229 | std::vector<at::Tensor> inps = { |
230 | at::rand({2, 5}), at::rand({4, 5}), at::rand({1, 1})}; |
231 | Stack stack(inps.begin(), inps.end()); |
232 | interp.run(stack); |
233 | auto tuple = pop(stack).toTuple(); |
234 | TORCH_INTERNAL_ASSERT(tuple->elements().at(0).isInt()); |
235 | |
236 | // Testing that the sym shape calculation was correct |
237 | for (size_t i = 0; i < sym_shapes.size(); ++i) { |
238 | auto sym_shape = sym_shapes[i]; |
239 | auto computed_value = tuple->elements().at(i).toInt(); |
240 | if (sym_shape == inp1->symbolic_sizes().at(0).value()) { |
241 | ASSERT_EQ(computed_value, 2); |
242 | } else if (sym_shape == inp1->symbolic_sizes().at(1).value()) { |
243 | ASSERT_EQ(computed_value, 5); |
244 | } else if (sym_shape == inp2->symbolic_sizes().at(0).value()) { |
245 | ASSERT_EQ(computed_value, 4); |
246 | } else if (sym_shape == out->symbolic_sizes().at(0).value()) { |
247 | ASSERT_EQ(computed_value, 6); |
248 | } else { |
249 | TORCH_INTERNAL_ASSERT(false); |
250 | } |
251 | } |
252 | } |
253 | |
254 | TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) { |
255 | std::shared_ptr<Graph> subgraph = std::make_shared<Graph>(); |
256 | const auto graph_string = R"IR( |
257 | graph(%x.1 : Tensor): |
258 | %none : NoneType = prim::Constant() |
259 | %size1 : int = prim::Constant[value=1]() |
260 | %size10 : int = prim::Constant[value=10]() |
261 | %sizes : int[] = prim::ListConstruct(%size10, %size1) |
262 | %device : Device = prim::Constant[value="cpu"]() |
263 | %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none) |
264 | %3 : Tensor = aten::tanh(%x.1) |
265 | %29 : Tensor = aten::mul(%3, %10) |
266 | return (%29))IR" ; |
267 | torch::jit::parseIR(graph_string, subgraph.get()); |
268 | ConstantPropagation(subgraph); |
269 | |
270 | std::shared_ptr<Graph> g = std::make_shared<Graph>(); |
271 | auto x_inp = g->addInput("x_inp" ); |
272 | auto x_type = TensorType::create(at::rand({10, 5})); |
273 | x_inp->setType(x_type); |
274 | subgraph->inputs().at(0)->setType(x_type); |
275 | subgraph->outputs().at(0)->setType(x_type); |
276 | auto output = g->insertNode(g->create(prim::TensorExprGroup))->output(); |
277 | output->node()->addInput(x_inp); |
278 | output->node()->g_(attr::Subgraph, subgraph); |
279 | |
280 | auto success = GenerateGuard(output->node()); |
281 | TORCH_INTERNAL_ASSERT(success); |
282 | |
283 | // Check that the constants have been moved out of the fused graph. |
284 | // This should result in not have any conditionals other than the one |
285 | // checking the result of TensorExprDynamicGuard. |
286 | testing::FileCheck() |
287 | .check("TensorExprDynamicGuard" ) |
288 | ->check_next("prim::If" ) |
289 | ->check_not("prim::If" ) // no other IFs due to constants. |
290 | ->check("TensorExprGroup" ) |
291 | ->check("block1" ) |
292 | ->check("FallbackGraph" ) |
293 | ->run(*g); |
294 | } |
295 | |
296 | namespace { |
297 | |
298 | c10::optional<int64_t> sym_dim = c10::nullopt; |
299 | |
300 | // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) |
301 | void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) { |
302 | auto a_canonical = CanonicalizedSymbolicShape(a); |
303 | auto e_canonical = CanonicalizedSymbolicShape(e); |
304 | EXPECT_EQ(a_canonical, e_canonical); |
305 | } |
306 | |
307 | void assertShapeEqual( |
308 | c10::optional<std::vector<c10::SymbolicShape>>& actual, |
309 | std::vector<c10::optional<int64_t>> expected) { |
310 | ASSERT_TRUE(actual.has_value()); |
311 | ASSERT_EQ(actual->size(), 1); |
312 | |
313 | auto symb_expected = c10::SymbolicShape(expected); |
314 | assertShapeEqual(actual->at(0), symb_expected); |
315 | } |
316 | |
317 | const FunctionSchema* getSchema(const char* name) { |
318 | return &(getOperatorForLiteral(name)->schema()); |
319 | } |
320 | } // namespace |
321 | |
322 | TEST(ShapeAnalysisTest, SymbolicShapeAPI) { |
323 | // Figure out how to fetch a function schema |
324 | |
325 | // Ask someone else how to create a function schema / operator in C++ |
326 | auto schema = getSchema( |
327 | "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" ); |
328 | |
329 | c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56}; |
330 | c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56}; |
331 | |
332 | // Check vector initializer list syntax |
333 | c10::SymbolicShape ss_concrete = |
334 | std::vector<c10::optional<int64_t>>{1, 56, 56}; |
335 | c10::SymbolicShape ss1 = std::vector<c10::optional<int64_t>>{sym_dim, 56, 56}; |
336 | c10::SymbolicShape ss2 = |
337 | std::vector<c10::optional<int64_t>>{64, sym_dim, sym_dim}; |
338 | c10::SymbolicShape ss3 = |
339 | std::vector<c10::optional<int64_t>>{sym_dim, sym_dim, sym_dim, sym_dim}; |
340 | |
341 | auto res = calculateSymbolicShapesOnOp( |
342 | schema, std::vector<SSAInput>{const_size_1, const_size_1}); |
343 | assertShapeEqual(res, {64, 56, 56}); |
344 | |
345 | res = calculateSymbolicShapesOnOp( |
346 | schema, std::vector<SSAInput>{const_size_1, const_size_2}); |
347 | assertShapeEqual(res, {64, 56, 56}); |
348 | |
349 | res = calculateSymbolicShapesOnOp( |
350 | schema, std::vector<SSAInput>{const_size_1, ss1}); |
351 | assertShapeEqual(res, {64, 56, 56}); |
352 | |
353 | res = calculateSymbolicShapesOnOp( |
354 | schema, std::vector<SSAInput>{const_size_2, ss1}); |
355 | assertShapeEqual(res, {sym_dim, 56, 56}); |
356 | |
357 | res = calculateSymbolicShapesOnOp( |
358 | schema, std::vector<SSAInput>{ss_concrete, ss2}); |
359 | assertShapeEqual(res, {64, 56, 56}); |
360 | |
361 | res = calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{ss2, ss3}); |
362 | assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim}); |
363 | } |
364 | |
365 | TEST(ShapeAnalysisTest, BoundedSymbolicShapes) { |
366 | auto schema = getSchema("aten::nonzero(Tensor self) -> (Tensor)" ); |
367 | |
368 | // Test that we generate symbolic shapes for the output of a nonzero op |
369 | c10::IValue const_size_1 = std::vector<int64_t>{5, 10}; |
370 | auto res = |
371 | calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_1}); |
372 | assertShapeEqual(res, {sym_dim, 2}); |
373 | |
374 | // Test that nonzero can also create concrete shapes |
375 | c10::IValue const_size_2 = std::vector<int64_t>({1, 0}); |
376 | res = |
377 | calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_2}); |
378 | assertShapeEqual(res, {0, 2}); |
379 | } |
380 | |
381 | TEST(ShapeAnalysisTest, SymbolicShapeCaching) { |
382 | clear_shape_cache(); |
383 | auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor" ); |
384 | |
385 | c10::IValue const_size_1 = std::vector<int64_t>{64, 56}; |
386 | c10::IValue const_size_2 = std::vector<int64_t>{64, 56}; |
387 | c10::IValue const_size_3 = std::vector<int64_t>{64, 20}; |
388 | |
389 | c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64}); |
390 | c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64}); |
391 | c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim}); |
392 | |
393 | auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1}); |
394 | assertShapeEqual(res, {sym_dim, 56}); |
395 | auto res1_val = res->at(0); |
396 | |
397 | // The exact same arguments should return the exact same result |
398 | res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1}); |
399 | auto res2_val = res->at(0); |
400 | EXPECT_EQ(res1_val, res2_val); |
401 | EXPECT_EQ(get_shape_cache_size(), 1); |
402 | |
403 | // Same shape but different symbols should return same shape |
404 | // but different symbolic indicies |
405 | res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2}); |
406 | auto res3_val = res->at(0); |
407 | |
408 | assertShapeEqual(res3_val, res2_val); |
409 | EXPECT_NE(res3_val, res2_val); |
410 | EXPECT_EQ(get_shape_cache_size(), 1); |
411 | |
412 | // Different concrete shape should be cached separately |
413 | res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3}); |
414 | assertShapeEqual(res, {sym_dim, 20}); |
415 | EXPECT_EQ(get_shape_cache_size(), 2); |
416 | |
417 | res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3}); |
418 | assertShapeEqual(res, {sym_dim, 20}); |
419 | EXPECT_EQ(get_shape_cache_size(), 3); |
420 | |
421 | res = calculateSymbolicShapesOnOp(schema, {ss3, ss3}); |
422 | assertShapeEqual(res, {sym_dim, sym_dim}); |
423 | EXPECT_EQ(get_shape_cache_size(), 4); |
424 | } |
425 | |
426 | TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) { |
427 | clear_shape_cache(); |
428 | |
429 | auto squeeze_op = |
430 | getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)" ); |
431 | auto mul_tensor = |
432 | getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor" ); |
433 | auto mul_scalar = |
434 | getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor" ); |
435 | auto div_tensor = |
436 | getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor" ); |
437 | auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor" ); |
438 | |
439 | c10::IValue const_int = 1; |
440 | |
441 | c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64}); |
442 | |
443 | auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int}); |
444 | assertShapeEqual(res, {sym_dim, 64}); |
445 | |
446 | // Show that cache can handle multiple functions |
447 | res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int}); |
448 | assertShapeEqual(res, {sym_dim, 64}); |
449 | EXPECT_EQ(get_shape_cache_size(), 2); |
450 | |
451 | res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1}); |
452 | assertShapeEqual(res, {sym_dim, 64}); |
453 | EXPECT_EQ(get_shape_cache_size(), 3); |
454 | |
455 | // Even when the expected outcome is the same, should not collide |
456 | res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1}); |
457 | assertShapeEqual(res, {sym_dim, 64}); |
458 | EXPECT_EQ(get_shape_cache_size(), 4); |
459 | |
460 | // Don't lose cached objects |
461 | res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int}); |
462 | assertShapeEqual(res, {sym_dim, 64}); |
463 | EXPECT_EQ(get_shape_cache_size(), 4); |
464 | |
465 | res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1}); |
466 | // SSA can infer that sym_dim is 64 as both tensors |
467 | // use the same sym_dim |
468 | assertShapeEqual(res, {64, 64}); |
469 | EXPECT_EQ(get_shape_cache_size(), 5); |
470 | } |
471 | |
472 | TEST(ShapeAnalysisTest, TestShapeMultipleReturns) { |
473 | clear_shape_cache(); |
474 | |
475 | auto max_dim_op = getSchema( |
476 | "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" ); |
477 | c10::IValue const_int = 1; |
478 | c10::IValue false_ival = false; |
479 | |
480 | c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64}); |
481 | c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64}); |
482 | |
483 | auto res = |
484 | calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival}); |
485 | c10::SymbolicShape expected_res = c10::SymbolicShape({sym_dim}); |
486 | assertShapeEqual(res->at(0), expected_res); |
487 | // res0 and res1 should share the same symbolic symbol |
488 | EXPECT_EQ(res->at(0), res->at(1)); |
489 | |
490 | // Also test that the shape cache also returns consistent result shapes |
491 | res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival}); |
492 | assertShapeEqual(res->at(0), expected_res); |
493 | EXPECT_EQ(res->at(0), res->at(1)); |
494 | EXPECT_EQ(get_shape_cache_size(), 1); |
495 | } |
496 | } // namespace jit |
497 | } // namespace torch |
498 | |