1#pragma once
2
3#include <ATen/Tensor.h>
4#include <c10/core/Scalar.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#else
9#include <ATen/ops/scalar_tensor.h>
10#endif
11
12namespace at {
13namespace detail {
14// When filling a number to 1-element CPU tensor, we want to skip
15// everything but manipulate data ptr directly.
16// Ideally this fast pass should be implemented in TensorIterator,
17// but we also want to skip compute_types which in not avoidable
18// in TensorIterator for now.
19Tensor& scalar_fill(Tensor& self, const Scalar& value);
20TORCH_API Tensor scalar_tensor_static(
21 const Scalar& s,
22 c10::optional<ScalarType> dtype_opt,
23 c10::optional<Device> device_opt);
24} // namespace detail
25} // namespace at
26
27// This is in the c10 namespace because we use ADL to find the functions in it.
28namespace c10 {
29
30// FIXME: this should be (and was) Scalar::toTensor, but there is currently no
31// way to implement this without going through Derived Types (which are not part
32// of core).
33inline at::Tensor scalar_to_tensor(
34 const Scalar& s,
35 const Device device = at::kCPU) {
36 // This is the fast track we have for CPU scalar tensors.
37 if (device == at::kCPU) {
38 if (s.isFloatingPoint()) {
39 return at::detail::scalar_tensor_static(s, at::kDouble, at::kCPU);
40 } else if (s.isComplex()) {
41 return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
42 } else if (s.isBoolean()) {
43 return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
44 } else {
45 AT_ASSERT(s.isIntegral(false));
46 return at::detail::scalar_tensor_static(s, at::kLong, at::kCPU);
47 }
48 }
49 if (s.isFloatingPoint()) {
50 return at::scalar_tensor(s, at::device(device).dtype(at::kDouble));
51 } else if (s.isBoolean()) {
52 return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
53 } else if (s.isComplex()) {
54 return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble));
55 } else {
56 AT_ASSERT(s.isIntegral(false));
57 return at::scalar_tensor(s, at::device(device).dtype(at::kLong));
58 }
59}
60
61} // namespace c10
62
63namespace at {
64namespace native {
65
66inline Tensor wrapped_scalar_tensor(
67 const Scalar& scalar,
68 const Device device = at::kCPU) {
69 auto tensor = scalar_to_tensor(scalar, device);
70 tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
71 return tensor;
72}
73
74} // namespace native
75} // namespace at
76