1#pragma once
2
3#include <c10/core/Device.h>
4#include <c10/core/DispatchKey.h>
5#include <c10/core/ScalarType.h>
6#include <torch/csrc/python_headers.h>
7
8namespace at {
9class Tensor;
10} // namespace at
11
12namespace torch {
13namespace tensors {
14
15// Initializes the Python tensor type objects: torch.FloatTensor,
16// torch.DoubleTensor, etc. and binds them in their containing modules.
17void initialize_python_bindings();
18
19// Same as set_default_tensor_type() but takes a PyObject*
20void py_set_default_tensor_type(PyObject* type_obj);
21
22// Same as py_set_default_tensor_type, but only changes the dtype (ScalarType).
23void py_set_default_dtype(PyObject* dtype_obj);
24
25// Gets the DispatchKey for the default tensor type.
26//
27// TODO: This is nuts! There is no reason to let the default tensor type id
28// change. Probably only store ScalarType, as that's the only flex point
29// we support.
30c10::DispatchKey get_default_dispatch_key();
31at::Device get_default_device();
32
33// Gets the ScalarType for the default tensor type.
34at::ScalarType get_default_scalar_type();
35} // namespace tensors
36} // namespace torch
37