1#include <gtest/gtest.h>
2
3#include <ATen/core/interned_strings.h>
4#include <c10/util/Exception.h>
5#include <c10/util/Optional.h>
6#include <test/cpp/jit/test_utils.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <torch/csrc/jit/ir/ir_views.h>
9#include <torch/csrc/jit/ir/irparser.h>
10#include <torch/csrc/jit/passes/constant_propagation.h>
11#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
12#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
13#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
14#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
15#include <torch/csrc/jit/runtime/graph_iterator.h>
16#include <torch/csrc/jit/runtime/interpreter.h>
17#include <torch/csrc/jit/testing/file_check.h>
18#include <torch/cuda.h>
19#include <unordered_map>
20
21namespace torch {
22namespace jit {
23
24namespace {
25
26Node* findNode(std::shared_ptr<Graph>& g, Symbol k) {
27 DepthFirstGraphNodeIterator graph_it(g);
28 for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
29 if (node->kind() == k) {
30 return node;
31 }
32 }
33 TORCH_INTERNAL_ASSERT(false, "Couldn't find node");
34}
35} // namespace
36
37TEST(ShapeAnalysisTest, DynamicShapesFusion) {
38 // Test Generalizing shapes to symbolic dimensions, guarding those symbolic
39 // dimensions and passing in runtime computed symbolic dimensions via inlined
40 // shape functions
41 std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
42 const auto graph_string = R"IR(
43 graph(%x.1 : Tensor, %y.1 : Tensor, %z: Tensor):
44 %11 : int = prim::Constant[value=0]()
45 %3 : Tensor = aten::tanh(%x.1)
46 %out1.1 : Tensor = aten::erf(%3)
47 %out2.1 : Tensor = aten::relu(%y.1)
48 %10 : Tensor[] = prim::ListConstruct(%out1.1, %out2.1)
49 %25 : Tensor = aten::cat(%10, %11)
50 %28 : Tensor = aten::hardswish(%25)
51 %29 : Tensor = aten::mul(%28, %z)
52 return (%28))IR";
53 torch::jit::parseIR(graph_string, subgraph.get());
54
55 /*
56 set up fused TensorExprGroup
57 */
58
59 std::shared_ptr<Graph> g = std::make_shared<Graph>();
60 auto x_inp = g->addInput("x_inp");
61 auto y_inp = g->addInput("y_inp");
62 auto z_inp = g->addInput("z_inp");
63 auto x_type = TensorType::create(at::rand({10, 5}));
64 auto y_type = TensorType::create(at::rand({4, 5}));
65 auto z_type = TensorType::create(at::rand({1, 1}));
66 x_inp->setType(x_type);
67 y_inp->setType(y_type);
68 z_inp->setType(z_type);
69 subgraph->inputs().at(0)->setType(x_type);
70 subgraph->inputs().at(1)->setType(y_type);
71 subgraph->inputs().at(2)->setType(z_type);
72 subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
73 auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
74 subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
75 output->node()->addInput(x_inp);
76 output->node()->addInput(y_inp);
77 output->node()->addInput(z_inp);
78 output->node()->g_(attr::Subgraph, subgraph);
79
80 auto success = GenerateGuard(output->node());
81 TORCH_INTERNAL_ASSERT(success);
82 testing::FileCheck()
83 .check("TensorExprDynamicGuard")
84 ->check_next("prim::If")
85 ->check("aten::add")
86 ->check("TensorExprGroup")
87 ->check_same("symbolic_shape_inputs")
88 ->check("block1")
89 ->check("aten::cat")
90 ->run(*g);
91
92 // clang-format off
93 /* Graph Should Look Something like: (note: strides not yet handled)
94 graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
95 %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
96 %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
97 %4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp)
98 %5 : Tensor = prim::If(%4)
99 block0():
100 %15 : int[] = aten::size(%x_inp)
101 %16 : int[] = aten::size(%y_inp)
102 %17 : int = prim::Constant[value=1]()
103 %18 : int = prim::Constant[value=0]()
104 %elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10
105 %elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10
106 %elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10
107 %cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29
108 %3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
109 -> (%3)
110 block1():
111 // FallbackGraph is inlined
112 %14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
113 -> (%14)
114 return ()
115 with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
116 %y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
117 %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
118 %SS_5 : int,
119 %SS_4 : int,
120 %SS_3 : int,
121 %SS_2 : int):
122 %3 : int = prim::Constant[value=0]()
123 %4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1)
124 %5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4)
125 %6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1)
126 %7 : Tensor[] = prim::ListConstruct(%5, %6)
127 %8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3)
128 %9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8)
129 %10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z)
130 return (%9)
131 */
132 // clang-format on
133
134 DepthFirstGraphNodeIterator graph_it(g);
135 Node* te_group = findNode(g, prim::TensorExprGroup);
136
137 /*
138 Test that input to the kernel - (10, 5), (4, 5), (1, 1) - are correctly
139 generalized to sym dimensions, and that the output - (10 + 4, 5)
140 correctly preserves non-catted dim as sym shape and catted dim as new sym
141 shape
142 */
143
144 auto tensorexpr_graph = te_group->g(attr::Subgraph);
145 auto inp1 = tensorexpr_graph->inputs().at(0)->type()->expect<TensorType>();
146 auto inp2 = tensorexpr_graph->inputs().at(1)->type()->expect<TensorType>();
147 auto inp3 = tensorexpr_graph->inputs().at(2)->type()->expect<TensorType>();
148 auto out = tensorexpr_graph->outputs().at(0)->type()->expect<TensorType>();
149
150 // 1 dims are preserved
151 auto inp3_sizes = inp3->sizes().concrete_sizes();
152 TORCH_INTERNAL_ASSERT(inp3_sizes);
153 TORCH_INTERNAL_ASSERT(
154 inp3_sizes->size() == 2 && inp3_sizes->at(0) == 1 &&
155 inp3_sizes->at(1) == 1);
156
157 // 5 made into sym shape
158 ASSERT_EQ(
159 inp1->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
160 ASSERT_EQ(
161 out->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
162
163 // 4, 10, 14 are different sym shapes
164 ASSERT_NE(
165 inp1->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
166 ASSERT_NE(
167 out->symbolic_sizes()[0].value(), inp1->symbolic_sizes()[0].value());
168 ASSERT_NE(
169 out->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
170
171 /*
172 Test guard behaves correctly at runtime and symbolic shapes are computed
173 correctly. As we don't have TE Kernel support for dynamic shapes we're
174 going to return all of the computed runtime symbolic dimensions as outputs
175 of the graph on guard success, and return None on guard failure
176 */
177
178 // Setting up guard to return sym shapes on guard success and None on failure
179 Node* if_node = findNode(g, prim::If);
180 IfView if_v(if_node);
181 if_node->eraseOutput(0);
182 if_v.thenBlock()->eraseOutput(0);
183 if_v.elseBlock()->eraseOutput(0);
184 WithInsertPoint guard(if_node);
185 auto none_val = g->insertConstant(IValue());
186
187 auto sym_shapes = te_group->is(Symbol::attr("symbolic_shape_inputs"));
188 auto offset = te_group->inputs().size() - sym_shapes.size();
189 for (size_t i = 0; i < sym_shapes.size(); ++i) {
190 if_v.thenBlock()->insertOutput(i, te_group->inputs().at(offset + i));
191 if_v.elseBlock()->insertOutput(i, none_val);
192 if_node->insertOutput(i)->setType(OptionalType::create(IntType::get()));
193 }
194
195 auto new_outputs = g->createTuple(if_node->outputs())->insertAfter(if_node);
196
197 g->registerOutput(new_outputs->output());
198 te_group->destroy();
199 findNode(g, prim::FallbackGraph)->destroy();
200
201 // Testing bad inputs
202
203 auto first_inp = at::rand({2, 5});
204 std::vector<std::vector<at::Tensor>> second_inps = {
205 {at::rand({3, 4}), at::rand({1, 1})}, // sym shape mismatch
206 {at::rand({5, 2}).transpose(0, 1), at::rand({1, 1})}, // discontiguous
207 {at::zeros({2, 5}).to(at::ScalarType::Int),
208 at::rand({1, 1})}, // wrong dtype
209 {at::rand({2, 5, 1}), at::rand({1, 1})}, // wrong # dims
210 {at::rand({2, 5}).requires_grad_(true),
211 at::rand({1, 1})}, // requires grad
212 {at::rand({2, 5}), at::rand({1, 12})}, // concrete dim mismatch (1)
213 };
214 if (torch::cuda::is_available()) {
215 second_inps.push_back({at::rand({2, 5}).cuda(), at::rand({1, 1})});
216 }
217 for (const auto& last_inps : second_inps) {
218 // todo - reusing interpreter across iters gave error
219 Code code(g, "");
220 InterpreterState interp(code);
221 auto stack = createStack({at::rand({2, 5}), last_inps[0], last_inps[1]});
222 interp.run(stack);
223 TORCH_INTERNAL_ASSERT(pop(stack).toTuple()->elements().at(0).isNone());
224 }
225
226 // Test good inputs
227 Code code(g, "");
228 InterpreterState interp(code);
229 std::vector<at::Tensor> inps = {
230 at::rand({2, 5}), at::rand({4, 5}), at::rand({1, 1})};
231 Stack stack(inps.begin(), inps.end());
232 interp.run(stack);
233 auto tuple = pop(stack).toTuple();
234 TORCH_INTERNAL_ASSERT(tuple->elements().at(0).isInt());
235
236 // Testing that the sym shape calculation was correct
237 for (size_t i = 0; i < sym_shapes.size(); ++i) {
238 auto sym_shape = sym_shapes[i];
239 auto computed_value = tuple->elements().at(i).toInt();
240 if (sym_shape == inp1->symbolic_sizes().at(0).value()) {
241 ASSERT_EQ(computed_value, 2);
242 } else if (sym_shape == inp1->symbolic_sizes().at(1).value()) {
243 ASSERT_EQ(computed_value, 5);
244 } else if (sym_shape == inp2->symbolic_sizes().at(0).value()) {
245 ASSERT_EQ(computed_value, 4);
246 } else if (sym_shape == out->symbolic_sizes().at(0).value()) {
247 ASSERT_EQ(computed_value, 6);
248 } else {
249 TORCH_INTERNAL_ASSERT(false);
250 }
251 }
252}
253
254TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
255 std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
256 const auto graph_string = R"IR(
257 graph(%x.1 : Tensor):
258 %none : NoneType = prim::Constant()
259 %size1 : int = prim::Constant[value=1]()
260 %size10 : int = prim::Constant[value=10]()
261 %sizes : int[] = prim::ListConstruct(%size10, %size1)
262 %device : Device = prim::Constant[value="cpu"]()
263 %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none)
264 %3 : Tensor = aten::tanh(%x.1)
265 %29 : Tensor = aten::mul(%3, %10)
266 return (%29))IR";
267 torch::jit::parseIR(graph_string, subgraph.get());
268 ConstantPropagation(subgraph);
269
270 std::shared_ptr<Graph> g = std::make_shared<Graph>();
271 auto x_inp = g->addInput("x_inp");
272 auto x_type = TensorType::create(at::rand({10, 5}));
273 x_inp->setType(x_type);
274 subgraph->inputs().at(0)->setType(x_type);
275 subgraph->outputs().at(0)->setType(x_type);
276 auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
277 output->node()->addInput(x_inp);
278 output->node()->g_(attr::Subgraph, subgraph);
279
280 auto success = GenerateGuard(output->node());
281 TORCH_INTERNAL_ASSERT(success);
282
283 // Check that the constants have been moved out of the fused graph.
284 // This should result in not have any conditionals other than the one
285 // checking the result of TensorExprDynamicGuard.
286 testing::FileCheck()
287 .check("TensorExprDynamicGuard")
288 ->check_next("prim::If")
289 ->check_not("prim::If") // no other IFs due to constants.
290 ->check("TensorExprGroup")
291 ->check("block1")
292 ->check("FallbackGraph")
293 ->run(*g);
294}
295
296namespace {
297
298c10::optional<int64_t> sym_dim = c10::nullopt;
299
300// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
301void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
302 auto a_canonical = CanonicalizedSymbolicShape(a);
303 auto e_canonical = CanonicalizedSymbolicShape(e);
304 EXPECT_EQ(a_canonical, e_canonical);
305}
306
307void assertShapeEqual(
308 c10::optional<std::vector<c10::SymbolicShape>>& actual,
309 std::vector<c10::optional<int64_t>> expected) {
310 ASSERT_TRUE(actual.has_value());
311 ASSERT_EQ(actual->size(), 1);
312
313 auto symb_expected = c10::SymbolicShape(expected);
314 assertShapeEqual(actual->at(0), symb_expected);
315}
316
317const FunctionSchema* getSchema(const char* name) {
318 return &(getOperatorForLiteral(name)->schema());
319}
320} // namespace
321
322TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
323 // Figure out how to fetch a function schema
324
325 // Ask someone else how to create a function schema / operator in C++
326 auto schema = getSchema(
327 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
328
329 c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56};
330 c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
331
332 // Check vector initializer list syntax
333 c10::SymbolicShape ss_concrete =
334 std::vector<c10::optional<int64_t>>{1, 56, 56};
335 c10::SymbolicShape ss1 = std::vector<c10::optional<int64_t>>{sym_dim, 56, 56};
336 c10::SymbolicShape ss2 =
337 std::vector<c10::optional<int64_t>>{64, sym_dim, sym_dim};
338 c10::SymbolicShape ss3 =
339 std::vector<c10::optional<int64_t>>{sym_dim, sym_dim, sym_dim, sym_dim};
340
341 auto res = calculateSymbolicShapesOnOp(
342 schema, std::vector<SSAInput>{const_size_1, const_size_1});
343 assertShapeEqual(res, {64, 56, 56});
344
345 res = calculateSymbolicShapesOnOp(
346 schema, std::vector<SSAInput>{const_size_1, const_size_2});
347 assertShapeEqual(res, {64, 56, 56});
348
349 res = calculateSymbolicShapesOnOp(
350 schema, std::vector<SSAInput>{const_size_1, ss1});
351 assertShapeEqual(res, {64, 56, 56});
352
353 res = calculateSymbolicShapesOnOp(
354 schema, std::vector<SSAInput>{const_size_2, ss1});
355 assertShapeEqual(res, {sym_dim, 56, 56});
356
357 res = calculateSymbolicShapesOnOp(
358 schema, std::vector<SSAInput>{ss_concrete, ss2});
359 assertShapeEqual(res, {64, 56, 56});
360
361 res = calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{ss2, ss3});
362 assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
363}
364
365TEST(ShapeAnalysisTest, BoundedSymbolicShapes) {
366 auto schema = getSchema("aten::nonzero(Tensor self) -> (Tensor)");
367
368 // Test that we generate symbolic shapes for the output of a nonzero op
369 c10::IValue const_size_1 = std::vector<int64_t>{5, 10};
370 auto res =
371 calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_1});
372 assertShapeEqual(res, {sym_dim, 2});
373
374 // Test that nonzero can also create concrete shapes
375 c10::IValue const_size_2 = std::vector<int64_t>({1, 0});
376 res =
377 calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_2});
378 assertShapeEqual(res, {0, 2});
379}
380
381TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
382 clear_shape_cache();
383 auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
384
385 c10::IValue const_size_1 = std::vector<int64_t>{64, 56};
386 c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
387 c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
388
389 c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
390 c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
391 c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
392
393 auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
394 assertShapeEqual(res, {sym_dim, 56});
395 auto res1_val = res->at(0);
396
397 // The exact same arguments should return the exact same result
398 res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
399 auto res2_val = res->at(0);
400 EXPECT_EQ(res1_val, res2_val);
401 EXPECT_EQ(get_shape_cache_size(), 1);
402
403 // Same shape but different symbols should return same shape
404 // but different symbolic indicies
405 res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2});
406 auto res3_val = res->at(0);
407
408 assertShapeEqual(res3_val, res2_val);
409 EXPECT_NE(res3_val, res2_val);
410 EXPECT_EQ(get_shape_cache_size(), 1);
411
412 // Different concrete shape should be cached separately
413 res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3});
414 assertShapeEqual(res, {sym_dim, 20});
415 EXPECT_EQ(get_shape_cache_size(), 2);
416
417 res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3});
418 assertShapeEqual(res, {sym_dim, 20});
419 EXPECT_EQ(get_shape_cache_size(), 3);
420
421 res = calculateSymbolicShapesOnOp(schema, {ss3, ss3});
422 assertShapeEqual(res, {sym_dim, sym_dim});
423 EXPECT_EQ(get_shape_cache_size(), 4);
424}
425
426TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) {
427 clear_shape_cache();
428
429 auto squeeze_op =
430 getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)");
431 auto mul_tensor =
432 getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor");
433 auto mul_scalar =
434 getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor");
435 auto div_tensor =
436 getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor");
437 auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
438
439 c10::IValue const_int = 1;
440
441 c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
442
443 auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
444 assertShapeEqual(res, {sym_dim, 64});
445
446 // Show that cache can handle multiple functions
447 res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
448 assertShapeEqual(res, {sym_dim, 64});
449 EXPECT_EQ(get_shape_cache_size(), 2);
450
451 res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1});
452 assertShapeEqual(res, {sym_dim, 64});
453 EXPECT_EQ(get_shape_cache_size(), 3);
454
455 // Even when the expected outcome is the same, should not collide
456 res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1});
457 assertShapeEqual(res, {sym_dim, 64});
458 EXPECT_EQ(get_shape_cache_size(), 4);
459
460 // Don't lose cached objects
461 res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
462 assertShapeEqual(res, {sym_dim, 64});
463 EXPECT_EQ(get_shape_cache_size(), 4);
464
465 res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1});
466 // SSA can infer that sym_dim is 64 as both tensors
467 // use the same sym_dim
468 assertShapeEqual(res, {64, 64});
469 EXPECT_EQ(get_shape_cache_size(), 5);
470}
471
472TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
473 clear_shape_cache();
474
475 auto max_dim_op = getSchema(
476 "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)");
477 c10::IValue const_int = 1;
478 c10::IValue false_ival = false;
479
480 c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
481 c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
482
483 auto res =
484 calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
485 c10::SymbolicShape expected_res = c10::SymbolicShape({sym_dim});
486 assertShapeEqual(res->at(0), expected_res);
487 // res0 and res1 should share the same symbolic symbol
488 EXPECT_EQ(res->at(0), res->at(1));
489
490 // Also test that the shape cache also returns consistent result shapes
491 res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival});
492 assertShapeEqual(res->at(0), expected_res);
493 EXPECT_EQ(res->at(0), res->at(1));
494 EXPECT_EQ(get_shape_cache_size(), 1);
495}
496} // namespace jit
497} // namespace torch
498