1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/save_restore_tensor.h" |
17 | |
18 | #include <memory> |
19 | #include <numeric> |
20 | #include <unordered_map> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/bounds_check.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/framework/types.pb.h" |
29 | #include "tensorflow/core/lib/core/threadpool.h" |
30 | #include "tensorflow/core/lib/gtl/array_slice.h" |
31 | #include "tensorflow/core/lib/strings/str_util.h" |
32 | #include "tensorflow/core/lib/strings/strcat.h" |
33 | #include "tensorflow/core/lib/strings/stringprintf.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" |
37 | #include "tensorflow/core/util/tensor_slice_reader.h" |
38 | #include "tensorflow/core/util/tensor_slice_reader_cache.h" |
39 | #include "tensorflow/core/util/tensor_slice_writer.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | void SaveTensors( |
44 | OpKernelContext* context, |
45 | checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, |
46 | bool save_slices) { |
47 | const Tensor& filename_t = context->input(0); |
48 | { |
49 | const int64_t size = filename_t.NumElements(); |
50 | OP_REQUIRES( |
51 | context, size == 1, |
52 | errors::InvalidArgument( |
53 | "Input 0 (filename) must be a string scalar; got a tensor of " , |
54 | size, "elements" )); |
55 | } |
56 | |
57 | // Path, names, and slices if save_slices is true. |
58 | const int kFixedInputs = save_slices ? 3 : 2; |
59 | const Tensor& tensor_names_t = context->input(1); |
60 | OP_REQUIRES(context, |
61 | FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs, |
62 | std::numeric_limits<int>::max()), |
63 | errors::InvalidArgument("Too many inputs to SaveTensors" )); |
64 | const int N = static_cast<int>(tensor_names_t.NumElements()); |
65 | const tstring* tensor_shapes_and_slices_ptr = nullptr; |
66 | if (save_slices) { |
67 | const Tensor& tensor_shapes_and_slices_t = context->input(2); |
68 | OP_REQUIRES( |
69 | context, |
70 | tensor_shapes_and_slices_t.NumElements() == static_cast<int64_t>(N), |
71 | errors::InvalidArgument("Expected " , N, |
72 | " elements for the tensor " |
73 | "shapes and slices but got " , |
74 | tensor_shapes_and_slices_t.NumElements())); |
75 | tensor_shapes_and_slices_ptr = |
76 | tensor_shapes_and_slices_t.flat<tstring>().data(); |
77 | } |
78 | OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs, |
79 | errors::InvalidArgument("Expected totally " , N + kFixedInputs, |
80 | " inputs as input #1 (which is a string " |
81 | "tensor of saved names) contains " , |
82 | N, " names, but received " , |
83 | context->num_inputs(), " inputs" )); |
84 | |
85 | VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0) |
86 | << "..." ; |
87 | checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0), |
88 | std::move(builder_func)); |
89 | |
90 | Status s; |
91 | auto tensor_names_flat = tensor_names_t.flat<tstring>(); |
92 | |
93 | // Process tensors in sorted name order. This allows us to avoid seeking |
94 | // during restoration in the common case where we are restoring a full |
95 | // checkpoint. |
96 | // RestoreTensorsV2 was changed to sort by file offset, so this sorting isn't |
97 | // strictly necessary anymore. However, restores with TF version <= 2.7 will |
98 | // still benefit. |
99 | std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); |
100 | std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); |
101 | std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), |
102 | [&tensor_names_flat](size_t a, size_t b) { |
103 | return tensor_names_flat(a) < tensor_names_flat(b); |
104 | }); |
105 | |
106 | for (const size_t i : sorted_name_idx) { |
107 | const string& name = tensor_names_flat(i); |
108 | const Tensor& input = context->input(i + kFixedInputs); |
109 | TensorShape shape(input.shape()); |
110 | TensorSlice slice(input.dims()); |
111 | if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) { |
112 | const tstring& shape_spec = tensor_shapes_and_slices_ptr[i]; |
113 | TensorShape slice_shape; |
114 | OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( |
115 | shape_spec, &shape, &slice, &slice_shape)); |
116 | OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()), |
117 | errors::InvalidArgument( |
118 | "Slice in shape_and_slice " |
119 | "specification does not match the " |
120 | "shape of the tensor to save: " , |
121 | shape_spec, ", tensor: " , input.shape().DebugString())); |
122 | } |
123 | |
124 | #define WRITER_ADD(T) \ |
125 | case DataTypeToEnum<T>::value: \ |
126 | s = writer.Add(name, shape, slice, input.flat<T>().data()); \ |
127 | break; |
128 | |
129 | switch (input.dtype()) { |
130 | TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD) |
131 | default: |
132 | context->SetStatus(errors::Unimplemented("Saving data type " , |
133 | DataTypeString(input.dtype()), |
134 | " not yet supported" )); |
135 | return; |
136 | } |
137 | #undef WRITER_ADD |
138 | if (!s.ok()) { |
139 | context->SetStatus(s); |
140 | return; |
141 | } |
142 | } |
143 | |
144 | s = writer.Finish(); |
145 | if (!s.ok()) { |
146 | context->SetStatus(s); |
147 | } |
148 | } |
149 | |
150 | void RestoreTensor(OpKernelContext* context, |
151 | checkpoint::TensorSliceReader::OpenTableFunction open_func, |
152 | int preferred_shard, bool restore_slice, int restore_index) { |
153 | const Tensor& file_pattern_t = context->input(0); |
154 | { |
155 | const int64_t size = file_pattern_t.NumElements(); |
156 | OP_REQUIRES( |
157 | context, size == 1, |
158 | errors::InvalidArgument( |
159 | "Input 0 (file_pattern) must be a string scalar; got a tensor of " , |
160 | size, " elements" )); |
161 | } |
162 | const string& file_pattern = file_pattern_t.flat<tstring>()(0); |
163 | |
164 | const Tensor& tensor_name_t = context->input(1); |
165 | { |
166 | const int64_t size = tensor_name_t.NumElements(); |
167 | OP_REQUIRES(context, size > restore_index, |
168 | errors::InvalidArgument( |
169 | "Input 1 (file_pattern) must be a have at least " , |
170 | restore_index + 1, " elements" )); |
171 | } |
172 | const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index); |
173 | |
174 | // If we cannot find a cached reader we will allocate our own. |
175 | std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader; |
176 | |
177 | const checkpoint::TensorSliceReader* reader = nullptr; |
178 | |
179 | if (context->slice_reader_cache()) { |
180 | reader = context->slice_reader_cache()->GetReader(file_pattern, open_func, |
181 | preferred_shard); |
182 | } |
183 | if (!reader) { |
184 | allocated_reader.reset(new checkpoint::TensorSliceReader( |
185 | file_pattern, open_func, preferred_shard)); |
186 | reader = allocated_reader.get(); |
187 | } |
188 | OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status()); |
189 | |
190 | // Get the shape and type from the save file. |
191 | DataType type; |
192 | TensorShape saved_shape; |
193 | OP_REQUIRES( |
194 | context, reader->HasTensor(tensor_name, &saved_shape, &type), |
195 | errors::NotFound("Tensor name \"" , tensor_name, |
196 | "\" not found in checkpoint files " , file_pattern)); |
197 | OP_REQUIRES( |
198 | context, type == context->expected_output_dtype(restore_index), |
199 | errors::InvalidArgument("Expected to restore a tensor of type " , |
200 | DataTypeString(context->expected_output_dtype(0)), |
201 | ", got a tensor of type " , DataTypeString(type), |
202 | " instead: tensor_name = " , tensor_name)); |
203 | |
204 | // Shape of the output and slice to load. |
205 | TensorShape output_shape(saved_shape); |
206 | TensorSlice slice_to_load(saved_shape.dims()); |
207 | if (restore_slice) { |
208 | const tstring& shape_spec = |
209 | context->input(2).flat<tstring>()(restore_index); |
210 | if (!shape_spec.empty()) { |
211 | TensorShape parsed_shape; |
212 | OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( |
213 | shape_spec, &parsed_shape, &slice_to_load, |
214 | &output_shape)); |
215 | OP_REQUIRES( |
216 | context, parsed_shape.IsSameSize(saved_shape), |
217 | errors::InvalidArgument( |
218 | "Shape in shape_and_slice spec does not match the shape in the " |
219 | "save file: " , |
220 | parsed_shape.DebugString(), |
221 | ", save file shape: " , saved_shape.DebugString())); |
222 | } |
223 | } |
224 | |
225 | Tensor* t = nullptr; |
226 | OP_REQUIRES_OK(context, |
227 | context->allocate_output(restore_index, output_shape, &t)); |
228 | |
229 | if (output_shape.num_elements() == 0) return; |
230 | |
231 | #define READER_COPY(T) \ |
232 | case DataTypeToEnum<T>::value: \ |
233 | OP_REQUIRES(context, \ |
234 | reader->CopySliceData(tensor_name, slice_to_load, \ |
235 | t->flat<T>().data()), \ |
236 | errors::InvalidArgument("Error copying slice data")); \ |
237 | break; |
238 | |
239 | switch (type) { |
240 | TF_CALL_SAVE_RESTORE_TYPES(READER_COPY) |
241 | default: |
242 | context->SetStatus(errors::Unimplemented( |
243 | "Restoring data type " , DataTypeString(type), " not yet supported" )); |
244 | } |
245 | #undef READER_COPY |
246 | } |
247 | |
248 | namespace { |
249 | |
250 | // Tensors larger than this threshold will be restored from a thread-pool. |
251 | const int64_t kLargeShapeThreshold = 16 << 20; // 16M |
252 | |
253 | // A restore operation for a single tensor. Small tensors may be restored |
254 | // directly from the op thread to improve read locality. Large tensors can be |
255 | // restored from a thread pool: this requires creating a separate BundleReader |
256 | // for each restore. |
257 | struct RestoreOp { |
258 | RestoreOp(OpKernelContext* context, int idx, const string& tensor_name, |
259 | const string& shape_and_slice, const string& reader_prefix, |
260 | DataType dtype) |
261 | : context(context), |
262 | idx(idx), |
263 | tensor_name(tensor_name), |
264 | shape_and_slice(shape_and_slice), |
265 | reader_prefix(reader_prefix), |
266 | dtype(dtype) {} |
267 | |
268 | // Move-only. It does not make sense to "run()" a copied RestoreOp. |
269 | RestoreOp(const RestoreOp&) = delete; |
270 | RestoreOp& operator=(const RestoreOp&) = delete; |
271 | RestoreOp(RestoreOp&&) = default; |
272 | RestoreOp& operator=(RestoreOp&&) = default; |
273 | |
274 | bool should_run_in_pool(BundleReader* reader) const { |
275 | TensorShape restored_full_shape; |
276 | |
277 | // Ignore status here; we'll catch the error later. |
278 | if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) { |
279 | return false; |
280 | } |
281 | |
282 | return restored_full_shape.num_elements() > kLargeShapeThreshold; |
283 | } |
284 | |
285 | // Run this restore operation using a new BundleReader. |
286 | void run_with_new_reader() { |
287 | BundleReader reader(Env::Default(), reader_prefix); |
288 | if (!reader.status().ok()) { |
289 | status = reader.status(); |
290 | return; |
291 | } |
292 | |
293 | status = run(&reader); |
294 | } |
295 | |
296 | Status run(BundleReader* reader) { |
297 | TensorShape restored_full_shape; |
298 | TF_RETURN_IF_ERROR( |
299 | reader->LookupTensorShape(tensor_name, &restored_full_shape)); |
300 | |
301 | VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : " |
302 | << restored_full_shape.num_elements(); |
303 | Tensor* restored_tensor; |
304 | if (shape_and_slice.empty()) { |
305 | // Lookup the full tensor. |
306 | TF_RETURN_IF_ERROR( |
307 | context->allocate_output(idx, restored_full_shape, &restored_tensor)); |
308 | TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor)); |
309 | } else { |
310 | // Lookup the slice. |
311 | TensorShape parsed_full_shape; |
312 | TensorSlice parsed_slice; |
313 | TensorShape parsed_slice_shape; |
314 | |
315 | TF_RETURN_IF_ERROR( |
316 | checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, |
317 | &parsed_slice, &parsed_slice_shape)); |
318 | |
319 | if (!restored_full_shape.IsSameSize(parsed_full_shape)) { |
320 | return errors::InvalidArgument( |
321 | "tensor_name = " , tensor_name, "; shape in shape_and_slice spec " , |
322 | parsed_full_shape.DebugString(), |
323 | " does not match the shape stored in checkpoint: " , |
324 | restored_full_shape.DebugString()); |
325 | } |
326 | TF_RETURN_IF_ERROR( |
327 | context->allocate_output(idx, parsed_slice_shape, &restored_tensor)); |
328 | TF_RETURN_IF_ERROR( |
329 | reader->LookupSlice(tensor_name, parsed_slice, restored_tensor)); |
330 | } |
331 | if (VLOG_IS_ON(5)) { |
332 | if (restored_tensor->dtype() == DT_FLOAT) { |
333 | const float* t_data = restored_tensor->flat<float>().data(); |
334 | float min = std::numeric_limits<float>::infinity(); |
335 | float max = -std::numeric_limits<float>::infinity(); |
336 | double avg = 0.0; |
337 | for (int i = 0; i < restored_tensor->NumElements(); ++i) { |
338 | if (t_data[i] < min) min = t_data[i]; |
339 | if (t_data[i] > max) max = t_data[i]; |
340 | avg += t_data[i]; |
341 | } |
342 | VLOG(5) << " min " << min << " max " << max << " avg " |
343 | << avg / restored_tensor->NumElements() << " total elts " |
344 | << restored_tensor->NumElements(); |
345 | } |
346 | } |
347 | VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : " |
348 | << restored_full_shape.num_elements(); |
349 | return OkStatus(); |
350 | } |
351 | |
352 | OpKernelContext* context; |
353 | int idx; |
354 | string tensor_name; |
355 | string shape_and_slice; |
356 | string reader_prefix; |
357 | DataType dtype; |
358 | |
359 | ::tensorflow::Status status; |
360 | }; |
361 | |
362 | } // namespace |
363 | |
364 | Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, |
365 | const Tensor& tensor_names, |
366 | const Tensor& shape_and_slices, |
367 | gtl::ArraySlice<DataType> dtypes) { |
368 | const string& prefix_string = prefix.scalar<tstring>()(); |
369 | |
370 | const auto& tensor_names_flat = tensor_names.flat<tstring>(); |
371 | const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>(); |
372 | |
373 | std::vector<RestoreOp> restore_ops; |
374 | restore_ops.reserve(tensor_names_flat.size()); |
375 | for (int i = 0; i < tensor_names_flat.size(); ++i) { |
376 | restore_ops.push_back({context, i, tensor_names_flat(i), |
377 | shape_and_slices_flat(i), prefix_string, dtypes[i]}); |
378 | } |
379 | |
380 | BundleReader default_reader(Env::Default(), prefix_string); |
381 | TF_RETURN_IF_ERROR(default_reader.status()); |
382 | |
383 | TF_RETURN_IF_ERROR(default_reader.SortForSequentialAccess<RestoreOp>( |
384 | restore_ops, [](const RestoreOp& op) { return op.tensor_name; })); |
385 | |
386 | std::vector<string> mismatched_errors; |
387 | for (const RestoreOp& restore_op : restore_ops) { |
388 | TensorShape restored_full_shape; |
389 | DataType original_dtype; |
390 | TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape( |
391 | restore_op.tensor_name, &original_dtype, &restored_full_shape)); |
392 | if (restore_op.dtype != original_dtype) { |
393 | string error_msg = strings::StrCat( |
394 | "tensor_name = " , restore_op.tensor_name, "; expected dtype " , |
395 | DataTypeString(restore_op.dtype), " does not equal original dtype " , |
396 | DataTypeString(original_dtype)); |
397 | mismatched_errors.emplace_back(error_msg); |
398 | } |
399 | } |
400 | if (!mismatched_errors.empty()) { |
401 | const string error_msg = absl::StrJoin(mismatched_errors, "\n" ); |
402 | return errors::InvalidArgument(error_msg); |
403 | } |
404 | |
405 | std::vector<RestoreOp*> pool_restore_ops; |
406 | std::vector<RestoreOp*> direct_restore_ops; |
407 | for (RestoreOp& restore_op : restore_ops) { |
408 | if (restore_op.should_run_in_pool(&default_reader)) { |
409 | pool_restore_ops.push_back(&restore_op); |
410 | } else { |
411 | direct_restore_ops.push_back(&restore_op); |
412 | } |
413 | } |
414 | |
415 | { |
416 | // Schedule any threaded operations first, skipping thread pool creation if |
417 | // we don't have any expensive operations. |
418 | std::unique_ptr<thread::ThreadPool> reader_pool; |
419 | if (!pool_restore_ops.empty()) { |
420 | reader_pool.reset( |
421 | new thread::ThreadPool(Env::Default(), "restore_tensors" , 8)); |
422 | for (auto* op : pool_restore_ops) { |
423 | reader_pool->Schedule([op]() { op->run_with_new_reader(); }); |
424 | } |
425 | } |
426 | |
427 | // Read small tensors from the op thread |
428 | for (auto* op : direct_restore_ops) { |
429 | TF_RETURN_IF_ERROR(op->run(&default_reader)); |
430 | } |
431 | } |
432 | |
433 | // Check status of pool ops; this must come after the pool shuts down. |
434 | for (auto* op : pool_restore_ops) { |
435 | TF_RETURN_IF_ERROR(op->status); |
436 | } |
437 | |
438 | for (const RestoreOp& restore_op : restore_ops) { |
439 | if (restore_op.dtype != context->mutable_output(restore_op.idx)->dtype()) { |
440 | return errors::InvalidArgument( |
441 | "tensor_name = " , restore_op.tensor_name, "; expected dtype " , |
442 | DataTypeString(restore_op.dtype), " does not equal restored dtype " , |
443 | DataTypeString(context->mutable_output(restore_op.idx)->dtype())); |
444 | } |
445 | } |
446 | |
447 | return OkStatus(); |
448 | } |
449 | |
450 | } // namespace tensorflow |
451 | |