1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
16
17#include "tensorflow/core/common_runtime/scoped_allocator.h"
18#include "tensorflow/core/framework/allocator.h"
19
20namespace tensorflow {
21
22Status ScopedAllocatorContainer::AddScopedAllocator(
23 const Tensor& backing_tensor, int32_t scope_id, const string& scope_name,
24 const gtl::ArraySlice<ScopedAllocator::Field>& fields,
25 int32_t expected_call_count) {
26 VLOG(1) << "AddScopedAllocator " << mgr_->device_name()
27 << " step_id_=" << step_id_ << " scope_id=" << scope_id;
28 mutex_lock l(mu_);
29 // Ensure none of the new scope_ids are in use.
30 auto it = allocators_.find(scope_id);
31 if (it != allocators_.end()) {
32 return errors::Internal("Cannot create ScopedAllocator because scope_id ",
33 scope_id, " for name ", scope_name,
34 " already exists");
35 }
36 for (auto& f : fields) {
37 if (allocators_.find(f.scope_id) != allocators_.end()) {
38 return errors::Internal(
39 "Cannot create ScopedAllocator because field scope_id ", f.scope_id,
40 " for name ", scope_name, " already exists");
41 }
42 }
43 VLOG(2) << " container " << this << " step_id " << step_id_;
44 ScopedAllocator* sa = new ScopedAllocator(
45 backing_tensor, scope_id, scope_name, fields, expected_call_count, this);
46 allocators_[scope_id] =
47 ScopedAllocatorContainer::SAField(ScopedAllocator::kBackingIndex, sa);
48 VLOG(2) << "#fields " << fields.size();
49 for (int i = 0; i < fields.size(); ++i) {
50 const ScopedAllocator::Field& f = fields[i];
51 VLOG(2) << "Adding instance with for " << mgr_->device_name()
52 << " scope_id=" << f.scope_id;
53 allocators_[f.scope_id] = ScopedAllocatorContainer::SAField(
54 i, new ScopedAllocatorInstance(sa, i));
55 }
56 return OkStatus();
57}
58
59ScopedAllocator* ScopedAllocatorContainer::GetAllocator(int32_t scope_id) {
60 mutex_lock l(mu_);
61 auto it = allocators_.find(scope_id);
62 if (it != allocators_.end()) {
63 CHECK_EQ(ScopedAllocator::kBackingIndex, it->second.field_index);
64 return it->second.scoped_allocator;
65 } else {
66 LOG(ERROR) << "Failed to find ScopedAllocator for " << scope_id
67 << " in container for step " << step_id_ << " on "
68 << mgr_->device_name();
69 return nullptr;
70 }
71}
72
73ScopedAllocatorInstance* ScopedAllocatorContainer::GetInstance(
74 int32_t scope_id) {
75 VLOG(2) << "GetInstance " << scope_id << " step " << step_id_ << " on "
76 << mgr_->device_name();
77 mutex_lock l(mu_);
78 auto it = allocators_.find(scope_id);
79 if (it != allocators_.end()) {
80 return it->second.instance;
81 }
82 LOG(FATAL) << "Failed to find instance " << scope_id << " in container "
83 << step_id_ << " on " << mgr_->device_name();
84 return nullptr;
85}
86
87void ScopedAllocatorContainer::Drop(int32_t scope_id, ScopedAllocator* sa) {
88 VLOG(2) << "Drop " << scope_id << " from container " << this << " step "
89 << step_id_ << " on " << mgr_->device_name();
90 mutex_lock l(mu_);
91 auto it = allocators_.find(scope_id);
92 if (it != allocators_.end()) {
93 if (it->second.field_index != ScopedAllocator::kBackingIndex) {
94 it->second.instance->DropFromTable();
95 }
96 allocators_.erase(it);
97 }
98}
99
100ScopedAllocatorContainer::~ScopedAllocatorContainer() {
101 VLOG(2) << "~ScopedAllocatorContainer " << this << " step " << step_id_
102 << " on " << mgr_->device_name();
103 mutex_lock l(mu_);
104 // In normal execution the table should be empty and all of its contents
105 // deleted via Drop. When a step ends early (e.g. through abnormal
106 // termination) we need to clean up explicitly. So long as graph execution
107 // of the associated step has completely terminated this should be safe.
108 for (auto& it : allocators_) {
109 if (it.second.field_index == ScopedAllocator::kBackingIndex) {
110 delete it.second.scoped_allocator;
111 } else {
112 it.second.instance->DropFromTable();
113 }
114 }
115}
116
117ScopedAllocatorMgr::~ScopedAllocatorMgr() {
118 mutex_lock l(mu_);
119 for (auto it : per_step_map_) {
120 // In normal execution the associated ScopedAllocatorContainer is
121 // empty and gone by the end of the step. But in abnormal termination,
122 // such as when an error has interrupted execution or in a unittest,
123 // we need to remove all of its Refs here to avoid memory leaks.
124 // This is safe so long as graph execution has ceased.
125 while (!it.second->Unref()) {
126 }
127 }
128}
129
130void ScopedAllocatorMgr::Cleanup(int64_t step_id) {
131 mutex_lock l(mu_);
132 auto it = per_step_map_.find(step_id);
133 if (it != per_step_map_.end()) {
134 it->second->Unref();
135 per_step_map_.erase(it);
136 }
137}
138
139ScopedAllocatorContainer* ScopedAllocatorMgr::GetContainer(int64_t step_id) {
140 VLOG(2) << "GetContainer " << step_id << " on " << device_name();
141 ScopedAllocatorContainer* sac = nullptr;
142 mutex_lock l(mu_);
143 auto it = per_step_map_.find(step_id);
144 if (it == per_step_map_.end()) {
145 sac = new ScopedAllocatorContainer(this, step_id);
146 per_step_map_[step_id] = sac;
147 } else {
148 sac = it->second;
149 }
150 return sac;
151}
152
153Status ScopedAllocatorMgr::AddScopedAllocator(
154 const Tensor& backing_tensor, int64_t step_id, int32_t scope_id,
155 const string& scope_name,
156 const gtl::ArraySlice<ScopedAllocator::Field>& fields,
157 int32_t expected_call_count) {
158 ScopedAllocatorContainer* sac = GetContainer(step_id);
159 return sac->AddScopedAllocator(backing_tensor, scope_id, scope_name, fields,
160 expected_call_count);
161}
162
163/*static*/
164size_t ScopedAllocatorMgr::PopulateFields(
165 int32_t scope_id, const gtl::ArraySlice<TensorShape>& shapes,
166 const DataType dtype, std::vector<ScopedAllocator::Field>* fields) {
167 const int32_t num_fields = static_cast<int32>(shapes.size());
168 fields->resize(num_fields);
169 // At the end of iteration `i`, `offset` points to the offset from the start
170 // of the backing buffer until the end of `field[i].bytes_allocated`. This
171 // is aligned to `kAllocatorAlignment`.
172 size_t offset = 0;
173 for (int32_t i = 0; i < num_fields; ++i) {
174 size_t bytes_requested = shapes[i].num_elements() * DataTypeSize(dtype);
175 auto* field = &((*fields)[i]);
176 field->scope_id = scope_id + 1 + i;
177 field->bytes_requested = bytes_requested;
178 field->offset = offset;
179 offset += bytes_requested;
180
181 // Compute actual #bytes allocated, which may include padding due to
182 // alignment.
183 size_t bytes_allocated = bytes_requested;
184 size_t overshoot = offset % Allocator::kAllocatorAlignment;
185 if (overshoot > 0) {
186 size_t alignment_bytes = Allocator::kAllocatorAlignment - overshoot;
187 bytes_allocated += alignment_bytes;
188 offset += alignment_bytes;
189 }
190 field->bytes_allocated = bytes_allocated;
191
192 VLOG(1) << "field=" << i << " scope_id=" << field->scope_id
193 << " bytes_requested=" << field->bytes_requested
194 << " offset=" << field->offset
195 << " bytes_allocated=" << field->bytes_allocated;
196 }
197
198 return offset;
199}
200
201} // namespace tensorflow
202