1#include <torch/csrc/profiler/data_flow.h>
2
3#include <c10/util/overloaded.h>
4#include <c10/util/variant.h>
5#include <torch/csrc/profiler/collection.h>
6
7namespace torch {
8namespace profiler {
9namespace impl {
10
11namespace {
12static constexpr TensorImplAddress NoTensorImpl{nullptr};
13
14struct RawTensorInfo {
15 TensorImplAddress impl_;
16 StorageImplData storage_;
17 c10::Device device_;
18 bool is_free_;
19
20 // Used to assign back to the original structs.
21 std::reference_wrapper<c10::optional<AllocationID>> allocation_id_ref_;
22 std::reference_wrapper<c10::optional<TensorID>> id_ref_;
23};
24
25struct RawTensors {
26 std::vector<RawTensorInfo>& get() {
27 return tensors_;
28 }
29
30 void operator()(TensorMetadata& t) {
31 tensors_.emplace_back(RawTensorInfo{
32 t.impl(), t.data_, t.device_, false, t.allocation_id_, t.id_});
33 }
34
35 void operator()(c10::optional<TensorMetadata>& t) {
36 if (t.has_value()) {
37 (*this)(*t);
38 }
39 }
40
41 void operator()(ExtraFields<EventType::Allocation>& a) {
42 const StorageImplData ptr{a.ptr_};
43 const auto is_free = a.alloc_size_ < 0;
44 tensors_.emplace_back(RawTensorInfo{
45 NoTensorImpl, ptr, a.device(), is_free, a.allocation_id_, a.id_});
46 }
47
48 void operator()(std::vector<TensorMetadata>& t) {
49 for (auto& ti : t) {
50 (*this)(ti);
51 }
52 }
53
54 template <typename T>
55 void operator()(T&) {}
56
57 std::vector<RawTensorInfo> tensors_;
58};
59} // namespace
60
61void calculateUniqueTensorIDs(
62 std::vector<std::shared_ptr<Result>>& sorted_results) {
63 // This task is equivilent to https://leetcode.com/problems/number-of-islands/
64 // We first cluster events with a greedy index assignment, and then merge
65 // groups that overlap.
66 std::vector<RawTensorInfo> tensors;
67
68 // Flatten results to a uniform representation.
69 // --------------------------------------------------------------------------
70 {
71 RawTensors raw_tensors;
72
73 // The python tracer caches values, so it's only safe to use the first case.
74 ska::flat_hash_set<PyModuleSelf> seen_modules;
75 ska::flat_hash_set<PyOptimizerSelf> seen_optimizers;
76 for (auto& result : sorted_results) {
77 result->visit(c10::overloaded(
78 [&](ExtraFields<EventType::TorchOp>& torch_op) {
79 for (auto& i : torch_op.inputs_) {
80 c10::visit(raw_tensors, i);
81 }
82 },
83 [&](ExtraFields<EventType::PyCall>& py_call) {
84 // torch.nn.Module
85 if (py_call.module_.has_value() &&
86 seen_modules.insert(py_call.module_->self_).second) {
87 for (auto& p : py_call.module_->parameters_) {
88 raw_tensors(p.metadata_);
89 raw_tensors(p.grad_metadata_);
90 }
91 }
92
93 // torch.optim.Optimizer
94 if (py_call.optimizer_.has_value() &&
95 seen_optimizers.insert(py_call.optimizer_->self_).second) {
96 for (auto& p : py_call.optimizer_->parameters_) {
97 raw_tensors(p.metadata_);
98 raw_tensors(p.grad_metadata_);
99 for (auto& state_i : p.state_) {
100 raw_tensors(state_i.second);
101 }
102 }
103 }
104 },
105 [&](auto& i) { raw_tensors(i); }));
106 }
107 tensors = std::move(raw_tensors.tensors_);
108 }
109
110 // Assign IDs to solve ABA for Storage.
111 // --------------------------------------------------------------------------
112 {
113 size_t counter{1};
114 using key_t = std::pair<StorageImplData, c10::Device>;
115 ska::flat_hash_map<key_t, size_t, HashCombine> versions;
116 for (auto& t : tensors) {
117 auto inserted = versions.insert({{t.storage_, t.device_}, counter});
118 counter += inserted.second;
119 t.allocation_id_ref_.get().emplace(AllocationID(inserted.first->second));
120 if (t.is_free_) {
121 versions.erase(inserted.first);
122 }
123 }
124 }
125
126 // Handle any allocation events which we cannot prove are for Tensor storage.
127 // --------------------------------------------------------------------------
128 {
129 ska::flat_hash_set<AllocationID> tensor_set;
130 for (const auto& t : tensors) {
131 if (t.impl_ != NoTensorImpl) {
132 tensor_set.insert(*t.allocation_id_ref_.get());
133 }
134 }
135 tensors.erase(
136 std::remove_if(
137 tensors.begin(),
138 tensors.end(),
139 [&tensor_set](const auto& i) {
140 auto it = tensor_set.find(*i.allocation_id_ref_.get());
141 return it == tensor_set.end();
142 }),
143 tensors.end());
144 }
145
146 // Handle the case that the storage of a TensorImpl changed.
147 // --------------------------------------------------------------------------
148 using storage_id_pair_t = std::pair<AllocationID, AllocationID>;
149 ska::flat_hash_set<storage_id_pair_t, HashCombine> same_group_set;
150 {
151 ska::flat_hash_map<TensorImplAddress, AllocationID> impl_map;
152 for (const auto& t : tensors) {
153 // Storage allocations / frees don't have an associated TensorImpl, so
154 // we don't want all storages to merge through nullptr.
155 if (!t.impl_) {
156 continue;
157 }
158
159 const auto allocation_id = *t.allocation_id_ref_.get();
160 const auto it = impl_map.insert({t.impl_, allocation_id}).first;
161
162 // The pair needs to be sorted for the coalesce step to work properly.
163 it->second < allocation_id
164 ? same_group_set.insert({it->second, allocation_id})
165 : same_group_set.insert({allocation_id, it->second});
166 }
167 }
168
169 // Coalesce groups and assign final IDs.
170 // --------------------------------------------------------------------------
171 ska::flat_hash_map<AllocationID, size_t> id_map;
172 {
173 std::vector<storage_id_pair_t> unique_pairs;
174 for (const auto& i : same_group_set) {
175 unique_pairs.push_back(i);
176 }
177 std::sort(unique_pairs.begin(), unique_pairs.end());
178
179 size_t current_id{0};
180 for (const auto& i : unique_pairs) {
181 auto inserted = id_map.insert({i.first, current_id});
182 current_id += inserted.second;
183 id_map.insert({i.second, inserted.first->second});
184 }
185 }
186
187 // Write back to Tensor IDs.
188 // --------------------------------------------------------------------------
189 for (const auto& t : tensors) {
190 const auto id = id_map.at(*t.allocation_id_ref_.get());
191 t.id_ref_.get().emplace(TensorID(id));
192 }
193}
194
195} // namespace impl
196} // namespace profiler
197} // namespace torch
198