1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/core/IListRef.h> |
4 | #include <ATen/core/Tensor.h> |
5 | #include <c10/core/DeviceGuard.h> |
6 | #include <c10/core/ScalarType.h> // TensorList whyyyyy |
7 | |
8 | namespace at { |
9 | |
10 | // Are you here because you're wondering why DeviceGuard(tensor) no |
11 | // longer works? For code organization reasons, we have temporarily(?) |
12 | // removed this constructor from DeviceGuard. The new way to |
13 | // spell it is: |
14 | // |
15 | // OptionalDeviceGuard guard(device_of(tensor)); |
16 | |
17 | /// Return the Device of a Tensor, if the Tensor is defined. |
18 | inline c10::optional<Device> device_of(const Tensor& t) { |
19 | if (t.defined()) { |
20 | return c10::make_optional(t.device()); |
21 | } else { |
22 | return c10::nullopt; |
23 | } |
24 | } |
25 | |
26 | inline c10::optional<Device> device_of(const c10::optional<Tensor>& t) { |
27 | return t.has_value() ? device_of(t.value()) : nullopt; |
28 | } |
29 | |
30 | /// Return the Device of a TensorList, if the list is non-empty and |
31 | /// the first Tensor is defined. (This function implicitly assumes |
32 | /// that all tensors in the list have the same device.) |
33 | inline c10::optional<Device> device_of(ITensorListRef t) { |
34 | if (!t.empty()) { |
35 | return device_of(t.front()); |
36 | } else { |
37 | return c10::nullopt; |
38 | } |
39 | } |
40 | |
41 | } // namespace at |
42 |