1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/saved_tensor_slice_util.h"
17
18#include "tensorflow/core/lib/core/errors.h"
19#include "tensorflow/core/lib/strings/ordered_code.h"
20#include "tensorflow/core/lib/strings/str_util.h"
21
22namespace tensorflow {
23
24namespace checkpoint {
25
26const char kSavedTensorSlicesKey[] = "";
27
28string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) {
29 string buffer;
30 // All the tensor slice keys will start with a 0
31 tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0);
32 tensorflow::strings::OrderedCode::WriteString(&buffer, name);
33 tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims());
34 for (int d = 0; d < slice.dims(); ++d) {
35 // A trivial extent (meaning we take EVERYTHING) will default to -1 for both
36 // start and end. These will be properly parsed.
37 tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
38 slice.start(d));
39 tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
40 slice.length(d));
41 }
42 return buffer;
43}
44
45Status DecodeTensorNameSlice(const string& code, string* name,
46 tensorflow::TensorSlice* slice) {
47 StringPiece src(code);
48 uint64 x;
49 if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
50 return errors::Internal("Failed to parse the leading number: src = ", src);
51 }
52 if (x != 0) {
53 return errors::Internal(
54 "The leading number should always be 0 for any valid key: src = ", src);
55 }
56 if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) {
57 return errors::Internal("Failed to parse the tensor name: src = ", src);
58 }
59 if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
60 return errors::Internal("Failed to parse the tensor rank: src = ", src);
61 }
62 if (x == 0) {
63 return errors::Internal("Expecting positive rank of the tensor, got ", x,
64 ", src = ", src);
65 }
66 if (x >= kint32max) {
67 return errors::Internal("Too many elements ", x);
68 }
69 slice->SetFullSlice(x);
70 for (int d = 0; d < static_cast<int32>(x); ++d) {
71 // We expected 2x integers
72 int64_t start, length;
73 if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
74 &start)) {
75 return errors::Internal("Failed to parse start: src = ", src);
76 }
77 if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
78 &length)) {
79 return errors::Internal("Failed to parse length: src = ", src);
80 }
81 if (length >= 0) {
82 // a non-trivial extent
83 slice->set_start(d, start);
84 slice->set_length(d, length);
85 }
86 }
87 return OkStatus();
88}
89
90Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
91 TensorSlice* slice, TensorShape* shape_slice) {
92 CHECK(!shape_and_slice.empty());
93 // Syntax: dim0 dim1 dim2 ... <slice string>
94 // Where slice string is defined in core/framework/tensor_slice.h
95 std::vector<string> splits = str_util::Split(shape_and_slice, ' ');
96
97 // Must have at least 2 strings.
98 if (splits.size() < 2) {
99 return errors::InvalidArgument(
100 "Need least two elements in shape_and_slice specification: ",
101 shape_and_slice);
102 }
103
104 // The last split is the slice specification.
105 slice->Clear();
106 auto status = slice->Parse(splits.back(), slice);
107 if (!status.ok()) return status;
108
109 // The first n-1 are the shape specification.
110 splits.pop_back();
111 shape->Clear();
112 for (const auto& s : splits) {
113 int64_t dim;
114 if (!strings::safe_strto64(s, &dim)) {
115 return errors::InvalidArgument(
116 "Non numerical dimension in shape_and_slice: ", shape_and_slice);
117 }
118 shape->AddDim(dim);
119 }
120
121 // The specified slice must be compatible with the specified shape.
122 return slice->SliceTensorShape(*shape, shape_slice);
123}
124
125} // namespace checkpoint
126
127} // namespace tensorflow
128