1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/core/functional.h> |
5 | #include <c10/core/TensorOptions.h> |
6 | #include <torch/csrc/Export.h> |
7 | #include <utility> |
8 | |
9 | namespace torch { |
10 | namespace utils { |
11 | |
12 | /// Generate an ID for a combination of tensor backend + scalar type to be used |
13 | /// when ordering tensors ('like' tensors are grouped by pulling out their |
14 | /// backend + scalar type, so this function combines that into a single number) |
15 | inline size_t type_id(const at::Tensor& tensor) { |
16 | return static_cast<size_t>(tensor.options().backend()) * |
17 | static_cast<size_t>(at::ScalarType::NumOptions) + |
18 | static_cast<size_t>(tensor.scalar_type()); |
19 | } |
20 | |
21 | inline at::Tensor flatten_dense_tensors(at::TensorList tensors) { |
22 | return at::flatten_dense_tensors(tensors); |
23 | } |
24 | |
25 | inline std::vector<at::Tensor> unflatten_dense_tensors( |
26 | const at::Tensor& flat, |
27 | at::TensorList tensors) { |
28 | return at::unflatten_dense_tensors(flat, tensors); |
29 | } |
30 | |
31 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
32 | struct TensorGroup { |
33 | std::vector<at::Tensor> tensors; |
34 | size_t size = 0; |
35 | |
36 | size_t type_id() { |
37 | AT_ASSERT(!tensors.empty()); |
38 | return ::torch::utils::type_id(tensors[0]); |
39 | } |
40 | |
41 | const at::TensorOptions options() { |
42 | AT_ASSERT(!tensors.empty()); |
43 | return tensors[0].options(); |
44 | } |
45 | }; |
46 | |
47 | // Helper function that takes a list of tensors and splits them into tensor |
48 | // groups by the size limit and outputs these tensor groups. If the input |
49 | // tensors are of different tensor types, they will be split into different |
50 | // groups as well. |
51 | // |
52 | // Two options of splitting provided to the user, |
53 | // |
54 | // Imagine the size_limit is 256 and the list of input tensors are: |
55 | // tensor_a(fp16 - 128 bytes), |
56 | // tensor_b(fp32 - 256 bytes), |
57 | // tensor_c(fp16 - 128 bytes), |
58 | // |
59 | // when fine_grained == false: |
60 | // The function will read the list of tensors sequentially and accumulate |
61 | // enough tensors for each data type until the size_limit, therefore: |
62 | // it will output: {{tensor_a, tensor_c}, {tensor_b}} |
63 | // |
64 | // when fine_grained == true: |
65 | // The function will read the list of tensors sequentially and accumulate |
66 | // enough tensors for all data types until the size_limit, and then split |
67 | // the accumulated tensors into different groups by data types, therefore: |
68 | // it will output: {{tensor_a}, {tensor_b}, {tensor_c}} |
69 | TORCH_API std::vector<TensorGroup> take_tensors( |
70 | at::TensorList tensors, |
71 | size_t size_limit, |
72 | bool fine_grained = false); |
73 | |
74 | TORCH_API void reorder_tensors_like( |
75 | std::vector<at::Tensor>& tensors, |
76 | at::TensorList order); |
77 | |
78 | TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors( |
79 | at::TensorList tensors); |
80 | |
81 | TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors( |
82 | const at::Tensor& flat_indices, |
83 | const at::Tensor& flat_values, |
84 | at::TensorList tensors); |
85 | |
86 | } // namespace utils |
87 | } // namespace torch |
88 | |