1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #include <ATen/Dispatch.h> |
3 | #include <ATen/EmptyTensor.h> |
4 | #include <ATen/ScalarOps.h> |
5 | |
6 | namespace at { |
7 | namespace { |
8 | template <typename scalar_t> |
9 | inline void fill_inplace(Tensor& self, const Scalar& value_scalar) { |
10 | auto value = value_scalar.to<scalar_t>(); |
11 | scalar_t* dptr = static_cast<scalar_t*>(self.data_ptr()); |
12 | *dptr = value; |
13 | } |
14 | } |
15 | |
16 | namespace detail { |
17 | Tensor& scalar_fill(Tensor& self, const Scalar& value) { |
18 | AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
19 | kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out" , [&]() { |
20 | fill_inplace<scalar_t>(self, value); |
21 | }); |
22 | return self; |
23 | } |
24 | |
25 | Tensor scalar_tensor_static(const Scalar& s, c10::optional<ScalarType> dtype_opt, c10::optional<Device> device_opt) { |
26 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
27 | at::AutoDispatchBelowAutograd mode; |
28 | Tensor result = at::detail::empty_cpu( |
29 | {}, dtype_opt, c10::nullopt, device_opt, c10::nullopt, c10::nullopt); |
30 | scalar_fill(result, s); |
31 | return result; |
32 | } |
33 | } // namespace detail |
34 | } // namespace at |
35 | |