1 | #pragma once |
2 | |
3 | #include <memory> |
4 | |
5 | #include <ATen/core/TensorBody.h> |
6 | #include <c10/core/TensorImpl.h> |
7 | #include <c10/macros/Macros.h> |
8 | #include <c10/util/strong_type.h> |
9 | #include <c10/util/variant.h> |
10 | |
11 | namespace torch { |
12 | namespace profiler { |
13 | namespace impl { |
14 | |
15 | // Identity is a complex concept in PyTorch. A Tensor might not have a |
16 | // an associated storage, multiple Tensors might share the same underlying |
17 | // storage, the storage of a Tensor might change over time, etc. |
18 | // |
19 | // For the purpose of profiling we're mostly interested in data flow |
20 | // analysis. As a result, we can take an expansive view of identity: |
21 | // Tensors share an ID if they share a TensorImpl or storage data. |
22 | // |
23 | // This identity equality is transitive; If Tensors T0 and T1 share a storage |
24 | // S0 and T1 later points to a different storage S1 then all Tensors which |
25 | // point to either S0 or S1 are considered to have the same identity. (Since |
26 | // profiler cannot reason beyond that.) |
27 | // |
28 | // The profiler will handle lifetime analysis to ensure that identities do |
29 | // not run afoul of the ABA problem. This does, however, mean that identities |
30 | // can only be assigned when memory profiling is enabled. |
31 | using TensorID = strong::type<size_t, struct TensorID_, strong::regular>; |
32 | |
33 | // Uniquely identifies an allocation. (Generally a StorageImpl's data ptr.) |
34 | using AllocationID = strong::type< |
35 | size_t, |
36 | struct StorageID_, |
37 | strong::ordered, |
38 | strong::regular, |
39 | strong::hashable>; |
40 | |
41 | // We use a Tensor's TensorImpl adress and StorageImpl data start to build the |
42 | // data flow graph. We do not hold an owning reference so we wrap them in strong |
43 | // types to prevent direct access. |
44 | using TensorImplAddress = strong::type< |
45 | const c10::TensorImpl*, |
46 | struct TensorImplAddress_, |
47 | strong::regular, |
48 | strong::hashable, |
49 | strong::boolean>; |
50 | |
51 | using StorageImplData = strong::type< |
52 | void*, |
53 | struct StorageImplData_, |
54 | strong::regular, |
55 | strong::hashable, |
56 | strong::boolean>; |
57 | |
58 | // ============================================================================ |
59 | // == weak_intrusive_ptr and the ABA problem for TensorImpl* ================== |
60 | // ============================================================================ |
61 | // Tracking `TensorImpl`s is an important part of identity tracking, because |
62 | // a Tensor might change storage; however when it does we want to retain the |
63 | // fact that the old and new storage belong to the same logical Tensor. We |
64 | // cannot take an owning reference to the Tensor because that would change |
65 | // program semantics by extending the lifetime of the Tensor. However if we |
66 | // store a raw TensorImpl* pointer the TensorImpl might be deleted and a new |
67 | // TensorImpl might be created that reuses the address. (ABA problem) |
68 | // |
69 | // Fortunately, there is a feature of `c10::intrusive_ptr` that we can use to |
70 | // prevent address reuse for the duration of profiling: the weak intrusive ptr. |
71 | // When a Tensor's refcount reaches zero but there are outstanding weak |
72 | // references (`weakcount_ > 0`) it will free the underlying managed resources |
73 | // by calling `target_->release_resources()`, but it will not call `delete`. |
74 | // (Instead, `delete` is called when the last weak reference is destroyed.) |
75 | // This means that we can safely use address identity to track `TensorImpls`. |
76 | class WeakTensor { |
77 | public: |
78 | explicit WeakTensor(const at::Tensor& t) : weak_self_(t.getIntrusivePtr()) {} |
79 | |
80 | auto get() const { |
81 | return TensorImplAddress{weak_self_._unsafe_get_target()}; |
82 | } |
83 | |
84 | private: |
85 | c10::weak_intrusive_ptr<c10::TensorImpl> weak_self_; |
86 | }; |
87 | |
88 | struct Result; |
89 | |
90 | void calculateUniqueTensorIDs( |
91 | std::vector<std::shared_ptr<Result>>& sorted_results); |
92 | |
93 | } // namespace impl |
94 | } // namespace profiler |
95 | } // namespace torch |
96 | |