1 | #pragma once |
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <c10/core/Scalar.h> |
5 | |
6 | #ifndef AT_PER_OPERATOR_HEADERS |
7 | #include <ATen/Functions.h> |
8 | #else |
9 | #include <ATen/ops/scalar_tensor.h> |
10 | #endif |
11 | |
12 | namespace at { |
13 | namespace detail { |
14 | // When filling a number to 1-element CPU tensor, we want to skip |
15 | // everything but manipulate data ptr directly. |
16 | // Ideally this fast pass should be implemented in TensorIterator, |
17 | // but we also want to skip compute_types which in not avoidable |
18 | // in TensorIterator for now. |
19 | Tensor& scalar_fill(Tensor& self, const Scalar& value); |
20 | TORCH_API Tensor scalar_tensor_static( |
21 | const Scalar& s, |
22 | c10::optional<ScalarType> dtype_opt, |
23 | c10::optional<Device> device_opt); |
24 | } // namespace detail |
25 | } // namespace at |
26 | |
27 | // This is in the c10 namespace because we use ADL to find the functions in it. |
28 | namespace c10 { |
29 | |
30 | // FIXME: this should be (and was) Scalar::toTensor, but there is currently no |
31 | // way to implement this without going through Derived Types (which are not part |
32 | // of core). |
33 | inline at::Tensor scalar_to_tensor( |
34 | const Scalar& s, |
35 | const Device device = at::kCPU) { |
36 | // This is the fast track we have for CPU scalar tensors. |
37 | if (device == at::kCPU) { |
38 | if (s.isFloatingPoint()) { |
39 | return at::detail::scalar_tensor_static(s, at::kDouble, at::kCPU); |
40 | } else if (s.isComplex()) { |
41 | return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU); |
42 | } else if (s.isBoolean()) { |
43 | return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU); |
44 | } else { |
45 | AT_ASSERT(s.isIntegral(false)); |
46 | return at::detail::scalar_tensor_static(s, at::kLong, at::kCPU); |
47 | } |
48 | } |
49 | if (s.isFloatingPoint()) { |
50 | return at::scalar_tensor(s, at::device(device).dtype(at::kDouble)); |
51 | } else if (s.isBoolean()) { |
52 | return at::scalar_tensor(s, at::device(device).dtype(at::kBool)); |
53 | } else if (s.isComplex()) { |
54 | return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble)); |
55 | } else { |
56 | AT_ASSERT(s.isIntegral(false)); |
57 | return at::scalar_tensor(s, at::device(device).dtype(at::kLong)); |
58 | } |
59 | } |
60 | |
61 | } // namespace c10 |
62 | |
63 | namespace at { |
64 | namespace native { |
65 | |
66 | inline Tensor wrapped_scalar_tensor( |
67 | const Scalar& scalar, |
68 | const Device device = at::kCPU) { |
69 | auto tensor = scalar_to_tensor(scalar, device); |
70 | tensor.unsafeGetTensorImpl()->set_wrapped_number(true); |
71 | return tensor; |
72 | } |
73 | |
74 | } // namespace native |
75 | } // namespace at |
76 | |