1 | #pragma once |
2 | |
3 | #include <torch/csrc/lazy/backend/backend_interface.h> |
4 | #include <torch/csrc/lazy/core/shape.h> |
5 | |
6 | #include <ATen/FunctionalTensorWrapper.h> |
7 | |
8 | #include <string> |
9 | #include <vector> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | |
14 | TORCH_API std::vector<int64_t> ComputeArrayStrides( |
15 | c10::ArrayRef<int64_t> sizes); |
16 | |
17 | TORCH_API std::vector<at::Tensor> DataHandlesToTensors( |
18 | c10::ArrayRef<BackendDataPtr> data_handles, |
19 | at::ScalarType dest_element_type); |
20 | |
21 | // Uploads an ATEN tensor data to the device and fetches the corresponding |
22 | // device data handle. |
23 | TORCH_API BackendDataPtr |
24 | TensorToDataHandle(const at::Tensor& tensor, const BackendDevice& device); |
25 | |
26 | // Retrieves the device data handles by parallel uploading data onto the |
27 | // corresponding devices. |
28 | TORCH_API std::vector<BackendDataPtr> CreateTensorsData( |
29 | const std::vector<at::Tensor>& tensors, |
30 | const std::vector<BackendDevice>& devices); |
31 | |
32 | // Makes a deep copy of an ATEN tensor. |
33 | inline at::Tensor CopyTensor(const at::Tensor& ref) { |
34 | return ref.to(ref.options(), /*non_blocking=*/false, /*copy=*/true); |
35 | } |
36 | |
37 | // Same as above, with an additional cast. |
38 | inline at::Tensor CopyTensor( |
39 | const at::Tensor& ref, |
40 | at::ScalarType dest_type, |
41 | bool copy = true) { |
42 | return ref.to(ref.options().dtype(dest_type), /*non_blocking=*/false, copy); |
43 | } |
44 | |
45 | template <typename T, typename S> |
46 | T OptionalOr(const c10::optional<S>& value, T defval) { |
47 | return value ? static_cast<T>(*value) : defval; |
48 | } |
49 | |
50 | // Unwraps tensor to target dtype if it's a wrapped number. |
51 | inline at::Tensor UnwrapNumber(const at::Tensor& tensor, at::ScalarType dtype) { |
52 | return tensor.unsafeGetTensorImpl()->is_wrapped_number() ? tensor.to(dtype) |
53 | : tensor; |
54 | } |
55 | |
56 | template <typename T> |
57 | at::Scalar MakeIntScalar(T value) { |
58 | return at::Scalar(static_cast<int64_t>(value)); |
59 | } |
60 | |
61 | // Routing values to device data maximizes the changes for compilation cache |
62 | // hits, but it can prevent the compiler to perform optimizations. So tensor |
63 | // values which are within a given set, are routed to constant scalars if this |
64 | // API returns true. |
65 | TORCH_API bool IsSpecialScalar(const at::Scalar& value); |
66 | |
67 | // Note: returns a reference instead of a fresh tensor to avoid refcount bumps. |
68 | inline const at::Tensor& maybe_unwrap_functional(const at::Tensor& tensor) { |
69 | if (at::functionalization::impl::isFunctionalTensor(tensor)) { |
70 | return at::functionalization::impl::unsafeGetFunctionalWrapper(tensor) |
71 | ->value(); |
72 | } else { |
73 | return tensor; |
74 | } |
75 | } |
76 | |
77 | } // namespace lazy |
78 | } // namespace torch |
79 | |