1#pragma once
2
3#include <memory>
4
5#include <ATen/core/TensorBody.h>
6#include <c10/core/TensorImpl.h>
7#include <c10/macros/Macros.h>
8#include <c10/util/strong_type.h>
9#include <c10/util/variant.h>
10
11namespace torch {
12namespace profiler {
13namespace impl {
14
15// Identity is a complex concept in PyTorch. A Tensor might not have a
16// an associated storage, multiple Tensors might share the same underlying
17// storage, the storage of a Tensor might change over time, etc.
18//
19// For the purpose of profiling we're mostly interested in data flow
20// analysis. As a result, we can take an expansive view of identity:
21// Tensors share an ID if they share a TensorImpl or storage data.
22//
23// This identity equality is transitive; If Tensors T0 and T1 share a storage
24// S0 and T1 later points to a different storage S1 then all Tensors which
25// point to either S0 or S1 are considered to have the same identity. (Since
26// profiler cannot reason beyond that.)
27//
28// The profiler will handle lifetime analysis to ensure that identities do
29// not run afoul of the ABA problem. This does, however, mean that identities
30// can only be assigned when memory profiling is enabled.
31using TensorID = strong::type<size_t, struct TensorID_, strong::regular>;
32
33// Uniquely identifies an allocation. (Generally a StorageImpl's data ptr.)
34using AllocationID = strong::type<
35 size_t,
36 struct StorageID_,
37 strong::ordered,
38 strong::regular,
39 strong::hashable>;
40
41// We use a Tensor's TensorImpl adress and StorageImpl data start to build the
42// data flow graph. We do not hold an owning reference so we wrap them in strong
43// types to prevent direct access.
44using TensorImplAddress = strong::type<
45 const c10::TensorImpl*,
46 struct TensorImplAddress_,
47 strong::regular,
48 strong::hashable,
49 strong::boolean>;
50
51using StorageImplData = strong::type<
52 void*,
53 struct StorageImplData_,
54 strong::regular,
55 strong::hashable,
56 strong::boolean>;
57
58// ============================================================================
59// == weak_intrusive_ptr and the ABA problem for TensorImpl* ==================
60// ============================================================================
61// Tracking `TensorImpl`s is an important part of identity tracking, because
62// a Tensor might change storage; however when it does we want to retain the
63// fact that the old and new storage belong to the same logical Tensor. We
64// cannot take an owning reference to the Tensor because that would change
65// program semantics by extending the lifetime of the Tensor. However if we
66// store a raw TensorImpl* pointer the TensorImpl might be deleted and a new
67// TensorImpl might be created that reuses the address. (ABA problem)
68//
69// Fortunately, there is a feature of `c10::intrusive_ptr` that we can use to
70// prevent address reuse for the duration of profiling: the weak intrusive ptr.
71// When a Tensor's refcount reaches zero but there are outstanding weak
72// references (`weakcount_ > 0`) it will free the underlying managed resources
73// by calling `target_->release_resources()`, but it will not call `delete`.
74// (Instead, `delete` is called when the last weak reference is destroyed.)
75// This means that we can safely use address identity to track `TensorImpls`.
76class WeakTensor {
77 public:
78 explicit WeakTensor(const at::Tensor& t) : weak_self_(t.getIntrusivePtr()) {}
79
80 auto get() const {
81 return TensorImplAddress{weak_self_._unsafe_get_target()};
82 }
83
84 private:
85 c10::weak_intrusive_ptr<c10::TensorImpl> weak_self_;
86};
87
88struct Result;
89
90void calculateUniqueTensorIDs(
91 std::vector<std::shared_ptr<Result>>& sorted_results);
92
93} // namespace impl
94} // namespace profiler
95} // namespace torch
96