1 | #include <torch/csrc/profiler/data_flow.h> |
2 | |
3 | #include <c10/util/overloaded.h> |
4 | #include <c10/util/variant.h> |
5 | #include <torch/csrc/profiler/collection.h> |
6 | |
7 | namespace torch { |
8 | namespace profiler { |
9 | namespace impl { |
10 | |
11 | namespace { |
12 | static constexpr TensorImplAddress NoTensorImpl{nullptr}; |
13 | |
14 | struct RawTensorInfo { |
15 | TensorImplAddress impl_; |
16 | StorageImplData storage_; |
17 | c10::Device device_; |
18 | bool is_free_; |
19 | |
20 | // Used to assign back to the original structs. |
21 | std::reference_wrapper<c10::optional<AllocationID>> allocation_id_ref_; |
22 | std::reference_wrapper<c10::optional<TensorID>> id_ref_; |
23 | }; |
24 | |
25 | struct RawTensors { |
26 | std::vector<RawTensorInfo>& get() { |
27 | return tensors_; |
28 | } |
29 | |
30 | void operator()(TensorMetadata& t) { |
31 | tensors_.emplace_back(RawTensorInfo{ |
32 | t.impl(), t.data_, t.device_, false, t.allocation_id_, t.id_}); |
33 | } |
34 | |
35 | void operator()(c10::optional<TensorMetadata>& t) { |
36 | if (t.has_value()) { |
37 | (*this)(*t); |
38 | } |
39 | } |
40 | |
41 | void (ExtraFields<EventType::Allocation>& a) { |
42 | const StorageImplData ptr{a.ptr_}; |
43 | const auto is_free = a.alloc_size_ < 0; |
44 | tensors_.emplace_back(RawTensorInfo{ |
45 | NoTensorImpl, ptr, a.device(), is_free, a.allocation_id_, a.id_}); |
46 | } |
47 | |
48 | void operator()(std::vector<TensorMetadata>& t) { |
49 | for (auto& ti : t) { |
50 | (*this)(ti); |
51 | } |
52 | } |
53 | |
54 | template <typename T> |
55 | void operator()(T&) {} |
56 | |
57 | std::vector<RawTensorInfo> tensors_; |
58 | }; |
59 | } // namespace |
60 | |
61 | void calculateUniqueTensorIDs( |
62 | std::vector<std::shared_ptr<Result>>& sorted_results) { |
63 | // This task is equivilent to https://leetcode.com/problems/number-of-islands/ |
64 | // We first cluster events with a greedy index assignment, and then merge |
65 | // groups that overlap. |
66 | std::vector<RawTensorInfo> tensors; |
67 | |
68 | // Flatten results to a uniform representation. |
69 | // -------------------------------------------------------------------------- |
70 | { |
71 | RawTensors raw_tensors; |
72 | |
73 | // The python tracer caches values, so it's only safe to use the first case. |
74 | ska::flat_hash_set<PyModuleSelf> seen_modules; |
75 | ska::flat_hash_set<PyOptimizerSelf> seen_optimizers; |
76 | for (auto& result : sorted_results) { |
77 | result->visit(c10::overloaded( |
78 | [&](ExtraFields<EventType::TorchOp>& torch_op) { |
79 | for (auto& i : torch_op.inputs_) { |
80 | c10::visit(raw_tensors, i); |
81 | } |
82 | }, |
83 | [&](ExtraFields<EventType::PyCall>& py_call) { |
84 | // torch.nn.Module |
85 | if (py_call.module_.has_value() && |
86 | seen_modules.insert(py_call.module_->self_).second) { |
87 | for (auto& p : py_call.module_->parameters_) { |
88 | raw_tensors(p.metadata_); |
89 | raw_tensors(p.grad_metadata_); |
90 | } |
91 | } |
92 | |
93 | // torch.optim.Optimizer |
94 | if (py_call.optimizer_.has_value() && |
95 | seen_optimizers.insert(py_call.optimizer_->self_).second) { |
96 | for (auto& p : py_call.optimizer_->parameters_) { |
97 | raw_tensors(p.metadata_); |
98 | raw_tensors(p.grad_metadata_); |
99 | for (auto& state_i : p.state_) { |
100 | raw_tensors(state_i.second); |
101 | } |
102 | } |
103 | } |
104 | }, |
105 | [&](auto& i) { raw_tensors(i); })); |
106 | } |
107 | tensors = std::move(raw_tensors.tensors_); |
108 | } |
109 | |
110 | // Assign IDs to solve ABA for Storage. |
111 | // -------------------------------------------------------------------------- |
112 | { |
113 | size_t counter{1}; |
114 | using key_t = std::pair<StorageImplData, c10::Device>; |
115 | ska::flat_hash_map<key_t, size_t, HashCombine> versions; |
116 | for (auto& t : tensors) { |
117 | auto inserted = versions.insert({{t.storage_, t.device_}, counter}); |
118 | counter += inserted.second; |
119 | t.allocation_id_ref_.get().emplace(AllocationID(inserted.first->second)); |
120 | if (t.is_free_) { |
121 | versions.erase(inserted.first); |
122 | } |
123 | } |
124 | } |
125 | |
126 | // Handle any allocation events which we cannot prove are for Tensor storage. |
127 | // -------------------------------------------------------------------------- |
128 | { |
129 | ska::flat_hash_set<AllocationID> tensor_set; |
130 | for (const auto& t : tensors) { |
131 | if (t.impl_ != NoTensorImpl) { |
132 | tensor_set.insert(*t.allocation_id_ref_.get()); |
133 | } |
134 | } |
135 | tensors.erase( |
136 | std::remove_if( |
137 | tensors.begin(), |
138 | tensors.end(), |
139 | [&tensor_set](const auto& i) { |
140 | auto it = tensor_set.find(*i.allocation_id_ref_.get()); |
141 | return it == tensor_set.end(); |
142 | }), |
143 | tensors.end()); |
144 | } |
145 | |
146 | // Handle the case that the storage of a TensorImpl changed. |
147 | // -------------------------------------------------------------------------- |
148 | using storage_id_pair_t = std::pair<AllocationID, AllocationID>; |
149 | ska::flat_hash_set<storage_id_pair_t, HashCombine> same_group_set; |
150 | { |
151 | ska::flat_hash_map<TensorImplAddress, AllocationID> impl_map; |
152 | for (const auto& t : tensors) { |
153 | // Storage allocations / frees don't have an associated TensorImpl, so |
154 | // we don't want all storages to merge through nullptr. |
155 | if (!t.impl_) { |
156 | continue; |
157 | } |
158 | |
159 | const auto allocation_id = *t.allocation_id_ref_.get(); |
160 | const auto it = impl_map.insert({t.impl_, allocation_id}).first; |
161 | |
162 | // The pair needs to be sorted for the coalesce step to work properly. |
163 | it->second < allocation_id |
164 | ? same_group_set.insert({it->second, allocation_id}) |
165 | : same_group_set.insert({allocation_id, it->second}); |
166 | } |
167 | } |
168 | |
169 | // Coalesce groups and assign final IDs. |
170 | // -------------------------------------------------------------------------- |
171 | ska::flat_hash_map<AllocationID, size_t> id_map; |
172 | { |
173 | std::vector<storage_id_pair_t> unique_pairs; |
174 | for (const auto& i : same_group_set) { |
175 | unique_pairs.push_back(i); |
176 | } |
177 | std::sort(unique_pairs.begin(), unique_pairs.end()); |
178 | |
179 | size_t current_id{0}; |
180 | for (const auto& i : unique_pairs) { |
181 | auto inserted = id_map.insert({i.first, current_id}); |
182 | current_id += inserted.second; |
183 | id_map.insert({i.second, inserted.first->second}); |
184 | } |
185 | } |
186 | |
187 | // Write back to Tensor IDs. |
188 | // -------------------------------------------------------------------------- |
189 | for (const auto& t : tensors) { |
190 | const auto id = id_map.at(*t.allocation_id_ref_.get()); |
191 | t.id_ref_.get().emplace(TensorID(id)); |
192 | } |
193 | } |
194 | |
195 | } // namespace impl |
196 | } // namespace profiler |
197 | } // namespace torch |
198 | |