1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/tensor_slice_set.h" |
17 | |
18 | #include <vector> |
19 | #include "tensorflow/core/lib/core/errors.h" |
20 | #include "tensorflow/core/lib/gtl/map_util.h" |
21 | #include "tensorflow/core/platform/logging.h" |
22 | #include "tensorflow/core/util/tensor_slice_util.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace checkpoint { |
27 | |
28 | TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type) |
29 | : shape_(shape), type_(type) {} |
30 | |
31 | TensorSliceSet::~TensorSliceSet() {} |
32 | |
33 | Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) { |
34 | TensorShape result_shape; |
35 | TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape)); |
36 | string str = slice.DebugString(); |
37 | |
38 | if (slices_.empty()) { |
39 | slices_hull_ = slice; |
40 | } else { |
41 | // We check if there is any intersection between this slice and any of the |
42 | // registered slices. |
43 | if (slices_hull_.Overlaps(slice)) { |
44 | for (const auto& x : slices_) { |
45 | if (slice.Overlaps(x.second.slice)) { |
46 | return errors::Internal("Overlapping slices: existing slice = " , |
47 | x.first, ", new slice = " , str); |
48 | } |
49 | } |
50 | } |
51 | // No overlap: we can now insert the slice |
52 | slices_hull_.UpdateToCover(slice); |
53 | } |
54 | |
55 | TensorSliceSet::SliceInfo info = {slice, tag, result_shape.num_elements()}; |
56 | slices_.insert(std::make_pair(str, info)); |
57 | return OkStatus(); |
58 | } |
59 | |
60 | bool TensorSliceSet::QueryMeta( |
61 | const TensorSlice& slice, |
62 | std::vector<std::pair<TensorSlice, string>>* results) const { |
63 | results->clear(); |
64 | Status s; |
65 | string str = slice.DebugString(); |
66 | // First we check if there is an exactly match (this is the dominant case). |
67 | const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); |
68 | if (info) { |
69 | results->emplace_back(std::make_pair(info->slice, info->tag)); |
70 | return true; |
71 | } else { |
72 | // We didn't find any exact match but there is still a possibility that |
73 | // multiple existing slices can be patched together to output the slice. |
74 | // We figure this out by computing the intersection of each of the existing |
75 | // slices with the query slice, and check if the union of all these |
76 | // intersections cover the entire slice. We rely on the fact that the |
77 | // existing slices don't have any intersection among themselves. |
78 | TensorShape target_shape; |
79 | Status s; |
80 | s = slice.SliceTensorShape(shape_, &target_shape); |
81 | if (!s.ok()) { |
82 | LOG(WARNING) << s; |
83 | return false; |
84 | } |
85 | int64_t total_size = target_shape.num_elements(); |
86 | |
87 | int64_t overlap_size = 0; |
88 | TensorSlice intersection; |
89 | TensorShape inter_shape; |
90 | for (const auto& x : slices_) { |
91 | if (slice.Intersect(x.second.slice, &intersection)) { |
92 | s = intersection.SliceTensorShape(shape_, &inter_shape); |
93 | if (!s.ok()) { |
94 | LOG(WARNING) << s; |
95 | return false; |
96 | } |
97 | overlap_size += inter_shape.num_elements(); |
98 | results->emplace_back(std::make_pair(x.second.slice, x.second.tag)); |
99 | } |
100 | } |
101 | if (total_size == overlap_size) { |
102 | // We have it! |
103 | return true; |
104 | } else { |
105 | // We don't have all the data for the asked tensor slice |
106 | results->clear(); |
107 | return false; |
108 | } |
109 | } |
110 | } |
111 | |
112 | Status RegisterTensorSlice( |
113 | const string& name, const TensorShape& shape, DataType type, |
114 | const string& tag, const TensorSlice& slice, |
115 | std::unordered_map<string, TensorSliceSet*>* tensor_slices) { |
116 | DCHECK_NE(tensor_slices, nullptr); |
117 | TensorSliceSet* tss = gtl::FindPtrOrNull(*tensor_slices, name); |
118 | // Create a tensor slice set if needed |
119 | if (!tss) { |
120 | tss = new TensorSliceSet(shape, type); |
121 | tensor_slices->insert(std::make_pair(name, tss)); |
122 | } else { |
123 | // Check if the shapes match |
124 | const TensorShape& tss_shape(tss->shape()); |
125 | if (!shape.IsSameSize(tss_shape)) { |
126 | return errors::Internal("Incompatible tensor shapes detected for tensor " , |
127 | name, ": existing = " , tss_shape.DebugString(), |
128 | ", new = " , shape.DebugString()); |
129 | } |
130 | if (type != tss->type()) { |
131 | return errors::Internal("Incompatible tensor types detected for tensor " , |
132 | name, |
133 | ": existing = " , DataTypeString(tss->type()), |
134 | ", new = " , DataTypeString(type)); |
135 | } |
136 | } |
137 | // Register the tensor slices without the actual data. |
138 | return tss->Register(slice, tag); |
139 | } |
140 | |
141 | } // namespace checkpoint |
142 | |
143 | } // namespace tensorflow |
144 | |