1#pragma once
2
3#include <ATen/ATen.h>
4#include <ATen/core/functional.h>
5#include <c10/core/TensorOptions.h>
6#include <torch/csrc/Export.h>
7#include <utility>
8
9namespace torch {
10namespace utils {
11
12/// Generate an ID for a combination of tensor backend + scalar type to be used
13/// when ordering tensors ('like' tensors are grouped by pulling out their
14/// backend + scalar type, so this function combines that into a single number)
15inline size_t type_id(const at::Tensor& tensor) {
16 return static_cast<size_t>(tensor.options().backend()) *
17 static_cast<size_t>(at::ScalarType::NumOptions) +
18 static_cast<size_t>(tensor.scalar_type());
19}
20
21inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
22 return at::flatten_dense_tensors(tensors);
23}
24
25inline std::vector<at::Tensor> unflatten_dense_tensors(
26 const at::Tensor& flat,
27 at::TensorList tensors) {
28 return at::unflatten_dense_tensors(flat, tensors);
29}
30
31// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
32struct TensorGroup {
33 std::vector<at::Tensor> tensors;
34 size_t size = 0;
35
36 size_t type_id() {
37 AT_ASSERT(!tensors.empty());
38 return ::torch::utils::type_id(tensors[0]);
39 }
40
41 const at::TensorOptions options() {
42 AT_ASSERT(!tensors.empty());
43 return tensors[0].options();
44 }
45};
46
47// Helper function that takes a list of tensors and splits them into tensor
48// groups by the size limit and outputs these tensor groups. If the input
49// tensors are of different tensor types, they will be split into different
50// groups as well.
51//
52// Two options of splitting provided to the user,
53//
54// Imagine the size_limit is 256 and the list of input tensors are:
55// tensor_a(fp16 - 128 bytes),
56// tensor_b(fp32 - 256 bytes),
57// tensor_c(fp16 - 128 bytes),
58//
59// when fine_grained == false:
60// The function will read the list of tensors sequentially and accumulate
61// enough tensors for each data type until the size_limit, therefore:
62// it will output: {{tensor_a, tensor_c}, {tensor_b}}
63//
64// when fine_grained == true:
65// The function will read the list of tensors sequentially and accumulate
66// enough tensors for all data types until the size_limit, and then split
67// the accumulated tensors into different groups by data types, therefore:
68// it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
69TORCH_API std::vector<TensorGroup> take_tensors(
70 at::TensorList tensors,
71 size_t size_limit,
72 bool fine_grained = false);
73
74TORCH_API void reorder_tensors_like(
75 std::vector<at::Tensor>& tensors,
76 at::TensorList order);
77
78TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
79 at::TensorList tensors);
80
81TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors(
82 const at::Tensor& flat_indices,
83 const at::Tensor& flat_values,
84 at::TensorList tensors);
85
86} // namespace utils
87} // namespace torch
88