1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tsl/framework/tracking_allocator.h" |
17 | |
18 | #include "tensorflow/tsl/platform/env.h" |
19 | #include "tensorflow/tsl/platform/logging.h" |
20 | |
21 | namespace tsl { |
22 | |
23 | TrackingAllocator::TrackingAllocator(Allocator* allocator, bool track_sizes) |
24 | : allocator_(allocator), |
25 | ref_(1), |
26 | allocated_(0), |
27 | high_watermark_(0), |
28 | total_bytes_(0), |
29 | track_sizes_locally_(track_sizes && !allocator_->TracksAllocationSizes()), |
30 | next_allocation_id_(0) {} |
31 | |
32 | void* TrackingAllocator::AllocateRaw( |
33 | size_t alignment, size_t num_bytes, |
34 | const AllocationAttributes& allocation_attr) { |
35 | void* ptr = allocator_->AllocateRaw(alignment, num_bytes, allocation_attr); |
36 | // If memory is exhausted AllocateRaw returns nullptr, and we should |
37 | // pass this through to the caller |
38 | if (nullptr == ptr) { |
39 | return ptr; |
40 | } |
41 | if (allocator_->TracksAllocationSizes()) { |
42 | size_t allocated_bytes = allocator_->AllocatedSize(ptr); |
43 | { |
44 | mutex_lock lock(mu_); |
45 | allocated_ += allocated_bytes; |
46 | high_watermark_ = std::max(high_watermark_, allocated_); |
47 | total_bytes_ += allocated_bytes; |
48 | allocations_.emplace_back(allocated_bytes, Env::Default()->NowMicros()); |
49 | ++ref_; |
50 | } |
51 | } else if (track_sizes_locally_) { |
52 | // Call the underlying allocator to try to get the allocated size |
53 | // whenever possible, even when it might be slow. If this fails, |
54 | // use the requested size as an approximation. |
55 | size_t allocated_bytes = allocator_->AllocatedSizeSlow(ptr); |
56 | allocated_bytes = std::max(num_bytes, allocated_bytes); |
57 | mutex_lock lock(mu_); |
58 | next_allocation_id_ += 1; |
59 | Chunk chunk = {num_bytes, allocated_bytes, next_allocation_id_}; |
60 | in_use_.emplace(std::make_pair(ptr, chunk)); |
61 | allocated_ += allocated_bytes; |
62 | high_watermark_ = std::max(high_watermark_, allocated_); |
63 | total_bytes_ += allocated_bytes; |
64 | allocations_.emplace_back(allocated_bytes, Env::Default()->NowMicros()); |
65 | ++ref_; |
66 | } else { |
67 | mutex_lock lock(mu_); |
68 | total_bytes_ += num_bytes; |
69 | allocations_.emplace_back(num_bytes, Env::Default()->NowMicros()); |
70 | ++ref_; |
71 | } |
72 | return ptr; |
73 | } |
74 | |
75 | void TrackingAllocator::DeallocateRaw(void* ptr) { |
76 | // freeing a null ptr is a no-op |
77 | if (nullptr == ptr) { |
78 | return; |
79 | } |
80 | bool should_delete; |
81 | // fetch the following outside the lock in case the call to |
82 | // AllocatedSize is slow |
83 | bool tracks_allocation_sizes = allocator_->TracksAllocationSizes(); |
84 | size_t allocated_bytes = 0; |
85 | if (tracks_allocation_sizes) { |
86 | allocated_bytes = allocator_->AllocatedSize(ptr); |
87 | } else if (track_sizes_locally_) { |
88 | mutex_lock lock(mu_); |
89 | auto itr = in_use_.find(ptr); |
90 | if (itr != in_use_.end()) { |
91 | tracks_allocation_sizes = true; |
92 | allocated_bytes = (*itr).second.allocated_size; |
93 | in_use_.erase(itr); |
94 | } |
95 | } |
96 | Allocator* allocator = allocator_; |
97 | { |
98 | mutex_lock lock(mu_); |
99 | if (tracks_allocation_sizes) { |
100 | CHECK_GE(allocated_, allocated_bytes); |
101 | allocated_ -= allocated_bytes; |
102 | allocations_.emplace_back(-allocated_bytes, Env::Default()->NowMicros()); |
103 | } |
104 | should_delete = UnRef(); |
105 | } |
106 | allocator->DeallocateRaw(ptr); |
107 | if (should_delete) { |
108 | delete this; |
109 | } |
110 | } |
111 | |
112 | bool TrackingAllocator::TracksAllocationSizes() const { |
113 | return track_sizes_locally_ || allocator_->TracksAllocationSizes(); |
114 | } |
115 | |
116 | size_t TrackingAllocator::RequestedSize(const void* ptr) const { |
117 | if (track_sizes_locally_) { |
118 | mutex_lock lock(mu_); |
119 | auto it = in_use_.find(ptr); |
120 | if (it != in_use_.end()) { |
121 | return (*it).second.requested_size; |
122 | } |
123 | return 0; |
124 | } else { |
125 | return allocator_->RequestedSize(ptr); |
126 | } |
127 | } |
128 | |
129 | size_t TrackingAllocator::AllocatedSize(const void* ptr) const { |
130 | if (track_sizes_locally_) { |
131 | mutex_lock lock(mu_); |
132 | auto it = in_use_.find(ptr); |
133 | if (it != in_use_.end()) { |
134 | return (*it).second.allocated_size; |
135 | } |
136 | return 0; |
137 | } else { |
138 | return allocator_->AllocatedSize(ptr); |
139 | } |
140 | } |
141 | |
142 | int64_t TrackingAllocator::AllocationId(const void* ptr) const { |
143 | if (track_sizes_locally_) { |
144 | mutex_lock lock(mu_); |
145 | auto it = in_use_.find(ptr); |
146 | if (it != in_use_.end()) { |
147 | return (*it).second.allocation_id; |
148 | } |
149 | return 0; |
150 | } else { |
151 | return allocator_->AllocationId(ptr); |
152 | } |
153 | } |
154 | |
155 | absl::optional<AllocatorStats> TrackingAllocator::GetStats() { |
156 | return allocator_->GetStats(); |
157 | } |
158 | |
159 | bool TrackingAllocator::ClearStats() { return allocator_->ClearStats(); } |
160 | |
161 | std::tuple<size_t, size_t, size_t> TrackingAllocator::GetSizes() { |
162 | size_t high_watermark; |
163 | size_t total_bytes; |
164 | size_t still_live_bytes; |
165 | { |
166 | mutex_lock lock(mu_); |
167 | high_watermark = high_watermark_; |
168 | total_bytes = total_bytes_; |
169 | still_live_bytes = allocated_; |
170 | } |
171 | return std::make_tuple(total_bytes, high_watermark, still_live_bytes); |
172 | } |
173 | |
174 | gtl::InlinedVector<AllocRecord, 4> TrackingAllocator::GetRecordsAndUnRef() { |
175 | bool should_delete; |
176 | gtl::InlinedVector<AllocRecord, 4> allocations; |
177 | { |
178 | mutex_lock lock(mu_); |
179 | allocations.swap(allocations_); |
180 | should_delete = UnRef(); |
181 | } |
182 | if (should_delete) { |
183 | delete this; |
184 | } |
185 | return allocations; |
186 | } |
187 | |
188 | gtl::InlinedVector<AllocRecord, 4> TrackingAllocator::GetCurrentRecords() { |
189 | gtl::InlinedVector<AllocRecord, 4> allocations; |
190 | { |
191 | mutex_lock lock(mu_); |
192 | for (const AllocRecord& alloc : allocations_) { |
193 | allocations.push_back(alloc); |
194 | } |
195 | } |
196 | return allocations; |
197 | } |
198 | |
199 | bool TrackingAllocator::UnRef() { |
200 | CHECK_GE(ref_, 1); |
201 | --ref_; |
202 | return (ref_ == 0); |
203 | } |
204 | |
205 | } // end namespace tsl |
206 | |