1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
17 | |
18 | #include "tensorflow/core/lib/core/errors.h" |
19 | #include "tensorflow/core/lib/strings/ordered_code.h" |
20 | #include "tensorflow/core/lib/strings/str_util.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | namespace checkpoint { |
25 | |
26 | const char kSavedTensorSlicesKey[] = "" ; |
27 | |
28 | string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) { |
29 | string buffer; |
30 | // All the tensor slice keys will start with a 0 |
31 | tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0); |
32 | tensorflow::strings::OrderedCode::WriteString(&buffer, name); |
33 | tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims()); |
34 | for (int d = 0; d < slice.dims(); ++d) { |
35 | // A trivial extent (meaning we take EVERYTHING) will default to -1 for both |
36 | // start and end. These will be properly parsed. |
37 | tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer, |
38 | slice.start(d)); |
39 | tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer, |
40 | slice.length(d)); |
41 | } |
42 | return buffer; |
43 | } |
44 | |
45 | Status DecodeTensorNameSlice(const string& code, string* name, |
46 | tensorflow::TensorSlice* slice) { |
47 | StringPiece src(code); |
48 | uint64 x; |
49 | if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) { |
50 | return errors::Internal("Failed to parse the leading number: src = " , src); |
51 | } |
52 | if (x != 0) { |
53 | return errors::Internal( |
54 | "The leading number should always be 0 for any valid key: src = " , src); |
55 | } |
56 | if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) { |
57 | return errors::Internal("Failed to parse the tensor name: src = " , src); |
58 | } |
59 | if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) { |
60 | return errors::Internal("Failed to parse the tensor rank: src = " , src); |
61 | } |
62 | if (x == 0) { |
63 | return errors::Internal("Expecting positive rank of the tensor, got " , x, |
64 | ", src = " , src); |
65 | } |
66 | if (x >= kint32max) { |
67 | return errors::Internal("Too many elements " , x); |
68 | } |
69 | slice->SetFullSlice(x); |
70 | for (int d = 0; d < static_cast<int32>(x); ++d) { |
71 | // We expected 2x integers |
72 | int64_t start, length; |
73 | if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src, |
74 | &start)) { |
75 | return errors::Internal("Failed to parse start: src = " , src); |
76 | } |
77 | if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src, |
78 | &length)) { |
79 | return errors::Internal("Failed to parse length: src = " , src); |
80 | } |
81 | if (length >= 0) { |
82 | // a non-trivial extent |
83 | slice->set_start(d, start); |
84 | slice->set_length(d, length); |
85 | } |
86 | } |
87 | return OkStatus(); |
88 | } |
89 | |
90 | Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, |
91 | TensorSlice* slice, TensorShape* shape_slice) { |
92 | CHECK(!shape_and_slice.empty()); |
93 | // Syntax: dim0 dim1 dim2 ... <slice string> |
94 | // Where slice string is defined in core/framework/tensor_slice.h |
95 | std::vector<string> splits = str_util::Split(shape_and_slice, ' '); |
96 | |
97 | // Must have at least 2 strings. |
98 | if (splits.size() < 2) { |
99 | return errors::InvalidArgument( |
100 | "Need least two elements in shape_and_slice specification: " , |
101 | shape_and_slice); |
102 | } |
103 | |
104 | // The last split is the slice specification. |
105 | slice->Clear(); |
106 | auto status = slice->Parse(splits.back(), slice); |
107 | if (!status.ok()) return status; |
108 | |
109 | // The first n-1 are the shape specification. |
110 | splits.pop_back(); |
111 | shape->Clear(); |
112 | for (const auto& s : splits) { |
113 | int64_t dim; |
114 | if (!strings::safe_strto64(s, &dim)) { |
115 | return errors::InvalidArgument( |
116 | "Non numerical dimension in shape_and_slice: " , shape_and_slice); |
117 | } |
118 | shape->AddDim(dim); |
119 | } |
120 | |
121 | // The specified slice must be compatible with the specified shape. |
122 | return slice->SliceTensorShape(*shape, shape_slice); |
123 | } |
124 | |
125 | } // namespace checkpoint |
126 | |
127 | } // namespace tensorflow |
128 | |