1 | #include <gtest/gtest.h> |
2 | #include <test/cpp/api/support.h> |
3 | #include <torch/script.h> |
4 | |
5 | using namespace torch::autograd; |
6 | using namespace torch::test; |
7 | |
8 | namespace { |
9 | torch::Tensor functional_op(torch::Tensor& x) { |
10 | return x * x; |
11 | } |
12 | |
13 | void inplace_op(torch::Tensor& x) { |
14 | x.mul_(1); |
15 | } |
16 | |
17 | torch::Tensor view_op(torch::Tensor& x) { |
18 | return x.view({2, 3}); |
19 | } |
20 | |
21 | /* |
22 | Only the following combos of Autograd & ADInplaceOrView keys on tensors are |
23 | valid: |
24 | - Autograd=true, ADInplaceOrView=true (normal tensor) |
25 | - Autograd=false, ADInplaceOrView=false (inference tensor) |
26 | Tensors created in InferenceMode are mostly inference tensors. The only |
27 | exception is that view of normal tensors created in InferenceMode still |
28 | produce normal tensor. |
29 | */ |
30 | void assert_TLS_states(bool inference_mode) { |
31 | ASSERT_EQ(InferenceMode::is_enabled(), inference_mode); |
32 | ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded( |
33 | c10::DispatchKey::ADInplaceOrView)); |
34 | ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included( |
35 | c10::autograd_dispatch_keyset)); |
36 | ASSERT_EQ( |
37 | c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), |
38 | inference_mode); |
39 | ASSERT_EQ( |
40 | c10::impl::tls_is_dispatch_key_included( |
41 | c10::DispatchKey::ADInplaceOrView), |
42 | !inference_mode); |
43 | ASSERT_EQ(GradMode::is_enabled(), !inference_mode); |
44 | } |
45 | } // namespace |
46 | |
47 | TEST(InferenceModeTest, TestTLSState) { |
48 | assert_TLS_states(false); |
49 | { |
50 | InferenceMode guard; |
51 | assert_TLS_states(true); |
52 | { |
53 | InferenceMode guard(false); |
54 | assert_TLS_states(false); |
55 | } |
56 | assert_TLS_states(true); |
57 | } |
58 | assert_TLS_states(false); |
59 | } |
60 | |
61 | TEST(InferenceModeTest, TestInferenceTensorCreation) { |
62 | { |
63 | InferenceMode guard; |
64 | // New tensor created through constructors are inference tensors. |
65 | torch::Tensor c = torch::ones({1, 2, 3}); |
66 | ASSERT_FALSE(c.requires_grad()); |
67 | ASSERT_TRUE(c.is_inference()); |
68 | |
69 | // requires_grad doesn't change inference tensor behavior inside |
70 | // InferenceMode. |
71 | torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true); |
72 | ASSERT_TRUE(tmp.requires_grad()); |
73 | ASSERT_TRUE(tmp.is_inference()); |
74 | |
75 | tmp = torch::ones({1, 2, 3}).set_requires_grad(false); |
76 | ASSERT_FALSE(tmp.requires_grad()); |
77 | ASSERT_TRUE(tmp.is_inference()); |
78 | } |
79 | } |
80 | |
81 | TEST(InferenceModeTest, TestExistingAutogradSession) { |
82 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); |
83 | torch::Tensor a = s.clone(); |
84 | |
85 | // Save `a` in an existing autograd session |
86 | torch::Tensor out = a * a; |
87 | { |
88 | InferenceMode guard; |
89 | inplace_op(a); |
90 | } |
91 | // Performing backward should trigger error since `a`'s version has been |
92 | // bumped. |
93 | ASSERT_THROWS_WITH( |
94 | out.backward(torch::ones_like(out)), |
95 | "one of the variables needed for gradient computation has been modified by an inplace operation" ) |
96 | } |
97 | |
98 | TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) { |
99 | c10::InferenceMode guard; |
100 | for (bool requires_grad : {true, false}) { |
101 | torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
102 | |
103 | torch::Tensor func_out = functional_op(c); // go through kernels: CPU |
104 | ASSERT_TRUE(func_out.is_inference()); |
105 | ASSERT_FALSE(func_out.requires_grad()); |
106 | } |
107 | } |
108 | |
109 | TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) { |
110 | c10::InferenceMode guard; |
111 | for (bool requires_grad : {true, false}) { |
112 | torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
113 | |
114 | inplace_op(c); // go through kernels: CPU |
115 | ASSERT_TRUE(c.is_inference()); |
116 | ASSERT_EQ(c.requires_grad(), requires_grad); |
117 | } |
118 | } |
119 | |
120 | TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) { |
121 | c10::InferenceMode guard; |
122 | for (bool requires_grad : {true, false}) { |
123 | torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
124 | |
125 | torch::Tensor view_out = view_op(c); // go through kernels: CPU |
126 | ASSERT_TRUE(view_out.is_inference()); |
127 | // Note this is different from NoGradMode but makes sense. |
128 | ASSERT_FALSE(view_out.requires_grad()); |
129 | ASSERT_FALSE(view_out.is_view()); |
130 | } |
131 | } |
132 | |
133 | TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) { |
134 | torch::Tensor inference_tensor; |
135 | for (bool requires_grad : {true, false}) { |
136 | { |
137 | InferenceMode guard; |
138 | inference_tensor = |
139 | torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
140 | } |
141 | |
142 | // Due to issue #54614, this might run slower compared to InferenceMode |
143 | // since intermediate tensors are normal tensors, and they might dispatch to |
144 | // VariableType kernels. This is fine since users can easily fix it by |
145 | // moving it inside InferenceMode block. |
146 | torch::Tensor tmp = |
147 | functional_op(inference_tensor); // go through kernels: |
148 | // ADInplaceOrView(fallthrough), CPU |
149 | ASSERT_FALSE(tmp.is_inference()); |
150 | ASSERT_FALSE(tmp.requires_grad()); |
151 | } |
152 | } |
153 | |
154 | TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) { |
155 | torch::Tensor inference_tensor; |
156 | for (bool requires_grad : {true, false}) { |
157 | { |
158 | InferenceMode guard; |
159 | inference_tensor = |
160 | torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
161 | } |
162 | ASSERT_THROWS_WITH( |
163 | inplace_op( |
164 | inference_tensor), // go through kernels: ADInplaceOrView, CPU |
165 | "Inplace update to inference tensor outside InferenceMode is not allowed" ); |
166 | } |
167 | } |
168 | |
169 | TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) { |
170 | torch::Tensor inference_tensor; |
171 | for (bool requires_grad : {true, false}) { |
172 | { |
173 | InferenceMode guard; |
174 | inference_tensor = |
175 | torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
176 | } |
177 | torch::Tensor out = |
178 | view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU |
179 | ASSERT_TRUE(out.is_inference()); |
180 | ASSERT_FALSE(out.requires_grad()); |
181 | ASSERT_FALSE(out.is_view()); |
182 | ASSERT_TRUE(out.is_leaf()); |
183 | } |
184 | } |
185 | |
186 | TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) { |
187 | for (bool requires_grad : {true, false}) { |
188 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
189 | torch::Tensor a = s.clone(); |
190 | |
191 | { |
192 | c10::InferenceMode guard; |
193 | |
194 | inplace_op(a); // go through kernels: ADInplaceOrView, CPU |
195 | ASSERT_FALSE(a.is_inference()); |
196 | ASSERT_EQ(a.requires_grad(), requires_grad); |
197 | |
198 | // inplace -> inplace |
199 | inplace_op(a); // go through kernels: ADInplaceOrView, CPU |
200 | ASSERT_FALSE(a.is_inference()); |
201 | ASSERT_EQ(a.requires_grad(), requires_grad); |
202 | |
203 | // inplace -> inplace -> view |
204 | torch::Tensor view_out = |
205 | view_op(a); // go through kernels: ADInplaceOrView, CPU |
206 | ASSERT_FALSE(view_out.is_inference()); |
207 | ASSERT_EQ(view_out.requires_grad(), requires_grad); |
208 | } |
209 | } |
210 | } |
211 | |
212 | TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) { |
213 | for (bool requires_grad : {true, false}) { |
214 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
215 | torch::Tensor a = s.clone(); |
216 | |
217 | { |
218 | c10::InferenceMode guard; |
219 | |
220 | inplace_op(a); // go through kernels: ADInplaceOrView, CPU |
221 | ASSERT_FALSE(a.is_inference()); |
222 | ASSERT_EQ(a.requires_grad(), requires_grad); |
223 | } |
224 | |
225 | torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, |
226 | // ADInplaceOrView(fallthrough), CPU |
227 | ASSERT_FALSE(tmp.is_inference()); |
228 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
229 | |
230 | inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU |
231 | ASSERT_FALSE(a.is_inference()); |
232 | ASSERT_EQ(a.requires_grad(), requires_grad); |
233 | |
234 | tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU |
235 | ASSERT_FALSE(tmp.is_inference()); |
236 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
237 | } |
238 | } |
239 | |
240 | TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) { |
241 | for (bool requires_grad : {true, false}) { |
242 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
243 | torch::Tensor a = s.clone(); |
244 | torch::Tensor view_out, tmp; |
245 | |
246 | { |
247 | c10::InferenceMode guard; |
248 | // View ops on normal tensor produce normal tensors as output. |
249 | // - For view ops it has both dispatch keys since due to the way we create |
250 | // view Tensors in alias_with_sizes_and_strides: |
251 | // ``` |
252 | // auto impl = c10::make_intrusive<TensorImpl>( |
253 | // Storage(self.storage()), self.key_set(), self.dtype()); |
254 | // ``` |
255 | // In addition, these view output tensors are normal in the sense they |
256 | // have both Autograd and ADInplaceOrView keys. But they're still |
257 | // special since they'll have CreationMeta::INFERENCE_MODE. In other |
258 | // words they behave exactly the same as a view tensor created in |
259 | // no_grad mode. |
260 | |
261 | view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU |
262 | ASSERT_FALSE(view_out.is_inference()); |
263 | assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); |
264 | ASSERT_EQ(view_out.requires_grad(), requires_grad); |
265 | ASSERT_TRUE(view_out.is_leaf()); |
266 | |
267 | // view -> view |
268 | tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU |
269 | ASSERT_FALSE(tmp.is_inference()); |
270 | assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); |
271 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
272 | ASSERT_TRUE(tmp.is_leaf()); |
273 | |
274 | // view -> view -> inplace |
275 | inplace_op(tmp); // kernels: ADInplaceOrView, CPU |
276 | assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); |
277 | ASSERT_FALSE(tmp.is_inference()); |
278 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
279 | ASSERT_TRUE(tmp.is_leaf()); |
280 | ASSERT_EQ(a._version(), tmp._version()); |
281 | } |
282 | } |
283 | } |
284 | |
285 | TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) { |
286 | for (bool requires_grad : {true, false}) { |
287 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
288 | torch::Tensor a = s.clone(); |
289 | torch::Tensor view_out, tmp; |
290 | |
291 | { |
292 | c10::InferenceMode guard; |
293 | view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU |
294 | ASSERT_FALSE(view_out.is_inference()); |
295 | assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); |
296 | ASSERT_EQ(view_out.requires_grad(), requires_grad); |
297 | ASSERT_TRUE(view_out.is_leaf()); |
298 | } |
299 | |
300 | tmp = functional_op(view_out); |
301 | ASSERT_FALSE(view_out.is_inference()); |
302 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
303 | |
304 | if (requires_grad) { |
305 | ASSERT_THROWS_WITH( |
306 | inplace_op(view_out), // go through kernels: VariableType, |
307 | // ADInplaceOrView, CPU |
308 | "A view was created in inference mode and is being modified inplace" ) |
309 | } else { |
310 | inplace_op(view_out); |
311 | } |
312 | |
313 | tmp = view_op(view_out); |
314 | ASSERT_FALSE(view_out.is_inference()); |
315 | ASSERT_EQ(tmp.requires_grad(), requires_grad); |
316 | } |
317 | } |
318 | |
319 | TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) { |
320 | for (bool requires_grad : {true, false}) { |
321 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
322 | torch::Tensor c; |
323 | { |
324 | InferenceMode guard; |
325 | c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
326 | } |
327 | |
328 | // add(Tensor, Tensor) is safe with inference tensor since it doesn't save |
329 | // any variable for backward. |
330 | torch::Tensor out = c.add(s); // go through kernels: VariableType, |
331 | // ADInplaceOrView(fallthrough), CPU |
332 | ASSERT_FALSE(out.is_inference()); |
333 | ASSERT_EQ(out.requires_grad(), requires_grad); |
334 | if (requires_grad) { |
335 | // leaf inference tensor with requires_grad=true can still have gradient. |
336 | // Note this behavior is different from NoGradMode which has empty grad. |
337 | out.backward(torch::ones_like(out)); |
338 | assert_tensor_equal(c.grad(), torch::ones_like(c)); |
339 | } |
340 | |
341 | if (requires_grad) { |
342 | // mul(self, other) saves variable when requires_grad=true |
343 | ASSERT_THROWS_WITH( |
344 | c.mul(s), "Inference tensors cannot be saved for backward." ); |
345 | |
346 | // Inference tensor in TensorList input |
347 | // stack does not capture anymore, so disabled |
348 | // TODO: find alternative Function that captures a list (maybe custom fn) |
349 | /* |
350 | std::vector<torch::Tensor> inputs = {s, c}; |
351 | ASSERT_THROWS_WITH( |
352 | torch::stack(inputs), // go through kernels: VariableType(ERROR)!, |
353 | // ADInplaceOrView(fallthrough), CPU |
354 | "Inference tensors cannot be saved for backward.") |
355 | */ |
356 | } |
357 | } |
358 | } |
359 | |
360 | TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) { |
361 | for (bool requires_grad : {true, false}) { |
362 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
363 | torch::Tensor a = s.clone(); |
364 | torch::Tensor c; |
365 | { |
366 | InferenceMode guard; |
367 | c = torch::ones({1, 2, 3}); |
368 | } |
369 | |
370 | if (requires_grad) { |
371 | ASSERT_THROWS_WITH( |
372 | a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, |
373 | // CPU |
374 | "Inference tensors cannot be saved for backward." ); |
375 | |
376 | ASSERT_THROWS_WITH( |
377 | torch::mul_out( |
378 | /*out=*/c, s, s), // go through kernels: VariableType(ERROR!), |
379 | // ADInplaceOrView, CPU |
380 | "out=... arguments don't support automatic differentiation, but one of the arguments requires grad" ) |
381 | } else { |
382 | a.mul_(c); |
383 | |
384 | ASSERT_THROWS_WITH( |
385 | torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, |
386 | // ADInplaceOrView(ERROR!), CPU |
387 | "Inplace update to inference tensor outside InferenceMode is not allowed" ); |
388 | } |
389 | } |
390 | } |
391 | |
392 | TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) { |
393 | for (bool requires_grad : {true, false}) { |
394 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
395 | torch::Tensor c; |
396 | { |
397 | InferenceMode guard; |
398 | c = torch::ones({1, 2, 3}); |
399 | } |
400 | |
401 | // view_as is a composite op which calls view() with only one tensor |
402 | // argument. So there isn't a mixed inference tensor and normal tensor |
403 | // inputs for view ops. |
404 | torch::Tensor tmp1 = |
405 | c.view_as(s); // go through kernels: ADInplaceOrView, CPU |
406 | ASSERT_TRUE(tmp1.is_inference()); |
407 | ASSERT_FALSE(tmp1.requires_grad()); |
408 | |
409 | // This is fine since it's equivalent as s.view(c.sizes()) which |
410 | // isn't a mixed input scenario. |
411 | torch::Tensor tmp2 = |
412 | s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU |
413 | ASSERT_FALSE(tmp2.is_inference()); |
414 | ASSERT_EQ(tmp2.requires_grad(), requires_grad); |
415 | } |
416 | } |
417 | |
418 | TEST(InferenceModeTest, TestHandleDirectViewOnRebase) { |
419 | for (bool requires_grad : {true, false}) { |
420 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
421 | torch::Tensor a = s.clone(); |
422 | torch::Tensor view_out; |
423 | { |
424 | InferenceMode guard; |
425 | view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU |
426 | } |
427 | if (requires_grad) { |
428 | ASSERT_THROWS_WITH( |
429 | inplace_op(view_out), |
430 | "A view was created in inference mode and is being modified inplace" ) |
431 | } else { |
432 | inplace_op(view_out); |
433 | } |
434 | } |
435 | } |
436 | |
437 | TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) { |
438 | for (bool requires_grad : {true, false}) { |
439 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
440 | torch::Tensor a = s.clone(); |
441 | torch::Tensor view_out; |
442 | { |
443 | InferenceMode guard; |
444 | view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU |
445 | } |
446 | inplace_op(a); |
447 | if (requires_grad) { |
448 | ASSERT_THROWS_WITH( |
449 | view_out.grad_fn(), |
450 | "A view was created in inference mode and its base or another view of its base has been modified inplace" ); |
451 | } else { |
452 | view_out.grad_fn(); |
453 | } |
454 | } |
455 | } |
456 | |
457 | TEST(InferenceModeTest, TestCreationMetaPropagation) { |
458 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); |
459 | torch::Tensor b, c; |
460 | { |
461 | InferenceMode guard; |
462 | b = s.view_as(s); |
463 | } |
464 | ASSERT_THROWS_WITH( |
465 | b.add_(1), |
466 | "A view was created in inference mode and is being modified inplace" ); |
467 | { |
468 | AutoGradMode mode(false); |
469 | c = b.view_as(b); |
470 | } |
471 | ASSERT_THROWS_WITH( |
472 | c.add_(1), |
473 | "A view was created in inference mode and is being modified inplace" ); |
474 | } |
475 | |
476 | TEST(InferenceModeTest, TestCreationMetaPropagationInput) { |
477 | torch::Tensor s = torch::ones({2, 2, 3}).set_requires_grad(true); |
478 | auto s_view = s.view_as(s); |
479 | std::vector<at::Tensor> b, c; |
480 | { |
481 | InferenceMode guard; |
482 | b = s_view.split_with_sizes({1, 1}); |
483 | |
484 | s = s.view_as(s); |
485 | c = s.split_with_sizes({1, 1}); |
486 | } |
487 | for (auto& b_el : b) { |
488 | assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE); |
489 | ASSERT_THROWS_WITH( |
490 | b_el.add_(1), |
491 | "A view was created in inference mode and is being modified inplace" ); |
492 | } |
493 | for (auto& c_el : c) { |
494 | assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE); |
495 | ASSERT_THROWS_WITH( |
496 | c_el.add_(1), |
497 | "A view was created in inference mode and is being modified inplace" ); |
498 | } |
499 | } |
500 | |
501 | TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) { |
502 | for (bool requires_grad : {true, false}) { |
503 | torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); |
504 | torch::Tensor t; |
505 | { |
506 | InferenceMode guard; |
507 | t = torch::ones({1, 2, 3}); |
508 | t.copy_(s); |
509 | ASSERT_TRUE(t.is_inference()); |
510 | ASSERT_FALSE(t.requires_grad()); |
511 | } |
512 | |
513 | ASSERT_THROWS_WITH( |
514 | t.copy_(s), |
515 | "Inplace update to inference tensor outside InferenceMode is not allowed" ); |
516 | } |
517 | } |
518 | |
519 | TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) { |
520 | torch::Tensor t; |
521 | { |
522 | InferenceMode guard; |
523 | t = torch::ones({1, 2, 3}); |
524 | } |
525 | t.set_requires_grad(false); |
526 | ASSERT_THROWS_WITH( |
527 | t.set_requires_grad(true), |
528 | "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed." ); |
529 | } |
530 | |
531 | TEST(InferenceModeTest, TestAccessVersionCounter) { |
532 | torch::Tensor t; |
533 | { |
534 | InferenceMode guard; |
535 | t = torch::ones({1, 2, 3}); |
536 | ASSERT_THROWS_WITH( |
537 | t.unsafeGetTensorImpl()->version_counter().current_version(), |
538 | "Inference tensors do not track version counter." ); |
539 | t.unsafeGetTensorImpl()->bump_version(); |
540 | } |
541 | ASSERT_THROWS_WITH( |
542 | t.unsafeGetTensorImpl()->version_counter().current_version(), |
543 | "Inference tensors do not track version counter." ); |
544 | ASSERT_THROWS_WITH( |
545 | t.unsafeGetTensorImpl()->bump_version(), |
546 | "Inplace update to inference tensor outside InferenceMode is not allowed." ); |
547 | // Suggested workaround |
548 | torch::Tensor c = t.clone(); |
549 | uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version(); |
550 | c.unsafeGetTensorImpl()->bump_version(); |
551 | ASSERT_EQ( |
552 | c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1); |
553 | } |
554 | |
555 | TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) { |
556 | torch::Tensor s = torch::ones({1, 2, 3}); |
557 | torch::Tensor t; |
558 | { |
559 | InferenceMode guard; |
560 | t = torch::ones({1, 2, 3}); |
561 | // Testing both copy_ from VariableTypeManual and add_ from generated code. |
562 | s.copy_(t); |
563 | s.add_(t); |
564 | t.add_(s); |
565 | t.copy_(s); |
566 | } |
567 | s.copy_(t); |
568 | s.add_(t); |
569 | ASSERT_THROWS_WITH( |
570 | t.copy_(s), |
571 | "Inplace update to inference tensor outside InferenceMode is not allowed" ); |
572 | |
573 | ASSERT_THROWS_WITH( |
574 | t.add_(s), |
575 | "Inplace update to inference tensor outside InferenceMode is not allowed" ); |
576 | } |
577 | |
578 | TEST(InferenceModeTest, TestComplexViewInInferenceMode) { |
579 | torch::Tensor s = torch::ones({3, 3, 2}); |
580 | torch::Tensor t = torch::view_as_complex(s); |
581 | { |
582 | InferenceMode guard; |
583 | torch::Tensor tmp; |
584 | |
585 | tmp = torch::view_as_real(t); |
586 | ASSERT_FALSE(tmp.is_inference()); |
587 | tmp = torch::view_as_complex(s); |
588 | ASSERT_FALSE(tmp.is_inference()); |
589 | |
590 | torch::Tensor e = torch::ones({3, 3, 2}); |
591 | tmp = torch::view_as_complex(e); |
592 | ASSERT_TRUE(tmp.is_inference()); |
593 | tmp = torch::view_as_real(tmp); |
594 | ASSERT_TRUE(tmp.is_inference()); |
595 | } |
596 | } |
597 | |
598 | TEST(InferenceModeTest, TestComplexViewInNormalMode) { |
599 | torch::Tensor s; |
600 | { |
601 | InferenceMode guard; |
602 | s = torch::ones({3, 3, 2}); |
603 | } |
604 | torch::Tensor tmp = torch::view_as_complex(s); |
605 | ASSERT_TRUE(tmp.is_inference()); |
606 | tmp = torch::view_as_real(tmp); |
607 | ASSERT_TRUE(tmp.is_inference()); |
608 | } |
609 | |
610 | TEST(InferenceModeTest, TestCustomFunction) { |
611 | struct MyFunction : public Function<MyFunction> { |
612 | static Variable forward( |
613 | AutogradContext* ctx, |
614 | Variable var1, |
615 | int mul, |
616 | Variable var2) { |
617 | ctx->saved_data["mul" ] = mul; |
618 | ctx->save_for_backward({var1, var2}); |
619 | return var1 + mul * var2 + var1 * var2; |
620 | } |
621 | |
622 | static variable_list backward( |
623 | AutogradContext* ctx, |
624 | variable_list grad_output) { |
625 | int mul = ctx->saved_data["mul" ].toInt(); |
626 | auto saved = ctx->get_saved_variables(); |
627 | auto var1 = saved[0]; |
628 | auto var2 = saved[1]; |
629 | variable_list output = { |
630 | grad_output[0] + grad_output[0] * var2, |
631 | Variable(), |
632 | grad_output[0] * mul + grad_output[0] * var1}; |
633 | return output; |
634 | } |
635 | }; |
636 | |
637 | { |
638 | InferenceMode guard; |
639 | torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true); |
640 | auto var2 = var1.clone(); |
641 | int mul = 2; |
642 | // If InferenceMode didn't set NoGradGuard automatically, this line |
643 | // would error out when trying to save `var1` and `var2` for backward. |
644 | auto y = MyFunction::apply(var1, mul, var2); |
645 | torch::Tensor expected = var1 + mul * var2 + var1 * var2; |
646 | assert_tensor_equal(y, expected); |
647 | } |
648 | } |
649 | |
650 | TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) { |
651 | c10::WarningUtils::WarnAlways warn_always(true); |
652 | WarningCapture warnings; |
653 | at::AutoNonVariableTypeMode guard; |
654 | ASSERT_TRUE( |
655 | warnings.str().find("AutoNonVariableTypeMode is deprecated" ) != |
656 | std::string::npos); |
657 | } |
658 | |