1#include <gtest/gtest.h>
2#include <test/cpp/api/support.h>
3#include <torch/script.h>
4
5using namespace torch::autograd;
6using namespace torch::test;
7
8namespace {
9torch::Tensor functional_op(torch::Tensor& x) {
10 return x * x;
11}
12
13void inplace_op(torch::Tensor& x) {
14 x.mul_(1);
15}
16
17torch::Tensor view_op(torch::Tensor& x) {
18 return x.view({2, 3});
19}
20
21/*
22 Only the following combos of Autograd & ADInplaceOrView keys on tensors are
23 valid:
24 - Autograd=true, ADInplaceOrView=true (normal tensor)
25 - Autograd=false, ADInplaceOrView=false (inference tensor)
26 Tensors created in InferenceMode are mostly inference tensors. The only
27 exception is that view of normal tensors created in InferenceMode still
28 produce normal tensor.
29*/
30void assert_TLS_states(bool inference_mode) {
31 ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
32 ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(
33 c10::DispatchKey::ADInplaceOrView));
34 ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(
35 c10::autograd_dispatch_keyset));
36 ASSERT_EQ(
37 c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset),
38 inference_mode);
39 ASSERT_EQ(
40 c10::impl::tls_is_dispatch_key_included(
41 c10::DispatchKey::ADInplaceOrView),
42 !inference_mode);
43 ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
44}
45} // namespace
46
47TEST(InferenceModeTest, TestTLSState) {
48 assert_TLS_states(false);
49 {
50 InferenceMode guard;
51 assert_TLS_states(true);
52 {
53 InferenceMode guard(false);
54 assert_TLS_states(false);
55 }
56 assert_TLS_states(true);
57 }
58 assert_TLS_states(false);
59}
60
61TEST(InferenceModeTest, TestInferenceTensorCreation) {
62 {
63 InferenceMode guard;
64 // New tensor created through constructors are inference tensors.
65 torch::Tensor c = torch::ones({1, 2, 3});
66 ASSERT_FALSE(c.requires_grad());
67 ASSERT_TRUE(c.is_inference());
68
69 // requires_grad doesn't change inference tensor behavior inside
70 // InferenceMode.
71 torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
72 ASSERT_TRUE(tmp.requires_grad());
73 ASSERT_TRUE(tmp.is_inference());
74
75 tmp = torch::ones({1, 2, 3}).set_requires_grad(false);
76 ASSERT_FALSE(tmp.requires_grad());
77 ASSERT_TRUE(tmp.is_inference());
78 }
79}
80
81TEST(InferenceModeTest, TestExistingAutogradSession) {
82 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
83 torch::Tensor a = s.clone();
84
85 // Save `a` in an existing autograd session
86 torch::Tensor out = a * a;
87 {
88 InferenceMode guard;
89 inplace_op(a);
90 }
91 // Performing backward should trigger error since `a`'s version has been
92 // bumped.
93 ASSERT_THROWS_WITH(
94 out.backward(torch::ones_like(out)),
95 "one of the variables needed for gradient computation has been modified by an inplace operation")
96}
97
98TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
99 c10::InferenceMode guard;
100 for (bool requires_grad : {true, false}) {
101 torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
102
103 torch::Tensor func_out = functional_op(c); // go through kernels: CPU
104 ASSERT_TRUE(func_out.is_inference());
105 ASSERT_FALSE(func_out.requires_grad());
106 }
107}
108
109TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
110 c10::InferenceMode guard;
111 for (bool requires_grad : {true, false}) {
112 torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
113
114 inplace_op(c); // go through kernels: CPU
115 ASSERT_TRUE(c.is_inference());
116 ASSERT_EQ(c.requires_grad(), requires_grad);
117 }
118}
119
120TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
121 c10::InferenceMode guard;
122 for (bool requires_grad : {true, false}) {
123 torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
124
125 torch::Tensor view_out = view_op(c); // go through kernels: CPU
126 ASSERT_TRUE(view_out.is_inference());
127 // Note this is different from NoGradMode but makes sense.
128 ASSERT_FALSE(view_out.requires_grad());
129 ASSERT_FALSE(view_out.is_view());
130 }
131}
132
133TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
134 torch::Tensor inference_tensor;
135 for (bool requires_grad : {true, false}) {
136 {
137 InferenceMode guard;
138 inference_tensor =
139 torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
140 }
141
142 // Due to issue #54614, this might run slower compared to InferenceMode
143 // since intermediate tensors are normal tensors, and they might dispatch to
144 // VariableType kernels. This is fine since users can easily fix it by
145 // moving it inside InferenceMode block.
146 torch::Tensor tmp =
147 functional_op(inference_tensor); // go through kernels:
148 // ADInplaceOrView(fallthrough), CPU
149 ASSERT_FALSE(tmp.is_inference());
150 ASSERT_FALSE(tmp.requires_grad());
151 }
152}
153
154TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
155 torch::Tensor inference_tensor;
156 for (bool requires_grad : {true, false}) {
157 {
158 InferenceMode guard;
159 inference_tensor =
160 torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
161 }
162 ASSERT_THROWS_WITH(
163 inplace_op(
164 inference_tensor), // go through kernels: ADInplaceOrView, CPU
165 "Inplace update to inference tensor outside InferenceMode is not allowed");
166 }
167}
168
169TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
170 torch::Tensor inference_tensor;
171 for (bool requires_grad : {true, false}) {
172 {
173 InferenceMode guard;
174 inference_tensor =
175 torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
176 }
177 torch::Tensor out =
178 view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
179 ASSERT_TRUE(out.is_inference());
180 ASSERT_FALSE(out.requires_grad());
181 ASSERT_FALSE(out.is_view());
182 ASSERT_TRUE(out.is_leaf());
183 }
184}
185
186TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
187 for (bool requires_grad : {true, false}) {
188 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
189 torch::Tensor a = s.clone();
190
191 {
192 c10::InferenceMode guard;
193
194 inplace_op(a); // go through kernels: ADInplaceOrView, CPU
195 ASSERT_FALSE(a.is_inference());
196 ASSERT_EQ(a.requires_grad(), requires_grad);
197
198 // inplace -> inplace
199 inplace_op(a); // go through kernels: ADInplaceOrView, CPU
200 ASSERT_FALSE(a.is_inference());
201 ASSERT_EQ(a.requires_grad(), requires_grad);
202
203 // inplace -> inplace -> view
204 torch::Tensor view_out =
205 view_op(a); // go through kernels: ADInplaceOrView, CPU
206 ASSERT_FALSE(view_out.is_inference());
207 ASSERT_EQ(view_out.requires_grad(), requires_grad);
208 }
209 }
210}
211
212TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
213 for (bool requires_grad : {true, false}) {
214 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
215 torch::Tensor a = s.clone();
216
217 {
218 c10::InferenceMode guard;
219
220 inplace_op(a); // go through kernels: ADInplaceOrView, CPU
221 ASSERT_FALSE(a.is_inference());
222 ASSERT_EQ(a.requires_grad(), requires_grad);
223 }
224
225 torch::Tensor tmp = functional_op(a); // go through kernels: VariableType,
226 // ADInplaceOrView(fallthrough), CPU
227 ASSERT_FALSE(tmp.is_inference());
228 ASSERT_EQ(tmp.requires_grad(), requires_grad);
229
230 inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
231 ASSERT_FALSE(a.is_inference());
232 ASSERT_EQ(a.requires_grad(), requires_grad);
233
234 tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
235 ASSERT_FALSE(tmp.is_inference());
236 ASSERT_EQ(tmp.requires_grad(), requires_grad);
237 }
238}
239
240TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
241 for (bool requires_grad : {true, false}) {
242 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
243 torch::Tensor a = s.clone();
244 torch::Tensor view_out, tmp;
245
246 {
247 c10::InferenceMode guard;
248 // View ops on normal tensor produce normal tensors as output.
249 // - For view ops it has both dispatch keys since due to the way we create
250 // view Tensors in alias_with_sizes_and_strides:
251 // ```
252 // auto impl = c10::make_intrusive<TensorImpl>(
253 // Storage(self.storage()), self.key_set(), self.dtype());
254 // ```
255 // In addition, these view output tensors are normal in the sense they
256 // have both Autograd and ADInplaceOrView keys. But they're still
257 // special since they'll have CreationMeta::INFERENCE_MODE. In other
258 // words they behave exactly the same as a view tensor created in
259 // no_grad mode.
260
261 view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
262 ASSERT_FALSE(view_out.is_inference());
263 assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
264 ASSERT_EQ(view_out.requires_grad(), requires_grad);
265 ASSERT_TRUE(view_out.is_leaf());
266
267 // view -> view
268 tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
269 ASSERT_FALSE(tmp.is_inference());
270 assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
271 ASSERT_EQ(tmp.requires_grad(), requires_grad);
272 ASSERT_TRUE(tmp.is_leaf());
273
274 // view -> view -> inplace
275 inplace_op(tmp); // kernels: ADInplaceOrView, CPU
276 assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
277 ASSERT_FALSE(tmp.is_inference());
278 ASSERT_EQ(tmp.requires_grad(), requires_grad);
279 ASSERT_TRUE(tmp.is_leaf());
280 ASSERT_EQ(a._version(), tmp._version());
281 }
282 }
283}
284
285TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
286 for (bool requires_grad : {true, false}) {
287 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
288 torch::Tensor a = s.clone();
289 torch::Tensor view_out, tmp;
290
291 {
292 c10::InferenceMode guard;
293 view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
294 ASSERT_FALSE(view_out.is_inference());
295 assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
296 ASSERT_EQ(view_out.requires_grad(), requires_grad);
297 ASSERT_TRUE(view_out.is_leaf());
298 }
299
300 tmp = functional_op(view_out);
301 ASSERT_FALSE(view_out.is_inference());
302 ASSERT_EQ(tmp.requires_grad(), requires_grad);
303
304 if (requires_grad) {
305 ASSERT_THROWS_WITH(
306 inplace_op(view_out), // go through kernels: VariableType,
307 // ADInplaceOrView, CPU
308 "A view was created in inference mode and is being modified inplace")
309 } else {
310 inplace_op(view_out);
311 }
312
313 tmp = view_op(view_out);
314 ASSERT_FALSE(view_out.is_inference());
315 ASSERT_EQ(tmp.requires_grad(), requires_grad);
316 }
317}
318
319TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
320 for (bool requires_grad : {true, false}) {
321 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
322 torch::Tensor c;
323 {
324 InferenceMode guard;
325 c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
326 }
327
328 // add(Tensor, Tensor) is safe with inference tensor since it doesn't save
329 // any variable for backward.
330 torch::Tensor out = c.add(s); // go through kernels: VariableType,
331 // ADInplaceOrView(fallthrough), CPU
332 ASSERT_FALSE(out.is_inference());
333 ASSERT_EQ(out.requires_grad(), requires_grad);
334 if (requires_grad) {
335 // leaf inference tensor with requires_grad=true can still have gradient.
336 // Note this behavior is different from NoGradMode which has empty grad.
337 out.backward(torch::ones_like(out));
338 assert_tensor_equal(c.grad(), torch::ones_like(c));
339 }
340
341 if (requires_grad) {
342 // mul(self, other) saves variable when requires_grad=true
343 ASSERT_THROWS_WITH(
344 c.mul(s), "Inference tensors cannot be saved for backward.");
345
346 // Inference tensor in TensorList input
347 // stack does not capture anymore, so disabled
348 // TODO: find alternative Function that captures a list (maybe custom fn)
349 /*
350 std::vector<torch::Tensor> inputs = {s, c};
351 ASSERT_THROWS_WITH(
352 torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
353 // ADInplaceOrView(fallthrough), CPU
354 "Inference tensors cannot be saved for backward.")
355 */
356 }
357 }
358}
359
360TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
361 for (bool requires_grad : {true, false}) {
362 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
363 torch::Tensor a = s.clone();
364 torch::Tensor c;
365 {
366 InferenceMode guard;
367 c = torch::ones({1, 2, 3});
368 }
369
370 if (requires_grad) {
371 ASSERT_THROWS_WITH(
372 a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode,
373 // CPU
374 "Inference tensors cannot be saved for backward.");
375
376 ASSERT_THROWS_WITH(
377 torch::mul_out(
378 /*out=*/c, s, s), // go through kernels: VariableType(ERROR!),
379 // ADInplaceOrView, CPU
380 "out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
381 } else {
382 a.mul_(c);
383
384 ASSERT_THROWS_WITH(
385 torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType,
386 // ADInplaceOrView(ERROR!), CPU
387 "Inplace update to inference tensor outside InferenceMode is not allowed");
388 }
389 }
390}
391
392TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
393 for (bool requires_grad : {true, false}) {
394 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
395 torch::Tensor c;
396 {
397 InferenceMode guard;
398 c = torch::ones({1, 2, 3});
399 }
400
401 // view_as is a composite op which calls view() with only one tensor
402 // argument. So there isn't a mixed inference tensor and normal tensor
403 // inputs for view ops.
404 torch::Tensor tmp1 =
405 c.view_as(s); // go through kernels: ADInplaceOrView, CPU
406 ASSERT_TRUE(tmp1.is_inference());
407 ASSERT_FALSE(tmp1.requires_grad());
408
409 // This is fine since it's equivalent as s.view(c.sizes()) which
410 // isn't a mixed input scenario.
411 torch::Tensor tmp2 =
412 s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
413 ASSERT_FALSE(tmp2.is_inference());
414 ASSERT_EQ(tmp2.requires_grad(), requires_grad);
415 }
416}
417
418TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
419 for (bool requires_grad : {true, false}) {
420 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
421 torch::Tensor a = s.clone();
422 torch::Tensor view_out;
423 {
424 InferenceMode guard;
425 view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
426 }
427 if (requires_grad) {
428 ASSERT_THROWS_WITH(
429 inplace_op(view_out),
430 "A view was created in inference mode and is being modified inplace")
431 } else {
432 inplace_op(view_out);
433 }
434 }
435}
436
437TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
438 for (bool requires_grad : {true, false}) {
439 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
440 torch::Tensor a = s.clone();
441 torch::Tensor view_out;
442 {
443 InferenceMode guard;
444 view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
445 }
446 inplace_op(a);
447 if (requires_grad) {
448 ASSERT_THROWS_WITH(
449 view_out.grad_fn(),
450 "A view was created in inference mode and its base or another view of its base has been modified inplace");
451 } else {
452 view_out.grad_fn();
453 }
454 }
455}
456
457TEST(InferenceModeTest, TestCreationMetaPropagation) {
458 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
459 torch::Tensor b, c;
460 {
461 InferenceMode guard;
462 b = s.view_as(s);
463 }
464 ASSERT_THROWS_WITH(
465 b.add_(1),
466 "A view was created in inference mode and is being modified inplace");
467 {
468 AutoGradMode mode(false);
469 c = b.view_as(b);
470 }
471 ASSERT_THROWS_WITH(
472 c.add_(1),
473 "A view was created in inference mode and is being modified inplace");
474}
475
476TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
477 torch::Tensor s = torch::ones({2, 2, 3}).set_requires_grad(true);
478 auto s_view = s.view_as(s);
479 std::vector<at::Tensor> b, c;
480 {
481 InferenceMode guard;
482 b = s_view.split_with_sizes({1, 1});
483
484 s = s.view_as(s);
485 c = s.split_with_sizes({1, 1});
486 }
487 for (auto& b_el : b) {
488 assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE);
489 ASSERT_THROWS_WITH(
490 b_el.add_(1),
491 "A view was created in inference mode and is being modified inplace");
492 }
493 for (auto& c_el : c) {
494 assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE);
495 ASSERT_THROWS_WITH(
496 c_el.add_(1),
497 "A view was created in inference mode and is being modified inplace");
498 }
499}
500
501TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
502 for (bool requires_grad : {true, false}) {
503 torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
504 torch::Tensor t;
505 {
506 InferenceMode guard;
507 t = torch::ones({1, 2, 3});
508 t.copy_(s);
509 ASSERT_TRUE(t.is_inference());
510 ASSERT_FALSE(t.requires_grad());
511 }
512
513 ASSERT_THROWS_WITH(
514 t.copy_(s),
515 "Inplace update to inference tensor outside InferenceMode is not allowed");
516 }
517}
518
519TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
520 torch::Tensor t;
521 {
522 InferenceMode guard;
523 t = torch::ones({1, 2, 3});
524 }
525 t.set_requires_grad(false);
526 ASSERT_THROWS_WITH(
527 t.set_requires_grad(true),
528 "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
529}
530
531TEST(InferenceModeTest, TestAccessVersionCounter) {
532 torch::Tensor t;
533 {
534 InferenceMode guard;
535 t = torch::ones({1, 2, 3});
536 ASSERT_THROWS_WITH(
537 t.unsafeGetTensorImpl()->version_counter().current_version(),
538 "Inference tensors do not track version counter.");
539 t.unsafeGetTensorImpl()->bump_version();
540 }
541 ASSERT_THROWS_WITH(
542 t.unsafeGetTensorImpl()->version_counter().current_version(),
543 "Inference tensors do not track version counter.");
544 ASSERT_THROWS_WITH(
545 t.unsafeGetTensorImpl()->bump_version(),
546 "Inplace update to inference tensor outside InferenceMode is not allowed.");
547 // Suggested workaround
548 torch::Tensor c = t.clone();
549 uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
550 c.unsafeGetTensorImpl()->bump_version();
551 ASSERT_EQ(
552 c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
553}
554
555TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
556 torch::Tensor s = torch::ones({1, 2, 3});
557 torch::Tensor t;
558 {
559 InferenceMode guard;
560 t = torch::ones({1, 2, 3});
561 // Testing both copy_ from VariableTypeManual and add_ from generated code.
562 s.copy_(t);
563 s.add_(t);
564 t.add_(s);
565 t.copy_(s);
566 }
567 s.copy_(t);
568 s.add_(t);
569 ASSERT_THROWS_WITH(
570 t.copy_(s),
571 "Inplace update to inference tensor outside InferenceMode is not allowed");
572
573 ASSERT_THROWS_WITH(
574 t.add_(s),
575 "Inplace update to inference tensor outside InferenceMode is not allowed");
576}
577
578TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
579 torch::Tensor s = torch::ones({3, 3, 2});
580 torch::Tensor t = torch::view_as_complex(s);
581 {
582 InferenceMode guard;
583 torch::Tensor tmp;
584
585 tmp = torch::view_as_real(t);
586 ASSERT_FALSE(tmp.is_inference());
587 tmp = torch::view_as_complex(s);
588 ASSERT_FALSE(tmp.is_inference());
589
590 torch::Tensor e = torch::ones({3, 3, 2});
591 tmp = torch::view_as_complex(e);
592 ASSERT_TRUE(tmp.is_inference());
593 tmp = torch::view_as_real(tmp);
594 ASSERT_TRUE(tmp.is_inference());
595 }
596}
597
598TEST(InferenceModeTest, TestComplexViewInNormalMode) {
599 torch::Tensor s;
600 {
601 InferenceMode guard;
602 s = torch::ones({3, 3, 2});
603 }
604 torch::Tensor tmp = torch::view_as_complex(s);
605 ASSERT_TRUE(tmp.is_inference());
606 tmp = torch::view_as_real(tmp);
607 ASSERT_TRUE(tmp.is_inference());
608}
609
610TEST(InferenceModeTest, TestCustomFunction) {
611 struct MyFunction : public Function<MyFunction> {
612 static Variable forward(
613 AutogradContext* ctx,
614 Variable var1,
615 int mul,
616 Variable var2) {
617 ctx->saved_data["mul"] = mul;
618 ctx->save_for_backward({var1, var2});
619 return var1 + mul * var2 + var1 * var2;
620 }
621
622 static variable_list backward(
623 AutogradContext* ctx,
624 variable_list grad_output) {
625 int mul = ctx->saved_data["mul"].toInt();
626 auto saved = ctx->get_saved_variables();
627 auto var1 = saved[0];
628 auto var2 = saved[1];
629 variable_list output = {
630 grad_output[0] + grad_output[0] * var2,
631 Variable(),
632 grad_output[0] * mul + grad_output[0] * var1};
633 return output;
634 }
635 };
636
637 {
638 InferenceMode guard;
639 torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true);
640 auto var2 = var1.clone();
641 int mul = 2;
642 // If InferenceMode didn't set NoGradGuard automatically, this line
643 // would error out when trying to save `var1` and `var2` for backward.
644 auto y = MyFunction::apply(var1, mul, var2);
645 torch::Tensor expected = var1 + mul * var2 + var1 * var2;
646 assert_tensor_equal(y, expected);
647 }
648}
649
650TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
651 c10::WarningUtils::WarnAlways warn_always(true);
652 WarningCapture warnings;
653 at::AutoNonVariableTypeMode guard;
654 ASSERT_TRUE(
655 warnings.str().find("AutoNonVariableTypeMode is deprecated") !=
656 std::string::npos);
657}
658