1#pragma once
2
3#include <ATen/core/IListRef.h>
4#include <ATen/core/Tensor.h>
5#include <c10/core/DeviceGuard.h>
6#include <c10/core/ScalarType.h> // TensorList whyyyyy
7
8namespace at {
9
10// Are you here because you're wondering why DeviceGuard(tensor) no
11// longer works? For code organization reasons, we have temporarily(?)
12// removed this constructor from DeviceGuard. The new way to
13// spell it is:
14//
15// OptionalDeviceGuard guard(device_of(tensor));
16
17/// Return the Device of a Tensor, if the Tensor is defined.
18inline c10::optional<Device> device_of(const Tensor& t) {
19 if (t.defined()) {
20 return c10::make_optional(t.device());
21 } else {
22 return c10::nullopt;
23 }
24}
25
26inline c10::optional<Device> device_of(const c10::optional<Tensor>& t) {
27 return t.has_value() ? device_of(t.value()) : nullopt;
28}
29
30/// Return the Device of a TensorList, if the list is non-empty and
31/// the first Tensor is defined. (This function implicitly assumes
32/// that all tensors in the list have the same device.)
33inline c10::optional<Device> device_of(ITensorListRef t) {
34 if (!t.empty()) {
35 return device_of(t.front());
36 } else {
37 return c10::nullopt;
38 }
39}
40
41} // namespace at
42