1 | #pragma once |
2 | |
3 | #include <c10/core/Device.h> |
4 | #include <c10/core/DispatchKey.h> |
5 | #include <c10/core/ScalarType.h> |
6 | #include <torch/csrc/python_headers.h> |
7 | |
8 | namespace at { |
9 | class Tensor; |
10 | } // namespace at |
11 | |
12 | namespace torch { |
13 | namespace tensors { |
14 | |
15 | // Initializes the Python tensor type objects: torch.FloatTensor, |
16 | // torch.DoubleTensor, etc. and binds them in their containing modules. |
17 | void initialize_python_bindings(); |
18 | |
19 | // Same as set_default_tensor_type() but takes a PyObject* |
20 | void py_set_default_tensor_type(PyObject* type_obj); |
21 | |
22 | // Same as py_set_default_tensor_type, but only changes the dtype (ScalarType). |
23 | void py_set_default_dtype(PyObject* dtype_obj); |
24 | |
25 | // Gets the DispatchKey for the default tensor type. |
26 | // |
27 | // TODO: This is nuts! There is no reason to let the default tensor type id |
28 | // change. Probably only store ScalarType, as that's the only flex point |
29 | // we support. |
30 | c10::DispatchKey get_default_dispatch_key(); |
31 | at::Device get_default_device(); |
32 | |
33 | // Gets the ScalarType for the default tensor type. |
34 | at::ScalarType get_default_scalar_type(); |
35 | } // namespace tensors |
36 | } // namespace torch |
37 | |