1#include <torch/csrc/lazy/core/tensor_util.h>
2
3#include <c10/util/BFloat16.h>
4#include <c10/util/Half.h>
5#include <c10/util/complex.h>
6#include <c10/util/irange.h>
7#include <torch/csrc/lazy/backend/backend_device.h>
8#include <torch/csrc/lazy/backend/backend_interface.h>
9#include <torch/csrc/lazy/core/config.h>
10#include <torch/csrc/lazy/core/helpers.h>
11
12#include <algorithm>
13#include <cstring>
14#include <functional>
15#include <list>
16#include <numeric>
17#include <thread>
18
19namespace torch {
20namespace lazy {
21
22std::vector<int64_t> ComputeArrayStrides(c10::ArrayRef<int64_t> sizes) {
23 std::vector<int64_t> strides(sizes.size(), 1);
24 for (int64_t i = sizes.size(); i > 1; --i) {
25 strides[i - 2] = strides[i - 1] * sizes[i - 1];
26 }
27 return strides;
28}
29
30std::vector<at::Tensor> DataHandlesToTensors(
31 c10::ArrayRef<BackendDataPtr> data_handles,
32 at::ScalarType dest_element_type) {
33 std::vector<at::Tensor> tensors;
34 for (const auto& handle : data_handles) {
35 tensors.push_back(
36 getBackend()->MakeTensorFromComputationData(handle, dest_element_type));
37 }
38 return tensors;
39}
40
41BackendDataPtr TensorToDataHandle(
42 const at::Tensor& tensor,
43 const BackendDevice& device) {
44 return getBackend()->MakeComputationDataFromTensor(
45 tensor, Shape(tensor.scalar_type(), tensor.sizes()), device);
46}
47
48std::vector<BackendDataPtr> CreateTensorsData(
49 const std::vector<at::Tensor>& tensors,
50 const std::vector<BackendDevice>& devices) {
51 TORCH_CHECK(tensors.size() == devices.size());
52 std::vector<BackendDataPtr> result;
53 result.reserve(tensors.size());
54 for (const auto i : c10::irange(tensors.size())) {
55 result.push_back(TensorToDataHandle(tensors[i], devices[i]));
56 }
57 return result;
58}
59
60bool IsSpecialScalar(const at::Scalar& value) {
61 if (FLAGS_torch_lazy_handle_special_scalars &&
62 (value.isIntegral(false) || value.isFloatingPoint())) {
63 if (FLAGS_torch_lazy_all_numbers_special_scalars) {
64 return true;
65 }
66 double scalar_value = value.toDouble();
67 return scalar_value == 0.0 || std::fabs(scalar_value) == 1.0;
68 }
69 return false;
70}
71
72} // namespace lazy
73} // namespace torch
74