1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/io_ops.cc. |
17 | |
18 | #include <string> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/resource_mgr.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/framework/types.pb.h" |
27 | #include "tensorflow/core/kernels/checkpoint_callback_manager.h" |
28 | #include "tensorflow/core/kernels/save_restore_tensor.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/core/lib/io/path.h" |
31 | #include "tensorflow/core/platform/env.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
35 | #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" |
36 | #include "tensorflow/core/util/tensor_slice_reader.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | namespace { |
41 | |
42 | // Shared validations of the inputs to the SaveV2 and RestoreV2 ops. |
43 | void ValidateInputs(bool is_save_op, OpKernelContext* context, |
44 | const Tensor& prefix, const Tensor& tensor_names, |
45 | const Tensor& shape_and_slices) { |
46 | const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. |
47 | const int num_tensors = static_cast<int>(tensor_names.NumElements()); |
48 | OP_REQUIRES( |
49 | context, prefix.NumElements() == 1, |
50 | errors::InvalidArgument("Input prefix should have a single element, got " , |
51 | prefix.NumElements(), " instead." )); |
52 | OP_REQUIRES(context, |
53 | TensorShapeUtils::IsVector(tensor_names.shape()) && |
54 | TensorShapeUtils::IsVector(shape_and_slices.shape()), |
55 | errors::InvalidArgument( |
56 | "Input tensor_names and shape_and_slices " |
57 | "should be an 1-D tensors, got " , |
58 | tensor_names.shape().DebugString(), " and " , |
59 | shape_and_slices.shape().DebugString(), " instead." )); |
60 | OP_REQUIRES(context, |
61 | tensor_names.NumElements() == shape_and_slices.NumElements(), |
62 | errors::InvalidArgument("tensor_names and shape_and_slices " |
63 | "have different number of elements: " , |
64 | tensor_names.NumElements(), " vs. " , |
65 | shape_and_slices.NumElements())); |
66 | OP_REQUIRES(context, |
67 | FastBoundsCheck(tensor_names.NumElements() + kFixedInputs, |
68 | std::numeric_limits<int>::max()), |
69 | errors::InvalidArgument("Too many inputs to the op" )); |
70 | OP_REQUIRES( |
71 | context, shape_and_slices.NumElements() == num_tensors, |
72 | errors::InvalidArgument("Expected " , num_tensors, |
73 | " elements in shapes_and_slices, but got " , |
74 | context->input(2).NumElements())); |
75 | if (is_save_op) { |
76 | OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs, |
77 | errors::InvalidArgument( |
78 | "Got " , num_tensors, " tensor names but " , |
79 | context->num_inputs() - kFixedInputs, " tensors." )); |
80 | OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs, |
81 | errors::InvalidArgument( |
82 | "Expected a total of " , num_tensors + kFixedInputs, |
83 | " inputs as input #1 (which is a string " |
84 | "tensor of saved names) contains " , |
85 | num_tensors, " names, but received " , context->num_inputs(), |
86 | " inputs" )); |
87 | } |
88 | } |
89 | |
90 | } // namespace |
91 | |
92 | // Saves a list of named tensors using the tensor bundle library. |
93 | class SaveV2 : public OpKernel { |
94 | public: |
95 | explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {} |
96 | |
97 | void Compute(OpKernelContext* context) override { |
98 | const Tensor& prefix = context->input(0); |
99 | const Tensor& tensor_names = context->input(1); |
100 | const Tensor& shape_and_slices = context->input(2); |
101 | ValidateInputs(true /* is save op */, context, prefix, tensor_names, |
102 | shape_and_slices); |
103 | if (!context->status().ok()) return; |
104 | |
105 | const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. |
106 | const int num_tensors = static_cast<int>(tensor_names.NumElements()); |
107 | const string& prefix_string = prefix.scalar<tstring>()(); |
108 | const auto& tensor_names_flat = tensor_names.flat<tstring>(); |
109 | const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>(); |
110 | |
111 | BundleWriter writer(Env::Default(), prefix_string); |
112 | OP_REQUIRES_OK(context, writer.status()); |
113 | VLOG(1) << "BundleWriter, prefix_string: " << prefix_string; |
114 | |
115 | for (int i = 0; i < num_tensors; ++i) { |
116 | const string& tensor_name = tensor_names_flat(i); |
117 | const Tensor& tensor = context->input(i + kFixedInputs); |
118 | VLOG(2) << "Starting save of " << tensor_name; |
119 | |
120 | if (!shape_and_slices_flat(i).empty()) { |
121 | const string& shape_spec = shape_and_slices_flat(i); |
122 | TensorShape shape; |
123 | TensorSlice slice(tensor.dims()); |
124 | TensorShape slice_shape; |
125 | |
126 | OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( |
127 | shape_spec, &shape, &slice, &slice_shape)); |
128 | OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()), |
129 | errors::InvalidArgument("Slice in shape_and_slice " |
130 | "specification does not match the " |
131 | "shape of the tensor to save: " , |
132 | shape_spec, ", tensor: " , |
133 | tensor.shape().DebugString())); |
134 | |
135 | OP_REQUIRES_OK(context, |
136 | writer.AddSlice(tensor_name, shape, slice, tensor)); |
137 | } else { |
138 | OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor)); |
139 | } |
140 | |
141 | if (VLOG_IS_ON(5)) { |
142 | if (tensor.dtype() == DT_FLOAT) { |
143 | const float* t_data = tensor.flat<float>().data(); |
144 | float min = std::numeric_limits<float>::infinity(); |
145 | float max = -std::numeric_limits<float>::infinity(); |
146 | double avg = 0.0; |
147 | for (int i = 0; i < tensor.NumElements(); ++i) { |
148 | if (t_data[i] < min) min = t_data[i]; |
149 | if (t_data[i] > max) max = t_data[i]; |
150 | avg += t_data[i]; |
151 | } |
152 | VLOG(5) << " min " << min << " max " << max << " avg " |
153 | << avg / tensor.NumElements() << " total elts " |
154 | << tensor.NumElements(); |
155 | } |
156 | } |
157 | |
158 | VLOG(2) << "Done save of " << tensor_name; |
159 | } |
160 | OP_REQUIRES_OK(context, writer.Finish()); |
161 | VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string; |
162 | |
163 | ResourceMgr* resource_manager = context->resource_manager(); |
164 | if (resource_manager != nullptr) { |
165 | checkpoint::CheckpointCallbackManager* checkpoint_callback_manager; |
166 | OP_REQUIRES_OK( |
167 | context, |
168 | resource_manager |
169 | ->LookupOrCreate<checkpoint::CheckpointCallbackManager>( |
170 | resource_manager->default_container(), |
171 | std::string( |
172 | checkpoint::kCheckpointCallbackManagerResourceName), |
173 | &checkpoint_callback_manager, |
174 | [](checkpoint::CheckpointCallbackManager** out) { |
175 | *out = new checkpoint::CheckpointCallbackManager(); |
176 | return OkStatus(); |
177 | })); |
178 | checkpoint_callback_manager->Save(prefix_string); |
179 | checkpoint_callback_manager->Unref(); |
180 | } |
181 | } |
182 | }; |
183 | REGISTER_KERNEL_BUILDER(Name("SaveV2" ).Device(DEVICE_CPU), SaveV2); |
184 | |
185 | // Restores a list of named tensors from a tensor bundle (V2 checkpoint format). |
186 | class RestoreV2 : public OpKernel { |
187 | public: |
188 | explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) { |
189 | OP_REQUIRES_OK(context, context->GetAttr("dtypes" , &dtypes_)); |
190 | } |
191 | |
192 | void Compute(OpKernelContext* context) override { |
193 | const Tensor& prefix = context->input(0); |
194 | const Tensor& tensor_names = context->input(1); |
195 | const Tensor& shape_and_slices = context->input(2); |
196 | OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(), |
197 | errors::InvalidArgument("Got " , tensor_names.NumElements(), |
198 | " tensor names, but " , dtypes_.size(), |
199 | " expected dtypes." )); |
200 | ValidateInputs(false /* not save op */, context, prefix, tensor_names, |
201 | shape_and_slices); |
202 | if (!context->status().ok()) return; |
203 | |
204 | const string& prefix_string = prefix.scalar<tstring>()(); |
205 | |
206 | // Intention: we plan to use the RestoreV2 op as a backward-compatible |
207 | // reader as we upgrade to the V2 format. This allows transparent upgrade. |
208 | // We here attempt to read a V1 checkpoint, if "prefix_string" does not |
209 | // refer to a V2 checkpoint. |
210 | Env* env = Env::Default(); |
211 | std::vector<string> paths; |
212 | if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() || |
213 | paths.empty()) { |
214 | // Cannot find V2's metadata file, so "prefix_string" does not point to a |
215 | // V2 checkpoint. Invokes the V1 read path instead. |
216 | for (size_t i = 0; i < tensor_names.NumElements(); ++i) { |
217 | RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, |
218 | /* preferred_shard */ -1, /* restore_slice */ true, |
219 | /* restore_index */ i); |
220 | if (!context->status().ok()) { |
221 | return; |
222 | } |
223 | } |
224 | return; |
225 | } |
226 | // If found, invokes the V2 reader. |
227 | OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names, |
228 | shape_and_slices, dtypes_)); |
229 | |
230 | ResourceMgr* resource_manager = context->resource_manager(); |
231 | if (resource_manager != nullptr) { |
232 | checkpoint::CheckpointCallbackManager* checkpoint_callback_manager; |
233 | OP_REQUIRES_OK( |
234 | context, |
235 | resource_manager |
236 | ->LookupOrCreate<checkpoint::CheckpointCallbackManager>( |
237 | resource_manager->default_container(), |
238 | std::string( |
239 | checkpoint::kCheckpointCallbackManagerResourceName), |
240 | &checkpoint_callback_manager, |
241 | [](checkpoint::CheckpointCallbackManager** out) { |
242 | *out = new checkpoint::CheckpointCallbackManager(); |
243 | return OkStatus(); |
244 | })); |
245 | checkpoint_callback_manager->Restore(prefix_string); |
246 | checkpoint_callback_manager->Unref(); |
247 | } |
248 | } |
249 | |
250 | private: |
251 | // Expected dtypes of the to-restore tensors. |
252 | std::vector<DataType> dtypes_; |
253 | }; |
254 | REGISTER_KERNEL_BUILDER(Name("RestoreV2" ).Device(DEVICE_CPU), RestoreV2); |
255 | |
256 | // The final step in saving sharded V2 checkpoints: merges metadata files. |
257 | class MergeV2Checkpoints : public OpKernel { |
258 | public: |
259 | explicit MergeV2Checkpoints(OpKernelConstruction* context) |
260 | : OpKernel(context) { |
261 | OP_REQUIRES_OK(context, |
262 | context->GetAttr("delete_old_dirs" , &delete_old_dirs_)); |
263 | OP_REQUIRES_OK(context, context->GetAttr("allow_missing_files" , |
264 | &allow_missing_files_)); |
265 | } |
266 | |
267 | void Compute(OpKernelContext* context) override { |
268 | const Tensor& checkpoint_prefixes = context->input(0); |
269 | const Tensor& destination_prefix = context->input(1); |
270 | OP_REQUIRES(context, |
271 | TensorShapeUtils::IsVector(checkpoint_prefixes.shape()), |
272 | errors::InvalidArgument( |
273 | "Input checkpoint_prefixes should be an 1-D tensor, got " , |
274 | checkpoint_prefixes.shape().DebugString(), " instead." )); |
275 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()), |
276 | errors::InvalidArgument( |
277 | "Input destination_prefix should be a scalar tensor, got " , |
278 | destination_prefix.shape().DebugString(), " instead." )); |
279 | |
280 | const gtl::ArraySlice<tstring> input_prefixes = |
281 | gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>()); |
282 | Env* env = Env::Default(); |
283 | const string& merged_prefix = destination_prefix.scalar<tstring>()(); |
284 | OP_REQUIRES_OK(context, |
285 | tensorflow::MergeBundles(env, input_prefixes, merged_prefix, |
286 | allow_missing_files_)); |
287 | |
288 | if (delete_old_dirs_) { |
289 | const string merged_dir(io::Dirname(merged_prefix)); |
290 | for (const string& input_prefix : input_prefixes) { |
291 | const string dirname(io::Dirname(input_prefix)); |
292 | if (dirname == merged_dir) continue; |
293 | Status status = env->DeleteDir(dirname); |
294 | // For sharded save, only the first delete will go through and all |
295 | // others will hit NotFound. Use vlog to be less verbose. |
296 | if (!status.ok()) VLOG(1) << status; |
297 | } |
298 | } |
299 | } |
300 | |
301 | private: |
302 | // On merge, whether or not to delete the input (temporary) directories. |
303 | bool delete_old_dirs_; |
304 | |
305 | // On merge, whether or not to relax condition that all input prefix filenames |
306 | // to exist. |
307 | bool allow_missing_files_; |
308 | }; |
309 | REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints" ).Device(DEVICE_CPU), |
310 | MergeV2Checkpoints); |
311 | |
312 | } // namespace tensorflow |
313 | |