1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/core/Reduction.h> |
5 | #include <torch/cuda.h> |
6 | #include <ATen/test/test_assert.h> |
7 | #include <c10/util/irange.h> |
8 | #include <c10/util/CallOnce.h> |
9 | |
10 | // for TH compat test only... |
11 | struct THFloatTensor; |
12 | |
13 | #include <iostream> |
14 | #include <chrono> |
15 | // NOLINTNEXTLINE(modernize-deprecated-headers) |
16 | #include <string.h> |
17 | #include <sstream> |
18 | #include <thread> |
19 | #include <mutex> |
20 | |
21 | #define ASSERT_EQ_RESOLVED(X, Y) \ |
22 | { \ |
23 | bool isEQ = X == Y; \ |
24 | ASSERT_TRUE(isEQ); \ |
25 | } |
26 | |
27 | using namespace at; |
28 | |
29 | void TestResize(DeprecatedTypeProperties& type) { |
30 | auto a = at::empty({0}, type.options()); |
31 | a.resize_({3, 4}); |
32 | ASSERT_EQ_RESOLVED(a.numel(), 12); |
33 | a.resize_({5, 7}); |
34 | ASSERT_EQ_RESOLVED(a.numel(), 35); |
35 | } |
36 | |
37 | void TestOnesAndDot(DeprecatedTypeProperties& type) { |
38 | Tensor b0 = ones({1, 1}, type); |
39 | ASSERT_EQ_RESOLVED((b0 + b0).sum().item<double>(), 2); |
40 | |
41 | Tensor b1 = ones({1, 2}, type); |
42 | ASSERT_EQ_RESOLVED((b1 + b1).sum().item<double>(), 4); |
43 | |
44 | Tensor b = ones({3, 4}, type); |
45 | ASSERT_EQ_RESOLVED((b + b).sum().item<double>(), 24); |
46 | ASSERT_EQ_RESOLVED(b.numel(), 12); |
47 | if (type.backend() != Backend::CPU || type.scalarType() != kHalf) { |
48 | ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12); |
49 | } |
50 | } |
51 | |
52 | void TestSort(DeprecatedTypeProperties& type) { |
53 | Tensor b = rand({3, 4}, type); |
54 | |
55 | auto z = b.sort(1); |
56 | auto z_sorted = std::get<0>(z); |
57 | |
58 | bool isLT = z_sorted[0][0].item<float>() < z_sorted[0][1].item<float>(); |
59 | ASSERT_TRUE(isLT); |
60 | } |
61 | |
62 | void TestRandperm(DeprecatedTypeProperties& type) { |
63 | if (type.backend() != Backend::CUDA) { |
64 | Tensor b = randperm(15, type); |
65 | Tensor rv, ri; |
66 | std::tie(rv, ri) = sort(b, 0); |
67 | bool isLE = (rv[0].item<float>() <= rv[1].item<float>()); |
68 | ASSERT_TRUE(isLE); |
69 | } |
70 | } |
71 | |
72 | void SendContext() { |
73 | std::stringstream ss; |
74 | ss << "context: " << std::hex << (int64_t)&globalContext() << std::endl; |
75 | } |
76 | |
77 | void TestAdd(DeprecatedTypeProperties& type) { |
78 | Tensor a = rand({3, 4}, type); |
79 | Tensor b = rand({3, 4}, type); |
80 | Tensor c = add(a, add(a, b)); |
81 | // TODO:0-dim Tensor d(3.f); |
82 | Scalar d = 3.f; |
83 | if (type.backend() == Backend::CPU && type.scalarType() == kHalf) { |
84 | ASSERT_TRUE(add(c, d).allclose(a + a + b + d, 1e-2)); |
85 | } else { |
86 | ASSERT_TRUE(add(c, d).allclose(a + a + b + d)); |
87 | } |
88 | } |
89 | |
90 | void TestZeros(DeprecatedTypeProperties& type) { |
91 | auto begin = std::chrono::high_resolution_clock::now(); |
92 | Tensor a = zeros({1024, 1024}, type); |
93 | for (const auto i : c10::irange(1, 1000)) { |
94 | (void)i; // Suppress unused variable warning |
95 | a = zeros({128, 128}, type); |
96 | } |
97 | auto end = std::chrono::high_resolution_clock::now(); |
98 | std::cout << std::dec << " " |
99 | << std::chrono::duration_cast<std::chrono::milliseconds>( |
100 | end - begin) |
101 | .count() |
102 | << " ms" << std::endl; |
103 | |
104 | std::srand(std::time(nullptr)); |
105 | ASSERT_EQ(norm(a).item<double>(), 0.0); |
106 | } |
107 | |
108 | void TestLoadsOfAdds(DeprecatedTypeProperties& type) { |
109 | auto begin = std::chrono::high_resolution_clock::now(); |
110 | Tensor d = ones({3, 4}, type); |
111 | Tensor r = zeros({3, 4}, type); |
112 | for (const auto i : c10::irange(100000)) { |
113 | (void)i; // Suppress unused variable warning |
114 | add_out(r, r, d); |
115 | } |
116 | auto end = std::chrono::high_resolution_clock::now(); |
117 | // TODO TEST PERF? |
118 | std::cout << std::dec << " " |
119 | << std::chrono::duration_cast<std::chrono::milliseconds>( |
120 | end - begin) |
121 | .count() |
122 | << " ms" << std::endl; |
123 | ASSERT_EQ_RESOLVED(norm(100000 * d).item<double>(), norm(r).item<double>()); |
124 | } |
125 | |
126 | void TestLoadOfAddsWithCopy(DeprecatedTypeProperties& type) { |
127 | auto begin = std::chrono::high_resolution_clock::now(); |
128 | Tensor d = ones({3, 4}, type); |
129 | Tensor r = zeros({3, 4}, type); |
130 | for (const auto i : c10::irange(100000)) { |
131 | (void)i; // Suppress unused variable warning |
132 | r = add(r, d); |
133 | } |
134 | auto end = std::chrono::high_resolution_clock::now(); |
135 | // TODO TEST PERF? |
136 | std::cout << std::dec << " " |
137 | << std::chrono::duration_cast<std::chrono::milliseconds>( |
138 | end - begin) |
139 | .count() |
140 | << " ms" << std::endl; |
141 | ASSERT_EQ_RESOLVED(norm(100000 * d).item<double>(), norm(r).item<double>()); |
142 | } |
143 | |
144 | void TestIsContiguous(DeprecatedTypeProperties& type) { |
145 | Tensor a = rand({3, 4}, type); |
146 | ASSERT_TRUE(a.is_contiguous()); |
147 | a = a.transpose(0, 1); |
148 | ASSERT_FALSE(a.is_contiguous()); |
149 | } |
150 | |
151 | void TestPermute(DeprecatedTypeProperties& type) { |
152 | Tensor a = rand({3, 4, 5}, type); |
153 | Tensor b = a.permute({1, 2, 0}); |
154 | ASSERT_TRUE(b.sizes().equals({4, 5, 3})); |
155 | ASSERT_TRUE(b.strides().equals({5, 1, 20})); |
156 | } |
157 | |
158 | void TestMm(DeprecatedTypeProperties& type) { |
159 | if (type.backend() != Backend::CPU || type.scalarType() != kHalf) { |
160 | Tensor a = rand({3, 4}, type); |
161 | Tensor b = rand({4}, type); |
162 | Tensor c = mv(a, b); |
163 | ASSERT_TRUE(c.equal(addmv(zeros({3}, type), a, b, 0, 1))); |
164 | } |
165 | } |
166 | |
167 | void TestSqueeze(DeprecatedTypeProperties& type) { |
168 | Tensor a = rand({2, 1}, type); |
169 | Tensor b = squeeze(a); |
170 | ASSERT_EQ_RESOLVED(b.dim(), 1); |
171 | a = rand({1}, type); |
172 | b = squeeze(a); |
173 | // TODO 0-dim squeeze |
174 | ASSERT_TRUE(a[0].equal(b)); |
175 | } |
176 | |
177 | void TestCopy(DeprecatedTypeProperties& type) { |
178 | Tensor a = zeros({4, 3}, type); |
179 | Tensor e = rand({4, 3}, type); |
180 | a.copy_(e); |
181 | ASSERT_TRUE(a.equal(e)); |
182 | } |
183 | |
184 | void TestCopyBroadcasting(DeprecatedTypeProperties& type) { |
185 | Tensor a = zeros({4, 3}, type); |
186 | Tensor e = rand({3}, type); |
187 | a.copy_(e); |
188 | for (const auto i : c10::irange(4)) { |
189 | ASSERT_TRUE(a[i].equal(e)); |
190 | } |
191 | } |
192 | void TestAbsValue(DeprecatedTypeProperties& type) { |
193 | Tensor r = at::abs(at::scalar_tensor(-3, type.options())); |
194 | ASSERT_EQ_RESOLVED(r.item<int32_t>(), 3); |
195 | } |
196 | /* |
197 | TODO(zach): operator overloads |
198 | #if 0 |
199 | { |
200 | std::cout << "eq (value):" << std::endl; |
201 | Tensor a = Tensor(10.f); |
202 | std::cout << (a == 11_i64) << " -- should be 0" << std::endl; |
203 | std::cout << (a == 10_i64) << " -- should be 1" << std::endl; |
204 | std::cout << (a == 10.) << " -- should be 1" << std::endl; |
205 | } |
206 | #endif |
207 | */ |
208 | |
209 | void TestAddingAValueWithScalar(DeprecatedTypeProperties& type) { |
210 | Tensor a = rand({4, 3}, type); |
211 | ASSERT_TRUE((ones({4, 3}, type) + a).equal(add(a, 1))); |
212 | } |
213 | |
214 | void TestSelect(DeprecatedTypeProperties& type) { |
215 | Tensor a = rand({3, 7}, type); |
216 | auto a_13 = select(a, 1, 3); |
217 | auto a_13_02 = select(select(a, 1, 3), 0, 2); |
218 | ASSERT_TRUE(a[0][3].equal(a_13[0])); |
219 | ASSERT_TRUE(a[2][3].equal(a_13_02)); |
220 | } |
221 | |
222 | void TestZeroDim(DeprecatedTypeProperties& type) { |
223 | Tensor a = at::scalar_tensor(4, type.options()); // rand(type, {1}); |
224 | |
225 | Tensor b = rand({3, 4}, type); |
226 | ASSERT_EQ_RESOLVED((a + a).dim(), 0); |
227 | ASSERT_EQ_RESOLVED((1 + a).dim(), 0); |
228 | ASSERT_EQ_RESOLVED((b + a).dim(), 2); |
229 | ASSERT_EQ_RESOLVED((a + b).dim(), 2); |
230 | auto c = rand({3, 4}, type); |
231 | ASSERT_EQ_RESOLVED(c[1][2].dim(), 0); |
232 | |
233 | auto f = rand({3, 4}, type); |
234 | f[2] = zeros({4}, type); |
235 | f[1][0] = -1; |
236 | ASSERT_EQ_RESOLVED(f[2][0].item<double>(), 0); |
237 | } |
238 | |
239 | void TestToCFloat() { |
240 | Tensor a = zeros({3, 4}); |
241 | Tensor b = ones({3, 7}); |
242 | Tensor c = cat({a, b}, 1); |
243 | ASSERT_EQ_RESOLVED(c.size(1), 11); |
244 | |
245 | Tensor e = rand({}); |
246 | ASSERT_EQ_RESOLVED(*e.data_ptr<float>(), e.sum().item<float>()); |
247 | } |
248 | void TestToString() { |
249 | Tensor b = ones({3, 7}) * .0000001f; |
250 | std::stringstream s; |
251 | s << b << "\n" ; |
252 | std::string expect = "1e-07 *" ; |
253 | ASSERT_EQ_RESOLVED(s.str().substr(0, expect.size()), expect); |
254 | } |
255 | |
256 | void TestIndexingByScalar() { |
257 | Tensor tensor = arange(0, 10, kInt); |
258 | Tensor one = ones({}, kInt); |
259 | for (const auto i : c10::irange(tensor.numel())) { |
260 | ASSERT_TRUE(tensor[i].equal(one * i)); |
261 | } |
262 | for (size_t i = 0; i < static_cast<uint64_t>(tensor.numel()); ++i) { |
263 | ASSERT_TRUE(tensor[i].equal(one * static_cast<int64_t>(i))); |
264 | } |
265 | for (const auto i : c10::irange(tensor.numel())) { |
266 | ASSERT_TRUE(tensor[i].equal(one * i)); |
267 | } |
268 | // NOLINTNEXTLINE(bugprone-too-small-loop-variable) |
269 | for (int16_t i = 0; i < tensor.numel(); ++i) { |
270 | ASSERT_TRUE(tensor[i].equal(one * i)); |
271 | } |
272 | // NOLINTNEXTLINE(bugprone-too-small-loop-variable) |
273 | for (int8_t i = 0; i < tensor.numel(); ++i) { |
274 | ASSERT_TRUE(tensor[i].equal(one * i)); |
275 | } |
276 | // Throw StartsWith("Can only index tensors with integral scalars") |
277 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto) |
278 | ASSERT_ANY_THROW(tensor[Scalar(3.14)].equal(one)); |
279 | } |
280 | |
281 | void TestIndexingByZerodimTensor() { |
282 | Tensor tensor = arange(0, 10, kInt); |
283 | Tensor one = ones({}, kInt); |
284 | for (const auto i : c10::irange(tensor.numel())) { |
285 | ASSERT_TRUE(tensor[one * i].equal(one * i)); |
286 | } |
287 | // Throw StartsWith( |
288 | // "Can only index tensors with integral scalars") |
289 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto) |
290 | ASSERT_ANY_THROW(tensor[ones({}) * 3.14].equal(one)); |
291 | // Throw StartsWith("Can only index with tensors that are defined") |
292 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
293 | ASSERT_ANY_THROW(tensor[Tensor()].equal(one)); |
294 | // Throw StartsWith("Can only index with tensors that are scalars (zero-dim)") |
295 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
296 | ASSERT_ANY_THROW(tensor[ones({2, 3, 4}, kInt)].equal(one)); |
297 | } |
298 | void TestIndexingMixedDevice(DeprecatedTypeProperties& type) { |
299 | Tensor tensor = randn({20, 20}, type); |
300 | Tensor index = arange(10, kLong).cpu(); |
301 | Tensor result = tensor.index({index}); |
302 | ASSERT_TRUE(result[0].equal(tensor[0])); |
303 | } |
304 | void TestDispatch() { |
305 | Tensor tensor = randn({20, 20}); |
306 | Tensor other = randn({20, 20}); |
307 | auto result = tensor.m(relu).m(mse_loss, other, at::Reduction::Mean); |
308 | ASSERT_TRUE(result.allclose(mse_loss(relu(tensor), other))); |
309 | } |
310 | |
311 | void TestNegativeDim(DeprecatedTypeProperties& type) { |
312 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
313 | ASSERT_ANY_THROW(empty({5, -5, 5}, type.options())); |
314 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
315 | ASSERT_ANY_THROW(empty({5, -5, -5}, type.options())); |
316 | Tensor tensor = empty({5, 5}, type.options()); |
317 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
318 | ASSERT_ANY_THROW(tensor.reshape({-5, -5})); |
319 | } |
320 | |
321 | void TestView(DeprecatedTypeProperties& type) { |
322 | // Testing the tensor view path, which is different from |
323 | // the Variable view path, see https://github.com/pytorch/pytorch/pull/23452 |
324 | // for details |
325 | Tensor tensor = randn({3, 4}, type);; |
326 | Tensor viewed = tensor.view({3, 4}); |
327 | tensor.resize_({6, 2}); |
328 | ASSERT_TRUE(tensor.sizes().equals({6, 2})); |
329 | ASSERT_TRUE(viewed.sizes().equals({3, 4})); |
330 | } |
331 | |
332 | void TestIntArrayRefExpansion(DeprecatedTypeProperties& type) { |
333 | if (type.backend() != Backend::CPU || type.scalarType() != kHalf) { |
334 | max_pool2d(randn({3, 3, 3, 3}, type.options()), 2, 1, 1, 1); |
335 | max_pool3d(randn({3, 3, 3, 3, 3}, type.options()), 2, 1, 1, 1); |
336 | avg_pool2d(randn({3, 3, 3, 3}, type.options()), 2, 1, 1); |
337 | avg_pool3d(randn({3, 3, 3, 3, 3}, type.options()), 2, 1, 1); |
338 | } |
339 | } |
340 | |
341 | void test(DeprecatedTypeProperties& type) { |
342 | TestResize(type); |
343 | TestOnesAndDot(type); |
344 | |
345 | TestSort(type); |
346 | TestRandperm(type); |
347 | TestAdd(type); |
348 | TestZeros(type); |
349 | TestLoadsOfAdds(type); |
350 | TestLoadOfAddsWithCopy(type); |
351 | TestIsContiguous(type); |
352 | TestPermute(type); |
353 | TestMm(type); |
354 | TestSqueeze(type); |
355 | TestCopy(type); |
356 | TestCopyBroadcasting(type); |
357 | TestAbsValue(type); |
358 | TestAddingAValueWithScalar(type); |
359 | TestSelect(type); |
360 | TestZeroDim(type); |
361 | TestToCFloat(); |
362 | TestToString(); |
363 | TestIndexingByScalar(); |
364 | TestIndexingByZerodimTensor(); |
365 | TestIndexingMixedDevice(type); |
366 | TestDispatch(); |
367 | TestNegativeDim(type); |
368 | TestView(type); |
369 | TestIntArrayRefExpansion(type); |
370 | } |
371 | |
372 | TEST(BasicTest, BasicTestCPU) { |
373 | manual_seed(123); |
374 | |
375 | test(CPU(kFloat)); |
376 | } |
377 | |
378 | TEST(BasicTest, BasicTestHalfCPU) { |
379 | manual_seed(234); |
380 | |
381 | test(CPU(kHalf)); |
382 | } |
383 | |
384 | TEST(BasicTest, BasicTestCUDA) { |
385 | manual_seed(123); |
386 | |
387 | if (at::hasCUDA()) { |
388 | test(CUDA(kFloat)); |
389 | } |
390 | } |
391 | |
392 | TEST(BasicTest, FactoryMethodsTest) { |
393 | // Test default values |
394 | at::Tensor tensor0 = at::empty({4}); |
395 | ASSERT_EQ(tensor0.dtype(), at::kFloat); |
396 | ASSERT_EQ(tensor0.layout(), at::kStrided); |
397 | ASSERT_EQ(tensor0.device(), at::kCPU); |
398 | ASSERT_FALSE(tensor0.requires_grad()); |
399 | ASSERT_FALSE(tensor0.is_pinned()); |
400 | |
401 | // Test setting requires_grad to false. |
402 | tensor0 = at::empty({4}, at::TensorOptions().requires_grad(false)); |
403 | ASSERT_EQ(tensor0.dtype(), at::kFloat); |
404 | ASSERT_EQ(tensor0.layout(), at::kStrided); |
405 | ASSERT_EQ(tensor0.device(), at::kCPU); |
406 | ASSERT_FALSE(tensor0.requires_grad()); |
407 | ASSERT_FALSE(tensor0.is_pinned()); |
408 | |
409 | // Test setting requires_grad to true. |
410 | // This is a bug. Requires_grad was set to TRUE but this is not implemented. |
411 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
412 | EXPECT_ANY_THROW(at::empty({4}, at::TensorOptions().requires_grad(true))); |
413 | |
414 | // Test setting dtype |
415 | at::Tensor tensor1 = at::empty({4}, at::TensorOptions().dtype(at::kHalf)); |
416 | ASSERT_EQ(tensor1.dtype(), at::kHalf); |
417 | ASSERT_EQ(tensor1.layout(), at::kStrided); |
418 | ASSERT_EQ(tensor1.device(), at::kCPU); |
419 | ASSERT_FALSE(tensor1.requires_grad()); |
420 | ASSERT_FALSE(tensor1.is_pinned()); |
421 | |
422 | // Sparse tensor CPU test to avoid requiring CUDA to catch simple bugs. |
423 | // Sparse tensors do not work with static CPU dispatch. |
424 | #ifndef ATEN_CPU_STATIC_DISPATCH |
425 | tensor1 = at::empty({4}, at::TensorOptions().dtype(at::kHalf).layout(at::kSparse)); |
426 | ASSERT_EQ(tensor1.dtype(), at::kHalf); |
427 | ASSERT_EQ(tensor1.layout(), at::kSparse); |
428 | ASSERT_EQ(tensor1.device(), at::kCPU); |
429 | ASSERT_FALSE(tensor1.requires_grad()); |
430 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
431 | ASSERT_FALSE(tensor1.is_pinned()); |
432 | #endif // ATEN_CPU_STATIC_DISPATCH |
433 | |
434 | if (torch::cuda::is_available()) { |
435 | // Test setting pin memory |
436 | tensor1 = at::empty({4}, at::TensorOptions().pinned_memory(true)); |
437 | ASSERT_EQ(tensor1.dtype(), at::kFloat); |
438 | ASSERT_EQ(tensor1.layout(), at::kStrided); |
439 | ASSERT_EQ(tensor1.device(), at::kCPU); |
440 | ASSERT_EQ(tensor1.requires_grad(), false); |
441 | ASSERT_FALSE(tensor1.device().is_cuda()); |
442 | ASSERT_TRUE(tensor1.is_pinned()); |
443 | |
444 | // Test setting device |
445 | tensor1 = at::empty({4}, at::TensorOptions().device(at::kCUDA)); |
446 | ASSERT_EQ(tensor1.dtype(), at::kFloat); |
447 | ASSERT_EQ(tensor1.layout(), at::kStrided); |
448 | ASSERT_TRUE(tensor1.device().is_cuda()); |
449 | ASSERT_FALSE(tensor1.requires_grad()); |
450 | ASSERT_FALSE(tensor1.is_pinned()); |
451 | |
452 | // Test set everything |
453 | tensor1 = at::empty({4}, at::TensorOptions().dtype(at::kHalf).device(at::kCUDA).layout(at::kSparse).requires_grad(false)); |
454 | ASSERT_EQ(tensor1.dtype(), at::kHalf); |
455 | ASSERT_EQ(tensor1.layout(), at::kSparse); |
456 | ASSERT_TRUE(tensor1.device().is_cuda()); |
457 | ASSERT_THROWS(tensor1.nbytes()); |
458 | |
459 | // This is a bug |
460 | // Issue https://github.com/pytorch/pytorch/issues/30405 |
461 | ASSERT_FALSE(tensor1.requires_grad()); |
462 | ASSERT_FALSE(tensor1.is_pinned()); |
463 | } |
464 | |
465 | // Test _like variants |
466 | if (torch::cuda::is_available()) { |
467 | // Issue https://github.com/pytorch/pytorch/issues/28093 |
468 | at::Tensor proto = at::empty({1}, at::kDouble); |
469 | tensor0 = at::empty_like(proto, at::kCUDA); |
470 | ASSERT_EQ(tensor0.dtype(), at::kDouble); |
471 | ASSERT_EQ(tensor0.layout(), at::kStrided); |
472 | ASSERT_TRUE(tensor0.device().is_cuda()); |
473 | ASSERT_FALSE(tensor0.requires_grad()); |
474 | ASSERT_FALSE(tensor0.is_pinned()); |
475 | } |
476 | } |
477 | |
478 | TEST(BasicTest, BasicStdTestCPU) { |
479 | c10::once_flag flag1, flag2; |
480 | |
481 | auto simple_do_once = [&]() |
482 | { |
483 | c10::call_once(flag1, [](){ std::cout << "Simple example: called once\n" ; }); |
484 | }; |
485 | |
486 | auto may_throw_function = [&](bool do_throw) |
487 | { |
488 | if (do_throw) { |
489 | std::cout << "throw: call_once will retry\n" ; // this may appear more than once |
490 | TORCH_CHECK(false, "throw exception" ); |
491 | } |
492 | std::cout << "Didn't throw, call_once will not attempt again\n" ; // guaranteed once |
493 | }; |
494 | |
495 | auto do_once = [&](bool do_throw) |
496 | { |
497 | try { |
498 | c10::call_once(flag2, may_throw_function, do_throw); |
499 | } |
500 | catch (...) { |
501 | } |
502 | }; |
503 | |
504 | std::thread st1(simple_do_once); |
505 | std::thread st2(simple_do_once); |
506 | std::thread st3(simple_do_once); |
507 | std::thread st4(simple_do_once); |
508 | st1.join(); |
509 | st2.join(); |
510 | st3.join(); |
511 | st4.join(); |
512 | |
513 | std::thread t1(do_once, true); |
514 | std::thread t2(do_once, true); |
515 | std::thread t3(do_once, false); |
516 | std::thread t4(do_once, true); |
517 | t1.join(); |
518 | t2.join(); |
519 | t3.join(); |
520 | t4.join(); |
521 | } |
522 | |