1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <torch/cuda.h> |
6 | |
7 | // NB: This file is compiled even in CPU build (for some reason), so |
8 | // make sure you don't include any CUDA only headers. |
9 | |
10 | using namespace at; |
11 | |
12 | // TODO: This might be generally helpful aliases elsewhere. |
13 | at::Device CPUDevice() { |
14 | return at::Device(at::kCPU); |
15 | } |
16 | at::Device CUDADevice(DeviceIndex index) { |
17 | return at::Device(at::kCUDA, index); |
18 | } |
19 | |
20 | // A macro so we don't lose location information when an assertion fails. |
21 | #define REQUIRE_OPTIONS(device_, index_, type_, layout_) \ |
22 | ASSERT_EQ(options.device().type(), Device((device_), (index_)).type()); \ |
23 | ASSERT_TRUE( \ |
24 | options.device().index() == Device((device_), (index_)).index()); \ |
25 | ASSERT_EQ(typeMetaToScalarType(options.dtype()), (type_)); \ |
26 | ASSERT_TRUE(options.layout() == (layout_)) |
27 | |
28 | #define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \ |
29 | ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \ |
30 | ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \ |
31 | ASSERT_EQ(tensor.scalar_type(), (type_)); \ |
32 | ASSERT_TRUE(tensor.options().layout() == (layout_)) |
33 | |
34 | TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) { |
35 | auto options = CUDA(kFloat).options(); |
36 | REQUIRE_OPTIONS(kCUDA, -1, kFloat, kStrided); |
37 | |
38 | options = CUDA(kInt).options(); |
39 | REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided); |
40 | |
41 | options = getDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(); |
42 | REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse); |
43 | |
44 | options = getDeprecatedTypeProperties(Backend::SparseCUDA, kByte).options(); |
45 | REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse); |
46 | |
47 | // NOLINTNEXTLINE(bugprone-argument-comment,cppcoreguidelines-avoid-magic-numbers) |
48 | options = CUDA(kFloat).options(/*device=*/5); |
49 | REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided); |
50 | |
51 | options = |
52 | // NOLINTNEXTLINE(bugprone-argument-comment,cppcoreguidelines-avoid-magic-numbers) |
53 | getDeprecatedTypeProperties(Backend::SparseCUDA, kFloat) |
54 | .options(/*device=*/5); |
55 | REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse); |
56 | } |
57 | |
58 | TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) { |
59 | auto options = empty(5, device(kCUDA).dtype(kDouble)).options(); |
60 | REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided); |
61 | |
62 | options = empty(5, getDeprecatedTypeProperties(Backend::SparseCUDA, kByte)) |
63 | .options(); |
64 | REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse); |
65 | |
66 | if (torch::cuda::device_count() > 1) { |
67 | Tensor tensor; |
68 | { |
69 | DeviceGuard guard(CUDADevice(1)); |
70 | tensor = empty(5, device(kCUDA)); |
71 | } |
72 | options = tensor.options(); |
73 | REQUIRE_OPTIONS(kCUDA, 1, kFloat, kStrided); |
74 | |
75 | { |
76 | DeviceGuard guard(CUDADevice(1)); |
77 | tensor = empty(5, device(kCUDA).layout(kSparse)); |
78 | } |
79 | options = tensor.options(); |
80 | REQUIRE_OPTIONS(kCUDA, 1, kFloat, kSparse); |
81 | } |
82 | } |
83 | |