1#include <gtest/gtest.h>
2
3#include <c10/core/DeviceGuard.h>
4#include <c10/core/impl/FakeGuardImpl.h>
5
6using namespace c10;
7using namespace c10::impl;
8
9// The tests here are mostly covered by InlineDeviceGuard_test, but there
10// is some DeviceGuard specific functionality we must test.
11
12// -- DeviceGuard -------------------------------------------------------
13
14TEST(DeviceGuard, ResetDeviceDifferentDeviceType) {
15 FakeGuardImpl<DeviceType::CUDA> cuda_impl;
16 FakeGuardImpl<DeviceType::HIP> hip_impl;
17 FakeGuardImpl<DeviceType::CUDA>::setDeviceIndex(0);
18 FakeGuardImpl<DeviceType::HIP>::setDeviceIndex(0);
19 DeviceGuard g(Device(DeviceType::CUDA, 1), &cuda_impl);
20 g.reset_device(Device(DeviceType::HIP, 2), &hip_impl);
21 ASSERT_EQ(FakeGuardImpl<DeviceType::CUDA>::getDeviceIndex(), 0);
22 ASSERT_EQ(FakeGuardImpl<DeviceType::HIP>::getDeviceIndex(), 2);
23 ASSERT_EQ(g.current_device(), Device(DeviceType::HIP, 2));
24 ASSERT_EQ(g.original_device(), Device(DeviceType::HIP, 0));
25}
26
27// -- OptionalDeviceGuard -----------------------------------------------
28
29TEST(OptionalDeviceGuard, ResetDeviceDifferentDeviceType) {
30 FakeGuardImpl<DeviceType::CUDA> cuda_impl;
31 FakeGuardImpl<DeviceType::HIP> hip_impl;
32 FakeGuardImpl<DeviceType::CUDA>::setDeviceIndex(0);
33 FakeGuardImpl<DeviceType::HIP>::setDeviceIndex(0);
34 OptionalDeviceGuard g;
35 g.reset_device(Device(DeviceType::CUDA, 1), &cuda_impl);
36 g.reset_device(Device(DeviceType::HIP, 2), &hip_impl);
37 ASSERT_EQ(FakeGuardImpl<DeviceType::CUDA>::getDeviceIndex(), 0);
38 ASSERT_EQ(FakeGuardImpl<DeviceType::HIP>::getDeviceIndex(), 2);
39 ASSERT_EQ(g.current_device(), make_optional(Device(DeviceType::HIP, 2)));
40 ASSERT_EQ(g.original_device(), make_optional(Device(DeviceType::HIP, 0)));
41}
42