1#pragma once
2
3#include <torch/csrc/lazy/backend/backend_interface.h>
4#include <torch/csrc/lazy/core/shape.h>
5
6#include <ATen/FunctionalTensorWrapper.h>
7
8#include <string>
9#include <vector>
10
11namespace torch {
12namespace lazy {
13
14TORCH_API std::vector<int64_t> ComputeArrayStrides(
15 c10::ArrayRef<int64_t> sizes);
16
17TORCH_API std::vector<at::Tensor> DataHandlesToTensors(
18 c10::ArrayRef<BackendDataPtr> data_handles,
19 at::ScalarType dest_element_type);
20
21// Uploads an ATEN tensor data to the device and fetches the corresponding
22// device data handle.
23TORCH_API BackendDataPtr
24TensorToDataHandle(const at::Tensor& tensor, const BackendDevice& device);
25
26// Retrieves the device data handles by parallel uploading data onto the
27// corresponding devices.
28TORCH_API std::vector<BackendDataPtr> CreateTensorsData(
29 const std::vector<at::Tensor>& tensors,
30 const std::vector<BackendDevice>& devices);
31
32// Makes a deep copy of an ATEN tensor.
33inline at::Tensor CopyTensor(const at::Tensor& ref) {
34 return ref.to(ref.options(), /*non_blocking=*/false, /*copy=*/true);
35}
36
37// Same as above, with an additional cast.
38inline at::Tensor CopyTensor(
39 const at::Tensor& ref,
40 at::ScalarType dest_type,
41 bool copy = true) {
42 return ref.to(ref.options().dtype(dest_type), /*non_blocking=*/false, copy);
43}
44
45template <typename T, typename S>
46T OptionalOr(const c10::optional<S>& value, T defval) {
47 return value ? static_cast<T>(*value) : defval;
48}
49
50// Unwraps tensor to target dtype if it's a wrapped number.
51inline at::Tensor UnwrapNumber(const at::Tensor& tensor, at::ScalarType dtype) {
52 return tensor.unsafeGetTensorImpl()->is_wrapped_number() ? tensor.to(dtype)
53 : tensor;
54}
55
56template <typename T>
57at::Scalar MakeIntScalar(T value) {
58 return at::Scalar(static_cast<int64_t>(value));
59}
60
61// Routing values to device data maximizes the changes for compilation cache
62// hits, but it can prevent the compiler to perform optimizations. So tensor
63// values which are within a given set, are routed to constant scalars if this
64// API returns true.
65TORCH_API bool IsSpecialScalar(const at::Scalar& value);
66
67// Note: returns a reference instead of a fresh tensor to avoid refcount bumps.
68inline const at::Tensor& maybe_unwrap_functional(const at::Tensor& tensor) {
69 if (at::functionalization::impl::isFunctionalTensor(tensor)) {
70 return at::functionalization::impl::unsafeGetFunctionalWrapper(tensor)
71 ->value();
72 } else {
73 return tensor;
74 }
75}
76
77} // namespace lazy
78} // namespace torch
79