1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/core/DeviceGuard.h> |
4 | #include <c10/core/impl/FakeGuardImpl.h> |
5 | |
6 | using namespace c10; |
7 | using namespace c10::impl; |
8 | |
9 | // The tests here are mostly covered by InlineDeviceGuard_test, but there |
10 | // is some DeviceGuard specific functionality we must test. |
11 | |
12 | // -- DeviceGuard ------------------------------------------------------- |
13 | |
14 | TEST(DeviceGuard, ResetDeviceDifferentDeviceType) { |
15 | FakeGuardImpl<DeviceType::CUDA> cuda_impl; |
16 | FakeGuardImpl<DeviceType::HIP> hip_impl; |
17 | FakeGuardImpl<DeviceType::CUDA>::setDeviceIndex(0); |
18 | FakeGuardImpl<DeviceType::HIP>::setDeviceIndex(0); |
19 | DeviceGuard g(Device(DeviceType::CUDA, 1), &cuda_impl); |
20 | g.reset_device(Device(DeviceType::HIP, 2), &hip_impl); |
21 | ASSERT_EQ(FakeGuardImpl<DeviceType::CUDA>::getDeviceIndex(), 0); |
22 | ASSERT_EQ(FakeGuardImpl<DeviceType::HIP>::getDeviceIndex(), 2); |
23 | ASSERT_EQ(g.current_device(), Device(DeviceType::HIP, 2)); |
24 | ASSERT_EQ(g.original_device(), Device(DeviceType::HIP, 0)); |
25 | } |
26 | |
27 | // -- OptionalDeviceGuard ----------------------------------------------- |
28 | |
29 | TEST(OptionalDeviceGuard, ResetDeviceDifferentDeviceType) { |
30 | FakeGuardImpl<DeviceType::CUDA> cuda_impl; |
31 | FakeGuardImpl<DeviceType::HIP> hip_impl; |
32 | FakeGuardImpl<DeviceType::CUDA>::setDeviceIndex(0); |
33 | FakeGuardImpl<DeviceType::HIP>::setDeviceIndex(0); |
34 | OptionalDeviceGuard g; |
35 | g.reset_device(Device(DeviceType::CUDA, 1), &cuda_impl); |
36 | g.reset_device(Device(DeviceType::HIP, 2), &hip_impl); |
37 | ASSERT_EQ(FakeGuardImpl<DeviceType::CUDA>::getDeviceIndex(), 0); |
38 | ASSERT_EQ(FakeGuardImpl<DeviceType::HIP>::getDeviceIndex(), 2); |
39 | ASSERT_EQ(g.current_device(), make_optional(Device(DeviceType::HIP, 2))); |
40 | ASSERT_EQ(g.original_device(), make_optional(Device(DeviceType::HIP, 0))); |
41 | } |
42 | |