1 | #pragma once |
2 | #include <ATen/core/Tensor.h> |
3 | #include <c10/core/ScalarType.h> |
4 | |
5 | namespace at { |
6 | |
7 | // These functions are defined in ATen/Utils.cpp. |
8 | #define TENSOR(T, S) \ |
9 | TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \ |
10 | inline Tensor tensor( \ |
11 | std::initializer_list<T> values, const TensorOptions& options) { \ |
12 | return at::tensor(ArrayRef<T>(values), options); \ |
13 | } \ |
14 | inline Tensor tensor(T value, const TensorOptions& options) { \ |
15 | return at::tensor(ArrayRef<T>(value), options); \ |
16 | } \ |
17 | inline Tensor tensor(ArrayRef<T> values) { \ |
18 | return at::tensor(std::move(values), at::dtype(k##S)); \ |
19 | } \ |
20 | inline Tensor tensor(std::initializer_list<T> values) { \ |
21 | return at::tensor(ArrayRef<T>(values)); \ |
22 | } \ |
23 | inline Tensor tensor(T value) { \ |
24 | return at::tensor(ArrayRef<T>(value)); \ |
25 | } |
26 | AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) |
27 | AT_FORALL_COMPLEX_TYPES(TENSOR) |
28 | #undef TENSOR |
29 | |
30 | } // namespace at |
31 | |