1#pragma once
2#include <ATen/core/Tensor.h>
3#include <c10/core/ScalarType.h>
4
5namespace at {
6
7// These functions are defined in ATen/Utils.cpp.
8#define TENSOR(T, S) \
9 TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
10 inline Tensor tensor( \
11 std::initializer_list<T> values, const TensorOptions& options) { \
12 return at::tensor(ArrayRef<T>(values), options); \
13 } \
14 inline Tensor tensor(T value, const TensorOptions& options) { \
15 return at::tensor(ArrayRef<T>(value), options); \
16 } \
17 inline Tensor tensor(ArrayRef<T> values) { \
18 return at::tensor(std::move(values), at::dtype(k##S)); \
19 } \
20 inline Tensor tensor(std::initializer_list<T> values) { \
21 return at::tensor(ArrayRef<T>(values)); \
22 } \
23 inline Tensor tensor(T value) { \
24 return at::tensor(ArrayRef<T>(value)); \
25 }
26AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
27AT_FORALL_COMPLEX_TYPES(TENSOR)
28#undef TENSOR
29
30} // namespace at
31