1 | #include <torch/csrc/lazy/core/tensor_util.h> |
---|---|
2 | |
3 | #include <c10/util/BFloat16.h> |
4 | #include <c10/util/Half.h> |
5 | #include <c10/util/complex.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/csrc/lazy/backend/backend_device.h> |
8 | #include <torch/csrc/lazy/backend/backend_interface.h> |
9 | #include <torch/csrc/lazy/core/config.h> |
10 | #include <torch/csrc/lazy/core/helpers.h> |
11 | |
12 | #include <algorithm> |
13 | #include <cstring> |
14 | #include <functional> |
15 | #include <list> |
16 | #include <numeric> |
17 | #include <thread> |
18 | |
19 | namespace torch { |
20 | namespace lazy { |
21 | |
22 | std::vector<int64_t> ComputeArrayStrides(c10::ArrayRef<int64_t> sizes) { |
23 | std::vector<int64_t> strides(sizes.size(), 1); |
24 | for (int64_t i = sizes.size(); i > 1; --i) { |
25 | strides[i - 2] = strides[i - 1] * sizes[i - 1]; |
26 | } |
27 | return strides; |
28 | } |
29 | |
30 | std::vector<at::Tensor> DataHandlesToTensors( |
31 | c10::ArrayRef<BackendDataPtr> data_handles, |
32 | at::ScalarType dest_element_type) { |
33 | std::vector<at::Tensor> tensors; |
34 | for (const auto& handle : data_handles) { |
35 | tensors.push_back( |
36 | getBackend()->MakeTensorFromComputationData(handle, dest_element_type)); |
37 | } |
38 | return tensors; |
39 | } |
40 | |
41 | BackendDataPtr TensorToDataHandle( |
42 | const at::Tensor& tensor, |
43 | const BackendDevice& device) { |
44 | return getBackend()->MakeComputationDataFromTensor( |
45 | tensor, Shape(tensor.scalar_type(), tensor.sizes()), device); |
46 | } |
47 | |
48 | std::vector<BackendDataPtr> CreateTensorsData( |
49 | const std::vector<at::Tensor>& tensors, |
50 | const std::vector<BackendDevice>& devices) { |
51 | TORCH_CHECK(tensors.size() == devices.size()); |
52 | std::vector<BackendDataPtr> result; |
53 | result.reserve(tensors.size()); |
54 | for (const auto i : c10::irange(tensors.size())) { |
55 | result.push_back(TensorToDataHandle(tensors[i], devices[i])); |
56 | } |
57 | return result; |
58 | } |
59 | |
60 | bool IsSpecialScalar(const at::Scalar& value) { |
61 | if (FLAGS_torch_lazy_handle_special_scalars && |
62 | (value.isIntegral(false) || value.isFloatingPoint())) { |
63 | if (FLAGS_torch_lazy_all_numbers_special_scalars) { |
64 | return true; |
65 | } |
66 | double scalar_value = value.toDouble(); |
67 | return scalar_value == 0.0 || std::fabs(scalar_value) == 1.0; |
68 | } |
69 | return false; |
70 | } |
71 | |
72 | } // namespace lazy |
73 | } // namespace torch |
74 |